mirror of https://github.com/coqui-ai/TTS.git
updates and debugs
This commit is contained in:
parent
dadefb5dbc
commit
56697ac8cf
|
@ -15,7 +15,7 @@
|
||||||
"lr": 0.003,
|
"lr": 0.003,
|
||||||
"lr_patience": 5,
|
"lr_patience": 5,
|
||||||
"lr_decay": 0.5,
|
"lr_decay": 0.5,
|
||||||
"batch_size": 98,
|
"batch_size": 180,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
|
@ -23,7 +23,8 @@
|
||||||
|
|
||||||
"num_loader_workers": 32,
|
"num_loader_workers": 32,
|
||||||
|
|
||||||
"save_step": 200,
|
"checkpoint": false,
|
||||||
|
"save_step": 69,
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
"output_path": "result",
|
"output_path": "result",
|
||||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
|
|
|
@ -72,9 +72,14 @@ class LJSpeechDataset(Dataset):
|
||||||
timesteps = mel.shape[2]
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
# PAD with zeros that can be divided by outputs per step
|
# PAD with zeros that can be divided by outputs per step
|
||||||
if timesteps % self.outputs_per_step != 0:
|
if (timesteps + 1) % self.outputs_per_step != 0:
|
||||||
linear = pad_per_step(linear, self.outputs_per_step)
|
pad_len = self.outputs_per_step - \
|
||||||
mel = pad_per_step(mel, self.outputs_per_step)
|
((timesteps + 1) % self.outputs_per_step)
|
||||||
|
pad_len += 1
|
||||||
|
else:
|
||||||
|
pad_len = 1
|
||||||
|
linear = pad_per_step(linear, pad_len)
|
||||||
|
mel = pad_per_step(mel, pad_len)
|
||||||
|
|
||||||
# reshape jombo
|
# reshape jombo
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
|
|
|
@ -192,6 +192,14 @@ class Encoder(nn.Module):
|
||||||
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
inputs (FloatTensor): embedding features
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- inputs: batch x time x embedding_size
|
||||||
|
- outputs: batch x time x 128
|
||||||
|
"""
|
||||||
inputs = self.prenet(inputs)
|
inputs = self.prenet(inputs)
|
||||||
return self.cbhg(inputs)
|
return self.cbhg(inputs)
|
||||||
|
|
||||||
|
@ -200,12 +208,9 @@ class Decoder(nn.Module):
|
||||||
r"""Decoder module.
|
r"""Decoder module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_dim (int): memory vector sample size
|
in_features (int): input vector (encoder output) sample size.
|
||||||
r (int): number of outputs per time step
|
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||||
|
r (int): number of outputs per time step.
|
||||||
Shape:
|
|
||||||
- input:
|
|
||||||
- output:
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features, memory_dim, r):
|
def __init__(self, in_features, memory_dim, r):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
@ -263,9 +268,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
# Grouping multiple frames if necessary
|
# Grouping multiple frames if necessary
|
||||||
if memory.size(-1) == self.memory_dim:
|
if memory.size(-1) == self.memory_dim:
|
||||||
print(" > Blamento", memory.shape)
|
|
||||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||||
print(" > Blamento", memory.shape)
|
|
||||||
assert memory.size(-1) == self.memory_dim * self.r,\
|
assert memory.size(-1) == self.memory_dim * self.r,\
|
||||||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
||||||
self.memory_dim, self.r)
|
self.memory_dim, self.r)
|
||||||
|
|
|
@ -20,7 +20,7 @@ class Tacotron(nn.Module):
|
||||||
# Trying smaller std
|
# Trying smaller std
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
self.decoder = Decoder(mel_dim, r)
|
self.decoder = Decoder(256, mel_dim, r)
|
||||||
|
|
||||||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
||||||
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
|
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -41,6 +41,7 @@ class TestDataset(unittest.TestCase):
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
item_idx = data[4]
|
item_idx = data[4]
|
||||||
|
|
||||||
|
@ -48,7 +49,45 @@ class TestDataset(unittest.TestCase):
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, \
|
assert check_count == 0, \
|
||||||
" !! Negative values in text_input: {}".format(check_count)
|
" !! Negative values in text_input: {}".format(check_count)
|
||||||
|
# TODO: more assertion here
|
||||||
|
assert linear_input.shape[0] == c.batch_size
|
||||||
|
assert mel_input.shape[0] == c.batch_size
|
||||||
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
|
def test_padding(self):
|
||||||
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
|
os.path.join(c.data_path, 'wavs'),
|
||||||
|
1,
|
||||||
|
c.sample_rate,
|
||||||
|
c.text_cleaner,
|
||||||
|
c.num_mels,
|
||||||
|
c.min_level_db,
|
||||||
|
c.frame_shift_ms,
|
||||||
|
c.frame_length_ms,
|
||||||
|
c.preemphasis,
|
||||||
|
c.ref_level_db,
|
||||||
|
c.num_freq,
|
||||||
|
c.power
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=1,
|
||||||
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
item_idx = data[4]
|
||||||
|
|
||||||
|
# check the last time step to be zero padded
|
||||||
|
assert mel_input[0, -1].sum() == 0
|
||||||
|
assert mel_input[0, -2].sum() != 0
|
||||||
|
assert linear_input[0, -1].sum() == 0
|
||||||
|
assert linear_input[0, -2].sum() != 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"lr": 0.003,
|
"lr": 0.003,
|
||||||
"lr_patience": 5,
|
"lr_patience": 5,
|
||||||
"lr_decay": 0.5,
|
"lr_decay": 0.5,
|
||||||
"batch_size": 8,
|
"batch_size": 2,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
|
|
19
train.py
19
train.py
|
@ -20,7 +20,7 @@ from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
load_config, lr_decay)
|
save_best_model, load_config, lr_decay)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
|
@ -101,7 +101,7 @@ def main(args):
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
print("\n > Model restored from step %d\n" % args.restore_step)
|
print("\n > Model restored from step %d\n" % args.restore_step)
|
||||||
start_epoch = checkpoint['step'] // len(dataloader)
|
start_epoch = checkpoint['step'] // len(dataloader)
|
||||||
|
best_loss = checkpoint['linear_loss']
|
||||||
else:
|
else:
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
print("\n > Starting a new training")
|
print("\n > Starting a new training")
|
||||||
|
@ -144,6 +144,7 @@ def main(args):
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Add a single frame of zeros to Mel Specs for better end detection
|
||||||
#try:
|
#try:
|
||||||
# mel_input = np.concatenate((np.zeros(
|
# mel_input = np.concatenate((np.zeros(
|
||||||
# [c.batch_size, 1, c.num_mels], dtype=np.float32),
|
# [c.batch_size, 1, c.num_mels], dtype=np.float32),
|
||||||
|
@ -214,9 +215,11 @@ def main(args):
|
||||||
|
|
||||||
if current_step % c.save_step == 0:
|
if current_step % c.save_step == 0:
|
||||||
|
|
||||||
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
best_loss = save_checkpoint(model, loss.data[0],
|
save_checkpoint(model, optimizer, linear_loss.data[0],
|
||||||
best_loss, out_path=OUT_PATH)
|
best_loss, OUT_PATH,
|
||||||
|
current_step, epoch)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = linear_output[0].data.cpu().numpy()
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
|
@ -243,6 +246,14 @@ def main(args):
|
||||||
print(audio_signal.max())
|
print(audio_signal.max())
|
||||||
print(audio_signal.min())
|
print(audio_signal.min())
|
||||||
|
|
||||||
|
|
||||||
|
# average loss after the epoch
|
||||||
|
avg_epoch_loss = np.mean(
|
||||||
|
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
|
||||||
|
best_loss = save_best_model(model, optimizer, avg_epoch_loss,
|
||||||
|
best_loss, OUT_PATH,
|
||||||
|
current_step, epoch)
|
||||||
|
|
||||||
#lr_scheduler.step(loss.data[0])
|
#lr_scheduler.step(loss.data[0])
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
|
|
@ -14,9 +14,8 @@ def prepare_data(inputs):
|
||||||
return np.stack([pad_data(x, max_len) for x in inputs])
|
return np.stack([pad_data(x, max_len) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
def pad_per_step(inputs, outputs_per_step):
|
def pad_per_step(inputs, pad_len):
|
||||||
"""zero pad inputs if it is not divisible with outputs_per_step (r)"""
|
|
||||||
timesteps = inputs.shape[-1]
|
timesteps = inputs.shape[-1]
|
||||||
return np.pad(inputs, [[0, 0], [0, 0],
|
return np.pad(inputs, [[0, 0], [0, 0],
|
||||||
[0, outputs_per_step - (timesteps % outputs_per_step)]],
|
[0, pad_len]],
|
||||||
mode='constant', constant_values=0.0)
|
mode='constant', constant_values=0.0)
|
||||||
|
|
|
@ -48,7 +48,8 @@ def copy_config_file(config_file, path):
|
||||||
shutil.copyfile(config_file, out_path)
|
shutil.copyfile(config_file, out_path)
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, model_loss, best_loss, out_path):
|
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
|
||||||
|
current_step, epoch):
|
||||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||||
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
|
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
|
||||||
|
@ -56,16 +57,24 @@ def save_checkpoint(model, model_loss, best_loss, out_path):
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
'step': current_step,
|
'step': current_step,
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
'total_loss': loss.data[0],
|
'linear_loss': model_loss,
|
||||||
'linear_loss': linear_loss.data[0],
|
|
||||||
'mel_loss': mel_loss.data[0],
|
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y")}
|
'date': datetime.date.today().strftime("%B %d, %Y")}
|
||||||
torch.save(state, checkpoint_path)
|
torch.save(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
|
current_step, epoch):
|
||||||
if model_loss < best_loss:
|
if model_loss < best_loss:
|
||||||
|
state = {'model': model.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'step': current_step,
|
||||||
|
'epoch': epoch,
|
||||||
|
'linear_loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y")}
|
||||||
best_loss = model_loss
|
best_loss = model_loss
|
||||||
bestmodel_path = 'best_model.pth.tar'.format(current_step)
|
bestmodel_path = 'best_model.pth.tar'
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
print("\n | > Best model saving with loss {} : {}".format(model_loss, bestmodel_path))
|
print("\n | > Best model saving with loss {0:.2f} : {1:}".format(model_loss, bestmodel_path))
|
||||||
torch.save(state, bestmodel_path)
|
torch.save(state, bestmodel_path)
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue