mirror of https://github.com/coqui-ai/TTS.git
130 lines
4.1 KiB
Python
130 lines
4.1 KiB
Python
import logging
|
|
|
|
import torch
|
|
from trainer.io import load_fsspec
|
|
|
|
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HifiDecoder(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_sample_rate=22050,
|
|
output_sample_rate=24000,
|
|
output_hop_length=256,
|
|
ar_mel_length_compression=1024,
|
|
decoder_input_dim=1024,
|
|
resblock_type_decoder="1",
|
|
resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
resblock_kernel_sizes_decoder=[3, 7, 11],
|
|
upsample_rates_decoder=[8, 8, 2, 2],
|
|
upsample_initial_channel_decoder=512,
|
|
upsample_kernel_sizes_decoder=[16, 16, 4, 4],
|
|
d_vector_dim=512,
|
|
cond_d_vector_in_each_upsampling_layer=True,
|
|
speaker_encoder_audio_config={
|
|
"fft_size": 512,
|
|
"win_length": 400,
|
|
"hop_length": 160,
|
|
"sample_rate": 16000,
|
|
"preemphasis": 0.97,
|
|
"num_mels": 64,
|
|
},
|
|
):
|
|
super().__init__()
|
|
self.input_sample_rate = input_sample_rate
|
|
self.output_sample_rate = output_sample_rate
|
|
self.output_hop_length = output_hop_length
|
|
self.ar_mel_length_compression = ar_mel_length_compression
|
|
self.speaker_encoder_audio_config = speaker_encoder_audio_config
|
|
self.waveform_decoder = HifiganGenerator(
|
|
decoder_input_dim,
|
|
1,
|
|
resblock_type_decoder,
|
|
resblock_dilation_sizes_decoder,
|
|
resblock_kernel_sizes_decoder,
|
|
upsample_kernel_sizes_decoder,
|
|
upsample_initial_channel_decoder,
|
|
upsample_rates_decoder,
|
|
inference_padding=0,
|
|
cond_channels=d_vector_dim,
|
|
conv_pre_weight_norm=False,
|
|
conv_post_weight_norm=False,
|
|
conv_post_bias=False,
|
|
cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer,
|
|
)
|
|
self.speaker_encoder = ResNetSpeakerEncoder(
|
|
input_dim=64,
|
|
proj_dim=512,
|
|
log_input=True,
|
|
use_torch_spec=True,
|
|
audio_config=speaker_encoder_audio_config,
|
|
)
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
def forward(self, latents, g=None):
|
|
"""
|
|
Args:
|
|
x (Tensor): feature input tensor (GPT latent).
|
|
g (Tensor): global conditioning input tensor.
|
|
|
|
Returns:
|
|
Tensor: output waveform.
|
|
|
|
Shapes:
|
|
x: [B, C, T]
|
|
Tensor: [B, 1, T]
|
|
"""
|
|
|
|
z = torch.nn.functional.interpolate(
|
|
latents.transpose(1, 2),
|
|
scale_factor=[self.ar_mel_length_compression / self.output_hop_length],
|
|
mode="linear",
|
|
).squeeze(1)
|
|
# upsample to the right sr
|
|
if self.output_sample_rate != self.input_sample_rate:
|
|
z = torch.nn.functional.interpolate(
|
|
z,
|
|
scale_factor=[self.output_sample_rate / self.input_sample_rate],
|
|
mode="linear",
|
|
).squeeze(0)
|
|
o = self.waveform_decoder(z, g=g)
|
|
return o
|
|
|
|
@torch.no_grad()
|
|
def inference(self, c, g):
|
|
"""
|
|
Args:
|
|
x (Tensor): feature input tensor (GPT latent).
|
|
g (Tensor): global conditioning input tensor.
|
|
|
|
Returns:
|
|
Tensor: output waveform.
|
|
|
|
Shapes:
|
|
x: [B, C, T]
|
|
Tensor: [B, 1, T]
|
|
"""
|
|
return self.forward(c, g=g)
|
|
|
|
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
|
# remove unused keys
|
|
state = state["model"]
|
|
states_keys = list(state.keys())
|
|
for key in states_keys:
|
|
if "waveform_decoder." not in key and "speaker_encoder." not in key:
|
|
del state[key]
|
|
|
|
self.load_state_dict(state)
|
|
if eval:
|
|
self.eval()
|
|
assert not self.training
|
|
self.waveform_decoder.remove_weight_norm()
|