mirror of https://github.com/coqui-ai/TTS.git
Rename TTS_part_sample_rate to encoder_sample_rate
This commit is contained in:
parent
3f3efe88bb
commit
b3e2c58398
|
@ -455,14 +455,14 @@ class VitsArgs(Coqpit):
|
|||
freeze_waveform_decoder (bool):
|
||||
Freeze the waveform decoder weigths during training. Defaults to False.
|
||||
|
||||
TTS_part_sample_rate (int):
|
||||
encoder_sample_rate (int):
|
||||
If not None this sample rate will be used for training the Posterior Encoder,
|
||||
flow, text_encoder and duration predictor. The decoder part (vocoder) will be
|
||||
trained with the `config.audio.sample_rate`. Defaults to None.
|
||||
|
||||
interpolate_z (bool):
|
||||
If `TTS_part_sample_rate` not None and this parameter True the nearest interpolation
|
||||
will be used to upsampling the latent variable z with the sampling rate `TTS_part_sample_rate`
|
||||
If `encoder_sample_rate` not None and this parameter True the nearest interpolation
|
||||
will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate`
|
||||
to the `config.audio.sample_rate`. If it is False you will need to add extra
|
||||
`upsample_rates_decoder` to match the shape. Defaults to True.
|
||||
|
||||
|
@ -521,7 +521,7 @@ class VitsArgs(Coqpit):
|
|||
freeze_PE: bool = False
|
||||
freeze_flow_decoder: bool = False
|
||||
freeze_waveform_decoder: bool = False
|
||||
TTS_part_sample_rate: int = None
|
||||
encoder_sample_rate: int = None
|
||||
interpolate_z: bool = True
|
||||
|
||||
|
||||
|
@ -648,10 +648,10 @@ class Vits(BaseTTS):
|
|||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||
)
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.TTS_part_sample_rate
|
||||
if self.args.encoder_sample_rate:
|
||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
||||
self.audio_resampler = torchaudio.transforms.Resample(
|
||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.TTS_part_sample_rate
|
||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
|
@ -906,7 +906,7 @@ class Vits(BaseTTS):
|
|||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
if self.args.encoder_sample_rate:
|
||||
slice_ids = slice_ids * int(self.interpolate_factor)
|
||||
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
||||
if self.args.interpolate_z:
|
||||
|
@ -1029,7 +1029,7 @@ class Vits(BaseTTS):
|
|||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
|
||||
if self.args.TTS_part_sample_rate and self.args.interpolate_z:
|
||||
if self.args.encoder_sample_rate and self.args.interpolate_z:
|
||||
z = z.unsqueeze(0) # pylint: disable=not-callable
|
||||
z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0)
|
||||
y_mask = (
|
||||
|
@ -1155,7 +1155,7 @@ class Vits(BaseTTS):
|
|||
# compute melspec segment
|
||||
with autocast(enabled=False):
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
if self.args.encoder_sample_rate:
|
||||
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
||||
else:
|
||||
spec_segment_size = self.spec_segment_size
|
||||
|
@ -1370,7 +1370,7 @@ class Vits(BaseTTS):
|
|||
"""Compute spectrograms on the device."""
|
||||
ac = self.config.audio
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
if self.args.encoder_sample_rate:
|
||||
wav = self.audio_resampler(batch["waveform"])
|
||||
else:
|
||||
wav = batch["waveform"]
|
||||
|
@ -1378,7 +1378,7 @@ class Vits(BaseTTS):
|
|||
# compute spectrograms
|
||||
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
if self.args.encoder_sample_rate:
|
||||
# recompute spec with high sampling rate to the loss
|
||||
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
else:
|
||||
|
@ -1393,14 +1393,14 @@ class Vits(BaseTTS):
|
|||
fmax=ac.mel_fmax,
|
||||
)
|
||||
|
||||
if not self.args.TTS_part_sample_rate:
|
||||
if not self.args.encoder_sample_rate:
|
||||
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
|
||||
# compute spectrogram frame lengths
|
||||
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||
|
||||
if not self.args.TTS_part_sample_rate:
|
||||
if not self.args.encoder_sample_rate:
|
||||
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
||||
|
||||
# zero the padding frames
|
||||
|
@ -1518,7 +1518,7 @@ class Vits(BaseTTS):
|
|||
# as it is probably easier for model distribution.
|
||||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||
|
||||
if self.args.TTS_part_sample_rate is not None and eval:
|
||||
if self.args.encoder_sample_rate is not None and eval:
|
||||
# audio resampler is not used in inference time
|
||||
self.audio_resampler = None
|
||||
|
||||
|
@ -1549,7 +1549,7 @@ class Vits(BaseTTS):
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item()
|
||||
if not config.model_args.TTS_part_sample_rate:
|
||||
if not config.model_args.encoder_sample_rate:
|
||||
assert (
|
||||
upsample_rate == config.audio.hop_length
|
||||
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
||||
|
|
|
@ -42,7 +42,7 @@ config.model_args.d_vector_dim = 256
|
|||
|
||||
|
||||
# test upsample interpolation approach
|
||||
config.model_args.TTS_part_sample_rate = 11025
|
||||
config.model_args.encoder_sample_rate = 11025
|
||||
config.model_args.interpolate_z = True
|
||||
config.model_args.upsample_rates_decoder = [8, 8, 2, 2]
|
||||
config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7]
|
||||
|
|
|
@ -42,7 +42,7 @@ config.model_args.d_vector_dim = 256
|
|||
|
||||
|
||||
# test upsample
|
||||
config.model_args.TTS_part_sample_rate = 11025
|
||||
config.model_args.encoder_sample_rate = 11025
|
||||
config.model_args.interpolate_z = False
|
||||
config.model_args.upsample_rates_decoder = [8, 8, 4, 2]
|
||||
config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7]
|
||||
|
|
Loading…
Reference in New Issue