mirror of https://github.com/coqui-ai/TTS.git
setup training scripts for computing phonemes before training optionally. And define data_loaders before starting training and re-use them instead of re-define for every train and eval calls. This is to enable better instance filtering based on input length.
This commit is contained in:
parent
7c3cdced1a
commit
affe1c1138
|
@ -59,6 +59,12 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
enable_eos_bos=c.enable_eos_bos_chars,
|
enable_eos_bos=c.enable_eos_bos_chars,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||||
|
|
||||||
|
if c.use_phonemes and c.compute_input_seq_cache:
|
||||||
|
# precompute phonemes to have a better estimate of sequence lengths.
|
||||||
|
dataset.compute_input_seq(c.num_loader_workers)
|
||||||
|
dataset.sort_items()
|
||||||
|
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -112,7 +118,7 @@ def format_data(data):
|
||||||
avg_text_length, avg_spec_length, attn_mask, item_idx
|
avg_text_length, avg_spec_length, attn_mask, item_idx
|
||||||
|
|
||||||
|
|
||||||
def data_depended_init(model, ap):
|
def data_depended_init(data_loader, model, ap):
|
||||||
"""Data depended initialization for activation normalization."""
|
"""Data depended initialization for activation normalization."""
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
for f in model.module.decoder.flows:
|
for f in model.module.decoder.flows:
|
||||||
|
@ -123,7 +129,6 @@ def data_depended_init(model, ap):
|
||||||
if getattr(f, "set_ddi", False):
|
if getattr(f, "set_ddi", False):
|
||||||
f.set_ddi(True)
|
f.set_ddi(True)
|
||||||
|
|
||||||
data_loader = setup_loader(ap, 1, is_val=False)
|
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Data depended initialization ... ")
|
print(" > Data depended initialization ... ")
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
|
@ -152,10 +157,9 @@ def data_depended_init(model, ap):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, scheduler,
|
def train(data_loader, model, criterion, optimizer, scheduler,
|
||||||
ap, global_step, epoch):
|
ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, 1, is_val=False,
|
|
||||||
verbose=(epoch == 0))
|
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
|
@ -308,8 +312,7 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(model, criterion, ap, global_step, epoch):
|
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, 1, is_val=True)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
|
@ -533,14 +536,18 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if 'best_loss' not in locals():
|
if 'best_loss' not in locals():
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
|
# define dataloaders
|
||||||
|
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||||
|
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
||||||
|
|
||||||
global_step = args.restore_step
|
global_step = args.restore_step
|
||||||
model = data_depended_init(model, ap)
|
model = data_depended_init(train_loader, model, ap)
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
||||||
scheduler, ap, global_step,
|
scheduler, ap, global_step,
|
||||||
epoch)
|
epoch)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_loss']
|
target_loss = train_avg_loss_dict['avg_loss']
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
|
|
|
@ -61,6 +61,12 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
enable_eos_bos=c.enable_eos_bos_chars,
|
enable_eos_bos=c.enable_eos_bos_chars,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||||
|
|
||||||
|
if c.use_phonemes and c.compute_input_seq_cache:
|
||||||
|
# precompute phonemes to have a better estimate of sequence lengths.
|
||||||
|
dataset.compute_input_seq(c.num_loader_workers)
|
||||||
|
dataset.sort_items()
|
||||||
|
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -123,10 +129,8 @@ def format_data(data):
|
||||||
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
|
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
ap, global_step, epoch, scaler, scaler_st):
|
ap, global_step, epoch, scaler, scaler_st):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
|
||||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
|
@ -324,8 +328,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(model, criterion, ap, global_step, epoch):
|
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
|
@ -583,6 +586,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if 'best_loss' not in locals():
|
if 'best_loss' not in locals():
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
|
# define data loaders
|
||||||
|
train_loader = setup_loader(ap,
|
||||||
|
model.decoder.r,
|
||||||
|
is_val=False,
|
||||||
|
verbose=True)
|
||||||
|
eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
||||||
|
|
||||||
global_step = args.restore_step
|
global_step = args.restore_step
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
|
@ -594,16 +604,27 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if c.bidirectional_decoder:
|
if c.bidirectional_decoder:
|
||||||
model.decoder_backward.set_r(r)
|
model.decoder_backward.set_r(r)
|
||||||
print("\n > Number of output frames:", model.decoder.r)
|
print("\n > Number of output frames:", model.decoder.r)
|
||||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(train_loader, model,
|
||||||
|
criterion, optimizer,
|
||||||
optimizer_st, scheduler, ap,
|
optimizer_st, scheduler, ap,
|
||||||
global_step, epoch, scaler, scaler_st)
|
global_step, epoch, scaler,
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
scaler_st)
|
||||||
|
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||||
|
global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
||||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
best_loss = save_best_model(
|
||||||
OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None)
|
target_loss,
|
||||||
|
best_loss,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
global_step,
|
||||||
|
epoch,
|
||||||
|
c.r,
|
||||||
|
OUT_PATH,
|
||||||
|
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -131,6 +131,7 @@
|
||||||
"batch_group_size": 4, //Number of batches to shuffle after bucketing.
|
"batch_group_size": 4, //Number of batches to shuffle after bucketing.
|
||||||
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
||||||
"max_seq_len": 153, // DATASET-RELATED: maximum text length
|
"max_seq_len": 153, // DATASET-RELATED: maximum text length
|
||||||
|
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
|
||||||
|
|
||||||
// PATHS
|
// PATHS
|
||||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||||
|
|
|
@ -105,6 +105,7 @@
|
||||||
"min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training
|
"min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training
|
||||||
"max_seq_len": 500, // DATASET-RELATED: maximum text length
|
"max_seq_len": 500, // DATASET-RELATED: maximum text length
|
||||||
"compute_f0": false, // compute f0 values in data-loader
|
"compute_f0": false, // compute f0 values in data-loader
|
||||||
|
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
|
||||||
|
|
||||||
// PATHS
|
// PATHS
|
||||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||||
|
|
|
@ -63,6 +63,7 @@ class MyDataset(Dataset):
|
||||||
self.enable_eos_bos = enable_eos_bos
|
self.enable_eos_bos = enable_eos_bos
|
||||||
self.speaker_mapping = speaker_mapping
|
self.speaker_mapping = speaker_mapping
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self.input_seq_computed = False
|
||||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -71,7 +72,6 @@ class MyDataset(Dataset):
|
||||||
if use_phonemes:
|
if use_phonemes:
|
||||||
print(" | > phoneme language: {}".format(phoneme_language))
|
print(" | > phoneme language: {}".format(phoneme_language))
|
||||||
print(" | > Number of instances : {}".format(len(self.items)))
|
print(" | > Number of instances : {}".format(len(self.items)))
|
||||||
self.sort_items()
|
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
audio = self.ap.load_wav(filename)
|
audio = self.ap.load_wav(filename)
|
||||||
|
|
Loading…
Reference in New Issue