mirror of https://github.com/coqui-ai/TTS.git
* Fix #813 * Update glow_tts recipe * Fix glow-tts test * Linter fix * Run data dep init only in training
This commit is contained in:
parent
f563415052
commit
2766dd1d6e
|
@ -1,4 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
|
||||||
|
@ -167,3 +168,14 @@ class GlowTTSConfig(BaseTTSConfig):
|
||||||
min_seq_len: int = 3
|
min_seq_len: int = 3
|
||||||
max_seq_len: int = 500
|
max_seq_len: int = 500
|
||||||
r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it.
|
r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it.
|
||||||
|
|
||||||
|
# testing
|
||||||
|
test_sentences: List[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
|
"This cake is great. It's so delicious and moist.",
|
||||||
|
"Prior to November 22, 1963.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
@ -119,7 +119,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
||||||
hidden_channels=128,
|
hidden_channels=128,
|
||||||
num_speakers=0,
|
num_speakers=0,
|
||||||
positional_encoding=True,
|
positional_encoding=True,
|
||||||
detach_duration_predictor=True
|
detach_duration_predictor=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# multi-speaker settings
|
# multi-speaker settings
|
||||||
|
|
|
@ -165,9 +165,9 @@ class Encoder(nn.Module):
|
||||||
# set duration predictor input
|
# set duration predictor input
|
||||||
if g is not None:
|
if g is not None:
|
||||||
g_exp = g.expand(-1, -1, x.size(-1))
|
g_exp = g.expand(-1, -1, x.size(-1))
|
||||||
x_dp = torch.cat([torch.detach(x), g_exp], 1)
|
x_dp = torch.cat([x.detach(), g_exp], 1)
|
||||||
else:
|
else:
|
||||||
x_dp = torch.detach(x)
|
x_dp = x.detach()
|
||||||
# final projection layer
|
# final projection layer
|
||||||
x_m = self.proj_m(x) * x_mask
|
x_m = self.proj_m(x) * x_mask
|
||||||
if not self.mean_only:
|
if not self.mean_only:
|
||||||
|
|
|
@ -427,11 +427,11 @@ class GlowTTSLoss(torch.nn.Module):
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
# flow loss - neg log likelihood
|
# flow loss - neg log likelihood
|
||||||
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2)
|
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2)
|
||||||
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[1])
|
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[2])
|
||||||
# duration loss - MSE
|
# duration loss - MSE
|
||||||
# loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
loss_dur = torch.sum((o_dur_log - o_attn_dur) ** 2) / torch.sum(x_lengths)
|
||||||
# duration loss - huber loss
|
# duration loss - huber loss
|
||||||
loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths)
|
# loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths)
|
||||||
return_dict["loss"] = log_mle + loss_dur
|
return_dict["loss"] = log_mle + loss_dur
|
||||||
return_dict["log_mle"] = log_mle
|
return_dict["log_mle"] = log_mle
|
||||||
return_dict["loss_dur"] = loss_dur
|
return_dict["loss_dur"] = loss_dur
|
||||||
|
|
|
@ -2,6 +2,7 @@ import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.configs import GlowTTSConfig
|
from TTS.tts.configs import GlowTTSConfig
|
||||||
|
@ -68,6 +69,8 @@ class GlowTTS(BaseTTS):
|
||||||
# TODO: make this adjustable
|
# TODO: make this adjustable
|
||||||
self.c_in_channels = 256
|
self.c_in_channels = 256
|
||||||
|
|
||||||
|
self.run_data_dep_init = config.data_dep_init_steps > 0
|
||||||
|
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
self.num_chars,
|
self.num_chars,
|
||||||
out_channels=self.out_channels,
|
out_channels=self.out_channels,
|
||||||
|
@ -131,6 +134,18 @@ class GlowTTS(BaseTTS):
|
||||||
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
||||||
return y_mean, y_log_scale, o_attn_dur
|
return y_mean, y_log_scale, o_attn_dur
|
||||||
|
|
||||||
|
def unlock_act_norm_layers(self):
|
||||||
|
"""Unlock activation normalization layers for data depended initalization."""
|
||||||
|
for f in self.decoder.flows:
|
||||||
|
if getattr(f, "set_ddi", False):
|
||||||
|
f.set_ddi(True)
|
||||||
|
|
||||||
|
def lock_act_norm_layers(self):
|
||||||
|
"""Lock activation normalization layers."""
|
||||||
|
for f in self.decoder.flows:
|
||||||
|
if getattr(f, "set_ddi", False):
|
||||||
|
f.set_ddi(False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
|
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
|
@ -142,6 +157,7 @@ class GlowTTS(BaseTTS):
|
||||||
- y_lengths::math:`B`
|
- y_lengths::math:`B`
|
||||||
- g: :math:`[B, C] or B`
|
- g: :math:`[B, C] or B`
|
||||||
"""
|
"""
|
||||||
|
# [B, T, C] -> [B, C, T]
|
||||||
y = y.transpose(1, 2)
|
y = y.transpose(1, 2)
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
|
@ -157,6 +173,7 @@ class GlowTTS(BaseTTS):
|
||||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
||||||
# create masks
|
# create masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||||
|
# [B, 1, T_en, T_de]
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# decoder pass
|
# decoder pass
|
||||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||||
|
@ -172,7 +189,7 @@ class GlowTTS(BaseTTS):
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
outputs = {
|
outputs = {
|
||||||
"model_outputs": z.transpose(1, 2),
|
"z": z.transpose(1, 2),
|
||||||
"logdet": logdet,
|
"logdet": logdet,
|
||||||
"y_mean": y_mean.transpose(1, 2),
|
"y_mean": y_mean.transpose(1, 2),
|
||||||
"y_log_scale": y_log_scale.transpose(1, 2),
|
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||||
|
@ -319,7 +336,8 @@ class GlowTTS(BaseTTS):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def train_step(self, batch: dict, criterion: nn.Module):
|
def train_step(self, batch: dict, criterion: nn.Module):
|
||||||
"""Perform a single training step by fetching the right set if samples from the batch.
|
"""A single training step. Forward pass and loss computation. Run data depended initialization for the
|
||||||
|
first `config.data_dep_init_steps` steps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (dict): [description]
|
batch (dict): [description]
|
||||||
|
@ -332,6 +350,22 @@ class GlowTTS(BaseTTS):
|
||||||
d_vectors = batch["d_vectors"]
|
d_vectors = batch["d_vectors"]
|
||||||
speaker_ids = batch["speaker_ids"]
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
|
||||||
|
if self.run_data_dep_init and self.training:
|
||||||
|
# compute data-dependent initialization of activation norm layers
|
||||||
|
self.unlock_act_norm_layers()
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = self.forward(
|
||||||
|
text_input,
|
||||||
|
text_lengths,
|
||||||
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||||
|
)
|
||||||
|
outputs = None
|
||||||
|
loss_dict = None
|
||||||
|
self.lock_act_norm_layers()
|
||||||
|
else:
|
||||||
|
# normal training step
|
||||||
outputs = self.forward(
|
outputs = self.forward(
|
||||||
text_input,
|
text_input,
|
||||||
text_lengths,
|
text_lengths,
|
||||||
|
@ -340,23 +374,33 @@ class GlowTTS(BaseTTS):
|
||||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with autocast(enabled=False): # avoid mixed_precision in criterion
|
||||||
loss_dict = criterion(
|
loss_dict = criterion(
|
||||||
outputs["model_outputs"],
|
outputs["z"].float(),
|
||||||
outputs["y_mean"],
|
outputs["y_mean"].float(),
|
||||||
outputs["y_log_scale"],
|
outputs["y_log_scale"].float(),
|
||||||
outputs["logdet"],
|
outputs["logdet"].float(),
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
outputs["durations_log"],
|
outputs["durations_log"].float(),
|
||||||
outputs["total_durations_log"],
|
outputs["total_durations_log"].float(),
|
||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||||
model_outputs = outputs["model_outputs"]
|
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
|
text_input = batch["text_input"]
|
||||||
|
text_lengths = batch["text_lengths"]
|
||||||
mel_input = batch["mel_input"]
|
mel_input = batch["mel_input"]
|
||||||
|
d_vectors = batch["d_vectors"]
|
||||||
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
|
||||||
|
# model runs reverse flow to predict spectrograms
|
||||||
|
pred_outputs = self.inference(
|
||||||
|
text_input[:1],
|
||||||
|
aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||||
|
)
|
||||||
|
model_outputs = pred_outputs["model_outputs"]
|
||||||
|
|
||||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
gt_spec = mel_input[0].data.cpu().numpy()
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
@ -393,6 +437,9 @@ class GlowTTS(BaseTTS):
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
test_sentences = self.config.test_sentences
|
test_sentences = self.config.test_sentences
|
||||||
aux_inputs = self.get_aux_input()
|
aux_inputs = self.get_aux_input()
|
||||||
|
if len(test_sentences) == 0:
|
||||||
|
print(" | [!] No test sentences provided.")
|
||||||
|
else:
|
||||||
for idx, sen in enumerate(test_sentences):
|
for idx, sen in enumerate(test_sentences):
|
||||||
outputs = synthesis(
|
outputs = synthesis(
|
||||||
self,
|
self,
|
||||||
|
@ -441,3 +488,7 @@ class GlowTTS(BaseTTS):
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
|
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
return GlowTTSLoss()
|
return GlowTTSLoss()
|
||||||
|
|
||||||
|
def on_train_step_start(self, trainer):
|
||||||
|
"""Decide on every training step wheter enable/disable data depended initialization."""
|
||||||
|
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
|
||||||
|
|
|
@ -15,13 +15,13 @@ config = GlowTTSConfig(
|
||||||
run_eval=True,
|
run_eval=True,
|
||||||
test_delay_epochs=-1,
|
test_delay_epochs=-1,
|
||||||
epochs=1000,
|
epochs=1000,
|
||||||
text_cleaner="english_cleaners",
|
text_cleaner="phoneme_cleaners",
|
||||||
use_phonemes=False,
|
use_phonemes=True,
|
||||||
phoneme_language="en-us",
|
phoneme_language="en-us",
|
||||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
print_step=25,
|
print_step=25,
|
||||||
print_eval=True,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=True,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
|
|
|
@ -63,7 +63,7 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||||
loss_dict = criterion(
|
loss_dict = criterion(
|
||||||
outputs["model_outputs"],
|
outputs["z"],
|
||||||
outputs["y_mean"],
|
outputs["y_mean"],
|
||||||
outputs["y_log_scale"],
|
outputs["y_log_scale"],
|
||||||
outputs["logdet"],
|
outputs["logdet"],
|
||||||
|
|
Loading…
Reference in New Issue