mirror of https://github.com/coqui-ai/TTS.git
Add Pitch Predictor conditioned on enc
This commit is contained in:
parent
8f6c187848
commit
569decba64
|
@ -69,6 +69,9 @@ def extract_aligments(
|
|||
for idx in range(tokens.shape[0]):
|
||||
wav_file_path = item_idx[idx]
|
||||
alignment = alignments[idx]
|
||||
spec_length = spec_lens[idx]
|
||||
token_length = token_lenghts[idx]
|
||||
alignment = alignment[:token_length, :spec_length]
|
||||
# set paths
|
||||
align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy"
|
||||
os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True)
|
||||
|
|
|
@ -40,6 +40,7 @@ class TextEncoder(nn.Module):
|
|||
language_emb_dim: int = None,
|
||||
emotion_emb_dim: int = None,
|
||||
prosody_emb_dim: int = None,
|
||||
pitch_dim: int = None,
|
||||
):
|
||||
"""Text Encoder for VITS model.
|
||||
|
||||
|
@ -70,6 +71,9 @@ class TextEncoder(nn.Module):
|
|||
if prosody_emb_dim:
|
||||
hidden_channels += prosody_emb_dim
|
||||
|
||||
if pitch_dim:
|
||||
hidden_channels += pitch_dim
|
||||
|
||||
self.encoder = RelativePositionTransformer(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=hidden_channels,
|
||||
|
@ -85,7 +89,7 @@ class TextEncoder(nn.Module):
|
|||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None):
|
||||
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None, pitch_emb=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T]`
|
||||
|
@ -105,6 +109,9 @@ class TextEncoder(nn.Module):
|
|||
if pros_emb is not None:
|
||||
x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||
|
||||
if pitch_emb is not None:
|
||||
x = torch.cat((x, pitch_emb.transpose(2, 1)), dim=-1)
|
||||
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
|
||||
|
||||
|
|
|
@ -265,7 +265,9 @@ class VitsDataset(TTSDataset):
|
|||
self.pad_id = self.tokenizer.characters.pad_id
|
||||
self.model_args = model_args
|
||||
self.compute_pitch = compute_pitch
|
||||
|
||||
self.use_precomputed_alignments = model_args.use_precomputed_alignments
|
||||
self.alignments_cache_path = model_args.alignments_cache_path
|
||||
|
||||
if self.compute_pitch:
|
||||
self.f0_dataset = VITSF0Dataset(config,
|
||||
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
|
||||
|
@ -289,6 +291,11 @@ class VitsDataset(TTSDataset):
|
|||
if self.compute_pitch:
|
||||
f0 = self.get_f0(idx)["f0"]
|
||||
|
||||
alignments = None
|
||||
if self.use_precomputed_alignments:
|
||||
align_file = os.path.join(self.alignments_cache_path, os.path.splitext(wav_filename)[0] + ".npy")
|
||||
alignments = self.get_attn_mask(align_file)
|
||||
|
||||
# after phonemization the text length may change
|
||||
# this is a shameful 🤭 hack to prevent longer phonemes
|
||||
# TODO: find a better fix
|
||||
|
@ -305,6 +312,8 @@ class VitsDataset(TTSDataset):
|
|||
"speaker_name": item["speaker_name"],
|
||||
"language_name": item["language"],
|
||||
"pitch": f0,
|
||||
"alignments": alignments,
|
||||
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -365,6 +374,18 @@ class VitsDataset(TTSDataset):
|
|||
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||
else:
|
||||
pitch = None
|
||||
|
||||
padded_alignments = None
|
||||
if self.use_precomputed_alignments:
|
||||
alignments = batch["alignments"]
|
||||
max_len_1 = max((x.shape[0] for x in alignments))
|
||||
max_len_2 = max((x.shape[1] for x in alignments))
|
||||
padded_alignments = []
|
||||
for x in alignments:
|
||||
padded_alignment = np.pad(x, ((0, max_len_1 - x.shape[0]), (0, max_len_2 - x.shape[1])), mode="constant", constant_values=0)
|
||||
padded_alignments.append(padded_alignment)
|
||||
|
||||
padded_alignments = torch.FloatTensor(np.stack(padded_alignments)).unsqueeze(1)
|
||||
|
||||
return {
|
||||
"tokens": token_padded,
|
||||
|
@ -378,6 +399,7 @@ class VitsDataset(TTSDataset):
|
|||
"language_names": batch["language_name"],
|
||||
"audio_files": batch["wav_file"],
|
||||
"raw_text": batch["raw_text"],
|
||||
"alignments": padded_alignments,
|
||||
}
|
||||
|
||||
|
||||
|
@ -385,7 +407,6 @@ class VitsDataset(TTSDataset):
|
|||
# MODEL DEFINITION
|
||||
##############################
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitsArgs(Coqpit):
|
||||
"""VITS model arguments.
|
||||
|
@ -664,6 +685,9 @@ class VitsArgs(Coqpit):
|
|||
pitch_predictor_dropout_p: float = 0.1
|
||||
pitch_embedding_kernel_size: int = 3
|
||||
detach_pp_input: bool = False
|
||||
use_precomputed_alignments: bool = False
|
||||
alignments_cache_path: str = ""
|
||||
pitch_embedding_dim: int = 0
|
||||
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
|
@ -751,6 +775,7 @@ class Vits(BaseTTS):
|
|||
language_emb_dim=self.embedded_language_dim,
|
||||
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor else 0,
|
||||
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
|
||||
)
|
||||
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
|
@ -791,6 +816,9 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor:
|
||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
dp_extra_inp_dim += self.args.pitch_embedding_dim
|
||||
|
||||
if self.args.use_sdp:
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
self.args.hidden_channels + dp_extra_inp_dim,
|
||||
|
@ -814,13 +842,13 @@ class Vits(BaseTTS):
|
|||
if self.args.use_pitch:
|
||||
if self.args.use_pitch_on_enc_input:
|
||||
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
||||
else:
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1,
|
||||
self.args.hidden_channels,
|
||||
kernel_size=self.args.pitch_predictor_kernel_size,
|
||||
padding=int((self.args.pitch_predictor_kernel_size - 1) / 2),
|
||||
)
|
||||
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1,
|
||||
self.args.hidden_channels if not self.args.use_pitch_on_enc_input else self.args.pitch_embedding_dim,
|
||||
kernel_size=self.args.pitch_predictor_kernel_size,
|
||||
padding=int((self.args.pitch_predictor_kernel_size - 1) / 2),
|
||||
)
|
||||
self.pitch_predictor = DurationPredictor(
|
||||
self.args.hidden_channels,
|
||||
self.args.pitch_predictor_hidden_channels,
|
||||
|
@ -1241,17 +1269,16 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
pitch_loss = None
|
||||
gt_avg_pitch = None
|
||||
pred_avg_pitch_emb = None
|
||||
gt_avg_pitch_emb = None
|
||||
if pitch is not None:
|
||||
gt_avg_pitch = average_over_durations(pitch, dr.squeeze()).detach()
|
||||
pitch_loss = torch.sum(torch.sum((gt_avg_pitch - pred_avg_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
|
||||
if not self.args.use_pitch_on_enc_input:
|
||||
gt_agv_pitch = self.pitch_emb(gt_avg_pitch)
|
||||
gt_avg_pitch_emb = self.pitch_emb(gt_avg_pitch)
|
||||
else:
|
||||
if not self.args.use_pitch_on_enc_input:
|
||||
pred_avg_pitch = self.pitch_emb(pred_avg_pitch)
|
||||
pred_avg_pitch_emb = self.pitch_emb(pred_avg_pitch)
|
||||
|
||||
return pitch_loss, gt_agv_pitch, pred_avg_pitch
|
||||
return pitch_loss, gt_avg_pitch_emb, pred_avg_pitch_emb
|
||||
|
||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||
# find the alignment path
|
||||
|
@ -1313,6 +1340,7 @@ class Vits(BaseTTS):
|
|||
y_lengths: torch.tensor,
|
||||
waveform: torch.tensor,
|
||||
pitch: torch.tensor,
|
||||
alignments: torch.tensor,
|
||||
aux_input={
|
||||
"d_vectors": None,
|
||||
"speaker_ids": None,
|
||||
|
@ -1389,6 +1417,21 @@ class Vits(BaseTTS):
|
|||
if self.args.use_speaker_embedding or self.args.use_d_vector_file:
|
||||
g = F.normalize(self.speaker_embedding_squeezer(g.squeeze(-1))).unsqueeze(-1)
|
||||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dp is None:
|
||||
g_dp = eg
|
||||
else:
|
||||
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
pitch_loss = None
|
||||
gt_avg_pitch_emb = None
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
if alignments is None:
|
||||
raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !")
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g_dp)
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
|
@ -1418,13 +1461,14 @@ class Vits(BaseTTS):
|
|||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
if self.args.use_prosody_enc_emo_classifier:
|
||||
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
|
||||
x_input = x
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||
)
|
||||
|
||||
# reversal speaker loss to force the encoder to be speaker identity free
|
||||
|
@ -1437,14 +1481,6 @@ class Vits(BaseTTS):
|
|||
if self.args.use_text_enc_emo_classifier:
|
||||
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask)
|
||||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dp is None:
|
||||
g_dp = eg
|
||||
else:
|
||||
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_dp is None:
|
||||
g_dp = pros_emb
|
||||
|
@ -1453,7 +1489,6 @@ class Vits(BaseTTS):
|
|||
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||
|
||||
pitch_loss = None
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
|
||||
m_p = m_p + gt_avg_pitch_emb
|
||||
|
@ -1631,14 +1666,6 @@ class Vits(BaseTTS):
|
|||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||
)
|
||||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
|
@ -1647,6 +1674,19 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
pred_avg_pitch_emb = None
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||
)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_dp is None:
|
||||
g_dp = pros_emb
|
||||
|
@ -1673,7 +1713,6 @@ class Vits(BaseTTS):
|
|||
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||
|
||||
pred_avg_pitch_emb = None
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp)
|
||||
m_p = m_p + pred_avg_pitch_emb
|
||||
|
@ -1819,6 +1858,7 @@ class Vits(BaseTTS):
|
|||
emotion_ids = batch["emotion_ids"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
alignments = batch["alignments"]
|
||||
|
||||
# generator pass
|
||||
outputs = self.forward(
|
||||
|
@ -1828,6 +1868,7 @@ class Vits(BaseTTS):
|
|||
spec_lens,
|
||||
waveform,
|
||||
pitch,
|
||||
alignments,
|
||||
aux_input={
|
||||
"d_vectors": d_vectors,
|
||||
"speaker_ids": speaker_ids,
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -28,7 +28,7 @@ config = VitsConfig(
|
|||
compute_pitch=True,
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
|
||||
["Be a voice, not an echo.", "ljspeech-1", None, None, None],
|
||||
],
|
||||
)
|
||||
# set audio config
|
||||
|
@ -42,17 +42,18 @@ config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
|||
config.model_args.speaker_embedding_channels = 128
|
||||
config.model_args.d_vector_dim = 128
|
||||
|
||||
# prosody embedding
|
||||
config.model_args.use_prosody_encoder = True
|
||||
config.model_args.prosody_embedding_dim = 64
|
||||
|
||||
config.model_args.use_precomputed_alignments = True
|
||||
config.model_args.alignments_cache_path = "tests/data/ljspeech/mas_alignments/alignments/"
|
||||
|
||||
# pitch predictor
|
||||
config.model_args.use_pitch = True
|
||||
config.model_args.use_pitch_on_enc_input = True
|
||||
config.model_args.pitch_embedding_dim = 2
|
||||
config.model_args.condition_dp_on_speaker = True
|
||||
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
|
@ -74,11 +75,9 @@ continue_config_path = os.path.join(continue_path, "config.json")
|
|||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
|
||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||
|
||||
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}"
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
Loading…
Reference in New Issue