mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' of https://github.com/mozilla/TTS into dev
This commit is contained in:
commit
871588c391
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"run_name": "ljspeech-stft_params",
|
||||
"run_description": "tacotron2 cosntant stf parameters",
|
||||
"run_name": "ljspeech",
|
||||
"run_description": "tacotron2 with guided attention and -1 1 normalization and no preemphasis",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
|
|
|
@ -13,6 +13,7 @@ class MyDataset(Dataset):
|
|||
def __init__(self,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
compute_linear_spec,
|
||||
ap,
|
||||
meta_data,
|
||||
tp=None,
|
||||
|
@ -28,6 +29,7 @@ class MyDataset(Dataset):
|
|||
Args:
|
||||
outputs_per_step (int): number of time frames predicted per step.
|
||||
text_cleaner (str): text cleaner used for the dataset.
|
||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
ap (TTS.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
batch_group_size (int): (0) range of batch randomization after sorting
|
||||
|
@ -47,6 +49,7 @@ class MyDataset(Dataset):
|
|||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.compute_linear_spec = compute_linear_spec
|
||||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
|
@ -193,7 +196,6 @@ class MyDataset(Dataset):
|
|||
|
||||
# compute features
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
|
@ -208,25 +210,29 @@ class MyDataset(Dataset):
|
|||
|
||||
# PAD sequences with longest instance in the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
wav = prepare_data(wav)
|
||||
|
||||
# PAD features with longest instance
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
assert mel.shape[2] == linear.shape[2]
|
||||
|
||||
# B x D x T --> B x T x D
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear).contiguous()
|
||||
mel = torch.FloatTensor(mel).contiguous()
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
# compute linear spectrogram
|
||||
if self.compute_linear_spec:
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
assert mel.shape[1] == linear.shape[1]
|
||||
linear = torch.FloatTensor(linear).contiguous()
|
||||
else:
|
||||
linear = None
|
||||
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
||||
stop_targets, item_idxs
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
dataset = TTSDataset.MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=True,
|
||||
ap=self.ap,
|
||||
meta_data=items,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
|
|
4
train.py
4
train.py
|
@ -47,6 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
dataset = MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=True if c.model.lower() is 'tacotron' else False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
|
@ -695,6 +696,9 @@ if __name__ == '__main__':
|
|||
LOG_DIR = OUT_PATH
|
||||
tb_logger = Logger(LOG_DIR)
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
@ -75,3 +75,6 @@ class Logger(object):
|
|||
|
||||
def tb_test_figures(self, step, figures):
|
||||
self.dict_to_tb_figure("TestFigures", figures, step)
|
||||
|
||||
def tb_add_text(self, title, text, step):
|
||||
self.writer.add_text(title, text, step)
|
||||
|
|
Loading…
Reference in New Issue