mirror of https://github.com/coqui-ai/TTS.git
make style
This commit is contained in:
parent
469d2e620a
commit
b500338faa
|
@ -153,15 +153,9 @@ def inference(
|
||||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||||
|
|
||||||
elif "tacotron" in model_name:
|
elif "tacotron" in model_name:
|
||||||
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': speaker_embeddings}
|
cond_input = {"speaker_ids": speaker_ids, "x_vectors": speaker_embeddings}
|
||||||
outputs = model(
|
outputs = model(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||||
text_input,
|
postnet_outputs = outputs["model_outputs"]
|
||||||
text_lengths,
|
|
||||||
mel_input,
|
|
||||||
mel_lengths,
|
|
||||||
cond_input
|
|
||||||
)
|
|
||||||
postnet_outputs = outputs['model_outputs']
|
|
||||||
# normalize tacotron output
|
# normalize tacotron output
|
||||||
if model_name == "tacotron":
|
if model_name == "tacotron":
|
||||||
mel_specs = []
|
mel_specs = []
|
||||||
|
|
|
@ -8,13 +8,8 @@ from TTS.trainer import TrainerTTS
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# try:
|
# try:
|
||||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
|
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||||
sys.argv)
|
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
|
||||||
trainer = TrainerTTS(args,
|
|
||||||
config,
|
|
||||||
c_logger,
|
|
||||||
tb_logger,
|
|
||||||
output_path=OUT_PATH)
|
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
# except KeyboardInterrupt:
|
# except KeyboardInterrupt:
|
||||||
# remove_experiment_folder(OUT_PATH)
|
# remove_experiment_folder(OUT_PATH)
|
||||||
|
|
140
TTS/trainer.py
140
TTS/trainer.py
|
@ -84,7 +84,7 @@ class TrainerTTS:
|
||||||
self.best_loss = float("inf")
|
self.best_loss = float("inf")
|
||||||
self.train_loader = None
|
self.train_loader = None
|
||||||
self.eval_loader = None
|
self.eval_loader = None
|
||||||
self.output_audio_path = os.path.join(output_path, 'test_audios')
|
self.output_audio_path = os.path.join(output_path, "test_audios")
|
||||||
|
|
||||||
self.keep_avg_train = None
|
self.keep_avg_train = None
|
||||||
self.keep_avg_eval = None
|
self.keep_avg_eval = None
|
||||||
|
@ -138,8 +138,8 @@ class TrainerTTS:
|
||||||
|
|
||||||
if self.args.restore_path:
|
if self.args.restore_path:
|
||||||
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
|
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
|
||||||
self.config, args.restore_path, self.model, self.optimizer,
|
self.config, args.restore_path, self.model, self.optimizer, self.scaler
|
||||||
self.scaler)
|
)
|
||||||
|
|
||||||
# setup scheduler
|
# setup scheduler
|
||||||
self.scheduler = self.get_scheduler(self.config, self.optimizer)
|
self.scheduler = self.get_scheduler(self.config, self.optimizer)
|
||||||
|
@ -207,6 +207,7 @@ class TrainerTTS:
|
||||||
return None
|
return None
|
||||||
if lr_scheduler.lower() == "noamlr":
|
if lr_scheduler.lower() == "noamlr":
|
||||||
from TTS.utils.training import NoamLR
|
from TTS.utils.training import NoamLR
|
||||||
|
|
||||||
scheduler = NoamLR
|
scheduler = NoamLR
|
||||||
else:
|
else:
|
||||||
scheduler = getattr(torch.optim, lr_scheduler)
|
scheduler = getattr(torch.optim, lr_scheduler)
|
||||||
|
@ -261,8 +262,7 @@ class TrainerTTS:
|
||||||
ap=ap,
|
ap=ap,
|
||||||
tp=self.config.characters,
|
tp=self.config.characters,
|
||||||
add_blank=self.config["add_blank"],
|
add_blank=self.config["add_blank"],
|
||||||
batch_group_size=0 if is_eval else
|
batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
|
||||||
self.config.batch_group_size * self.config.batch_size,
|
|
||||||
min_seq_len=self.config.min_seq_len,
|
min_seq_len=self.config.min_seq_len,
|
||||||
max_seq_len=self.config.max_seq_len,
|
max_seq_len=self.config.max_seq_len,
|
||||||
phoneme_cache_path=self.config.phoneme_cache_path,
|
phoneme_cache_path=self.config.phoneme_cache_path,
|
||||||
|
@ -272,8 +272,8 @@ class TrainerTTS:
|
||||||
use_noise_augment=not is_eval,
|
use_noise_augment=not is_eval,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_mapping=speaker_mapping
|
speaker_mapping=speaker_mapping
|
||||||
if self.config.use_speaker_embedding
|
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
|
||||||
and self.config.use_external_speaker_embedding_file else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.use_phonemes and self.config.compute_input_seq_cache:
|
if self.config.use_phonemes and self.config.compute_input_seq_cache:
|
||||||
|
@ -281,18 +281,15 @@ class TrainerTTS:
|
||||||
dataset.compute_input_seq(self.config.num_loader_workers)
|
dataset.compute_input_seq(self.config.num_loader_workers)
|
||||||
dataset.sort_items()
|
dataset.sort_items()
|
||||||
|
|
||||||
sampler = DistributedSampler(
|
sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
|
||||||
dataset) if self.num_gpus > 1 else None
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=self.config.eval_batch_size
|
batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
|
||||||
if is_eval else self.config.batch_size,
|
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=self.config.num_val_loader_workers
|
num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
|
||||||
if is_eval else self.config.num_loader_workers,
|
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
)
|
)
|
||||||
return loader
|
return loader
|
||||||
|
@ -314,8 +311,7 @@ class TrainerTTS:
|
||||||
text_input = batch[0]
|
text_input = batch[0]
|
||||||
text_lengths = batch[1]
|
text_lengths = batch[1]
|
||||||
speaker_names = batch[2]
|
speaker_names = batch[2]
|
||||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"
|
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
||||||
] else None
|
|
||||||
mel_input = batch[4]
|
mel_input = batch[4]
|
||||||
mel_lengths = batch[5]
|
mel_lengths = batch[5]
|
||||||
stop_targets = batch[6]
|
stop_targets = batch[6]
|
||||||
|
@ -331,10 +327,7 @@ class TrainerTTS:
|
||||||
speaker_embeddings = batch[8]
|
speaker_embeddings = batch[8]
|
||||||
speaker_ids = None
|
speaker_ids = None
|
||||||
else:
|
else:
|
||||||
speaker_ids = [
|
speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
|
||||||
self.speaker_manager.speaker_ids[speaker_name]
|
|
||||||
for speaker_name in speaker_names
|
|
||||||
]
|
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
speaker_embeddings = None
|
speaker_embeddings = None
|
||||||
else:
|
else:
|
||||||
|
@ -346,7 +339,7 @@ class TrainerTTS:
|
||||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||||
for idx, am in enumerate(attn_mask):
|
for idx, am in enumerate(attn_mask):
|
||||||
# compute raw durations
|
# compute raw durations
|
||||||
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
|
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||||
|
@ -359,14 +352,11 @@ class TrainerTTS:
|
||||||
assert (
|
assert (
|
||||||
dur.sum() == mel_lengths[idx]
|
dur.sum() == mel_lengths[idx]
|
||||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||||
durations[idx, :text_lengths[idx]] = dur
|
durations[idx, : text_lengths[idx]] = dur
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per iteration.
|
# set stop targets view, we predict a single stop token per iteration.
|
||||||
stop_targets = stop_targets.view(text_input.shape[0],
|
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||||
stop_targets.size(1) // self.config.r,
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||||
-1)
|
|
||||||
stop_targets = (stop_targets.sum(2) >
|
|
||||||
0.0).unsqueeze(2).float().squeeze(2)
|
|
||||||
|
|
||||||
# dispatch batch to GPU
|
# dispatch batch to GPU
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
|
@ -374,15 +364,10 @@ class TrainerTTS:
|
||||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||||
mel_input = mel_input.cuda(non_blocking=True)
|
mel_input = mel_input.cuda(non_blocking=True)
|
||||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||||
linear_input = linear_input.cuda(
|
linear_input = linear_input.cuda(non_blocking=True) if self.config.model.lower() in ["tacotron"] else None
|
||||||
non_blocking=True) if self.config.model.lower() in [
|
|
||||||
"tacotron"
|
|
||||||
] else None
|
|
||||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||||
attn_mask = attn_mask.cuda(
|
attn_mask = attn_mask.cuda(non_blocking=True) if attn_mask is not None else None
|
||||||
non_blocking=True) if attn_mask is not None else None
|
durations = durations.cuda(non_blocking=True) if attn_mask is not None else None
|
||||||
durations = durations.cuda(
|
|
||||||
non_blocking=True) if attn_mask is not None else None
|
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||||
if speaker_embeddings is not None:
|
if speaker_embeddings is not None:
|
||||||
|
@ -401,7 +386,7 @@ class TrainerTTS:
|
||||||
"x_vectors": speaker_embeddings,
|
"x_vectors": speaker_embeddings,
|
||||||
"max_text_length": max_text_length,
|
"max_text_length": max_text_length,
|
||||||
"max_spec_length": max_spec_length,
|
"max_spec_length": max_spec_length,
|
||||||
"item_idx": item_idx
|
"item_idx": item_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int,
|
def train_step(self, batch: Dict, batch_n_steps: int, step: int,
|
||||||
|
@ -421,25 +406,20 @@ class TrainerTTS:
|
||||||
|
|
||||||
# check nan loss
|
# check nan loss
|
||||||
if torch.isnan(loss_dict["loss"]).any():
|
if torch.isnan(loss_dict["loss"]).any():
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
||||||
f"Detected NaN loss at step {self.total_steps_done}.")
|
|
||||||
|
|
||||||
# optimizer step
|
# optimizer step
|
||||||
if self.config.mixed_precision:
|
if self.config.mixed_precision:
|
||||||
# model optimizer step in mixed precision mode
|
# model optimizer step in mixed precision mode
|
||||||
self.scaler.scale(loss_dict["loss"]).backward()
|
self.scaler.scale(loss_dict["loss"]).backward()
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
grad_norm, _ = check_update(self.model,
|
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
||||||
self.config.grad_clip,
|
|
||||||
ignore_stopnet=True)
|
|
||||||
self.scaler.step(self.optimizer)
|
self.scaler.step(self.optimizer)
|
||||||
self.scaler.update()
|
self.scaler.update()
|
||||||
else:
|
else:
|
||||||
# main model optimizer step
|
# main model optimizer step
|
||||||
loss_dict["loss"].backward()
|
loss_dict["loss"].backward()
|
||||||
grad_norm, _ = check_update(self.model,
|
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
||||||
self.config.grad_clip,
|
|
||||||
ignore_stopnet=True)
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
step_time = time.time() - step_start_time
|
step_time = time.time() - step_start_time
|
||||||
|
@ -469,17 +449,15 @@ class TrainerTTS:
|
||||||
current_lr = self.optimizer.param_groups[0]["lr"]
|
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||||
if self.total_steps_done % self.config.print_step == 0:
|
if self.total_steps_done % self.config.print_step == 0:
|
||||||
log_dict = {
|
log_dict = {
|
||||||
"max_spec_length": [batch["max_spec_length"],
|
"max_spec_length": [batch["max_spec_length"], 1], # value, precision
|
||||||
1], # value, precision
|
|
||||||
"max_text_length": [batch["max_text_length"], 1],
|
"max_text_length": [batch["max_text_length"], 1],
|
||||||
"step_time": [step_time, 4],
|
"step_time": [step_time, 4],
|
||||||
"loader_time": [loader_time, 2],
|
"loader_time": [loader_time, 2],
|
||||||
"current_lr": current_lr,
|
"current_lr": current_lr,
|
||||||
}
|
}
|
||||||
self.c_logger.print_train_step(batch_n_steps, step,
|
self.c_logger.print_train_step(
|
||||||
self.total_steps_done, log_dict,
|
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
||||||
loss_dict,
|
)
|
||||||
self.keep_avg_train.avg_values)
|
|
||||||
|
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
# Plot Training Iter Stats
|
# Plot Training Iter Stats
|
||||||
|
@ -491,8 +469,7 @@ class TrainerTTS:
|
||||||
"step_time": step_time,
|
"step_time": step_time,
|
||||||
}
|
}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
self.tb_logger.tb_train_step_stats(self.total_steps_done,
|
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
||||||
iter_stats)
|
|
||||||
|
|
||||||
if self.total_steps_done % self.config.save_step == 0:
|
if self.total_steps_done % self.config.save_step == 0:
|
||||||
if self.config.checkpoint:
|
if self.config.checkpoint:
|
||||||
|
@ -506,15 +483,12 @@ class TrainerTTS:
|
||||||
self.output_path,
|
self.output_path,
|
||||||
model_loss=loss_dict["loss"],
|
model_loss=loss_dict["loss"],
|
||||||
characters=self.model_characters,
|
characters=self.model_characters,
|
||||||
scaler=self.scaler.state_dict()
|
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||||
if self.config.mixed_precision else None,
|
|
||||||
)
|
)
|
||||||
# training visualizations
|
# training visualizations
|
||||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||||
self.tb_logger.tb_train_audios(self.total_steps_done,
|
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
|
||||||
{"TrainAudio": audios},
|
|
||||||
self.ap.sample_rate)
|
|
||||||
self.total_steps_done += 1
|
self.total_steps_done += 1
|
||||||
self.on_train_step_end()
|
self.on_train_step_end()
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
@ -523,35 +497,28 @@ class TrainerTTS:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
batch_num_steps = int(
|
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
|
||||||
len(self.train_loader.dataset) /
|
|
||||||
(self.config.batch_size * self.num_gpus))
|
|
||||||
else:
|
else:
|
||||||
batch_num_steps = int(
|
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
||||||
len(self.train_loader.dataset) / self.config.batch_size)
|
|
||||||
self.c_logger.print_train_start()
|
self.c_logger.print_train_start()
|
||||||
loader_start_time = time.time()
|
loader_start_time = time.time()
|
||||||
for cur_step, batch in enumerate(self.train_loader):
|
for cur_step, batch in enumerate(self.train_loader):
|
||||||
_, _ = self.train_step(batch, batch_num_steps, cur_step,
|
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
||||||
loader_start_time)
|
|
||||||
epoch_time = time.time() - epoch_start_time
|
epoch_time = time.time() - epoch_start_time
|
||||||
# Plot self.epochs_done Stats
|
# Plot self.epochs_done Stats
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
epoch_stats = {"epoch_time": epoch_time}
|
epoch_stats = {"epoch_time": epoch_time}
|
||||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done,
|
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||||
epoch_stats)
|
|
||||||
if self.config.tb_model_param_stats:
|
if self.config.tb_model_param_stats:
|
||||||
self.tb_logger.tb_model_weights(self.model,
|
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||||
self.total_steps_done)
|
|
||||||
|
|
||||||
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||||
outputs, loss_dict = self.model.eval_step(
|
outputs, loss_dict = self.model.eval_step(batch, self.criterion)
|
||||||
batch, self.criterion)
|
|
||||||
|
|
||||||
step_time = time.time() - step_start_time
|
step_time = time.time() - step_start_time
|
||||||
|
|
||||||
|
@ -572,8 +539,7 @@ class TrainerTTS:
|
||||||
self.keep_avg_eval.update_values(update_eval_values)
|
self.keep_avg_eval.update_values(update_eval_values)
|
||||||
|
|
||||||
if self.config.print_eval:
|
if self.config.print_eval:
|
||||||
self.c_logger.print_eval_step(step, loss_dict,
|
self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
|
||||||
self.keep_avg_eval.avg_values)
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def eval_epoch(self) -> None:
|
def eval_epoch(self) -> None:
|
||||||
|
@ -585,15 +551,13 @@ class TrainerTTS:
|
||||||
# format data
|
# format data
|
||||||
batch = self.format_batch(batch)
|
batch = self.format_batch(batch)
|
||||||
loader_time = time.time() - loader_start_time
|
loader_time = time.time() - loader_start_time
|
||||||
self.keep_avg_eval.update_values({'avg_loader_time': loader_time})
|
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
||||||
outputs, _ = self.eval_step(batch, cur_step)
|
outputs, _ = self.eval_step(batch, cur_step)
|
||||||
# Plot epoch stats and samples from the last batch.
|
# Plot epoch stats and samples from the last batch.
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
||||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||||
self.tb_logger.tb_eval_audios(self.total_steps_done,
|
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
||||||
{"EvalAudio": eval_audios},
|
|
||||||
self.ap.sample_rate)
|
|
||||||
|
|
||||||
def test_run(self, ) -> None:
|
def test_run(self, ) -> None:
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
|
@ -608,9 +572,9 @@ class TrainerTTS:
|
||||||
self.config,
|
self.config,
|
||||||
self.use_cuda,
|
self.use_cuda,
|
||||||
self.ap,
|
self.ap,
|
||||||
speaker_id=cond_inputs['speaker_id'],
|
speaker_id=cond_inputs["speaker_id"],
|
||||||
x_vector=cond_inputs['x_vector'],
|
x_vector=cond_inputs["x_vector"],
|
||||||
style_wav=cond_inputs['style_wav'],
|
style_wav=cond_inputs["style_wav"],
|
||||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||||
use_griffin_lim=True,
|
use_griffin_lim=True,
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
|
@ -623,10 +587,8 @@ class TrainerTTS:
|
||||||
"TestSentence_{}.wav".format(idx))
|
"TestSentence_{}.wav".format(idx))
|
||||||
self.ap.save_wav(wav, file_path)
|
self.ap.save_wav(wav, file_path)
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||||
model_outputs, self.ap, output_fig=False)
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
|
||||||
alignment, output_fig=False)
|
|
||||||
|
|
||||||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
|
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
|
||||||
self.config.audio["sample_rate"])
|
self.config.audio["sample_rate"])
|
||||||
|
@ -641,11 +603,11 @@ class TrainerTTS:
|
||||||
if self.config.use_external_speaker_embedding_file
|
if self.config.use_external_speaker_embedding_file
|
||||||
and self.config.use_speaker_embedding else None)
|
and self.config.use_speaker_embedding else None)
|
||||||
# setup style_mel
|
# setup style_mel
|
||||||
if self.config.has('gst_style_input'):
|
if self.config.has("gst_style_input"):
|
||||||
style_wav = self.config.gst_style_input
|
style_wav = self.config.gst_style_input
|
||||||
else:
|
else:
|
||||||
style_wav = None
|
style_wav = None
|
||||||
if style_wav is None and 'use_gst' in self.config and self.config.use_gst:
|
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||||
# inicialize GST with zero dict.
|
# inicialize GST with zero dict.
|
||||||
style_wav = {}
|
style_wav = {}
|
||||||
print(
|
print(
|
||||||
|
@ -688,8 +650,7 @@ class TrainerTTS:
|
||||||
for epoch in range(0, self.config.epochs):
|
for epoch in range(0, self.config.epochs):
|
||||||
self.on_epoch_start()
|
self.on_epoch_start()
|
||||||
self.keep_avg_train = KeepAverage()
|
self.keep_avg_train = KeepAverage()
|
||||||
self.keep_avg_eval = KeepAverage(
|
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||||
) if self.config.run_eval else None
|
|
||||||
self.epochs_done = epoch
|
self.epochs_done = epoch
|
||||||
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
||||||
self.train_epoch()
|
self.train_epoch()
|
||||||
|
@ -698,8 +659,8 @@ class TrainerTTS:
|
||||||
if epoch >= self.config.test_delay_epochs:
|
if epoch >= self.config.test_delay_epochs:
|
||||||
self.test_run()
|
self.test_run()
|
||||||
self.c_logger.print_epoch_end(
|
self.c_logger.print_epoch_end(
|
||||||
epoch, self.keep_avg_eval.avg_values
|
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||||
if self.config.run_eval else self.keep_avg_train.avg_values)
|
)
|
||||||
self.save_best_model()
|
self.save_best_model()
|
||||||
self.on_epoch_end()
|
self.on_epoch_end()
|
||||||
|
|
||||||
|
@ -717,8 +678,7 @@ class TrainerTTS:
|
||||||
self.model_characters,
|
self.model_characters,
|
||||||
keep_all_best=self.config.keep_all_best,
|
keep_all_best=self.config.keep_all_best,
|
||||||
keep_after=self.config.keep_after,
|
keep_after=self.config.keep_after,
|
||||||
scaler=self.scaler.state_dict()
|
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||||
if self.config.mixed_precision else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -93,7 +93,7 @@ class AlignTTSConfig(BaseTTSConfig):
|
||||||
|
|
||||||
# optimizer parameters
|
# optimizer parameters
|
||||||
optimizer: str = "Adam"
|
optimizer: str = "Adam"
|
||||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||||
lr_scheduler: str = None
|
lr_scheduler: str = None
|
||||||
lr_scheduler_params: dict = None
|
lr_scheduler_params: dict = None
|
||||||
lr: float = 1e-4
|
lr: float = 1e-4
|
||||||
|
@ -104,12 +104,13 @@ class AlignTTSConfig(BaseTTSConfig):
|
||||||
max_seq_len: int = 200
|
max_seq_len: int = 200
|
||||||
r: int = 1
|
r: int = 1
|
||||||
|
|
||||||
# testing
|
# testing
|
||||||
test_sentences: List[str] = field(default_factory=lambda:[
|
test_sentences: List[str] = field(
|
||||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
default_factory=lambda: [
|
||||||
"Be a voice, not an echo.",
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
"Be a voice, not an echo.",
|
||||||
"This cake is great. It's so delicious and moist.",
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
"Prior to November 22, 1963."
|
"This cake is great. It's so delicious and moist.",
|
||||||
])
|
"Prior to November 22, 1963.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
@ -91,9 +91,9 @@ class GlowTTSConfig(BaseTTSConfig):
|
||||||
|
|
||||||
# optimizer parameters
|
# optimizer parameters
|
||||||
optimizer: str = "RAdam"
|
optimizer: str = "RAdam"
|
||||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||||
lr_scheduler: str = "NoamLR"
|
lr_scheduler: str = "NoamLR"
|
||||||
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
|
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||||
grad_clip: float = 5.0
|
grad_clip: float = 5.0
|
||||||
lr: float = 1e-3
|
lr: float = 1e-3
|
||||||
|
|
||||||
|
|
|
@ -174,7 +174,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
optimizer: str = MISSING
|
optimizer: str = MISSING
|
||||||
optimizer_params: dict = MISSING
|
optimizer_params: dict = MISSING
|
||||||
# scheduler
|
# scheduler
|
||||||
lr_scheduler: str = ''
|
lr_scheduler: str = ""
|
||||||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||||
# testing
|
# testing
|
||||||
test_sentences: List[str] = field(default_factory=lambda:[])
|
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||||
|
|
|
@ -101,7 +101,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
||||||
|
|
||||||
# optimizer parameters
|
# optimizer parameters
|
||||||
optimizer: str = "RAdam"
|
optimizer: str = "RAdam"
|
||||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||||
lr_scheduler: str = None
|
lr_scheduler: str = None
|
||||||
lr_scheduler_params: dict = None
|
lr_scheduler_params: dict = None
|
||||||
lr: float = 1e-4
|
lr: float = 1e-4
|
||||||
|
@ -118,10 +118,12 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
||||||
r: int = 1 # DO NOT CHANGE
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
|
||||||
# testing
|
# testing
|
||||||
test_sentences: List[str] = field(default_factory=lambda:[
|
test_sentences: List[str] = field(
|
||||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
default_factory=lambda: [
|
||||||
"Be a voice, not an echo.",
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
"Be a voice, not an echo.",
|
||||||
"This cake is great. It's so delicious and moist.",
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
"Prior to November 22, 1963."
|
"This cake is great. It's so delicious and moist.",
|
||||||
])
|
"Prior to November 22, 1963.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
@ -160,9 +160,9 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
|
|
||||||
# optimizer parameters
|
# optimizer parameters
|
||||||
optimizer: str = "RAdam"
|
optimizer: str = "RAdam"
|
||||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||||
lr_scheduler: str = "NoamLR"
|
lr_scheduler: str = "NoamLR"
|
||||||
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
|
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||||
lr: float = 1e-4
|
lr: float = 1e-4
|
||||||
grad_clip: float = 5.0
|
grad_clip: float = 5.0
|
||||||
seq_len_norm: bool = False
|
seq_len_norm: bool = False
|
||||||
|
@ -178,13 +178,15 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
ga_alpha: float = 5.0
|
ga_alpha: float = 5.0
|
||||||
|
|
||||||
# testing
|
# testing
|
||||||
test_sentences: List[str] = field(default_factory=lambda:[
|
test_sentences: List[str] = field(
|
||||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
default_factory=lambda: [
|
||||||
"Be a voice, not an echo.",
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
"Be a voice, not an echo.",
|
||||||
"This cake is great. It's so delicious and moist.",
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
"Prior to November 22, 1963."
|
"This cake is great. It's so delicious and moist.",
|
||||||
])
|
"Prior to November 22, 1963.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def check_values(self):
|
def check_values(self):
|
||||||
if self.gradual_training:
|
if self.gradual_training:
|
||||||
|
|
|
@ -44,22 +44,18 @@ def load_meta_data(datasets, eval_split=True):
|
||||||
preprocessor = _get_preprocessor_by_name(name)
|
preprocessor = _get_preprocessor_by_name(name)
|
||||||
# load train set
|
# load train set
|
||||||
meta_data_train = preprocessor(root_path, meta_file_train)
|
meta_data_train = preprocessor(root_path, meta_file_train)
|
||||||
print(
|
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||||
f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}"
|
|
||||||
)
|
|
||||||
# load evaluation split if set
|
# load evaluation split if set
|
||||||
if eval_split:
|
if eval_split:
|
||||||
if meta_file_val:
|
if meta_file_val:
|
||||||
meta_data_eval = preprocessor(root_path, meta_file_val)
|
meta_data_eval = preprocessor(root_path, meta_file_val)
|
||||||
else:
|
else:
|
||||||
meta_data_eval, meta_data_train = split_dataset(
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||||
meta_data_train)
|
|
||||||
meta_data_eval_all += meta_data_eval
|
meta_data_eval_all += meta_data_eval
|
||||||
meta_data_train_all += meta_data_train
|
meta_data_train_all += meta_data_train
|
||||||
# load attention masks for duration predictor training
|
# load attention masks for duration predictor training
|
||||||
if dataset.meta_file_attn_mask:
|
if dataset.meta_file_attn_mask:
|
||||||
meta_data = dict(
|
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||||
load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
|
||||||
for idx, ins in enumerate(meta_data_train_all):
|
for idx, ins in enumerate(meta_data_train_all):
|
||||||
attn_file = meta_data[ins[1]].strip()
|
attn_file = meta_data[ins[1]].strip()
|
||||||
meta_data_train_all[idx].append(attn_file)
|
meta_data_train_all[idx].append(attn_file)
|
||||||
|
|
|
@ -506,5 +506,10 @@ class AlignTTSLoss(nn.Module):
|
||||||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||||
loss = self.spec_loss_alpha * spec_loss + self.ssim_alpha * ssim_loss + self.dur_loss_alpha * dur_loss + self.mdn_alpha * mdn_loss
|
loss = (
|
||||||
|
self.spec_loss_alpha * spec_loss
|
||||||
|
+ self.ssim_alpha * ssim_loss
|
||||||
|
+ self.dur_loss_alpha * dur_loss
|
||||||
|
+ self.mdn_alpha * mdn_loss
|
||||||
|
)
|
||||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||||
|
|
|
@ -72,19 +72,9 @@ class AlignTTS(nn.Module):
|
||||||
hidden_channels=256,
|
hidden_channels=256,
|
||||||
hidden_channels_dp=256,
|
hidden_channels_dp=256,
|
||||||
encoder_type="fftransformer",
|
encoder_type="fftransformer",
|
||||||
encoder_params={
|
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||||
"hidden_channels_ffn": 1024,
|
|
||||||
"num_heads": 2,
|
|
||||||
"num_layers": 6,
|
|
||||||
"dropout_p": 0.1
|
|
||||||
},
|
|
||||||
decoder_type="fftransformer",
|
decoder_type="fftransformer",
|
||||||
decoder_params={
|
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||||
"hidden_channels_ffn": 1024,
|
|
||||||
"num_heads": 2,
|
|
||||||
"num_layers": 6,
|
|
||||||
"dropout_p": 0.1
|
|
||||||
},
|
|
||||||
length_scale=1,
|
length_scale=1,
|
||||||
num_speakers=0,
|
num_speakers=0,
|
||||||
external_c=False,
|
external_c=False,
|
||||||
|
@ -93,14 +83,11 @@ class AlignTTS(nn.Module):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.phase = -1
|
self.phase = -1
|
||||||
self.length_scale = float(length_scale) if isinstance(
|
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||||
length_scale, int) else length_scale
|
|
||||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||||
encoder_params, c_in_channels)
|
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
|
|
||||||
decoder_params)
|
|
||||||
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
||||||
|
|
||||||
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||||
|
@ -121,9 +108,9 @@ class AlignTTS(nn.Module):
|
||||||
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||||
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||||
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
|
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
|
||||||
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss(
|
exponential = -0.5 * torch.mean(
|
||||||
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2),
|
torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1
|
||||||
dim=-1) # B, L, T
|
) # B, L, T
|
||||||
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
|
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
|
||||||
return logp
|
return logp
|
||||||
|
|
||||||
|
@ -157,9 +144,7 @@ class AlignTTS(nn.Module):
|
||||||
[1, 0, 0, 0, 0, 0, 0]]
|
[1, 0, 0, 0, 0, 0, 0]]
|
||||||
"""
|
"""
|
||||||
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
||||||
o_en_ex = torch.matmul(
|
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||||
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
|
||||||
2)).transpose(1, 2)
|
|
||||||
return o_en_ex, attn
|
return o_en_ex, attn
|
||||||
|
|
||||||
def format_durations(self, o_dr_log, x_mask):
|
def format_durations(self, o_dr_log, x_mask):
|
||||||
|
@ -193,8 +178,7 @@ class AlignTTS(nn.Module):
|
||||||
x_emb = torch.transpose(x_emb, 1, -1)
|
x_emb = torch.transpose(x_emb, 1, -1)
|
||||||
|
|
||||||
# compute sequence masks
|
# compute sequence masks
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
||||||
1).to(x.dtype)
|
|
||||||
|
|
||||||
# encoder pass
|
# encoder pass
|
||||||
o_en = self.encoder(x_emb, x_mask)
|
o_en = self.encoder(x_emb, x_mask)
|
||||||
|
@ -207,8 +191,7 @@ class AlignTTS(nn.Module):
|
||||||
return o_en, o_en_dp, x_mask, g
|
return o_en, o_en_dp, x_mask, g
|
||||||
|
|
||||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||||
1).to(o_en_dp.dtype)
|
|
||||||
# expand o_en with durations
|
# expand o_en with durations
|
||||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||||
# positional encoding
|
# positional encoding
|
||||||
|
@ -224,13 +207,13 @@ class AlignTTS(nn.Module):
|
||||||
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
|
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
|
||||||
# MAS potentials and alignment
|
# MAS potentials and alignment
|
||||||
mu, log_sigma = self.mdn_block(o_en)
|
mu, log_sigma = self.mdn_block(o_en)
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||||
1).to(o_en.dtype)
|
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
|
||||||
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
|
|
||||||
y_mask)
|
|
||||||
return dr_mas, mu, log_sigma, logp
|
return dr_mas, mu, log_sigma, logp
|
||||||
|
|
||||||
def forward(self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None): # pylint: disable=unused-argument
|
def forward(
|
||||||
|
self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T_max]
|
x: [B, T_max]
|
||||||
|
@ -240,83 +223,58 @@ class AlignTTS(nn.Module):
|
||||||
g: [B, C]
|
g: [B, C]
|
||||||
"""
|
"""
|
||||||
y = y.transpose(1, 2)
|
y = y.transpose(1, 2)
|
||||||
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||||
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
|
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
|
||||||
if phase == 0:
|
if phase == 0:
|
||||||
# train encoder and MDN
|
# train encoder and MDN
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||||
o_en, y, y_lengths, x_mask)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
|
||||||
1).to(o_en_dp.dtype)
|
|
||||||
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
|
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
|
||||||
elif phase == 1:
|
elif phase == 1:
|
||||||
# train decoder
|
# train decoder
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||||
o_de, attn = self._forward_decoder(o_en.detach(),
|
o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g)
|
||||||
o_en_dp.detach(),
|
|
||||||
dr_mas.detach(),
|
|
||||||
x_mask,
|
|
||||||
y_lengths,
|
|
||||||
g=g)
|
|
||||||
elif phase == 2:
|
elif phase == 2:
|
||||||
# train the whole except duration predictor
|
# train the whole except duration predictor
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||||
o_en, y, y_lengths, x_mask)
|
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||||
o_de, attn = self._forward_decoder(o_en,
|
|
||||||
o_en_dp,
|
|
||||||
dr_mas,
|
|
||||||
x_mask,
|
|
||||||
y_lengths,
|
|
||||||
g=g)
|
|
||||||
elif phase == 3:
|
elif phase == 3:
|
||||||
# train duration predictor
|
# train duration predictor
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
o_dr_log = self.duration_predictor(x, x_mask)
|
o_dr_log = self.duration_predictor(x, x_mask)
|
||||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||||
o_en, y, y_lengths, x_mask)
|
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||||
o_de, attn = self._forward_decoder(o_en,
|
|
||||||
o_en_dp,
|
|
||||||
dr_mas,
|
|
||||||
x_mask,
|
|
||||||
y_lengths,
|
|
||||||
g=g)
|
|
||||||
o_dr_log = o_dr_log.squeeze(1)
|
o_dr_log = o_dr_log.squeeze(1)
|
||||||
else:
|
else:
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||||
o_en, y, y_lengths, x_mask)
|
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||||
o_de, attn = self._forward_decoder(o_en,
|
|
||||||
o_en_dp,
|
|
||||||
dr_mas,
|
|
||||||
x_mask,
|
|
||||||
y_lengths,
|
|
||||||
g=g)
|
|
||||||
o_dr_log = o_dr_log.squeeze(1)
|
o_dr_log = o_dr_log.squeeze(1)
|
||||||
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
|
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
|
||||||
outputs = {
|
outputs = {
|
||||||
'model_outputs': o_de.transpose(1, 2),
|
"model_outputs": o_de.transpose(1, 2),
|
||||||
'alignments': attn,
|
"alignments": attn,
|
||||||
'durations_log': o_dr_log,
|
"durations_log": o_dr_log,
|
||||||
'durations_mas_log': dr_mas_log,
|
"durations_mas_log": dr_mas_log,
|
||||||
'mu': mu,
|
"mu": mu,
|
||||||
'log_sigma': log_sigma,
|
"log_sigma": log_sigma,
|
||||||
'logp': logp
|
"logp": logp,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, cond_input={'x_vectors': None}): # pylint: disable=unused-argument
|
def inference(self, x, cond_input={"x_vectors": None}): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T_max]
|
x: [B, T_max]
|
||||||
x_lengths: [B]
|
x_lengths: [B]
|
||||||
g: [B, C]
|
g: [B, C]
|
||||||
"""
|
"""
|
||||||
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
# pad input to prevent dropping the last word
|
# pad input to prevent dropping the last word
|
||||||
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||||
|
@ -326,46 +284,40 @@ class AlignTTS(nn.Module):
|
||||||
# duration predictor pass
|
# duration predictor pass
|
||||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
y_lengths = o_dr.sum(1)
|
y_lengths = o_dr.sum(1)
|
||||||
o_de, attn = self._forward_decoder(o_en,
|
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||||
o_en_dp,
|
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn}
|
||||||
o_dr,
|
|
||||||
x_mask,
|
|
||||||
y_lengths,
|
|
||||||
g=g)
|
|
||||||
outputs = {'model_outputs': o_de.transpose(1, 2), 'alignments': attn}
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def train_step(self, batch: dict, criterion: nn.Module):
|
def train_step(self, batch: dict, criterion: nn.Module):
|
||||||
text_input = batch['text_input']
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch['text_lengths']
|
text_lengths = batch["text_lengths"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
mel_lengths = batch['mel_lengths']
|
mel_lengths = batch["mel_lengths"]
|
||||||
x_vectors = batch['x_vectors']
|
x_vectors = batch["x_vectors"]
|
||||||
speaker_ids = batch['speaker_ids']
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
|
||||||
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids}
|
cond_input = {"x_vectors": x_vectors, "speaker_ids": speaker_ids}
|
||||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input, self.phase)
|
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input, self.phase)
|
||||||
loss_dict = criterion(
|
loss_dict = criterion(
|
||||||
outputs['logp'],
|
outputs["logp"],
|
||||||
outputs['model_outputs'],
|
outputs["model_outputs"],
|
||||||
mel_input,
|
mel_input,
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
outputs['durations_log'],
|
outputs["durations_log"],
|
||||||
outputs['durations_mas_log'],
|
outputs["durations_mas_log"],
|
||||||
text_lengths,
|
text_lengths,
|
||||||
phase=self.phase,
|
phase=self.phase,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
# compute alignment error (the lower the better )
|
||||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'],
|
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||||
binary=True)
|
|
||||||
loss_dict["align_error"] = align_error
|
loss_dict["align_error"] = align_error
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
model_outputs = outputs['model_outputs']
|
model_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs['alignments']
|
alignments = outputs["alignments"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
|
|
||||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
gt_spec = mel_input[0].data.cpu().numpy()
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
@ -387,7 +339,9 @@ class AlignTTS(nn.Module):
|
||||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
return self.train_log(ap, batch, outputs)
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
def load_checkpoint(
|
||||||
|
self, config, checkpoint_path, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -38,6 +38,7 @@ class GlowTTS(nn.Module):
|
||||||
encoder_params (dict): encoder module parameters.
|
encoder_params (dict): encoder module parameters.
|
||||||
speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_chars,
|
num_chars,
|
||||||
|
@ -132,17 +133,17 @@ class GlowTTS(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||||
# compute final values with the computed alignment
|
# compute final values with the computed alignment
|
||||||
y_mean = torch.matmul(
|
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||||
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
1, 2
|
||||||
1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||||
y_log_scale = torch.matmul(
|
y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose(
|
||||||
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
|
1, 2
|
||||||
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||||
# compute total duration with adjustment
|
# compute total duration with adjustment
|
||||||
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
||||||
return y_mean, y_log_scale, o_attn_dur
|
return y_mean, y_log_scale, o_attn_dur
|
||||||
|
|
||||||
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={'x_vectors':None}):
|
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={"x_vectors": None}):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T]
|
x: [B, T]
|
||||||
|
@ -154,7 +155,7 @@ class GlowTTS(nn.Module):
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
y = y.transpose(1, 2)
|
y = y.transpose(1, 2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
g = cond_input['x_vectors']
|
g = cond_input["x_vectors"]
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.speaker_embedding_dim:
|
if self.speaker_embedding_dim:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
|
@ -162,54 +163,38 @@ class GlowTTS(nn.Module):
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||||
x_lengths,
|
|
||||||
g=g)
|
|
||||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||||
y, y_lengths, y_max_length, attn = self.preprocess(
|
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
||||||
y, y_lengths, y_max_length, None)
|
|
||||||
# create masks
|
# create masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||||
1).to(x_mask.dtype)
|
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# decoder pass
|
# decoder pass
|
||||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||||
# find the alignment path
|
# find the alignment path
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
o_scale = torch.exp(-2 * o_log_scale)
|
o_scale = torch.exp(-2 * o_log_scale)
|
||||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
[1]).unsqueeze(-1) # [b, t, 1]
|
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
|
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
|
|
||||||
z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
|
||||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
|
|
||||||
[1]).unsqueeze(-1) # [b, t, 1]
|
|
||||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||||
attn = maximum_path(logp,
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
attn_mask.squeeze(1)).unsqueeze(1).detach()
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
|
||||||
attn, o_mean, o_log_scale, x_mask)
|
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
outputs = {
|
outputs = {
|
||||||
'model_outputs': z,
|
"model_outputs": z,
|
||||||
'logdet': logdet,
|
"logdet": logdet,
|
||||||
'y_mean': y_mean,
|
"y_mean": y_mean,
|
||||||
'y_log_scale': y_log_scale,
|
"y_log_scale": y_log_scale,
|
||||||
'alignments': attn,
|
"alignments": attn,
|
||||||
'durations_log': o_dur_log,
|
"durations_log": o_dur_log,
|
||||||
'total_durations_log': o_attn_dur
|
"total_durations_log": o_attn_dur,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference_with_MAS(self,
|
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
||||||
x,
|
|
||||||
x_lengths,
|
|
||||||
y=None,
|
|
||||||
y_lengths=None,
|
|
||||||
attn=None,
|
|
||||||
g=None):
|
|
||||||
"""
|
"""
|
||||||
It's similar to the teacher forcing in Tacotron.
|
It's similar to the teacher forcing in Tacotron.
|
||||||
It was proposed in: https://arxiv.org/abs/2104.05557
|
It was proposed in: https://arxiv.org/abs/2104.05557
|
||||||
|
@ -229,33 +214,24 @@ class GlowTTS(nn.Module):
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||||
x_lengths,
|
|
||||||
g=g)
|
|
||||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||||
y, y_lengths, y_max_length, attn = self.preprocess(
|
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
||||||
y, y_lengths, y_max_length, None)
|
|
||||||
# create masks
|
# create masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||||
1).to(x_mask.dtype)
|
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# decoder pass
|
# decoder pass
|
||||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||||
# find the alignment path between z and encoder output
|
# find the alignment path between z and encoder output
|
||||||
o_scale = torch.exp(-2 * o_log_scale)
|
o_scale = torch.exp(-2 * o_log_scale)
|
||||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
[1]).unsqueeze(-1) # [b, t, 1]
|
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
|
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
|
|
||||||
z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
|
||||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
|
|
||||||
[1]).unsqueeze(-1) # [b, t, 1]
|
|
||||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||||
attn, o_mean, o_log_scale, x_mask)
|
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
|
|
||||||
# get predited aligned distribution
|
# get predited aligned distribution
|
||||||
|
@ -264,13 +240,13 @@ class GlowTTS(nn.Module):
|
||||||
# reverse the decoder and predict using the aligned distribution
|
# reverse the decoder and predict using the aligned distribution
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
outputs = {
|
outputs = {
|
||||||
'model_outputs': y,
|
"model_outputs": y,
|
||||||
'logdet': logdet,
|
"logdet": logdet,
|
||||||
'y_mean': y_mean,
|
"y_mean": y_mean,
|
||||||
'y_log_scale': y_log_scale,
|
"y_log_scale": y_log_scale,
|
||||||
'alignments': attn,
|
"alignments": attn,
|
||||||
'durations_log': o_dur_log,
|
"durations_log": o_dur_log,
|
||||||
'total_durations_log': o_attn_dur
|
"total_durations_log": o_attn_dur,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -290,8 +266,7 @@ class GlowTTS(nn.Module):
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
|
||||||
1).to(y.dtype)
|
|
||||||
|
|
||||||
# decoder pass
|
# decoder pass
|
||||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||||
|
@ -310,37 +285,31 @@ class GlowTTS(nn.Module):
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||||
x_lengths,
|
|
||||||
g=g)
|
|
||||||
# compute output durations
|
# compute output durations
|
||||||
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
||||||
w_ceil = torch.ceil(w)
|
w_ceil = torch.ceil(w)
|
||||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
y_max_length = None
|
y_max_length = None
|
||||||
# compute masks
|
# compute masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||||
1).to(x_mask.dtype)
|
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# compute attention mask
|
# compute attention mask
|
||||||
attn = generate_path(w_ceil.squeeze(1),
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||||
attn_mask.squeeze(1)).unsqueeze(1)
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
|
||||||
attn, o_mean, o_log_scale, x_mask)
|
|
||||||
|
|
||||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
|
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask
|
||||||
self.inference_noise_scale) * y_mask
|
|
||||||
# decoder pass
|
# decoder pass
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
outputs = {
|
outputs = {
|
||||||
'model_outputs': y,
|
"model_outputs": y,
|
||||||
'logdet': logdet,
|
"logdet": logdet,
|
||||||
'y_mean': y_mean,
|
"y_mean": y_mean,
|
||||||
'y_log_scale': y_log_scale,
|
"y_log_scale": y_log_scale,
|
||||||
'alignments': attn,
|
"alignments": attn,
|
||||||
'durations_log': o_dur_log,
|
"durations_log": o_dur_log,
|
||||||
'total_durations_log': o_attn_dur
|
"total_durations_log": o_attn_dur,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -351,32 +320,34 @@ class GlowTTS(nn.Module):
|
||||||
batch (dict): [description]
|
batch (dict): [description]
|
||||||
criterion (nn.Module): [description]
|
criterion (nn.Module): [description]
|
||||||
"""
|
"""
|
||||||
text_input = batch['text_input']
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch['text_lengths']
|
text_lengths = batch["text_lengths"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
mel_lengths = batch['mel_lengths']
|
mel_lengths = batch["mel_lengths"]
|
||||||
x_vectors = batch['x_vectors']
|
x_vectors = batch["x_vectors"]
|
||||||
|
|
||||||
outputs = self.forward(text_input,
|
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input={"x_vectors": x_vectors})
|
||||||
text_lengths,
|
|
||||||
mel_input,
|
|
||||||
mel_lengths,
|
|
||||||
cond_input={"x_vectors": x_vectors})
|
|
||||||
|
|
||||||
loss_dict = criterion(outputs['model_outputs'], outputs['y_mean'],
|
loss_dict = criterion(
|
||||||
outputs['y_log_scale'], outputs['logdet'],
|
outputs["model_outputs"],
|
||||||
mel_lengths, outputs['durations_log'],
|
outputs["y_mean"],
|
||||||
outputs['total_durations_log'], text_lengths)
|
outputs["y_log_scale"],
|
||||||
|
outputs["logdet"],
|
||||||
|
mel_lengths,
|
||||||
|
outputs["durations_log"],
|
||||||
|
outputs["total_durations_log"],
|
||||||
|
text_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
# compute alignment error (the lower the better )
|
||||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'], binary=True)
|
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||||
loss_dict["align_error"] = align_error
|
loss_dict["align_error"] = align_error
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
model_outputs = outputs['model_outputs']
|
model_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs['alignments']
|
alignments = outputs["alignments"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
|
|
||||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
gt_spec = mel_input[0].data.cpu().numpy()
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
@ -400,8 +371,7 @@ class GlowTTS(nn.Module):
|
||||||
|
|
||||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||||
if y_max_length is not None:
|
if y_max_length is not None:
|
||||||
y_max_length = (y_max_length //
|
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
||||||
self.num_squeeze) * self.num_squeeze
|
|
||||||
y = y[:, :, :y_max_length]
|
y = y[:, :, :y_max_length]
|
||||||
if attn is not None:
|
if attn is not None:
|
||||||
attn = attn[:, :, :, :y_max_length]
|
attn = attn[:, :, :, :y_max_length]
|
||||||
|
@ -411,7 +381,9 @@ class GlowTTS(nn.Module):
|
||||||
def store_inverse(self):
|
def store_inverse(self):
|
||||||
self.decoder.store_inverse()
|
self.decoder.store_inverse()
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
def load_checkpoint(
|
||||||
|
self, config, checkpoint_path, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -49,12 +49,7 @@ class SpeedySpeech(nn.Module):
|
||||||
positional_encoding=True,
|
positional_encoding=True,
|
||||||
length_scale=1,
|
length_scale=1,
|
||||||
encoder_type="residual_conv_bn",
|
encoder_type="residual_conv_bn",
|
||||||
encoder_params={
|
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
|
||||||
"kernel_size": 4,
|
|
||||||
"dilations": 4 * [1, 2, 4] + [1],
|
|
||||||
"num_conv_blocks": 2,
|
|
||||||
"num_res_blocks": 13
|
|
||||||
},
|
|
||||||
decoder_type="residual_conv_bn",
|
decoder_type="residual_conv_bn",
|
||||||
decoder_params={
|
decoder_params={
|
||||||
"kernel_size": 4,
|
"kernel_size": 4,
|
||||||
|
@ -68,17 +63,13 @@ class SpeedySpeech(nn.Module):
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.length_scale = float(length_scale) if isinstance(
|
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||||
length_scale, int) else length_scale
|
|
||||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||||
encoder_params, c_in_channels)
|
|
||||||
if positional_encoding:
|
if positional_encoding:
|
||||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
|
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||||
decoder_params)
|
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
|
||||||
self.duration_predictor = DurationPredictor(hidden_channels +
|
|
||||||
c_in_channels)
|
|
||||||
|
|
||||||
if num_speakers > 1 and not external_c:
|
if num_speakers > 1 and not external_c:
|
||||||
# speaker embedding layer
|
# speaker embedding layer
|
||||||
|
@ -105,9 +96,7 @@ class SpeedySpeech(nn.Module):
|
||||||
"""
|
"""
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
||||||
o_en_ex = torch.matmul(
|
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||||
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
|
||||||
2)).transpose(1, 2)
|
|
||||||
return o_en_ex, attn
|
return o_en_ex, attn
|
||||||
|
|
||||||
def format_durations(self, o_dr_log, x_mask):
|
def format_durations(self, o_dr_log, x_mask):
|
||||||
|
@ -141,8 +130,7 @@ class SpeedySpeech(nn.Module):
|
||||||
x_emb = torch.transpose(x_emb, 1, -1)
|
x_emb = torch.transpose(x_emb, 1, -1)
|
||||||
|
|
||||||
# compute sequence masks
|
# compute sequence masks
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
||||||
1).to(x.dtype)
|
|
||||||
|
|
||||||
# encoder pass
|
# encoder pass
|
||||||
o_en = self.encoder(x_emb, x_mask)
|
o_en = self.encoder(x_emb, x_mask)
|
||||||
|
@ -155,8 +143,7 @@ class SpeedySpeech(nn.Module):
|
||||||
return o_en, o_en_dp, x_mask, g
|
return o_en, o_en_dp, x_mask, g
|
||||||
|
|
||||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||||
1).to(o_en_dp.dtype)
|
|
||||||
# expand o_en with durations
|
# expand o_en with durations
|
||||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||||
# positional encoding
|
# positional encoding
|
||||||
|
@ -169,15 +156,9 @@ class SpeedySpeech(nn.Module):
|
||||||
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||||
return o_de, attn.transpose(1, 2)
|
return o_de, attn.transpose(1, 2)
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
x,
|
self, x, x_lengths, y_lengths, dr, cond_input={"x_vectors": None, "speaker_ids": None}
|
||||||
x_lengths,
|
): # pylint: disable=unused-argument
|
||||||
y_lengths,
|
|
||||||
dr,
|
|
||||||
cond_input={
|
|
||||||
'x_vectors': None,
|
|
||||||
'speaker_ids': None
|
|
||||||
}): # pylint: disable=unused-argument
|
|
||||||
"""
|
"""
|
||||||
TODO: speaker embedding for speaker_ids
|
TODO: speaker embedding for speaker_ids
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -187,91 +168,68 @@ class SpeedySpeech(nn.Module):
|
||||||
dr: [B, T_max]
|
dr: [B, T_max]
|
||||||
g: [B, C]
|
g: [B, C]
|
||||||
"""
|
"""
|
||||||
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||||
o_de, attn = self._forward_decoder(o_en,
|
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
|
||||||
o_en_dp,
|
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
||||||
dr,
|
|
||||||
x_mask,
|
|
||||||
y_lengths,
|
|
||||||
g=g)
|
|
||||||
outputs = {
|
|
||||||
'model_outputs': o_de.transpose(1, 2),
|
|
||||||
'durations_log': o_dr_log.squeeze(1),
|
|
||||||
'alignments': attn
|
|
||||||
}
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def inference(self,
|
def inference(self, x, cond_input={"x_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||||
x,
|
|
||||||
cond_input={
|
|
||||||
'x_vectors': None,
|
|
||||||
'speaker_ids': None
|
|
||||||
}): # pylint: disable=unused-argument
|
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T_max]
|
x: [B, T_max]
|
||||||
x_lengths: [B]
|
x_lengths: [B]
|
||||||
g: [B, C]
|
g: [B, C]
|
||||||
"""
|
"""
|
||||||
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
# input sequence should be greated than the max convolution size
|
# input sequence should be greated than the max convolution size
|
||||||
inference_padding = 5
|
inference_padding = 5
|
||||||
if x.shape[1] < 13:
|
if x.shape[1] < 13:
|
||||||
inference_padding += 13 - x.shape[1]
|
inference_padding += 13 - x.shape[1]
|
||||||
# pad input to prevent dropping the last word
|
# pad input to prevent dropping the last word
|
||||||
x = torch.nn.functional.pad(x,
|
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
|
||||||
pad=(0, inference_padding),
|
|
||||||
mode="constant",
|
|
||||||
value=0)
|
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
# duration predictor pass
|
# duration predictor pass
|
||||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
y_lengths = o_dr.sum(1)
|
y_lengths = o_dr.sum(1)
|
||||||
o_de, attn = self._forward_decoder(o_en,
|
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||||
o_en_dp,
|
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "durations_log": None}
|
||||||
o_dr,
|
|
||||||
x_mask,
|
|
||||||
y_lengths,
|
|
||||||
g=g)
|
|
||||||
outputs = {
|
|
||||||
'model_outputs': o_de.transpose(1, 2),
|
|
||||||
'alignments': attn,
|
|
||||||
'durations_log': None
|
|
||||||
}
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def train_step(self, batch: dict, criterion: nn.Module):
|
def train_step(self, batch: dict, criterion: nn.Module):
|
||||||
text_input = batch['text_input']
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch['text_lengths']
|
text_lengths = batch["text_lengths"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
mel_lengths = batch['mel_lengths']
|
mel_lengths = batch["mel_lengths"]
|
||||||
x_vectors = batch['x_vectors']
|
x_vectors = batch["x_vectors"]
|
||||||
speaker_ids = batch['speaker_ids']
|
speaker_ids = batch["speaker_ids"]
|
||||||
durations = batch['durations']
|
durations = batch["durations"]
|
||||||
|
|
||||||
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids}
|
cond_input = {"x_vectors": x_vectors, "speaker_ids": speaker_ids}
|
||||||
outputs = self.forward(text_input, text_lengths, mel_lengths,
|
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, cond_input)
|
||||||
durations, cond_input)
|
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = criterion(outputs['model_outputs'], mel_input,
|
loss_dict = criterion(
|
||||||
mel_lengths, outputs['durations_log'],
|
outputs["model_outputs"],
|
||||||
torch.log(1 + durations), text_lengths)
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
outputs["durations_log"],
|
||||||
|
torch.log(1 + durations),
|
||||||
|
text_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
# compute alignment error (the lower the better )
|
||||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'],
|
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||||
binary=True)
|
|
||||||
loss_dict["align_error"] = align_error
|
loss_dict["align_error"] = align_error
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
model_outputs = outputs['model_outputs']
|
model_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs['alignments']
|
alignments = outputs["alignments"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
|
|
||||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
gt_spec = mel_input[0].data.cpu().numpy()
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
@ -293,7 +251,9 @@ class SpeedySpeech(nn.Module):
|
||||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
return self.train_log(ap, batch, outputs)
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
def load_checkpoint(
|
||||||
|
self, config, checkpoint_path, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -50,6 +50,7 @@ class Tacotron(TacotronAbstract):
|
||||||
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||||
Defaults to `[]`.
|
Defaults to `[]`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_chars,
|
num_chars,
|
||||||
|
@ -78,7 +79,7 @@ class Tacotron(TacotronAbstract):
|
||||||
use_gst=False,
|
use_gst=False,
|
||||||
gst=None,
|
gst=None,
|
||||||
memory_size=5,
|
memory_size=5,
|
||||||
gradual_training=[]
|
gradual_training=[],
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_chars,
|
num_chars,
|
||||||
|
@ -106,15 +107,14 @@ class Tacotron(TacotronAbstract):
|
||||||
speaker_embedding_dim,
|
speaker_embedding_dim,
|
||||||
use_gst,
|
use_gst,
|
||||||
gst,
|
gst,
|
||||||
gradual_training
|
gradual_training,
|
||||||
)
|
)
|
||||||
|
|
||||||
# speaker embedding layers
|
# speaker embedding layers
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
speaker_embedding_dim = 256
|
speaker_embedding_dim = 256
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers,
|
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
|
||||||
speaker_embedding_dim)
|
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
# speaker and gst embeddings is concat in decoder input
|
||||||
|
@ -145,8 +145,7 @@ class Tacotron(TacotronAbstract):
|
||||||
separate_stopnet,
|
separate_stopnet,
|
||||||
)
|
)
|
||||||
self.postnet = PostCBHG(decoder_output_dim)
|
self.postnet = PostCBHG(decoder_output_dim)
|
||||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
|
||||||
postnet_output_dim)
|
|
||||||
|
|
||||||
# setup prenet dropout
|
# setup prenet dropout
|
||||||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||||
|
@ -183,12 +182,7 @@ class Tacotron(TacotronAbstract):
|
||||||
separate_stopnet,
|
separate_stopnet,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
|
||||||
text,
|
|
||||||
text_lengths,
|
|
||||||
mel_specs=None,
|
|
||||||
mel_lengths=None,
|
|
||||||
cond_input=None):
|
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
text: [B, T_in]
|
text: [B, T_in]
|
||||||
|
@ -197,100 +191,87 @@ class Tacotron(TacotronAbstract):
|
||||||
mel_lengths: [B]
|
mel_lengths: [B]
|
||||||
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
||||||
"""
|
"""
|
||||||
outputs = {
|
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||||
'alignments_backward': None,
|
|
||||||
'decoder_outputs_backward': None
|
|
||||||
}
|
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
# B x T_in x embed_dim
|
# B x T_in x embed_dim
|
||||||
inputs = self.embedding(text)
|
inputs = self.embedding(text)
|
||||||
# B x T_in x encoder_in_features
|
# B x T_in x encoder_in_features
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||||
encoder_outputs)
|
|
||||||
# global style token
|
# global style token
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs,
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["x_vectors"])
|
||||||
cond_input['x_vectors'])
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:,
|
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
||||||
None]
|
|
||||||
else:
|
else:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||||
encoder_outputs, speaker_embeddings)
|
|
||||||
# decoder_outputs: B x decoder_in_features x T_out
|
# decoder_outputs: B x decoder_in_features x T_out
|
||||||
# alignments: B x T_in x encoder_in_features
|
# alignments: B x T_in x encoder_in_features
|
||||||
# stop_tokens: B x T_in
|
# stop_tokens: B x T_in
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
|
||||||
encoder_outputs, mel_specs, input_mask)
|
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if output_mask is not None:
|
if output_mask is not None:
|
||||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(
|
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||||
1).expand_as(decoder_outputs)
|
|
||||||
# B x T_out x decoder_in_features
|
# B x T_out x decoder_in_features
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if output_mask is not None:
|
if output_mask is not None:
|
||||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(
|
postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs)
|
||||||
2).expand_as(postnet_outputs)
|
|
||||||
# B x T_out x posnet_dim
|
# B x T_out x posnet_dim
|
||||||
postnet_outputs = self.last_linear(postnet_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
# B x T_out x decoder_in_features
|
# B x T_out x decoder_in_features
|
||||||
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
decoder_outputs_backward, alignments_backward = self._backward_pass(
|
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||||
mel_specs, encoder_outputs, input_mask)
|
outputs["alignments_backward"] = alignments_backward
|
||||||
outputs['alignments_backward'] = alignments_backward
|
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
|
||||||
if self.double_decoder_consistency:
|
if self.double_decoder_consistency:
|
||||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||||
mel_specs, encoder_outputs, alignments, input_mask)
|
mel_specs, encoder_outputs, alignments, input_mask
|
||||||
outputs['alignments_backward'] = alignments_backward
|
)
|
||||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
outputs["alignments_backward"] = alignments_backward
|
||||||
outputs.update({
|
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||||
'model_outputs': postnet_outputs,
|
outputs.update(
|
||||||
'decoder_outputs': decoder_outputs,
|
{
|
||||||
'alignments': alignments,
|
"model_outputs": postnet_outputs,
|
||||||
'stop_tokens': stop_tokens
|
"decoder_outputs": decoder_outputs,
|
||||||
})
|
"alignments": alignments,
|
||||||
|
"stop_tokens": stop_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self,
|
def inference(self, text_input, cond_input=None):
|
||||||
text_input,
|
|
||||||
cond_input=None):
|
|
||||||
inputs = self.embedding(text_input)
|
inputs = self.embedding(text_input)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'],
|
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["x_vectors"])
|
||||||
cond_input['x_vectors'])
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
||||||
else:
|
else:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||||
encoder_outputs, speaker_embeddings)
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
|
||||||
encoder_outputs)
|
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = self.last_linear(postnet_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||||
outputs = {
|
outputs = {
|
||||||
'model_outputs': postnet_outputs,
|
"model_outputs": postnet_outputs,
|
||||||
'decoder_outputs': decoder_outputs,
|
"decoder_outputs": decoder_outputs,
|
||||||
'alignments': alignments,
|
"alignments": alignments,
|
||||||
'stop_tokens': stop_tokens
|
"stop_tokens": stop_tokens,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -301,64 +282,61 @@ class Tacotron(TacotronAbstract):
|
||||||
batch ([type]): [description]
|
batch ([type]): [description]
|
||||||
criterion ([type]): [description]
|
criterion ([type]): [description]
|
||||||
"""
|
"""
|
||||||
text_input = batch['text_input']
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch['text_lengths']
|
text_lengths = batch["text_lengths"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
mel_lengths = batch['mel_lengths']
|
mel_lengths = batch["mel_lengths"]
|
||||||
linear_input = batch['linear_input']
|
linear_input = batch["linear_input"]
|
||||||
stop_targets = batch['stop_targets']
|
stop_targets = batch["stop_targets"]
|
||||||
speaker_ids = batch['speaker_ids']
|
speaker_ids = batch["speaker_ids"]
|
||||||
x_vectors = batch['x_vectors']
|
x_vectors = batch["x_vectors"]
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
outputs = self.forward(text_input,
|
outputs = self.forward(
|
||||||
text_lengths,
|
text_input,
|
||||||
mel_input,
|
text_lengths,
|
||||||
mel_lengths,
|
mel_input,
|
||||||
cond_input={
|
mel_lengths,
|
||||||
'speaker_ids': speaker_ids,
|
cond_input={"speaker_ids": speaker_ids, "x_vectors": x_vectors},
|
||||||
'x_vectors': x_vectors
|
)
|
||||||
})
|
|
||||||
|
|
||||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||||
if mel_lengths.max() % self.decoder.r != 0:
|
if mel_lengths.max() % self.decoder.r != 0:
|
||||||
alignment_lengths = (
|
alignment_lengths = (
|
||||||
mel_lengths +
|
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
|
||||||
(self.decoder.r -
|
) // self.decoder.r
|
||||||
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
|
|
||||||
else:
|
else:
|
||||||
alignment_lengths = mel_lengths // self.decoder.r
|
alignment_lengths = mel_lengths // self.decoder.r
|
||||||
|
|
||||||
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors}
|
cond_input = {"speaker_ids": speaker_ids, "x_vectors": x_vectors}
|
||||||
outputs = self.forward(text_input, text_lengths, mel_input,
|
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||||
mel_lengths, cond_input)
|
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = criterion(
|
loss_dict = criterion(
|
||||||
outputs['model_outputs'],
|
outputs["model_outputs"],
|
||||||
outputs['decoder_outputs'],
|
outputs["decoder_outputs"],
|
||||||
mel_input,
|
mel_input,
|
||||||
linear_input,
|
linear_input,
|
||||||
outputs['stop_tokens'],
|
outputs["stop_tokens"],
|
||||||
stop_targets,
|
stop_targets,
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
outputs['decoder_outputs_backward'],
|
outputs["decoder_outputs_backward"],
|
||||||
outputs['alignments'],
|
outputs["alignments"],
|
||||||
alignment_lengths,
|
alignment_lengths,
|
||||||
outputs['alignments_backward'],
|
outputs["alignments_backward"],
|
||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
# compute alignment error (the lower the better )
|
||||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'])
|
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
|
||||||
loss_dict["align_error"] = align_error
|
loss_dict["align_error"] = align_error
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap, batch, outputs):
|
def train_log(self, ap, batch, outputs):
|
||||||
postnet_outputs = outputs['model_outputs']
|
postnet_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs['alignments']
|
alignments = outputs["alignments"]
|
||||||
alignments_backward = outputs['alignments_backward']
|
alignments_backward = outputs["alignments_backward"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
|
|
||||||
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
||||||
gt_spec = mel_input[0].data.cpu().numpy()
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
@ -371,8 +349,7 @@ class Tacotron(TacotronAbstract):
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.bidirectional_decoder or self.double_decoder_consistency:
|
if self.bidirectional_decoder or self.double_decoder_consistency:
|
||||||
figures["alignment_backward"] = plot_alignment(
|
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||||
alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
train_audio = ap.inv_spectrogram(pred_spec.T)
|
train_audio = ap.inv_spectrogram(pred_spec.T)
|
||||||
|
|
|
@ -49,49 +49,70 @@ class Tacotron2(TacotronAbstract):
|
||||||
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||||
Defaults to `[]`.
|
Defaults to `[]`.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
num_chars,
|
def __init__(
|
||||||
num_speakers,
|
self,
|
||||||
r,
|
num_chars,
|
||||||
postnet_output_dim=80,
|
num_speakers,
|
||||||
decoder_output_dim=80,
|
r,
|
||||||
attn_type="original",
|
postnet_output_dim=80,
|
||||||
attn_win=False,
|
decoder_output_dim=80,
|
||||||
attn_norm="softmax",
|
attn_type="original",
|
||||||
prenet_type="original",
|
attn_win=False,
|
||||||
prenet_dropout=True,
|
attn_norm="softmax",
|
||||||
prenet_dropout_at_inference=False,
|
prenet_type="original",
|
||||||
forward_attn=False,
|
prenet_dropout=True,
|
||||||
trans_agent=False,
|
prenet_dropout_at_inference=False,
|
||||||
forward_attn_mask=False,
|
forward_attn=False,
|
||||||
location_attn=True,
|
trans_agent=False,
|
||||||
attn_K=5,
|
forward_attn_mask=False,
|
||||||
separate_stopnet=True,
|
location_attn=True,
|
||||||
bidirectional_decoder=False,
|
attn_K=5,
|
||||||
double_decoder_consistency=False,
|
separate_stopnet=True,
|
||||||
ddc_r=None,
|
bidirectional_decoder=False,
|
||||||
encoder_in_features=512,
|
double_decoder_consistency=False,
|
||||||
decoder_in_features=512,
|
ddc_r=None,
|
||||||
speaker_embedding_dim=None,
|
encoder_in_features=512,
|
||||||
use_gst=False,
|
decoder_in_features=512,
|
||||||
gst=None,
|
speaker_embedding_dim=None,
|
||||||
gradual_training=[]):
|
use_gst=False,
|
||||||
super().__init__(num_chars, num_speakers, r, postnet_output_dim,
|
gst=None,
|
||||||
decoder_output_dim, attn_type, attn_win, attn_norm,
|
gradual_training=[],
|
||||||
prenet_type, prenet_dropout,
|
):
|
||||||
prenet_dropout_at_inference, forward_attn,
|
super().__init__(
|
||||||
trans_agent, forward_attn_mask, location_attn, attn_K,
|
num_chars,
|
||||||
separate_stopnet, bidirectional_decoder,
|
num_speakers,
|
||||||
double_decoder_consistency, ddc_r,
|
r,
|
||||||
encoder_in_features, decoder_in_features,
|
postnet_output_dim,
|
||||||
speaker_embedding_dim, use_gst, gst, gradual_training)
|
decoder_output_dim,
|
||||||
|
attn_type,
|
||||||
|
attn_win,
|
||||||
|
attn_norm,
|
||||||
|
prenet_type,
|
||||||
|
prenet_dropout,
|
||||||
|
prenet_dropout_at_inference,
|
||||||
|
forward_attn,
|
||||||
|
trans_agent,
|
||||||
|
forward_attn_mask,
|
||||||
|
location_attn,
|
||||||
|
attn_K,
|
||||||
|
separate_stopnet,
|
||||||
|
bidirectional_decoder,
|
||||||
|
double_decoder_consistency,
|
||||||
|
ddc_r,
|
||||||
|
encoder_in_features,
|
||||||
|
decoder_in_features,
|
||||||
|
speaker_embedding_dim,
|
||||||
|
use_gst,
|
||||||
|
gst,
|
||||||
|
gradual_training,
|
||||||
|
)
|
||||||
|
|
||||||
# speaker embedding layer
|
# speaker embedding layer
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
speaker_embedding_dim = 512
|
speaker_embedding_dim = 512
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers,
|
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
|
||||||
speaker_embedding_dim)
|
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
# speaker and gst embeddings is concat in decoder input
|
||||||
|
@ -162,12 +183,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self,
|
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
|
||||||
text,
|
|
||||||
text_lengths,
|
|
||||||
mel_specs=None,
|
|
||||||
mel_lengths=None,
|
|
||||||
cond_input=None):
|
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
text: [B, T_in]
|
text: [B, T_in]
|
||||||
|
@ -176,10 +192,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
mel_lengths: [B]
|
mel_lengths: [B]
|
||||||
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
||||||
"""
|
"""
|
||||||
outputs = {
|
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||||
'alignments_backward': None,
|
|
||||||
'decoder_outputs_backward': None
|
|
||||||
}
|
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
|
@ -189,55 +202,49 @@ class Tacotron2(TacotronAbstract):
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs,
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["x_vectors"])
|
||||||
cond_input['x_vectors'])
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:,
|
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
||||||
None]
|
|
||||||
else:
|
else:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||||
encoder_outputs, speaker_embeddings)
|
|
||||||
|
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||||
encoder_outputs)
|
|
||||||
|
|
||||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
|
||||||
encoder_outputs, mel_specs, input_mask)
|
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if mel_lengths is not None:
|
if mel_lengths is not None:
|
||||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(
|
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||||
1).expand_as(decoder_outputs)
|
|
||||||
# B x mel_dim x T_out
|
# B x mel_dim x T_out
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = decoder_outputs + postnet_outputs
|
postnet_outputs = decoder_outputs + postnet_outputs
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if output_mask is not None:
|
if output_mask is not None:
|
||||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(
|
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
|
||||||
1).expand_as(postnet_outputs)
|
|
||||||
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
|
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
|
||||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
||||||
decoder_outputs, postnet_outputs, alignments)
|
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
decoder_outputs_backward, alignments_backward = self._backward_pass(
|
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||||
mel_specs, encoder_outputs, input_mask)
|
outputs["alignments_backward"] = alignments_backward
|
||||||
outputs['alignments_backward'] = alignments_backward
|
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
|
||||||
if self.double_decoder_consistency:
|
if self.double_decoder_consistency:
|
||||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||||
mel_specs, encoder_outputs, alignments, input_mask)
|
mel_specs, encoder_outputs, alignments, input_mask
|
||||||
outputs['alignments_backward'] = alignments_backward
|
)
|
||||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
outputs["alignments_backward"] = alignments_backward
|
||||||
outputs.update({
|
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||||
'model_outputs': postnet_outputs,
|
outputs.update(
|
||||||
'decoder_outputs': decoder_outputs,
|
{
|
||||||
'alignments': alignments,
|
"model_outputs": postnet_outputs,
|
||||||
'stop_tokens': stop_tokens
|
"decoder_outputs": decoder_outputs,
|
||||||
})
|
"alignments": alignments,
|
||||||
|
"stop_tokens": stop_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -247,29 +254,25 @@ class Tacotron2(TacotronAbstract):
|
||||||
|
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'],
|
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["x_vectors"])
|
||||||
cond_input['x_vectors'])
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
||||||
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
|
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
x_vector = cond_input['x_vectors']
|
x_vector = cond_input["x_vectors"]
|
||||||
|
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, x_vector)
|
||||||
encoder_outputs, x_vector)
|
|
||||||
|
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||||
encoder_outputs)
|
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = decoder_outputs + postnet_outputs
|
postnet_outputs = decoder_outputs + postnet_outputs
|
||||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
||||||
decoder_outputs, postnet_outputs, alignments)
|
|
||||||
outputs = {
|
outputs = {
|
||||||
'model_outputs': postnet_outputs,
|
"model_outputs": postnet_outputs,
|
||||||
'decoder_outputs': decoder_outputs,
|
"decoder_outputs": decoder_outputs,
|
||||||
'alignments': alignments,
|
"alignments": alignments,
|
||||||
'stop_tokens': stop_tokens
|
"stop_tokens": stop_tokens,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -280,64 +283,61 @@ class Tacotron2(TacotronAbstract):
|
||||||
batch ([type]): [description]
|
batch ([type]): [description]
|
||||||
criterion ([type]): [description]
|
criterion ([type]): [description]
|
||||||
"""
|
"""
|
||||||
text_input = batch['text_input']
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch['text_lengths']
|
text_lengths = batch["text_lengths"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
mel_lengths = batch['mel_lengths']
|
mel_lengths = batch["mel_lengths"]
|
||||||
linear_input = batch['linear_input']
|
linear_input = batch["linear_input"]
|
||||||
stop_targets = batch['stop_targets']
|
stop_targets = batch["stop_targets"]
|
||||||
speaker_ids = batch['speaker_ids']
|
speaker_ids = batch["speaker_ids"]
|
||||||
x_vectors = batch['x_vectors']
|
x_vectors = batch["x_vectors"]
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
outputs = self.forward(text_input,
|
outputs = self.forward(
|
||||||
text_lengths,
|
text_input,
|
||||||
mel_input,
|
text_lengths,
|
||||||
mel_lengths,
|
mel_input,
|
||||||
cond_input={
|
mel_lengths,
|
||||||
'speaker_ids': speaker_ids,
|
cond_input={"speaker_ids": speaker_ids, "x_vectors": x_vectors},
|
||||||
'x_vectors': x_vectors
|
)
|
||||||
})
|
|
||||||
|
|
||||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||||
if mel_lengths.max() % self.decoder.r != 0:
|
if mel_lengths.max() % self.decoder.r != 0:
|
||||||
alignment_lengths = (
|
alignment_lengths = (
|
||||||
mel_lengths +
|
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
|
||||||
(self.decoder.r -
|
) // self.decoder.r
|
||||||
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
|
|
||||||
else:
|
else:
|
||||||
alignment_lengths = mel_lengths // self.decoder.r
|
alignment_lengths = mel_lengths // self.decoder.r
|
||||||
|
|
||||||
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors}
|
cond_input = {"speaker_ids": speaker_ids, "x_vectors": x_vectors}
|
||||||
outputs = self.forward(text_input, text_lengths, mel_input,
|
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||||
mel_lengths, cond_input)
|
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = criterion(
|
loss_dict = criterion(
|
||||||
outputs['model_outputs'],
|
outputs["model_outputs"],
|
||||||
outputs['decoder_outputs'],
|
outputs["decoder_outputs"],
|
||||||
mel_input,
|
mel_input,
|
||||||
linear_input,
|
linear_input,
|
||||||
outputs['stop_tokens'],
|
outputs["stop_tokens"],
|
||||||
stop_targets,
|
stop_targets,
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
outputs['decoder_outputs_backward'],
|
outputs["decoder_outputs_backward"],
|
||||||
outputs['alignments'],
|
outputs["alignments"],
|
||||||
alignment_lengths,
|
alignment_lengths,
|
||||||
outputs['alignments_backward'],
|
outputs["alignments_backward"],
|
||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
# compute alignment error (the lower the better )
|
||||||
align_error = 1 - alignment_diagonal_score(outputs['alignments'])
|
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
|
||||||
loss_dict["align_error"] = align_error
|
loss_dict["align_error"] = align_error
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap, batch, outputs):
|
def train_log(self, ap, batch, outputs):
|
||||||
postnet_outputs = outputs['model_outputs']
|
postnet_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs['alignments']
|
alignments = outputs["alignments"]
|
||||||
alignments_backward = outputs['alignments_backward']
|
alignments_backward = outputs["alignments_backward"]
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch["mel_input"]
|
||||||
|
|
||||||
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
||||||
gt_spec = mel_input[0].data.cpu().numpy()
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
@ -350,8 +350,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.bidirectional_decoder or self.double_decoder_consistency:
|
if self.bidirectional_decoder or self.double_decoder_consistency:
|
||||||
figures["alignment_backward"] = plot_alignment(
|
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||||
alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
|
|
|
@ -37,7 +37,7 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
speaker_embedding_dim=None,
|
speaker_embedding_dim=None,
|
||||||
use_gst=False,
|
use_gst=False,
|
||||||
gst=None,
|
gst=None,
|
||||||
gradual_training=[]
|
gradual_training=[],
|
||||||
):
|
):
|
||||||
"""Abstract Tacotron class"""
|
"""Abstract Tacotron class"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -59,25 +59,16 @@ def numpy_to_tf(np_array, dtype):
|
||||||
|
|
||||||
|
|
||||||
def compute_style_mel(style_wav, ap, cuda=False):
|
def compute_style_mel(style_wav, ap, cuda=False):
|
||||||
style_mel = torch.FloatTensor(
|
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
|
||||||
ap.melspectrogram(ap.load_wav(style_wav,
|
|
||||||
sr=ap.sample_rate))).unsqueeze(0)
|
|
||||||
if cuda:
|
if cuda:
|
||||||
return style_mel.cuda()
|
return style_mel.cuda()
|
||||||
return style_mel
|
return style_mel
|
||||||
|
|
||||||
|
|
||||||
def run_model_torch(model,
|
def run_model_torch(model, inputs, speaker_id=None, style_mel=None, x_vector=None):
|
||||||
inputs,
|
outputs = model.inference(
|
||||||
speaker_id=None,
|
inputs, cond_input={"speaker_ids": speaker_id, "x_vector": x_vector, "style_mel": style_mel}
|
||||||
style_mel=None,
|
)
|
||||||
x_vector=None):
|
|
||||||
outputs = model.inference(inputs,
|
|
||||||
cond_input={
|
|
||||||
'speaker_ids': speaker_id,
|
|
||||||
'x_vector': x_vector,
|
|
||||||
'style_mel': style_mel
|
|
||||||
})
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,18 +78,15 @@ def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
|
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
|
||||||
# TODO: handle multispeaker case
|
# TODO: handle multispeaker case
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
|
||||||
inputs, training=False)
|
|
||||||
return decoder_output, postnet_output, alignments, stop_tokens
|
return decoder_output, postnet_output, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||||
if CONFIG.gst and style_mel is not None:
|
if CONFIG.gst and style_mel is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
|
||||||
" [!] GST inference not implemented for TfLite")
|
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
|
||||||
" [!] Multi-Speaker not implemented for TfLite")
|
|
||||||
# get input and output details
|
# get input and output details
|
||||||
input_details = model.get_input_details()
|
input_details = model.get_input_details()
|
||||||
output_details = model.get_output_details()
|
output_details = model.get_output_details()
|
||||||
|
@ -132,7 +120,7 @@ def parse_outputs_tflite(postnet_output, decoder_output):
|
||||||
|
|
||||||
|
|
||||||
def trim_silence(wav, ap):
|
def trim_silence(wav, ap):
|
||||||
return wav[:ap.find_endpoint(wav)]
|
return wav[: ap.find_endpoint(wav)]
|
||||||
|
|
||||||
|
|
||||||
def inv_spectrogram(postnet_output, ap, CONFIG):
|
def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||||
|
@ -155,8 +143,7 @@ def speaker_id_to_torch(speaker_id, cuda=False):
|
||||||
def embedding_to_torch(x_vector, cuda=False):
|
def embedding_to_torch(x_vector, cuda=False):
|
||||||
if x_vector is not None:
|
if x_vector is not None:
|
||||||
x_vector = np.asarray(x_vector)
|
x_vector = np.asarray(x_vector)
|
||||||
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(
|
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(torch.FloatTensor)
|
||||||
torch.FloatTensor)
|
|
||||||
if cuda:
|
if cuda:
|
||||||
return x_vector.cuda()
|
return x_vector.cuda()
|
||||||
return x_vector
|
return x_vector
|
||||||
|
@ -173,8 +160,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
||||||
"""
|
"""
|
||||||
wavs = []
|
wavs = []
|
||||||
for idx, spec in enumerate(inputs):
|
for idx, spec in enumerate(inputs):
|
||||||
wav_len = (input_lens[idx] *
|
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
|
||||||
ap.hop_length) - ap.hop_length # inverse librosa padding
|
|
||||||
wav = inv_spectrogram(spec, ap, CONFIG)
|
wav = inv_spectrogram(spec, ap, CONFIG)
|
||||||
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
|
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
|
||||||
wavs.append(wav[:wav_len])
|
wavs.append(wav[:wav_len])
|
||||||
|
@ -242,23 +228,21 @@ def synthesis(
|
||||||
text_inputs = tf.expand_dims(text_inputs, 0)
|
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
if backend == "torch":
|
if backend == "torch":
|
||||||
outputs = run_model_torch(model,
|
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, x_vector=x_vector)
|
||||||
text_inputs,
|
model_outputs = outputs["model_outputs"]
|
||||||
speaker_id,
|
|
||||||
style_mel,
|
|
||||||
x_vector=x_vector)
|
|
||||||
model_outputs = outputs['model_outputs']
|
|
||||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||||
elif backend == "tf":
|
elif backend == "tf":
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||||
model, text_inputs, CONFIG, speaker_id, style_mel)
|
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||||
|
)
|
||||||
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||||
postnet_output, decoder_output, alignments, stop_tokens)
|
postnet_output, decoder_output, alignments, stop_tokens
|
||||||
|
)
|
||||||
elif backend == "tflite":
|
elif backend == "tflite":
|
||||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||||
model, text_inputs, CONFIG, speaker_id, style_mel)
|
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||||
model_outputs, decoder_output = parse_outputs_tflite(
|
)
|
||||||
postnet_output, decoder_output)
|
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
# plot results
|
# plot results
|
||||||
wav = None
|
wav = None
|
||||||
|
@ -268,9 +252,9 @@ def synthesis(
|
||||||
if do_trim_silence:
|
if do_trim_silence:
|
||||||
wav = trim_silence(wav, ap)
|
wav = trim_silence(wav, ap)
|
||||||
return_dict = {
|
return_dict = {
|
||||||
'wav': wav,
|
"wav": wav,
|
||||||
'alignments': outputs['alignments'],
|
"alignments": outputs["alignments"],
|
||||||
'model_outputs': model_outputs,
|
"model_outputs": model_outputs,
|
||||||
'text_inputs': text_inputs
|
"text_inputs": text_inputs,
|
||||||
}
|
}
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
|
@ -30,16 +30,16 @@ def init_arguments(argv):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--continue_path",
|
"--continue_path",
|
||||||
type=str,
|
type=str,
|
||||||
help=("Training output folder to continue training. Used to continue "
|
help=(
|
||||||
"a training. If it is used, 'config_path' is ignored."),
|
"Training output folder to continue training. Used to continue "
|
||||||
|
"a training. If it is used, 'config_path' is ignored."
|
||||||
|
),
|
||||||
default="",
|
default="",
|
||||||
required="--config_path" not in argv,
|
required="--config_path" not in argv,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--restore_path",
|
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||||
type=str,
|
)
|
||||||
help="Model file to be restored. Use to finetune a model.",
|
|
||||||
default="")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--best_path",
|
"--best_path",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -49,23 +49,12 @@ def init_arguments(argv):
|
||||||
),
|
),
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
parser.add_argument("--config_path",
|
|
||||||
type=str,
|
|
||||||
help="Path to config file for training.",
|
|
||||||
required="--continue_path" not in argv)
|
|
||||||
parser.add_argument("--debug",
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="Do not verify commit integrity to run training.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rank",
|
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
|
||||||
type=int,
|
)
|
||||||
default=0,
|
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
|
||||||
help="DISTRIBUTED: process rank for distributed training.")
|
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
|
||||||
parser.add_argument("--group_id",
|
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="DISTRIBUTED: process group id.")
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -160,8 +149,7 @@ def process_args(args):
|
||||||
print(" > Mixed precision mode is ON")
|
print(" > Mixed precision mode is ON")
|
||||||
experiment_path = args.continue_path
|
experiment_path = args.continue_path
|
||||||
if not experiment_path:
|
if not experiment_path:
|
||||||
experiment_path = create_experiment_folder(config.output_path,
|
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
|
||||||
config.run_name, args.debug)
|
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
audio_path = os.path.join(experiment_path, "test_audios")
|
||||||
# setup rank 0 process in distributed training
|
# setup rank 0 process in distributed training
|
||||||
tb_logger = None
|
tb_logger = None
|
||||||
|
@ -182,8 +170,7 @@ def process_args(args):
|
||||||
os.chmod(experiment_path, 0o775)
|
os.chmod(experiment_path, 0o775)
|
||||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>",
|
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
0)
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||||
|
|
||||||
|
|
|
@ -234,8 +234,8 @@ class Synthesizer(object):
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
x_vector=speaker_embedding,
|
x_vector=speaker_embedding,
|
||||||
)
|
)
|
||||||
waveform = outputs['wav']
|
waveform = outputs["wav"]
|
||||||
mel_postnet_spec = outputs['model_outputs']
|
mel_postnet_spec = outputs["model_outputs"]
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
# denormalize tts output based on tts audio config
|
# denormalize tts output based on tts audio config
|
||||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||||
|
|
|
@ -44,8 +44,6 @@ run_cli(command_train)
|
||||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
# restore the model and continue training for one more epoch
|
# restore the model and continue training for one more epoch
|
||||||
command_train = (
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
|
||||||
)
|
|
||||||
run_cli(command_train)
|
run_cli(command_train)
|
||||||
shutil.rmtree(continue_path)
|
shutil.rmtree(continue_path)
|
||||||
|
|
|
@ -46,8 +46,6 @@ run_cli(command_train)
|
||||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
# restore the model and continue training for one more epoch
|
# restore the model and continue training for one more epoch
|
||||||
command_train = (
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
|
||||||
)
|
|
||||||
run_cli(command_train)
|
run_cli(command_train)
|
||||||
shutil.rmtree(continue_path)
|
shutil.rmtree(continue_path)
|
||||||
|
|
|
@ -45,8 +45,6 @@ run_cli(command_train)
|
||||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
# restore the model and continue training for one more epoch
|
# restore the model and continue training for one more epoch
|
||||||
command_train = (
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
|
||||||
)
|
|
||||||
run_cli(command_train)
|
run_cli(command_train)
|
||||||
shutil.rmtree(continue_path)
|
shutil.rmtree(continue_path)
|
||||||
|
|
|
@ -45,8 +45,6 @@ run_cli(command_train)
|
||||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
# restore the model and continue training for one more epoch
|
# restore the model and continue training for one more epoch
|
||||||
command_train = (
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
|
||||||
)
|
|
||||||
run_cli(command_train)
|
run_cli(command_train)
|
||||||
shutil.rmtree(continue_path)
|
shutil.rmtree(continue_path)
|
||||||
|
|
|
@ -44,8 +44,6 @@ run_cli(command_train)
|
||||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
# restore the model and continue training for one more epoch
|
# restore the model and continue training for one more epoch
|
||||||
command_train = (
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
|
||||||
)
|
|
||||||
run_cli(command_train)
|
run_cli(command_train)
|
||||||
shutil.rmtree(continue_path)
|
shutil.rmtree(continue_path)
|
||||||
|
|
|
@ -21,7 +21,6 @@ config = MelganConfig(
|
||||||
print_step=1,
|
print_step=1,
|
||||||
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
|
||||||
data_path="tests/data/ljspeech",
|
data_path="tests/data/ljspeech",
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue