mirror of https://github.com/coqui-ai/TTS.git
fix unittests for the latest updates
This commit is contained in:
parent
d5cefe0546
commit
4336e1d338
|
@ -13,7 +13,6 @@ from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
|||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self,
|
||||
root_path,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
|
@ -28,13 +27,10 @@ class MyDataset(Dataset):
|
|||
verbose=False):
|
||||
"""
|
||||
Args:
|
||||
root_path (str): root path for the data folder.
|
||||
outputs_per_step (int): number of time frames predicted per step.
|
||||
text_cleaner (str): text cleaner used for the dataset.
|
||||
ap (TTS.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
speaker_id_cache_path (str): path where the speaker name to id
|
||||
mapping is stored
|
||||
batch_group_size (int): (0) range of batch randomization after sorting
|
||||
sequences by length.
|
||||
min_seq_len (int): (0) minimum sequence length to be processed
|
||||
|
@ -47,7 +43,6 @@ class MyDataset(Dataset):
|
|||
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
||||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
self.root_path = root_path
|
||||
self.batch_group_size = batch_group_size
|
||||
self.items = meta_data
|
||||
self.outputs_per_step = outputs_per_step
|
||||
|
@ -65,7 +60,6 @@ class MyDataset(Dataset):
|
|||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(" | > Data path: {}".format(root_path))
|
||||
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
||||
if use_phonemes:
|
||||
print(" | > phoneme language: {}".format(phoneme_language))
|
||||
|
|
|
@ -70,6 +70,8 @@ class Tacotron(nn.Module):
|
|||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||
|
||||
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||
|
||||
|
|
|
@ -89,7 +89,9 @@ class Tacotron2(nn.Module):
|
|||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
|
||||
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
elif hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||
|
||||
speaker_embeddings.unsqueeze_(1)
|
||||
|
|
|
@ -38,11 +38,25 @@ class CBHGTests(unittest.TestCase):
|
|||
|
||||
class DecoderTests(unittest.TestCase):
|
||||
def test_in_out(self):
|
||||
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid")
|
||||
layer = Decoder(
|
||||
in_features=256,
|
||||
memory_dim=80,
|
||||
r=2,
|
||||
memory_size=4,
|
||||
attn_windowing=False,
|
||||
attn_norm="sigmoid",
|
||||
prenet_type='original',
|
||||
prenet_dropout=True,
|
||||
forward_attn=True,
|
||||
trans_agent=True,
|
||||
forward_attn_mask=True,
|
||||
location_attn=True,
|
||||
separate_stopnet=True)
|
||||
dummy_input = T.rand(4, 8, 256)
|
||||
dummy_memory = T.rand(4, 2, 80)
|
||||
|
||||
output, alignment, stop_tokens = layer(dummy_input, dummy_memory, mask=None)
|
||||
output, alignment, stop_tokens = layer(
|
||||
dummy_input, dummy_memory, mask=None)
|
||||
|
||||
assert output.shape[0] == 4
|
||||
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
|
||||
|
|
|
@ -29,13 +29,12 @@ class TestTTSDataset(unittest.TestCase):
|
|||
self.ap = AudioProcessor(**c.audio)
|
||||
|
||||
def _create_dataloader(self, batch_size, r, bgs):
|
||||
items = ljspeech(c.data_path,'metadata.csv')
|
||||
dataset = TTSDataset.MyDataset(
|
||||
c.data_path,
|
||||
'metadata.csv',
|
||||
r,
|
||||
c.text_cleaner,
|
||||
preprocessor=ljspeech,
|
||||
ap=self.ap,
|
||||
meta_data=items,
|
||||
batch_group_size=bgs,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=float("inf"),
|
||||
|
@ -58,17 +57,19 @@ class TestTTSDataset(unittest.TestCase):
|
|||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
speaker_name = data[2]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert type(speaker_name[0]) is str
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert linear_input.shape[2] == self.ap.num_freq
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
|
@ -92,11 +93,12 @@ class TestTTSDataset(unittest.TestCase):
|
|||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
speaker_name = data[2]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
|
||||
avg_length = mel_lengths.numpy().mean()
|
||||
assert avg_length >= last_length
|
||||
|
@ -112,11 +114,12 @@ class TestTTSDataset(unittest.TestCase):
|
|||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
speaker_name = data[2]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
|
||||
# check mel_spec consistency
|
||||
wav = self.ap.load_wav(item_idx[0])
|
||||
|
@ -159,11 +162,12 @@ class TestTTSDataset(unittest.TestCase):
|
|||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
speaker_name = data[2]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
|
|
|
@ -27,6 +27,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
|
@ -37,7 +38,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
|
||||
criterion = MSELossMasked().to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron2(24, c.r).to(device)
|
||||
model = Tacotron2(24, c.r, 5).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
|
@ -48,7 +49,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||
input, input_lengths, mel_spec)
|
||||
input, input_lengths, mel_spec, speaker_ids)
|
||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -32,6 +32,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
|
@ -45,6 +46,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron(
|
||||
32,
|
||||
5,
|
||||
linear_dim=c.audio['num_freq'],
|
||||
mel_dim=c.audio['num_mels'],
|
||||
r=c.r,
|
||||
|
@ -60,7 +62,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input, input_lengths, mel_spec)
|
||||
input, input_lengths, mel_spec, speaker_ids)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
|
Loading…
Reference in New Issue