mirror of https://github.com/coqui-ai/TTS.git
Fix up
This commit is contained in:
parent
8e915b70e0
commit
2d29e8219d
|
@ -89,7 +89,10 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
|
|||
def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None):
|
||||
wav, _ = load_audio(wav_file)
|
||||
f0 = compute_f0(
|
||||
x=wav.numpy()[0], sample_rate=audio_config.sample_rate, hop_length=audio_config.hop_length, pitch_fmax=audio_config.pitch_fmax
|
||||
x=wav.numpy()[0],
|
||||
sample_rate=audio_config.sample_rate,
|
||||
hop_length=audio_config.hop_length,
|
||||
pitch_fmax=audio_config.pitch_fmax,
|
||||
)
|
||||
# skip the last F0 value to align with the spectrogram
|
||||
if wav.shape[1] % audio_config.hop_length != 0:
|
||||
|
@ -104,7 +107,9 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
|
|||
"""
|
||||
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = self._compute_and_save_pitch(audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file)
|
||||
pitch = self._compute_and_save_pitch(
|
||||
audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file
|
||||
)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch.astype(np.float32)
|
||||
|
@ -300,6 +305,9 @@ class ForwardTTSE2eArgs(ForwardTTSArgs):
|
|||
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
|
||||
upsample_initial_channel_decoder: int = 512
|
||||
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||
# discriminator
|
||||
upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4])
|
||||
periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
|
||||
# multi-speaker params
|
||||
use_speaker_embedding: bool = False
|
||||
num_speakers: int = 0
|
||||
|
@ -359,7 +367,18 @@ class ForwardTTSE2e(BaseTTSE2E):
|
|||
|
||||
# use Vits Discriminator for limiting VRAM use
|
||||
if self.args.init_discriminator:
|
||||
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_discriminator)
|
||||
self.disc = VitsDiscriminator(
|
||||
use_spectral_norm=self.args.use_spectral_norm_discriminator,
|
||||
periods=self.args.periods_discriminator,
|
||||
upsampling_rates=self.args.upsampling_rates_discriminator,
|
||||
)
|
||||
|
||||
# def check_model_args(self):
|
||||
# upsample_rate = torch.prod(torch.as_tensor(self.args.upsample_rates_decoder)).item()
|
||||
# if s
|
||||
# assert (
|
||||
# upsample_rate == self.config.audio.hop_length
|
||||
# ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Init for multi-speaker training.
|
||||
|
@ -440,9 +459,10 @@ class ForwardTTSE2e(BaseTTSE2E):
|
|||
let_short_samples=True,
|
||||
pad_short=True,
|
||||
)
|
||||
|
||||
vocoder_output = self.waveform_decoder(
|
||||
x=o_en_ex_slices.detach() if self.args.detach_vocoder_input else o_en_ex_slices,
|
||||
g=encoder_outputs["g"],
|
||||
g=encoder_outputs["spk_emb"],
|
||||
)
|
||||
wav_seg = segment(
|
||||
waveform,
|
||||
|
@ -461,7 +481,7 @@ class ForwardTTSE2e(BaseTTSE2E):
|
|||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
||||
encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=True)
|
||||
o_en_ex = encoder_outputs["o_en_ex"]
|
||||
vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["g"])
|
||||
vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["spk_emb"])
|
||||
model_outputs = {**encoder_outputs}
|
||||
model_outputs["model_outputs"] = vocoder_output
|
||||
return model_outputs
|
||||
|
@ -860,9 +880,11 @@ class ForwardTTSE2e(BaseTTSE2E):
|
|||
center=False,
|
||||
)
|
||||
|
||||
assert (
|
||||
batch["pitch"].shape[2] == batch["mel_input"].shape[2]
|
||||
), f"{batch['pitch'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
# TODO: Align pitch properly
|
||||
# assert (
|
||||
# batch["pitch"].shape[2] == batch["mel_input"].shape[2]
|
||||
# ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}"
|
||||
batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]]
|
||||
batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||
|
||||
# zero the padding frames
|
||||
|
@ -990,9 +1012,7 @@ class ForwardTTSE2e(BaseTTSE2E):
|
|||
# language_manager = LanguageManager.init_from_config(config)
|
||||
return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
):
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||
"""Load model from a checkpoint created by the 👟"""
|
||||
# pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
|
@ -1003,11 +1023,7 @@ class ForwardTTSE2e(BaseTTSE2E):
|
|||
|
||||
def get_state_dict(self):
|
||||
"""Custom state dict of the model with all the necessary components for inference."""
|
||||
save_state = {
|
||||
"config": self.config.to_dict(),
|
||||
"args": self.args.to_dict(),
|
||||
"model": self.state_dict
|
||||
}
|
||||
save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict}
|
||||
|
||||
if hasattr(self, "emb_g"):
|
||||
save_state["speaker_ids"] = self.speaker_manager.speaker_ids
|
||||
|
|
Loading…
Reference in New Issue