mirror of https://github.com/coqui-ai/TTS.git
add stop token to tacotron testing
This commit is contained in:
parent
02d50089c4
commit
b7e4d214f8
12
config.json
12
config.json
|
@ -12,7 +12,7 @@
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 50,
|
"epochs": 50,
|
||||||
"lr": 0.004,
|
"lr": 0.002,
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"eval_batch_size":32,
|
"eval_batch_size":32,
|
||||||
|
@ -23,14 +23,14 @@
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
"power": 1.2,
|
"power": 1.2,
|
||||||
|
|
||||||
"dataset": "TWEB",
|
"dataset": "LJSpeech",
|
||||||
"meta_file_train": "transcript_train.txt",
|
"meta_file_train": "metadata_train.csv",
|
||||||
"meta_file_val": "transcript_val.txt",
|
"meta_file_val": "metadata_val.csv",
|
||||||
"data_path": "/data/shared/BibleSpeech/",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0/",
|
||||||
"min_seq_len": 0,
|
"min_seq_len": 0,
|
||||||
"num_loader_workers": 8,
|
"num_loader_workers": 8,
|
||||||
|
|
||||||
"checkpoint": true,
|
"checkpoint": true,
|
||||||
"save_step": 908,
|
"save_step": 600,
|
||||||
"output_path": "/data/shared/erogol_models/"
|
"output_path": "/data/shared/erogol_models/"
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
from torch import nn
|
||||||
from TTS.utils.generic_utils import load_config
|
from TTS.utils.generic_utils import load_config
|
||||||
from TTS.layers.losses import L1LossMasked
|
from TTS.layers.losses import L1LossMasked
|
||||||
from TTS.models.tacotron import Tacotron
|
from TTS.models.tacotron import Tacotron
|
||||||
|
@ -24,7 +25,11 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
||||||
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
for idx in mel_lengths:
|
||||||
|
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||||
criterion = L1LossMasked().to(device)
|
criterion = L1LossMasked().to(device)
|
||||||
|
criterion_st = nn.BCELoss().to(device)
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
|
@ -37,17 +42,20 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, linear_out, align = model.forward(input, mel_spec)
|
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
||||||
|
assert stop_tokens.data.max() <= 1.0
|
||||||
|
assert stop_tokens.data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
loss = 0.5 * loss + 0.5 * criterion(linear_out, linear_spec, mel_lengths)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
|
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
if count not in [139, 59]:
|
if count not in [141, 59]:
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue