Add Pitch Predictor conditioned on m_p

This commit is contained in:
Edresson Casanova 2022-06-17 17:45:10 +00:00
parent 6186da855f
commit 8f6c187848
3 changed files with 98 additions and 57 deletions

View File

@ -24,54 +24,56 @@ def extract_aligments(
data_loader, model, output_path, use_cuda=True
):
model.eval()
export_metadata = []
for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
batch = model.format_batch(batch)
if use_cuda:
for k, v in batch.items():
batch[k] = to_cuda(v)
batch = model.format_batch(batch)
if use_cuda:
for k, v in batch.items():
batch[k] = to_cuda(v)
batch = model.format_batch_on_device(batch)
batch = model.format_batch_on_device(batch)
spec_lens = batch["spec_lens"]
tokens = batch["tokens"]
token_lenghts = batch["token_lens"]
spec = batch["spec"]
spec_lens = batch["spec_lens"]
tokens = batch["tokens"]
token_lenghts = batch["token_lens"]
spec = batch["spec"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
language_ids = batch["language_ids"]
emotion_embeddings = batch["emotion_embeddings"]
emotion_ids = batch["emotion_ids"]
waveform = batch["waveform"]
item_idx = batch["audio_files"]
# generator pass
outputs = model.forward(
tokens,
token_lenghts,
spec,
spec_lens,
waveform,
aux_input={
"d_vectors": d_vectors,
"speaker_ids": speaker_ids,
"language_ids": language_ids,
"emotion_embeddings": emotion_embeddings,
"emotion_ids": emotion_ids,
},
)
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
language_ids = batch["language_ids"]
emotion_embeddings = batch["emotion_embeddings"]
emotion_ids = batch["emotion_ids"]
waveform = batch["waveform"]
item_idx = batch["audio_files"]
pitch = batch["pitch"]
# generator pass
outputs = model.forward(
tokens,
token_lenghts,
spec,
spec_lens,
waveform,
pitch,
aux_input={
"d_vectors": d_vectors,
"speaker_ids": speaker_ids,
"language_ids": language_ids,
"emotion_embeddings": emotion_embeddings,
"emotion_ids": emotion_ids,
},
)
alignments = outputs["alignments"].detach().cpu().numpy()
alignments = outputs["alignments"].detach().cpu().numpy()
for idx in range(tokens.shape[0]):
wav_file_path = item_idx[idx]
alignment = alignments[idx]
# set paths
align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy"
os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True)
align_file_path = os.path.join(output_path, "alignments", align_file_name)
np.save(align_file_path, alignment)
for idx in range(tokens.shape[0]):
wav_file_path = item_idx[idx]
alignment = alignments[idx]
# set paths
align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy"
os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True)
align_file_path = os.path.join(output_path, "alignments", align_file_name)
np.save(align_file_path, alignment)

View File

@ -657,6 +657,7 @@ class VitsArgs(Coqpit):
use_avg_feature_on_latent_discriminator: bool = False
# Pitch predictor
use_pitch_on_enc_input: bool = False
use_pitch: bool = False
pitch_predictor_hidden_channels: int = 256
pitch_predictor_kernel_size: int = 3
@ -811,12 +812,22 @@ class Vits(BaseTTS):
)
if self.args.use_pitch:
if self.args.use_pitch_on_enc_input:
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
else:
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.pitch_predictor_kernel_size,
padding=int((self.args.pitch_predictor_kernel_size - 1) / 2),
)
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
self.args.hidden_channels,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
cond_channels=dp_cond_embedding_dim,
language_emb_dim=self.embedded_language_dim,
)
if self.args.use_prosody_encoder:
@ -1190,7 +1201,7 @@ class Vits(BaseTTS):
def forward_pitch_predictor(
self,
o_en: torch.FloatTensor,
x_mask: torch.IntTensor,
x_lengths: torch.IntTensor,
pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None,
g_pp: torch.IntTensor = None,
@ -1217,15 +1228,30 @@ class Vits(BaseTTS):
- pitch: :math:`(B, 1, T_{de})`
- dr: :math:`(B, T_{en})`
"""
o_pitch = self.pitch_predictor(
if self.args.use_pitch_on_enc_input:
o_en = self.pitch_predictor_vocab_emb(o_en)
o_en = torch.transpose(o_en, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, o_en.size(2)), 1).to(o_en.dtype) # [b, 1, t]
pred_avg_pitch = self.pitch_predictor(
o_en,
x_mask,
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
)
avg_pitch = average_over_durations(pitch, dr.squeeze())
pitch_loss = torch.sum(torch.sum((avg_pitch - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
return pitch_loss
pitch_loss = None
gt_avg_pitch = None
if pitch is not None:
gt_avg_pitch = average_over_durations(pitch, dr.squeeze()).detach()
pitch_loss = torch.sum(torch.sum((gt_avg_pitch - pred_avg_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
if not self.args.use_pitch_on_enc_input:
gt_agv_pitch = self.pitch_emb(gt_avg_pitch)
else:
if not self.args.use_pitch_on_enc_input:
pred_avg_pitch = self.pitch_emb(pred_avg_pitch)
return pitch_loss, gt_agv_pitch, pred_avg_pitch
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
# find the alignment path
@ -1392,7 +1418,7 @@ class Vits(BaseTTS):
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
if self.args.use_prosody_enc_emo_classifier:
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
x_input = x
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,
@ -1428,8 +1454,9 @@ class Vits(BaseTTS):
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
pitch_loss = None
if self.args.use_pitch:
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
m_p = m_p + gt_avg_pitch_emb
# expand prior
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
@ -1646,6 +1673,12 @@ class Vits(BaseTTS):
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
pred_avg_pitch_emb = None
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp)
m_p = m_p + pred_avg_pitch_emb
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
@ -1683,6 +1716,7 @@ class Vits(BaseTTS):
"m_p": m_p,
"logs_p": logs_p,
"y_mask": y_mask,
"pitch": pred_avg_pitch_emb,
}
return outputs
@ -1693,7 +1727,7 @@ class Vits(BaseTTS):
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
)
y = wav_to_spec(
style_wav,
style_wav.unsqueeze(1),
self.config.audio.fft_size,
self.config.audio.hop_length,
self.config.audio.win_length,

View File

@ -25,6 +25,8 @@ config = VitsConfig(
epochs=1,
print_step=1,
print_eval=True,
compute_pitch=True,
f0_cache_path="tests/data/ljspeech/f0_cache/",
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"],
],
@ -57,12 +59,15 @@ config.model_args.use_latent_discriminator = True
config.model_args.use_noise_scale_predictor = False
config.model_args.condition_pros_enc_on_speaker = True
config.model_args.use_pros_enc_input_as_pros_emb = True
config.model_args.use_prosody_embedding_squeezer = True
config.model_args.prosody_embedding_squeezer_input_dim = 192
config.model_args.use_pros_enc_input_as_pros_emb = False
config.model_args.use_prosody_embedding_squeezer = False
config.model_args.prosody_embedding_squeezer_input_dim = 0
# pitch predictor
config.model_args.use_pitch = True
config.model_args.use_pitch_on_enc_input = False
config.model_args.condition_dp_on_speaker = False
# enable end2end loss
config.model_args.use_end2end_loss = False
config.mixed_precision = False