Add Pitch Predictor conditioned on enc

This commit is contained in:
Edresson Casanova 2022-06-17 17:39:17 -03:00
parent 8f6c187848
commit 569decba64
12 changed files with 93 additions and 43 deletions

View File

@ -69,6 +69,9 @@ def extract_aligments(
for idx in range(tokens.shape[0]): for idx in range(tokens.shape[0]):
wav_file_path = item_idx[idx] wav_file_path = item_idx[idx]
alignment = alignments[idx] alignment = alignments[idx]
spec_length = spec_lens[idx]
token_length = token_lenghts[idx]
alignment = alignment[:token_length, :spec_length]
# set paths # set paths
align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy" align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy"
os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True) os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True)

View File

@ -40,6 +40,7 @@ class TextEncoder(nn.Module):
language_emb_dim: int = None, language_emb_dim: int = None,
emotion_emb_dim: int = None, emotion_emb_dim: int = None,
prosody_emb_dim: int = None, prosody_emb_dim: int = None,
pitch_dim: int = None,
): ):
"""Text Encoder for VITS model. """Text Encoder for VITS model.
@ -70,6 +71,9 @@ class TextEncoder(nn.Module):
if prosody_emb_dim: if prosody_emb_dim:
hidden_channels += prosody_emb_dim hidden_channels += prosody_emb_dim
if pitch_dim:
hidden_channels += pitch_dim
self.encoder = RelativePositionTransformer( self.encoder = RelativePositionTransformer(
in_channels=hidden_channels, in_channels=hidden_channels,
out_channels=hidden_channels, out_channels=hidden_channels,
@ -85,7 +89,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None): def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None, pitch_emb=None):
""" """
Shapes: Shapes:
- x: :math:`[B, T]` - x: :math:`[B, T]`
@ -105,6 +109,9 @@ class TextEncoder(nn.Module):
if pros_emb is not None: if pros_emb is not None:
x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
if pitch_emb is not None:
x = torch.cat((x, pitch_emb.transpose(2, 1)), dim=-1)
x = torch.transpose(x, 1, -1) # [b, h, t] x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]

View File

@ -265,6 +265,8 @@ class VitsDataset(TTSDataset):
self.pad_id = self.tokenizer.characters.pad_id self.pad_id = self.tokenizer.characters.pad_id
self.model_args = model_args self.model_args = model_args
self.compute_pitch = compute_pitch self.compute_pitch = compute_pitch
self.use_precomputed_alignments = model_args.use_precomputed_alignments
self.alignments_cache_path = model_args.alignments_cache_path
if self.compute_pitch: if self.compute_pitch:
self.f0_dataset = VITSF0Dataset(config, self.f0_dataset = VITSF0Dataset(config,
@ -289,6 +291,11 @@ class VitsDataset(TTSDataset):
if self.compute_pitch: if self.compute_pitch:
f0 = self.get_f0(idx)["f0"] f0 = self.get_f0(idx)["f0"]
alignments = None
if self.use_precomputed_alignments:
align_file = os.path.join(self.alignments_cache_path, os.path.splitext(wav_filename)[0] + ".npy")
alignments = self.get_attn_mask(align_file)
# after phonemization the text length may change # after phonemization the text length may change
# this is a shameful 🤭 hack to prevent longer phonemes # this is a shameful 🤭 hack to prevent longer phonemes
# TODO: find a better fix # TODO: find a better fix
@ -305,6 +312,8 @@ class VitsDataset(TTSDataset):
"speaker_name": item["speaker_name"], "speaker_name": item["speaker_name"],
"language_name": item["language"], "language_name": item["language"],
"pitch": f0, "pitch": f0,
"alignments": alignments,
} }
@property @property
@ -366,6 +375,18 @@ class VitsDataset(TTSDataset):
else: else:
pitch = None pitch = None
padded_alignments = None
if self.use_precomputed_alignments:
alignments = batch["alignments"]
max_len_1 = max((x.shape[0] for x in alignments))
max_len_2 = max((x.shape[1] for x in alignments))
padded_alignments = []
for x in alignments:
padded_alignment = np.pad(x, ((0, max_len_1 - x.shape[0]), (0, max_len_2 - x.shape[1])), mode="constant", constant_values=0)
padded_alignments.append(padded_alignment)
padded_alignments = torch.FloatTensor(np.stack(padded_alignments)).unsqueeze(1)
return { return {
"tokens": token_padded, "tokens": token_padded,
"token_lens": token_lens, "token_lens": token_lens,
@ -378,6 +399,7 @@ class VitsDataset(TTSDataset):
"language_names": batch["language_name"], "language_names": batch["language_name"],
"audio_files": batch["wav_file"], "audio_files": batch["wav_file"],
"raw_text": batch["raw_text"], "raw_text": batch["raw_text"],
"alignments": padded_alignments,
} }
@ -385,7 +407,6 @@ class VitsDataset(TTSDataset):
# MODEL DEFINITION # MODEL DEFINITION
############################## ##############################
@dataclass @dataclass
class VitsArgs(Coqpit): class VitsArgs(Coqpit):
"""VITS model arguments. """VITS model arguments.
@ -664,6 +685,9 @@ class VitsArgs(Coqpit):
pitch_predictor_dropout_p: float = 0.1 pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3 pitch_embedding_kernel_size: int = 3
detach_pp_input: bool = False detach_pp_input: bool = False
use_precomputed_alignments: bool = False
alignments_cache_path: str = ""
pitch_embedding_dim: int = 0
detach_dp_input: bool = True detach_dp_input: bool = True
use_language_embedding: bool = False use_language_embedding: bool = False
@ -751,6 +775,7 @@ class Vits(BaseTTS):
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0, emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0,
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor else 0, prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor else 0,
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
) )
self.posterior_encoder = PosteriorEncoder( self.posterior_encoder = PosteriorEncoder(
@ -791,6 +816,9 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor: if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor:
dp_extra_inp_dim += self.args.prosody_embedding_dim dp_extra_inp_dim += self.args.prosody_embedding_dim
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
dp_extra_inp_dim += self.args.pitch_embedding_dim
if self.args.use_sdp: if self.args.use_sdp:
self.duration_predictor = StochasticDurationPredictor( self.duration_predictor = StochasticDurationPredictor(
self.args.hidden_channels + dp_extra_inp_dim, self.args.hidden_channels + dp_extra_inp_dim,
@ -814,10 +842,10 @@ class Vits(BaseTTS):
if self.args.use_pitch: if self.args.use_pitch:
if self.args.use_pitch_on_enc_input: if self.args.use_pitch_on_enc_input:
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
else:
self.pitch_emb = nn.Conv1d( self.pitch_emb = nn.Conv1d(
1, 1,
self.args.hidden_channels, self.args.hidden_channels if not self.args.use_pitch_on_enc_input else self.args.pitch_embedding_dim,
kernel_size=self.args.pitch_predictor_kernel_size, kernel_size=self.args.pitch_predictor_kernel_size,
padding=int((self.args.pitch_predictor_kernel_size - 1) / 2), padding=int((self.args.pitch_predictor_kernel_size - 1) / 2),
) )
@ -1241,17 +1269,16 @@ class Vits(BaseTTS):
) )
pitch_loss = None pitch_loss = None
gt_avg_pitch = None pred_avg_pitch_emb = None
gt_avg_pitch_emb = None
if pitch is not None: if pitch is not None:
gt_avg_pitch = average_over_durations(pitch, dr.squeeze()).detach() gt_avg_pitch = average_over_durations(pitch, dr.squeeze()).detach()
pitch_loss = torch.sum(torch.sum((gt_avg_pitch - pred_avg_pitch) ** 2, [1, 2]) / torch.sum(x_mask)) pitch_loss = torch.sum(torch.sum((gt_avg_pitch - pred_avg_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
if not self.args.use_pitch_on_enc_input: gt_avg_pitch_emb = self.pitch_emb(gt_avg_pitch)
gt_agv_pitch = self.pitch_emb(gt_avg_pitch)
else: else:
if not self.args.use_pitch_on_enc_input: pred_avg_pitch_emb = self.pitch_emb(pred_avg_pitch)
pred_avg_pitch = self.pitch_emb(pred_avg_pitch)
return pitch_loss, gt_agv_pitch, pred_avg_pitch return pitch_loss, gt_avg_pitch_emb, pred_avg_pitch_emb
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
# find the alignment path # find the alignment path
@ -1313,6 +1340,7 @@ class Vits(BaseTTS):
y_lengths: torch.tensor, y_lengths: torch.tensor,
waveform: torch.tensor, waveform: torch.tensor,
pitch: torch.tensor, pitch: torch.tensor,
alignments: torch.tensor,
aux_input={ aux_input={
"d_vectors": None, "d_vectors": None,
"speaker_ids": None, "speaker_ids": None,
@ -1389,6 +1417,21 @@ class Vits(BaseTTS):
if self.args.use_speaker_embedding or self.args.use_d_vector_file: if self.args.use_speaker_embedding or self.args.use_d_vector_file:
g = F.normalize(self.speaker_embedding_squeezer(g.squeeze(-1))).unsqueeze(-1) g = F.normalize(self.speaker_embedding_squeezer(g.squeeze(-1))).unsqueeze(-1)
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None:
g_dp = eg
else:
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
pitch_loss = None
gt_avg_pitch_emb = None
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
if alignments is None:
raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !")
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g_dp)
# posterior encoder # posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
@ -1418,13 +1461,14 @@ class Vits(BaseTTS):
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
if self.args.use_prosody_enc_emo_classifier: if self.args.use_prosody_enc_emo_classifier:
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None) _, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
x_input = x
x, m_p, logs_p, x_mask = self.text_encoder( x, m_p, logs_p, x_mask = self.text_encoder(
x, x,
x_lengths, x_lengths,
lang_emb=lang_emb, lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None, emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None, pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
) )
# reversal speaker loss to force the encoder to be speaker identity free # reversal speaker loss to force the encoder to be speaker identity free
@ -1437,14 +1481,6 @@ class Vits(BaseTTS):
if self.args.use_text_enc_emo_classifier: if self.args.use_text_enc_emo_classifier:
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask) _, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask)
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None:
g_dp = eg
else:
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
if g_dp is None: if g_dp is None:
g_dp = pros_emb g_dp = pros_emb
@ -1453,7 +1489,6 @@ class Vits(BaseTTS):
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
pitch_loss = None
if self.args.use_pitch and not self.args.use_pitch_on_enc_input: if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp) pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
m_p = m_p + gt_avg_pitch_emb m_p = m_p + gt_avg_pitch_emb
@ -1631,14 +1666,6 @@ class Vits(BaseTTS):
pros_emb = pros_emb.transpose(1, 2) pros_emb = pros_emb.transpose(1, 2)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
)
# duration predictor # duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
@ -1647,6 +1674,19 @@ class Vits(BaseTTS):
else: else:
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1] g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
pred_avg_pitch_emb = None
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
)
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
if g_dp is None: if g_dp is None:
g_dp = pros_emb g_dp = pros_emb
@ -1673,7 +1713,6 @@ class Vits(BaseTTS):
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
pred_avg_pitch_emb = None
if self.args.use_pitch and not self.args.use_pitch_on_enc_input: if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp) _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp)
m_p = m_p + pred_avg_pitch_emb m_p = m_p + pred_avg_pitch_emb
@ -1819,6 +1858,7 @@ class Vits(BaseTTS):
emotion_ids = batch["emotion_ids"] emotion_ids = batch["emotion_ids"]
waveform = batch["waveform"] waveform = batch["waveform"]
pitch = batch["pitch"] pitch = batch["pitch"]
alignments = batch["alignments"]
# generator pass # generator pass
outputs = self.forward( outputs = self.forward(
@ -1828,6 +1868,7 @@ class Vits(BaseTTS):
spec_lens, spec_lens,
waveform, waveform,
pitch, pitch,
alignments,
aux_input={ aux_input={
"d_vectors": d_vectors, "d_vectors": d_vectors,
"speaker_ids": speaker_ids, "speaker_ids": speaker_ids,

View File

@ -28,7 +28,7 @@ config = VitsConfig(
compute_pitch=True, compute_pitch=True,
f0_cache_path="tests/data/ljspeech/f0_cache/", f0_cache_path="tests/data/ljspeech/f0_cache/",
test_sentences=[ test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None], ["Be a voice, not an echo.", "ljspeech-1", None, None, None],
], ],
) )
# set audio config # set audio config
@ -42,17 +42,18 @@ config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128 config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 128 config.model_args.d_vector_dim = 128
# prosody embedding
config.model_args.use_prosody_encoder = True config.model_args.use_precomputed_alignments = True
config.model_args.prosody_embedding_dim = 64 config.model_args.alignments_cache_path = "tests/data/ljspeech/mas_alignments/alignments/"
# pitch predictor # pitch predictor
config.model_args.use_pitch = True config.model_args.use_pitch = True
config.model_args.use_pitch_on_enc_input = True
config.model_args.pitch_embedding_dim = 2
config.model_args.condition_dp_on_speaker = True config.model_args.condition_dp_on_speaker = True
config.save_json(config_path) config.save_json(config_path)
# train the model for one epoch # train the model for one epoch
command_train = ( command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
@ -74,11 +75,9 @@ continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path) continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav") out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1" speaker_id = "ljspeech-1"
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
continue_speakers_path = os.path.join(continue_path, "speakers.json") continue_speakers_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}"
run_cli(inference_command) run_cli(inference_command)
# restore the model and continue training for one more epoch # restore the model and continue training for one more epoch