make style

This commit is contained in:
Eren Gölge 2021-05-27 17:25:00 +02:00
parent 469d2e620a
commit b500338faa
25 changed files with 524 additions and 747 deletions

View File

@ -153,15 +153,9 @@ def inference(
model_output = model_output.transpose(1, 2).detach().cpu().numpy() model_output = model_output.transpose(1, 2).detach().cpu().numpy()
elif "tacotron" in model_name: elif "tacotron" in model_name:
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': speaker_embeddings} cond_input = {"speaker_ids": speaker_ids, "x_vectors": speaker_embeddings}
outputs = model( outputs = model(text_input, text_lengths, mel_input, mel_lengths, cond_input)
text_input, postnet_outputs = outputs["model_outputs"]
text_lengths,
mel_input,
mel_lengths,
cond_input
)
postnet_outputs = outputs['model_outputs']
# normalize tacotron output # normalize tacotron output
if model_name == "tacotron": if model_name == "tacotron":
mel_specs = [] mel_specs = []

View File

@ -8,13 +8,8 @@ from TTS.trainer import TrainerTTS
def main(): def main():
# try: # try:
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training( args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
sys.argv) trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
trainer = TrainerTTS(args,
config,
c_logger,
tb_logger,
output_path=OUT_PATH)
trainer.fit() trainer.fit()
# except KeyboardInterrupt: # except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH) # remove_experiment_folder(OUT_PATH)

View File

@ -84,7 +84,7 @@ class TrainerTTS:
self.best_loss = float("inf") self.best_loss = float("inf")
self.train_loader = None self.train_loader = None
self.eval_loader = None self.eval_loader = None
self.output_audio_path = os.path.join(output_path, 'test_audios') self.output_audio_path = os.path.join(output_path, "test_audios")
self.keep_avg_train = None self.keep_avg_train = None
self.keep_avg_eval = None self.keep_avg_eval = None
@ -138,8 +138,8 @@ class TrainerTTS:
if self.args.restore_path: if self.args.restore_path:
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
self.config, args.restore_path, self.model, self.optimizer, self.config, args.restore_path, self.model, self.optimizer, self.scaler
self.scaler) )
# setup scheduler # setup scheduler
self.scheduler = self.get_scheduler(self.config, self.optimizer) self.scheduler = self.get_scheduler(self.config, self.optimizer)
@ -207,6 +207,7 @@ class TrainerTTS:
return None return None
if lr_scheduler.lower() == "noamlr": if lr_scheduler.lower() == "noamlr":
from TTS.utils.training import NoamLR from TTS.utils.training import NoamLR
scheduler = NoamLR scheduler = NoamLR
else: else:
scheduler = getattr(torch.optim, lr_scheduler) scheduler = getattr(torch.optim, lr_scheduler)
@ -261,8 +262,7 @@ class TrainerTTS:
ap=ap, ap=ap,
tp=self.config.characters, tp=self.config.characters,
add_blank=self.config["add_blank"], add_blank=self.config["add_blank"],
batch_group_size=0 if is_eval else batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
self.config.batch_group_size * self.config.batch_size,
min_seq_len=self.config.min_seq_len, min_seq_len=self.config.min_seq_len,
max_seq_len=self.config.max_seq_len, max_seq_len=self.config.max_seq_len,
phoneme_cache_path=self.config.phoneme_cache_path, phoneme_cache_path=self.config.phoneme_cache_path,
@ -272,8 +272,8 @@ class TrainerTTS:
use_noise_augment=not is_eval, use_noise_augment=not is_eval,
verbose=verbose, verbose=verbose,
speaker_mapping=speaker_mapping speaker_mapping=speaker_mapping
if self.config.use_speaker_embedding if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
and self.config.use_external_speaker_embedding_file else None, else None,
) )
if self.config.use_phonemes and self.config.compute_input_seq_cache: if self.config.use_phonemes and self.config.compute_input_seq_cache:
@ -281,18 +281,15 @@ class TrainerTTS:
dataset.compute_input_seq(self.config.num_loader_workers) dataset.compute_input_seq(self.config.num_loader_workers)
dataset.sort_items() dataset.sort_items()
sampler = DistributedSampler( sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
dataset) if self.num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=self.config.eval_batch_size batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
if is_eval else self.config.batch_size,
shuffle=False, shuffle=False,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=False, drop_last=False,
sampler=sampler, sampler=sampler,
num_workers=self.config.num_val_loader_workers num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
if is_eval else self.config.num_loader_workers,
pin_memory=False, pin_memory=False,
) )
return loader return loader
@ -314,8 +311,7 @@ class TrainerTTS:
text_input = batch[0] text_input = batch[0]
text_lengths = batch[1] text_lengths = batch[1]
speaker_names = batch[2] speaker_names = batch[2]
linear_input = batch[3] if self.config.model.lower() in ["tacotron" linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
] else None
mel_input = batch[4] mel_input = batch[4]
mel_lengths = batch[5] mel_lengths = batch[5]
stop_targets = batch[6] stop_targets = batch[6]
@ -331,10 +327,7 @@ class TrainerTTS:
speaker_embeddings = batch[8] speaker_embeddings = batch[8]
speaker_ids = None speaker_ids = None
else: else:
speaker_ids = [ speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
self.speaker_manager.speaker_ids[speaker_name]
for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids) speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None speaker_embeddings = None
else: else:
@ -346,7 +339,7 @@ class TrainerTTS:
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
for idx, am in enumerate(attn_mask): for idx, am in enumerate(attn_mask):
# compute raw durations # compute raw durations
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1] c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
c_idxs, counts = torch.unique(c_idxs, return_counts=True) c_idxs, counts = torch.unique(c_idxs, return_counts=True)
dur = torch.ones([text_lengths[idx]]).to(counts.dtype) dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
@ -359,14 +352,11 @@ class TrainerTTS:
assert ( assert (
dur.sum() == mel_lengths[idx] dur.sum() == mel_lengths[idx]
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, :text_lengths[idx]] = dur durations[idx, : text_lengths[idx]] = dur
# set stop targets view, we predict a single stop token per iteration. # set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0], stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
stop_targets.size(1) // self.config.r, stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
-1)
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze(2)
# dispatch batch to GPU # dispatch batch to GPU
if self.use_cuda: if self.use_cuda:
@ -374,15 +364,10 @@ class TrainerTTS:
text_lengths = text_lengths.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda( linear_input = linear_input.cuda(non_blocking=True) if self.config.model.lower() in ["tacotron"] else None
non_blocking=True) if self.config.model.lower() in [
"tacotron"
] else None
stop_targets = stop_targets.cuda(non_blocking=True) stop_targets = stop_targets.cuda(non_blocking=True)
attn_mask = attn_mask.cuda( attn_mask = attn_mask.cuda(non_blocking=True) if attn_mask is not None else None
non_blocking=True) if attn_mask is not None else None durations = durations.cuda(non_blocking=True) if attn_mask is not None else None
durations = durations.cuda(
non_blocking=True) if attn_mask is not None else None
if speaker_ids is not None: if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True) speaker_ids = speaker_ids.cuda(non_blocking=True)
if speaker_embeddings is not None: if speaker_embeddings is not None:
@ -401,7 +386,7 @@ class TrainerTTS:
"x_vectors": speaker_embeddings, "x_vectors": speaker_embeddings,
"max_text_length": max_text_length, "max_text_length": max_text_length,
"max_spec_length": max_spec_length, "max_spec_length": max_spec_length,
"item_idx": item_idx "item_idx": item_idx,
} }
def train_step(self, batch: Dict, batch_n_steps: int, step: int, def train_step(self, batch: Dict, batch_n_steps: int, step: int,
@ -421,25 +406,20 @@ class TrainerTTS:
# check nan loss # check nan loss
if torch.isnan(loss_dict["loss"]).any(): if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError( raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
f"Detected NaN loss at step {self.total_steps_done}.")
# optimizer step # optimizer step
if self.config.mixed_precision: if self.config.mixed_precision:
# model optimizer step in mixed precision mode # model optimizer step in mixed precision mode
self.scaler.scale(loss_dict["loss"]).backward() self.scaler.scale(loss_dict["loss"]).backward()
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
grad_norm, _ = check_update(self.model, grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
self.config.grad_clip,
ignore_stopnet=True)
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
self.scaler.update() self.scaler.update()
else: else:
# main model optimizer step # main model optimizer step
loss_dict["loss"].backward() loss_dict["loss"].backward()
grad_norm, _ = check_update(self.model, grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
self.config.grad_clip,
ignore_stopnet=True)
self.optimizer.step() self.optimizer.step()
step_time = time.time() - step_start_time step_time = time.time() - step_start_time
@ -469,17 +449,15 @@ class TrainerTTS:
current_lr = self.optimizer.param_groups[0]["lr"] current_lr = self.optimizer.param_groups[0]["lr"]
if self.total_steps_done % self.config.print_step == 0: if self.total_steps_done % self.config.print_step == 0:
log_dict = { log_dict = {
"max_spec_length": [batch["max_spec_length"], "max_spec_length": [batch["max_spec_length"], 1], # value, precision
1], # value, precision
"max_text_length": [batch["max_text_length"], 1], "max_text_length": [batch["max_text_length"], 1],
"step_time": [step_time, 4], "step_time": [step_time, 4],
"loader_time": [loader_time, 2], "loader_time": [loader_time, 2],
"current_lr": current_lr, "current_lr": current_lr,
} }
self.c_logger.print_train_step(batch_n_steps, step, self.c_logger.print_train_step(
self.total_steps_done, log_dict, batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
loss_dict, )
self.keep_avg_train.avg_values)
if self.args.rank == 0: if self.args.rank == 0:
# Plot Training Iter Stats # Plot Training Iter Stats
@ -491,8 +469,7 @@ class TrainerTTS:
"step_time": step_time, "step_time": step_time,
} }
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done, self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
iter_stats)
if self.total_steps_done % self.config.save_step == 0: if self.total_steps_done % self.config.save_step == 0:
if self.config.checkpoint: if self.config.checkpoint:
@ -506,15 +483,12 @@ class TrainerTTS:
self.output_path, self.output_path,
model_loss=loss_dict["loss"], model_loss=loss_dict["loss"],
characters=self.model_characters, characters=self.model_characters,
scaler=self.scaler.state_dict() scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
if self.config.mixed_precision else None,
) )
# training visualizations # training visualizations
figures, audios = self.model.train_log(self.ap, batch, outputs) figures, audios = self.model.train_log(self.ap, batch, outputs)
self.tb_logger.tb_train_figures(self.total_steps_done, figures) self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.tb_logger.tb_train_audios(self.total_steps_done, self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
{"TrainAudio": audios},
self.ap.sample_rate)
self.total_steps_done += 1 self.total_steps_done += 1
self.on_train_step_end() self.on_train_step_end()
return outputs, loss_dict return outputs, loss_dict
@ -523,35 +497,28 @@ class TrainerTTS:
self.model.train() self.model.train()
epoch_start_time = time.time() epoch_start_time = time.time()
if self.use_cuda: if self.use_cuda:
batch_num_steps = int( batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
len(self.train_loader.dataset) /
(self.config.batch_size * self.num_gpus))
else: else:
batch_num_steps = int( batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start() self.c_logger.print_train_start()
loader_start_time = time.time() loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader): for cur_step, batch in enumerate(self.train_loader):
_, _ = self.train_step(batch, batch_num_steps, cur_step, _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
loader_start_time)
epoch_time = time.time() - epoch_start_time epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats # Plot self.epochs_done Stats
if self.args.rank == 0: if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time} epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values) epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
epoch_stats)
if self.config.tb_model_param_stats: if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
self.total_steps_done)
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
with torch.no_grad(): with torch.no_grad():
step_start_time = time.time() step_start_time = time.time()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self.model.eval_step( outputs, loss_dict = self.model.eval_step(batch, self.criterion)
batch, self.criterion)
step_time = time.time() - step_start_time step_time = time.time() - step_start_time
@ -572,8 +539,7 @@ class TrainerTTS:
self.keep_avg_eval.update_values(update_eval_values) self.keep_avg_eval.update_values(update_eval_values)
if self.config.print_eval: if self.config.print_eval:
self.c_logger.print_eval_step(step, loss_dict, self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
self.keep_avg_eval.avg_values)
return outputs, loss_dict return outputs, loss_dict
def eval_epoch(self) -> None: def eval_epoch(self) -> None:
@ -585,15 +551,13 @@ class TrainerTTS:
# format data # format data
batch = self.format_batch(batch) batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({'avg_loader_time': loader_time}) self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
outputs, _ = self.eval_step(batch, cur_step) outputs, _ = self.eval_step(batch, cur_step)
# Plot epoch stats and samples from the last batch. # Plot epoch stats and samples from the last batch.
if self.args.rank == 0: if self.args.rank == 0:
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs) figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
self.tb_logger.tb_eval_figures(self.total_steps_done, figures) self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.tb_logger.tb_eval_audios(self.total_steps_done, self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
{"EvalAudio": eval_audios},
self.ap.sample_rate)
def test_run(self, ) -> None: def test_run(self, ) -> None:
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
@ -608,9 +572,9 @@ class TrainerTTS:
self.config, self.config,
self.use_cuda, self.use_cuda,
self.ap, self.ap,
speaker_id=cond_inputs['speaker_id'], speaker_id=cond_inputs["speaker_id"],
x_vector=cond_inputs['x_vector'], x_vector=cond_inputs["x_vector"],
style_wav=cond_inputs['style_wav'], style_wav=cond_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars, enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
@ -623,10 +587,8 @@ class TrainerTTS:
"TestSentence_{}.wav".format(idx)) "TestSentence_{}.wav".format(idx))
self.ap.save_wav(wav, file_path) self.ap.save_wav(wav, file_path)
test_audios["{}-audio".format(idx)] = wav test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram( test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
model_outputs, self.ap, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(
alignment, output_fig=False)
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
self.config.audio["sample_rate"]) self.config.audio["sample_rate"])
@ -641,11 +603,11 @@ class TrainerTTS:
if self.config.use_external_speaker_embedding_file if self.config.use_external_speaker_embedding_file
and self.config.use_speaker_embedding else None) and self.config.use_speaker_embedding else None)
# setup style_mel # setup style_mel
if self.config.has('gst_style_input'): if self.config.has("gst_style_input"):
style_wav = self.config.gst_style_input style_wav = self.config.gst_style_input
else: else:
style_wav = None style_wav = None
if style_wav is None and 'use_gst' in self.config and self.config.use_gst: if style_wav is None and "use_gst" in self.config and self.config.use_gst:
# inicialize GST with zero dict. # inicialize GST with zero dict.
style_wav = {} style_wav = {}
print( print(
@ -688,8 +650,7 @@ class TrainerTTS:
for epoch in range(0, self.config.epochs): for epoch in range(0, self.config.epochs):
self.on_epoch_start() self.on_epoch_start()
self.keep_avg_train = KeepAverage() self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage( self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
) if self.config.run_eval else None
self.epochs_done = epoch self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs) self.c_logger.print_epoch_start(epoch, self.config.epochs)
self.train_epoch() self.train_epoch()
@ -698,8 +659,8 @@ class TrainerTTS:
if epoch >= self.config.test_delay_epochs: if epoch >= self.config.test_delay_epochs:
self.test_run() self.test_run()
self.c_logger.print_epoch_end( self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
if self.config.run_eval else self.keep_avg_train.avg_values) )
self.save_best_model() self.save_best_model()
self.on_epoch_end() self.on_epoch_end()
@ -717,8 +678,7 @@ class TrainerTTS:
self.model_characters, self.model_characters,
keep_all_best=self.config.keep_all_best, keep_all_best=self.config.keep_all_best,
keep_after=self.config.keep_after, keep_after=self.config.keep_after,
scaler=self.scaler.state_dict() scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
if self.config.mixed_precision else None,
) )
@staticmethod @staticmethod

View File

@ -93,7 +93,7 @@ class AlignTTSConfig(BaseTTSConfig):
# optimizer parameters # optimizer parameters
optimizer: str = "Adam" optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6}) optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = None lr_scheduler: str = None
lr_scheduler_params: dict = None lr_scheduler_params: dict = None
lr: float = 1e-4 lr: float = 1e-4
@ -104,12 +104,13 @@ class AlignTTSConfig(BaseTTSConfig):
max_seq_len: int = 200 max_seq_len: int = 200
r: int = 1 r: int = 1
# testing # testing
test_sentences: List[str] = field(default_factory=lambda:[ test_sentences: List[str] = field(
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", default_factory=lambda: [
"Be a voice, not an echo.", "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"I'm sorry Dave. I'm afraid I can't do that.", "Be a voice, not an echo.",
"This cake is great. It's so delicious and moist.", "I'm sorry Dave. I'm afraid I can't do that.",
"Prior to November 22, 1963." "This cake is great. It's so delicious and moist.",
]) "Prior to November 22, 1963.",
]
)

View File

@ -91,9 +91,9 @@ class GlowTTSConfig(BaseTTSConfig):
# optimizer parameters # optimizer parameters
optimizer: str = "RAdam" optimizer: str = "RAdam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6}) optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR" lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000}) lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
grad_clip: float = 5.0 grad_clip: float = 5.0
lr: float = 1e-3 lr: float = 1e-3

View File

@ -174,7 +174,7 @@ class BaseTTSConfig(BaseTrainingConfig):
optimizer: str = MISSING optimizer: str = MISSING
optimizer_params: dict = MISSING optimizer_params: dict = MISSING
# scheduler # scheduler
lr_scheduler: str = '' lr_scheduler: str = ""
lr_scheduler_params: dict = field(default_factory=lambda: {}) lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing # testing
test_sentences: List[str] = field(default_factory=lambda:[]) test_sentences: List[str] = field(default_factory=lambda: [])

View File

@ -101,7 +101,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
# optimizer parameters # optimizer parameters
optimizer: str = "RAdam" optimizer: str = "RAdam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6}) optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = None lr_scheduler: str = None
lr_scheduler_params: dict = None lr_scheduler_params: dict = None
lr: float = 1e-4 lr: float = 1e-4
@ -118,10 +118,12 @@ class SpeedySpeechConfig(BaseTTSConfig):
r: int = 1 # DO NOT CHANGE r: int = 1 # DO NOT CHANGE
# testing # testing
test_sentences: List[str] = field(default_factory=lambda:[ test_sentences: List[str] = field(
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", default_factory=lambda: [
"Be a voice, not an echo.", "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"I'm sorry Dave. I'm afraid I can't do that.", "Be a voice, not an echo.",
"This cake is great. It's so delicious and moist.", "I'm sorry Dave. I'm afraid I can't do that.",
"Prior to November 22, 1963." "This cake is great. It's so delicious and moist.",
]) "Prior to November 22, 1963.",
]
)

View File

@ -160,9 +160,9 @@ class TacotronConfig(BaseTTSConfig):
# optimizer parameters # optimizer parameters
optimizer: str = "RAdam" optimizer: str = "RAdam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6}) optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR" lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000}) lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4 lr: float = 1e-4
grad_clip: float = 5.0 grad_clip: float = 5.0
seq_len_norm: bool = False seq_len_norm: bool = False
@ -178,13 +178,15 @@ class TacotronConfig(BaseTTSConfig):
ga_alpha: float = 5.0 ga_alpha: float = 5.0
# testing # testing
test_sentences: List[str] = field(default_factory=lambda:[ test_sentences: List[str] = field(
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", default_factory=lambda: [
"Be a voice, not an echo.", "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"I'm sorry Dave. I'm afraid I can't do that.", "Be a voice, not an echo.",
"This cake is great. It's so delicious and moist.", "I'm sorry Dave. I'm afraid I can't do that.",
"Prior to November 22, 1963." "This cake is great. It's so delicious and moist.",
]) "Prior to November 22, 1963.",
]
)
def check_values(self): def check_values(self):
if self.gradual_training: if self.gradual_training:

View File

@ -44,22 +44,18 @@ def load_meta_data(datasets, eval_split=True):
preprocessor = _get_preprocessor_by_name(name) preprocessor = _get_preprocessor_by_name(name)
# load train set # load train set
meta_data_train = preprocessor(root_path, meta_file_train) meta_data_train = preprocessor(root_path, meta_file_train)
print( print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}"
)
# load evaluation split if set # load evaluation split if set
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
meta_data_eval = preprocessor(root_path, meta_file_val) meta_data_eval = preprocessor(root_path, meta_file_val)
else: else:
meta_data_eval, meta_data_train = split_dataset( meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_train)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train meta_data_train_all += meta_data_train
# load attention masks for duration predictor training # load attention masks for duration predictor training
if dataset.meta_file_attn_mask: if dataset.meta_file_attn_mask:
meta_data = dict( meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all): for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file) meta_data_train_all[idx].append(attn_file)

View File

@ -506,5 +506,10 @@ class AlignTTSLoss(nn.Module):
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
loss = self.spec_loss_alpha * spec_loss + self.ssim_alpha * ssim_loss + self.dur_loss_alpha * dur_loss + self.mdn_alpha * mdn_loss loss = (
self.spec_loss_alpha * spec_loss
+ self.ssim_alpha * ssim_loss
+ self.dur_loss_alpha * dur_loss
+ self.mdn_alpha * mdn_loss
)
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}

View File

@ -72,19 +72,9 @@ class AlignTTS(nn.Module):
hidden_channels=256, hidden_channels=256,
hidden_channels_dp=256, hidden_channels_dp=256,
encoder_type="fftransformer", encoder_type="fftransformer",
encoder_params={ encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
"hidden_channels_ffn": 1024,
"num_heads": 2,
"num_layers": 6,
"dropout_p": 0.1
},
decoder_type="fftransformer", decoder_type="fftransformer",
decoder_params={ decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
"hidden_channels_ffn": 1024,
"num_heads": 2,
"num_layers": 6,
"dropout_p": 0.1
},
length_scale=1, length_scale=1,
num_speakers=0, num_speakers=0,
external_c=False, external_c=False,
@ -93,14 +83,11 @@ class AlignTTS(nn.Module):
super().__init__() super().__init__()
self.phase = -1 self.phase = -1
self.length_scale = float(length_scale) if isinstance( self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
length_scale, int) else length_scale
self.emb = nn.Embedding(num_chars, hidden_channels) self.emb = nn.Embedding(num_chars, hidden_channels)
self.pos_encoder = PositionalEncoding(hidden_channels) self.pos_encoder = PositionalEncoding(hidden_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
encoder_params, c_in_channels) self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels_dp) self.duration_predictor = DurationPredictor(hidden_channels_dp)
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1) self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
@ -121,9 +108,9 @@ class AlignTTS(nn.Module):
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss( exponential = -0.5 * torch.mean(
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1
dim=-1) # B, L, T ) # B, L, T
logp = exponential - 0.5 * log_sigma.mean(dim=-1) logp = exponential - 0.5 * log_sigma.mean(dim=-1)
return logp return logp
@ -157,9 +144,7 @@ class AlignTTS(nn.Module):
[1, 0, 0, 0, 0, 0, 0]] [1, 0, 0, 0, 0, 0, 0]]
""" """
attn = self.convert_dr_to_align(dr, x_mask, y_mask) attn = self.convert_dr_to_align(dr, x_mask, y_mask)
o_en_ex = torch.matmul( o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
attn.squeeze(1).transpose(1, 2), en.transpose(1,
2)).transpose(1, 2)
return o_en_ex, attn return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask): def format_durations(self, o_dr_log, x_mask):
@ -193,8 +178,7 @@ class AlignTTS(nn.Module):
x_emb = torch.transpose(x_emb, 1, -1) x_emb = torch.transpose(x_emb, 1, -1)
# compute sequence masks # compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
1).to(x.dtype)
# encoder pass # encoder pass
o_en = self.encoder(x_emb, x_mask) o_en = self.encoder(x_emb, x_mask)
@ -207,8 +191,7 @@ class AlignTTS(nn.Module):
return o_en, o_en_dp, x_mask, g return o_en, o_en_dp, x_mask, g
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
1).to(o_en_dp.dtype)
# expand o_en with durations # expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding # positional encoding
@ -224,13 +207,13 @@ class AlignTTS(nn.Module):
def _forward_mdn(self, o_en, y, y_lengths, x_mask): def _forward_mdn(self, o_en, y, y_lengths, x_mask):
# MAS potentials and alignment # MAS potentials and alignment
mu, log_sigma = self.mdn_block(o_en) mu, log_sigma = self.mdn_block(o_en)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
1).to(o_en.dtype) dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
y_mask)
return dr_mas, mu, log_sigma, logp return dr_mas, mu, log_sigma, logp
def forward(self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None): # pylint: disable=unused-argument def forward(
self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None
): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
x: [B, T_max] x: [B, T_max]
@ -240,83 +223,58 @@ class AlignTTS(nn.Module):
g: [B, C] g: [B, C]
""" """
y = y.transpose(1, 2) y = y.transpose(1, 2)
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
if phase == 0: if phase == 0:
# train encoder and MDN # train encoder and MDN
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn( dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_en, y, y_lengths, x_mask) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask) attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
elif phase == 1: elif phase == 1:
# train decoder # train decoder
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en.detach(), o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g)
o_en_dp.detach(),
dr_mas.detach(),
x_mask,
y_lengths,
g=g)
elif phase == 2: elif phase == 2:
# train the whole except duration predictor # train the whole except duration predictor
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn( dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
elif phase == 3: elif phase == 3:
# train duration predictor # train duration predictor
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(x, x_mask) o_dr_log = self.duration_predictor(x, x_mask)
dr_mas, mu, log_sigma, logp = self._forward_mdn( dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
o_dr_log = o_dr_log.squeeze(1) o_dr_log = o_dr_log.squeeze(1)
else: else:
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
dr_mas, mu, log_sigma, logp = self._forward_mdn( dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
o_dr_log = o_dr_log.squeeze(1) o_dr_log = o_dr_log.squeeze(1)
dr_mas_log = torch.log(dr_mas + 1).squeeze(1) dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
outputs = { outputs = {
'model_outputs': o_de.transpose(1, 2), "model_outputs": o_de.transpose(1, 2),
'alignments': attn, "alignments": attn,
'durations_log': o_dr_log, "durations_log": o_dr_log,
'durations_mas_log': dr_mas_log, "durations_mas_log": dr_mas_log,
'mu': mu, "mu": mu,
'log_sigma': log_sigma, "log_sigma": log_sigma,
'logp': logp "logp": logp,
} }
return outputs return outputs
@torch.no_grad() @torch.no_grad()
def inference(self, x, cond_input={'x_vectors': None}): # pylint: disable=unused-argument def inference(self, x, cond_input={"x_vectors": None}): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
x: [B, T_max] x: [B, T_max]
x_lengths: [B] x_lengths: [B]
g: [B, C] g: [B, C]
""" """
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# pad input to prevent dropping the last word # pad input to prevent dropping the last word
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) # x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
@ -326,46 +284,40 @@ class AlignTTS(nn.Module):
# duration predictor pass # duration predictor pass
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1) y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en, o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
o_en_dp, outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn}
o_dr,
x_mask,
y_lengths,
g=g)
outputs = {'model_outputs': o_de.transpose(1, 2), 'alignments': attn}
return outputs return outputs
def train_step(self, batch: dict, criterion: nn.Module): def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch['text_input'] text_input = batch["text_input"]
text_lengths = batch['text_lengths'] text_lengths = batch["text_lengths"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
mel_lengths = batch['mel_lengths'] mel_lengths = batch["mel_lengths"]
x_vectors = batch['x_vectors'] x_vectors = batch["x_vectors"]
speaker_ids = batch['speaker_ids'] speaker_ids = batch["speaker_ids"]
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids} cond_input = {"x_vectors": x_vectors, "speaker_ids": speaker_ids}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input, self.phase) outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input, self.phase)
loss_dict = criterion( loss_dict = criterion(
outputs['logp'], outputs["logp"],
outputs['model_outputs'], outputs["model_outputs"],
mel_input, mel_input,
mel_lengths, mel_lengths,
outputs['durations_log'], outputs["durations_log"],
outputs['durations_mas_log'], outputs["durations_mas_log"],
text_lengths, text_lengths,
phase=self.phase, phase=self.phase,
) )
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'], align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
binary=True)
loss_dict["align_error"] = align_error loss_dict["align_error"] = align_error
return outputs, loss_dict return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
model_outputs = outputs['model_outputs'] model_outputs = outputs["model_outputs"]
alignments = outputs['alignments'] alignments = outputs["alignments"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
pred_spec = model_outputs[0].data.cpu().numpy() pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy()
@ -387,7 +339,9 @@ class AlignTTS(nn.Module):
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs) return self.train_log(ap, batch, outputs)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:

View File

@ -38,6 +38,7 @@ class GlowTTS(nn.Module):
encoder_params (dict): encoder module parameters. encoder_params (dict): encoder module parameters.
speaker_embedding_dim (int): channels of external speaker embedding vectors. speaker_embedding_dim (int): channels of external speaker embedding vectors.
""" """
def __init__( def __init__(
self, self,
num_chars, num_chars,
@ -132,17 +133,17 @@ class GlowTTS(nn.Module):
@staticmethod @staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask): def compute_outputs(attn, o_mean, o_log_scale, x_mask):
# compute final values with the computed alignment # compute final values with the computed alignment
y_mean = torch.matmul( y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( 1, 2
1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] ) # [b, t', t], [b, t, d] -> [b, d, t']
y_log_scale = torch.matmul( y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose(
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose( 1, 2
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] ) # [b, t', t], [b, t, d] -> [b, d, t']
# compute total duration with adjustment # compute total duration with adjustment
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
return y_mean, y_log_scale, o_attn_dur return y_mean, y_log_scale, o_attn_dur
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={'x_vectors':None}): def forward(self, x, x_lengths, y, y_lengths=None, cond_input={"x_vectors": None}):
""" """
Shapes: Shapes:
x: [B, T] x: [B, T]
@ -154,7 +155,7 @@ class GlowTTS(nn.Module):
y_max_length = y.size(2) y_max_length = y.size(2)
y = y.transpose(1, 2) y = y.transpose(1, 2)
# norm speaker embeddings # norm speaker embeddings
g = cond_input['x_vectors'] g = cond_input["x_vectors"]
if g is not None: if g is not None:
if self.speaker_embedding_dim: if self.speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1) g = F.normalize(g).unsqueeze(-1)
@ -162,54 +163,38 @@ class GlowTTS(nn.Module):
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
x_lengths,
g=g)
# drop redisual frames wrt num_squeeze and set y_lengths. # drop redisual frames wrt num_squeeze and set y_lengths.
y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
y, y_lengths, y_max_length, None)
# create masks # create masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# decoder pass # decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False) z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# find the alignment path # find the alignment path
with torch.no_grad(): with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale) o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
[1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
attn = attn.squeeze(1).permute(0, 2, 1) attn = attn.squeeze(1).permute(0, 2, 1)
outputs = { outputs = {
'model_outputs': z, "model_outputs": z,
'logdet': logdet, "logdet": logdet,
'y_mean': y_mean, "y_mean": y_mean,
'y_log_scale': y_log_scale, "y_log_scale": y_log_scale,
'alignments': attn, "alignments": attn,
'durations_log': o_dur_log, "durations_log": o_dur_log,
'total_durations_log': o_attn_dur "total_durations_log": o_attn_dur,
} }
return outputs return outputs
@torch.no_grad() @torch.no_grad()
def inference_with_MAS(self, def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
x,
x_lengths,
y=None,
y_lengths=None,
attn=None,
g=None):
""" """
It's similar to the teacher forcing in Tacotron. It's similar to the teacher forcing in Tacotron.
It was proposed in: https://arxiv.org/abs/2104.05557 It was proposed in: https://arxiv.org/abs/2104.05557
@ -229,33 +214,24 @@ class GlowTTS(nn.Module):
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
x_lengths,
g=g)
# drop redisual frames wrt num_squeeze and set y_lengths. # drop redisual frames wrt num_squeeze and set y_lengths.
y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
y, y_lengths, y_max_length, None)
# create masks # create masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# decoder pass # decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False) z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# find the alignment path between z and encoder output # find the alignment path between z and encoder output
o_scale = torch.exp(-2 * o_log_scale) o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
[1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs( y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
attn, o_mean, o_log_scale, x_mask)
attn = attn.squeeze(1).permute(0, 2, 1) attn = attn.squeeze(1).permute(0, 2, 1)
# get predited aligned distribution # get predited aligned distribution
@ -264,13 +240,13 @@ class GlowTTS(nn.Module):
# reverse the decoder and predict using the aligned distribution # reverse the decoder and predict using the aligned distribution
y, logdet = self.decoder(z, y_mask, g=g, reverse=True) y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
outputs = { outputs = {
'model_outputs': y, "model_outputs": y,
'logdet': logdet, "logdet": logdet,
'y_mean': y_mean, "y_mean": y_mean,
'y_log_scale': y_log_scale, "y_log_scale": y_log_scale,
'alignments': attn, "alignments": attn,
'durations_log': o_dur_log, "durations_log": o_dur_log,
'total_durations_log': o_attn_dur "total_durations_log": o_attn_dur,
} }
return outputs return outputs
@ -290,8 +266,7 @@ class GlowTTS(nn.Module):
else: else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
1).to(y.dtype)
# decoder pass # decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False) z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
@ -310,37 +285,31 @@ class GlowTTS(nn.Module):
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
x_lengths,
g=g)
# compute output durations # compute output durations
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = None y_max_length = None
# compute masks # compute masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# compute attention mask # compute attention mask
attn = generate_path(w_ceil.squeeze(1), attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
attn_mask.squeeze(1)).unsqueeze(1) y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask
self.inference_noise_scale) * y_mask
# decoder pass # decoder pass
y, logdet = self.decoder(z, y_mask, g=g, reverse=True) y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
attn = attn.squeeze(1).permute(0, 2, 1) attn = attn.squeeze(1).permute(0, 2, 1)
outputs = { outputs = {
'model_outputs': y, "model_outputs": y,
'logdet': logdet, "logdet": logdet,
'y_mean': y_mean, "y_mean": y_mean,
'y_log_scale': y_log_scale, "y_log_scale": y_log_scale,
'alignments': attn, "alignments": attn,
'durations_log': o_dur_log, "durations_log": o_dur_log,
'total_durations_log': o_attn_dur "total_durations_log": o_attn_dur,
} }
return outputs return outputs
@ -351,32 +320,34 @@ class GlowTTS(nn.Module):
batch (dict): [description] batch (dict): [description]
criterion (nn.Module): [description] criterion (nn.Module): [description]
""" """
text_input = batch['text_input'] text_input = batch["text_input"]
text_lengths = batch['text_lengths'] text_lengths = batch["text_lengths"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
mel_lengths = batch['mel_lengths'] mel_lengths = batch["mel_lengths"]
x_vectors = batch['x_vectors'] x_vectors = batch["x_vectors"]
outputs = self.forward(text_input, outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input={"x_vectors": x_vectors})
text_lengths,
mel_input,
mel_lengths,
cond_input={"x_vectors": x_vectors})
loss_dict = criterion(outputs['model_outputs'], outputs['y_mean'], loss_dict = criterion(
outputs['y_log_scale'], outputs['logdet'], outputs["model_outputs"],
mel_lengths, outputs['durations_log'], outputs["y_mean"],
outputs['total_durations_log'], text_lengths) outputs["y_log_scale"],
outputs["logdet"],
mel_lengths,
outputs["durations_log"],
outputs["total_durations_log"],
text_lengths,
)
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'], binary=True) align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
loss_dict["align_error"] = align_error loss_dict["align_error"] = align_error
return outputs, loss_dict return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
model_outputs = outputs['model_outputs'] model_outputs = outputs["model_outputs"]
alignments = outputs['alignments'] alignments = outputs["alignments"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
pred_spec = model_outputs[0].data.cpu().numpy() pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy()
@ -400,8 +371,7 @@ class GlowTTS(nn.Module):
def preprocess(self, y, y_lengths, y_max_length, attn=None): def preprocess(self, y, y_lengths, y_max_length, attn=None):
if y_max_length is not None: if y_max_length is not None:
y_max_length = (y_max_length // y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
self.num_squeeze) * self.num_squeeze
y = y[:, :, :y_max_length] y = y[:, :, :y_max_length]
if attn is not None: if attn is not None:
attn = attn[:, :, :, :y_max_length] attn = attn[:, :, :, :y_max_length]
@ -411,7 +381,9 @@ class GlowTTS(nn.Module):
def store_inverse(self): def store_inverse(self):
self.decoder.store_inverse() self.decoder.store_inverse()
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:

View File

@ -49,12 +49,7 @@ class SpeedySpeech(nn.Module):
positional_encoding=True, positional_encoding=True,
length_scale=1, length_scale=1,
encoder_type="residual_conv_bn", encoder_type="residual_conv_bn",
encoder_params={ encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type="residual_conv_bn", decoder_type="residual_conv_bn",
decoder_params={ decoder_params={
"kernel_size": 4, "kernel_size": 4,
@ -68,17 +63,13 @@ class SpeedySpeech(nn.Module):
): ):
super().__init__() super().__init__()
self.length_scale = float(length_scale) if isinstance( self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
length_scale, int) else length_scale
self.emb = nn.Embedding(num_chars, hidden_channels) self.emb = nn.Embedding(num_chars, hidden_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
encoder_params, c_in_channels)
if positional_encoding: if positional_encoding:
self.pos_encoder = PositionalEncoding(hidden_channels) self.pos_encoder = PositionalEncoding(hidden_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
decoder_params) self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
self.duration_predictor = DurationPredictor(hidden_channels +
c_in_channels)
if num_speakers > 1 and not external_c: if num_speakers > 1 and not external_c:
# speaker embedding layer # speaker embedding layer
@ -105,9 +96,7 @@ class SpeedySpeech(nn.Module):
""" """
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
o_en_ex = torch.matmul( o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
attn.squeeze(1).transpose(1, 2), en.transpose(1,
2)).transpose(1, 2)
return o_en_ex, attn return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask): def format_durations(self, o_dr_log, x_mask):
@ -141,8 +130,7 @@ class SpeedySpeech(nn.Module):
x_emb = torch.transpose(x_emb, 1, -1) x_emb = torch.transpose(x_emb, 1, -1)
# compute sequence masks # compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
1).to(x.dtype)
# encoder pass # encoder pass
o_en = self.encoder(x_emb, x_mask) o_en = self.encoder(x_emb, x_mask)
@ -155,8 +143,7 @@ class SpeedySpeech(nn.Module):
return o_en, o_en_dp, x_mask, g return o_en, o_en_dp, x_mask, g
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
1).to(o_en_dp.dtype)
# expand o_en with durations # expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding # positional encoding
@ -169,15 +156,9 @@ class SpeedySpeech(nn.Module):
o_de = self.decoder(o_en_ex, y_mask, g=g) o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de, attn.transpose(1, 2) return o_de, attn.transpose(1, 2)
def forward(self, def forward(
x, self, x, x_lengths, y_lengths, dr, cond_input={"x_vectors": None, "speaker_ids": None}
x_lengths, ): # pylint: disable=unused-argument
y_lengths,
dr,
cond_input={
'x_vectors': None,
'speaker_ids': None
}): # pylint: disable=unused-argument
""" """
TODO: speaker embedding for speaker_ids TODO: speaker embedding for speaker_ids
Shapes: Shapes:
@ -187,91 +168,68 @@ class SpeedySpeech(nn.Module):
dr: [B, T_max] dr: [B, T_max]
g: [B, C] g: [B, C]
""" """
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_de, attn = self._forward_decoder(o_en, o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
o_en_dp, outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
dr,
x_mask,
y_lengths,
g=g)
outputs = {
'model_outputs': o_de.transpose(1, 2),
'durations_log': o_dr_log.squeeze(1),
'alignments': attn
}
return outputs return outputs
def inference(self, def inference(self, x, cond_input={"x_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
x,
cond_input={
'x_vectors': None,
'speaker_ids': None
}): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
x: [B, T_max] x: [B, T_max]
x_lengths: [B] x_lengths: [B]
g: [B, C] g: [B, C]
""" """
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# input sequence should be greated than the max convolution size # input sequence should be greated than the max convolution size
inference_padding = 5 inference_padding = 5
if x.shape[1] < 13: if x.shape[1] < 13:
inference_padding += 13 - x.shape[1] inference_padding += 13 - x.shape[1]
# pad input to prevent dropping the last word # pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
pad=(0, inference_padding),
mode="constant",
value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
# duration predictor pass # duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1) y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en, o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
o_en_dp, outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "durations_log": None}
o_dr,
x_mask,
y_lengths,
g=g)
outputs = {
'model_outputs': o_de.transpose(1, 2),
'alignments': attn,
'durations_log': None
}
return outputs return outputs
def train_step(self, batch: dict, criterion: nn.Module): def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch['text_input'] text_input = batch["text_input"]
text_lengths = batch['text_lengths'] text_lengths = batch["text_lengths"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
mel_lengths = batch['mel_lengths'] mel_lengths = batch["mel_lengths"]
x_vectors = batch['x_vectors'] x_vectors = batch["x_vectors"]
speaker_ids = batch['speaker_ids'] speaker_ids = batch["speaker_ids"]
durations = batch['durations'] durations = batch["durations"]
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids} cond_input = {"x_vectors": x_vectors, "speaker_ids": speaker_ids}
outputs = self.forward(text_input, text_lengths, mel_lengths, outputs = self.forward(text_input, text_lengths, mel_lengths, durations, cond_input)
durations, cond_input)
# compute loss # compute loss
loss_dict = criterion(outputs['model_outputs'], mel_input, loss_dict = criterion(
mel_lengths, outputs['durations_log'], outputs["model_outputs"],
torch.log(1 + durations), text_lengths) mel_input,
mel_lengths,
outputs["durations_log"],
torch.log(1 + durations),
text_lengths,
)
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'], align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
binary=True)
loss_dict["align_error"] = align_error loss_dict["align_error"] = align_error
return outputs, loss_dict return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
model_outputs = outputs['model_outputs'] model_outputs = outputs["model_outputs"]
alignments = outputs['alignments'] alignments = outputs["alignments"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
pred_spec = model_outputs[0].data.cpu().numpy() pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy()
@ -293,7 +251,9 @@ class SpeedySpeech(nn.Module):
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs) return self.train_log(ap, batch, outputs)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:

View File

@ -50,6 +50,7 @@ class Tacotron(TacotronAbstract):
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used. gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
Defaults to `[]`. Defaults to `[]`.
""" """
def __init__( def __init__(
self, self,
num_chars, num_chars,
@ -78,7 +79,7 @@ class Tacotron(TacotronAbstract):
use_gst=False, use_gst=False,
gst=None, gst=None,
memory_size=5, memory_size=5,
gradual_training=[] gradual_training=[],
): ):
super().__init__( super().__init__(
num_chars, num_chars,
@ -106,15 +107,14 @@ class Tacotron(TacotronAbstract):
speaker_embedding_dim, speaker_embedding_dim,
use_gst, use_gst,
gst, gst,
gradual_training gradual_training,
) )
# speaker embedding layers # speaker embedding layers
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.embeddings_per_sample: if not self.embeddings_per_sample:
speaker_embedding_dim = 256 speaker_embedding_dim = 256
self.speaker_embedding = nn.Embedding(self.num_speakers, self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
speaker_embedding_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
# speaker and gst embeddings is concat in decoder input # speaker and gst embeddings is concat in decoder input
@ -145,8 +145,7 @@ class Tacotron(TacotronAbstract):
separate_stopnet, separate_stopnet,
) )
self.postnet = PostCBHG(decoder_output_dim) self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
postnet_output_dim)
# setup prenet dropout # setup prenet dropout
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
@ -183,12 +182,7 @@ class Tacotron(TacotronAbstract):
separate_stopnet, separate_stopnet,
) )
def forward(self, def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
text,
text_lengths,
mel_specs=None,
mel_lengths=None,
cond_input=None):
""" """
Shapes: Shapes:
text: [B, T_in] text: [B, T_in]
@ -197,100 +191,87 @@ class Tacotron(TacotronAbstract):
mel_lengths: [B] mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
""" """
outputs = { outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
'alignments_backward': None,
'decoder_outputs_backward': None
}
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x T_in x embed_dim # B x T_in x embed_dim
inputs = self.embedding(text) inputs = self.embedding(text)
# B x T_in x encoder_in_features # B x T_in x encoder_in_features
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# sequence masking # sequence masking
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as( encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
encoder_outputs)
# global style token # global style token
if self.gst and self.use_gst: if self.gst and self.use_gst:
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["x_vectors"])
cond_input['x_vectors'])
# speaker embedding # speaker embedding
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.embeddings_per_sample: if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
None]
else: else:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1) speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding( encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
encoder_outputs, speaker_embeddings)
# decoder_outputs: B x decoder_in_features x T_out # decoder_outputs: B x decoder_in_features x T_out
# alignments: B x T_in x encoder_in_features # alignments: B x T_in x encoder_in_features
# stop_tokens: B x T_in # stop_tokens: B x T_in
decoder_outputs, alignments, stop_tokens = self.decoder( decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
encoder_outputs, mel_specs, input_mask)
# sequence masking # sequence masking
if output_mask is not None: if output_mask is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze( decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
1).expand_as(decoder_outputs)
# B x T_out x decoder_in_features # B x T_out x decoder_in_features
postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.postnet(decoder_outputs)
# sequence masking # sequence masking
if output_mask is not None: if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze( postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs)
2).expand_as(postnet_outputs)
# B x T_out x posnet_dim # B x T_out x posnet_dim
postnet_outputs = self.last_linear(postnet_outputs) postnet_outputs = self.last_linear(postnet_outputs)
# B x T_out x decoder_in_features # B x T_out x decoder_in_features
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
if self.bidirectional_decoder: if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_pass( decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
mel_specs, encoder_outputs, input_mask) outputs["alignments_backward"] = alignments_backward
outputs['alignments_backward'] = alignments_backward outputs["decoder_outputs_backward"] = decoder_outputs_backward
outputs['decoder_outputs_backward'] = decoder_outputs_backward
if self.double_decoder_consistency: if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
mel_specs, encoder_outputs, alignments, input_mask) mel_specs, encoder_outputs, alignments, input_mask
outputs['alignments_backward'] = alignments_backward )
outputs['decoder_outputs_backward'] = decoder_outputs_backward outputs["alignments_backward"] = alignments_backward
outputs.update({ outputs["decoder_outputs_backward"] = decoder_outputs_backward
'model_outputs': postnet_outputs, outputs.update(
'decoder_outputs': decoder_outputs, {
'alignments': alignments, "model_outputs": postnet_outputs,
'stop_tokens': stop_tokens "decoder_outputs": decoder_outputs,
}) "alignments": alignments,
"stop_tokens": stop_tokens,
}
)
return outputs return outputs
@torch.no_grad() @torch.no_grad()
def inference(self, def inference(self, text_input, cond_input=None):
text_input,
cond_input=None):
inputs = self.embedding(text_input) inputs = self.embedding(text_input)
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
if self.gst and self.use_gst: if self.gst and self.use_gst:
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'], encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["x_vectors"])
cond_input['x_vectors'])
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.embeddings_per_sample: if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, None] speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
else: else:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1) speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding( encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
encoder_outputs, speaker_embeddings) decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = self.last_linear(postnet_outputs) postnet_outputs = self.last_linear(postnet_outputs)
decoder_outputs = decoder_outputs.transpose(1, 2) decoder_outputs = decoder_outputs.transpose(1, 2)
outputs = { outputs = {
'model_outputs': postnet_outputs, "model_outputs": postnet_outputs,
'decoder_outputs': decoder_outputs, "decoder_outputs": decoder_outputs,
'alignments': alignments, "alignments": alignments,
'stop_tokens': stop_tokens "stop_tokens": stop_tokens,
} }
return outputs return outputs
@ -301,64 +282,61 @@ class Tacotron(TacotronAbstract):
batch ([type]): [description] batch ([type]): [description]
criterion ([type]): [description] criterion ([type]): [description]
""" """
text_input = batch['text_input'] text_input = batch["text_input"]
text_lengths = batch['text_lengths'] text_lengths = batch["text_lengths"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
mel_lengths = batch['mel_lengths'] mel_lengths = batch["mel_lengths"]
linear_input = batch['linear_input'] linear_input = batch["linear_input"]
stop_targets = batch['stop_targets'] stop_targets = batch["stop_targets"]
speaker_ids = batch['speaker_ids'] speaker_ids = batch["speaker_ids"]
x_vectors = batch['x_vectors'] x_vectors = batch["x_vectors"]
# forward pass model # forward pass model
outputs = self.forward(text_input, outputs = self.forward(
text_lengths, text_input,
mel_input, text_lengths,
mel_lengths, mel_input,
cond_input={ mel_lengths,
'speaker_ids': speaker_ids, cond_input={"speaker_ids": speaker_ids, "x_vectors": x_vectors},
'x_vectors': x_vectors )
})
# set the [alignment] lengths wrt reduction factor for guided attention # set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % self.decoder.r != 0: if mel_lengths.max() % self.decoder.r != 0:
alignment_lengths = ( alignment_lengths = (
mel_lengths + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
(self.decoder.r - ) // self.decoder.r
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
else: else:
alignment_lengths = mel_lengths // self.decoder.r alignment_lengths = mel_lengths // self.decoder.r
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors} cond_input = {"speaker_ids": speaker_ids, "x_vectors": x_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
mel_lengths, cond_input)
# compute loss # compute loss
loss_dict = criterion( loss_dict = criterion(
outputs['model_outputs'], outputs["model_outputs"],
outputs['decoder_outputs'], outputs["decoder_outputs"],
mel_input, mel_input,
linear_input, linear_input,
outputs['stop_tokens'], outputs["stop_tokens"],
stop_targets, stop_targets,
mel_lengths, mel_lengths,
outputs['decoder_outputs_backward'], outputs["decoder_outputs_backward"],
outputs['alignments'], outputs["alignments"],
alignment_lengths, alignment_lengths,
outputs['alignments_backward'], outputs["alignments_backward"],
text_lengths, text_lengths,
) )
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments']) align_error = 1 - alignment_diagonal_score(outputs["alignments"])
loss_dict["align_error"] = align_error loss_dict["align_error"] = align_error
return outputs, loss_dict return outputs, loss_dict
def train_log(self, ap, batch, outputs): def train_log(self, ap, batch, outputs):
postnet_outputs = outputs['model_outputs'] postnet_outputs = outputs["model_outputs"]
alignments = outputs['alignments'] alignments = outputs["alignments"]
alignments_backward = outputs['alignments_backward'] alignments_backward = outputs["alignments_backward"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
pred_spec = postnet_outputs[0].data.cpu().numpy() pred_spec = postnet_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy()
@ -371,8 +349,7 @@ class Tacotron(TacotronAbstract):
} }
if self.bidirectional_decoder or self.double_decoder_consistency: if self.bidirectional_decoder or self.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment( figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
alignments_backward[0].data.cpu().numpy(), output_fig=False)
# Sample audio # Sample audio
train_audio = ap.inv_spectrogram(pred_spec.T) train_audio = ap.inv_spectrogram(pred_spec.T)

View File

@ -49,49 +49,70 @@ class Tacotron2(TacotronAbstract):
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used. gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
Defaults to `[]`. Defaults to `[]`.
""" """
def __init__(self,
num_chars, def __init__(
num_speakers, self,
r, num_chars,
postnet_output_dim=80, num_speakers,
decoder_output_dim=80, r,
attn_type="original", postnet_output_dim=80,
attn_win=False, decoder_output_dim=80,
attn_norm="softmax", attn_type="original",
prenet_type="original", attn_win=False,
prenet_dropout=True, attn_norm="softmax",
prenet_dropout_at_inference=False, prenet_type="original",
forward_attn=False, prenet_dropout=True,
trans_agent=False, prenet_dropout_at_inference=False,
forward_attn_mask=False, forward_attn=False,
location_attn=True, trans_agent=False,
attn_K=5, forward_attn_mask=False,
separate_stopnet=True, location_attn=True,
bidirectional_decoder=False, attn_K=5,
double_decoder_consistency=False, separate_stopnet=True,
ddc_r=None, bidirectional_decoder=False,
encoder_in_features=512, double_decoder_consistency=False,
decoder_in_features=512, ddc_r=None,
speaker_embedding_dim=None, encoder_in_features=512,
use_gst=False, decoder_in_features=512,
gst=None, speaker_embedding_dim=None,
gradual_training=[]): use_gst=False,
super().__init__(num_chars, num_speakers, r, postnet_output_dim, gst=None,
decoder_output_dim, attn_type, attn_win, attn_norm, gradual_training=[],
prenet_type, prenet_dropout, ):
prenet_dropout_at_inference, forward_attn, super().__init__(
trans_agent, forward_attn_mask, location_attn, attn_K, num_chars,
separate_stopnet, bidirectional_decoder, num_speakers,
double_decoder_consistency, ddc_r, r,
encoder_in_features, decoder_in_features, postnet_output_dim,
speaker_embedding_dim, use_gst, gst, gradual_training) decoder_output_dim,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
prenet_dropout_at_inference,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
bidirectional_decoder,
double_decoder_consistency,
ddc_r,
encoder_in_features,
decoder_in_features,
speaker_embedding_dim,
use_gst,
gst,
gradual_training,
)
# speaker embedding layer # speaker embedding layer
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.embeddings_per_sample: if not self.embeddings_per_sample:
speaker_embedding_dim = 512 speaker_embedding_dim = 512
self.speaker_embedding = nn.Embedding(self.num_speakers, self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
speaker_embedding_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
# speaker and gst embeddings is concat in decoder input # speaker and gst embeddings is concat in decoder input
@ -162,12 +183,7 @@ class Tacotron2(TacotronAbstract):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments return mel_outputs, mel_outputs_postnet, alignments
def forward(self, def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
text,
text_lengths,
mel_specs=None,
mel_lengths=None,
cond_input=None):
""" """
Shapes: Shapes:
text: [B, T_in] text: [B, T_in]
@ -176,10 +192,7 @@ class Tacotron2(TacotronAbstract):
mel_lengths: [B] mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
""" """
outputs = { outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
'alignments_backward': None,
'decoder_outputs_backward': None
}
# compute mask for padding # compute mask for padding
# B x T_in_max (boolean) # B x T_in_max (boolean)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
@ -189,55 +202,49 @@ class Tacotron2(TacotronAbstract):
encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self.encoder(embedded_inputs, text_lengths)
if self.gst and self.use_gst: if self.gst and self.use_gst:
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["x_vectors"])
cond_input['x_vectors'])
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.embeddings_per_sample: if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
None]
else: else:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1) speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding( encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
encoder_outputs, speaker_embeddings)
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as( encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
encoder_outputs)
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
decoder_outputs, alignments, stop_tokens = self.decoder( decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
encoder_outputs, mel_specs, input_mask)
# sequence masking # sequence masking
if mel_lengths is not None: if mel_lengths is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze( decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
1).expand_as(decoder_outputs)
# B x mel_dim x T_out # B x mel_dim x T_out
postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs postnet_outputs = decoder_outputs + postnet_outputs
# sequence masking # sequence masking
if output_mask is not None: if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze( postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
1).expand_as(postnet_outputs)
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
decoder_outputs, postnet_outputs, alignments = self.shape_outputs( decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
decoder_outputs, postnet_outputs, alignments)
if self.bidirectional_decoder: if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_pass( decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
mel_specs, encoder_outputs, input_mask) outputs["alignments_backward"] = alignments_backward
outputs['alignments_backward'] = alignments_backward outputs["decoder_outputs_backward"] = decoder_outputs_backward
outputs['decoder_outputs_backward'] = decoder_outputs_backward
if self.double_decoder_consistency: if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
mel_specs, encoder_outputs, alignments, input_mask) mel_specs, encoder_outputs, alignments, input_mask
outputs['alignments_backward'] = alignments_backward )
outputs['decoder_outputs_backward'] = decoder_outputs_backward outputs["alignments_backward"] = alignments_backward
outputs.update({ outputs["decoder_outputs_backward"] = decoder_outputs_backward
'model_outputs': postnet_outputs, outputs.update(
'decoder_outputs': decoder_outputs, {
'alignments': alignments, "model_outputs": postnet_outputs,
'stop_tokens': stop_tokens "decoder_outputs": decoder_outputs,
}) "alignments": alignments,
"stop_tokens": stop_tokens,
}
)
return outputs return outputs
@torch.no_grad() @torch.no_grad()
@ -247,29 +254,25 @@ class Tacotron2(TacotronAbstract):
if self.gst and self.use_gst: if self.gst and self.use_gst:
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'], encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["x_vectors"])
cond_input['x_vectors'])
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.embeddings_per_sample: if not self.embeddings_per_sample:
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None] x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2) x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
else: else:
x_vector = cond_input['x_vectors'] x_vector = cond_input["x_vectors"]
encoder_outputs = self._concat_speaker_embedding( encoder_outputs = self._concat_speaker_embedding(encoder_outputs, x_vector)
encoder_outputs, x_vector)
decoder_outputs, alignments, stop_tokens = self.decoder.inference( decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
encoder_outputs)
postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs postnet_outputs = decoder_outputs + postnet_outputs
decoder_outputs, postnet_outputs, alignments = self.shape_outputs( decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
decoder_outputs, postnet_outputs, alignments)
outputs = { outputs = {
'model_outputs': postnet_outputs, "model_outputs": postnet_outputs,
'decoder_outputs': decoder_outputs, "decoder_outputs": decoder_outputs,
'alignments': alignments, "alignments": alignments,
'stop_tokens': stop_tokens "stop_tokens": stop_tokens,
} }
return outputs return outputs
@ -280,64 +283,61 @@ class Tacotron2(TacotronAbstract):
batch ([type]): [description] batch ([type]): [description]
criterion ([type]): [description] criterion ([type]): [description]
""" """
text_input = batch['text_input'] text_input = batch["text_input"]
text_lengths = batch['text_lengths'] text_lengths = batch["text_lengths"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
mel_lengths = batch['mel_lengths'] mel_lengths = batch["mel_lengths"]
linear_input = batch['linear_input'] linear_input = batch["linear_input"]
stop_targets = batch['stop_targets'] stop_targets = batch["stop_targets"]
speaker_ids = batch['speaker_ids'] speaker_ids = batch["speaker_ids"]
x_vectors = batch['x_vectors'] x_vectors = batch["x_vectors"]
# forward pass model # forward pass model
outputs = self.forward(text_input, outputs = self.forward(
text_lengths, text_input,
mel_input, text_lengths,
mel_lengths, mel_input,
cond_input={ mel_lengths,
'speaker_ids': speaker_ids, cond_input={"speaker_ids": speaker_ids, "x_vectors": x_vectors},
'x_vectors': x_vectors )
})
# set the [alignment] lengths wrt reduction factor for guided attention # set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % self.decoder.r != 0: if mel_lengths.max() % self.decoder.r != 0:
alignment_lengths = ( alignment_lengths = (
mel_lengths + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
(self.decoder.r - ) // self.decoder.r
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
else: else:
alignment_lengths = mel_lengths // self.decoder.r alignment_lengths = mel_lengths // self.decoder.r
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors} cond_input = {"speaker_ids": speaker_ids, "x_vectors": x_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
mel_lengths, cond_input)
# compute loss # compute loss
loss_dict = criterion( loss_dict = criterion(
outputs['model_outputs'], outputs["model_outputs"],
outputs['decoder_outputs'], outputs["decoder_outputs"],
mel_input, mel_input,
linear_input, linear_input,
outputs['stop_tokens'], outputs["stop_tokens"],
stop_targets, stop_targets,
mel_lengths, mel_lengths,
outputs['decoder_outputs_backward'], outputs["decoder_outputs_backward"],
outputs['alignments'], outputs["alignments"],
alignment_lengths, alignment_lengths,
outputs['alignments_backward'], outputs["alignments_backward"],
text_lengths, text_lengths,
) )
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments']) align_error = 1 - alignment_diagonal_score(outputs["alignments"])
loss_dict["align_error"] = align_error loss_dict["align_error"] = align_error
return outputs, loss_dict return outputs, loss_dict
def train_log(self, ap, batch, outputs): def train_log(self, ap, batch, outputs):
postnet_outputs = outputs['model_outputs'] postnet_outputs = outputs["model_outputs"]
alignments = outputs['alignments'] alignments = outputs["alignments"]
alignments_backward = outputs['alignments_backward'] alignments_backward = outputs["alignments_backward"]
mel_input = batch['mel_input'] mel_input = batch["mel_input"]
pred_spec = postnet_outputs[0].data.cpu().numpy() pred_spec = postnet_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy()
@ -350,8 +350,7 @@ class Tacotron2(TacotronAbstract):
} }
if self.bidirectional_decoder or self.double_decoder_consistency: if self.bidirectional_decoder or self.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment( figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
alignments_backward[0].data.cpu().numpy(), output_fig=False)
# Sample audio # Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T) train_audio = ap.inv_melspectrogram(pred_spec.T)

View File

@ -37,7 +37,7 @@ class TacotronAbstract(ABC, nn.Module):
speaker_embedding_dim=None, speaker_embedding_dim=None,
use_gst=False, use_gst=False,
gst=None, gst=None,
gradual_training=[] gradual_training=[],
): ):
"""Abstract Tacotron class""" """Abstract Tacotron class"""
super().__init__() super().__init__()

View File

@ -59,25 +59,16 @@ def numpy_to_tf(np_array, dtype):
def compute_style_mel(style_wav, ap, cuda=False): def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor( style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
ap.melspectrogram(ap.load_wav(style_wav,
sr=ap.sample_rate))).unsqueeze(0)
if cuda: if cuda:
return style_mel.cuda() return style_mel.cuda()
return style_mel return style_mel
def run_model_torch(model, def run_model_torch(model, inputs, speaker_id=None, style_mel=None, x_vector=None):
inputs, outputs = model.inference(
speaker_id=None, inputs, cond_input={"speaker_ids": speaker_id, "x_vector": x_vector, "style_mel": style_mel}
style_mel=None, )
x_vector=None):
outputs = model.inference(inputs,
cond_input={
'speaker_ids': speaker_id,
'x_vector': x_vector,
'style_mel': style_mel
})
return outputs return outputs
@ -87,18 +78,15 @@ def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
if speaker_id is not None: if speaker_id is not None:
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF") raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
# TODO: handle multispeaker case # TODO: handle multispeaker case
decoder_output, postnet_output, alignments, stop_tokens = model( decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
inputs, training=False)
return decoder_output, postnet_output, alignments, stop_tokens return decoder_output, postnet_output, alignments, stop_tokens
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None): def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
if CONFIG.gst and style_mel is not None: if CONFIG.gst and style_mel is not None:
raise NotImplementedError( raise NotImplementedError(" [!] GST inference not implemented for TfLite")
" [!] GST inference not implemented for TfLite")
if speaker_id is not None: if speaker_id is not None:
raise NotImplementedError( raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
" [!] Multi-Speaker not implemented for TfLite")
# get input and output details # get input and output details
input_details = model.get_input_details() input_details = model.get_input_details()
output_details = model.get_output_details() output_details = model.get_output_details()
@ -132,7 +120,7 @@ def parse_outputs_tflite(postnet_output, decoder_output):
def trim_silence(wav, ap): def trim_silence(wav, ap):
return wav[:ap.find_endpoint(wav)] return wav[: ap.find_endpoint(wav)]
def inv_spectrogram(postnet_output, ap, CONFIG): def inv_spectrogram(postnet_output, ap, CONFIG):
@ -155,8 +143,7 @@ def speaker_id_to_torch(speaker_id, cuda=False):
def embedding_to_torch(x_vector, cuda=False): def embedding_to_torch(x_vector, cuda=False):
if x_vector is not None: if x_vector is not None:
x_vector = np.asarray(x_vector) x_vector = np.asarray(x_vector)
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type( x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(torch.FloatTensor)
torch.FloatTensor)
if cuda: if cuda:
return x_vector.cuda() return x_vector.cuda()
return x_vector return x_vector
@ -173,8 +160,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
""" """
wavs = [] wavs = []
for idx, spec in enumerate(inputs): for idx, spec in enumerate(inputs):
wav_len = (input_lens[idx] * wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
ap.hop_length) - ap.hop_length # inverse librosa padding
wav = inv_spectrogram(spec, ap, CONFIG) wav = inv_spectrogram(spec, ap, CONFIG)
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}" # assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
wavs.append(wav[:wav_len]) wavs.append(wav[:wav_len])
@ -242,23 +228,21 @@ def synthesis(
text_inputs = tf.expand_dims(text_inputs, 0) text_inputs = tf.expand_dims(text_inputs, 0)
# synthesize voice # synthesize voice
if backend == "torch": if backend == "torch":
outputs = run_model_torch(model, outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, x_vector=x_vector)
text_inputs, model_outputs = outputs["model_outputs"]
speaker_id,
style_mel,
x_vector=x_vector)
model_outputs = outputs['model_outputs']
model_outputs = model_outputs[0].data.cpu().numpy() model_outputs = model_outputs[0].data.cpu().numpy()
elif backend == "tf": elif backend == "tf":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
model, text_inputs, CONFIG, speaker_id, style_mel) model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf( model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
postnet_output, decoder_output, alignments, stop_tokens) postnet_output, decoder_output, alignments, stop_tokens
)
elif backend == "tflite": elif backend == "tflite":
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite( decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
model, text_inputs, CONFIG, speaker_id, style_mel) model, text_inputs, CONFIG, speaker_id, style_mel
model_outputs, decoder_output = parse_outputs_tflite( )
postnet_output, decoder_output) model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
# convert outputs to numpy # convert outputs to numpy
# plot results # plot results
wav = None wav = None
@ -268,9 +252,9 @@ def synthesis(
if do_trim_silence: if do_trim_silence:
wav = trim_silence(wav, ap) wav = trim_silence(wav, ap)
return_dict = { return_dict = {
'wav': wav, "wav": wav,
'alignments': outputs['alignments'], "alignments": outputs["alignments"],
'model_outputs': model_outputs, "model_outputs": model_outputs,
'text_inputs': text_inputs "text_inputs": text_inputs,
} }
return return_dict return return_dict

View File

@ -30,16 +30,16 @@ def init_arguments(argv):
parser.add_argument( parser.add_argument(
"--continue_path", "--continue_path",
type=str, type=str,
help=("Training output folder to continue training. Used to continue " help=(
"a training. If it is used, 'config_path' is ignored."), "Training output folder to continue training. Used to continue "
"a training. If it is used, 'config_path' is ignored."
),
default="", default="",
required="--config_path" not in argv, required="--config_path" not in argv,
) )
parser.add_argument( parser.add_argument(
"--restore_path", "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
type=str, )
help="Model file to be restored. Use to finetune a model.",
default="")
parser.add_argument( parser.add_argument(
"--best_path", "--best_path",
type=str, type=str,
@ -49,23 +49,12 @@ def init_arguments(argv):
), ),
default="", default="",
) )
parser.add_argument("--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in argv)
parser.add_argument("--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.")
parser.add_argument( parser.add_argument(
"--rank", "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
type=int, )
default=0, parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
help="DISTRIBUTED: process rank for distributed training.") parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id", parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
type=str,
default="",
help="DISTRIBUTED: process group id.")
return parser return parser
@ -160,8 +149,7 @@ def process_args(args):
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
experiment_path = args.continue_path experiment_path = args.continue_path
if not experiment_path: if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
config.run_name, args.debug)
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
tb_logger = None tb_logger = None
@ -182,8 +170,7 @@ def process_args(args):
os.chmod(experiment_path, 0o775) os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model) tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, tb_logger

View File

@ -234,8 +234,8 @@ class Synthesizer(object):
use_griffin_lim=use_gl, use_griffin_lim=use_gl,
x_vector=speaker_embedding, x_vector=speaker_embedding,
) )
waveform = outputs['wav'] waveform = outputs["wav"]
mel_postnet_spec = outputs['model_outputs'] mel_postnet_spec = outputs["model_outputs"]
if not use_gl: if not use_gl:
# denormalize tts output based on tts audio config # denormalize tts output based on tts audio config
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T

View File

@ -44,8 +44,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch # restore the model and continue training for one more epoch
command_train = ( command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
run_cli(command_train) run_cli(command_train)
shutil.rmtree(continue_path) shutil.rmtree(continue_path)

View File

@ -46,8 +46,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch # restore the model and continue training for one more epoch
command_train = ( command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
run_cli(command_train) run_cli(command_train)
shutil.rmtree(continue_path) shutil.rmtree(continue_path)

View File

@ -45,8 +45,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch # restore the model and continue training for one more epoch
command_train = ( command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
run_cli(command_train) run_cli(command_train)
shutil.rmtree(continue_path) shutil.rmtree(continue_path)

View File

@ -45,8 +45,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch # restore the model and continue training for one more epoch
command_train = ( command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
run_cli(command_train) run_cli(command_train)
shutil.rmtree(continue_path) shutil.rmtree(continue_path)

View File

@ -44,8 +44,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch # restore the model and continue training for one more epoch
command_train = ( command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
run_cli(command_train) run_cli(command_train)
shutil.rmtree(continue_path) shutil.rmtree(continue_path)

View File

@ -21,7 +21,6 @@ config = MelganConfig(
print_step=1, print_step=1,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]}, discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
print_eval=True, print_eval=True,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
data_path="tests/data/ljspeech", data_path="tests/data/ljspeech",
output_path=output_path, output_path=output_path,
) )