diff --git a/train.py b/train.py index 8f818ac2..62b8f074 100644 --- a/train.py +++ b/train.py @@ -125,11 +125,18 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) - decoder_loss = criterion(decoder_output, mel_input, mel_lengths) - if c.model == "Tacotron": - postnet_loss = criterion(postnet_output, linear_input, mel_lengths) + if c.loss_masking: + decoder_loss = criterion(decoder_output, mel_input, mel_lengths) + if c.model == "Tacotron": + postnet_loss = criterion(postnet_output, linear_input, mel_lengths) + else: + postnet_loss = criterion(postnet_output, mel_input, mel_lengths) else: - postnet_loss = criterion(postnet_output, mel_input, mel_lengths) + decoder_loss = criterion(decoder_output, mel_input) + if c.model == "Tacotron": + postnet_loss = criterion(postnet_output, linear_input) + else: + postnet_loss = criterion(postnet_output, mel_input) loss = decoder_loss + postnet_loss # backpass and check the grad norm for spec losses @@ -283,11 +290,18 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) - decoder_loss = criterion(decoder_output, mel_input, mel_lengths) - if c.model == "Tacotron": - postnet_loss = criterion(postnet_output, linear_input, mel_lengths) + if c.loss_masking: + decoder_loss = criterion(decoder_output, mel_input, mel_lengths) + if c.model == "Tacotron": + postnet_loss = criterion(postnet_output, linear_input, mel_lengths) + else: + postnet_loss = criterion(postnet_output, mel_input, mel_lengths) else: - postnet_loss = criterion(postnet_output, mel_input, mel_lengths) + decoder_loss = criterion(decoder_output, mel_input) + if c.model == "Tacotron": + postnet_loss = criterion(postnet_output, linear_input) + else: + postnet_loss = criterion(postnet_output, mel_input) loss = decoder_loss + postnet_loss + stop_loss step_time = time.time() - start_time @@ -383,7 +397,10 @@ def main(args): optimizer_st = optim.Adam( model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) - criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked() + if c.loss_masking: + criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked() + else: + criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss() criterion_st = nn.BCEWithLogitsLoss() if args.restore_path: diff --git a/utils/generic_utils.py b/utils/generic_utils.py index b1197fc6..ef686962 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -56,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug): """ Create a folder with the current date and time """ date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") # if debug: - # commit_hash = 'debug' + # commit_hash = 'debug' # else: commit_hash = get_commit_hash() output_folder = os.path.join( @@ -207,8 +207,8 @@ def sequence_mask(sequence_length, max_len=None): seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) if sequence_length.is_cuda: seq_range_expand = seq_range_expand.cuda() - seq_length_expand = (sequence_length.unsqueeze(1) - .expand_as(seq_range_expand)) + seq_length_expand = ( + sequence_length.unsqueeze(1).expand_as(seq_range_expand)) # B x T_max return seq_range_expand < seq_length_expand @@ -239,16 +239,27 @@ def set_init_dict(model_dict, checkpoint, c): } # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + print(" | > {} / {} layers are restored.".format( + len(pretrained_dict), len(model_dict))) return model_dict def setup_model(num_chars, c): print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module('models.'+c.model.lower()) + MyModel = importlib.import_module('models.' + c.model.lower()) MyModel = getattr(MyModel, c.model) if c.model.lower() == "tacotron": - model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, memory_size=c.memory_size) + model = MyModel( + num_chars=num_chars, + r=c.r, + attn_norm=c.attention_norm, + memory_size=c.memory_size) elif c.model.lower() == "tacotron2": - model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn) + model = MyModel( + num_chars=num_chars, + r=c.r, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent) return model \ No newline at end of file