diff --git a/mozilla_voice_tts/bin/train_encoder.py b/mozilla_voice_tts/bin/train_encoder.py index c89469b9..f9bfea7f 100644 --- a/mozilla_voice_tts/bin/train_encoder.py +++ b/mozilla_voice_tts/bin/train_encoder.py @@ -10,21 +10,21 @@ import traceback import torch from torch.utils.data import DataLoader -from mozilla_voice_tts.generic_utils import count_parameters from mozilla_voice_tts.speaker_encoder.dataset import MyDataset from mozilla_voice_tts.speaker_encoder.generic_utils import save_best_model -from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss +from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss, AngleProtoLoss from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder from mozilla_voice_tts.speaker_encoder.visual import plot_embeddings from mozilla_voice_tts.tts.datasets.preprocess import load_meta_data -from mozilla_voice_tts.tts.utils.audio import AudioProcessor from mozilla_voice_tts.tts.utils.generic_utils import ( create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) from mozilla_voice_tts.tts.utils.io import copy_config_file, load_config -from mozilla_voice_tts.tts.utils.radam import RAdam -from mozilla_voice_tts.tts.utils.tensorboard_logger import TensorboardLogger -from mozilla_voice_tts.tts.utils.training import NoamLR, check_update +from mozilla_voice_tts.utils.audio import AudioProcessor +from mozilla_voice_tts.utils.generic_utils import count_parameters +from mozilla_voice_tts.utils.radam import RAdam +from mozilla_voice_tts.utils.tensorboard_logger import TensorboardLogger +from mozilla_voice_tts.utils.training import NoamLR, check_update torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True @@ -146,7 +146,7 @@ def main(args): # pylint: disable=redefined-outer-name elif c.loss == "angleproto": criterion = AngleProtoLoss() else: - raise Exception("The %s not is a loss supported" %c.loss) + raise Exception("The %s not is a loss supported" % c.loss) if args.restore_path: checkpoint = torch.load(args.restore_path) @@ -192,6 +192,7 @@ def main(args): # pylint: disable=redefined-outer-name _, global_step = train(model, criterion, optimizer, scheduler, ap, global_step) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( diff --git a/mozilla_voice_tts/speaker_encoder/loss.py b/mozilla_voice_tts/speaker_encoder/loss.py deleted file mode 100644 index 6f83be63..00000000 --- a/mozilla_voice_tts/speaker_encoder/loss.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -# adapted from https://github.com/cvqluu/GE2E-Loss -class GE2ELoss(nn.Module): - def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): - """ - Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] - Accepts an input of size (N, M, D) - where N is the number of speakers in the batch, - M is the number of utterances per speaker, - and D is the dimensionality of the embedding vector (e.g. d-vector) - Args: - - init_w (float): defines the initial value of w in Equation (5) of [1] - - init_b (float): definies the initial value of b in Equation (5) of [1] - """ - super(GE2ELoss, self).__init__() - # pylint: disable=E1102 - self.w = nn.Parameter(torch.tensor(init_w)) - # pylint: disable=E1102 - self.b = nn.Parameter(torch.tensor(init_b)) - self.loss_method = loss_method - - print('Initialised Generalized End-to-End loss') - - assert self.loss_method in ["softmax", "contrast"] - - if self.loss_method == "softmax": - self.embed_loss = self.embed_loss_softmax - if self.loss_method == "contrast": - self.embed_loss = self.embed_loss_contrast - - # pylint: disable=R0201 - def calc_new_centroids(self, dvecs, centroids, spkr, utt): - """ - Calculates the new centroids excluding the reference utterance - """ - excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) - excl = torch.mean(excl, 0) - new_centroids = [] - for i, centroid in enumerate(centroids): - if i == spkr: - new_centroids.append(excl) - else: - new_centroids.append(centroid) - return torch.stack(new_centroids) - - def calc_cosine_sim(self, dvecs, centroids): - """ - Make the cosine similarity matrix with dims (N,M,N) - """ - cos_sim_matrix = [] - for spkr_idx, speaker in enumerate(dvecs): - cs_row = [] - for utt_idx, utterance in enumerate(speaker): - new_centroids = self.calc_new_centroids( - dvecs, centroids, spkr_idx, utt_idx - ) - # vector based cosine similarity for speed - cs_row.append( - torch.clamp( - torch.mm( - utterance.unsqueeze(1).transpose(0, 1), - new_centroids.transpose(0, 1), - ) - / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), - 1e-6, - ) - ) - cs_row = torch.cat(cs_row, dim=0) - cos_sim_matrix.append(cs_row) - return torch.stack(cos_sim_matrix) - - # pylint: disable=R0201 - def embed_loss_softmax(self, dvecs, cos_sim_matrix): - """ - Calculates the loss on each embedding $L(e_{ji})$ by taking softmax - """ - N, M, _ = dvecs.shape - L = [] - for j in range(N): - L_row = [] - for i in range(M): - L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) - L_row = torch.stack(L_row) - L.append(L_row) - return torch.stack(L) - - # pylint: disable=R0201 - def embed_loss_contrast(self, dvecs, cos_sim_matrix): - """ - Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid - """ - N, M, _ = dvecs.shape - L = [] - for j in range(N): - L_row = [] - for i in range(M): - centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) - excl_centroids_sigmoids = torch.cat( - (centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]) - ) - L_row.append( - 1.0 - - torch.sigmoid(cos_sim_matrix[j, i, j]) - + torch.max(excl_centroids_sigmoids) - ) - L_row = torch.stack(L_row) - L.append(L_row) - return torch.stack(L) - - def forward(self, dvecs): - """ - Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) - """ - centroids = torch.mean(dvecs, 1) - cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids) - torch.clamp(self.w, 1e-6) - cos_sim_matrix = self.w * cos_sim_matrix + self.b - L = self.embed_loss(dvecs, cos_sim_matrix) - return L.mean() - -# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py -class AngleProtoLoss(nn.Module): - """ - Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 - Accepts an input of size (N, M, D) - where N is the number of speakers in the batch, - M is the number of utterances per speaker, - and D is the dimensionality of the embedding vector - Args: - - init_w (float): defines the initial value of w - - init_b (float): definies the initial value of b - """ - def __init__(self, init_w=10.0, init_b=-5.0): - super(AngleProtoLoss, self).__init__() - # pylint: disable=E1102 - self.w = nn.Parameter(torch.tensor(init_w)) - # pylint: disable=E1102 - self.b = nn.Parameter(torch.tensor(init_b)) - self.criterion = torch.nn.CrossEntropyLoss() - self.use_cuda = torch.cuda.is_available() - - print('Initialised Angular Prototypical loss') - - def forward(self, x): - """ - Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) - """ - out_anchor = torch.mean(x[:,1:,:],1) - out_positive = x[:,0,:] - num_speakers = out_anchor.size()[0] - - cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,num_speakers),out_anchor.unsqueeze(-1).expand(-1,-1,num_speakers).transpose(0,2)) - torch.clamp(self.w, 1e-6) - cos_sim_matrix = cos_sim_matrix * self.w + self.b - label = torch.from_numpy(np.asarray(range(0,num_speakers))) - if self.use_cuda: - label = label.cuda() - L = self.criterion(cos_sim_matrix, label) - return L \ No newline at end of file diff --git a/mozilla_voice_tts/speaker_encoder/losses.py b/mozilla_voice_tts/speaker_encoder/losses.py index 750648e5..35ff73fa 100644 --- a/mozilla_voice_tts/speaker_encoder/losses.py +++ b/mozilla_voice_tts/speaker_encoder/losses.py @@ -157,4 +157,4 @@ class AngleProtoLoss(nn.Module): cos_sim_matrix = cos_sim_matrix * self.w + self.b label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device) L = self.criterion(cos_sim_matrix, label) - return L \ No newline at end of file + return L diff --git a/mozilla_voice_tts/tts/datasets/preprocess.py b/mozilla_voice_tts/tts/datasets/preprocess.py index 7865652a..ece3bcb6 100644 --- a/mozilla_voice_tts/tts/datasets/preprocess.py +++ b/mozilla_voice_tts/tts/datasets/preprocess.py @@ -93,9 +93,10 @@ def mozilla_de(root_path, meta_file): def mailabs(root_path, meta_files=None): """Normalizes M-AI-Labs meta data files to TTS format""" - speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/") + speaker_regex = re.compile( + "by_book/(male|female)/(?P[^/]+)/") if meta_files is None: - csv_files = glob(root_path+"/**/metadata.csv", recursive=True) + csv_files = glob(root_path + "/**/metadata.csv", recursive=True) else: csv_files = meta_files # meta_files = [f.strip() for f in meta_files.split(",")] @@ -115,12 +116,15 @@ def mailabs(root_path, meta_files=None): if meta_files is None: wav_file = os.path.join(folder, 'wavs', cols[0] + '.wav') else: - wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav') + wav_file = os.path.join(root_path, + folder.replace("metadata.csv", ""), + 'wavs', cols[0] + '.wav') if os.path.isfile(wav_file): text = cols[1].strip() items.append([text, wav_file, speaker_name]) else: - raise RuntimeError("> File %s does not exist!"%(wav_file)) + raise RuntimeError("> File %s does not exist!" % + (wav_file)) return items @@ -185,7 +189,8 @@ def libri_tts(root_path, meta_files=None): text = cols[1] items.append([text, wav_file, speaker_name]) for item in items: - assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" + assert os.path.exists( + item[1]), f" [!] wav files don't exist - {item[1]}" return items @@ -197,7 +202,8 @@ def custom_turkish(root_path, meta_file): with open(txt_file, 'r', encoding='utf-8') as ttf: for line in ttf: cols = line.split('|') - wav_file = os.path.join(root_path, 'wavs', cols[0].strip() + '.wav') + wav_file = os.path.join(root_path, 'wavs', + cols[0].strip() + '.wav') if not os.path.exists(wav_file): skipped_files.append(wav_file) continue @@ -206,6 +212,7 @@ def custom_turkish(root_path, meta_file): print(f" [!] {len(skipped_files)} files skipped. They don't exist...") return items + # ToDo: add the dataset link when the dataset is released publicly def brspeech(root_path, meta_file): '''BRSpeech 3.0 beta''' @@ -223,20 +230,25 @@ def brspeech(root_path, meta_file): items.append([text, wav_file, speaker_name]) return items + def vctk(root_path, meta_files=None, wavs_path='wav48'): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" test_speakers = meta_files items = [] - meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", + recursive=True) for meta_file in meta_files: - _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + _, speaker_id, txt_file = os.path.relpath(meta_file, + root_path).split(os.sep) file_id = txt_file.split('.')[0] - if isinstance(test_speakers, list): # if is list ignore this speakers ids + if isinstance(test_speakers, + list): # if is list ignore this speakers ids if speaker_id in test_speakers: continue with open(meta_file) as file_text: text = file_text.readlines()[0] - wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id+'.wav') + wav_file = os.path.join(root_path, wavs_path, speaker_id, + file_id + '.wav') items.append([text, wav_file, speaker_id]) return items \ No newline at end of file diff --git a/mozilla_voice_tts/tts/models/tacotron.py b/mozilla_voice_tts/tts/models/tacotron.py index ac88133b..1dcf2fc8 100644 --- a/mozilla_voice_tts/tts/models/tacotron.py +++ b/mozilla_voice_tts/tts/models/tacotron.py @@ -6,6 +6,7 @@ from mozilla_voice_tts.tts.layers.gst_layers import GST from mozilla_voice_tts.tts.layers.tacotron import Decoder, Encoder, PostCBHG from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract + class Tacotron(TacotronAbstract): def __init__(self, num_chars, @@ -42,8 +43,8 @@ class Tacotron(TacotronAbstract): forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, - ddc_r, encoder_in_features, decoder_in_features, - speaker_embedding_dim, gst, gst_embedding_dim, + ddc_r, encoder_in_features, decoder_in_features, + speaker_embedding_dim, gst, gst_embedding_dim, gst_num_heads, gst_style_tokens) # speaker embedding layers diff --git a/mozilla_voice_tts/tts/models/tacotron2.py b/mozilla_voice_tts/tts/models/tacotron2.py index 9fa640b0..a9ba442c 100644 --- a/mozilla_voice_tts/tts/models/tacotron2.py +++ b/mozilla_voice_tts/tts/models/tacotron2.py @@ -1,15 +1,9 @@ import torch from torch import nn -<<<<<<< HEAD:mozilla_voice_tts/tts/models/tacotron2.py from mozilla_voice_tts.tts.layers.gst_layers import GST from mozilla_voice_tts.tts.layers.tacotron2 import Decoder, Encoder, Postnet from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract -======= -from TTS.tts.layers.gst_layers import GST -from TTS.tts.layers.tacotron2 import Decoder, Encoder, Postnet -from TTS.tts.models.tacotron_abstract import TacotronAbstract ->>>>>>> bugfix in DDC now DDC work on Tacotron1:TTS/tts/models/tacotron2.py # TODO: match function arguments with tacotron class Tacotron2(TacotronAbstract): @@ -47,8 +41,8 @@ class Tacotron2(TacotronAbstract): forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, - ddc_r, encoder_in_features, decoder_in_features, - speaker_embedding_dim, gst, gst_embedding_dim, + ddc_r, encoder_in_features, decoder_in_features, + speaker_embedding_dim, gst, gst_embedding_dim, gst_num_heads, gst_style_tokens) # speaker embedding layer @@ -61,7 +55,7 @@ class Tacotron2(TacotronAbstract): # speaker and gst embeddings is concat in decoder input if self.num_speakers > 1: self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim - + # embedding layer self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) diff --git a/mozilla_voice_tts/tts/models/tacotron_abstract.py b/mozilla_voice_tts/tts/models/tacotron_abstract.py index 0077f3e4..d98d03b7 100644 --- a/mozilla_voice_tts/tts/models/tacotron_abstract.py +++ b/mozilla_voice_tts/tts/models/tacotron_abstract.py @@ -28,8 +28,8 @@ class TacotronAbstract(ABC, nn.Module): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, - encoder_in_features=512, - decoder_in_features=512, + encoder_in_features=512, + decoder_in_features=512, speaker_embedding_dim=None, gst=False, gst_embedding_dim=512, diff --git a/synthesize.py b/synthesize.py deleted file mode 100644 index bd720123..00000000 --- a/synthesize.py +++ /dev/null @@ -1,182 +0,0 @@ -# pylint: disable=redefined-outer-name, unused-argument -import os -import time -import argparse -import torch -import json -import string - -from TTS.utils.synthesis import synthesis -from TTS.utils.generic_utils import setup_model -from TTS.utils.io import load_config -from TTS.utils.text.symbols import make_symbols, symbols, phonemes -from TTS.utils.audio import AudioProcessor - - -def tts(model, - vocoder_model, - C, - VC, - text, - ap, - ap_vocoder, - use_cuda, - batched_vocoder, - speaker_id=None, - figures=False): - t_1 = time.time() - use_vocoder_model = vocoder_model is not None - waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis( - model, text, C, use_cuda, ap, speaker_id, style_wav=C.gst['gst_style_input'], - truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars, - use_griffin_lim=(not use_vocoder_model), do_trim_silence=True) - - if C.model == "Tacotron" and use_vocoder_model: - postnet_output = ap.out_linear_to_mel(postnet_output.T).T - # correct if there is a scale difference b/w two models - if use_vocoder_model: - postnet_output = ap._denormalize(postnet_output) - postnet_output = ap_vocoder._normalize(postnet_output) - vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) - waveform = vocoder_model.generate( - vocoder_input.cuda() if use_cuda else vocoder_input, - batched=batched_vocoder, - target=8000, - overlap=400) - print(" > Run-time: {}".format(time.time() - t_1)) - return alignment, postnet_output, stop_tokens, waveform - - -if __name__ == "__main__": - - global symbols, phonemes - - parser = argparse.ArgumentParser() - parser.add_argument('text', type=str, help='Text to generate speech.') - parser.add_argument('config_path', - type=str, - help='Path to model config file.') - parser.add_argument( - 'model_path', - type=str, - help='Path to model file.', - ) - parser.add_argument( - 'out_path', - type=str, - help='Path to save final wav file. Wav file will be names as the text given.', - ) - parser.add_argument('--use_cuda', - type=bool, - help='Run model on CUDA.', - default=False) - parser.add_argument( - '--vocoder_path', - type=str, - help= - 'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).', - default="", - ) - parser.add_argument('--vocoder_config_path', - type=str, - help='Path to vocoder model config file.', - default="") - parser.add_argument( - '--batched_vocoder', - type=bool, - help="If True, vocoder model uses faster batch processing.", - default=True) - parser.add_argument('--speakers_json', - type=str, - help="JSON file for multi-speaker model.", - default="") - parser.add_argument( - '--speaker_id', - type=int, - help="target speaker_id if the model is multi-speaker.", - default=None) - args = parser.parse_args() - - if args.vocoder_path != "": - assert args.use_cuda, " [!] Enable cuda for vocoder." - from WaveRNN.models.wavernn import Model as VocoderModel - - # load the config - C = load_config(args.config_path) - C.forward_attn_mask = True - - # load the audio processor - ap = AudioProcessor(**C.audio) - - # if the vocabulary was passed, replace the default - if 'characters' in C.keys(): - symbols, phonemes = make_symbols(**C.characters) - - # load speakers - if args.speakers_json != '': - speakers = json.load(open(args.speakers_json, 'r')) - num_speakers = len(speakers) - else: - num_speakers = 0 - - # load the model - num_chars = len(phonemes) if C.use_phonemes else len(symbols) - model = setup_model(num_chars, num_speakers, C) - cp = torch.load(args.model_path) - model.load_state_dict(cp['model']) - model.eval() - if args.use_cuda: - model.cuda() - model.decoder.set_r(cp['r']) - - # load vocoder model - if args.vocoder_path != "": - VC = load_config(args.vocoder_config_path) - ap_vocoder = AudioProcessor(**VC.audio) - bits = 10 - vocoder_model = VocoderModel(rnn_dims=512, - fc_dims=512, - mode=VC.mode, - mulaw=VC.mulaw, - pad=VC.pad, - upsample_factors=VC.upsample_factors, - feat_dims=VC.audio["num_mels"], - compute_dims=128, - res_out_dims=128, - res_blocks=10, - hop_length=ap.hop_length, - sample_rate=ap.sample_rate, - use_aux_net=True, - use_upsample_net=True) - - check = torch.load(args.vocoder_path) - vocoder_model.load_state_dict(check['model']) - vocoder_model.eval() - if args.use_cuda: - vocoder_model.cuda() - else: - vocoder_model = None - VC = None - ap_vocoder = None - - # synthesize voice - print(" > Text: {}".format(args.text)) - _, _, _, wav = tts(model, - vocoder_model, - C, - VC, - args.text, - ap, - ap_vocoder, - args.use_cuda, - args.batched_vocoder, - speaker_id=args.speaker_id, - figures=False) - - # save the results - file_name = args.text.replace(" ", "_") - file_name = file_name.translate( - str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav' - out_path = os.path.join(args.out_path, file_name) - print(" > Saving output to {}".format(out_path)) - ap.save_wav(wav, out_path) diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index 0ff79f6e..28d39de5 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -76,61 +76,6 @@ class TacotronTrainTest(unittest.TestCase): count += 1 -class TacotronGSTTrainTest(unittest.TestCase): - def test_train_step(self): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 128, (8, )).long().to(device) - input_lengths = torch.sort(input_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) - mel_lengths[0] = 30 - stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) - - for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 - - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() - - criterion = MSELossMasked(seq_len_norm=False).to(device) - criterion_st = nn.BCEWithLogitsLoss().to(device) - model = Tacotron2(num_chars=24, - gst=True, - r=c.r, - num_speakers=5).to(device) - model.train() - model_ref = copy.deepcopy(model) - count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): - assert (param - param_ref).sum() == 0, param - count += 1 - optimizer = optim.Adam(model.parameters(), lr=c.lr) - for i in range(5): - mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) - assert torch.sigmoid(stop_tokens).data.max() <= 1.0 - assert torch.sigmoid(stop_tokens).data.min() >= 0.0 - optimizer.zero_grad() - loss = criterion(mel_out, mel_spec, mel_lengths) - stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss - loss.backward() - optimizer.step() - # check parameter changes - count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): - # ignore pre-higway layer since it works conditional - # if count not in [145, 59]: - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) - count += 1 - class MultiSpeakeTacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): @@ -185,8 +130,8 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase): count += 1 class TacotronGSTTrainTest(unittest.TestCase): - @staticmethod - def test_train_step(): + #pylint: disable=no-self-use + def test_train_step(self): # with random gst mel style input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 128, (8, )).long().to(device)