diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index c0292743..49b450cf 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -161,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, loader_time = time.time() - end_time global_step += 1 - # setup lr - if c.lr_decay: - scheduler.step() optimizer.zero_grad() # dispatch data to GPU @@ -182,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() + # setup lr + if c.lr_decay: + scheduler.step() + step_time = time.time() - start_time epoch_time += step_time diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 81385c6c..bb780e3c 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -55,7 +55,7 @@ class EncoderDataset(Dataset): logger.info(" | Number of instances: %d", len(self.items)) logger.info(" | Sequence length: %d", self.seq_len) logger.info(" | Number of classes: %d", len(self.classes)) - logger.info(" | Classes: %d", self.classes) + logger.info(" | Classes: %s", self.classes) def load_wav(self, filename): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) diff --git a/TTS/model.py b/TTS/model.py index ae6be7b4..01dd515d 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -1,5 +1,6 @@ +import os from abc import abstractmethod -from typing import Dict +from typing import Any, Union import torch from coqpit import Coqpit @@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel): @staticmethod @abstractmethod - def init_from_config(config: Coqpit): + def init_from_config(config: Coqpit) -> "BaseTrainerModel": """Init the model and all its attributes from the given config. Override this depending on your model. @@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel): ... @abstractmethod - def inference(self, input: torch.Tensor, aux_input={}) -> Dict: + def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]: """Forward pass for inference. It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` @@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel): @abstractmethod def load_checkpoint( - self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + self, + config: Coqpit, + checkpoint_path: Union[str, os.PathLike[Any]], + eval: bool = False, + strict: bool = True, + cache: bool = False, ) -> None: - """Load a model checkpoint gile and get ready for training or inference. + """Load a model checkpoint file and get ready for training or inference. Args: config (Coqpit): Model configuration. - checkpoint_path (str): Path to the model checkpoint file. + checkpoint_path (str | os.PathLike): Path to the model checkpoint file. eval (bool, optional): If true, init model for inference else for training. Defaults to False. strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 02688d61..c97d070a 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -5,6 +5,7 @@ from torch import nn from torch.nn import functional as F from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 +from TTS.tts.utils.helpers import convert_pad_shape class RelativePositionMultiHeadAttention(nn.Module): @@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module): pad_l = self.kernel_size - 1 pad_r = 0 padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, self._pad_shape(padding)) + x = F.pad(x, convert_pad_shape(padding)) return x def _same_padding(self, x): @@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module): pad_l = (self.kernel_size - 1) // 2 pad_r = self.kernel_size // 2 padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, self._pad_shape(padding)) + x = F.pad(x, convert_pad_shape(padding)) return x - @staticmethod - def _pad_shape(padding): - l = padding[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - class RelativePositionTransformer(nn.Module): """Transformer with Relative Potional Encoding. diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index cd6cd0ae..5ebed81d 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -255,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module): @staticmethod def _make_ga_mask(ilen, olen, sigma): - grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) + grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij") grid_x, grid_y = grid_x.float(), grid_y.float() return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) diff --git a/TTS/tts/layers/overflow/neural_hmm.py b/TTS/tts/layers/overflow/neural_hmm.py index 0631ba98..a12becef 100644 --- a/TTS/tts/layers/overflow/neural_hmm.py +++ b/TTS/tts/layers/overflow/neural_hmm.py @@ -128,7 +128,8 @@ class NeuralHMM(nn.Module): # Get mean, std and transition vector from decoder for this timestep # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop if self.use_grad_checkpointing and self.training: - mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs) + # TODO: use_reentrant=False is recommended + mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True) else: mean, std, transition_vector = self.output_net(h_memory, inputs) diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index f97b584f..50ed1024 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -10,22 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask LRELU_SLOPE = 0.1 -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - class TextEncoder(nn.Module): def __init__( self, diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index 42f64e68..9160529b 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -9,16 +9,13 @@ from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations from TTS.utils.io import load_fsspec +from TTS.vocoder.models.hifigan_generator import get_padding logger = logging.getLogger(__name__) LRELU_SLOPE = 0.1 -def get_padding(k, d): - return int((k * d - d) / 2) - - class ResBlock1(torch.nn.Module): """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 7fbc2a3a..ccb023ce 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel): if speaker_name is None: d_vector = self.speaker_manager.get_random_embedding() else: - d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + d_vector = self.speaker_manager.get_mean_embedding(speaker_name) elif config.use_speaker_embedding: if speaker_name is None: speaker_id = self.speaker_manager.get_random_id() diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index ed318923..4230fcc3 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -88,12 +88,6 @@ def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: return out_padded -def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: return torch.ceil(lens / stride).int() diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 7b37201f..7429d0fc 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -145,10 +145,9 @@ def average_over_durations(values, durs): return avg -def convert_pad_shape(pad_shape): +def convert_pad_shape(pad_shape: list[list]) -> list: l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape + return [item for sublist in l for item in sublist] def generate_path(duration, mask): diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py index c387157f..22ffd009 100644 --- a/TTS/vc/models/base_vc.py +++ b/TTS/vc/models/base_vc.py @@ -1,7 +1,7 @@ import logging import os import random -from typing import Dict, List, Tuple, Union +from typing import Any, Optional, Union import torch import torch.distributed as dist @@ -10,6 +10,7 @@ from torch import nn from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer import Trainer from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset @@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio.processor import AudioProcessor # pylint: skip-file @@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel): def __init__( self, config: Coqpit, - ap: "AudioProcessor", - speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None, - ): + ap: AudioProcessor, + speaker_manager: Optional[SpeakerManager] = None, + language_manager: Optional[LanguageManager] = None, + ) -> None: super().__init__() self.config = config self.ap = ap @@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel): self.language_manager = language_manager self._set_model_args(config) - def _set_model_args(self, config: Coqpit): + def _set_model_args(self, config: Coqpit) -> None: """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). `ModelArgs` has all the fields reuqired to initialize the model architecture. @@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel): else: raise ValueError("config must be either a *Config or *Args") - def init_multispeaker(self, config: Coqpit, data: List = None): + def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None: """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining `in_channels` size of the connected layers. @@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel): self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) - def get_aux_input(self, **kwargs) -> Dict: + def get_aux_input(self, **kwargs: Any) -> dict[str, Any]: """Prepare and return `aux_input` used by `forward()`""" return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} - def get_aux_input_from_test_sentences(self, sentence_info): + def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]: if hasattr(self.config, "model_args"): config = self.config.model_args else: @@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel): if speaker_name is None: d_vector = self.speaker_manager.get_random_embedding() else: - d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + d_vector = self.speaker_manager.get_mean_embedding(speaker_name) elif config.use_speaker_embedding: if speaker_name is None: speaker_id = self.speaker_manager.get_random_id() @@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel): "language_id": language_id, } - def format_batch(self, batch: Dict) -> Dict: + def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]: """Generic batch formatting for `VCDataset`. You must override this if you use a custom dataset. Args: - batch (Dict): [description] + batch (dict): [description] Returns: - Dict: [description] + dict: [description] """ # setup input batch text_input = batch["token_id"] @@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel): "audio_unique_names": batch["audio_unique_names"], } - def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1): weights = None data_items = dataset.samples @@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel): def get_data_loader( self, config: Coqpit, - assets: Dict, + assets: dict, is_eval: bool, - samples: Union[List[Dict], List[List]], + samples: Union[list[dict], list[list]], verbose: bool, num_gpus: int, - rank: int = None, + rank: Optional[int] = None, ) -> "DataLoader": if is_eval and not config.run_eval: loader = None @@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel): def _get_test_aux_input( self, - ) -> Dict: + ) -> dict[str, Any]: d_vector = None - if self.config.use_d_vector_file: + if self.speaker_manager is not None and self.config.use_d_vector_file: d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] d_vector = (random.sample(sorted(d_vector), 1),) @@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel): } return aux_inputs - def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + def test_run(self, assets: dict) -> tuple[dict, dict]: """Generic test run for `vc` models used by `Trainer`. You can override this for a different behaviour. @@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel): assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`. Returns: - Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + tuple[dict, dict]: Test figures and audios to be projected to Tensorboard. """ logger.info("Synthesizing test sentences.") test_audios = {} @@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel): ) return test_figures, test_audios - def on_init_start(self, trainer): + def on_init_start(self, trainer: Trainer) -> None: """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" if self.speaker_manager is not None: output_path = os.path.join(trainer.output_path, "speakers.pth") diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index ec7cc0e0..7746572f 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -14,14 +14,16 @@ from torch.nn.utils.parametrize import remove_parametrizations import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.modules as modules +from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.io import load_fsspec from TTS.vc.configs.freevc_config import FreeVCConfig from TTS.vc.models.base_vc import BaseVC -from TTS.vc.modules.freevc.commons import get_padding, init_weights +from TTS.vc.modules.freevc.commons import init_weights from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.wavlm import get_wavlm +from TTS.vocoder.models.hifigan_generator import get_padding logger = logging.getLogger(__name__) @@ -80,7 +82,7 @@ class Encoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py index e5fb13c1..feea7f34 100644 --- a/TTS/vc/modules/freevc/commons.py +++ b/TTS/vc/modules/freevc/commons.py @@ -3,23 +3,15 @@ import math import torch from torch.nn import functional as F +from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask -def init_weights(m, mean=0.0, std=0.01): + +def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - def intersperse(lst, item): result = [item] * (len(lst) * 2 + 1) result[1::2] = lst @@ -119,20 +111,11 @@ def shift_1d(x): return x -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - def generate_path(duration, mask): """ duration: [b, 1, t_x] mask: [b, 1, t_y, t_x] """ - device = duration.device - b, _, t_y, t_x = mask.shape cum_duration = torch.cumsum(duration, -1) diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/modules/freevc/modules.py index 9bb54990..722444a3 100644 --- a/TTS/vc/modules/freevc/modules.py +++ b/TTS/vc/modules/freevc/modules.py @@ -6,26 +6,13 @@ from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations import TTS.vc.modules.freevc.commons as commons -from TTS.vc.modules.freevc.commons import get_padding, init_weights +from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.vc.modules.freevc.commons import init_weights +from TTS.vocoder.models.hifigan_generator import get_padding LRELU_SLOPE = 0.1 -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - class ConvReluNorm(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): super().__init__() @@ -40,11 +27,11 @@ class ConvReluNorm(nn.Module): self.conv_layers = nn.ModuleList() self.norm_layers = nn.ModuleList() self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) + self.norm_layers.append(LayerNorm2(hidden_channels)) self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) for _ in range(n_layers - 1): self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) + self.norm_layers.append(LayerNorm2(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() self.proj.bias.data.zero_() @@ -59,48 +46,6 @@ class ConvReluNorm(nn.Module): return x * x_mask -class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size**i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append( - nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) - ) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - class WN(torch.nn.Module): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): super(WN, self).__init__() @@ -317,24 +262,6 @@ class Flip(nn.Module): return x -class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels, 1)) - self.logs = nn.Parameter(torch.zeros(channels, 1)) - - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1, 2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x - - class ResidualCouplingLayer(nn.Module): def __init__( self, diff --git a/TTS/vocoder/models/hifigan_discriminator.py b/TTS/vocoder/models/hifigan_discriminator.py index 7447a5fb..1cbc6ab3 100644 --- a/TTS/vocoder/models/hifigan_discriminator.py +++ b/TTS/vocoder/models/hifigan_discriminator.py @@ -3,6 +3,8 @@ import torch from torch import nn from torch.nn import functional as F +from TTS.vocoder.models.hifigan_generator import get_padding + LRELU_SLOPE = 0.1 @@ -29,7 +31,6 @@ class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super().__init__() self.period = period - get_padding = lambda k, d: int((k * d - d) / 2) norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm self.convs = nn.ModuleList( [ diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index b9561f6f..083ce344 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -15,8 +15,8 @@ logger = logging.getLogger(__name__) LRELU_SLOPE = 0.1 -def get_padding(k, d): - return int((k * d - d) / 2) +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) class ResBlock1(torch.nn.Module): diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index 23bb440a..d07efa36 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -3,7 +3,7 @@ import torch as T from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask -def average_over_durations_test(): # pylint: disable=no-self-use +def test_average_over_durations(): # pylint: disable=no-self-use pitch = T.rand(1, 1, 128) durations = T.randint(1, 5, (1, 21)) @@ -21,7 +21,7 @@ def average_over_durations_test(): # pylint: disable=no-self-use index += dur -def seqeunce_mask_test(): +def test_sequence_mask(): lengths = T.randint(10, 15, (8,)) mask = sequence_mask(lengths) for i in range(8): @@ -30,8 +30,8 @@ def seqeunce_mask_test(): assert mask[i, l:].sum() == 0 -def segment_test(): - x = T.range(0, 11) +def test_segment(): + x = T.arange(0, 12) x = x.repeat(8, 1).unsqueeze(1) segment_ids = T.randint(0, 7, (8,)) @@ -50,11 +50,11 @@ def segment_test(): assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum() -def rand_segments_test(): +def test_rand_segments(): x = T.rand(2, 3, 4) x_lens = T.randint(3, 4, (2,)) - segments, seg_idxs = rand_segments(x, x_lens, segment_size=3) - assert segments.shape == (2, 3, 3) + segments, seg_idxs = rand_segments(x, x_lens, segment_size=2) + assert segments.shape == (2, 3, 2) assert all(seg_idxs >= 0), seg_idxs try: segments, _ = rand_segments(x, x_lens, segment_size=5) @@ -68,10 +68,10 @@ def rand_segments_test(): assert all(x_lens_back == x_lens) -def generate_path_test(): +def test_generate_path(): durations = T.randint(1, 4, (10, 21)) x_length = T.randint(18, 22, (10,)) - x_mask = sequence_mask(x_length).unsqueeze(1).long() + x_mask = sequence_mask(x_length, max_len=21).unsqueeze(1).long() durations = durations * x_mask.squeeze(1) y_length = durations.sum(1) y_mask = sequence_mask(y_length).unsqueeze(1).long()