mirror of https://github.com/coqui-ai/TTS.git
Formatting, fixing import statements, logging learning rate, remove optimizer restore cuda call
This commit is contained in:
parent
440f51b61d
commit
6550db5251
77
train.py
77
train.py
|
@ -14,9 +14,9 @@ from torch.utils.data import DataLoader
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (
|
from utils.generic_utils import (
|
||||||
remove_experiment_folder, create_experiment_folder,
|
remove_experiment_folder, create_experiment_folder, save_checkpoint,
|
||||||
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
save_best_model, load_config, lr_decay, count_parameters, check_update,
|
||||||
check_update, get_commit_hash, sequence_mask, AnnealLR)
|
get_commit_hash, sequence_mask, AnnealLR)
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
@ -25,6 +25,8 @@ from utils.synthesis import synthesis
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
print(" > Using CUDA: ", use_cuda)
|
||||||
|
print(" > Number of GPUs: ", torch.cuda.device_count())
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
@ -36,7 +38,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
avg_step_time = 0
|
avg_step_time = 0
|
||||||
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||||
n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
n_priority_freq = int(
|
||||||
|
3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
||||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -80,7 +83,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
|
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
|
||||||
model, (text_input, mel_input, mask))
|
model, (text_input, mel_input, mask))
|
||||||
else:
|
else:
|
||||||
mel_output, linear_output, alignments, stop_tokens = model(text_input, mel_input, mask)
|
mel_output, linear_output, alignments, stop_tokens = model(
|
||||||
|
text_input, mel_input, mask)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
|
@ -96,6 +100,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
# custom weight decay
|
# custom weight decay
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
for param in group['params']:
|
for param in group['params']:
|
||||||
|
current_lr = group['lr']
|
||||||
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
||||||
grad_norm, skip_flag = check_update(model, 1)
|
grad_norm, skip_flag = check_update(model, 1)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
|
@ -106,13 +111,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
|
# custom weight decay
|
||||||
for group in optimizer_st.param_groups:
|
for group in optimizer_st.param_groups:
|
||||||
for param in group['params']:
|
for param in group['params']:
|
||||||
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
||||||
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
print(" | | > Iteration skipped fro stopnet!!")
|
print(" | > Iteration skipped fro stopnet!!")
|
||||||
continue
|
continue
|
||||||
optimizer_st.step()
|
optimizer_st.step()
|
||||||
|
|
||||||
|
@ -121,12 +127,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
if current_step % c.print_step == 0:
|
if current_step % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(
|
"GradNormST:{:.5f} StepTime:{:.2f} LR:{:.6f}".format(
|
||||||
num_iter, batch_n_iter, current_step, loss.item(),
|
num_iter, batch_n_iter, current_step, loss.item(),
|
||||||
linear_loss.item(), mel_loss.item(), stop_loss.item(),
|
linear_loss.item(), mel_loss.item(), stop_loss.item(),
|
||||||
grad_norm, grad_norm_st, step_time),
|
grad_norm, grad_norm_st, step_time, current_lr),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += linear_loss.item()
|
||||||
|
@ -186,7 +192,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
print(
|
print(
|
||||||
" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
||||||
|
@ -217,7 +223,8 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
"This cake is great. It's so delicious and moist."
|
"This cake is great. It's so delicious and moist."
|
||||||
]
|
]
|
||||||
n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
n_priority_freq = int(
|
||||||
|
3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if data_loader is not None:
|
if data_loader is not None:
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
|
@ -263,7 +270,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
|
|
||||||
if num_iter % c.print_step == 0:
|
if num_iter % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
" | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
||||||
"StopLoss: {:.5f} ".format(loss.item(),
|
"StopLoss: {:.5f} ".format(loss.item(),
|
||||||
linear_loss.item(),
|
linear_loss.item(),
|
||||||
mel_loss.item(),
|
mel_loss.item(),
|
||||||
|
@ -322,8 +329,8 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
ap.griffin_lim_iters = 60
|
ap.griffin_lim_iters = 60
|
||||||
for idx, test_sentence in enumerate(test_sentences):
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
try:
|
try:
|
||||||
wav, alignment, linear_spec, stop_tokens = synthesis(model, test_sentence, c,
|
wav, alignment, linear_spec, stop_tokens = synthesis(
|
||||||
use_cuda, ap)
|
model, test_sentence, c, use_cuda, ap)
|
||||||
|
|
||||||
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
@ -333,10 +340,12 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
|
|
||||||
wav_name = 'TestSentences/{}'.format(idx)
|
wav_name = 'TestSentences/{}'.format(idx)
|
||||||
tb.add_audio(
|
tb.add_audio(
|
||||||
wav_name, wav, current_step, sample_rate=c.audio['sample_rate'])
|
wav_name,
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
wav,
|
||||||
|
current_step,
|
||||||
|
sample_rate=c.audio['sample_rate'])
|
||||||
linear_spec = plot_spectrogram(linear_spec, ap)
|
linear_spec = plot_spectrogram(linear_spec, ap)
|
||||||
align_img = plot_alignment(align_img)
|
align_img = plot_alignment(alignment)
|
||||||
tb.add_figure('TestSentences/{}_Spectrogram'.format(idx),
|
tb.add_figure('TestSentences/{}_Spectrogram'.format(idx),
|
||||||
linear_spec, current_step)
|
linear_spec, current_step)
|
||||||
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
|
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
|
||||||
|
@ -348,13 +357,15 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
# Conditional imports
|
||||||
preprocessor = importlib.import_module('datasets.preprocess')
|
preprocessor = importlib.import_module('datasets.preprocess')
|
||||||
preprocessor = getattr(preprocessor, c.dataset.lower())
|
preprocessor = getattr(preprocessor, c.dataset.lower())
|
||||||
MyDataset = importlib.import_module('datasets.'+c.data_loader)
|
MyDataset = importlib.import_module('datasets.' + c.data_loader)
|
||||||
MyDataset = getattr(MyDataset, "MyDataset")
|
MyDataset = getattr(MyDataset, "MyDataset")
|
||||||
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
|
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
|
||||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
|
|
||||||
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
|
@ -365,7 +376,7 @@ def main(args):
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
batch_group_size=8*c.batch_size,
|
batch_group_size=8 * c.batch_size,
|
||||||
min_seq_len=c.min_seq_len)
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
|
@ -379,7 +390,13 @@ def main(args):
|
||||||
|
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
val_dataset = MyDataset(
|
val_dataset = MyDataset(
|
||||||
c.data_path, c.meta_file_val, c.r, c.text_cleaner, preprocessor=preprocessor, ap=ap, batch_group_size=0)
|
c.data_path,
|
||||||
|
c.meta_file_val,
|
||||||
|
c.r,
|
||||||
|
c.text_cleaner,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
ap=ap,
|
||||||
|
batch_group_size=0)
|
||||||
|
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
|
@ -410,11 +427,6 @@ def main(args):
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
# optimizer_st.load_state_dict(checkpoint['optimizer_st'])
|
|
||||||
for state in optimizer.state.values():
|
|
||||||
for k, v in state.items():
|
|
||||||
if torch.is_tensor(v):
|
|
||||||
state[k] = v.cuda()
|
|
||||||
print(
|
print(
|
||||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||||
start_epoch = checkpoint['step'] // len(train_loader)
|
start_epoch = checkpoint['step'] // len(train_loader)
|
||||||
|
@ -428,7 +440,14 @@ def main(args):
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
|
|
||||||
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
if c.lr_decay:
|
||||||
|
scheduler = AnnealLR(
|
||||||
|
optimizer,
|
||||||
|
warmup_steps=c.warmup_steps,
|
||||||
|
last_epoch=args.restore_step - 1)
|
||||||
|
else:
|
||||||
|
scheduler = None
|
||||||
|
|
||||||
num_params = count_parameters(model)
|
num_params = count_parameters(model)
|
||||||
print(" | > Model has {} parameters".format(num_params), flush=True)
|
print(" | > Model has {} parameters".format(num_params), flush=True)
|
||||||
|
|
||||||
|
@ -472,11 +491,7 @@ if __name__ == '__main__':
|
||||||
default=False,
|
default=False,
|
||||||
help='do not ask for git has before run.')
|
help='do not ask for git has before run.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--data_path',
|
'--data_path', type=str, help='dataset path.', default='')
|
||||||
type=str,
|
|
||||||
help='dataset path.',
|
|
||||||
default=''
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
|
|
|
@ -11,7 +11,7 @@ import subprocess
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from TTS.utils.text import text_to_sequence
|
from utils.text import text_to_sequence
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
|
|
|
@ -5,7 +5,7 @@ Defines the set of symbols used in text input to the model.
|
||||||
The default is a set of ASCII characters that works well for English or text that has been run
|
The default is a set of ASCII characters that works well for English or text that has been run
|
||||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||||
'''
|
'''
|
||||||
from TTS.utils.text import cmudict
|
from utils.text import cmudict
|
||||||
|
|
||||||
_pad = '_'
|
_pad = '_'
|
||||||
_eos = '~'
|
_eos = '~'
|
||||||
|
|
Loading…
Reference in New Issue