Merge branch 'dev' of https://github.com/mozilla/TTS into dev

This commit is contained in:
erogol 2020-03-12 03:27:52 +01:00
commit 871588c391
5 changed files with 22 additions and 8 deletions

View File

@ -1,7 +1,7 @@
{
"model": "Tacotron2", // one of the model in models/
"run_name": "ljspeech-stft_params",
"run_description": "tacotron2 cosntant stf parameters",
"run_name": "ljspeech",
"run_description": "tacotron2 with guided attention and -1 1 normalization and no preemphasis",
// AUDIO PARAMETERS
"audio":{

View File

@ -13,6 +13,7 @@ class MyDataset(Dataset):
def __init__(self,
outputs_per_step,
text_cleaner,
compute_linear_spec,
ap,
meta_data,
tp=None,
@ -28,6 +29,7 @@ class MyDataset(Dataset):
Args:
outputs_per_step (int): number of time frames predicted per step.
text_cleaner (str): text cleaner used for the dataset.
compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
batch_group_size (int): (0) range of batch randomization after sorting
@ -47,6 +49,7 @@ class MyDataset(Dataset):
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
@ -193,7 +196,6 @@ class MyDataset(Dataset):
# compute features
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel_lengths = [m.shape[1] for m in mel]
@ -208,25 +210,29 @@ class MyDataset(Dataset):
# PAD sequences with longest instance in the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with longest instance
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
# B x D x T --> B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear).contiguous()
mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
linear = torch.FloatTensor(linear).contiguous()
else:
linear = None
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
stop_targets, item_idxs

View File

@ -36,6 +36,7 @@ class TestTTSDataset(unittest.TestCase):
dataset = TTSDataset.MyDataset(
r,
c.text_cleaner,
compute_linear_spec=True,
ap=self.ap,
meta_data=items,
tp=c.characters if 'characters' in c.keys() else None,

View File

@ -47,6 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
dataset = MyDataset(
r,
c.text_cleaner,
compute_linear_spec=True if c.model.lower() is 'tacotron' else False,
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
@ -695,6 +696,9 @@ if __name__ == '__main__':
LOG_DIR = OUT_PATH
tb_logger = Logger(LOG_DIR)
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try:
main(args)
except KeyboardInterrupt:

View File

@ -75,3 +75,6 @@ class Logger(object):
def tb_test_figures(self, step, figures):
self.dict_to_tb_figure("TestFigures", figures, step)
def tb_add_text(self, title, text, step):
self.writer.add_text(title, text, step)