import logging import torch from trainer.io import load_fsspec from TTS.encoder.models.resnet import ResNetSpeakerEncoder from TTS.vocoder.models.hifigan_generator import HifiganGenerator logger = logging.getLogger(__name__) class HifiDecoder(torch.nn.Module): def __init__( self, input_sample_rate=22050, output_sample_rate=24000, output_hop_length=256, ar_mel_length_compression=1024, decoder_input_dim=1024, resblock_type_decoder="1", resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes_decoder=[3, 7, 11], upsample_rates_decoder=[8, 8, 2, 2], upsample_initial_channel_decoder=512, upsample_kernel_sizes_decoder=[16, 16, 4, 4], d_vector_dim=512, cond_d_vector_in_each_upsampling_layer=True, speaker_encoder_audio_config={ "fft_size": 512, "win_length": 400, "hop_length": 160, "sample_rate": 16000, "preemphasis": 0.97, "num_mels": 64, }, ): super().__init__() self.input_sample_rate = input_sample_rate self.output_sample_rate = output_sample_rate self.output_hop_length = output_hop_length self.ar_mel_length_compression = ar_mel_length_compression self.speaker_encoder_audio_config = speaker_encoder_audio_config self.waveform_decoder = HifiganGenerator( decoder_input_dim, 1, resblock_type_decoder, resblock_dilation_sizes_decoder, resblock_kernel_sizes_decoder, upsample_kernel_sizes_decoder, upsample_initial_channel_decoder, upsample_rates_decoder, inference_padding=0, cond_channels=d_vector_dim, conv_pre_weight_norm=False, conv_post_weight_norm=False, conv_post_bias=False, cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer, ) self.speaker_encoder = ResNetSpeakerEncoder( input_dim=64, proj_dim=512, log_input=True, use_torch_spec=True, audio_config=speaker_encoder_audio_config, ) @property def device(self): return next(self.parameters()).device def forward(self, latents, g=None): """ Args: x (Tensor): feature input tensor (GPT latent). g (Tensor): global conditioning input tensor. Returns: Tensor: output waveform. Shapes: x: [B, C, T] Tensor: [B, 1, T] """ z = torch.nn.functional.interpolate( latents.transpose(1, 2), scale_factor=[self.ar_mel_length_compression / self.output_hop_length], mode="linear", ).squeeze(1) # upsample to the right sr if self.output_sample_rate != self.input_sample_rate: z = torch.nn.functional.interpolate( z, scale_factor=[self.output_sample_rate / self.input_sample_rate], mode="linear", ).squeeze(0) o = self.waveform_decoder(z, g=g) return o @torch.no_grad() def inference(self, c, g): """ Args: x (Tensor): feature input tensor (GPT latent). g (Tensor): global conditioning input tensor. Returns: Tensor: output waveform. Shapes: x: [B, C, T] Tensor: [B, 1, T] """ return self.forward(c, g=g) def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) # remove unused keys state = state["model"] states_keys = list(state.keys()) for key in states_keys: if "waveform_decoder." not in key and "speaker_encoder." not in key: del state[key] self.load_state_dict(state) if eval: self.eval() assert not self.training self.waveform_decoder.remove_weight_norm()