mirror of https://github.com/coqui-ai/TTS.git
Add upsampling_init and upsampling_z methods
This commit is contained in:
parent
b3e2c58398
commit
ce7138d9d4
|
@ -566,6 +566,7 @@ class Vits(BaseTTS):
|
|||
|
||||
self.init_multispeaker(config)
|
||||
self.init_multilingual(config)
|
||||
self.init_upsampling()
|
||||
|
||||
self.length_scale = self.args.length_scale
|
||||
self.noise_scale = self.args.noise_scale
|
||||
|
@ -648,12 +649,6 @@ class Vits(BaseTTS):
|
|||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||
)
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
||||
self.audio_resampler = torchaudio.transforms.Resample(
|
||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
@ -734,6 +729,16 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
self.embedded_language_dim = 0
|
||||
|
||||
def init_upsampling(self):
|
||||
"""
|
||||
Initialize upsampling modules of a model.
|
||||
"""
|
||||
if self.args.encoder_sample_rate:
|
||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
||||
self.audio_resampler = torchaudio.transforms.Resample(
|
||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||
)
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
@ -831,6 +836,25 @@ class Vits(BaseTTS):
|
|||
outputs["loss_duration"] = loss_duration
|
||||
return outputs, attn
|
||||
|
||||
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
|
||||
spec_segment_size = self.spec_segment_size
|
||||
if self.args.encoder_sample_rate:
|
||||
# recompute the slices and spec_segment_size if needed
|
||||
slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids
|
||||
spec_segment_size = spec_segment_size * int(self.interpolate_factor)
|
||||
# interpolate z if needed
|
||||
if self.args.interpolate_z:
|
||||
z = torch.nn.functional.interpolate(
|
||||
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest"
|
||||
).squeeze(0)
|
||||
# recompute the mask if needed
|
||||
if y_lengths is not None and y_mask is not None:
|
||||
y_mask = (
|
||||
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
|
||||
) # [B, 1, T_dec_resampled]
|
||||
|
||||
return z, spec_segment_size, slice_ids, y_mask
|
||||
|
||||
def forward( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
x: torch.tensor,
|
||||
|
@ -906,16 +930,8 @@ class Vits(BaseTTS):
|
|||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
slice_ids = slice_ids * int(self.interpolate_factor)
|
||||
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
||||
if self.args.interpolate_z:
|
||||
z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable
|
||||
z_slice = torch.nn.functional.interpolate(
|
||||
z_slice, scale_factor=[1, self.interpolate_factor], mode="nearest"
|
||||
).squeeze(0)
|
||||
else:
|
||||
spec_segment_size = self.spec_segment_size
|
||||
# interpolate z if needed
|
||||
z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids)
|
||||
|
||||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
|
@ -1029,12 +1045,8 @@ class Vits(BaseTTS):
|
|||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
|
||||
if self.args.encoder_sample_rate and self.args.interpolate_z:
|
||||
z = z.unsqueeze(0) # pylint: disable=not-callable
|
||||
z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0)
|
||||
y_mask = (
|
||||
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
|
||||
) # [B, 1, T_dec_resampled]
|
||||
# upsampling if needed
|
||||
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
|
||||
|
||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||
|
||||
|
|
Loading…
Reference in New Issue