mirror of https://github.com/coqui-ai/TTS.git
rename to
This commit is contained in:
parent
4575b70826
commit
f077a356e0
|
@ -136,14 +136,14 @@ def inference(
|
|||
speaker_c = d_vectors
|
||||
|
||||
outputs = model.inference_with_MAS(
|
||||
text_input, text_lengths, mel_input, mel_lengths, cond_input={"d_vectors": speaker_c}
|
||||
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c}
|
||||
)
|
||||
model_output = outputs["model_outputs"]
|
||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||
|
||||
elif "tacotron" in model_name:
|
||||
cond_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = model(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
# normalize tacotron output
|
||||
if model_name == "tacotron":
|
||||
|
|
|
@ -573,7 +573,7 @@ class TrainerTTS:
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
cond_inputs = self._get_cond_inputs()
|
||||
aux_inputs = self._get_aux_inputs()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, model_outputs, _ = synthesis(
|
||||
self.model,
|
||||
|
@ -581,9 +581,9 @@ class TrainerTTS:
|
|||
self.config,
|
||||
self.use_cuda,
|
||||
self.ap,
|
||||
speaker_id=cond_inputs["speaker_id"],
|
||||
d_vector=cond_inputs["d_vector"],
|
||||
style_wav=cond_inputs["style_wav"],
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
|
@ -600,7 +600,7 @@ class TrainerTTS:
|
|||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
||||
|
||||
def _get_cond_inputs(self) -> Dict:
|
||||
def _get_aux_inputs(self) -> Dict:
|
||||
# setup speaker_id
|
||||
speaker_id = 0 if self.config.use_speaker_embedding else None
|
||||
# setup d_vector
|
||||
|
@ -620,8 +620,8 @@ class TrainerTTS:
|
|||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
cond_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
||||
return cond_inputs
|
||||
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
||||
return aux_inputs
|
||||
|
||||
def fit(self) -> None:
|
||||
if self.restore_step != 0 or self.args.best_path:
|
||||
|
|
|
@ -212,7 +212,7 @@ class AlignTTS(nn.Module):
|
|||
return dr_mas, mu, log_sigma, logp
|
||||
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths, cond_input={"d_vectors": None}, phase=None
|
||||
self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -223,7 +223,7 @@ class AlignTTS(nn.Module):
|
|||
g: [B, C]
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
g = cond_input["d_vectors"] if "d_vectors" in cond_input else None
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
|
||||
if phase == 0:
|
||||
# train encoder and MDN
|
||||
|
@ -267,14 +267,14 @@ class AlignTTS(nn.Module):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, cond_input={"d_vectors": None}): # pylint: disable=unused-argument
|
||||
def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
g: [B, C]
|
||||
"""
|
||||
g = cond_input["d_vectors"] if "d_vectors" in cond_input else None
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
# pad input to prevent dropping the last word
|
||||
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
|
@ -296,8 +296,8 @@ class AlignTTS(nn.Module):
|
|||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
|
||||
cond_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input, self.phase)
|
||||
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input, self.phase)
|
||||
loss_dict = criterion(
|
||||
outputs["logp"],
|
||||
outputs["model_outputs"],
|
||||
|
|
|
@ -144,7 +144,7 @@ class GlowTTS(nn.Module):
|
|||
return y_mean, y_log_scale, o_attn_dur
|
||||
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths=None, cond_input={"d_vectors": None}
|
||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -157,7 +157,7 @@ class GlowTTS(nn.Module):
|
|||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if g is not None:
|
||||
if self.d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
|
@ -197,7 +197,7 @@ class GlowTTS(nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference_with_MAS(
|
||||
self, x, x_lengths, y=None, y_lengths=None, cond_input={"d_vectors": None}
|
||||
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
It's similar to the teacher forcing in Tacotron.
|
||||
|
@ -212,7 +212,7 @@ class GlowTTS(nn.Module):
|
|||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if g is not None:
|
||||
if self.external_d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
|
@ -258,7 +258,7 @@ class GlowTTS(nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def decoder_inference(
|
||||
self, y, y_lengths=None, cond_input={"d_vectors": None}
|
||||
self, y, y_lengths=None, aux_input={"d_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -268,7 +268,7 @@ class GlowTTS(nn.Module):
|
|||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
# norm speaker embeddings
|
||||
if g is not None:
|
||||
if self.external_d_vector_dim:
|
||||
|
@ -290,11 +290,9 @@ class GlowTTS(nn.Module):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self, x, cond_input={"x_lengths": None, "d_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
x_lengths = cond_input["x_lengths"]
|
||||
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
|
||||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value
|
||||
x_lengths = aux_input["x_lengths"]
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if g is not None:
|
||||
if self.d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
|
@ -343,7 +341,7 @@ class GlowTTS(nn.Module):
|
|||
mel_lengths = batch["mel_lengths"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input={"d_vectors": d_vectors})
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors})
|
||||
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
|
|
|
@ -157,7 +157,7 @@ class SpeedySpeech(nn.Module):
|
|||
return o_de, attn.transpose(1, 2)
|
||||
|
||||
def forward(
|
||||
self, x, x_lengths, y_lengths, dr, cond_input={"d_vectors": None, "speaker_ids": None}
|
||||
self, x, x_lengths, y_lengths, dr, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
TODO: speaker embedding for speaker_ids
|
||||
|
@ -168,21 +168,21 @@ class SpeedySpeech(nn.Module):
|
|||
dr: [B, T_max]
|
||||
g: [B, C]
|
||||
"""
|
||||
g = cond_input["d_vectors"] if "d_vectors" in cond_input else None
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
|
||||
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
||||
return outputs
|
||||
|
||||
def inference(self, x, cond_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
g: [B, C]
|
||||
"""
|
||||
g = cond_input["d_vectors"] if "d_vectors" in cond_input else None
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
# input sequence should be greated than the max convolution size
|
||||
inference_padding = 5
|
||||
|
@ -208,8 +208,8 @@ class SpeedySpeech(nn.Module):
|
|||
speaker_ids = batch["speaker_ids"]
|
||||
durations = batch["durations"]
|
||||
|
||||
cond_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, cond_input)
|
||||
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, aux_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
|
|
|
@ -186,14 +186,14 @@ class Tacotron(TacotronAbstract):
|
|||
max_decoder_steps,
|
||||
)
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
||||
"""
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
text_lengths: [B]
|
||||
mel_specs: [B, T_out, C]
|
||||
mel_lengths: [B]
|
||||
cond_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||
"""
|
||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||
inputs = self.embedding(text)
|
||||
|
@ -205,15 +205,15 @@ class Tacotron(TacotronAbstract):
|
|||
# global style token
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["d_vectors"])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, aux_input["d_vectors"])
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = torch.unsqueeze(cond_input["d_vectors"], 1)
|
||||
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
# decoder_outputs: B x decoder_in_features x T_out
|
||||
# alignments: B x T_in x encoder_in_features
|
||||
|
@ -252,17 +252,17 @@ class Tacotron(TacotronAbstract):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text_input, cond_input=None):
|
||||
cond_input = self._format_cond_input(cond_input)
|
||||
def inference(self, text_input, aux_input=None):
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
inputs = self.embedding(text_input)
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["d_vectors"])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = self.speaker_embedding(cond_input["speaker_ids"])
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
|
||||
# reshape embedded_speakers
|
||||
if embedded_speakers.ndim == 1:
|
||||
embedded_speakers = embedded_speakers[None, None, :]
|
||||
|
@ -270,7 +270,7 @@ class Tacotron(TacotronAbstract):
|
|||
embedded_speakers = embedded_speakers[None, :]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = torch.unsqueeze(cond_input["d_vectors"], 1)
|
||||
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
|
@ -306,7 +306,7 @@ class Tacotron(TacotronAbstract):
|
|||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
cond_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
|
||||
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
|
||||
)
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
|
@ -317,8 +317,8 @@ class Tacotron(TacotronAbstract):
|
|||
else:
|
||||
alignment_lengths = mel_lengths // self.decoder.r
|
||||
|
||||
cond_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
|
|
|
@ -186,16 +186,16 @@ class Tacotron2(TacotronAbstract):
|
|||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
||||
"""
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
text_lengths: [B]
|
||||
mel_specs: [B, T_out, C]
|
||||
mel_lengths: [B]
|
||||
cond_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||
"""
|
||||
cond_input = self._format_cond_input(cond_input)
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||
# compute mask for padding
|
||||
# B x T_in_max (boolean)
|
||||
|
@ -206,14 +206,14 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, cond_input["d_vectors"])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, aux_input["d_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = torch.unsqueeze(cond_input["d_vectors"], 1)
|
||||
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
|
@ -252,24 +252,24 @@ class Tacotron2(TacotronAbstract):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, cond_input=None):
|
||||
cond_input = self._format_cond_input(cond_input)
|
||||
def inference(self, text, aux_input=None):
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["d_vectors"])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vectors:
|
||||
embedded_speakers = self.speaker_embedding(cond_input["speaker_ids"])[None]
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
|
||||
# reshape embedded_speakers
|
||||
if embedded_speakers.ndim == 1:
|
||||
embedded_speakers = embedded_speakers[None, None, :]
|
||||
elif embedded_speakers.ndim == 2:
|
||||
embedded_speakers = embedded_speakers[None, :]
|
||||
else:
|
||||
embedded_speakers = cond_input["d_vectors"]
|
||||
embedded_speakers = aux_input["d_vectors"]
|
||||
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
|
||||
|
@ -307,7 +307,7 @@ class Tacotron2(TacotronAbstract):
|
|||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
cond_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
|
||||
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
|
||||
)
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
|
@ -318,8 +318,8 @@ class Tacotron2(TacotronAbstract):
|
|||
else:
|
||||
alignment_lengths = mel_lengths // self.decoder.r
|
||||
|
||||
cond_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input)
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.utils.generic_utils import format_cond_input
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
|
||||
|
@ -97,8 +97,8 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.coarse_decoder = None
|
||||
|
||||
@staticmethod
|
||||
def _format_cond_input(cond_input: Dict) -> Dict:
|
||||
return format_cond_input({"d_vectors": None, "speaker_ids": None}, cond_input)
|
||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
|
|
|
@ -92,7 +92,7 @@ def run_model_torch(
|
|||
_func = model.inference
|
||||
outputs = _func(
|
||||
inputs,
|
||||
cond_input={
|
||||
aux_input={
|
||||
"x_lengths": input_lengths,
|
||||
"speaker_ids": speaker_id,
|
||||
"d_vectors": d_vector,
|
||||
|
|
|
@ -136,7 +136,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
return model_dict
|
||||
|
||||
|
||||
def format_cond_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||
"""Format kwargs to hande auxilary inputs to models.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -16,8 +16,8 @@ tail -n 812 $RUN_DIR/$CORPUS/metadata_shuf.csv > $RUN_DIR/$CORPUS/metadata_val.c
|
|||
python TTS/bin/compute_statistics.py $RUN_DIR/tacotron2-DDC.json $RUN_DIR/scale_stats.npy --data_path $RUN_DIR/$CORPUS/wavs/
|
||||
# training ....
|
||||
# change the GPU id if needed
|
||||
CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tacotron.py --config_path $RUN_DIR/tacotron2-DDC.json \
|
||||
--coqpit.output_path $RUN_DIR \
|
||||
--coqpit.datasets.0.path $RUN_DIR/$CORPUS \
|
||||
--coqpit.audio.stats_path $RUN_DIR/scale_stats.npy \
|
||||
--coqpit.phoneme_cache_path $RUN_DIR/phoneme_cache \
|
||||
CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tts.py --config_path $RUN_DIR/tacotron2-DDC.json \
|
||||
--coqpit.output_path $RUN_DIR \
|
||||
--coqpit.datasets.0.path $RUN_DIR/$CORPUS \
|
||||
--coqpit.audio.stats_path $RUN_DIR/scale_stats.npy \
|
||||
--coqpit.phoneme_cache_path $RUN_DIR/phoneme_cache \
|
|
@ -16,7 +16,7 @@ rm LJSpeech-1.1.tar.bz2
|
|||
python TTS/bin/compute_statistics.py $RUN_DIR/tacotron2-DDC.json $RUN_DIR/scale_stats.npy --data_path $RUN_DIR/LJSpeech-1.1/wavs/
|
||||
# training ....
|
||||
# change the GPU id if needed
|
||||
CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tacotron.py --config_path $RUN_DIR/tacotron2-DDC.json \
|
||||
--coqpit.output_path $RUN_DIR \
|
||||
--coqpit.datasets.0.path $RUN_DIR/LJSpeech-1.1/ \
|
||||
--coqpit.audio.stats_path $RUN_DIR/scale_stats.npy \
|
||||
CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tts.py --config_path $RUN_DIR/tacotron2-DDC.json \
|
||||
--coqpit.output_path $RUN_DIR \
|
||||
--coqpit.datasets.0.path $RUN_DIR/LJSpeech-1.1/ \
|
||||
--coqpit.audio.stats_path $RUN_DIR/scale_stats.npy \
|
|
@ -57,7 +57,7 @@ def test_speedy_speech():
|
|||
# with speaker embedding
|
||||
model = SpeedySpeech(num_chars, out_channels=80, hidden_channels=128, num_speakers=10, c_in_channels=256).to(device)
|
||||
model.forward(
|
||||
x_dummy, x_lengths, y_lengths, durations, cond_input={"d_vectors": torch.randint(0, 10, (B,)).to(device)}
|
||||
x_dummy, x_lengths, y_lengths, durations, aux_input={"d_vectors": torch.randint(0, 10, (B,)).to(device)}
|
||||
)
|
||||
o_de = outputs["model_outputs"]
|
||||
attn = outputs["alignments"]
|
||||
|
@ -71,7 +71,7 @@ def test_speedy_speech():
|
|||
model = SpeedySpeech(
|
||||
num_chars, out_channels=80, hidden_channels=128, num_speakers=10, external_c=True, c_in_channels=256
|
||||
).to(device)
|
||||
model.forward(x_dummy, x_lengths, y_lengths, durations, cond_input={"d_vectors": torch.rand((B, 256)).to(device)})
|
||||
model.forward(x_dummy, x_lengths, y_lengths, durations, aux_input={"d_vectors": torch.rand((B, 256)).to(device)})
|
||||
o_de = outputs["model_outputs"]
|
||||
attn = outputs["alignments"]
|
||||
o_dr = outputs["durations_log"]
|
||||
|
|
|
@ -53,7 +53,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"speaker_ids": speaker_ids}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
||||
)
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||
|
@ -105,7 +105,7 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"d_vectors": speaker_ids}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_ids}
|
||||
)
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||
|
@ -158,7 +158,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(10):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"speaker_ids": speaker_ids}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
||||
)
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||
|
@ -214,7 +214,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(10):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"speaker_ids": speaker_ids}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
||||
)
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||
|
@ -269,7 +269,7 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"d_vectors": speaker_embeddings}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings}
|
||||
)
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||
|
|
|
@ -69,7 +69,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"speaker_ids": speaker_ids}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
|
@ -130,7 +130,7 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"d_vectors": speaker_embeddings}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings}
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
|
@ -194,7 +194,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(10):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"speaker_ids": speaker_ids}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
|
@ -257,7 +257,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(10):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"speaker_ids": speaker_ids}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"speaker_ids": speaker_ids}
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
|
@ -319,7 +319,7 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(5):
|
||||
outputs = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, cond_input={"d_vectors": speaker_embeddings}
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings}
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||
|
|
Loading…
Reference in New Issue