diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index cf9d98d2..3d34d978 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -8,16 +8,14 @@ import sys import time import traceback -import numpy as np import torch -from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.utils.distribute import (DistributedSampler, init_distributed, reduce_tensor) -from TTS.tts.utils.generic_utils import check_config, setup_model +from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, @@ -33,21 +31,20 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.io import copy_config_file, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger -from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, - gradual_training_scheduler, set_weight_decay, +from TTS.utils.training import (NoamLR, check_update, setup_torch_training_env) use_cuda, num_gpus = setup_torch_training_env(True, False) - def setup_loader(ap, r, is_val=False, verbose=False): + if is_val and not c.run_eval: loader = None else: dataset = MyDataset( r, c.text_cleaner, - compute_linear_spec=True if c.model.lower() == 'tacotron' else False, + compute_linear_spec=False, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, tp=c.characters if 'characters' in c.keys() else None, @@ -125,11 +122,11 @@ def data_depended_init(model, ap): model.train() print(" > Data depended initialization ... ") with torch.no_grad(): - for num_iter, data in enumerate(data_loader): + for _, data in enumerate(data_loader): # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ - avg_text_length, avg_spec_length, attn_mask = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, _,\ + _, _, attn_mask = format_data(data) # forward pass model _ = model.forward( @@ -165,7 +162,7 @@ def train(model, criterion, optimizer, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ + text_input, text_lengths, mel_input, mel_lengths, _,\ avg_text_length, avg_spec_length, attn_mask = format_data(data) loader_time = time.time() - end_time @@ -187,7 +184,7 @@ def train(model, criterion, optimizer, scheduler, # backward pass if amp is not None: - with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss: + with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss: scaled_loss.backward() else: loss_dict['loss'].backward() @@ -312,8 +309,8 @@ def evaluate(model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ - avg_text_length, avg_spec_length, attn_mask = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, _,\ + _, _, attn_mask = format_data(data) # forward pass model z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( @@ -321,7 +318,7 @@ def evaluate(model, criterion, ap, global_step, epoch): # compute loss loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, - o_dur_log, o_total_dur, text_lengths) + o_dur_log, o_total_dur, text_lengths) # step time step_time = time.time() - start_time @@ -405,7 +402,7 @@ def evaluate(model, criterion, ap, global_step, epoch): style_wav = c.get("style_wav_for_test") for idx, test_sentence in enumerate(test_sentences): try: - wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis( + wav, alignment, _, postnet_output, _, _ = synthesis( model, test_sentence, c, @@ -428,7 +425,7 @@ def evaluate(model, criterion, ap, global_step, epoch): postnet_output, ap) test_figures['{}-alignment'.format(idx)] = plot_alignment( alignment) - except: + except: #pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() tb_logger.tb_test_audios(global_step, test_audios, @@ -503,7 +500,7 @@ def main(args): # pylint: disable=redefined-outer-name if c.reinit_layers: raise RuntimeError model.load_state_dict(checkpoint['model']) - except: + except: #pylint: disable=bare-except print(" > Partial model initialization.") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model'], c) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index a92b880f..ab8f3f88 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -100,13 +100,13 @@ class MyDataset(Dataset): try: phonemes = np.load(cache_path) except FileNotFoundError: - phonemes = self._generate_and_cache_phoneme_sequence(text, - cache_path) + phonemes = self._generate_and_cache_phoneme_sequence( + text, cache_path) except (ValueError, IOError): print(" > ERROR: failed loading phonemes for {}. " "Recomputing.".format(wav_file)) - phonemes = self._generate_and_cache_phoneme_sequence(text, - cache_path) + phonemes = self._generate_and_cache_phoneme_sequence( + text, cache_path) if self.enable_eos_bos: phonemes = pad_with_eos_bos(phonemes, tp=self.tp) phonemes = np.asarray(phonemes, dtype=np.int32) @@ -116,18 +116,19 @@ class MyDataset(Dataset): item = self.items[idx] if len(item) == 4: - text, wav_file, speaker_name, attn_file = item + text, wav_file, speaker_name, attn_file = item else: - text, wav_file, speaker_name = item - attn = None + text, wav_file, speaker_name = item + attn = None wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) if self.use_phonemes: text = self._load_or_generate_phoneme_sequence(wav_file, text) else: - text = np.asarray( - text_to_sequence(text, [self.cleaners], tp=self.tp), dtype=np.int32) + text = np.asarray(text_to_sequence(text, [self.cleaners], + tp=self.tp), + dtype=np.int32) assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] @@ -172,8 +173,9 @@ class MyDataset(Dataset): print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Min length sequence: {}".format(np.min(lengths))) print(" | > Avg length sequence: {}".format(np.mean(lengths))) - print(" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( - self.max_seq_len, self.min_seq_len, len(ignored))) + print( + " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}" + .format(self.max_seq_len, self.min_seq_len, len(ignored))) print(" | > Batch group size: {}.".format(self.batch_group_size)) def __len__(self): @@ -206,12 +208,19 @@ class MyDataset(Dataset): ] text = [batch[idx]['text'] for idx in ids_sorted_decreasing] - speaker_name = [batch[idx]['speaker_name'] - for idx in ids_sorted_decreasing] + speaker_name = [ + batch[idx]['speaker_name'] for idx in ids_sorted_decreasing + ] # get speaker embeddings - if self.speaker_mapping is not None: - wav_files_names = [batch[idx]['wav_file_name'] for idx in ids_sorted_decreasing] - speaker_embedding = [self.speaker_mapping[w]['embedding'] for w in wav_files_names] + if self.speaker_mapping is not None: + wav_files_names = [ + batch[idx]['wav_file_name'] + for idx in ids_sorted_decreasing + ] + speaker_embedding = [ + self.speaker_mapping[w]['embedding'] + for w in wav_files_names + ] else: speaker_embedding = None # compute features @@ -221,7 +230,8 @@ class MyDataset(Dataset): # compute 'stop token' targets stop_targets = [ - np.array([0.] * (mel_len - 1) + [1.]) for mel_len in mel_lengths + np.array([0.] * (mel_len - 1) + [1.]) + for mel_len in mel_lengths ] # PAD stop targets @@ -249,7 +259,9 @@ class MyDataset(Dataset): # compute linear spectrogram if self.compute_linear_spec: - linear = [self.ap.spectrogram(w).astype('float32') for w in wav] + linear = [ + self.ap.spectrogram(w).astype('float32') for w in wav + ] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] diff --git a/TTS/tts/layers/glow_tts/decoder.py b/TTS/tts/layers/glow_tts/decoder.py index 43811821..67329a2a 100644 --- a/TTS/tts/layers/glow_tts/decoder.py +++ b/TTS/tts/layers/glow_tts/decoder.py @@ -1,8 +1,6 @@ import torch from torch import nn -from torch.nn import functional as F -from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock from TTS.tts.layers.glow_tts.normalization import ActNorm @@ -54,8 +52,7 @@ class Decoder(nn.Module): num_splits=4, num_sqz=2, sigmoid_scale=False, - c_in_channels=0, - feat_channels=None): + c_in_channels=0): super().__init__() self.in_channels = in_channels diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index df0e0462..c5af85ec 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -5,7 +5,7 @@ from torch import nn from TTS.tts.layers.glow_tts.transformer import Transformer from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock from TTS.tts.utils.generic_utils import sequence_mask -from TTS.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm +from TTS.tts.layers.glow_tts.glow import ConvLayerNorm from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock diff --git a/TTS/tts/layers/glow_tts/gated_conv.py b/TTS/tts/layers/glow_tts/gated_conv.py index 2417ea63..dbe0f0f0 100644 --- a/TTS/tts/layers/glow_tts/gated_conv.py +++ b/TTS/tts/layers/glow_tts/gated_conv.py @@ -1,4 +1,3 @@ -import torch from torch import nn from .normalization import LayerNorm diff --git a/TTS/tts/layers/glow_tts/monotonic_align/setup.py b/TTS/tts/layers/glow_tts/monotonic_align/setup.py index 30c22480..1d669ea0 100644 --- a/TTS/tts/layers/glow_tts/monotonic_align/setup.py +++ b/TTS/tts/layers/glow_tts/monotonic_align/setup.py @@ -2,8 +2,6 @@ from distutils.core import setup from Cython.Build import cythonize import numpy -setup( - name = 'monotonic_align', - ext_modules = cythonize("core.pyx"), - include_dirs=[numpy.get_include()] -) +setup(name='monotonic_align', + ext_modules=cythonize("core.pyx"), + include_dirs=[numpy.get_include()]) diff --git a/TTS/tts/layers/glow_tts/normalization.py b/TTS/tts/layers/glow_tts/normalization.py index 70444abc..0930f48c 100644 --- a/TTS/tts/layers/glow_tts/normalization.py +++ b/TTS/tts/layers/glow_tts/normalization.py @@ -31,11 +31,19 @@ class LayerNorm(nn.Module): class TemporalBatchNorm1d(nn.BatchNorm1d): """Normalize each channel separately over time and batch. """ - def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): - super(TemporalBatchNorm1d, self).__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) + def __init__(self, + channels, + affine=True, + track_running_stats=True, + momentum=0.1): + super(TemporalBatchNorm1d, + self).__init__(channels, + affine=affine, + track_running_stats=track_running_stats, + momentum=momentum) def forward(self, x): - return super().forward(x.transpose(2,1)).transpose(2,1) + return super().forward(x.transpose(2, 1)).transpose(2, 1) class ActNorm(nn.Module): @@ -51,7 +59,6 @@ class ActNorm(nn.Module): - inputs: (B, C, T) - outputs: (B, C, T) """ - def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument super().__init__() self.channels = channels diff --git a/TTS/tts/layers/glow_tts/time_depth_sep_conv.py b/TTS/tts/layers/glow_tts/time_depth_sep_conv.py index 19fc7035..732e7d96 100644 --- a/TTS/tts/layers/glow_tts/time_depth_sep_conv.py +++ b/TTS/tts/layers/glow_tts/time_depth_sep_conv.py @@ -1,8 +1,6 @@ import torch from torch import nn -from .normalization import LayerNorm - class TimeDepthSeparableConv(nn.Module): """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 5cccea19..4b1c88a7 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -1,6 +1,4 @@ -import copy import math -import numpy as np import torch from torch import nn from torch.nn import functional as F @@ -106,7 +104,7 @@ class RelativePositionMultiHeadAttention(nn.Module): scores = scores.masked_fill(mask == 0, -1e4) if self.input_length is not None: block_mask = torch.ones_like(scores).triu( - -self.input_length).tril(self.input_length) + -1 * self.input_length).tril(self.input_length) scores = scores * block_mask + -1e4 * (1 - block_mask) # attention score normalization p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] @@ -126,7 +124,8 @@ class RelativePositionMultiHeadAttention(nn.Module): b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] return output, p_attn - def _matmul_with_relative_values(self, p_attn, re): + @staticmethod + def _matmul_with_relative_values(p_attn, re): """ Args: p_attn (Tensor): attention weights. diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 50f08c93..902de699 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -65,8 +65,8 @@ class GlowTts(nn.Module): self.hidden_channels_enc = hidden_channels_enc self.hidden_channels_dec = hidden_channels_dec self.use_encoder_prenet = use_encoder_prenet - self.noise_scale=0.66 - self.length_scale=1. + self.noise_scale = 0.66 + self.length_scale = 1. self.encoder = Encoder(num_chars, out_channels=out_channels, @@ -98,7 +98,8 @@ class GlowTts(nn.Module): self.emb_g = nn.Embedding(num_speakers, c_in_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) - def compute_outputs(self, attn, o_mean, o_log_scale, x_mask): + @staticmethod + def compute_outputs(attn, o_mean, o_log_scale, x_mask): # compute final values with the computed alignment y_mean = torch.matmul( attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( @@ -123,22 +124,30 @@ class GlowTts(nn.Module): if g is not None: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, + x_lengths, + g=g) # format feature vectors and feature vector lenghts - y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + y, y_lengths, y_max_length, attn = self.preprocess( + y, y_lengths, y_max_length, None) # create masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), + 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp3 = torch.matmul((o_mean * o_scale).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, + [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * + (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), + z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, + [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( @@ -151,14 +160,17 @@ class GlowTts(nn.Module): if g is not None: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, + x_lengths, + g=g) # compute output durations w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None # compute masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), + 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # compute attention mask attn = generate_path(w_ceil.squeeze(1), diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 48083a2a..f810e213 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -58,7 +58,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel decoder_output, postnet_output, alignments, stop_tokens = model.inference( inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) elif 'glow' in CONFIG.model.lower(): - inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) + inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 07ec63a0..d7acaa5d 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -2,13 +2,15 @@ import re import json import pickle as pickle_tts + class RenamingUnpickler(pickle_tts.Unpickler): """Overload default pickler to solve module renaming problem""" def find_class(self, module, name): - if 'TTS' in module : + if 'TTS' in module: module = module.replace('TTS', 'TTS') return super().find_class(module, name) + class AttrDict(dict): """A custom dict which converts dict keys to class attributes""" diff --git a/setup.py b/setup.py index 3126cf6d..a31380f0 100644 --- a/setup.py +++ b/setup.py @@ -16,20 +16,24 @@ try: from Cython.Build import cythonize except ImportError: # create closure for deferred import - def cythonize (*args, ** kwargs ): - from Cython.Build import cythonize - return cythonize(*args, ** kwargs) + def cythonize(*args, **kwargs): #pylint: disable=redefined-outer-name + from Cython.Build import cythonize #pylint: disable=redefined-outer-name, import-outside-toplevel + return cythonize(*args, **kwargs) parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) -parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file to embed in wheel.') -parser.add_argument('--model_config', type=str, help='Path to model configuration file to embed in wheel.') +parser.add_argument('--checkpoint', + type=str, + help='Path to checkpoint file to embed in wheel.') +parser.add_argument('--model_config', + type=str, + help='Path to model configuration file to embed in wheel.') args, unknown_args = parser.parse_known_args() # Remove our arguments from argv so that setuptools doesn't see them sys.argv = [sys.argv[0]] + unknown_args -version = '0.0.4' +version = '0.0.5' # Adapted from https://github.com/pytorch/pytorch cwd = os.path.dirname(os.path.abspath(__file__)) @@ -37,8 +41,8 @@ if os.getenv('TTS_PYTORCH_BUILD_VERSION'): version = os.getenv('TTS_PYTORCH_BUILD_VERSION') else: try: - sha = subprocess.check_output( - ['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() + sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=cwd).decode('ascii').strip() version += '+' + sha[:7] except subprocess.CalledProcessError: pass @@ -49,7 +53,7 @@ else: # Handle Cython code def find_pyx(path='.'): pyx_files = [] - for root, dirs, filenames in os.walk(path): + for root, _, filenames in os.walk(path): for fname in filenames: if fname.endswith('.pyx'): pyx_files.append(os.path.join(root, fname)) @@ -91,20 +95,14 @@ if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config: def pip_install(package_name): - subprocess.call( - [sys.executable, '-m', 'pip', 'install', package_name] - ) + subprocess.call([sys.executable, '-m', 'pip', 'install', package_name]) reqs_from_file = open('requirements.txt').readlines() reqs_without_tf = [r for r in reqs_from_file if not r.startswith('tensorflow')] tf_req = [r for r in reqs_from_file if r.startswith('tensorflow')] -requirements = { - 'install_requires': reqs_without_tf, - 'pip_install': tf_req -} - +requirements = {'install_requires': reqs_without_tf, 'pip_install': tf_req} setup( name='TTS', @@ -114,11 +112,7 @@ setup( author_email='egolge@mozilla.com', description='Text to Speech with Deep Learning', license='MPL-2.0', - entry_points={ - 'console_scripts': [ - 'tts-server = TTS.server.server:main' - ] - }, + entry_points={'console_scripts': ['tts-server = TTS.server.server:main']}, include_dirs=[numpy.get_include()], ext_modules=cythonize(find_pyx(), language_level=3), packages=find_packages(include=['TTS*']), @@ -145,8 +139,7 @@ setup( "Operating System :: POSIX :: Linux", 'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)', "Topic :: Software Development :: Libraries :: Python Modules :: Speech :: Sound/Audio :: Multimedia :: Artificial Intelligence", - ] -) + ]) # for some reason having tensorflow in 'install_requires' # breaks some of the dependencies. diff --git a/tests/test_glow_tts.py b/tests/test_glow_tts.py index 6f3cdb81..f625e943 100644 --- a/tests/test_glow_tts.py +++ b/tests/test_glow_tts.py @@ -4,7 +4,7 @@ import unittest import torch from tests import get_tests_input_path -from torch import nn, optim +from torch import optim from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTts @@ -42,64 +42,60 @@ class GlowTTSTrainTest(unittest.TestCase): criterion = criterion = GlowTTSLoss() # model to train - model = GlowTts( - num_chars=32, - hidden_channels=128, - filter_channels=32, - filter_channels_dp=32, - out_channels=80, - kernel_size=3, - num_heads=2, - num_layers_enc=6, - dropout_p=0.1, - num_flow_blocks_dec=12, - kernel_size_dec=5, - dilation_rate=5, - num_block_layers=4, - dropout_p_dec=0., - num_speakers=0, - c_in_channels=0, - num_splits=4, - num_sqz=1, - sigmoid_scale=False, - rel_attn_window_size=None, - input_length=None, - mean_only=False, - hidden_channels_enc=None, - hidden_channels_dec=None, - use_encoder_prenet=False, - encoder_type="transformer" - ).to(device) + model = GlowTts(num_chars=32, + hidden_channels=128, + filter_channels=32, + filter_channels_dp=32, + out_channels=80, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dec=0., + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_sqz=1, + sigmoid_scale=False, + rel_attn_window_size=None, + input_length=None, + mean_only=False, + hidden_channels_enc=None, + hidden_channels_dec=None, + use_encoder_prenet=False, + encoder_type="transformer").to(device) # reference model to compare model weights - model_ref = GlowTts( - num_chars=32, - hidden_channels=128, - filter_channels=32, - filter_channels_dp=32, - out_channels=80, - kernel_size=3, - num_heads=2, - num_layers_enc=6, - dropout_p=0.1, - num_flow_blocks_dec=12, - kernel_size_dec=5, - dilation_rate=5, - num_block_layers=4, - dropout_p_dec=0., - num_speakers=0, - c_in_channels=0, - num_splits=4, - num_sqz=1, - sigmoid_scale=False, - rel_attn_window_size=None, - input_length=None, - mean_only=False, - hidden_channels_enc=None, - hidden_channels_dec=None, - use_encoder_prenet=False, - encoder_type="transformer" - ).to(device) + model_ref = GlowTts(num_chars=32, + hidden_channels=128, + filter_channels=32, + filter_channels_dp=32, + out_channels=80, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dec=0., + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_sqz=1, + sigmoid_scale=False, + rel_attn_window_size=None, + input_length=None, + mean_only=False, + hidden_channels_enc=None, + hidden_channels_dec=None, + use_encoder_prenet=False, + encoder_type="transformer").to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % @@ -120,7 +116,7 @@ class GlowTTSTrainTest(unittest.TestCase): input_dummy, input_lengths, mel_spec, mel_lengths, None) optimizer.zero_grad() loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, - o_dur_log, o_total_dur, input_lengths) + o_dur_log, o_total_dur, input_lengths) loss = loss_dict['loss'] loss.backward() optimizer.step()