mirror of https://github.com/coqui-ai/TTS.git
Add upsampling_init and upsampling_z methods
This commit is contained in:
parent
b3e2c58398
commit
ce7138d9d4
|
@ -566,6 +566,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
self.init_multilingual(config)
|
self.init_multilingual(config)
|
||||||
|
self.init_upsampling()
|
||||||
|
|
||||||
self.length_scale = self.args.length_scale
|
self.length_scale = self.args.length_scale
|
||||||
self.noise_scale = self.args.noise_scale
|
self.noise_scale = self.args.noise_scale
|
||||||
|
@ -648,12 +649,6 @@ class Vits(BaseTTS):
|
||||||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.encoder_sample_rate:
|
|
||||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
|
||||||
self.audio_resampler = torchaudio.transforms.Resample(
|
|
||||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit):
|
def init_multispeaker(self, config: Coqpit):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
or with external `d_vectors` computed from a speaker encoder model.
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
@ -734,6 +729,16 @@ class Vits(BaseTTS):
|
||||||
else:
|
else:
|
||||||
self.embedded_language_dim = 0
|
self.embedded_language_dim = 0
|
||||||
|
|
||||||
|
def init_upsampling(self):
|
||||||
|
"""
|
||||||
|
Initialize upsampling modules of a model.
|
||||||
|
"""
|
||||||
|
if self.args.encoder_sample_rate:
|
||||||
|
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
||||||
|
self.audio_resampler = torchaudio.transforms.Resample(
|
||||||
|
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||||
|
@ -831,6 +836,25 @@ class Vits(BaseTTS):
|
||||||
outputs["loss_duration"] = loss_duration
|
outputs["loss_duration"] = loss_duration
|
||||||
return outputs, attn
|
return outputs, attn
|
||||||
|
|
||||||
|
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
|
||||||
|
spec_segment_size = self.spec_segment_size
|
||||||
|
if self.args.encoder_sample_rate:
|
||||||
|
# recompute the slices and spec_segment_size if needed
|
||||||
|
slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids
|
||||||
|
spec_segment_size = spec_segment_size * int(self.interpolate_factor)
|
||||||
|
# interpolate z if needed
|
||||||
|
if self.args.interpolate_z:
|
||||||
|
z = torch.nn.functional.interpolate(
|
||||||
|
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest"
|
||||||
|
).squeeze(0)
|
||||||
|
# recompute the mask if needed
|
||||||
|
if y_lengths is not None and y_mask is not None:
|
||||||
|
y_mask = (
|
||||||
|
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
|
||||||
|
) # [B, 1, T_dec_resampled]
|
||||||
|
|
||||||
|
return z, spec_segment_size, slice_ids, y_mask
|
||||||
|
|
||||||
def forward( # pylint: disable=dangerous-default-value
|
def forward( # pylint: disable=dangerous-default-value
|
||||||
self,
|
self,
|
||||||
x: torch.tensor,
|
x: torch.tensor,
|
||||||
|
@ -906,16 +930,8 @@ class Vits(BaseTTS):
|
||||||
# select a random feature segment for the waveform decoder
|
# select a random feature segment for the waveform decoder
|
||||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||||
|
|
||||||
if self.args.encoder_sample_rate:
|
# interpolate z if needed
|
||||||
slice_ids = slice_ids * int(self.interpolate_factor)
|
z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids)
|
||||||
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
|
||||||
if self.args.interpolate_z:
|
|
||||||
z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable
|
|
||||||
z_slice = torch.nn.functional.interpolate(
|
|
||||||
z_slice, scale_factor=[1, self.interpolate_factor], mode="nearest"
|
|
||||||
).squeeze(0)
|
|
||||||
else:
|
|
||||||
spec_segment_size = self.spec_segment_size
|
|
||||||
|
|
||||||
o = self.waveform_decoder(z_slice, g=g)
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
|
||||||
|
@ -1029,12 +1045,8 @@ class Vits(BaseTTS):
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||||
|
|
||||||
if self.args.encoder_sample_rate and self.args.interpolate_z:
|
# upsampling if needed
|
||||||
z = z.unsqueeze(0) # pylint: disable=not-callable
|
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
|
||||||
z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0)
|
|
||||||
y_mask = (
|
|
||||||
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
|
|
||||||
) # [B, 1, T_dec_resampled]
|
|
||||||
|
|
||||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue