mirror of https://github.com/coqui-ai/TTS.git
Enable CPU training and fix restore_epoch
This commit is contained in:
parent
e00c3f5044
commit
1ae3b3e442
|
@ -321,6 +321,7 @@ class Decoder(nn.Module):
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
# Grouping multiple frames if necessary
|
# Grouping multiple frames if necessary
|
||||||
if memory.size(-1) == self.memory_dim:
|
if memory.size(-1) == self.memory_dim:
|
||||||
|
memory = memory.contiguous()
|
||||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||||
" !! Dimension mismatch {} vs {} * {}".format(
|
" !! Dimension mismatch {} vs {} * {}".format(
|
||||||
memory.size(-1), self.memory_dim, self.r)
|
memory.size(-1), self.memory_dim, self.r)
|
||||||
|
|
16
train.py
16
train.py
|
@ -75,8 +75,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
mask = sequence_mask(text_lengths)
|
mask = sequence_mask(text_lengths)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
if use_cuda:
|
||||||
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
|
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
|
||||||
model, (text_input, mel_input, mask))
|
model, (text_input, mel_input, mask))
|
||||||
|
else:
|
||||||
|
mel_output, linear_output, alignments, stop_tokens = model(text_input, mel_input, mask)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
|
@ -274,7 +277,6 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
const_spec = linear_output[idx].data.cpu().numpy()
|
const_spec = linear_output[idx].data.cpu().numpy()
|
||||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
\
|
|
||||||
const_spec = plot_spectrogram(const_spec, ap)
|
const_spec = plot_spectrogram(const_spec, ap)
|
||||||
gt_spec = plot_spectrogram(gt_spec, ap)
|
gt_spec = plot_spectrogram(gt_spec, ap)
|
||||||
align_img = plot_alignment(align_img)
|
align_img = plot_alignment(align_img)
|
||||||
|
@ -422,14 +424,14 @@ def main(args):
|
||||||
best_loss = checkpoint['linear_loss']
|
best_loss = checkpoint['linear_loss']
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
else:
|
else:
|
||||||
args.restore_step = -1
|
args.restore_step = 0
|
||||||
print("\n > Starting a new training", flush=True)
|
print("\n > Starting a new training", flush=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
|
|
||||||
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step)
|
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch= (args.restore_step-1))
|
||||||
num_params = count_parameters(model)
|
num_params = count_parameters(model)
|
||||||
print(" | > Model has {} parameters".format(num_params), flush=True)
|
print(" | > Model has {} parameters".format(num_params), flush=True)
|
||||||
|
|
||||||
|
@ -472,6 +474,11 @@ if __name__ == '__main__':
|
||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
help='do not ask for git has before run.')
|
help='do not ask for git has before run.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--data_path',
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help='data path to overrite config.json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
|
@ -484,6 +491,9 @@ if __name__ == '__main__':
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||||
|
|
||||||
|
if args.data_path != "":
|
||||||
|
c.data_path = args.data_path
|
||||||
|
|
||||||
# setup tensorboard
|
# setup tensorboard
|
||||||
LOG_DIR = OUT_PATH
|
LOG_DIR = OUT_PATH
|
||||||
tb = SummaryWriter(LOG_DIR)
|
tb = SummaryWriter(LOG_DIR)
|
||||||
|
|
Loading…
Reference in New Issue