mirror of https://github.com/coqui-ai/TTS.git
make style
This commit is contained in:
parent
88d8a94a10
commit
aefa71155c
|
@ -153,15 +153,9 @@ def inference(
|
|||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||
|
||||
elif "tacotron" in model_name:
|
||||
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': speaker_embeddings}
|
||||
outputs = model(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
cond_input
|
||||
)
|
||||
postnet_outputs = outputs['model_outputs']
|
||||
cond_input = {"speaker_ids": speaker_ids, "x_vectors": speaker_embeddings}
|
||||
outputs = model(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
# normalize tacotron output
|
||||
if model_name == "tacotron":
|
||||
mel_specs = []
|
||||
|
|
|
@ -8,13 +8,8 @@ from TTS.trainer import TrainerTTS
|
|||
|
||||
def main():
|
||||
# try:
|
||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
|
||||
sys.argv)
|
||||
trainer = TrainerTTS(args,
|
||||
config,
|
||||
c_logger,
|
||||
tb_logger,
|
||||
output_path=OUT_PATH)
|
||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
|
||||
trainer.fit()
|
||||
# except KeyboardInterrupt:
|
||||
# remove_experiment_folder(OUT_PATH)
|
||||
|
|
140
TTS/trainer.py
140
TTS/trainer.py
|
@ -84,7 +84,7 @@ class TrainerTTS:
|
|||
self.best_loss = float("inf")
|
||||
self.train_loader = None
|
||||
self.eval_loader = None
|
||||
self.output_audio_path = os.path.join(output_path, 'test_audios')
|
||||
self.output_audio_path = os.path.join(output_path, "test_audios")
|
||||
|
||||
self.keep_avg_train = None
|
||||
self.keep_avg_eval = None
|
||||
|
@ -138,8 +138,8 @@ class TrainerTTS:
|
|||
|
||||
if self.args.restore_path:
|
||||
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
|
||||
self.config, args.restore_path, self.model, self.optimizer,
|
||||
self.scaler)
|
||||
self.config, args.restore_path, self.model, self.optimizer, self.scaler
|
||||
)
|
||||
|
||||
# setup scheduler
|
||||
self.scheduler = self.get_scheduler(self.config, self.optimizer)
|
||||
|
@ -207,6 +207,7 @@ class TrainerTTS:
|
|||
return None
|
||||
if lr_scheduler.lower() == "noamlr":
|
||||
from TTS.utils.training import NoamLR
|
||||
|
||||
scheduler = NoamLR
|
||||
else:
|
||||
scheduler = getattr(torch.optim, lr_scheduler)
|
||||
|
@ -261,8 +262,7 @@ class TrainerTTS:
|
|||
ap=ap,
|
||||
tp=self.config.characters,
|
||||
add_blank=self.config["add_blank"],
|
||||
batch_group_size=0 if is_eval else
|
||||
self.config.batch_group_size * self.config.batch_size,
|
||||
batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
|
||||
min_seq_len=self.config.min_seq_len,
|
||||
max_seq_len=self.config.max_seq_len,
|
||||
phoneme_cache_path=self.config.phoneme_cache_path,
|
||||
|
@ -272,8 +272,8 @@ class TrainerTTS:
|
|||
use_noise_augment=not is_eval,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping
|
||||
if self.config.use_speaker_embedding
|
||||
and self.config.use_external_speaker_embedding_file else None,
|
||||
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if self.config.use_phonemes and self.config.compute_input_seq_cache:
|
||||
|
@ -281,18 +281,15 @@ class TrainerTTS:
|
|||
dataset.compute_input_seq(self.config.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(
|
||||
dataset) if self.num_gpus > 1 else None
|
||||
sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.eval_batch_size
|
||||
if is_eval else self.config.batch_size,
|
||||
batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=self.config.num_val_loader_workers
|
||||
if is_eval else self.config.num_loader_workers,
|
||||
num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
@ -314,8 +311,7 @@ class TrainerTTS:
|
|||
text_input = batch[0]
|
||||
text_lengths = batch[1]
|
||||
speaker_names = batch[2]
|
||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"
|
||||
] else None
|
||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
||||
mel_input = batch[4]
|
||||
mel_lengths = batch[5]
|
||||
stop_targets = batch[6]
|
||||
|
@ -331,10 +327,7 @@ class TrainerTTS:
|
|||
speaker_embeddings = batch[8]
|
||||
speaker_ids = None
|
||||
else:
|
||||
speaker_ids = [
|
||||
self.speaker_manager.speaker_ids[speaker_name]
|
||||
for speaker_name in speaker_names
|
||||
]
|
||||
speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
speaker_embeddings = None
|
||||
else:
|
||||
|
@ -346,7 +339,7 @@ class TrainerTTS:
|
|||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
|
@ -359,14 +352,11 @@ class TrainerTTS:
|
|||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, :text_lengths[idx]] = dur
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
|
||||
# set stop targets view, we predict a single stop token per iteration.
|
||||
stop_targets = stop_targets.view(text_input.shape[0],
|
||||
stop_targets.size(1) // self.config.r,
|
||||
-1)
|
||||
stop_targets = (stop_targets.sum(2) >
|
||||
0.0).unsqueeze(2).float().squeeze(2)
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
# dispatch batch to GPU
|
||||
if self.use_cuda:
|
||||
|
@ -374,15 +364,10 @@ class TrainerTTS:
|
|||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
linear_input = linear_input.cuda(
|
||||
non_blocking=True) if self.config.model.lower() in [
|
||||
"tacotron"
|
||||
] else None
|
||||
linear_input = linear_input.cuda(non_blocking=True) if self.config.model.lower() in ["tacotron"] else None
|
||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||
attn_mask = attn_mask.cuda(
|
||||
non_blocking=True) if attn_mask is not None else None
|
||||
durations = durations.cuda(
|
||||
non_blocking=True) if attn_mask is not None else None
|
||||
attn_mask = attn_mask.cuda(non_blocking=True) if attn_mask is not None else None
|
||||
durations = durations.cuda(non_blocking=True) if attn_mask is not None else None
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
if speaker_embeddings is not None:
|
||||
|
@ -401,7 +386,7 @@ class TrainerTTS:
|
|||
"x_vectors": speaker_embeddings,
|
||||
"max_text_length": max_text_length,
|
||||
"max_spec_length": max_spec_length,
|
||||
"item_idx": item_idx
|
||||
"item_idx": item_idx,
|
||||
}
|
||||
|
||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int,
|
||||
|
@ -421,25 +406,20 @@ class TrainerTTS:
|
|||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(
|
||||
f"Detected NaN loss at step {self.total_steps_done}.")
|
||||
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
||||
|
||||
# optimizer step
|
||||
if self.config.mixed_precision:
|
||||
# model optimizer step in mixed precision mode
|
||||
self.scaler.scale(loss_dict["loss"]).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
grad_norm, _ = check_update(self.model,
|
||||
self.config.grad_clip,
|
||||
ignore_stopnet=True)
|
||||
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# main model optimizer step
|
||||
loss_dict["loss"].backward()
|
||||
grad_norm, _ = check_update(self.model,
|
||||
self.config.grad_clip,
|
||||
ignore_stopnet=True)
|
||||
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
||||
self.optimizer.step()
|
||||
|
||||
step_time = time.time() - step_start_time
|
||||
|
@ -469,17 +449,15 @@ class TrainerTTS:
|
|||
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||
if self.total_steps_done % self.config.print_step == 0:
|
||||
log_dict = {
|
||||
"max_spec_length": [batch["max_spec_length"],
|
||||
1], # value, precision
|
||||
"max_spec_length": [batch["max_spec_length"], 1], # value, precision
|
||||
"max_text_length": [batch["max_text_length"], 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
self.c_logger.print_train_step(batch_n_steps, step,
|
||||
self.total_steps_done, log_dict,
|
||||
loss_dict,
|
||||
self.keep_avg_train.avg_values)
|
||||
self.c_logger.print_train_step(
|
||||
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
||||
)
|
||||
|
||||
if self.args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
|
@ -491,8 +469,7 @@ class TrainerTTS:
|
|||
"step_time": step_time,
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
self.tb_logger.tb_train_step_stats(self.total_steps_done,
|
||||
iter_stats)
|
||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
||||
|
||||
if self.total_steps_done % self.config.save_step == 0:
|
||||
if self.config.checkpoint:
|
||||
|
@ -506,15 +483,12 @@ class TrainerTTS:
|
|||
self.output_path,
|
||||
model_loss=loss_dict["loss"],
|
||||
characters=self.model_characters,
|
||||
scaler=self.scaler.state_dict()
|
||||
if self.config.mixed_precision else None,
|
||||
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||
)
|
||||
# training visualizations
|
||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||
self.tb_logger.tb_train_audios(self.total_steps_done,
|
||||
{"TrainAudio": audios},
|
||||
self.ap.sample_rate)
|
||||
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
|
||||
self.total_steps_done += 1
|
||||
self.on_train_step_end()
|
||||
return outputs, loss_dict
|
||||
|
@ -523,35 +497,28 @@ class TrainerTTS:
|
|||
self.model.train()
|
||||
epoch_start_time = time.time()
|
||||
if self.use_cuda:
|
||||
batch_num_steps = int(
|
||||
len(self.train_loader.dataset) /
|
||||
(self.config.batch_size * self.num_gpus))
|
||||
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
|
||||
else:
|
||||
batch_num_steps = int(
|
||||
len(self.train_loader.dataset) / self.config.batch_size)
|
||||
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
||||
self.c_logger.print_train_start()
|
||||
loader_start_time = time.time()
|
||||
for cur_step, batch in enumerate(self.train_loader):
|
||||
_, _ = self.train_step(batch, batch_num_steps, cur_step,
|
||||
loader_start_time)
|
||||
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
# Plot self.epochs_done Stats
|
||||
if self.args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done,
|
||||
epoch_stats)
|
||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||
if self.config.tb_model_param_stats:
|
||||
self.tb_logger.tb_model_weights(self.model,
|
||||
self.total_steps_done)
|
||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||
|
||||
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||
with torch.no_grad():
|
||||
step_start_time = time.time()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||
outputs, loss_dict = self.model.eval_step(
|
||||
batch, self.criterion)
|
||||
outputs, loss_dict = self.model.eval_step(batch, self.criterion)
|
||||
|
||||
step_time = time.time() - step_start_time
|
||||
|
||||
|
@ -572,8 +539,7 @@ class TrainerTTS:
|
|||
self.keep_avg_eval.update_values(update_eval_values)
|
||||
|
||||
if self.config.print_eval:
|
||||
self.c_logger.print_eval_step(step, loss_dict,
|
||||
self.keep_avg_eval.avg_values)
|
||||
self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
|
||||
return outputs, loss_dict
|
||||
|
||||
def eval_epoch(self) -> None:
|
||||
|
@ -585,15 +551,13 @@ class TrainerTTS:
|
|||
# format data
|
||||
batch = self.format_batch(batch)
|
||||
loader_time = time.time() - loader_start_time
|
||||
self.keep_avg_eval.update_values({'avg_loader_time': loader_time})
|
||||
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
||||
outputs, _ = self.eval_step(batch, cur_step)
|
||||
# Plot epoch stats and samples from the last batch.
|
||||
if self.args.rank == 0:
|
||||
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done,
|
||||
{"EvalAudio": eval_audios},
|
||||
self.ap.sample_rate)
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
||||
|
||||
def test_run(self, ) -> None:
|
||||
print(" | > Synthesizing test sentences.")
|
||||
|
@ -608,9 +572,9 @@ class TrainerTTS:
|
|||
self.config,
|
||||
self.use_cuda,
|
||||
self.ap,
|
||||
speaker_id=cond_inputs['speaker_id'],
|
||||
x_vector=cond_inputs['x_vector'],
|
||||
style_wav=cond_inputs['style_wav'],
|
||||
speaker_id=cond_inputs["speaker_id"],
|
||||
x_vector=cond_inputs["x_vector"],
|
||||
style_wav=cond_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
|
@ -623,10 +587,8 @@ class TrainerTTS:
|
|||
"TestSentence_{}.wav".format(idx))
|
||||
self.ap.save_wav(wav, file_path)
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
model_outputs, self.ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
alignment, output_fig=False)
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
|
||||
self.config.audio["sample_rate"])
|
||||
|
@ -641,11 +603,11 @@ class TrainerTTS:
|
|||
if self.config.use_external_speaker_embedding_file
|
||||
and self.config.use_speaker_embedding else None)
|
||||
# setup style_mel
|
||||
if self.config.has('gst_style_input'):
|
||||
if self.config.has("gst_style_input"):
|
||||
style_wav = self.config.gst_style_input
|
||||
else:
|
||||
style_wav = None
|
||||
if style_wav is None and 'use_gst' in self.config and self.config.use_gst:
|
||||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print(
|
||||
|
@ -688,8 +650,7 @@ class TrainerTTS:
|
|||
for epoch in range(0, self.config.epochs):
|
||||
self.on_epoch_start()
|
||||
self.keep_avg_train = KeepAverage()
|
||||
self.keep_avg_eval = KeepAverage(
|
||||
) if self.config.run_eval else None
|
||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||
self.epochs_done = epoch
|
||||
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
||||
self.train_epoch()
|
||||
|
@ -698,8 +659,8 @@ class TrainerTTS:
|
|||
if epoch >= self.config.test_delay_epochs:
|
||||
self.test_run()
|
||||
self.c_logger.print_epoch_end(
|
||||
epoch, self.keep_avg_eval.avg_values
|
||||
if self.config.run_eval else self.keep_avg_train.avg_values)
|
||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||
)
|
||||
self.save_best_model()
|
||||
self.on_epoch_end()
|
||||
|
||||
|
@ -717,8 +678,7 @@ class TrainerTTS:
|
|||
self.model_characters,
|
||||
keep_all_best=self.config.keep_all_best,
|
||||
keep_after=self.config.keep_after,
|
||||
scaler=self.scaler.state_dict()
|
||||
if self.config.mixed_precision else None,
|
||||
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -93,7 +93,7 @@ class AlignTTSConfig(BaseTTSConfig):
|
|||
|
||||
# optimizer parameters
|
||||
optimizer: str = "Adam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = None
|
||||
lr_scheduler_params: dict = None
|
||||
lr: float = 1e-4
|
||||
|
@ -104,12 +104,13 @@ class AlignTTSConfig(BaseTTSConfig):
|
|||
max_seq_len: int = 200
|
||||
r: int = 1
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda:[
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
])
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
|
|
@ -91,9 +91,9 @@ class GlowTTSConfig(BaseTTSConfig):
|
|||
|
||||
# optimizer parameters
|
||||
optimizer: str = "RAdam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = "NoamLR"
|
||||
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||
grad_clip: float = 5.0
|
||||
lr: float = 1e-3
|
||||
|
||||
|
|
|
@ -171,7 +171,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
optimizer: str = MISSING
|
||||
optimizer_params: dict = MISSING
|
||||
# scheduler
|
||||
lr_scheduler: str = ''
|
||||
lr_scheduler: str = ""
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda:[])
|
||||
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||
|
|
|
@ -101,7 +101,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
|
||||
# optimizer parameters
|
||||
optimizer: str = "RAdam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = None
|
||||
lr_scheduler_params: dict = None
|
||||
lr: float = 1e-4
|
||||
|
@ -118,10 +118,12 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
r: int = 1 # DO NOT CHANGE
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda:[
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
])
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
|
|
@ -160,9 +160,9 @@ class TacotronConfig(BaseTTSConfig):
|
|||
|
||||
# optimizer parameters
|
||||
optimizer: str = "RAdam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = "NoamLR"
|
||||
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||
lr: float = 1e-4
|
||||
grad_clip: float = 5.0
|
||||
seq_len_norm: bool = False
|
||||
|
@ -178,13 +178,15 @@ class TacotronConfig(BaseTTSConfig):
|
|||
ga_alpha: float = 5.0
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda:[
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
])
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
||||
def check_values(self):
|
||||
if self.gradual_training:
|
||||
|
|
|
@ -44,22 +44,18 @@ def load_meta_data(datasets, eval_split=True):
|
|||
preprocessor = _get_preprocessor_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = preprocessor(root_path, meta_file_train)
|
||||
print(
|
||||
f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}"
|
||||
)
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = preprocessor(root_path, meta_file_val)
|
||||
else:
|
||||
meta_data_eval, meta_data_train = split_dataset(
|
||||
meta_data_train)
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
meta_data_train_all += meta_data_train
|
||||
# load attention masks for duration predictor training
|
||||
if dataset.meta_file_attn_mask:
|
||||
meta_data = dict(
|
||||
load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||
for idx, ins in enumerate(meta_data_train_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_train_all[idx].append(attn_file)
|
||||
|
|
|
@ -506,5 +506,10 @@ class AlignTTSLoss(nn.Module):
|
|||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||
loss = self.spec_loss_alpha * spec_loss + self.ssim_alpha * ssim_loss + self.dur_loss_alpha * dur_loss + self.mdn_alpha * mdn_loss
|
||||
loss = (
|
||||
self.spec_loss_alpha * spec_loss
|
||||
+ self.ssim_alpha * ssim_loss
|
||||
+ self.dur_loss_alpha * dur_loss
|
||||
+ self.mdn_alpha * mdn_loss
|
||||
)
|
||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||
|
|
|
@ -72,19 +72,9 @@ class AlignTTS(nn.Module):
|
|||
hidden_channels=256,
|
||||
hidden_channels_dp=256,
|
||||
encoder_type="fftransformer",
|
||||
encoder_params={
|
||||
"hidden_channels_ffn": 1024,
|
||||
"num_heads": 2,
|
||||
"num_layers": 6,
|
||||
"dropout_p": 0.1
|
||||
},
|
||||
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||
decoder_type="fftransformer",
|
||||
decoder_params={
|
||||
"hidden_channels_ffn": 1024,
|
||||
"num_heads": 2,
|
||||
"num_layers": 6,
|
||||
"dropout_p": 0.1
|
||||
},
|
||||
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||
length_scale=1,
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
|
@ -93,14 +83,11 @@ class AlignTTS(nn.Module):
|
|||
|
||||
super().__init__()
|
||||
self.phase = -1
|
||||
self.length_scale = float(length_scale) if isinstance(
|
||||
length_scale, int) else length_scale
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
||||
encoder_params, c_in_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
|
||||
decoder_params)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
||||
|
||||
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
|
@ -121,9 +108,9 @@ class AlignTTS(nn.Module):
|
|||
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
|
||||
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss(
|
||||
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2),
|
||||
dim=-1) # B, L, T
|
||||
exponential = -0.5 * torch.mean(
|
||||
torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1
|
||||
) # B, L, T
|
||||
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
|
||||
return logp
|
||||
|
||||
|
@ -157,9 +144,7 @@ class AlignTTS(nn.Module):
|
|||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
||||
2)).transpose(1, 2)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
|
@ -193,8 +178,7 @@ class AlignTTS(nn.Module):
|
|||
x_emb = torch.transpose(x_emb, 1, -1)
|
||||
|
||||
# compute sequence masks
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
||||
1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
||||
|
||||
# encoder pass
|
||||
o_en = self.encoder(x_emb, x_mask)
|
||||
|
@ -207,8 +191,7 @@ class AlignTTS(nn.Module):
|
|||
return o_en, o_en_dp, x_mask, g
|
||||
|
||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en_dp.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
# positional encoding
|
||||
|
@ -224,13 +207,13 @@ class AlignTTS(nn.Module):
|
|||
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
|
||||
# MAS potentials and alignment
|
||||
mu, log_sigma = self.mdn_block(o_en)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en.dtype)
|
||||
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
|
||||
y_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
|
||||
return dr_mas, mu, log_sigma, logp
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None): # pylint: disable=unused-argument
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
|
@ -240,83 +223,58 @@ class AlignTTS(nn.Module):
|
|||
g: [B, C]
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
||||
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
|
||||
if phase == 0:
|
||||
# train encoder and MDN
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en_dp.dtype)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
|
||||
elif phase == 1:
|
||||
# train decoder
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en.detach(),
|
||||
o_en_dp.detach(),
|
||||
dr_mas.detach(),
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g)
|
||||
elif phase == 2:
|
||||
# train the whole except duration predictor
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
elif phase == 3:
|
||||
# train duration predictor
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(x, x_mask)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
o_dr_log = o_dr_log.squeeze(1)
|
||||
else:
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
o_dr_log = o_dr_log.squeeze(1)
|
||||
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
|
||||
outputs = {
|
||||
'model_outputs': o_de.transpose(1, 2),
|
||||
'alignments': attn,
|
||||
'durations_log': o_dr_log,
|
||||
'durations_mas_log': dr_mas_log,
|
||||
'mu': mu,
|
||||
'log_sigma': log_sigma,
|
||||
'logp': logp
|
||||
"model_outputs": o_de.transpose(1, 2),
|
||||
"alignments": attn,
|
||||
"durations_log": o_dr_log,
|
||||
"durations_mas_log": dr_mas_log,
|
||||
"mu": mu,
|
||||
"log_sigma": log_sigma,
|
||||
"logp": logp,
|
||||
}
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, cond_input={'x_vectors': None}): # pylint: disable=unused-argument
|
||||
def inference(self, x, cond_input={"x_vectors": None}): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
g: [B, C]
|
||||
"""
|
||||
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
||||
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
# pad input to prevent dropping the last word
|
||||
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
|
@ -326,46 +284,40 @@ class AlignTTS(nn.Module):
|
|||
# duration predictor pass
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
o_dr,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
outputs = {'model_outputs': o_de.transpose(1, 2), 'alignments': attn}
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
text_input = batch['text_input']
|
||||
text_lengths = batch['text_lengths']
|
||||
mel_input = batch['mel_input']
|
||||
mel_lengths = batch['mel_lengths']
|
||||
x_vectors = batch['x_vectors']
|
||||
speaker_ids = batch['speaker_ids']
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
x_vectors = batch["x_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
|
||||
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids}
|
||||
cond_input = {"x_vectors": x_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input, self.phase)
|
||||
loss_dict = criterion(
|
||||
outputs['logp'],
|
||||
outputs['model_outputs'],
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
outputs['durations_log'],
|
||||
outputs['durations_mas_log'],
|
||||
text_lengths,
|
||||
phase=self.phase,
|
||||
)
|
||||
outputs["logp"],
|
||||
outputs["model_outputs"],
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
outputs["durations_mas_log"],
|
||||
text_lengths,
|
||||
phase=self.phase,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'],
|
||||
binary=True)
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
model_outputs = outputs['model_outputs']
|
||||
alignments = outputs['alignments']
|
||||
mel_input = batch['mel_input']
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
|
@ -387,7 +339,9 @@ class AlignTTS(nn.Module):
|
|||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
|
|
|
@ -38,6 +38,7 @@ class GlowTTS(nn.Module):
|
|||
encoder_params (dict): encoder module parameters.
|
||||
speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
|
@ -132,17 +133,17 @@ class GlowTTS(nn.Module):
|
|||
@staticmethod
|
||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||
# compute final values with the computed alignment
|
||||
y_mean = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
y_log_scale = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
|
||||
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
# compute total duration with adjustment
|
||||
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
||||
return y_mean, y_log_scale, o_attn_dur
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={'x_vectors':None}):
|
||||
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={"x_vectors": None}):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T]
|
||||
|
@ -154,7 +155,7 @@ class GlowTTS(nn.Module):
|
|||
y_max_length = y.size(2)
|
||||
y = y.transpose(1, 2)
|
||||
# norm speaker embeddings
|
||||
g = cond_input['x_vectors']
|
||||
g = cond_input["x_vectors"]
|
||||
if g is not None:
|
||||
if self.speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
|
@ -162,54 +163,38 @@ class GlowTTS(nn.Module):
|
|||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
x_lengths,
|
||||
g=g)
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(
|
||||
y, y_lengths, y_max_length, None)
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
||||
# create masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||
1).to(x_mask.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
# decoder pass
|
||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||
# find the alignment path
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
|
||||
[1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
|
||||
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
|
||||
z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
|
||||
[1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||
attn = maximum_path(logp,
|
||||
attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||
attn, o_mean, o_log_scale, x_mask)
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
outputs = {
|
||||
'model_outputs': z,
|
||||
'logdet': logdet,
|
||||
'y_mean': y_mean,
|
||||
'y_log_scale': y_log_scale,
|
||||
'alignments': attn,
|
||||
'durations_log': o_dur_log,
|
||||
'total_durations_log': o_attn_dur
|
||||
"model_outputs": z,
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean,
|
||||
"y_log_scale": y_log_scale,
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log,
|
||||
"total_durations_log": o_attn_dur,
|
||||
}
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_with_MAS(self,
|
||||
x,
|
||||
x_lengths,
|
||||
y=None,
|
||||
y_lengths=None,
|
||||
attn=None,
|
||||
g=None):
|
||||
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
||||
"""
|
||||
It's similar to the teacher forcing in Tacotron.
|
||||
It was proposed in: https://arxiv.org/abs/2104.05557
|
||||
|
@ -229,33 +214,24 @@ class GlowTTS(nn.Module):
|
|||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
x_lengths,
|
||||
g=g)
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(
|
||||
y, y_lengths, y_max_length, None)
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
||||
# create masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||
1).to(x_mask.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
# decoder pass
|
||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||
# find the alignment path between z and encoder output
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
|
||||
[1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
|
||||
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
|
||||
z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
|
||||
[1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||
attn, o_mean, o_log_scale, x_mask)
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
|
||||
# get predited aligned distribution
|
||||
|
@ -264,13 +240,13 @@ class GlowTTS(nn.Module):
|
|||
# reverse the decoder and predict using the aligned distribution
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
outputs = {
|
||||
'model_outputs': y,
|
||||
'logdet': logdet,
|
||||
'y_mean': y_mean,
|
||||
'y_log_scale': y_log_scale,
|
||||
'alignments': attn,
|
||||
'durations_log': o_dur_log,
|
||||
'total_durations_log': o_attn_dur
|
||||
"model_outputs": y,
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean,
|
||||
"y_log_scale": y_log_scale,
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log,
|
||||
"total_durations_log": o_attn_dur,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -290,8 +266,7 @@ class GlowTTS(nn.Module):
|
|||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||
1).to(y.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
|
||||
|
||||
# decoder pass
|
||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||
|
@ -310,37 +285,31 @@ class GlowTTS(nn.Module):
|
|||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
x_lengths,
|
||||
g=g)
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# compute output durations
|
||||
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_max_length = None
|
||||
# compute masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||
1).to(x_mask.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
# compute attention mask
|
||||
attn = generate_path(w_ceil.squeeze(1),
|
||||
attn_mask.squeeze(1)).unsqueeze(1)
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||
attn, o_mean, o_log_scale, x_mask)
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
|
||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
|
||||
self.inference_noise_scale) * y_mask
|
||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask
|
||||
# decoder pass
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
outputs = {
|
||||
'model_outputs': y,
|
||||
'logdet': logdet,
|
||||
'y_mean': y_mean,
|
||||
'y_log_scale': y_log_scale,
|
||||
'alignments': attn,
|
||||
'durations_log': o_dur_log,
|
||||
'total_durations_log': o_attn_dur
|
||||
"model_outputs": y,
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean,
|
||||
"y_log_scale": y_log_scale,
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log,
|
||||
"total_durations_log": o_attn_dur,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -351,32 +320,34 @@ class GlowTTS(nn.Module):
|
|||
batch (dict): [description]
|
||||
criterion (nn.Module): [description]
|
||||
"""
|
||||
text_input = batch['text_input']
|
||||
text_lengths = batch['text_lengths']
|
||||
mel_input = batch['mel_input']
|
||||
mel_lengths = batch['mel_lengths']
|
||||
x_vectors = batch['x_vectors']
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
x_vectors = batch["x_vectors"]
|
||||
|
||||
outputs = self.forward(text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
cond_input={"x_vectors": x_vectors})
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input={"x_vectors": x_vectors})
|
||||
|
||||
loss_dict = criterion(outputs['model_outputs'], outputs['y_mean'],
|
||||
outputs['y_log_scale'], outputs['logdet'],
|
||||
mel_lengths, outputs['durations_log'],
|
||||
outputs['total_durations_log'], text_lengths)
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
outputs["y_mean"],
|
||||
outputs["y_log_scale"],
|
||||
outputs["logdet"],
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
outputs["total_durations_log"],
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'], binary=True)
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
model_outputs = outputs['model_outputs']
|
||||
alignments = outputs['alignments']
|
||||
mel_input = batch['mel_input']
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
|
@ -400,8 +371,7 @@ class GlowTTS(nn.Module):
|
|||
|
||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||
if y_max_length is not None:
|
||||
y_max_length = (y_max_length //
|
||||
self.num_squeeze) * self.num_squeeze
|
||||
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
||||
y = y[:, :, :y_max_length]
|
||||
if attn is not None:
|
||||
attn = attn[:, :, :, :y_max_length]
|
||||
|
@ -411,7 +381,9 @@ class GlowTTS(nn.Module):
|
|||
def store_inverse(self):
|
||||
self.decoder.store_inverse()
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
|
|
|
@ -49,12 +49,7 @@ class SpeedySpeech(nn.Module):
|
|||
positional_encoding=True,
|
||||
length_scale=1,
|
||||
encoder_type="residual_conv_bn",
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
},
|
||||
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
|
||||
decoder_type="residual_conv_bn",
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
|
@ -68,17 +63,13 @@ class SpeedySpeech(nn.Module):
|
|||
):
|
||||
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(
|
||||
length_scale, int) else length_scale
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
||||
encoder_params, c_in_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||
if positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
|
||||
decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels +
|
||||
c_in_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
# speaker embedding layer
|
||||
|
@ -105,9 +96,7 @@ class SpeedySpeech(nn.Module):
|
|||
"""
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
||||
o_en_ex = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
||||
2)).transpose(1, 2)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
|
@ -141,8 +130,7 @@ class SpeedySpeech(nn.Module):
|
|||
x_emb = torch.transpose(x_emb, 1, -1)
|
||||
|
||||
# compute sequence masks
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
||||
1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
||||
|
||||
# encoder pass
|
||||
o_en = self.encoder(x_emb, x_mask)
|
||||
|
@ -155,8 +143,7 @@ class SpeedySpeech(nn.Module):
|
|||
return o_en, o_en_dp, x_mask, g
|
||||
|
||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en_dp.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
# positional encoding
|
||||
|
@ -169,15 +156,9 @@ class SpeedySpeech(nn.Module):
|
|||
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||
return o_de, attn.transpose(1, 2)
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
x_lengths,
|
||||
y_lengths,
|
||||
dr,
|
||||
cond_input={
|
||||
'x_vectors': None,
|
||||
'speaker_ids': None
|
||||
}): # pylint: disable=unused-argument
|
||||
def forward(
|
||||
self, x, x_lengths, y_lengths, dr, cond_input={"x_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
TODO: speaker embedding for speaker_ids
|
||||
Shapes:
|
||||
|
@ -187,91 +168,68 @@ class SpeedySpeech(nn.Module):
|
|||
dr: [B, T_max]
|
||||
g: [B, C]
|
||||
"""
|
||||
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
||||
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
outputs = {
|
||||
'model_outputs': o_de.transpose(1, 2),
|
||||
'durations_log': o_dr_log.squeeze(1),
|
||||
'alignments': attn
|
||||
}
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
|
||||
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
||||
return outputs
|
||||
|
||||
def inference(self,
|
||||
x,
|
||||
cond_input={
|
||||
'x_vectors': None,
|
||||
'speaker_ids': None
|
||||
}): # pylint: disable=unused-argument
|
||||
def inference(self, x, cond_input={"x_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
g: [B, C]
|
||||
"""
|
||||
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
||||
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
# input sequence should be greated than the max convolution size
|
||||
inference_padding = 5
|
||||
if x.shape[1] < 13:
|
||||
inference_padding += 13 - x.shape[1]
|
||||
# pad input to prevent dropping the last word
|
||||
x = torch.nn.functional.pad(x,
|
||||
pad=(0, inference_padding),
|
||||
mode="constant",
|
||||
value=0)
|
||||
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
o_dr,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
outputs = {
|
||||
'model_outputs': o_de.transpose(1, 2),
|
||||
'alignments': attn,
|
||||
'durations_log': None
|
||||
}
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "durations_log": None}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
text_input = batch['text_input']
|
||||
text_lengths = batch['text_lengths']
|
||||
mel_input = batch['mel_input']
|
||||
mel_lengths = batch['mel_lengths']
|
||||
x_vectors = batch['x_vectors']
|
||||
speaker_ids = batch['speaker_ids']
|
||||
durations = batch['durations']
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
x_vectors = batch["x_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
durations = batch["durations"]
|
||||
|
||||
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_lengths,
|
||||
durations, cond_input)
|
||||
cond_input = {"x_vectors": x_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, cond_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(outputs['model_outputs'], mel_input,
|
||||
mel_lengths, outputs['durations_log'],
|
||||
torch.log(1 + durations), text_lengths)
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
torch.log(1 + durations),
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'],
|
||||
binary=True)
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
model_outputs = outputs['model_outputs']
|
||||
alignments = outputs['alignments']
|
||||
mel_input = batch['mel_input']
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
|
@ -293,7 +251,9 @@ class SpeedySpeech(nn.Module):
|
|||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
|
|
|
@ -50,6 +50,7 @@ class Tacotron(TacotronAbstract):
|
|||
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||
Defaults to `[]`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
|
@ -78,7 +79,7 @@ class Tacotron(TacotronAbstract):
|
|||
use_gst=False,
|
||||
gst=None,
|
||||
memory_size=5,
|
||||
gradual_training=[]
|
||||
gradual_training=[],
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
|
@ -106,15 +107,14 @@ class Tacotron(TacotronAbstract):
|
|||
speaker_embedding_dim,
|
||||
use_gst,
|
||||
gst,
|
||||
gradual_training
|
||||
gradual_training,
|
||||
)
|
||||
|
||||
# speaker embedding layers
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embedding_dim = 256
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers,
|
||||
speaker_embedding_dim)
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
|
@ -145,8 +145,7 @@ class Tacotron(TacotronAbstract):
|
|||
separate_stopnet,
|
||||
)
|
||||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
||||
postnet_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
|
||||
|
||||
# setup prenet dropout
|
||||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||
|
@ -183,12 +182,7 @@ class Tacotron(TacotronAbstract):
|
|||
separate_stopnet,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_specs=None,
|
||||
mel_lengths=None,
|
||||
cond_input=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
|
||||
"""
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
|
@ -197,100 +191,87 @@ class Tacotron(TacotronAbstract):
|
|||
mel_lengths: [B]
|
||||
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
||||
"""
|
||||
outputs = {
|
||||
'alignments_backward': None,
|
||||
'decoder_outputs_backward': None
|
||||
}
|
||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
# B x T_in x embed_dim
|
||||
inputs = self.embedding(text)
|
||||
# B x T_in x encoder_in_features
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# sequence masking
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(
|
||||
encoder_outputs)
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
# global style token
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs,
|
||||
cond_input['x_vectors'])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["x_vectors"])
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:,
|
||||
None]
|
||||
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(
|
||||
encoder_outputs, speaker_embeddings)
|
||||
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
# decoder_outputs: B x decoder_in_features x T_out
|
||||
# alignments: B x T_in x encoder_in_features
|
||||
# stop_tokens: B x T_in
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, input_mask)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(
|
||||
1).expand_as(decoder_outputs)
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||
# B x T_out x decoder_in_features
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(
|
||||
2).expand_as(postnet_outputs)
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs)
|
||||
# B x T_out x posnet_dim
|
||||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
# B x T_out x decoder_in_features
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(
|
||||
mel_specs, encoder_outputs, input_mask)
|
||||
outputs['alignments_backward'] = alignments_backward
|
||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
outputs["alignments_backward"] = alignments_backward
|
||||
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||
mel_specs, encoder_outputs, alignments, input_mask)
|
||||
outputs['alignments_backward'] = alignments_backward
|
||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||
outputs.update({
|
||||
'model_outputs': postnet_outputs,
|
||||
'decoder_outputs': decoder_outputs,
|
||||
'alignments': alignments,
|
||||
'stop_tokens': stop_tokens
|
||||
})
|
||||
mel_specs, encoder_outputs, alignments, input_mask
|
||||
)
|
||||
outputs["alignments_backward"] = alignments_backward
|
||||
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": postnet_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self,
|
||||
text_input,
|
||||
cond_input=None):
|
||||
def inference(self, text_input, cond_input=None):
|
||||
inputs = self.embedding(text_input)
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'],
|
||||
cond_input['x_vectors'])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["x_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
||||
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(
|
||||
encoder_outputs, speaker_embeddings)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||
outputs = {
|
||||
'model_outputs': postnet_outputs,
|
||||
'decoder_outputs': decoder_outputs,
|
||||
'alignments': alignments,
|
||||
'stop_tokens': stop_tokens
|
||||
"model_outputs": postnet_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -301,64 +282,61 @@ class Tacotron(TacotronAbstract):
|
|||
batch ([type]): [description]
|
||||
criterion ([type]): [description]
|
||||
"""
|
||||
text_input = batch['text_input']
|
||||
text_lengths = batch['text_lengths']
|
||||
mel_input = batch['mel_input']
|
||||
mel_lengths = batch['mel_lengths']
|
||||
linear_input = batch['linear_input']
|
||||
stop_targets = batch['stop_targets']
|
||||
speaker_ids = batch['speaker_ids']
|
||||
x_vectors = batch['x_vectors']
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
linear_input = batch["linear_input"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
x_vectors = batch["x_vectors"]
|
||||
|
||||
# forward pass model
|
||||
outputs = self.forward(text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
cond_input={
|
||||
'speaker_ids': speaker_ids,
|
||||
'x_vectors': x_vectors
|
||||
})
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
cond_input={"speaker_ids": speaker_ids, "x_vectors": x_vectors},
|
||||
)
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % self.decoder.r != 0:
|
||||
alignment_lengths = (
|
||||
mel_lengths +
|
||||
(self.decoder.r -
|
||||
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
|
||||
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
|
||||
) // self.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // self.decoder.r
|
||||
|
||||
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input,
|
||||
mel_lengths, cond_input)
|
||||
cond_input = {"speaker_ids": speaker_ids, "x_vectors": x_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
outputs['model_outputs'],
|
||||
outputs['decoder_outputs'],
|
||||
outputs["model_outputs"],
|
||||
outputs["decoder_outputs"],
|
||||
mel_input,
|
||||
linear_input,
|
||||
outputs['stop_tokens'],
|
||||
outputs["stop_tokens"],
|
||||
stop_targets,
|
||||
mel_lengths,
|
||||
outputs['decoder_outputs_backward'],
|
||||
outputs['alignments'],
|
||||
outputs["decoder_outputs_backward"],
|
||||
outputs["alignments"],
|
||||
alignment_lengths,
|
||||
outputs['alignments_backward'],
|
||||
outputs["alignments_backward"],
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'])
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap, batch, outputs):
|
||||
postnet_outputs = outputs['model_outputs']
|
||||
alignments = outputs['alignments']
|
||||
alignments_backward = outputs['alignments_backward']
|
||||
mel_input = batch['mel_input']
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
alignments_backward = outputs["alignments_backward"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
|
@ -371,8 +349,7 @@ class Tacotron(TacotronAbstract):
|
|||
}
|
||||
|
||||
if self.bidirectional_decoder or self.double_decoder_consistency:
|
||||
figures["alignment_backward"] = plot_alignment(
|
||||
alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_spectrogram(pred_spec.T)
|
||||
|
@ -382,4 +359,4 @@ class Tacotron(TacotronAbstract):
|
|||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap, batch, outputs):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
|
|
@ -49,49 +49,70 @@ class Tacotron2(TacotronAbstract):
|
|||
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||
Defaults to `[]`.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
gradual_training=[]):
|
||||
super().__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
decoder_output_dim, attn_type, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout,
|
||||
prenet_dropout_at_inference, forward_attn,
|
||||
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||
separate_stopnet, bidirectional_decoder,
|
||||
double_decoder_consistency, ddc_r,
|
||||
encoder_in_features, decoder_in_features,
|
||||
speaker_embedding_dim, use_gst, gst, gradual_training)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
gradual_training=[],
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim,
|
||||
decoder_output_dim,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
prenet_dropout_at_inference,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
bidirectional_decoder,
|
||||
double_decoder_consistency,
|
||||
ddc_r,
|
||||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
use_gst,
|
||||
gst,
|
||||
gradual_training,
|
||||
)
|
||||
|
||||
# speaker embedding layer
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embedding_dim = 512
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers,
|
||||
speaker_embedding_dim)
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
|
@ -162,12 +183,7 @@ class Tacotron2(TacotronAbstract):
|
|||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_specs=None,
|
||||
mel_lengths=None,
|
||||
cond_input=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
|
||||
"""
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
|
@ -176,10 +192,7 @@ class Tacotron2(TacotronAbstract):
|
|||
mel_lengths: [B]
|
||||
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
||||
"""
|
||||
outputs = {
|
||||
'alignments_backward': None,
|
||||
'decoder_outputs_backward': None
|
||||
}
|
||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||
# compute mask for padding
|
||||
# B x T_in_max (boolean)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
|
@ -189,55 +202,49 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs,
|
||||
cond_input['x_vectors'])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["x_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:,
|
||||
None]
|
||||
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(
|
||||
encoder_outputs, speaker_embeddings)
|
||||
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(
|
||||
encoder_outputs)
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
|
||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, input_mask)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
|
||||
# sequence masking
|
||||
if mel_lengths is not None:
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(
|
||||
1).expand_as(decoder_outputs)
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||
# B x mel_dim x T_out
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = decoder_outputs + postnet_outputs
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(
|
||||
1).expand_as(postnet_outputs)
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
|
||||
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||
decoder_outputs, postnet_outputs, alignments)
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(
|
||||
mel_specs, encoder_outputs, input_mask)
|
||||
outputs['alignments_backward'] = alignments_backward
|
||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
outputs["alignments_backward"] = alignments_backward
|
||||
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||
mel_specs, encoder_outputs, alignments, input_mask)
|
||||
outputs['alignments_backward'] = alignments_backward
|
||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||
outputs.update({
|
||||
'model_outputs': postnet_outputs,
|
||||
'decoder_outputs': decoder_outputs,
|
||||
'alignments': alignments,
|
||||
'stop_tokens': stop_tokens
|
||||
})
|
||||
mel_specs, encoder_outputs, alignments, input_mask
|
||||
)
|
||||
outputs["alignments_backward"] = alignments_backward
|
||||
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": postnet_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -247,29 +254,25 @@ class Tacotron2(TacotronAbstract):
|
|||
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'],
|
||||
cond_input['x_vectors'])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["x_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
||||
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
|
||||
else:
|
||||
x_vector = cond_input['x_vectors']
|
||||
x_vector = cond_input["x_vectors"]
|
||||
|
||||
encoder_outputs = self._concat_speaker_embedding(
|
||||
encoder_outputs, x_vector)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, x_vector)
|
||||
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = decoder_outputs + postnet_outputs
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||
decoder_outputs, postnet_outputs, alignments)
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
||||
outputs = {
|
||||
'model_outputs': postnet_outputs,
|
||||
'decoder_outputs': decoder_outputs,
|
||||
'alignments': alignments,
|
||||
'stop_tokens': stop_tokens
|
||||
"model_outputs": postnet_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -280,64 +283,61 @@ class Tacotron2(TacotronAbstract):
|
|||
batch ([type]): [description]
|
||||
criterion ([type]): [description]
|
||||
"""
|
||||
text_input = batch['text_input']
|
||||
text_lengths = batch['text_lengths']
|
||||
mel_input = batch['mel_input']
|
||||
mel_lengths = batch['mel_lengths']
|
||||
linear_input = batch['linear_input']
|
||||
stop_targets = batch['stop_targets']
|
||||
speaker_ids = batch['speaker_ids']
|
||||
x_vectors = batch['x_vectors']
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
linear_input = batch["linear_input"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
x_vectors = batch["x_vectors"]
|
||||
|
||||
# forward pass model
|
||||
outputs = self.forward(text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
cond_input={
|
||||
'speaker_ids': speaker_ids,
|
||||
'x_vectors': x_vectors
|
||||
})
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
cond_input={"speaker_ids": speaker_ids, "x_vectors": x_vectors},
|
||||
)
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % self.decoder.r != 0:
|
||||
alignment_lengths = (
|
||||
mel_lengths +
|
||||
(self.decoder.r -
|
||||
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
|
||||
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
|
||||
) // self.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // self.decoder.r
|
||||
|
||||
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input,
|
||||
mel_lengths, cond_input)
|
||||
cond_input = {"speaker_ids": speaker_ids, "x_vectors": x_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
outputs['model_outputs'],
|
||||
outputs['decoder_outputs'],
|
||||
outputs["model_outputs"],
|
||||
outputs["decoder_outputs"],
|
||||
mel_input,
|
||||
linear_input,
|
||||
outputs['stop_tokens'],
|
||||
outputs["stop_tokens"],
|
||||
stop_targets,
|
||||
mel_lengths,
|
||||
outputs['decoder_outputs_backward'],
|
||||
outputs['alignments'],
|
||||
outputs["decoder_outputs_backward"],
|
||||
outputs["alignments"],
|
||||
alignment_lengths,
|
||||
outputs['alignments_backward'],
|
||||
outputs["alignments_backward"],
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'])
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap, batch, outputs):
|
||||
postnet_outputs = outputs['model_outputs']
|
||||
alignments = outputs['alignments']
|
||||
alignments_backward = outputs['alignments_backward']
|
||||
mel_input = batch['mel_input']
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
alignments_backward = outputs["alignments_backward"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
|
@ -350,8 +350,7 @@ class Tacotron2(TacotronAbstract):
|
|||
}
|
||||
|
||||
if self.bidirectional_decoder or self.double_decoder_consistency:
|
||||
figures["alignment_backward"] = plot_alignment(
|
||||
alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
|
|
|
@ -37,7 +37,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
gradual_training=[]
|
||||
gradual_training=[],
|
||||
):
|
||||
"""Abstract Tacotron class"""
|
||||
super().__init__()
|
||||
|
|
|
@ -58,25 +58,16 @@ def numpy_to_tf(np_array, dtype):
|
|||
|
||||
|
||||
def compute_style_mel(style_wav, ap, cuda=False):
|
||||
style_mel = torch.FloatTensor(
|
||||
ap.melspectrogram(ap.load_wav(style_wav,
|
||||
sr=ap.sample_rate))).unsqueeze(0)
|
||||
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
|
||||
if cuda:
|
||||
return style_mel.cuda()
|
||||
return style_mel
|
||||
|
||||
|
||||
def run_model_torch(model,
|
||||
inputs,
|
||||
speaker_id=None,
|
||||
style_mel=None,
|
||||
x_vector=None):
|
||||
outputs = model.inference(inputs,
|
||||
cond_input={
|
||||
'speaker_ids': speaker_id,
|
||||
'x_vector': x_vector,
|
||||
'style_mel': style_mel
|
||||
})
|
||||
def run_model_torch(model, inputs, speaker_id=None, style_mel=None, x_vector=None):
|
||||
outputs = model.inference(
|
||||
inputs, cond_input={"speaker_ids": speaker_id, "x_vector": x_vector, "style_mel": style_mel}
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -86,18 +77,15 @@ def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
|||
if speaker_id is not None:
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
|
||||
# TODO: handle multispeaker case
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
inputs, training=False)
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(
|
||||
" [!] GST inference not implemented for TfLite")
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(
|
||||
" [!] Multi-Speaker not implemented for TfLite")
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
|
||||
# get input and output details
|
||||
input_details = model.get_input_details()
|
||||
output_details = model.get_output_details()
|
||||
|
@ -131,7 +119,7 @@ def parse_outputs_tflite(postnet_output, decoder_output):
|
|||
|
||||
|
||||
def trim_silence(wav, ap):
|
||||
return wav[:ap.find_endpoint(wav)]
|
||||
return wav[: ap.find_endpoint(wav)]
|
||||
|
||||
|
||||
def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||
|
@ -154,8 +142,7 @@ def speaker_id_to_torch(speaker_id, cuda=False):
|
|||
def embedding_to_torch(x_vector, cuda=False):
|
||||
if x_vector is not None:
|
||||
x_vector = np.asarray(x_vector)
|
||||
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(
|
||||
torch.FloatTensor)
|
||||
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(torch.FloatTensor)
|
||||
if cuda:
|
||||
return x_vector.cuda()
|
||||
return x_vector
|
||||
|
@ -172,8 +159,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
|||
"""
|
||||
wavs = []
|
||||
for idx, spec in enumerate(inputs):
|
||||
wav_len = (input_lens[idx] *
|
||||
ap.hop_length) - ap.hop_length # inverse librosa padding
|
||||
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
|
||||
wav = inv_spectrogram(spec, ap, CONFIG)
|
||||
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
|
||||
wavs.append(wav[:wav_len])
|
||||
|
@ -241,23 +227,21 @@ def synthesis(
|
|||
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||
# synthesize voice
|
||||
if backend == "torch":
|
||||
outputs = run_model_torch(model,
|
||||
text_inputs,
|
||||
speaker_id,
|
||||
style_mel,
|
||||
x_vector=x_vector)
|
||||
model_outputs = outputs['model_outputs']
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, x_vector=x_vector)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
elif backend == "tf":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel)
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||
)
|
||||
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
elif backend == "tflite":
|
||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel)
|
||||
model_outputs, decoder_output = parse_outputs_tflite(
|
||||
postnet_output, decoder_output)
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||
)
|
||||
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
|
@ -267,9 +251,9 @@ def synthesis(
|
|||
if do_trim_silence:
|
||||
wav = trim_silence(wav, ap)
|
||||
return_dict = {
|
||||
'wav': wav,
|
||||
'alignments': outputs['alignments'],
|
||||
'model_outputs': model_outputs,
|
||||
'text_inputs': text_inputs
|
||||
"wav": wav,
|
||||
"alignments": outputs["alignments"],
|
||||
"model_outputs": model_outputs,
|
||||
"text_inputs": text_inputs,
|
||||
}
|
||||
return return_dict
|
||||
|
|
|
@ -30,16 +30,16 @@ def init_arguments(argv):
|
|||
parser.add_argument(
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help=("Training output folder to continue training. Used to continue "
|
||||
"a training. If it is used, 'config_path' is ignored."),
|
||||
help=(
|
||||
"Training output folder to continue training. Used to continue "
|
||||
"a training. If it is used, 'config_path' is ignored."
|
||||
),
|
||||
default="",
|
||||
required="--config_path" not in argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restore_path",
|
||||
type=str,
|
||||
help="Model file to be restored. Use to finetune a model.",
|
||||
default="")
|
||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best_path",
|
||||
type=str,
|
||||
|
@ -49,23 +49,12 @@ def init_arguments(argv):
|
|||
),
|
||||
default="",
|
||||
)
|
||||
parser.add_argument("--config_path",
|
||||
type=str,
|
||||
help="Path to config file for training.",
|
||||
required="--continue_path" not in argv)
|
||||
parser.add_argument("--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Do not verify commit integrity to run training.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="DISTRIBUTED: process rank for distributed training.")
|
||||
parser.add_argument("--group_id",
|
||||
type=str,
|
||||
default="",
|
||||
help="DISTRIBUTED: process group id.")
|
||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
|
||||
)
|
||||
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
|
||||
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
|
||||
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -160,8 +149,7 @@ def process_args(args):
|
|||
print(" > Mixed precision mode is ON")
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = create_experiment_folder(config.output_path,
|
||||
config.run_name, args.debug)
|
||||
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
# setup rank 0 process in distributed training
|
||||
tb_logger = None
|
||||
|
@ -182,8 +170,7 @@ def process_args(args):
|
|||
os.chmod(experiment_path, 0o775)
|
||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>",
|
||||
0)
|
||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||
|
||||
|
|
|
@ -234,8 +234,8 @@ class Synthesizer(object):
|
|||
use_griffin_lim=use_gl,
|
||||
x_vector=speaker_embedding,
|
||||
)
|
||||
waveform = outputs['wav']
|
||||
mel_postnet_spec = outputs['model_outputs']
|
||||
waveform = outputs["wav"]
|
||||
mel_postnet_spec = outputs["model_outputs"]
|
||||
if not use_gl:
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||
|
|
|
@ -44,8 +44,6 @@ run_cli(command_train)
|
|||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
)
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
|
|
|
@ -45,8 +45,6 @@ run_cli(command_train)
|
|||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
)
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
|
|
|
@ -45,8 +45,6 @@ run_cli(command_train)
|
|||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
)
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
|
|
|
@ -45,8 +45,6 @@ run_cli(command_train)
|
|||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
)
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
|
|
|
@ -44,8 +44,6 @@ run_cli(command_train)
|
|||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
)
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
|
|
|
@ -21,7 +21,6 @@ config = MelganConfig(
|
|||
print_step=1,
|
||||
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
||||
print_eval=True,
|
||||
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
||||
data_path="tests/data/ljspeech",
|
||||
output_path=output_path,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue