Add Pitch Predictor conditioned on enc

This commit is contained in:
Edresson Casanova 2022-06-17 17:39:17 -03:00
parent 8f6c187848
commit 569decba64
12 changed files with 93 additions and 43 deletions

View File

@ -69,6 +69,9 @@ def extract_aligments(
for idx in range(tokens.shape[0]):
wav_file_path = item_idx[idx]
alignment = alignments[idx]
spec_length = spec_lens[idx]
token_length = token_lenghts[idx]
alignment = alignment[:token_length, :spec_length]
# set paths
align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy"
os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True)

View File

@ -40,6 +40,7 @@ class TextEncoder(nn.Module):
language_emb_dim: int = None,
emotion_emb_dim: int = None,
prosody_emb_dim: int = None,
pitch_dim: int = None,
):
"""Text Encoder for VITS model.
@ -70,6 +71,9 @@ class TextEncoder(nn.Module):
if prosody_emb_dim:
hidden_channels += prosody_emb_dim
if pitch_dim:
hidden_channels += pitch_dim
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
out_channels=hidden_channels,
@ -85,7 +89,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None):
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None, pitch_emb=None):
"""
Shapes:
- x: :math:`[B, T]`
@ -105,6 +109,9 @@ class TextEncoder(nn.Module):
if pros_emb is not None:
x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
if pitch_emb is not None:
x = torch.cat((x, pitch_emb.transpose(2, 1)), dim=-1)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]

View File

@ -265,7 +265,9 @@ class VitsDataset(TTSDataset):
self.pad_id = self.tokenizer.characters.pad_id
self.model_args = model_args
self.compute_pitch = compute_pitch
self.use_precomputed_alignments = model_args.use_precomputed_alignments
self.alignments_cache_path = model_args.alignments_cache_path
if self.compute_pitch:
self.f0_dataset = VITSF0Dataset(config,
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
@ -289,6 +291,11 @@ class VitsDataset(TTSDataset):
if self.compute_pitch:
f0 = self.get_f0(idx)["f0"]
alignments = None
if self.use_precomputed_alignments:
align_file = os.path.join(self.alignments_cache_path, os.path.splitext(wav_filename)[0] + ".npy")
alignments = self.get_attn_mask(align_file)
# after phonemization the text length may change
# this is a shameful 🤭 hack to prevent longer phonemes
# TODO: find a better fix
@ -305,6 +312,8 @@ class VitsDataset(TTSDataset):
"speaker_name": item["speaker_name"],
"language_name": item["language"],
"pitch": f0,
"alignments": alignments,
}
@property
@ -365,6 +374,18 @@ class VitsDataset(TTSDataset):
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else:
pitch = None
padded_alignments = None
if self.use_precomputed_alignments:
alignments = batch["alignments"]
max_len_1 = max((x.shape[0] for x in alignments))
max_len_2 = max((x.shape[1] for x in alignments))
padded_alignments = []
for x in alignments:
padded_alignment = np.pad(x, ((0, max_len_1 - x.shape[0]), (0, max_len_2 - x.shape[1])), mode="constant", constant_values=0)
padded_alignments.append(padded_alignment)
padded_alignments = torch.FloatTensor(np.stack(padded_alignments)).unsqueeze(1)
return {
"tokens": token_padded,
@ -378,6 +399,7 @@ class VitsDataset(TTSDataset):
"language_names": batch["language_name"],
"audio_files": batch["wav_file"],
"raw_text": batch["raw_text"],
"alignments": padded_alignments,
}
@ -385,7 +407,6 @@ class VitsDataset(TTSDataset):
# MODEL DEFINITION
##############################
@dataclass
class VitsArgs(Coqpit):
"""VITS model arguments.
@ -664,6 +685,9 @@ class VitsArgs(Coqpit):
pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3
detach_pp_input: bool = False
use_precomputed_alignments: bool = False
alignments_cache_path: str = ""
pitch_embedding_dim: int = 0
detach_dp_input: bool = True
use_language_embedding: bool = False
@ -751,6 +775,7 @@ class Vits(BaseTTS):
language_emb_dim=self.embedded_language_dim,
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0,
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor else 0,
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
)
self.posterior_encoder = PosteriorEncoder(
@ -791,6 +816,9 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor:
dp_extra_inp_dim += self.args.prosody_embedding_dim
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
dp_extra_inp_dim += self.args.pitch_embedding_dim
if self.args.use_sdp:
self.duration_predictor = StochasticDurationPredictor(
self.args.hidden_channels + dp_extra_inp_dim,
@ -814,13 +842,13 @@ class Vits(BaseTTS):
if self.args.use_pitch:
if self.args.use_pitch_on_enc_input:
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
else:
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.pitch_predictor_kernel_size,
padding=int((self.args.pitch_predictor_kernel_size - 1) / 2),
)
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels if not self.args.use_pitch_on_enc_input else self.args.pitch_embedding_dim,
kernel_size=self.args.pitch_predictor_kernel_size,
padding=int((self.args.pitch_predictor_kernel_size - 1) / 2),
)
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels,
self.args.pitch_predictor_hidden_channels,
@ -1241,17 +1269,16 @@ class Vits(BaseTTS):
)
pitch_loss = None
gt_avg_pitch = None
pred_avg_pitch_emb = None
gt_avg_pitch_emb = None
if pitch is not None:
gt_avg_pitch = average_over_durations(pitch, dr.squeeze()).detach()
pitch_loss = torch.sum(torch.sum((gt_avg_pitch - pred_avg_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
if not self.args.use_pitch_on_enc_input:
gt_agv_pitch = self.pitch_emb(gt_avg_pitch)
gt_avg_pitch_emb = self.pitch_emb(gt_avg_pitch)
else:
if not self.args.use_pitch_on_enc_input:
pred_avg_pitch = self.pitch_emb(pred_avg_pitch)
pred_avg_pitch_emb = self.pitch_emb(pred_avg_pitch)
return pitch_loss, gt_agv_pitch, pred_avg_pitch
return pitch_loss, gt_avg_pitch_emb, pred_avg_pitch_emb
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
# find the alignment path
@ -1313,6 +1340,7 @@ class Vits(BaseTTS):
y_lengths: torch.tensor,
waveform: torch.tensor,
pitch: torch.tensor,
alignments: torch.tensor,
aux_input={
"d_vectors": None,
"speaker_ids": None,
@ -1389,6 +1417,21 @@ class Vits(BaseTTS):
if self.args.use_speaker_embedding or self.args.use_d_vector_file:
g = F.normalize(self.speaker_embedding_squeezer(g.squeeze(-1))).unsqueeze(-1)
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None:
g_dp = eg
else:
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
pitch_loss = None
gt_avg_pitch_emb = None
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
if alignments is None:
raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !")
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g_dp)
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
@ -1418,13 +1461,14 @@ class Vits(BaseTTS):
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
if self.args.use_prosody_enc_emo_classifier:
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
x_input = x
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
)
# reversal speaker loss to force the encoder to be speaker identity free
@ -1437,14 +1481,6 @@ class Vits(BaseTTS):
if self.args.use_text_enc_emo_classifier:
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask)
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None:
g_dp = eg
else:
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
if self.args.use_prosody_encoder:
if g_dp is None:
g_dp = pros_emb
@ -1453,7 +1489,6 @@ class Vits(BaseTTS):
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
pitch_loss = None
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
m_p = m_p + gt_avg_pitch_emb
@ -1631,14 +1666,6 @@ class Vits(BaseTTS):
pros_emb = pros_emb.transpose(1, 2)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
)
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
@ -1647,6 +1674,19 @@ class Vits(BaseTTS):
else:
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
pred_avg_pitch_emb = None
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
)
if self.args.use_prosody_encoder:
if g_dp is None:
g_dp = pros_emb
@ -1673,7 +1713,6 @@ class Vits(BaseTTS):
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
pred_avg_pitch_emb = None
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp)
m_p = m_p + pred_avg_pitch_emb
@ -1819,6 +1858,7 @@ class Vits(BaseTTS):
emotion_ids = batch["emotion_ids"]
waveform = batch["waveform"]
pitch = batch["pitch"]
alignments = batch["alignments"]
# generator pass
outputs = self.forward(
@ -1828,6 +1868,7 @@ class Vits(BaseTTS):
spec_lens,
waveform,
pitch,
alignments,
aux_input={
"d_vectors": d_vectors,
"speaker_ids": speaker_ids,

View File

@ -28,7 +28,7 @@ config = VitsConfig(
compute_pitch=True,
f0_cache_path="tests/data/ljspeech/f0_cache/",
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
["Be a voice, not an echo.", "ljspeech-1", None, None, None],
],
)
# set audio config
@ -42,17 +42,18 @@ config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 128
# prosody embedding
config.model_args.use_prosody_encoder = True
config.model_args.prosody_embedding_dim = 64
config.model_args.use_precomputed_alignments = True
config.model_args.alignments_cache_path = "tests/data/ljspeech/mas_alignments/alignments/"
# pitch predictor
config.model_args.use_pitch = True
config.model_args.use_pitch_on_enc_input = True
config.model_args.pitch_embedding_dim = 2
config.model_args.condition_dp_on_speaker = True
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
@ -74,11 +75,9 @@ continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}"
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
run_cli(inference_command)
# restore the model and continue training for one more epoch