Bug fix on pre-compute F0

This commit is contained in:
Edresson Casanova 2022-05-19 13:48:02 +00:00
parent d94b8bac02
commit bdefc43d96
4 changed files with 29 additions and 28 deletions

View File

@ -151,7 +151,7 @@ class VitsConfig(BaseTTSConfig):
d_vector_dim: int = None d_vector_dim: int = None
# dataset configs # dataset configs
compute_f0: bool = False compute_pitch: bool = False
f0_cache_path: str = None f0_cache_path: str = None
def __post_init__(self): def __post_init__(self):

View File

@ -646,6 +646,7 @@ class VitsGeneratorLoss(nn.Module):
if loss_spk_reversal_classifier is not None: if loss_spk_reversal_classifier is not None:
loss += loss_spk_reversal_classifier loss += loss_spk_reversal_classifier
return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier
if pitch_loss is not None: if pitch_loss is not None:
pitch_loss = pitch_loss * self.pitch_loss_alpha pitch_loss = pitch_loss * self.pitch_loss_alpha
loss += pitch_loss loss += pitch_loss

View File

@ -189,7 +189,7 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
spec = amp_to_db(spec) spec = amp_to_db(spec)
return spec return spec
def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray: def compute_pitch(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
Args: Args:
@ -217,8 +217,8 @@ def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.n
class VITSF0Dataset(F0Dataset): class VITSF0Dataset(F0Dataset):
def __init__(self, config, *args, **kwargs): def __init__(self, config, *args, **kwargs):
self.audio_config = config.audio
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.config = config
def compute_or_load(self, wav_file): def compute_or_load(self, wav_file):
""" """
@ -226,15 +226,15 @@ class VITSF0Dataset(F0Dataset):
""" """
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file): if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(wav_file, pitch_file) pitch = self._compute_and_save_pitch(wav_file, self.audio_config.sample_rate, self.audio_config.hop_length, pitch_file)
else: else:
pitch = np.load(pitch_file) pitch = np.load(pitch_file)
return pitch.astype(np.float32) return pitch.astype(np.float32)
def _compute_and_save_pitch(self, wav_file, pitch_file=None): @staticmethod
print(wav_file, pitch_file) def _compute_and_save_pitch(wav_file, sample_rate, hop_length, pitch_file=None):
wav, _ = load_audio(wav_file) wav, _ = load_audio(wav_file)
pitch = compute_f0(wav.squeeze().numpy(), self.config.audio.sample_rate, self.config.audio.hop_length) pitch = compute_pitch(wav.squeeze().numpy(), sample_rate, hop_length)
if pitch_file: if pitch_file:
np.save(pitch_file, pitch) np.save(pitch_file, pitch)
return pitch return pitch
@ -242,11 +242,14 @@ class VITSF0Dataset(F0Dataset):
class VitsDataset(TTSDataset): class VitsDataset(TTSDataset):
def __init__(self, config, *args, **kwargs): def __init__(self, config, compute_pitch=False, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.pad_id = self.tokenizer.characters.pad_id self.pad_id = self.tokenizer.characters.pad_id
self.compute_pitch = compute_pitch
self.f0_dataset = VITSF0Dataset(config, if self.compute_pitch:
self.f0_dataset = VITSF0Dataset(config,
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
) )
@ -261,7 +264,7 @@ class VitsDataset(TTSDataset):
# get f0 values # get f0 values
f0 = None f0 = None
if self.compute_f0: if self.compute_pitch:
f0 = self.get_f0(idx)["f0"] f0 = self.get_f0(idx)["f0"]
# after phonemization the text length may change # after phonemization the text length may change
@ -335,7 +338,7 @@ class VitsDataset(TTSDataset):
# format F0 # format F0
if self.compute_f0: if self.compute_pitch:
pitch = prepare_data(batch["pitch"]) pitch = prepare_data(batch["pitch"])
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else: else:
@ -592,6 +595,7 @@ class VitsArgs(Coqpit):
prosody_embedding_dim: int = 0 prosody_embedding_dim: int = 0
prosody_encoder_num_heads: int = 1 prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5 prosody_encoder_num_tokens: int = 5
use_prosody_enc_spk_reversal_classifier: bool = True
# Pitch predictor # Pitch predictor
use_pitch: bool = False use_pitch: bool = False
@ -739,12 +743,6 @@ class Vits(BaseTTS):
self.args.pitch_predictor_dropout_p, self.args.pitch_predictor_dropout_p,
cond_channels=dp_cond_embedding_dim, cond_channels=dp_cond_embedding_dim,
) )
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.pitch_embedding_kernel_size,
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
)
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
self.prosody_encoder = GST( self.prosody_encoder = GST(
@ -753,11 +751,12 @@ class Vits(BaseTTS):
num_style_tokens=self.args.prosody_encoder_num_tokens, num_style_tokens=self.args.prosody_encoder_num_tokens,
gst_embedding_dim=self.args.prosody_embedding_dim, gst_embedding_dim=self.args.prosody_embedding_dim,
) )
self.speaker_reversal_classifier = ReversalClassifier( if self.args.use_prosody_enc_spk_reversal_classifier:
in_channels=self.args.prosody_embedding_dim, self.speaker_reversal_classifier = ReversalClassifier(
out_channels=self.num_speakers, in_channels=self.args.prosody_embedding_dim,
hidden_channels=256, out_channels=self.num_speakers,
) hidden_channels=256,
)
self.waveform_decoder = HifiganGenerator( self.waveform_decoder = HifiganGenerator(
self.args.hidden_channels, self.args.hidden_channels,
@ -1020,10 +1019,9 @@ class Vits(BaseTTS):
x_mask, x_mask,
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
) )
print(o_pitch.shape, pitch.shape, dr.shape)
avg_pitch = average_over_durations(pitch, dr.squeeze()) avg_pitch = average_over_durations(pitch, dr.squeeze())
o_pitch_emb = self.pitch_emb(avg_pitch) pitch_loss = torch.sum(torch.sum((avg_pitch - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
pitch_loss = torch.sum(torch.sum((o_pitch_emb - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
return pitch_loss return pitch_loss
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
@ -1137,7 +1135,8 @@ class Vits(BaseTTS):
l_pros_speaker = None l_pros_speaker = None
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2) pros_emb = self.prosody_encoder(z).transpose(1, 2)
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) if self.args.use_prosody_enc_spk_reversal_classifier:
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
@ -1160,6 +1159,7 @@ class Vits(BaseTTS):
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
pitch_loss = None
if self.args.use_pitch: if self.args.use_pitch:
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp) pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
@ -1781,7 +1781,7 @@ class Vits(BaseTTS):
verbose=verbose, verbose=verbose,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest, start_by_longest=config.start_by_longest,
compute_f0=config.get("compute_f0", False), compute_pitch=config.get("compute_pitch", False),
f0_cache_path=config.get("f0_cache_path", None), f0_cache_path=config.get("f0_cache_path", None),
) )

View File

@ -25,7 +25,7 @@ config = VitsConfig(
epochs=1, epochs=1,
print_step=1, print_step=1,
print_eval=True, print_eval=True,
compute_f0=True, compute_pitch=True,
f0_cache_path="tests/data/ljspeech/f0_cache/", f0_cache_path="tests/data/ljspeech/f0_cache/",
test_sentences=[ test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None], ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],