mirror of https://github.com/coqui-ai/TTS.git
make speaker_mapping a global variable to prevent reload. Fix glow-tts training
This commit is contained in:
parent
a757b203bc
commit
7c3cdced1a
|
@ -7,41 +7,37 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from random import randrange
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from random import randrange
|
# DISTRIBUTED
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
|
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import parse_speakers, load_speaker_mapping
|
from TTS.tts.utils.speakers import parse_speakers
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
|
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.training import (NoamLR, check_update,
|
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||||
setup_torch_training_env)
|
|
||||||
|
|
||||||
# DISTRIBUTED
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
|
||||||
|
|
||||||
|
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
|
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
|
@ -78,29 +74,29 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
|
||||||
|
|
||||||
|
|
||||||
def format_data(data):
|
def format_data(data):
|
||||||
if c.use_speaker_embedding:
|
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
|
||||||
|
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
speaker_names = data[2]
|
speaker_names = data[2]
|
||||||
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
||||||
mel_lengths = data[5]
|
mel_lengths = data[5]
|
||||||
attn_mask = data[8]
|
item_idx = data[7]
|
||||||
|
attn_mask = data[9]
|
||||||
avg_text_length = torch.mean(text_lengths.float())
|
avg_text_length = torch.mean(text_lengths.float())
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
avg_spec_length = torch.mean(mel_lengths.float())
|
||||||
|
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
if c.use_external_speaker_embedding_file:
|
if c.use_external_speaker_embedding_file:
|
||||||
speaker_ids = data[8]
|
# return precomputed embedding vector
|
||||||
|
speaker_c = data[8]
|
||||||
else:
|
else:
|
||||||
speaker_ids = [
|
# return speaker_id to be used by an embedding layer
|
||||||
|
speaker_c = [
|
||||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||||
]
|
]
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
speaker_c = torch.LongTensor(speaker_c)
|
||||||
else:
|
else:
|
||||||
speaker_ids = None
|
speaker_c = None
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
@ -108,15 +104,15 @@ def format_data(data):
|
||||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||||
mel_input = mel_input.cuda(non_blocking=True)
|
mel_input = mel_input.cuda(non_blocking=True)
|
||||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||||
if speaker_ids is not None:
|
if speaker_c is not None:
|
||||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
speaker_c = speaker_c.cuda(non_blocking=True)
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||||
return text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||||
avg_text_length, avg_spec_length, attn_mask
|
avg_text_length, avg_spec_length, attn_mask, item_idx
|
||||||
|
|
||||||
|
|
||||||
def data_depended_init(model, ap, speaker_mapping=None):
|
def data_depended_init(model, ap):
|
||||||
"""Data depended initialization for activation normalization."""
|
"""Data depended initialization for activation normalization."""
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
for f in model.module.decoder.flows:
|
for f in model.module.decoder.flows:
|
||||||
|
@ -127,20 +123,23 @@ def data_depended_init(model, ap, speaker_mapping=None):
|
||||||
if getattr(f, "set_ddi", False):
|
if getattr(f, "set_ddi", False):
|
||||||
f.set_ddi(True)
|
f.set_ddi(True)
|
||||||
|
|
||||||
data_loader = setup_loader(ap, 1, is_val=False, speaker_mapping=speaker_mapping)
|
data_loader = setup_loader(ap, 1, is_val=False)
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Data depended initialization ... ")
|
print(" > Data depended initialization ... ")
|
||||||
|
num_iter = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _, data in enumerate(data_loader):
|
for _, data in enumerate(data_loader):
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
|
||||||
_, _, attn_mask = format_data(data)
|
_, _, attn_mask, item_idx = format_data(data)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
_ = model.forward(
|
_ = model.forward(
|
||||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed)
|
||||||
|
if num_iter == c.data_dep_init_iter:
|
||||||
break
|
break
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
for f in model.module.decoder.flows:
|
for f in model.module.decoder.flows:
|
||||||
|
@ -154,9 +153,9 @@ def data_depended_init(model, ap, speaker_mapping=None):
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, scheduler,
|
def train(model, criterion, optimizer, scheduler,
|
||||||
ap, global_step, epoch, speaker_mapping=None):
|
ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, 1, is_val=False,
|
data_loader = setup_loader(ap, 1, is_val=False,
|
||||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
verbose=(epoch == 0))
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
|
@ -172,8 +171,8 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||||
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
avg_text_length, avg_spec_length, attn_mask, item_idx = format_data(data)
|
||||||
|
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
|
@ -203,10 +202,6 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
c.grad_clip)
|
c.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
if c.noam_schedule:
|
if c.noam_schedule:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
@ -215,7 +210,7 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
# compute alignment error (the lower the better )
|
||||||
align_error = 1 - alignment_diagonal_score(alignments)
|
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||||
loss_dict['align_error'] = align_error
|
loss_dict['align_error'] = align_error
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
@ -276,7 +271,7 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
# direct pass on model for spec predictions
|
# direct pass on model for spec predictions
|
||||||
target_speaker = None if speaker_ids is None else speaker_ids[:1]
|
target_speaker = None if speaker_c is None else speaker_c[:1]
|
||||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||||
spec_pred = spec_pred.permute(0, 2, 1)
|
spec_pred = spec_pred.permute(0, 2, 1)
|
||||||
gt_spec = mel_input.permute(0, 2, 1)
|
gt_spec = mel_input.permute(0, 2, 1)
|
||||||
|
@ -313,8 +308,8 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
|
def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, 1, is_val=True, speaker_mapping=speaker_mapping)
|
data_loader = setup_loader(ap, 1, is_val=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
|
@ -324,12 +319,12 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||||
_, _, attn_mask = format_data(data)
|
_, _, attn_mask, item_idx = format_data(data)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||||
|
@ -370,7 +365,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
# direct pass on model for spec predictions
|
# direct pass on model for spec predictions
|
||||||
target_speaker = None if speaker_ids is None else speaker_ids[:1]
|
target_speaker = None if speaker_c is None else speaker_c[:1]
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||||
else:
|
else:
|
||||||
|
@ -464,7 +459,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
|
||||||
# FIXME: move args definition/parsing inside of main?
|
# FIXME: move args definition/parsing inside of main?
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data_train, meta_data_eval, symbols, phonemes
|
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
if 'characters' in c.keys():
|
if 'characters' in c.keys():
|
||||||
|
@ -539,13 +534,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
global_step = args.restore_step
|
global_step = args.restore_step
|
||||||
model = data_depended_init(model, ap, speaker_mapping)
|
model = data_depended_init(model, ap)
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||||
scheduler, ap, global_step,
|
scheduler, ap, global_step,
|
||||||
epoch, speaker_mapping)
|
epoch)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
|
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_loss']
|
target_loss = train_avg_loss_dict['avg_loss']
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
|
|
|
@ -18,7 +18,7 @@ from TTS.tts.layers.losses import TacotronLoss
|
||||||
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import load_speaker_mapping, parse_speakers
|
from TTS.tts.utils.speakers import parse_speakers
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
@ -39,7 +39,7 @@ from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
|
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
|
@ -74,10 +74,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def format_data(data, speaker_mapping=None):
|
def format_data(data):
|
||||||
if speaker_mapping is None and c.use_speaker_embedding and not c.use_external_speaker_embedding_file:
|
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
|
||||||
|
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
|
@ -127,7 +124,7 @@ def format_data(data, speaker_mapping=None):
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
ap, global_step, epoch, scaler, scaler_st, speaker_mapping=None):
|
ap, global_step, epoch, scaler, scaler_st):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
||||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -144,7 +141,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data, speaker_mapping)
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data)
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -327,7 +324,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None):
|
def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping)
|
data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping)
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
@ -338,7 +335,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data, speaker_mapping)
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data)
|
||||||
assert mel_input.shape[1] % model.decoder.r == 0
|
assert mel_input.shape[1] % model.decoder.r == 0
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
|
@ -493,7 +490,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None):
|
||||||
# FIXME: move args definition/parsing inside of main?
|
# FIXME: move args definition/parsing inside of main?
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data_train, meta_data_eval, symbols, phonemes
|
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
if 'characters' in c.keys():
|
if 'characters' in c.keys():
|
||||||
|
@ -599,8 +596,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
print("\n > Number of output frames:", model.decoder.r)
|
print("\n > Number of output frames:", model.decoder.r)
|
||||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||||
optimizer_st, scheduler, ap,
|
optimizer_st, scheduler, ap,
|
||||||
global_step, epoch, scaler, scaler_st, speaker_mapping)
|
global_step, epoch, scaler, scaler_st)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
|
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
|
|
|
@ -104,6 +104,7 @@ class GlowTts(nn.Module):
|
||||||
c_in_channels=self.c_in_channels)
|
c_in_channels=self.c_in_channels)
|
||||||
|
|
||||||
if num_speakers > 1 and not external_speaker_embedding_dim:
|
if num_speakers > 1 and not external_speaker_embedding_dim:
|
||||||
|
# speaker embedding layer
|
||||||
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
|
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
|
||||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue