small gst config change

This commit is contained in:
SanjaESC 2020-07-13 08:51:37 +02:00 committed by thllwg
parent 69367bd2ae
commit 18007e389d
2 changed files with 3 additions and 59 deletions

View File

@ -2,7 +2,6 @@ import os
import copy import copy
import torch import torch
import unittest import unittest
import numpy as np
from torch import optim from torch import optim
from torch import nn from torch import nn
@ -21,7 +20,8 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTrainTest(unittest.TestCase): class TacotronTrainTest(unittest.TestCase):
def test_train_step(self): @staticmethod
def test_train_step():
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8, )).long().to(device) input_lengths = torch.randint(100, 128, (8, )).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0] input_lengths = torch.sort(input_lengths, descending=True)[0]
@ -71,59 +71,3 @@ class TacotronTrainTest(unittest.TestCase):
), "param {} with shape {} not updated!! \n{}\n{}".format( ), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref) count, param.shape, param, param_ref)
count += 1 count += 1
class TacotronGSTTrainTest(unittest.TestCase):
def test_train_step(self):
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
mel_lengths[0] = 30
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron2(num_chars=24,
gst=True,
r=c.r,
num_speakers=5).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5):
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1

View File

@ -359,8 +359,8 @@ def check_config(c):
# GST # GST
_check_argument('use_gst', c, restricted=True, val_type=bool) _check_argument('use_gst', c, restricted=True, val_type=bool)
_check_argument('gst_style_input', c, restricted=True, val_type=str)
_check_argument('gst', c, restricted=True, val_type=dict) _check_argument('gst', c, restricted=True, val_type=dict)
_check_argument('gst_style_input', c['gst'], restricted=True, val_type=str)
_check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1) _check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1)
_check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1) _check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
_check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1) _check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)