From 1b6d3ebd3317f1d4017104b5208cbc8210780e0e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 20 Nov 2024 18:39:55 +0100 Subject: [PATCH] refactor(xtts): remove duplicate hifigan generator --- TTS/tts/layers/xtts/hifigan_decoder.py | 609 +----------------------- TTS/vocoder/models/hifigan_generator.py | 13 + 2 files changed, 15 insertions(+), 607 deletions(-) diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index 5ef0030b..2e6ac01a 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -1,618 +1,13 @@ import logging import torch -import torchaudio -from torch import nn -from torch.nn import Conv1d, ConvTranspose1d -from torch.nn import functional as F -from torch.nn.utils.parametrizations import weight_norm -from torch.nn.utils.parametrize import remove_parametrizations from trainer.io import load_fsspec -from TTS.utils.generic_utils import is_pytorch_at_least_2_4 -from TTS.vocoder.models.hifigan_generator import get_padding +from TTS.encoder.models.resnet import ResNetSpeakerEncoder +from TTS.vocoder.models.hifigan_generator import HifiganGenerator logger = logging.getLogger(__name__) -LRELU_SLOPE = 0.1 - - -class ResBlock1(torch.nn.Module): - """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. - - Network:: - - x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o - |--------------------------------------------------------------------------------------------------| - - - Args: - channels (int): number of hidden channels for the convolutional layers. - kernel_size (int): size of the convolution filter in each layer. - dilations (list): list of dilation value for each conv layer in a block. - """ - - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super().__init__() - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - - def forward(self, x): - """ - Args: - x (Tensor): input tensor. - Returns: - Tensor: output tensor. - Shapes: - x: [B, C, T] - """ - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_parametrizations(l, "weight") - for l in self.convs2: - remove_parametrizations(l, "weight") - - -class ResBlock2(torch.nn.Module): - """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. - - Network:: - - x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o - |---------------------------------------------------| - - - Args: - channels (int): number of hidden channels for the convolutional layers. - kernel_size (int): size of the convolution filter in each layer. - dilations (list): list of dilation value for each conv layer in a block. - """ - - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): - super().__init__() - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - - def forward(self, x): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_parametrizations(l, "weight") - - -class HifiganGenerator(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - resblock_type, - resblock_dilation_sizes, - resblock_kernel_sizes, - upsample_kernel_sizes, - upsample_initial_channel, - upsample_factors, - inference_padding=5, - cond_channels=0, - conv_pre_weight_norm=True, - conv_post_weight_norm=True, - conv_post_bias=True, - cond_in_each_up_layer=False, - ): - r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) - - Network: - x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o - .. -> zI ---| - resblockN_kNx1 -> zN ---' - - Args: - in_channels (int): number of input tensor channels. - out_channels (int): number of output tensor channels. - resblock_type (str): type of the `ResBlock`. '1' or '2'. - resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. - resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. - upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. - upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 - for each consecutive upsampling layer. - upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. - inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. - """ - super().__init__() - self.inference_padding = inference_padding - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_factors) - self.cond_in_each_up_layer = cond_in_each_up_layer - - # initial upsampling layers - self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) - resblock = ResBlock1 if resblock_type == "1" else ResBlock2 - # upsampling layers - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - # MRF blocks - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d)) - # post convolution layer - self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) - if cond_channels > 0: - self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) - - if not conv_pre_weight_norm: - remove_parametrizations(self.conv_pre, "weight") - - if not conv_post_weight_norm: - remove_parametrizations(self.conv_post, "weight") - - if self.cond_in_each_up_layer: - self.conds = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - self.conds.append(nn.Conv1d(cond_channels, ch, 1)) - - def forward(self, x, g=None): - """ - Args: - x (Tensor): feature input tensor. - g (Tensor): global conditioning input tensor. - - Returns: - Tensor: output waveform. - - Shapes: - x: [B, C, T] - Tensor: [B, 1, T] - """ - o = self.conv_pre(x) - if hasattr(self, "cond_layer"): - o = o + self.cond_layer(g) - for i in range(self.num_upsamples): - o = F.leaky_relu(o, LRELU_SLOPE) - o = self.ups[i](o) - - if self.cond_in_each_up_layer: - o = o + self.conds[i](g) - - z_sum = None - for j in range(self.num_kernels): - if z_sum is None: - z_sum = self.resblocks[i * self.num_kernels + j](o) - else: - z_sum += self.resblocks[i * self.num_kernels + j](o) - o = z_sum / self.num_kernels - o = F.leaky_relu(o) - o = self.conv_post(o) - o = torch.tanh(o) - return o - - @torch.no_grad() - def inference(self, c): - """ - Args: - x (Tensor): conditioning input tensor. - - Returns: - Tensor: output waveform. - - Shapes: - x: [B, C, T] - Tensor: [B, 1, T] - """ - c = c.to(self.conv_pre.weight.device) - c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") - return self.forward(c) - - def remove_weight_norm(self): - logger.info("Removing weight norm...") - for l in self.ups: - remove_parametrizations(l, "weight") - for l in self.resblocks: - l.remove_weight_norm() - remove_parametrizations(self.conv_pre, "weight") - remove_parametrizations(self.conv_post, "weight") - - def load_checkpoint( - self, config, checkpoint_path, eval=False, cache=False - ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()) - self.load_state_dict(state["model"]) - if eval: - self.eval() - assert not self.training - self.remove_weight_norm() - - -class SELayer(nn.Module): - def __init__(self, channel, reduction=8): - super(SELayer, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential( - nn.Linear(channel, channel // reduction), - nn.ReLU(inplace=True), - nn.Linear(channel // reduction, channel), - nn.Sigmoid(), - ) - - def forward(self, x): - b, c, _, _ = x.size() - y = self.avg_pool(x).view(b, c) - y = self.fc(y).view(b, c, 1, 1) - return x * y - - -class SEBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): - super(SEBasicBlock, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.se = SELayer(planes, reduction) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.relu(out) - out = self.bn1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.se(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - return out - - -def set_init_dict(model_dict, checkpoint_state, c): - # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - for k, v in checkpoint_state.items(): - if k not in model_dict: - logger.warning("Layer missing in the model definition: %s", k) - # 1. filter out unnecessary keys - pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} - # 2. filter out different size layers - pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} - # 3. skip reinit layers - if c.has("reinit_layers") and c.reinit_layers is not None: - for reinit_layer_name in c.reinit_layers: - pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} - # 4. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) - return model_dict - - -class PreEmphasis(nn.Module): - def __init__(self, coefficient=0.97): - super().__init__() - self.coefficient = coefficient - self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) - - def forward(self, x): - assert len(x.size()) == 2 - - x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") - return torch.nn.functional.conv1d(x, self.filter).squeeze(1) - - -class ResNetSpeakerEncoder(nn.Module): - """This is copied from 🐸TTS to remove it from the dependencies.""" - - # pylint: disable=W0102 - def __init__( - self, - input_dim=64, - proj_dim=512, - layers=[3, 4, 6, 3], - num_filters=[32, 64, 128, 256], - encoder_type="ASP", - log_input=False, - use_torch_spec=False, - audio_config=None, - ): - super(ResNetSpeakerEncoder, self).__init__() - - self.encoder_type = encoder_type - self.input_dim = input_dim - self.log_input = log_input - self.use_torch_spec = use_torch_spec - self.audio_config = audio_config - self.proj_dim = proj_dim - - self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) - self.relu = nn.ReLU(inplace=True) - self.bn1 = nn.BatchNorm2d(num_filters[0]) - - self.inplanes = num_filters[0] - self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) - self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) - self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) - self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) - - self.instancenorm = nn.InstanceNorm1d(input_dim) - - if self.use_torch_spec: - self.torch_spec = torch.nn.Sequential( - PreEmphasis(audio_config["preemphasis"]), - torchaudio.transforms.MelSpectrogram( - sample_rate=audio_config["sample_rate"], - n_fft=audio_config["fft_size"], - win_length=audio_config["win_length"], - hop_length=audio_config["hop_length"], - window_fn=torch.hamming_window, - n_mels=audio_config["num_mels"], - ), - ) - - else: - self.torch_spec = None - - outmap_size = int(self.input_dim / 8) - - self.attention = nn.Sequential( - nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), - nn.ReLU(), - nn.BatchNorm1d(128), - nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), - nn.Softmax(dim=2), - ) - - if self.encoder_type == "SAP": - out_dim = num_filters[3] * outmap_size - elif self.encoder_type == "ASP": - out_dim = num_filters[3] * outmap_size * 2 - else: - raise ValueError("Undefined encoder") - - self.fc = nn.Linear(out_dim, proj_dim) - - self._init_layers() - - def _init_layers(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def create_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append(block(self.inplanes, planes)) - - return nn.Sequential(*layers) - - # pylint: disable=R0201 - def new_parameter(self, *size): - out = nn.Parameter(torch.FloatTensor(*size)) - nn.init.xavier_normal_(out) - return out - - def forward(self, x, l2_norm=False): - """Forward pass of the model. - - Args: - x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` - to compute the spectrogram on-the-fly. - l2_norm (bool): Whether to L2-normalize the outputs. - - Shapes: - - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` - """ - x.squeeze_(1) - # if you torch spec compute it otherwise use the mel spec computed by the AP - if self.use_torch_spec: - x = self.torch_spec(x) - - if self.log_input: - x = (x + 1e-6).log() - x = self.instancenorm(x).unsqueeze(1) - - x = self.conv1(x) - x = self.relu(x) - x = self.bn1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = x.reshape(x.size()[0], -1, x.size()[-1]) - - w = self.attention(x) - - if self.encoder_type == "SAP": - x = torch.sum(x * w, dim=2) - elif self.encoder_type == "ASP": - mu = torch.sum(x * w, dim=2) - sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) - x = torch.cat((mu, sg), 1) - - x = x.view(x.size()[0], -1) - x = self.fc(x) - - if l2_norm: - x = torch.nn.functional.normalize(x, p=2, dim=1) - return x - - def load_checkpoint( - self, - checkpoint_path: str, - eval: bool = False, - use_cuda: bool = False, - criterion=None, - cache=False, - ): - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) - try: - self.load_state_dict(state["model"]) - logger.info("Model fully restored.") - except (KeyError, RuntimeError) as error: - # If eval raise the error - if eval: - raise error - - logger.info("Partial model initialization.") - model_dict = self.state_dict() - model_dict = set_init_dict(model_dict, state["model"]) - self.load_state_dict(model_dict) - del model_dict - - # load the criterion for restore_path - if criterion is not None and "criterion" in state: - try: - criterion.load_state_dict(state["criterion"]) - except (KeyError, RuntimeError) as error: - logger.exception("Criterion load ignored because of: %s", error) - - if use_cuda: - self.cuda() - if criterion is not None: - criterion = criterion.cuda() - - if eval: - self.eval() - assert not self.training - - if not eval: - return criterion, state["step"] - return criterion - class HifiDecoder(torch.nn.Module): def __init__( diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index afdd59a8..8273d020 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -178,6 +178,7 @@ class HifiganGenerator(torch.nn.Module): conv_pre_weight_norm=True, conv_post_weight_norm=True, conv_post_bias=True, + cond_in_each_up_layer=False, ): r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) @@ -202,6 +203,8 @@ class HifiganGenerator(torch.nn.Module): self.inference_padding = inference_padding self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_factors) + self.cond_in_each_up_layer = cond_in_each_up_layer + # initial upsampling layers self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) resblock = ResBlock1 if resblock_type == "1" else ResBlock2 @@ -236,6 +239,12 @@ class HifiganGenerator(torch.nn.Module): if not conv_post_weight_norm: remove_parametrizations(self.conv_post, "weight") + if self.cond_in_each_up_layer: + self.conds = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.conds.append(nn.Conv1d(cond_channels, ch, 1)) + def forward(self, x, g=None): """ Args: @@ -255,6 +264,10 @@ class HifiganGenerator(torch.nn.Module): for i in range(self.num_upsamples): o = F.leaky_relu(o, LRELU_SLOPE) o = self.ups[i](o) + + if self.cond_in_each_up_layer: + o = o + self.conds[i](g) + z_sum = None for j in range(self.num_kernels): if z_sum is None: