mirror of https://github.com/coqui-ai/TTS.git
Loss bug fix - target_flat vs target
This commit is contained in:
parent
c45666f417
commit
497a6991c7
|
@ -46,7 +46,7 @@ class L1LossMasked(nn.Module):
|
|||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(input, target, size_average=False,
|
||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
||||
reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
|
|
@ -9,7 +9,7 @@ class PrenetTests(unittest.TestCase):
|
|||
|
||||
def test_in_out(self):
|
||||
layer = Prenet(128, out_features=[256, 128])
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 128))
|
||||
dummy_input = T.rand(4, 128)
|
||||
|
||||
print(layer)
|
||||
output = layer(dummy_input)
|
||||
|
@ -21,7 +21,7 @@ class CBHGTests(unittest.TestCase):
|
|||
|
||||
def test_in_out(self):
|
||||
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||
dummy_input = T.rand(4, 8, 128)
|
||||
|
||||
print(layer)
|
||||
output = layer(dummy_input)
|
||||
|
@ -34,8 +34,8 @@ class DecoderTests(unittest.TestCase):
|
|||
|
||||
def test_in_out(self):
|
||||
layer = Decoder(in_features=256, memory_dim=80, r=2)
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 256))
|
||||
dummy_memory = T.autograd.Variable(T.rand(4, 2, 80))
|
||||
dummy_input = T.rand(4, 8, 256)
|
||||
dummy_memory = T.rand(4, 2, 80)
|
||||
|
||||
output, alignment = layer(dummy_input, dummy_memory)
|
||||
|
||||
|
@ -48,7 +48,7 @@ class EncoderTests(unittest.TestCase):
|
|||
|
||||
def test_in_out(self):
|
||||
layer = Encoder(128)
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||
dummy_input = T.rand(4, 8, 128)
|
||||
|
||||
print(layer)
|
||||
output = layer(dummy_input)
|
||||
|
@ -62,24 +62,22 @@ class L1LossMaskedTests(unittest.TestCase):
|
|||
|
||||
def test_in_out(self):
|
||||
layer = L1LossMasked()
|
||||
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||
dummy_target = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.shape[0] == 0
|
||||
assert len(output.shape) == 1
|
||||
assert output.data[0] == 0.0
|
||||
assert output.item() == 0.0
|
||||
|
||||
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||
dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float())
|
||||
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
|
||||
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||
dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float())
|
||||
dummy_length = T.autograd.Variable((T.arange(5, 9)).long())
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((_sequence_mask(dummy_length).float() - 1.0)
|
||||
* 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
|
|
|
@ -5,21 +5,22 @@ import numpy as np
|
|||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.datasets.LJSpeech import LJSpeechDataset
|
||||
from TTS.datasets.TWEB import TWEBDataset
|
||||
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
|
||||
class TestDataset(unittest.TestCase):
|
||||
class TestLJSpeechDataset(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestDataset, self).__init__(*args, **kwargs)
|
||||
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
|
||||
def test_loader(self):
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
|
@ -58,8 +59,8 @@ class TestDataset(unittest.TestCase):
|
|||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
||||
1,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
|
@ -141,3 +142,136 @@ class TestDataset(unittest.TestCase):
|
|||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
||||
|
||||
class TestTWEBDataset(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestTWEBDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
|
||||
def test_loader(self):
|
||||
dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
||||
os.path.join(c.data_path_TWEB, 'wavs'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power
|
||||
)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
shuffle=True, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
||||
os.path.join(c.data_path_TWEB, 'wavs'),
|
||||
1,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power
|
||||
)
|
||||
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(dataset, batch_size=1,
|
||||
shuffle=False, collate_fn=dataset.collate_fn,
|
||||
drop_last=False, num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
# check the last time step to be zero padded
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i)
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
assert linear_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
shuffle=False, collate_fn=dataset.collate_fn,
|
||||
drop_last=False, num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
|
||||
# check the first item in the batch
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1-idx, -1].sum() == 0
|
||||
assert linear_input[1-idx, -1].sum() == 0
|
||||
assert stop_target[1-idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import os
|
||||
import copy
|
||||
import torch
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.layers.losses import L1LossMasked
|
||||
from TTS.models.tacotron import Tacotron
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
|
||||
class TacotronTrainTest(unittest.TestCase):
|
||||
|
||||
def test_train_step(self):
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
criterion = L1LossMasked().to(device)
|
||||
criterion_st = nn.BCELoss().to(device)
|
||||
model = Tacotron(c.embedding_size,
|
||||
c.num_freq,
|
||||
c.num_mels,
|
||||
c.r).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
||||
assert stop_tokens.data.max() <= 1.0
|
||||
assert stop_tokens.data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
if count not in [141, 59]:
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
||||
count += 1
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue