mirror of https://github.com/coqui-ai/TTS.git
fix unittests for the latest updates
This commit is contained in:
parent
d5cefe0546
commit
4336e1d338
|
@ -13,7 +13,6 @@ from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
|
|
||||||
class MyDataset(Dataset):
|
class MyDataset(Dataset):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
root_path,
|
|
||||||
outputs_per_step,
|
outputs_per_step,
|
||||||
text_cleaner,
|
text_cleaner,
|
||||||
ap,
|
ap,
|
||||||
|
@ -28,13 +27,10 @@ class MyDataset(Dataset):
|
||||||
verbose=False):
|
verbose=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
root_path (str): root path for the data folder.
|
|
||||||
outputs_per_step (int): number of time frames predicted per step.
|
outputs_per_step (int): number of time frames predicted per step.
|
||||||
text_cleaner (str): text cleaner used for the dataset.
|
text_cleaner (str): text cleaner used for the dataset.
|
||||||
ap (TTS.utils.AudioProcessor): audio processor object.
|
ap (TTS.utils.AudioProcessor): audio processor object.
|
||||||
meta_data (list): list of dataset instances.
|
meta_data (list): list of dataset instances.
|
||||||
speaker_id_cache_path (str): path where the speaker name to id
|
|
||||||
mapping is stored
|
|
||||||
batch_group_size (int): (0) range of batch randomization after sorting
|
batch_group_size (int): (0) range of batch randomization after sorting
|
||||||
sequences by length.
|
sequences by length.
|
||||||
min_seq_len (int): (0) minimum sequence length to be processed
|
min_seq_len (int): (0) minimum sequence length to be processed
|
||||||
|
@ -47,7 +43,6 @@ class MyDataset(Dataset):
|
||||||
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
||||||
verbose (bool): print diagnostic information.
|
verbose (bool): print diagnostic information.
|
||||||
"""
|
"""
|
||||||
self.root_path = root_path
|
|
||||||
self.batch_group_size = batch_group_size
|
self.batch_group_size = batch_group_size
|
||||||
self.items = meta_data
|
self.items = meta_data
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
|
@ -65,7 +60,6 @@ class MyDataset(Dataset):
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("\n > DataLoader initialization")
|
print("\n > DataLoader initialization")
|
||||||
print(" | > Data path: {}".format(root_path))
|
|
||||||
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
||||||
if use_phonemes:
|
if use_phonemes:
|
||||||
print(" | > phoneme language: {}".format(phoneme_language))
|
print(" | > phoneme language: {}".format(phoneme_language))
|
||||||
|
|
|
@ -70,6 +70,8 @@ class Tacotron(nn.Module):
|
||||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
||||||
|
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||||
|
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,9 @@ class Tacotron2(nn.Module):
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||||
|
|
||||||
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
||||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||||
|
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||||
|
elif hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
speaker_embeddings.unsqueeze_(1)
|
speaker_embeddings.unsqueeze_(1)
|
||||||
|
|
|
@ -38,11 +38,25 @@ class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid")
|
layer = Decoder(
|
||||||
|
in_features=256,
|
||||||
|
memory_dim=80,
|
||||||
|
r=2,
|
||||||
|
memory_size=4,
|
||||||
|
attn_windowing=False,
|
||||||
|
attn_norm="sigmoid",
|
||||||
|
prenet_type='original',
|
||||||
|
prenet_dropout=True,
|
||||||
|
forward_attn=True,
|
||||||
|
trans_agent=True,
|
||||||
|
forward_attn_mask=True,
|
||||||
|
location_attn=True,
|
||||||
|
separate_stopnet=True)
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
dummy_memory = T.rand(4, 2, 80)
|
dummy_memory = T.rand(4, 2, 80)
|
||||||
|
|
||||||
output, alignment, stop_tokens = layer(dummy_input, dummy_memory, mask=None)
|
output, alignment, stop_tokens = layer(
|
||||||
|
dummy_input, dummy_memory, mask=None)
|
||||||
|
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
|
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
|
||||||
|
|
|
@ -29,13 +29,12 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
self.ap = AudioProcessor(**c.audio)
|
self.ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
def _create_dataloader(self, batch_size, r, bgs):
|
def _create_dataloader(self, batch_size, r, bgs):
|
||||||
|
items = ljspeech(c.data_path,'metadata.csv')
|
||||||
dataset = TTSDataset.MyDataset(
|
dataset = TTSDataset.MyDataset(
|
||||||
c.data_path,
|
|
||||||
'metadata.csv',
|
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
preprocessor=ljspeech,
|
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
|
meta_data=items,
|
||||||
batch_group_size=bgs,
|
batch_group_size=bgs,
|
||||||
min_seq_len=c.min_seq_len,
|
min_seq_len=c.min_seq_len,
|
||||||
max_seq_len=float("inf"),
|
max_seq_len=float("inf"),
|
||||||
|
@ -58,17 +57,19 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
speaker_name = data[2]
|
||||||
mel_input = data[3]
|
linear_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_input = data[4]
|
||||||
stop_target = data[5]
|
mel_lengths = data[5]
|
||||||
item_idx = data[6]
|
stop_target = data[6]
|
||||||
|
item_idx = data[7]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, \
|
assert check_count == 0, \
|
||||||
" !! Negative values in text_input: {}".format(check_count)
|
" !! Negative values in text_input: {}".format(check_count)
|
||||||
# TODO: more assertion here
|
# TODO: more assertion here
|
||||||
|
assert type(speaker_name[0]) is str
|
||||||
assert linear_input.shape[0] == c.batch_size
|
assert linear_input.shape[0] == c.batch_size
|
||||||
assert linear_input.shape[2] == self.ap.num_freq
|
assert linear_input.shape[2] == self.ap.num_freq
|
||||||
assert mel_input.shape[0] == c.batch_size
|
assert mel_input.shape[0] == c.batch_size
|
||||||
|
@ -92,11 +93,12 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
speaker_name = data[2]
|
||||||
mel_input = data[3]
|
linear_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_input = data[4]
|
||||||
stop_target = data[5]
|
mel_lengths = data[5]
|
||||||
item_idx = data[6]
|
stop_target = data[6]
|
||||||
|
item_idx = data[7]
|
||||||
|
|
||||||
avg_length = mel_lengths.numpy().mean()
|
avg_length = mel_lengths.numpy().mean()
|
||||||
assert avg_length >= last_length
|
assert avg_length >= last_length
|
||||||
|
@ -112,11 +114,12 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
speaker_name = data[2]
|
||||||
mel_input = data[3]
|
linear_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_input = data[4]
|
||||||
stop_target = data[5]
|
mel_lengths = data[5]
|
||||||
item_idx = data[6]
|
stop_target = data[6]
|
||||||
|
item_idx = data[7]
|
||||||
|
|
||||||
# check mel_spec consistency
|
# check mel_spec consistency
|
||||||
wav = self.ap.load_wav(item_idx[0])
|
wav = self.ap.load_wav(item_idx[0])
|
||||||
|
@ -159,11 +162,12 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
speaker_name = data[2]
|
||||||
mel_input = data[3]
|
linear_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_input = data[4]
|
||||||
stop_target = data[5]
|
mel_lengths = data[5]
|
||||||
item_idx = data[6]
|
stop_target = data[6]
|
||||||
|
item_idx = data[7]
|
||||||
|
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
|
@ -27,6 +27,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||||
|
@ -37,7 +38,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
criterion = MSELossMasked().to(device)
|
criterion = MSELossMasked().to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron2(24, c.r).to(device)
|
model = Tacotron2(24, c.r, 5).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -48,7 +49,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||||
input, input_lengths, mel_spec)
|
input, input_lengths, mel_spec, speaker_ids)
|
||||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
|
@ -32,6 +32,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
|
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||||
|
@ -45,6 +46,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron(
|
model = Tacotron(
|
||||||
32,
|
32,
|
||||||
|
5,
|
||||||
linear_dim=c.audio['num_freq'],
|
linear_dim=c.audio['num_freq'],
|
||||||
mel_dim=c.audio['num_mels'],
|
mel_dim=c.audio['num_mels'],
|
||||||
r=c.r,
|
r=c.r,
|
||||||
|
@ -60,7 +62,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
input, input_lengths, mel_spec)
|
input, input_lengths, mel_spec, speaker_ids)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
|
|
1
train.py
1
train.py
|
@ -58,7 +58,6 @@ def setup_loader(is_val=False, verbose=False):
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = MyDataset(
|
dataset = MyDataset(
|
||||||
c.data_path,
|
|
||||||
c.r,
|
c.r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||||
|
|
Loading…
Reference in New Issue