make style

This commit is contained in:
Eren Gölge 2021-05-27 17:25:00 +02:00
parent 469d2e620a
commit b500338faa
25 changed files with 524 additions and 747 deletions

View File

@ -153,15 +153,9 @@ def inference(
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
elif "tacotron" in model_name:
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': speaker_embeddings}
outputs = model(
text_input,
text_lengths,
mel_input,
mel_lengths,
cond_input
)
postnet_outputs = outputs['model_outputs']
cond_input = {"speaker_ids": speaker_ids, "x_vectors": speaker_embeddings}
outputs = model(text_input, text_lengths, mel_input, mel_lengths, cond_input)
postnet_outputs = outputs["model_outputs"]
# normalize tacotron output
if model_name == "tacotron":
mel_specs = []

View File

@ -8,13 +8,8 @@ from TTS.trainer import TrainerTTS
def main():
# try:
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
sys.argv)
trainer = TrainerTTS(args,
config,
c_logger,
tb_logger,
output_path=OUT_PATH)
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
trainer.fit()
# except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH)

View File

@ -84,7 +84,7 @@ class TrainerTTS:
self.best_loss = float("inf")
self.train_loader = None
self.eval_loader = None
self.output_audio_path = os.path.join(output_path, 'test_audios')
self.output_audio_path = os.path.join(output_path, "test_audios")
self.keep_avg_train = None
self.keep_avg_eval = None
@ -138,8 +138,8 @@ class TrainerTTS:
if self.args.restore_path:
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
self.config, args.restore_path, self.model, self.optimizer,
self.scaler)
self.config, args.restore_path, self.model, self.optimizer, self.scaler
)
# setup scheduler
self.scheduler = self.get_scheduler(self.config, self.optimizer)
@ -207,6 +207,7 @@ class TrainerTTS:
return None
if lr_scheduler.lower() == "noamlr":
from TTS.utils.training import NoamLR
scheduler = NoamLR
else:
scheduler = getattr(torch.optim, lr_scheduler)
@ -261,8 +262,7 @@ class TrainerTTS:
ap=ap,
tp=self.config.characters,
add_blank=self.config["add_blank"],
batch_group_size=0 if is_eval else
self.config.batch_group_size * self.config.batch_size,
batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
min_seq_len=self.config.min_seq_len,
max_seq_len=self.config.max_seq_len,
phoneme_cache_path=self.config.phoneme_cache_path,
@ -272,8 +272,8 @@ class TrainerTTS:
use_noise_augment=not is_eval,
verbose=verbose,
speaker_mapping=speaker_mapping
if self.config.use_speaker_embedding
and self.config.use_external_speaker_embedding_file else None,
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
else None,
)
if self.config.use_phonemes and self.config.compute_input_seq_cache:
@ -281,18 +281,15 @@ class TrainerTTS:
dataset.compute_input_seq(self.config.num_loader_workers)
dataset.sort_items()
sampler = DistributedSampler(
dataset) if self.num_gpus > 1 else None
sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=self.config.eval_batch_size
if is_eval else self.config.batch_size,
batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=self.config.num_val_loader_workers
if is_eval else self.config.num_loader_workers,
num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
pin_memory=False,
)
return loader
@ -314,8 +311,7 @@ class TrainerTTS:
text_input = batch[0]
text_lengths = batch[1]
speaker_names = batch[2]
linear_input = batch[3] if self.config.model.lower() in ["tacotron"
] else None
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
mel_input = batch[4]
mel_lengths = batch[5]
stop_targets = batch[6]
@ -331,10 +327,7 @@ class TrainerTTS:
speaker_embeddings = batch[8]
speaker_ids = None
else:
speaker_ids = [
self.speaker_manager.speaker_ids[speaker_name]
for speaker_name in speaker_names
]
speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None
else:
@ -346,7 +339,7 @@ class TrainerTTS:
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
for idx, am in enumerate(attn_mask):
# compute raw durations
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
@ -359,14 +352,11 @@ class TrainerTTS:
assert (
dur.sum() == mel_lengths[idx]
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, :text_lengths[idx]] = dur
durations[idx, : text_lengths[idx]] = dur
# set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // self.config.r,
-1)
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze(2)
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
# dispatch batch to GPU
if self.use_cuda:
@ -374,15 +364,10 @@ class TrainerTTS:
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(
non_blocking=True) if self.config.model.lower() in [
"tacotron"
] else None
linear_input = linear_input.cuda(non_blocking=True) if self.config.model.lower() in ["tacotron"] else None
stop_targets = stop_targets.cuda(non_blocking=True)
attn_mask = attn_mask.cuda(
non_blocking=True) if attn_mask is not None else None
durations = durations.cuda(
non_blocking=True) if attn_mask is not None else None
attn_mask = attn_mask.cuda(non_blocking=True) if attn_mask is not None else None
durations = durations.cuda(non_blocking=True) if attn_mask is not None else None
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
if speaker_embeddings is not None:
@ -401,7 +386,7 @@ class TrainerTTS:
"x_vectors": speaker_embeddings,
"max_text_length": max_text_length,
"max_spec_length": max_spec_length,
"item_idx": item_idx
"item_idx": item_idx,
}
def train_step(self, batch: Dict, batch_n_steps: int, step: int,
@ -421,25 +406,20 @@ class TrainerTTS:
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(
f"Detected NaN loss at step {self.total_steps_done}.")
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
# optimizer step
if self.config.mixed_precision:
# model optimizer step in mixed precision mode
self.scaler.scale(loss_dict["loss"]).backward()
self.scaler.unscale_(self.optimizer)
grad_norm, _ = check_update(self.model,
self.config.grad_clip,
ignore_stopnet=True)
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# main model optimizer step
loss_dict["loss"].backward()
grad_norm, _ = check_update(self.model,
self.config.grad_clip,
ignore_stopnet=True)
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
self.optimizer.step()
step_time = time.time() - step_start_time
@ -469,17 +449,15 @@ class TrainerTTS:
current_lr = self.optimizer.param_groups[0]["lr"]
if self.total_steps_done % self.config.print_step == 0:
log_dict = {
"max_spec_length": [batch["max_spec_length"],
1], # value, precision
"max_spec_length": [batch["max_spec_length"], 1], # value, precision
"max_text_length": [batch["max_text_length"], 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
self.c_logger.print_train_step(batch_n_steps, step,
self.total_steps_done, log_dict,
loss_dict,
self.keep_avg_train.avg_values)
self.c_logger.print_train_step(
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
)
if self.args.rank == 0:
# Plot Training Iter Stats
@ -491,8 +469,7 @@ class TrainerTTS:
"step_time": step_time,
}
iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done,
iter_stats)
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
if self.total_steps_done % self.config.save_step == 0:
if self.config.checkpoint:
@ -506,15 +483,12 @@ class TrainerTTS:
self.output_path,
model_loss=loss_dict["loss"],
characters=self.model_characters,
scaler=self.scaler.state_dict()
if self.config.mixed_precision else None,
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
)
# training visualizations
figures, audios = self.model.train_log(self.ap, batch, outputs)
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.tb_logger.tb_train_audios(self.total_steps_done,
{"TrainAudio": audios},
self.ap.sample_rate)
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
self.total_steps_done += 1
self.on_train_step_end()
return outputs, loss_dict
@ -523,35 +497,28 @@ class TrainerTTS:
self.model.train()
epoch_start_time = time.time()
if self.use_cuda:
batch_num_steps = int(
len(self.train_loader.dataset) /
(self.config.batch_size * self.num_gpus))
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
else:
batch_num_steps = int(
len(self.train_loader.dataset) / self.config.batch_size)
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start()
loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader):
_, _ = self.train_step(batch, batch_num_steps, cur_step,
loader_start_time)
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats
if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done,
epoch_stats)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model,
self.total_steps_done)
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
with torch.no_grad():
step_start_time = time.time()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self.model.eval_step(
batch, self.criterion)
outputs, loss_dict = self.model.eval_step(batch, self.criterion)
step_time = time.time() - step_start_time
@ -572,8 +539,7 @@ class TrainerTTS:
self.keep_avg_eval.update_values(update_eval_values)
if self.config.print_eval:
self.c_logger.print_eval_step(step, loss_dict,
self.keep_avg_eval.avg_values)
self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
return outputs, loss_dict
def eval_epoch(self) -> None:
@ -585,15 +551,13 @@ class TrainerTTS:
# format data
batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({'avg_loader_time': loader_time})
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
outputs, _ = self.eval_step(batch, cur_step)
# Plot epoch stats and samples from the last batch.
if self.args.rank == 0:
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.tb_logger.tb_eval_audios(self.total_steps_done,
{"EvalAudio": eval_audios},
self.ap.sample_rate)
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
def test_run(self, ) -> None:
print(" | > Synthesizing test sentences.")
@ -608,9 +572,9 @@ class TrainerTTS:
self.config,
self.use_cuda,
self.ap,
speaker_id=cond_inputs['speaker_id'],
x_vector=cond_inputs['x_vector'],
style_wav=cond_inputs['style_wav'],
speaker_id=cond_inputs["speaker_id"],
x_vector=cond_inputs["x_vector"],
style_wav=cond_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
@ -623,10 +587,8 @@ class TrainerTTS:
"TestSentence_{}.wav".format(idx))
self.ap.save_wav(wav, file_path)
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
model_outputs, self.ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(
alignment, output_fig=False)
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
self.config.audio["sample_rate"])
@ -641,11 +603,11 @@ class TrainerTTS:
if self.config.use_external_speaker_embedding_file
and self.config.use_speaker_embedding else None)
# setup style_mel
if self.config.has('gst_style_input'):
if self.config.has("gst_style_input"):
style_wav = self.config.gst_style_input
else:
style_wav = None
if style_wav is None and 'use_gst' in self.config and self.config.use_gst:
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
# inicialize GST with zero dict.
style_wav = {}
print(
@ -688,8 +650,7 @@ class TrainerTTS:
for epoch in range(0, self.config.epochs):
self.on_epoch_start()
self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage(
) if self.config.run_eval else None
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs)
self.train_epoch()
@ -698,8 +659,8 @@ class TrainerTTS:
if epoch >= self.config.test_delay_epochs:
self.test_run()
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values
if self.config.run_eval else self.keep_avg_train.avg_values)
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
)
self.save_best_model()
self.on_epoch_end()
@ -717,8 +678,7 @@ class TrainerTTS:
self.model_characters,
keep_all_best=self.config.keep_all_best,
keep_after=self.config.keep_after,
scaler=self.scaler.state_dict()
if self.config.mixed_precision else None,
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
)
@staticmethod

View File

@ -93,7 +93,7 @@ class AlignTTSConfig(BaseTTSConfig):
# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = None
lr_scheduler_params: dict = None
lr: float = 1e-4
@ -104,12 +104,13 @@ class AlignTTSConfig(BaseTTSConfig):
max_seq_len: int = 200
r: int = 1
# testing
test_sentences: List[str] = field(default_factory=lambda:[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
])
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)

View File

@ -91,9 +91,9 @@ class GlowTTSConfig(BaseTTSConfig):
# optimizer parameters
optimizer: str = "RAdam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
grad_clip: float = 5.0
lr: float = 1e-3

View File

@ -174,7 +174,7 @@ class BaseTTSConfig(BaseTrainingConfig):
optimizer: str = MISSING
optimizer_params: dict = MISSING
# scheduler
lr_scheduler: str = ''
lr_scheduler: str = ""
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
test_sentences: List[str] = field(default_factory=lambda:[])
test_sentences: List[str] = field(default_factory=lambda: [])

View File

@ -101,7 +101,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
# optimizer parameters
optimizer: str = "RAdam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = None
lr_scheduler_params: dict = None
lr: float = 1e-4
@ -118,10 +118,12 @@ class SpeedySpeechConfig(BaseTTSConfig):
r: int = 1 # DO NOT CHANGE
# testing
test_sentences: List[str] = field(default_factory=lambda:[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
])
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)

View File

@ -160,9 +160,9 @@ class TacotronConfig(BaseTTSConfig):
# optimizer parameters
optimizer: str = "RAdam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 5.0
seq_len_norm: bool = False
@ -178,13 +178,15 @@ class TacotronConfig(BaseTTSConfig):
ga_alpha: float = 5.0
# testing
test_sentences: List[str] = field(default_factory=lambda:[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
])
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)
def check_values(self):
if self.gradual_training:

View File

@ -44,22 +44,18 @@ def load_meta_data(datasets, eval_split=True):
preprocessor = _get_preprocessor_by_name(name)
# load train set
meta_data_train = preprocessor(root_path, meta_file_train)
print(
f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}"
)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = preprocessor(root_path, meta_file_val)
else:
meta_data_eval, meta_data_train = split_dataset(
meta_data_train)
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
# load attention masks for duration predictor training
if dataset.meta_file_attn_mask:
meta_data = dict(
load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file)

View File

@ -506,5 +506,10 @@ class AlignTTSLoss(nn.Module):
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
loss = self.spec_loss_alpha * spec_loss + self.ssim_alpha * ssim_loss + self.dur_loss_alpha * dur_loss + self.mdn_alpha * mdn_loss
loss = (
self.spec_loss_alpha * spec_loss
+ self.ssim_alpha * ssim_loss
+ self.dur_loss_alpha * dur_loss
+ self.mdn_alpha * mdn_loss
)
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}

View File

@ -72,19 +72,9 @@ class AlignTTS(nn.Module):
hidden_channels=256,
hidden_channels_dp=256,
encoder_type="fftransformer",
encoder_params={
"hidden_channels_ffn": 1024,
"num_heads": 2,
"num_layers": 6,
"dropout_p": 0.1
},
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
decoder_type="fftransformer",
decoder_params={
"hidden_channels_ffn": 1024,
"num_heads": 2,
"num_layers": 6,
"dropout_p": 0.1
},
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
length_scale=1,
num_speakers=0,
external_c=False,
@ -93,14 +83,11 @@ class AlignTTS(nn.Module):
super().__init__()
self.phase = -1
self.length_scale = float(length_scale) if isinstance(
length_scale, int) else length_scale
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
self.emb = nn.Embedding(num_chars, hidden_channels)
self.pos_encoder = PositionalEncoding(hidden_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
encoder_params, c_in_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
decoder_params)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels_dp)
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
@ -121,9 +108,9 @@ class AlignTTS(nn.Module):
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss(
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2),
dim=-1) # B, L, T
exponential = -0.5 * torch.mean(
torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1
) # B, L, T
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
return logp
@ -157,9 +144,7 @@ class AlignTTS(nn.Module):
[1, 0, 0, 0, 0, 0, 0]]
"""
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
o_en_ex = torch.matmul(
attn.squeeze(1).transpose(1, 2), en.transpose(1,
2)).transpose(1, 2)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
@ -193,8 +178,7 @@ class AlignTTS(nn.Module):
x_emb = torch.transpose(x_emb, 1, -1)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
# encoder pass
o_en = self.encoder(x_emb, x_mask)
@ -207,8 +191,7 @@ class AlignTTS(nn.Module):
return o_en, o_en_dp, x_mask, g
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
@ -224,13 +207,13 @@ class AlignTTS(nn.Module):
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
# MAS potentials and alignment
mu, log_sigma = self.mdn_block(o_en)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en.dtype)
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
y_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
return dr_mas, mu, log_sigma, logp
def forward(self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None): # pylint: disable=unused-argument
def forward(
self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None
): # pylint: disable=unused-argument
"""
Shapes:
x: [B, T_max]
@ -240,83 +223,58 @@ class AlignTTS(nn.Module):
g: [B, C]
"""
y = y.transpose(1, 2)
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
if phase == 0:
# train encoder and MDN
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
elif phase == 1:
# train decoder
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en.detach(),
o_en_dp.detach(),
dr_mas.detach(),
x_mask,
y_lengths,
g=g)
o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g)
elif phase == 2:
# train the whole except duration predictor
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
elif phase == 3:
# train duration predictor
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(x, x_mask)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
o_dr_log = o_dr_log.squeeze(1)
else:
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
o_dr_log = o_dr_log.squeeze(1)
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
outputs = {
'model_outputs': o_de.transpose(1, 2),
'alignments': attn,
'durations_log': o_dr_log,
'durations_mas_log': dr_mas_log,
'mu': mu,
'log_sigma': log_sigma,
'logp': logp
"model_outputs": o_de.transpose(1, 2),
"alignments": attn,
"durations_log": o_dr_log,
"durations_mas_log": dr_mas_log,
"mu": mu,
"log_sigma": log_sigma,
"logp": logp,
}
return outputs
@torch.no_grad()
def inference(self, x, cond_input={'x_vectors': None}): # pylint: disable=unused-argument
def inference(self, x, cond_input={"x_vectors": None}): # pylint: disable=unused-argument
"""
Shapes:
x: [B, T_max]
x_lengths: [B]
g: [B, C]
"""
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# pad input to prevent dropping the last word
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
@ -326,46 +284,40 @@ class AlignTTS(nn.Module):
# duration predictor pass
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
o_dr,
x_mask,
y_lengths,
g=g)
outputs = {'model_outputs': o_de.transpose(1, 2), 'alignments': attn}
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn}
return outputs
def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch['text_input']
text_lengths = batch['text_lengths']
mel_input = batch['mel_input']
mel_lengths = batch['mel_lengths']
x_vectors = batch['x_vectors']
speaker_ids = batch['speaker_ids']
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
x_vectors = batch["x_vectors"]
speaker_ids = batch["speaker_ids"]
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids}
cond_input = {"x_vectors": x_vectors, "speaker_ids": speaker_ids}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input, self.phase)
loss_dict = criterion(
outputs['logp'],
outputs['model_outputs'],
mel_input,
mel_lengths,
outputs['durations_log'],
outputs['durations_mas_log'],
text_lengths,
phase=self.phase,
)
outputs["logp"],
outputs["model_outputs"],
mel_input,
mel_lengths,
outputs["durations_log"],
outputs["durations_mas_log"],
text_lengths,
phase=self.phase,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'],
binary=True)
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
model_outputs = outputs['model_outputs']
alignments = outputs['alignments']
mel_input = batch['mel_input']
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]
pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
@ -387,7 +339,9 @@ class AlignTTS(nn.Module):
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:

View File

@ -38,6 +38,7 @@ class GlowTTS(nn.Module):
encoder_params (dict): encoder module parameters.
speaker_embedding_dim (int): channels of external speaker embedding vectors.
"""
def __init__(
self,
num_chars,
@ -132,17 +133,17 @@ class GlowTTS(nn.Module):
@staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
# compute final values with the computed alignment
y_mean = torch.matmul(
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
y_log_scale = torch.matmul(
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
# compute total duration with adjustment
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
return y_mean, y_log_scale, o_attn_dur
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={'x_vectors':None}):
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={"x_vectors": None}):
"""
Shapes:
x: [B, T]
@ -154,7 +155,7 @@ class GlowTTS(nn.Module):
y_max_length = y.size(2)
y = y.transpose(1, 2)
# norm speaker embeddings
g = cond_input['x_vectors']
g = cond_input["x_vectors"]
if g is not None:
if self.speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
@ -162,54 +163,38 @@ class GlowTTS(nn.Module):
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
g=g)
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths.
y, y_lengths, y_max_length, attn = self.preprocess(
y, y_lengths, y_max_length, None)
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
# create masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
1).to(x_mask.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# find the alignment path
with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp,
attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
attn = attn.squeeze(1).permute(0, 2, 1)
outputs = {
'model_outputs': z,
'logdet': logdet,
'y_mean': y_mean,
'y_log_scale': y_log_scale,
'alignments': attn,
'durations_log': o_dur_log,
'total_durations_log': o_attn_dur
"model_outputs": z,
"logdet": logdet,
"y_mean": y_mean,
"y_log_scale": y_log_scale,
"alignments": attn,
"durations_log": o_dur_log,
"total_durations_log": o_attn_dur,
}
return outputs
@torch.no_grad()
def inference_with_MAS(self,
x,
x_lengths,
y=None,
y_lengths=None,
attn=None,
g=None):
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
"""
It's similar to the teacher forcing in Tacotron.
It was proposed in: https://arxiv.org/abs/2104.05557
@ -229,33 +214,24 @@ class GlowTTS(nn.Module):
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
g=g)
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths.
y, y_lengths, y_max_length, attn = self.preprocess(
y, y_lengths, y_max_length, None)
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
# create masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
1).to(x_mask.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# find the alignment path between z and encoder output
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
attn = attn.squeeze(1).permute(0, 2, 1)
# get predited aligned distribution
@ -264,13 +240,13 @@ class GlowTTS(nn.Module):
# reverse the decoder and predict using the aligned distribution
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
outputs = {
'model_outputs': y,
'logdet': logdet,
'y_mean': y_mean,
'y_log_scale': y_log_scale,
'alignments': attn,
'durations_log': o_dur_log,
'total_durations_log': o_attn_dur
"model_outputs": y,
"logdet": logdet,
"y_mean": y_mean,
"y_log_scale": y_log_scale,
"alignments": attn,
"durations_log": o_dur_log,
"total_durations_log": o_attn_dur,
}
return outputs
@ -290,8 +266,7 @@ class GlowTTS(nn.Module):
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
1).to(y.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
# decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
@ -310,37 +285,31 @@ class GlowTTS(nn.Module):
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
g=g)
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# compute output durations
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = None
# compute masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
1).to(x_mask.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# compute attention mask
attn = generate_path(w_ceil.squeeze(1),
attn_mask.squeeze(1)).unsqueeze(1)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
self.inference_noise_scale) * y_mask
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask
# decoder pass
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
attn = attn.squeeze(1).permute(0, 2, 1)
outputs = {
'model_outputs': y,
'logdet': logdet,
'y_mean': y_mean,
'y_log_scale': y_log_scale,
'alignments': attn,
'durations_log': o_dur_log,
'total_durations_log': o_attn_dur
"model_outputs": y,
"logdet": logdet,
"y_mean": y_mean,
"y_log_scale": y_log_scale,
"alignments": attn,
"durations_log": o_dur_log,
"total_durations_log": o_attn_dur,
}
return outputs
@ -351,32 +320,34 @@ class GlowTTS(nn.Module):
batch (dict): [description]
criterion (nn.Module): [description]
"""
text_input = batch['text_input']
text_lengths = batch['text_lengths']
mel_input = batch['mel_input']
mel_lengths = batch['mel_lengths']
x_vectors = batch['x_vectors']
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
x_vectors = batch["x_vectors"]
outputs = self.forward(text_input,
text_lengths,
mel_input,
mel_lengths,
cond_input={"x_vectors": x_vectors})
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input={"x_vectors": x_vectors})
loss_dict = criterion(outputs['model_outputs'], outputs['y_mean'],
outputs['y_log_scale'], outputs['logdet'],
mel_lengths, outputs['durations_log'],
outputs['total_durations_log'], text_lengths)
loss_dict = criterion(
outputs["model_outputs"],
outputs["y_mean"],
outputs["y_log_scale"],
outputs["logdet"],
mel_lengths,
outputs["durations_log"],
outputs["total_durations_log"],
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'], binary=True)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
model_outputs = outputs['model_outputs']
alignments = outputs['alignments']
mel_input = batch['mel_input']
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]
pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
@ -400,8 +371,7 @@ class GlowTTS(nn.Module):
def preprocess(self, y, y_lengths, y_max_length, attn=None):
if y_max_length is not None:
y_max_length = (y_max_length //
self.num_squeeze) * self.num_squeeze
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
y = y[:, :, :y_max_length]
if attn is not None:
attn = attn[:, :, :, :y_max_length]
@ -411,7 +381,9 @@ class GlowTTS(nn.Module):
def store_inverse(self):
self.decoder.store_inverse()
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:

View File

@ -49,12 +49,7 @@ class SpeedySpeech(nn.Module):
positional_encoding=True,
length_scale=1,
encoder_type="residual_conv_bn",
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
decoder_type="residual_conv_bn",
decoder_params={
"kernel_size": 4,
@ -68,17 +63,13 @@ class SpeedySpeech(nn.Module):
):
super().__init__()
self.length_scale = float(length_scale) if isinstance(
length_scale, int) else length_scale
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
self.emb = nn.Embedding(num_chars, hidden_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
encoder_params, c_in_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
if positional_encoding:
self.pos_encoder = PositionalEncoding(hidden_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels +
c_in_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
if num_speakers > 1 and not external_c:
# speaker embedding layer
@ -105,9 +96,7 @@ class SpeedySpeech(nn.Module):
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
o_en_ex = torch.matmul(
attn.squeeze(1).transpose(1, 2), en.transpose(1,
2)).transpose(1, 2)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
@ -141,8 +130,7 @@ class SpeedySpeech(nn.Module):
x_emb = torch.transpose(x_emb, 1, -1)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
# encoder pass
o_en = self.encoder(x_emb, x_mask)
@ -155,8 +143,7 @@ class SpeedySpeech(nn.Module):
return o_en, o_en_dp, x_mask, g
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
@ -169,15 +156,9 @@ class SpeedySpeech(nn.Module):
o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de, attn.transpose(1, 2)
def forward(self,
x,
x_lengths,
y_lengths,
dr,
cond_input={
'x_vectors': None,
'speaker_ids': None
}): # pylint: disable=unused-argument
def forward(
self, x, x_lengths, y_lengths, dr, cond_input={"x_vectors": None, "speaker_ids": None}
): # pylint: disable=unused-argument
"""
TODO: speaker embedding for speaker_ids
Shapes:
@ -187,91 +168,68 @@ class SpeedySpeech(nn.Module):
dr: [B, T_max]
g: [B, C]
"""
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr,
x_mask,
y_lengths,
g=g)
outputs = {
'model_outputs': o_de.transpose(1, 2),
'durations_log': o_dr_log.squeeze(1),
'alignments': attn
}
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
return outputs
def inference(self,
x,
cond_input={
'x_vectors': None,
'speaker_ids': None
}): # pylint: disable=unused-argument
def inference(self, x, cond_input={"x_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
"""
Shapes:
x: [B, T_max]
x_lengths: [B]
g: [B, C]
"""
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# input sequence should be greated than the max convolution size
inference_padding = 5
if x.shape[1] < 13:
inference_padding += 13 - x.shape[1]
# pad input to prevent dropping the last word
x = torch.nn.functional.pad(x,
pad=(0, inference_padding),
mode="constant",
value=0)
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
o_dr,
x_mask,
y_lengths,
g=g)
outputs = {
'model_outputs': o_de.transpose(1, 2),
'alignments': attn,
'durations_log': None
}
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "durations_log": None}
return outputs
def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch['text_input']
text_lengths = batch['text_lengths']
mel_input = batch['mel_input']
mel_lengths = batch['mel_lengths']
x_vectors = batch['x_vectors']
speaker_ids = batch['speaker_ids']
durations = batch['durations']
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
x_vectors = batch["x_vectors"]
speaker_ids = batch["speaker_ids"]
durations = batch["durations"]
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids}
outputs = self.forward(text_input, text_lengths, mel_lengths,
durations, cond_input)
cond_input = {"x_vectors": x_vectors, "speaker_ids": speaker_ids}
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, cond_input)
# compute loss
loss_dict = criterion(outputs['model_outputs'], mel_input,
mel_lengths, outputs['durations_log'],
torch.log(1 + durations), text_lengths)
loss_dict = criterion(
outputs["model_outputs"],
mel_input,
mel_lengths,
outputs["durations_log"],
torch.log(1 + durations),
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'],
binary=True)
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
model_outputs = outputs['model_outputs']
alignments = outputs['alignments']
mel_input = batch['mel_input']
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]
pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
@ -293,7 +251,9 @@ class SpeedySpeech(nn.Module):
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:

View File

@ -50,6 +50,7 @@ class Tacotron(TacotronAbstract):
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
Defaults to `[]`.
"""
def __init__(
self,
num_chars,
@ -78,7 +79,7 @@ class Tacotron(TacotronAbstract):
use_gst=False,
gst=None,
memory_size=5,
gradual_training=[]
gradual_training=[],
):
super().__init__(
num_chars,
@ -106,15 +107,14 @@ class Tacotron(TacotronAbstract):
speaker_embedding_dim,
use_gst,
gst,
gradual_training
gradual_training,
)
# speaker embedding layers
if self.num_speakers > 1:
if not self.embeddings_per_sample:
speaker_embedding_dim = 256
self.speaker_embedding = nn.Embedding(self.num_speakers,
speaker_embedding_dim)
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
# speaker and gst embeddings is concat in decoder input
@ -145,8 +145,7 @@ class Tacotron(TacotronAbstract):
separate_stopnet,
)
self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
postnet_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
# setup prenet dropout
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
@ -183,12 +182,7 @@ class Tacotron(TacotronAbstract):
separate_stopnet,
)
def forward(self,
text,
text_lengths,
mel_specs=None,
mel_lengths=None,
cond_input=None):
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
"""
Shapes:
text: [B, T_in]
@ -197,100 +191,87 @@ class Tacotron(TacotronAbstract):
mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
"""
outputs = {
'alignments_backward': None,
'decoder_outputs_backward': None
}
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x T_in x embed_dim
inputs = self.embedding(text)
# B x T_in x encoder_in_features
encoder_outputs = self.encoder(inputs)
# sequence masking
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(
encoder_outputs)
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# global style token
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs,
cond_input['x_vectors'])
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["x_vectors"])
# speaker embedding
if self.num_speakers > 1:
if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:,
None]
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
else:
# B x 1 x speaker_embed_dim
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, speaker_embeddings)
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
# decoder_outputs: B x decoder_in_features x T_out
# alignments: B x T_in x encoder_in_features
# stop_tokens: B x T_in
decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, input_mask)
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
# sequence masking
if output_mask is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze(
1).expand_as(decoder_outputs)
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
# B x T_out x decoder_in_features
postnet_outputs = self.postnet(decoder_outputs)
# sequence masking
if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze(
2).expand_as(postnet_outputs)
postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs)
# B x T_out x posnet_dim
postnet_outputs = self.last_linear(postnet_outputs)
# B x T_out x decoder_in_features
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_pass(
mel_specs, encoder_outputs, input_mask)
outputs['alignments_backward'] = alignments_backward
outputs['decoder_outputs_backward'] = decoder_outputs_backward
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
outputs["alignments_backward"] = alignments_backward
outputs["decoder_outputs_backward"] = decoder_outputs_backward
if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
mel_specs, encoder_outputs, alignments, input_mask)
outputs['alignments_backward'] = alignments_backward
outputs['decoder_outputs_backward'] = decoder_outputs_backward
outputs.update({
'model_outputs': postnet_outputs,
'decoder_outputs': decoder_outputs,
'alignments': alignments,
'stop_tokens': stop_tokens
})
mel_specs, encoder_outputs, alignments, input_mask
)
outputs["alignments_backward"] = alignments_backward
outputs["decoder_outputs_backward"] = decoder_outputs_backward
outputs.update(
{
"model_outputs": postnet_outputs,
"decoder_outputs": decoder_outputs,
"alignments": alignments,
"stop_tokens": stop_tokens,
}
)
return outputs
@torch.no_grad()
def inference(self,
text_input,
cond_input=None):
def inference(self, text_input, cond_input=None):
inputs = self.embedding(text_input)
encoder_outputs = self.encoder(inputs)
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'],
cond_input['x_vectors'])
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["x_vectors"])
if self.num_speakers > 1:
if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
else:
# B x 1 x speaker_embed_dim
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, speaker_embeddings)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = self.last_linear(postnet_outputs)
decoder_outputs = decoder_outputs.transpose(1, 2)
outputs = {
'model_outputs': postnet_outputs,
'decoder_outputs': decoder_outputs,
'alignments': alignments,
'stop_tokens': stop_tokens
"model_outputs": postnet_outputs,
"decoder_outputs": decoder_outputs,
"alignments": alignments,
"stop_tokens": stop_tokens,
}
return outputs
@ -301,64 +282,61 @@ class Tacotron(TacotronAbstract):
batch ([type]): [description]
criterion ([type]): [description]
"""
text_input = batch['text_input']
text_lengths = batch['text_lengths']
mel_input = batch['mel_input']
mel_lengths = batch['mel_lengths']
linear_input = batch['linear_input']
stop_targets = batch['stop_targets']
speaker_ids = batch['speaker_ids']
x_vectors = batch['x_vectors']
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
stop_targets = batch["stop_targets"]
speaker_ids = batch["speaker_ids"]
x_vectors = batch["x_vectors"]
# forward pass model
outputs = self.forward(text_input,
text_lengths,
mel_input,
mel_lengths,
cond_input={
'speaker_ids': speaker_ids,
'x_vectors': x_vectors
})
outputs = self.forward(
text_input,
text_lengths,
mel_input,
mel_lengths,
cond_input={"speaker_ids": speaker_ids, "x_vectors": x_vectors},
)
# set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % self.decoder.r != 0:
alignment_lengths = (
mel_lengths +
(self.decoder.r -
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
) // self.decoder.r
else:
alignment_lengths = mel_lengths // self.decoder.r
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors}
outputs = self.forward(text_input, text_lengths, mel_input,
mel_lengths, cond_input)
cond_input = {"speaker_ids": speaker_ids, "x_vectors": x_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
# compute loss
loss_dict = criterion(
outputs['model_outputs'],
outputs['decoder_outputs'],
outputs["model_outputs"],
outputs["decoder_outputs"],
mel_input,
linear_input,
outputs['stop_tokens'],
outputs["stop_tokens"],
stop_targets,
mel_lengths,
outputs['decoder_outputs_backward'],
outputs['alignments'],
outputs["decoder_outputs_backward"],
outputs["alignments"],
alignment_lengths,
outputs['alignments_backward'],
outputs["alignments_backward"],
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'])
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap, batch, outputs):
postnet_outputs = outputs['model_outputs']
alignments = outputs['alignments']
alignments_backward = outputs['alignments_backward']
mel_input = batch['mel_input']
postnet_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
alignments_backward = outputs["alignments_backward"]
mel_input = batch["mel_input"]
pred_spec = postnet_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
@ -371,8 +349,7 @@ class Tacotron(TacotronAbstract):
}
if self.bidirectional_decoder or self.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment(
alignments_backward[0].data.cpu().numpy(), output_fig=False)
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
# Sample audio
train_audio = ap.inv_spectrogram(pred_spec.T)
@ -382,4 +359,4 @@ class Tacotron(TacotronAbstract):
return self.train_step(batch, criterion)
def eval_log(self, ap, batch, outputs):
return self.train_log(ap, batch, outputs)
return self.train_log(ap, batch, outputs)

View File

@ -49,49 +49,70 @@ class Tacotron2(TacotronAbstract):
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
Defaults to `[]`.
"""
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type="original",
attn_win=False,
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
prenet_dropout_at_inference=False,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
use_gst=False,
gst=None,
gradual_training=[]):
super().__init__(num_chars, num_speakers, r, postnet_output_dim,
decoder_output_dim, attn_type, attn_win, attn_norm,
prenet_type, prenet_dropout,
prenet_dropout_at_inference, forward_attn,
trans_agent, forward_attn_mask, location_attn, attn_K,
separate_stopnet, bidirectional_decoder,
double_decoder_consistency, ddc_r,
encoder_in_features, decoder_in_features,
speaker_embedding_dim, use_gst, gst, gradual_training)
def __init__(
self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type="original",
attn_win=False,
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
prenet_dropout_at_inference=False,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
use_gst=False,
gst=None,
gradual_training=[],
):
super().__init__(
num_chars,
num_speakers,
r,
postnet_output_dim,
decoder_output_dim,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
prenet_dropout_at_inference,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
bidirectional_decoder,
double_decoder_consistency,
ddc_r,
encoder_in_features,
decoder_in_features,
speaker_embedding_dim,
use_gst,
gst,
gradual_training,
)
# speaker embedding layer
if self.num_speakers > 1:
if not self.embeddings_per_sample:
speaker_embedding_dim = 512
self.speaker_embedding = nn.Embedding(self.num_speakers,
speaker_embedding_dim)
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
# speaker and gst embeddings is concat in decoder input
@ -162,12 +183,7 @@ class Tacotron2(TacotronAbstract):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments
def forward(self,
text,
text_lengths,
mel_specs=None,
mel_lengths=None,
cond_input=None):
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
"""
Shapes:
text: [B, T_in]
@ -176,10 +192,7 @@ class Tacotron2(TacotronAbstract):
mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
"""
outputs = {
'alignments_backward': None,
'decoder_outputs_backward': None
}
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
# compute mask for padding
# B x T_in_max (boolean)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
@ -189,55 +202,49 @@ class Tacotron2(TacotronAbstract):
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs,
cond_input['x_vectors'])
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["x_vectors"])
if self.num_speakers > 1:
if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:,
None]
speaker_embeddings = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
else:
# B x 1 x speaker_embed_dim
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, speaker_embeddings)
speaker_embeddings = torch.unsqueeze(cond_input["x_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(
encoder_outputs)
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, input_mask)
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
# sequence masking
if mel_lengths is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze(
1).expand_as(decoder_outputs)
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
# B x mel_dim x T_out
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs
# sequence masking
if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze(
1).expand_as(postnet_outputs)
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
decoder_outputs, postnet_outputs, alignments)
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_pass(
mel_specs, encoder_outputs, input_mask)
outputs['alignments_backward'] = alignments_backward
outputs['decoder_outputs_backward'] = decoder_outputs_backward
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
outputs["alignments_backward"] = alignments_backward
outputs["decoder_outputs_backward"] = decoder_outputs_backward
if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
mel_specs, encoder_outputs, alignments, input_mask)
outputs['alignments_backward'] = alignments_backward
outputs['decoder_outputs_backward'] = decoder_outputs_backward
outputs.update({
'model_outputs': postnet_outputs,
'decoder_outputs': decoder_outputs,
'alignments': alignments,
'stop_tokens': stop_tokens
})
mel_specs, encoder_outputs, alignments, input_mask
)
outputs["alignments_backward"] = alignments_backward
outputs["decoder_outputs_backward"] = decoder_outputs_backward
outputs.update(
{
"model_outputs": postnet_outputs,
"decoder_outputs": decoder_outputs,
"alignments": alignments,
"stop_tokens": stop_tokens,
}
)
return outputs
@torch.no_grad()
@ -247,29 +254,25 @@ class Tacotron2(TacotronAbstract):
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'],
cond_input['x_vectors'])
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["x_vectors"])
if self.num_speakers > 1:
if not self.embeddings_per_sample:
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
else:
x_vector = cond_input['x_vectors']
x_vector = cond_input["x_vectors"]
encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, x_vector)
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, x_vector)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
decoder_outputs, postnet_outputs, alignments)
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
outputs = {
'model_outputs': postnet_outputs,
'decoder_outputs': decoder_outputs,
'alignments': alignments,
'stop_tokens': stop_tokens
"model_outputs": postnet_outputs,
"decoder_outputs": decoder_outputs,
"alignments": alignments,
"stop_tokens": stop_tokens,
}
return outputs
@ -280,64 +283,61 @@ class Tacotron2(TacotronAbstract):
batch ([type]): [description]
criterion ([type]): [description]
"""
text_input = batch['text_input']
text_lengths = batch['text_lengths']
mel_input = batch['mel_input']
mel_lengths = batch['mel_lengths']
linear_input = batch['linear_input']
stop_targets = batch['stop_targets']
speaker_ids = batch['speaker_ids']
x_vectors = batch['x_vectors']
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
stop_targets = batch["stop_targets"]
speaker_ids = batch["speaker_ids"]
x_vectors = batch["x_vectors"]
# forward pass model
outputs = self.forward(text_input,
text_lengths,
mel_input,
mel_lengths,
cond_input={
'speaker_ids': speaker_ids,
'x_vectors': x_vectors
})
outputs = self.forward(
text_input,
text_lengths,
mel_input,
mel_lengths,
cond_input={"speaker_ids": speaker_ids, "x_vectors": x_vectors},
)
# set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % self.decoder.r != 0:
alignment_lengths = (
mel_lengths +
(self.decoder.r -
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
) // self.decoder.r
else:
alignment_lengths = mel_lengths // self.decoder.r
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors}
outputs = self.forward(text_input, text_lengths, mel_input,
mel_lengths, cond_input)
cond_input = {"speaker_ids": speaker_ids, "x_vectors": x_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
# compute loss
loss_dict = criterion(
outputs['model_outputs'],
outputs['decoder_outputs'],
outputs["model_outputs"],
outputs["decoder_outputs"],
mel_input,
linear_input,
outputs['stop_tokens'],
outputs["stop_tokens"],
stop_targets,
mel_lengths,
outputs['decoder_outputs_backward'],
outputs['alignments'],
outputs["decoder_outputs_backward"],
outputs["alignments"],
alignment_lengths,
outputs['alignments_backward'],
outputs["alignments_backward"],
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'])
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap, batch, outputs):
postnet_outputs = outputs['model_outputs']
alignments = outputs['alignments']
alignments_backward = outputs['alignments_backward']
mel_input = batch['mel_input']
postnet_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
alignments_backward = outputs["alignments_backward"]
mel_input = batch["mel_input"]
pred_spec = postnet_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
@ -350,8 +350,7 @@ class Tacotron2(TacotronAbstract):
}
if self.bidirectional_decoder or self.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment(
alignments_backward[0].data.cpu().numpy(), output_fig=False)
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)

View File

@ -37,7 +37,7 @@ class TacotronAbstract(ABC, nn.Module):
speaker_embedding_dim=None,
use_gst=False,
gst=None,
gradual_training=[]
gradual_training=[],
):
"""Abstract Tacotron class"""
super().__init__()

View File

@ -59,25 +59,16 @@ def numpy_to_tf(np_array, dtype):
def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(
ap.melspectrogram(ap.load_wav(style_wav,
sr=ap.sample_rate))).unsqueeze(0)
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
if cuda:
return style_mel.cuda()
return style_mel
def run_model_torch(model,
inputs,
speaker_id=None,
style_mel=None,
x_vector=None):
outputs = model.inference(inputs,
cond_input={
'speaker_ids': speaker_id,
'x_vector': x_vector,
'style_mel': style_mel
})
def run_model_torch(model, inputs, speaker_id=None, style_mel=None, x_vector=None):
outputs = model.inference(
inputs, cond_input={"speaker_ids": speaker_id, "x_vector": x_vector, "style_mel": style_mel}
)
return outputs
@ -87,18 +78,15 @@ def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
if speaker_id is not None:
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
# TODO: handle multispeaker case
decoder_output, postnet_output, alignments, stop_tokens = model(
inputs, training=False)
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
return decoder_output, postnet_output, alignments, stop_tokens
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
if CONFIG.gst and style_mel is not None:
raise NotImplementedError(
" [!] GST inference not implemented for TfLite")
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
if speaker_id is not None:
raise NotImplementedError(
" [!] Multi-Speaker not implemented for TfLite")
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
# get input and output details
input_details = model.get_input_details()
output_details = model.get_output_details()
@ -132,7 +120,7 @@ def parse_outputs_tflite(postnet_output, decoder_output):
def trim_silence(wav, ap):
return wav[:ap.find_endpoint(wav)]
return wav[: ap.find_endpoint(wav)]
def inv_spectrogram(postnet_output, ap, CONFIG):
@ -155,8 +143,7 @@ def speaker_id_to_torch(speaker_id, cuda=False):
def embedding_to_torch(x_vector, cuda=False):
if x_vector is not None:
x_vector = np.asarray(x_vector)
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(
torch.FloatTensor)
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(torch.FloatTensor)
if cuda:
return x_vector.cuda()
return x_vector
@ -173,8 +160,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
"""
wavs = []
for idx, spec in enumerate(inputs):
wav_len = (input_lens[idx] *
ap.hop_length) - ap.hop_length # inverse librosa padding
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
wav = inv_spectrogram(spec, ap, CONFIG)
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
wavs.append(wav[:wav_len])
@ -242,23 +228,21 @@ def synthesis(
text_inputs = tf.expand_dims(text_inputs, 0)
# synthesize voice
if backend == "torch":
outputs = run_model_torch(model,
text_inputs,
speaker_id,
style_mel,
x_vector=x_vector)
model_outputs = outputs['model_outputs']
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, x_vector=x_vector)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
elif backend == "tf":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
model, text_inputs, CONFIG, speaker_id, style_mel)
model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
postnet_output, decoder_output, alignments, stop_tokens)
postnet_output, decoder_output, alignments, stop_tokens
)
elif backend == "tflite":
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
model, text_inputs, CONFIG, speaker_id, style_mel)
model_outputs, decoder_output = parse_outputs_tflite(
postnet_output, decoder_output)
model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
# convert outputs to numpy
# plot results
wav = None
@ -268,9 +252,9 @@ def synthesis(
if do_trim_silence:
wav = trim_silence(wav, ap)
return_dict = {
'wav': wav,
'alignments': outputs['alignments'],
'model_outputs': model_outputs,
'text_inputs': text_inputs
"wav": wav,
"alignments": outputs["alignments"],
"model_outputs": model_outputs,
"text_inputs": text_inputs,
}
return return_dict

View File

@ -30,16 +30,16 @@ def init_arguments(argv):
parser.add_argument(
"--continue_path",
type=str,
help=("Training output folder to continue training. Used to continue "
"a training. If it is used, 'config_path' is ignored."),
help=(
"Training output folder to continue training. Used to continue "
"a training. If it is used, 'config_path' is ignored."
),
default="",
required="--config_path" not in argv,
)
parser.add_argument(
"--restore_path",
type=str,
help="Model file to be restored. Use to finetune a model.",
default="")
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
)
parser.add_argument(
"--best_path",
type=str,
@ -49,23 +49,12 @@ def init_arguments(argv):
),
default="",
)
parser.add_argument("--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in argv)
parser.add_argument("--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.")
parser.add_argument(
"--rank",
type=int,
default=0,
help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id",
type=str,
default="",
help="DISTRIBUTED: process group id.")
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
)
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
return parser
@ -160,8 +149,7 @@ def process_args(args):
print(" > Mixed precision mode is ON")
experiment_path = args.continue_path
if not experiment_path:
experiment_path = create_experiment_folder(config.output_path,
config.run_name, args.debug)
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training
tb_logger = None
@ -182,8 +170,7 @@ def process_args(args):
os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>",
0)
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger

View File

@ -234,8 +234,8 @@ class Synthesizer(object):
use_griffin_lim=use_gl,
x_vector=speaker_embedding,
)
waveform = outputs['wav']
mel_postnet_spec = outputs['model_outputs']
waveform = outputs["wav"]
mel_postnet_spec = outputs["model_outputs"]
if not use_gl:
# denormalize tts output based on tts audio config
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T

View File

@ -44,8 +44,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -46,8 +46,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -45,8 +45,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -45,8 +45,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -44,8 +44,6 @@ run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
)
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -21,7 +21,6 @@ config = MelganConfig(
print_step=1,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
print_eval=True,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
data_path="tests/data/ljspeech",
output_path=output_path,
)