mirror of https://github.com/coqui-ai/TTS.git
Change scheduler AnnealLR and catch audio synthesis error in eval time
This commit is contained in:
parent
f5e87a0c70
commit
bb526c296f
21
train.py
21
train.py
|
@ -16,7 +16,7 @@ from tensorboardX import SummaryWriter
|
||||||
from utils.generic_utils import (
|
from utils.generic_utils import (
|
||||||
synthesis, remove_experiment_folder, create_experiment_folder,
|
synthesis, remove_experiment_folder, create_experiment_folder,
|
||||||
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
||||||
check_update, get_commit_hash, sequence_mask)
|
check_update, get_commit_hash, sequence_mask, AnnealLR)
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
@ -312,15 +312,13 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
# test sentences
|
# test sentences
|
||||||
ap.griffin_lim_iters = 60
|
ap.griffin_lim_iters = 60
|
||||||
for idx, test_sentence in enumerate(test_sentences):
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
|
try:
|
||||||
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
||||||
use_cuda, c.text_cleaner)
|
use_cuda, c.text_cleaner)
|
||||||
try:
|
|
||||||
wav_name = 'TestSentences/{}'.format(idx)
|
wav_name = 'TestSentences/{}'.format(idx)
|
||||||
tb.add_audio(
|
tb.add_audio(
|
||||||
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
||||||
except:
|
|
||||||
print(" !! Error as creating Test Sentence -", idx)
|
|
||||||
pass
|
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
linear_spec = plot_spectrogram(linear_spec, ap)
|
linear_spec = plot_spectrogram(linear_spec, ap)
|
||||||
align_img = plot_alignment(align_img)
|
align_img = plot_alignment(align_img)
|
||||||
|
@ -328,6 +326,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
current_step)
|
current_step)
|
||||||
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
|
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
|
||||||
current_step)
|
current_step)
|
||||||
|
except:
|
||||||
|
print(" !! Error as creating Test Sentence -", idx)
|
||||||
|
pass
|
||||||
return avg_linear_loss
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -337,14 +338,6 @@ def main(args):
|
||||||
audio = importlib.import_module('utils.' + c.audio_processor)
|
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
|
|
||||||
print(" > LR scheduler: {} ", c.lr_scheduler)
|
|
||||||
try:
|
|
||||||
scheduler = importlib.import_module('torch.optim.lr_scheduler')
|
|
||||||
scheduler = getattr(scheduler, c.lr_scheduler)
|
|
||||||
except:
|
|
||||||
scheduler = importlib.import_module('utils.generic_utils')
|
|
||||||
scheduler = getattr(scheduler, c.lr_scheduler)
|
|
||||||
|
|
||||||
ap = AudioProcessor(
|
ap = AudioProcessor(
|
||||||
sample_rate=c.sample_rate,
|
sample_rate=c.sample_rate,
|
||||||
num_mels=c.num_mels,
|
num_mels=c.num_mels,
|
||||||
|
@ -426,7 +419,7 @@ def main(args):
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
|
|
||||||
scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay)
|
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps)
|
||||||
num_params = count_parameters(model)
|
num_params = count_parameters(model)
|
||||||
print(" | > Model has {} parameters".format(num_params), flush=True)
|
print(" | > Model has {} parameters".format(num_params), flush=True)
|
||||||
|
|
||||||
|
|
|
@ -143,16 +143,16 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
|
|
||||||
|
|
||||||
class AnnealLR(torch.optim.lr_scheduler._LRScheduler):
|
class AnnealLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
def __init__(self, optimizer, warmup_steps=0.1):
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||||
self.warmup_steps = float(warmup_steps)
|
self.warmup_steps = float(warmup_steps)
|
||||||
super(AnnealLR, self).__init__(optimizer, last_epoch)
|
super(AnnealLR, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
|
step = max(self.last_epoch, 1)
|
||||||
return [
|
return [
|
||||||
base_lr * self.warmup_steps**0.5 * torch.min([
|
base_lr * self.warmup_steps**0.5 * min(
|
||||||
self.last_epoch * self.warmup_steps**-1.5, self.last_epoch**
|
step * self.warmup_steps**-1.5, step**-0.5)
|
||||||
-0.5
|
for base_lr in self.base_lrs
|
||||||
]) for base_lr in self.base_lrs
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue