linter fix

This commit is contained in:
erogol 2020-09-24 12:57:54 +02:00
parent 82376298d9
commit 665f7ca714
5 changed files with 53 additions and 107 deletions

View File

@ -69,7 +69,7 @@ class WN(torch.nn.Module):
num_layers,
c_in_channels=0,
dropout_p=0):
super(WN, self).__init__()
super().__init__()
assert kernel_size % 2 == 1
assert hidden_channels % 2 == 0
self.in_channels = in_channels
@ -148,70 +148,6 @@ class WN(torch.nn.Module):
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ActNorm(nn.Module):
"""Activation Normalization bijector as an alternative to Batch Norm. It computes
mean and std from a sample data in advance and it uses these values
for normalization at training.
Args:
channels (int): input channels.
ddi (False): data depended initialization flag.
Shapes:
- inputs: (B, C, T)
- outputs: (B, C, T)
"""
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
super().__init__()
self.channels = channels
self.initialized = not ddi
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device,
dtype=x.dtype)
x_len = torch.sum(x_mask, [1, 2])
if not self.initialized:
self.initialize(x, x_mask)
self.initialized = True
if reverse:
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
logdet = None
else:
z = (self.bias + torch.exp(self.logs) * x) * x_mask
logdet = torch.sum(self.logs) * x_len # [b]
return z, logdet
def store_inverse(self):
pass
def set_ddi(self, ddi):
self.initialized = not ddi
def initialize(self, x, x_mask):
with torch.no_grad():
denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(
dtype=self.bias.dtype)
logs_init = (-logs).view(*self.logs.shape).to(
dtype=self.logs.dtype)
self.bias.data.copy_(bias_init)
self.logs.data.copy_(logs_init)
class InvConvNear(nn.Module):
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
super().__init__()

View File

@ -36,11 +36,10 @@ class TemporalBatchNorm1d(nn.BatchNorm1d):
affine=True,
track_running_stats=True,
momentum=0.1):
super(TemporalBatchNorm1d,
self).__init__(channels,
affine=affine,
track_running_stats=track_running_stats,
momentum=momentum)
super().__init__(channels,
affine=affine,
track_running_stats=track_running_stats,
momentum=momentum)
def forward(self, x):
return super().forward(x.transpose(2, 1)).transpose(2, 1)

View File

@ -11,7 +11,7 @@ class TimeDepthSeparableConv(nn.Module):
out_channels,
kernel_size,
bias=True):
super(TimeDepthSeparableConv, self).__init__()
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -69,7 +69,7 @@ class TimeDepthSeparableConvBlock(nn.Module):
num_layers,
kernel_size,
bias=True):
super(TimeDepthSeparableConvBlock, self).__init__()
super().__init__()
assert (kernel_size - 1) % 2 == 0
assert num_layers > 1

View File

@ -7,9 +7,8 @@ from TTS.tts.utils.generic_utils import sequence_mask
class L1LossMasked(nn.Module):
def __init__(self, seq_len_norm):
super(L1LossMasked, self).__init__()
super().__init__()
self.seq_len_norm = seq_len_norm
def forward(self, x, target, length):
@ -28,25 +27,24 @@ class L1LossMasked(nn.Module):
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2).float()
if self.seq_len_norm:
norm_w = mask / mask.sum(dim=1, keepdim=True)
out_weights = norm_w.div(target.shape[0] * target.shape[2])
mask = mask.expand_as(x)
loss = functional.l1_loss(
x * mask, target * mask, reduction='none')
loss = functional.l1_loss(x * mask,
target * mask,
reduction='none')
loss = loss.mul(out_weights.to(loss.device)).sum()
else:
mask = mask.expand_as(x)
loss = functional.l1_loss(
x * mask, target * mask, reduction='sum')
loss = functional.l1_loss(x * mask, target * mask, reduction='sum')
loss = loss / mask.sum()
return loss
class MSELossMasked(nn.Module):
def __init__(self, seq_len_norm):
super(MSELossMasked, self).__init__()
self.seq_len_norm = seq_len_norm
@ -67,19 +65,21 @@ class MSELossMasked(nn.Module):
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2).float()
if self.seq_len_norm:
norm_w = mask / mask.sum(dim=1, keepdim=True)
out_weights = norm_w.div(target.shape[0] * target.shape[2])
mask = mask.expand_as(x)
loss = functional.mse_loss(
x * mask, target * mask, reduction='none')
loss = functional.mse_loss(x * mask,
target * mask,
reduction='none')
loss = loss.mul(out_weights.to(loss.device)).sum()
else:
mask = mask.expand_as(x)
loss = functional.mse_loss(
x * mask, target * mask, reduction='sum')
loss = functional.mse_loss(x * mask,
target * mask,
reduction='sum')
loss = loss / mask.sum()
return loss
@ -100,7 +100,6 @@ class AttentionEntropyLoss(nn.Module):
class BCELossMasked(nn.Module):
def __init__(self, pos_weight):
super(BCELossMasked, self).__init__()
self.pos_weight = pos_weight
@ -121,9 +120,13 @@ class BCELossMasked(nn.Module):
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).float()
loss = functional.binary_cross_entropy_with_logits(
x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum')
x * mask,
target * mask,
pos_weight=self.pos_weight,
reduction='sum')
loss = loss / mask.sum()
return loss
@ -139,7 +142,8 @@ class GuidedAttentionLoss(torch.nn.Module):
max_olen = max(olens)
ga_masks = torch.zeros((B, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma)
ga_masks[idx, :olen, :ilen] = self._make_ga_mask(
ilen, olen, self.sigma)
return ga_masks
def forward(self, att_ws, ilens, olens):
@ -153,7 +157,8 @@ class GuidedAttentionLoss(torch.nn.Module):
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2)))
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen)**2 /
(2 * (sigma**2)))
@staticmethod
def _make_masks(ilens, olens):
@ -181,7 +186,8 @@ class TacotronLoss(torch.nn.Module):
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
# stopnet loss
# pylint: disable=not-callable
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
self.criterion_st = BCELossMasked(
pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
stopnet_output, stopnet_target, output_lens, decoder_b_output,
@ -219,19 +225,25 @@ class TacotronLoss(torch.nn.Module):
# backward decoder loss (if enabled)
if self.config.bidirectional_decoder:
if self.config.loss_masking:
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input, output_lens)
decoder_b_loss = self.criterion(
torch.flip(decoder_b_output, dims=(1, )), mel_input,
output_lens)
else:
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1, )), decoder_output)
decoder_b_loss = self.criterion(
torch.flip(decoder_b_output, dims=(1, )), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(
torch.flip(decoder_b_output, dims=(1, )), decoder_output)
loss += decoder_b_loss + decoder_c_loss
return_dict['decoder_b_loss'] = decoder_b_loss
return_dict['decoder_c_loss'] = decoder_c_loss
# double decoder consistency loss (if enabled)
if self.config.double_decoder_consistency:
decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens)
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
output_lens)
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
attention_c_loss = torch.nn.functional.l1_loss(
alignments, alignments_backwards)
loss += decoder_b_loss + attention_c_loss
return_dict['decoder_coarse_loss'] = decoder_b_loss
return_dict['decoder_ddc_loss'] = attention_c_loss
@ -248,7 +260,7 @@ class TacotronLoss(torch.nn.Module):
class GlowTTSLoss(torch.nn.Module):
def __init__(self):
super(GlowTTSLoss, self).__init__()
super().__init__()
self.constant_factor = 0.5 * math.log(2 * math.pi)
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log,

View File

@ -12,14 +12,13 @@ class FullbandMelganGenerator(MelganGenerator):
upsample_factors=(2, 8, 2, 2),
res_kernel=3,
num_res_blocks=4):
super(FullbandMelganGenerator,
self).__init__(in_channels=in_channels,
out_channels=out_channels,
proj_kernel=proj_kernel,
base_channels=base_channels,
upsample_factors=upsample_factors,
res_kernel=res_kernel,
num_res_blocks=num_res_blocks)
super().__init__(in_channels=in_channels,
out_channels=out_channels,
proj_kernel=proj_kernel,
base_channels=base_channels,
upsample_factors=upsample_factors,
res_kernel=res_kernel,
num_res_blocks=num_res_blocks)
@torch.no_grad()
def inference(self, cond_features):