mirror of https://github.com/coqui-ai/TTS.git
Add Pitch Predictor conditioned on m_p
This commit is contained in:
parent
6186da855f
commit
8f6c187848
|
@ -24,7 +24,7 @@ def extract_aligments(
|
|||
data_loader, model, output_path, use_cuda=True
|
||||
):
|
||||
model.eval()
|
||||
export_metadata = []
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
|
||||
|
||||
batch = model.format_batch(batch)
|
||||
|
@ -46,6 +46,7 @@ def extract_aligments(
|
|||
emotion_ids = batch["emotion_ids"]
|
||||
waveform = batch["waveform"]
|
||||
item_idx = batch["audio_files"]
|
||||
pitch = batch["pitch"]
|
||||
# generator pass
|
||||
outputs = model.forward(
|
||||
tokens,
|
||||
|
@ -53,6 +54,7 @@ def extract_aligments(
|
|||
spec,
|
||||
spec_lens,
|
||||
waveform,
|
||||
pitch,
|
||||
aux_input={
|
||||
"d_vectors": d_vectors,
|
||||
"speaker_ids": speaker_ids,
|
||||
|
|
|
@ -657,6 +657,7 @@ class VitsArgs(Coqpit):
|
|||
use_avg_feature_on_latent_discriminator: bool = False
|
||||
|
||||
# Pitch predictor
|
||||
use_pitch_on_enc_input: bool = False
|
||||
use_pitch: bool = False
|
||||
pitch_predictor_hidden_channels: int = 256
|
||||
pitch_predictor_kernel_size: int = 3
|
||||
|
@ -811,12 +812,22 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
if self.args.use_pitch:
|
||||
if self.args.use_pitch_on_enc_input:
|
||||
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
||||
else:
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1,
|
||||
self.args.hidden_channels,
|
||||
kernel_size=self.args.pitch_predictor_kernel_size,
|
||||
padding=int((self.args.pitch_predictor_kernel_size - 1) / 2),
|
||||
)
|
||||
self.pitch_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||
self.args.hidden_channels,
|
||||
self.args.pitch_predictor_hidden_channels,
|
||||
self.args.pitch_predictor_kernel_size,
|
||||
self.args.pitch_predictor_dropout_p,
|
||||
cond_channels=dp_cond_embedding_dim,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
|
@ -1190,7 +1201,7 @@ class Vits(BaseTTS):
|
|||
def forward_pitch_predictor(
|
||||
self,
|
||||
o_en: torch.FloatTensor,
|
||||
x_mask: torch.IntTensor,
|
||||
x_lengths: torch.IntTensor,
|
||||
pitch: torch.FloatTensor = None,
|
||||
dr: torch.IntTensor = None,
|
||||
g_pp: torch.IntTensor = None,
|
||||
|
@ -1217,15 +1228,30 @@ class Vits(BaseTTS):
|
|||
- pitch: :math:`(B, 1, T_{de})`
|
||||
- dr: :math:`(B, T_{en})`
|
||||
"""
|
||||
o_pitch = self.pitch_predictor(
|
||||
if self.args.use_pitch_on_enc_input:
|
||||
o_en = self.pitch_predictor_vocab_emb(o_en)
|
||||
o_en = torch.transpose(o_en, 1, -1) # [b, h, t]
|
||||
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, o_en.size(2)), 1).to(o_en.dtype) # [b, 1, t]
|
||||
|
||||
pred_avg_pitch = self.pitch_predictor(
|
||||
o_en,
|
||||
x_mask,
|
||||
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
|
||||
)
|
||||
|
||||
avg_pitch = average_over_durations(pitch, dr.squeeze())
|
||||
pitch_loss = torch.sum(torch.sum((avg_pitch - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
|
||||
return pitch_loss
|
||||
pitch_loss = None
|
||||
gt_avg_pitch = None
|
||||
if pitch is not None:
|
||||
gt_avg_pitch = average_over_durations(pitch, dr.squeeze()).detach()
|
||||
pitch_loss = torch.sum(torch.sum((gt_avg_pitch - pred_avg_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
|
||||
if not self.args.use_pitch_on_enc_input:
|
||||
gt_agv_pitch = self.pitch_emb(gt_avg_pitch)
|
||||
else:
|
||||
if not self.args.use_pitch_on_enc_input:
|
||||
pred_avg_pitch = self.pitch_emb(pred_avg_pitch)
|
||||
|
||||
return pitch_loss, gt_agv_pitch, pred_avg_pitch
|
||||
|
||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||
# find the alignment path
|
||||
|
@ -1392,7 +1418,7 @@ class Vits(BaseTTS):
|
|||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
if self.args.use_prosody_enc_emo_classifier:
|
||||
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
|
||||
|
||||
x_input = x
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
|
@ -1428,8 +1454,9 @@ class Vits(BaseTTS):
|
|||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||
|
||||
pitch_loss = None
|
||||
if self.args.use_pitch:
|
||||
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
|
||||
m_p = m_p + gt_avg_pitch_emb
|
||||
|
||||
# expand prior
|
||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
|
@ -1646,6 +1673,12 @@ class Vits(BaseTTS):
|
|||
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||
|
||||
pred_avg_pitch_emb = None
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp)
|
||||
m_p = m_p + pred_avg_pitch_emb
|
||||
|
||||
|
||||
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
|
@ -1683,6 +1716,7 @@ class Vits(BaseTTS):
|
|||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"y_mask": y_mask,
|
||||
"pitch": pred_avg_pitch_emb,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -1693,7 +1727,7 @@ class Vits(BaseTTS):
|
|||
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
|
||||
)
|
||||
y = wav_to_spec(
|
||||
style_wav,
|
||||
style_wav.unsqueeze(1),
|
||||
self.config.audio.fft_size,
|
||||
self.config.audio.hop_length,
|
||||
self.config.audio.win_length,
|
||||
|
|
|
@ -25,6 +25,8 @@ config = VitsConfig(
|
|||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
compute_pitch=True,
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"],
|
||||
],
|
||||
|
@ -57,12 +59,15 @@ config.model_args.use_latent_discriminator = True
|
|||
config.model_args.use_noise_scale_predictor = False
|
||||
config.model_args.condition_pros_enc_on_speaker = True
|
||||
|
||||
config.model_args.use_pros_enc_input_as_pros_emb = True
|
||||
config.model_args.use_prosody_embedding_squeezer = True
|
||||
config.model_args.prosody_embedding_squeezer_input_dim = 192
|
||||
config.model_args.use_pros_enc_input_as_pros_emb = False
|
||||
config.model_args.use_prosody_embedding_squeezer = False
|
||||
config.model_args.prosody_embedding_squeezer_input_dim = 0
|
||||
|
||||
# pitch predictor
|
||||
config.model_args.use_pitch = True
|
||||
config.model_args.use_pitch_on_enc_input = False
|
||||
config.model_args.condition_dp_on_speaker = False
|
||||
|
||||
# enable end2end loss
|
||||
config.model_args.use_end2end_loss = False
|
||||
|
||||
config.mixed_precision = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue