mirror of https://github.com/coqui-ai/TTS.git
Bug fix on pre-compute F0
This commit is contained in:
parent
6a573065f4
commit
6186da855f
|
@ -157,7 +157,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
d_vector_dim: int = None
|
d_vector_dim: int = None
|
||||||
|
|
||||||
# dataset configs
|
# dataset configs
|
||||||
compute_f0: bool = False
|
compute_pitch: bool = False
|
||||||
f0_cache_path: str = None
|
f0_cache_path: str = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
|
@ -207,7 +207,7 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
||||||
spec = amp_to_db(spec)
|
spec = amp_to_db(spec)
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray:
|
def compute_pitch(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray:
|
||||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -235,8 +235,8 @@ def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.n
|
||||||
|
|
||||||
class VITSF0Dataset(F0Dataset):
|
class VITSF0Dataset(F0Dataset):
|
||||||
def __init__(self, config, *args, **kwargs):
|
def __init__(self, config, *args, **kwargs):
|
||||||
|
self.audio_config = config.audio
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def compute_or_load(self, wav_file):
|
def compute_or_load(self, wav_file):
|
||||||
"""
|
"""
|
||||||
|
@ -244,15 +244,15 @@ class VITSF0Dataset(F0Dataset):
|
||||||
"""
|
"""
|
||||||
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||||
if not os.path.exists(pitch_file):
|
if not os.path.exists(pitch_file):
|
||||||
pitch = self._compute_and_save_pitch(wav_file, pitch_file)
|
pitch = self._compute_and_save_pitch(wav_file, self.audio_config.sample_rate, self.audio_config.hop_length, pitch_file)
|
||||||
else:
|
else:
|
||||||
pitch = np.load(pitch_file)
|
pitch = np.load(pitch_file)
|
||||||
return pitch.astype(np.float32)
|
return pitch.astype(np.float32)
|
||||||
|
|
||||||
def _compute_and_save_pitch(self, wav_file, pitch_file=None):
|
@staticmethod
|
||||||
print(wav_file, pitch_file)
|
def _compute_and_save_pitch(wav_file, sample_rate, hop_length, pitch_file=None):
|
||||||
wav, _ = load_audio(wav_file)
|
wav, _ = load_audio(wav_file)
|
||||||
pitch = compute_f0(wav.squeeze().numpy(), self.config.audio.sample_rate, self.config.audio.hop_length)
|
pitch = compute_pitch(wav.squeeze().numpy(), sample_rate, hop_length)
|
||||||
if pitch_file:
|
if pitch_file:
|
||||||
np.save(pitch_file, pitch)
|
np.save(pitch_file, pitch)
|
||||||
return pitch
|
return pitch
|
||||||
|
@ -260,11 +260,13 @@ class VITSF0Dataset(F0Dataset):
|
||||||
|
|
||||||
|
|
||||||
class VitsDataset(TTSDataset):
|
class VitsDataset(TTSDataset):
|
||||||
def __init__(self, model_args, config, *args, **kwargs):
|
def __init__(self, model_args, config, compute_pitch=False, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.pad_id = self.tokenizer.characters.pad_id
|
self.pad_id = self.tokenizer.characters.pad_id
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
|
self.compute_pitch = compute_pitch
|
||||||
|
|
||||||
|
if self.compute_pitch:
|
||||||
self.f0_dataset = VITSF0Dataset(config,
|
self.f0_dataset = VITSF0Dataset(config,
|
||||||
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
|
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
|
||||||
)
|
)
|
||||||
|
@ -284,7 +286,7 @@ class VitsDataset(TTSDataset):
|
||||||
|
|
||||||
# get f0 values
|
# get f0 values
|
||||||
f0 = None
|
f0 = None
|
||||||
if self.compute_f0:
|
if self.compute_pitch:
|
||||||
f0 = self.get_f0(idx)["f0"]
|
f0 = self.get_f0(idx)["f0"]
|
||||||
|
|
||||||
# after phonemization the text length may change
|
# after phonemization the text length may change
|
||||||
|
@ -358,7 +360,7 @@ class VitsDataset(TTSDataset):
|
||||||
|
|
||||||
|
|
||||||
# format F0
|
# format F0
|
||||||
if self.compute_f0:
|
if self.compute_pitch:
|
||||||
pitch = prepare_data(batch["pitch"])
|
pitch = prepare_data(batch["pitch"])
|
||||||
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||||
else:
|
else:
|
||||||
|
@ -816,12 +818,6 @@ class Vits(BaseTTS):
|
||||||
self.args.pitch_predictor_dropout_p,
|
self.args.pitch_predictor_dropout_p,
|
||||||
cond_channels=dp_cond_embedding_dim,
|
cond_channels=dp_cond_embedding_dim,
|
||||||
)
|
)
|
||||||
self.pitch_emb = nn.Conv1d(
|
|
||||||
1,
|
|
||||||
self.args.hidden_channels,
|
|
||||||
kernel_size=self.args.pitch_embedding_kernel_size,
|
|
||||||
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
if self.args.use_pros_enc_input_as_pros_emb:
|
if self.args.use_pros_enc_input_as_pros_emb:
|
||||||
|
@ -1226,10 +1222,9 @@ class Vits(BaseTTS):
|
||||||
x_mask,
|
x_mask,
|
||||||
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
|
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
|
||||||
)
|
)
|
||||||
print(o_pitch.shape, pitch.shape, dr.shape)
|
|
||||||
avg_pitch = average_over_durations(pitch, dr.squeeze())
|
avg_pitch = average_over_durations(pitch, dr.squeeze())
|
||||||
o_pitch_emb = self.pitch_emb(avg_pitch)
|
pitch_loss = torch.sum(torch.sum((avg_pitch - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
|
||||||
pitch_loss = torch.sum(torch.sum((o_pitch_emb - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
|
|
||||||
return pitch_loss
|
return pitch_loss
|
||||||
|
|
||||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||||
|
@ -1432,6 +1427,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||||
|
|
||||||
|
pitch_loss = None
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
|
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
|
||||||
|
|
||||||
|
@ -2249,7 +2245,7 @@ class Vits(BaseTTS):
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
start_by_longest=config.start_by_longest,
|
start_by_longest=config.start_by_longest,
|
||||||
compute_f0=config.get("compute_f0", False),
|
compute_pitch=config.get("compute_pitch", False),
|
||||||
f0_cache_path=config.get("f0_cache_path", None),
|
f0_cache_path=config.get("f0_cache_path", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ config = VitsConfig(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
print_step=1,
|
print_step=1,
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
compute_f0=True,
|
compute_pitch=True,
|
||||||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||||
test_sentences=[
|
test_sentences=[
|
||||||
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
|
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
|
||||||
|
|
Loading…
Reference in New Issue