From 9a48ba382116810229333fae7b85ff15dac403fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 8 Mar 2021 05:06:54 +0100 Subject: [PATCH] a ton of linter updates --- TTS/bin/synthesize.py | 2 +- TTS/bin/train_glow_tts.py | 2 +- TTS/bin/train_speedy_speech.py | 4 +- TTS/bin/train_tacotron.py | 2 + TTS/bin/train_vocoder_gan.py | 2 +- TTS/server/server.py | 2 +- TTS/tts/datasets/TTSDataset.py | 2 +- TTS/tts/datasets/preprocess.py | 4 +- TTS/tts/layers/attentions.py | 26 ++--- TTS/tts/layers/generic/gated_conv.py | 2 +- TTS/tts/layers/generic/normalization.py | 2 +- TTS/tts/layers/generic/res_conv_bn.py | 2 +- TTS/tts/layers/generic/wavenet.py | 2 +- TTS/tts/layers/glow_tts/encoder.py | 20 ++-- TTS/tts/layers/glow_tts/transformer.py | 6 +- TTS/tts/layers/gst_layers.py | 4 +- TTS/tts/layers/losses.py | 9 +- TTS/tts/layers/speedy_speech/decoder.py | 24 ++-- TTS/tts/layers/speedy_speech/encoder.py | 47 ++++---- TTS/tts/models/speedy_speech.py | 50 ++++---- TTS/tts/utils/chinese_mandarin/numbers.py | 8 +- .../chinese_mandarin/pinyinToPhonemes.py | 6 +- TTS/tts/utils/generic_utils.py | 2 +- TTS/tts/utils/io.py | 4 +- TTS/tts/utils/speakers.py | 6 +- TTS/tts/utils/ssim.py | 22 ++-- TTS/tts/utils/synthesis.py | 14 ++- TTS/tts/utils/text/abbreviations.py | 108 +++++++++--------- TTS/utils/io.py | 6 +- TTS/utils/manage.py | 5 - TTS/vocoder/layers/losses.py | 2 +- TTS/vocoder/layers/wavegrad.py | 9 +- TTS/vocoder/models/wavegrad.py | 6 +- TTS/vocoder/models/wavernn.py | 23 ++-- TTS/vocoder/utils/distribution.py | 8 +- TTS/vocoder/utils/io.py | 2 +- tests/test_glow_tts.py | 2 +- tests/test_layers.py | 3 +- tests/test_model_manager.py | 2 +- tests/test_speedy_speech_layers.py | 4 +- tests/test_tacotron2_model.py | 2 +- tests/test_tacotron_model.py | 1 - tests/test_text_cleaners.py | 4 +- tests/test_text_processing.py | 2 +- tests/test_wavegrad_train.py | 20 ++-- 45 files changed, 244 insertions(+), 241 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 009affe5..1035c326 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -170,7 +170,7 @@ def main(): args.vocoder_name = model_item['default_vocoder'] if args.vocoder_name is None else args.vocoder_name if args.vocoder_name is not None: - vocoder_path, vocoder_config_path, vocoder_item = manager.download_model(args.vocoder_name) + vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) # CASE3: load custome models if args.model_path is not None: diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 2ec8eb1c..23695f70 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -573,7 +573,7 @@ def main(args): # pylint: disable=redefined-outer-name if c.run_eval: target_loss = eval_avg_loss_dict['avg_loss'] best_loss = save_best_model(target_loss, best_loss, model, optimizer, - global_step, epoch, c.r, OUT_PATH, + global_step, epoch, c.r, OUT_PATH, model_characters, keep_all_best=keep_all_best, keep_after=keep_after) diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py index 9c8b490e..a2ac6028 100644 --- a/TTS/bin/train_speedy_speech.py +++ b/TTS/bin/train_speedy_speech.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import argparse -import glob import os import sys import time @@ -535,7 +533,7 @@ def main(args): # pylint: disable=redefined-outer-name if c.run_eval: target_loss = eval_avg_loss_dict['avg_loss'] best_loss = save_best_model(target_loss, best_loss, model, optimizer, - global_step, epoch, c.r, OUT_PATH, + global_step, epoch, c.r, OUT_PATH, model_characters, keep_all_best=keep_all_best, keep_after=keep_after) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 86f2c9d6..2ceee8cc 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -648,12 +648,14 @@ def main(args): # pylint: disable=redefined-outer-name epoch, c.r, OUT_PATH, + model_characters, keep_all_best=keep_all_best, keep_after=keep_after, scaler=scaler.state_dict() if c.mixed_precision else None ) + if __name__ == '__main__': args = parse_arguments(sys.argv) c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 708bf350..a4872361 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -50,7 +50,7 @@ def setup_loader(ap, is_val=False, verbose=False): sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None loader = DataLoader(dataset, batch_size=1 if is_val else c.batch_size, - shuffle=False if num_gpus > 1 else True, + shuffle=num_gpus == 0, drop_last=False, sampler=sampler, num_workers=c.num_val_loader_workers diff --git a/TTS/server/server.py b/TTS/server/server.py index 7cf98394..d2f56843 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -59,7 +59,7 @@ if args.list_models: # set models by the released models if args.model_name is not None: tts_checkpoint_file, tts_config_file, tts_json_dict = manager.download_model(args.model_name) - args.vocoder_name = tts_json_dict['default_vocoder'] if args.vocoder_name is None else args.vocoder_name + args.vocoder_name = tts_json_dict['default_vocoder'] if args.vocoder_name is None else args.vocoder_name if args.vocoder_name is not None: vocoder_checkpoint_file, vocoder_config_file, vocoder_json_dict = manager.download_model(args.vocoder_name) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 16329ad7..3d85e000 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -1,7 +1,7 @@ import collections import os import random -from multiprocessing import Manager, Pool +from multiprocessing import Pool import numpy as np import torch diff --git a/TTS/tts/datasets/preprocess.py b/TTS/tts/datasets/preprocess.py index 7cb4edc1..439a4091 100644 --- a/TTS/tts/datasets/preprocess.py +++ b/TTS/tts/datasets/preprocess.py @@ -3,7 +3,7 @@ from glob import glob import re import sys from pathlib import Path -from typing import List, Tuple +from typing import List from tqdm import tqdm @@ -377,7 +377,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]: Args: root_path (str): path to the baker dataset - meta_file (str): name of the meta dataset containing names of wav to select and the transcript of the sentence + meta_file (str): name of the meta dataset containing names of wav to select and the transcript of the sentence Returns: List[List[str]]: List of (text, wav_path, speaker_name) associated with each sentences """ diff --git a/TTS/tts/layers/attentions.py b/TTS/tts/layers/attentions.py index 047e3b23..f7c720a7 100644 --- a/TTS/tts/layers/attentions.py +++ b/TTS/tts/layers/attentions.py @@ -367,18 +367,18 @@ class MonotonicDynamicConvolutionAttention(nn.Module): beta (float, optional): [description]. Defaults to 0.9 from the paper. """ def __init__( - self, - query_dim, - embedding_dim, # pylint: disable=unused-argument - attention_dim, - static_filter_dim, - static_kernel_size, - dynamic_filter_dim, - dynamic_kernel_size, - prior_filter_len=11, - alpha=0.1, - beta=0.9, - ): + self, + query_dim, + embedding_dim, # pylint: disable=unused-argument + attention_dim, + static_filter_dim, + static_kernel_size, + dynamic_filter_dim, + dynamic_kernel_size, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ): super().__init__() self._mask_value = 1e-8 self.dynamic_filter_dim = dynamic_filter_dim @@ -402,7 +402,7 @@ class MonotonicDynamicConvolutionAttention(nn.Module): self.v = nn.Linear(attention_dim, 1, bias=False) prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, - alpha, beta) + alpha, beta) self.register_buffer("prior", torch.FloatTensor(prior).flip(0)) # pylint: disable=unused-argument diff --git a/TTS/tts/layers/generic/gated_conv.py b/TTS/tts/layers/generic/gated_conv.py index dbe0f0f0..ec95565a 100644 --- a/TTS/tts/layers/generic/gated_conv.py +++ b/TTS/tts/layers/generic/gated_conv.py @@ -40,4 +40,4 @@ class GatedConvBlock(nn.Module): o = nn.functional.glu(o, dim=1) o = res + o res = o - return o \ No newline at end of file + return o diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py index 5ccdeb47..e3dbb52f 100644 --- a/TTS/tts/layers/generic/normalization.py +++ b/TTS/tts/layers/generic/normalization.py @@ -104,4 +104,4 @@ class ActNorm(nn.Module): dtype=self.logs.dtype) self.bias.data.copy_(bias_init) - self.logs.data.copy_(logs_init) \ No newline at end of file + self.logs.data.copy_(logs_init) diff --git a/TTS/tts/layers/generic/res_conv_bn.py b/TTS/tts/layers/generic/res_conv_bn.py index 322cab94..964afd0a 100644 --- a/TTS/tts/layers/generic/res_conv_bn.py +++ b/TTS/tts/layers/generic/res_conv_bn.py @@ -97,7 +97,7 @@ class ResidualConv1dBNBlock(nn.Module): assert len(dilations) == num_res_blocks self.res_blocks = nn.ModuleList() for idx, dilation in enumerate(dilations): - block = Conv1dBNBlock(in_channels if idx==0 else hidden_channels, + block = Conv1dBNBlock(in_channels if idx == 0 else hidden_channels, out_channels if (idx + 1) == len(dilations) else hidden_channels, hidden_channels, kernel_size, diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index 9906aa4a..dbfa9f69 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -167,4 +167,4 @@ class WNBlocks(nn.Module): o = x for layer in self.wn_blocks: o = layer(o, x_mask, g) - return o \ No newline at end of file + return o diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index 9a1508ee..8de006a9 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -98,11 +98,11 @@ class Encoder(nn.Module): if encoder_type.lower() == "rel_pos_transformer": if use_prenet: self.prenet = ResidualConv1dLayerNormBlock(hidden_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_layers=3, - dropout_p=0.5) + hidden_channels, + hidden_channels, + kernel_size=5, + num_layers=3, + dropout_p=0.5) self.encoder = RelativePositionTransformer(hidden_channels, hidden_channels, hidden_channels, @@ -125,11 +125,11 @@ class Encoder(nn.Module): elif encoder_type.lower() == 'time_depth_separable': if use_prenet: self.prenet = ResidualConv1dLayerNormBlock(hidden_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_layers=3, - dropout_p=0.5) + hidden_channels, + hidden_channels, + kernel_size=5, + num_layers=3, + dropout_p=0.5) self.encoder = TimeDepthSeparableConvBlock(hidden_channels, hidden_channels, hidden_channels, diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 4feadc80..77ea05f9 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -366,8 +366,10 @@ class RelativePositionTransformer(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.ffn_layers.append( - FeedForwardNetwork(hidden_channels, - hidden_channels if (idx + 1) != self.num_layers else out_channels, + FeedForwardNetwork( + hidden_channels, + hidden_channels if + (idx + 1) != self.num_layers else out_channels, hidden_channels_ffn, kernel_size, dropout_p=dropout_p)) diff --git a/TTS/tts/layers/gst_layers.py b/TTS/tts/layers/gst_layers.py index 381d881a..63e76070 100644 --- a/TTS/tts/layers/gst_layers.py +++ b/TTS/tts/layers/gst_layers.py @@ -75,7 +75,7 @@ class ReferenceEncoder(nn.Module): # x: 3D tensor [batch_size, post_conv_width, # num_channels*post_conv_height] self.recurrence.flatten_parameters() - memory, out = self.recurrence(x) + _, out = self.recurrence(x) # out: 3D tensor [seq_len==1, batch_size, encoding_size=128] return out.squeeze(0) @@ -173,4 +173,4 @@ class MultiHeadAttention(nn.Module): torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] - return out \ No newline at end of file + return out diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index ef68d1d0..50575b80 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -2,13 +2,12 @@ import math import numpy as np import torch from torch import nn -from inspect import signature from torch.nn import functional from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.ssim import ssim -# pylint: disable=abstract-method Method +# pylint: disable=abstract-method # relates https://github.com/pytorch/pytorch/issues/42305 class L1LossMasked(nn.Module): def __init__(self, seq_len_norm): @@ -165,7 +164,7 @@ class BCELossMasked(nn.Module): target.requires_grad = False if length is not None: mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).float() + max_len=target.size(1)).float() x = x * mask target = target * mask num_items = mask.sum() @@ -310,10 +309,10 @@ class TacotronLoss(torch.nn.Module): if self.postnet_alpha > 0: if self.config.model in ["Tacotron", "TacotronGST"]: postnet_loss = self.criterion(postnet_output, linear_input, - output_lens) + output_lens) else: postnet_loss = self.criterion(postnet_output, mel_input, - output_lens) + output_lens) else: if self.decoder_alpha > 0: decoder_loss = self.criterion(decoder_output, mel_input) diff --git a/TTS/tts/layers/speedy_speech/decoder.py b/TTS/tts/layers/speedy_speech/decoder.py index 5ffb3339..6d32c914 100644 --- a/TTS/tts/layers/speedy_speech/decoder.py +++ b/TTS/tts/layers/speedy_speech/decoder.py @@ -146,17 +146,17 @@ class Decoder(nn.Module): # pylint: disable=dangerous-default-value def __init__( - self, - out_channels, - in_hidden_channels, - decoder_type='residual_conv_bn', - decoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17 - }, - c_in_channels=0): + self, + out_channels, + in_hidden_channels, + decoder_type='residual_conv_bn', + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + }, + c_in_channels=0): super().__init__() if decoder_type == 'transformer': @@ -189,4 +189,4 @@ class Decoder(nn.Module): """ # TODO: implement multi-speaker o = self.decoder(x, x_mask, g) - return o \ No newline at end of file + return o diff --git a/TTS/tts/layers/speedy_speech/encoder.py b/TTS/tts/layers/speedy_speech/encoder.py index d26b306c..8086286c 100644 --- a/TTS/tts/layers/speedy_speech/encoder.py +++ b/TTS/tts/layers/speedy_speech/encoder.py @@ -73,13 +73,12 @@ class RelativePositionTransformerEncoder(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.prenet = ResidualConv1dBNBlock(in_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_res_blocks=3, - num_conv_blocks=1, - dilations=[1, 1, 1] - ) + hidden_channels, + hidden_channels, + kernel_size=5, + num_res_blocks=3, + num_conv_blocks=1, + dilations=[1, 1, 1]) self.rel_pos_transformer = RelativePositionTransformer( hidden_channels, out_channels, hidden_channels, **params) @@ -104,9 +103,8 @@ class ResidualConv1dBNEncoder(nn.Module): """ def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() - self.prenet = nn.Sequential( - nn.Conv1d(in_channels, hidden_channels, 1), - nn.ReLU()) + self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), + nn.ReLU()) self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) @@ -162,17 +160,17 @@ class Encoder(nn.Module): } """ def __init__( - self, - in_hidden_channels, - out_channels, - encoder_type='residual_conv_bn', - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13 - }, - c_in_channels=0): + self, + in_hidden_channels, + out_channels, + encoder_type='residual_conv_bn', + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + }, + c_in_channels=0): super().__init__() self.out_channels = out_channels self.in_channels = in_hidden_channels @@ -183,10 +181,9 @@ class Encoder(nn.Module): # init encoder if encoder_type.lower() == "transformer": # text encoder - self.encoder = RelativePositionTransformerEncoder(in_hidden_channels, - out_channels, - in_hidden_channels, - encoder_params) # pylint: disable=unexpected-keyword-arg + self.encoder = RelativePositionTransformerEncoder( + in_hidden_channels, out_channels, in_hidden_channels, + encoder_params) # pylint: disable=unexpected-keyword-arg elif encoder_type.lower() == 'residual_conv_bn': self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 93496d59..886d6fd4 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -33,32 +33,32 @@ class SpeedySpeech(nn.Module): external_c (bool, optional): enable external speaker embeddings. Defaults to False. c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0. """ -# pylint: disable=dangerous-default-value + # pylint: disable=dangerous-default-value def __init__( - self, - num_chars, - out_channels, - hidden_channels, - positional_encoding=True, - length_scale=1, - encoder_type='residual_conv_bn', - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13 - }, - decoder_type='residual_conv_bn', - decoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17 - }, - num_speakers=0, - external_c=False, - c_in_channels=0): + self, + num_chars, + out_channels, + hidden_channels, + positional_encoding=True, + length_scale=1, + encoder_type='residual_conv_bn', + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + }, + decoder_type='residual_conv_bn', + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + }, + num_speakers=0, + external_c=False, + c_in_channels=0): super().__init__() self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale @@ -171,7 +171,7 @@ class SpeedySpeech(nn.Module): """ o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) - o_de, attn= self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) return o_de, o_dr_log.squeeze(1), attn def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument diff --git a/TTS/tts/utils/chinese_mandarin/numbers.py b/TTS/tts/utils/chinese_mandarin/numbers.py index 0befe6b1..94c8fd03 100644 --- a/TTS/tts/utils/chinese_mandarin/numbers.py +++ b/TTS/tts/utils/chinese_mandarin/numbers.py @@ -10,7 +10,7 @@ import re import itertools -def _num2chinese(num :str, big=False, simp=True, o=False, twoalt=False) -> str: +def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: """Convert numerical arabic numbers (0->9) to chinese hanzi numbers (〇 -> 九) Args: @@ -32,7 +32,7 @@ def _num2chinese(num :str, big=False, simp=True, o=False, twoalt=False) -> str: nd = str(num) if abs(float(nd)) >= 1e48: raise ValueError('number out of range') - elif 'e' in nd: + if 'e' in nd: raise ValueError('scientific notation is not supported') c_symbol = '正负点' if simp else '正負點' if o: # formal @@ -69,7 +69,7 @@ def _num2chinese(num :str, big=False, simp=True, o=False, twoalt=False) -> str: if int(unit) == 0: # 0000 intresult.append(c_basic[0]) continue - elif nu > 0 and int(unit) == 2: # 0002 + if nu > 0 and int(unit) == 2: # 0002 intresult.append(c_twoalt + c_unit2[nu - 1]) continue ulist = [] @@ -128,4 +128,4 @@ def replace_numbers_to_characters_in_text(text: str) -> str: str: output text """ text = re.sub(r'[0-9]+', _number_replace, text) - return text \ No newline at end of file + return text diff --git a/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py b/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py index cdca44ac..a4722ff9 100644 --- a/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py +++ b/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py @@ -9,7 +9,7 @@ PINYIN_DICT = { "bai": ["bai"], "ban": ["ban"], "bang": ["bɑŋ"], - "bao": ["baʌ"], + "bao": ["baʌ"], # "be": ["be"], doesnt exist "bei": ["bɛi"], "ben": ["bœn"], @@ -377,7 +377,7 @@ PINYIN_DICT = { "yong": ["ioŋ"], "you": ["io"], "yu": ["y"], - "yuan": ["yɛn"], + "yuan": ["yɛn"], "yue": ["ye"], "yun": ["yn"], "za": ["dza"], @@ -417,4 +417,4 @@ PINYIN_DICT = { "zui": ["dzuei"], "zun": ["dzun"], "zuo": ["dzuo"], -} \ No newline at end of file +} diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index d898aebd..0d236fbc 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -135,7 +135,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): return model def is_tacotron(c): - return False if c['model'] in ['speedy_speech', 'glow_tts'] else True + return not c['model'] in ['speedy_speech', 'glow_tts'] def check_config_tts(c): check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech'], restricted=True, val_type=str) diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index fe94d98d..bcf5ff37 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -7,7 +7,7 @@ from TTS.utils.io import RenamingUnpickler -def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): +def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin """Load ```TTS.tts.models``` checkpoints. Args: @@ -98,7 +98,7 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, def save_best_model(target_loss, best_loss, model, optimizer, current_step, - epoch, r, output_folder, characters, **kwargs): + epoch, r, output_folder, characters, **kwargs): """Save model checkpoint, intended for saving the best model after each epoch. It compares the current model loss with the best loss so far and saves the model if the current loss is better. diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 43bb1f6a..feb1a845 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -63,11 +63,11 @@ def parse_speakers(c, args, meta_data_train, OUT_PATH): speaker_embedding_dim = None save_speaker_mapping(OUT_PATH, speaker_mapping) num_speakers = len(speaker_mapping) - print(" > Training with {} speakers: {}".format(len(speakers), - ", ".join(speakers))) + print(" > Training with {} speakers: {}".format( + len(speakers), ", ".join(speakers))) else: num_speakers = 0 speaker_embedding_dim = None speaker_mapping = None - return num_speakers, speaker_embedding_dim, speaker_mapping \ No newline at end of file + return num_speakers, speaker_embedding_dim, speaker_mapping diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 399d0898..8f4c4cae 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -17,17 +17,22 @@ def create_window(window_size, channel): window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window -def _ssim(img1, img2, window, window_size, channel, size_average = True): - mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) - mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1*mu2 - sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq - sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq - sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + sigma1_sq = F.conv2d( + img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=window_size // 2, + groups=channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 @@ -39,7 +44,7 @@ def _ssim(img1, img2, window, window_size, channel, size_average = True): return ssim_map.mean(1).mean(1).mean(1) class SSIM(torch.nn.Module): - def __init__(self, window_size = 11, size_average = True): + def __init__(self, window_size=11, size_average=True): super().__init__() self.window_size = window_size self.size_average = size_average @@ -64,7 +69,8 @@ class SSIM(torch.nn.Module): return _ssim(img1, img2, window, self.window_size, channel, self.size_average) -def ssim(img1, img2, window_size = 11, size_average = True): + +def ssim(img1, img2, window_size=11, size_average=True): (_, channel, _, _) = img1.size() window = create_window(window_size, channel) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index b35c7db3..a0524e8f 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -20,9 +20,13 @@ def text_to_seqvec(text, CONFIG): add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32) else: - seq = np.asarray( - text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, - add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32) + seq = np.asarray(text_to_sequence( + text, + text_cleaner, + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, + add_blank=CONFIG['add_blank'] + if 'add_blank' in CONFIG.keys() else False), + dtype=np.int32) return seq @@ -77,9 +81,9 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable if hasattr(model, 'module'): # distributed model - postnet_output, alignments= model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) else: - postnet_output, alignments= model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None diff --git a/TTS/tts/utils/text/abbreviations.py b/TTS/tts/utils/text/abbreviations.py index d14426e1..fe4c1cdc 100644 --- a/TTS/tts/utils/text/abbreviations.py +++ b/TTS/tts/utils/text/abbreviations.py @@ -2,60 +2,60 @@ import re # List of (regular expression, replacement) pairs for abbreviations in english: abbreviations_en = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) - for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), - ]] + for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), + ]] # List of (regular expression, replacement) pairs for abbreviations in french: abbreviations_fr = [(re.compile('\\b%s\\.?' % x[0], re.IGNORECASE), x[1]) - for x in [ - ('M', 'monsieur'), - ('Mlle', 'mademoiselle'), - ('Mlles', 'mesdemoiselles'), - ('Mme', 'Madame'), - ('Mmes', 'Mesdames'), - ('N.B', 'nota bene'), - ('M', 'monsieur'), - ('p.c.q', 'parce que'), - ('Pr', 'professeur'), - ('qqch', 'quelque chose'), - ('rdv', 'rendez-vous'), - ('max', 'maximum'), - ('min', 'minimum'), - ('no', 'numéro'), - ('adr', 'adresse'), - ('dr', 'docteur'), - ('st', 'saint'), - ('co', 'companie'), - ('jr', 'junior'), - ('sgt', 'sergent'), - ('capt', 'capitain'), - ('col', 'colonel'), - ('av', 'avenue'), - ('av. J.-C', 'avant Jésus-Christ'), - ('apr. J.-C', 'après Jésus-Christ'), - ('art', 'article'), - ('boul', 'boulevard'), - ('c.-à-d', 'c’est-à-dire'), - ('etc', 'et cetera'), - ('ex', 'exemple'), - ('excl', 'exclusivement'), - ('boul', 'boulevard'), - ]] \ No newline at end of file + for x in [ + ('M', 'monsieur'), + ('Mlle', 'mademoiselle'), + ('Mlles', 'mesdemoiselles'), + ('Mme', 'Madame'), + ('Mmes', 'Mesdames'), + ('N.B', 'nota bene'), + ('M', 'monsieur'), + ('p.c.q', 'parce que'), + ('Pr', 'professeur'), + ('qqch', 'quelque chose'), + ('rdv', 'rendez-vous'), + ('max', 'maximum'), + ('min', 'minimum'), + ('no', 'numéro'), + ('adr', 'adresse'), + ('dr', 'docteur'), + ('st', 'saint'), + ('co', 'companie'), + ('jr', 'junior'), + ('sgt', 'sergent'), + ('capt', 'capitain'), + ('col', 'colonel'), + ('av', 'avenue'), + ('av. J.-C', 'avant Jésus-Christ'), + ('apr. J.-C', 'après Jésus-Christ'), + ('art', 'article'), + ('boul', 'boulevard'), + ('c.-à-d', 'c’est-à-dire'), + ('etc', 'et cetera'), + ('ex', 'exemple'), + ('excl', 'exclusivement'), + ('boul', 'boulevard'), + ]] diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 30b7b7e2..1eb5e630 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -22,7 +22,7 @@ class AttrDict(dict): def read_json_with_comments(json_path): # fallback to json - with open(json_path, "r", encoding = "utf-8") as f: + with open(json_path, "r", encoding="utf-8") as f: input_str = f.read() # handle comments input_str = re.sub(r'\\\n', '', input_str) @@ -40,7 +40,7 @@ def load_config(config_path: str) -> AttrDict: ext = os.path.splitext(config_path)[1] if ext in (".yml", ".yaml"): - with open(config_path, "r", encoding = "utf-8") as f: + with open(config_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) else: data = read_json_with_comments(config_path) @@ -61,7 +61,7 @@ def copy_model_files(c, config_file, out_path, new_fields): """ # copy config.json copy_config_path = os.path.join(out_path, 'config.json') - config_lines = open(config_file, "r", encoding = "utf-8").readlines() + config_lines = open(config_file, "r", encoding="utf-8").readlines() # add extra information fields for key, value in new_fields.items(): if isinstance(value, str): diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index eb57c8be..12f930ac 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -144,8 +144,3 @@ class ModelManager(object): if isinstance(key, str) and len(my_dict[key]) > 0: return True return False - - - - - diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 1107b3c5..34c2f9b7 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import functional as F -class TorchSTFT(nn.Module): +class TorchSTFT(nn.Module): # pylint: disable=abstract-method def __init__(self, n_fft, hop_length, win_length, window='hann_window'): """ Torch based STFT operation """ super(TorchSTFT, self).__init__() diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index d09b4950..81f03124 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -22,8 +22,10 @@ class PositionalEncoding(nn.Module): def forward(self, x, noise_level): if x.shape[2] > self.pe.shape[1]: - self.init_pe_matrix(x.shape[1] ,x.shape[2], x) - return x + noise_level[..., None, None] + self.pe[:, :x.size(2)].repeat(x.shape[0], 1, 1) / self.C + self.init_pe_matrix(x.shape[1], x.shape[2], x) + return x + noise_level[..., None, + None] + self.pe[:, :x.size(2)].repeat( + x.shape[0], 1, 1) / self.C def init_pe_matrix(self, n_channels, max_len, x): pe = torch.zeros(max_len, n_channels) @@ -171,5 +173,4 @@ class DBlock(nn.Module): self.res_block = weight_norm(self.res_block) for idx, layer in enumerate(self.main_block): if len(layer.state_dict()) != 0: - self.main_block[idx] = weight_norm(layer) - + self.main_block[idx] = weight_norm(layer) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index f4a5faa3..96951ad1 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -79,7 +79,7 @@ class Wavegrad(nn.Module): return x def load_noise_schedule(self, path): - beta = np.load(path, allow_pickle=True).item()['beta'] + beta = np.load(path, allow_pickle=True).item()['beta'] # pylint: disable=unexpected-keyword-arg self.compute_noise_level(beta) @torch.no_grad() @@ -91,8 +91,8 @@ class Wavegrad(nn.Module): y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x) sqrt_alpha_hat = self.noise_level.to(x) for n in range(len(self.alpha) - 1, -1, -1): - y_n = self.c1[n] * (y_n - - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) + y_n = self.c1[n] * (y_n - self.c2[n] * self.forward( + y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) if n > 0: z = torch.randn_like(y_n) y_n += self.sigma[n - 1] * z diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index fdb71cff..dbcaea66 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -73,15 +73,15 @@ class Stretch2d(nn.Module): class UpsampleNetwork(nn.Module): def __init__( - self, - feat_dims, - upsample_scales, - compute_dims, - num_res_blocks, - res_out_dims, - pad, - use_aux_net, - ): + self, + feat_dims, + upsample_scales, + compute_dims, + num_res_blocks, + res_out_dims, + pad, + use_aux_net, + ): super().__init__() self.total_scale = np.cumproduct(upsample_scales)[-1] self.indent = pad * self.total_scale @@ -118,9 +118,8 @@ class UpsampleNetwork(nn.Module): class Upsample(nn.Module): - def __init__( - self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net - ): + def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, + res_out_dims, use_aux_net): super().__init__() self.scale = scale self.pad = pad diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py index 6aba5e34..b0553ed0 100644 --- a/TTS/vocoder/utils/distribution.py +++ b/TTS/vocoder/utils/distribution.py @@ -44,9 +44,11 @@ def log_sum_exp(x): # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py -def discretized_mix_logistic_loss( - y_hat, y, num_classes=65536, log_scale_min=None, reduce=True -): +def discretized_mix_logistic_loss(y_hat, + y, + num_classes=65536, + log_scale_min=None, + reduce=True): if log_scale_min is None: log_scale_min = float(np.log(1e-14)) y_hat = y_hat.permute(0, 2, 1) diff --git a/TTS/vocoder/utils/io.py b/TTS/vocoder/utils/io.py index 60def72a..f3bc9bad 100644 --- a/TTS/vocoder/utils/io.py +++ b/TTS/vocoder/utils/io.py @@ -7,7 +7,7 @@ import pickle as pickle_tts from TTS.utils.io import RenamingUnpickler -def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): +def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin try: state = torch.load(checkpoint_path, map_location=torch.device('cpu')) except ModuleNotFoundError: diff --git a/tests/test_glow_tts.py b/tests/test_glow_tts.py index d290baf5..670b7b67 100644 --- a/tests/test_glow_tts.py +++ b/tests/test_glow_tts.py @@ -130,4 +130,4 @@ class GlowTTSTrainTest(unittest.TestCase): assert (param != param_ref).any( ), "param {} with shape {} not updated!! \n{}\n{}".format( count, param.shape, param, param_ref) - count += 1 \ No newline at end of file + count += 1 diff --git a/tests/test_layers.py b/tests/test_layers.py index 5426e195..1a07b750 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -166,7 +166,7 @@ class SSIMLossTests(unittest.TestCase): dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.ones(4) * 8).long() output = layer(dummy_input, dummy_target, dummy_length) - assert abs(output.item() - 1.0) < 1e-4 , "1.0 vs {}".format(output.item()) + assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item()) # test if padded values of input makes any difference dummy_input = T.ones(4, 8, 128).float() @@ -217,4 +217,3 @@ class SSIMLossTests(unittest.TestCase): (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item()) - diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index 7807716f..4445b091 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -17,4 +17,4 @@ def test_if_all_models_available(): folders = glob.glob(os.path.join(manager.output_prefix, '*')) assert len(folders) == len(model_names) - shutil.rmtree(manager.output_prefix) \ No newline at end of file + shutil.rmtree(manager.output_prefix) diff --git a/tests/test_speedy_speech_layers.py b/tests/test_speedy_speech_layers.py index a5567ac3..53351fff 100644 --- a/tests/test_speedy_speech_layers.py +++ b/tests/test_speedy_speech_layers.py @@ -161,8 +161,8 @@ def test_speedy_speech(): x_lengths, y_lengths, durations, - g=torch.rand((B,256)).to(device)) + g=torch.rand((B, 256)).to(device)) assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}" assert list(attn.shape) == [B, T_de, T_en] - assert list(o_dr.shape) == [B, T_en] \ No newline at end of file + assert list(o_dr.shape) == [B, T_en] diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index 38f4c737..4ac07118 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -292,4 +292,4 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase): assert (param != param_ref).any( ), "param {} with shape {} not updated!! \n{}\n{}".format( count, param.shape, param, param_ref) - count += 1 \ No newline at end of file + count += 1 diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index c56a6565..f8e88160 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -356,4 +356,3 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase): ), "param {} with shape {} not updated!! \n{}\n{}".format( count, param.shape, param, param_ref) count += 1 - diff --git a/tests/test_text_cleaners.py b/tests/test_text_cleaners.py index 7a2abe72..b301fb5a 100644 --- a/tests/test_text_cleaners.py +++ b/tests/test_text_cleaners.py @@ -17,5 +17,5 @@ def test_currency() -> None: def test_expand_numbers() -> None: - assert "minus one" == phoneme_cleaners("-1") - assert "one" == phoneme_cleaners("1") + assert phoneme_cleaners("-1") == 'minus one' + assert phoneme_cleaners("1") == 'one' diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index 2ea8e8f9..646f2592 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -17,7 +17,7 @@ def test_phoneme_to_sequence(): lang = "en-us" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) - sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = 'ɹiːsənt ɹᵻsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!' assert text_hat == text_hat_with_params == gt diff --git a/tests/test_wavegrad_train.py b/tests/test_wavegrad_train.py index 700e94d1..45f75e3b 100644 --- a/tests/test_wavegrad_train.py +++ b/tests/test_wavegrad_train.py @@ -20,18 +20,18 @@ class WavegradTrainTest(unittest.TestCase): criterion = torch.nn.L1Loss().to(device) model = Wavegrad(in_channels=80, - out_channels=1, - upsample_factors=[5, 5, 3, 2, 2], - upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], - [1, 2, 4, 8], [1, 2, 4, 8], - [1, 2, 4, 8]]) + out_channels=1, + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], + [1, 2, 4, 8], [1, 2, 4, 8], + [1, 2, 4, 8]]) model_ref = Wavegrad(in_channels=80, - out_channels=1, - upsample_factors=[5, 5, 3, 2, 2], - upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], - [1, 2, 4, 8], [1, 2, 4, 8], - [1, 2, 4, 8]]) + out_channels=1, + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], + [1, 2, 4, 8], [1, 2, 4, 8], + [1, 2, 4, 8]]) model.train() model.to(device) betas = np.linspace(1e-6, 1e-2, 1000)