Merge pull request #49 from idiap/vc-refactors

VC-related refactors and fixes
This commit is contained in:
Enno Hermann 2024-06-26 14:01:21 +01:00 committed by GitHub
commit ff2cd5c97d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 76 additions and 184 deletions

View File

@ -161,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
loader_time = time.time() - end_time loader_time = time.time() - end_time
global_step += 1 global_step += 1
# setup lr
if c.lr_decay:
scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
# dispatch data to GPU # dispatch data to GPU
@ -182,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
grad_norm, _ = check_update(model, c.grad_clip) grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step() optimizer.step()
# setup lr
if c.lr_decay:
scheduler.step()
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time

View File

@ -55,7 +55,7 @@ class EncoderDataset(Dataset):
logger.info(" | Number of instances: %d", len(self.items)) logger.info(" | Number of instances: %d", len(self.items))
logger.info(" | Sequence length: %d", self.seq_len) logger.info(" | Sequence length: %d", self.seq_len)
logger.info(" | Number of classes: %d", len(self.classes)) logger.info(" | Number of classes: %d", len(self.classes))
logger.info(" | Classes: %d", self.classes) logger.info(" | Classes: %s", self.classes)
def load_wav(self, filename): def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)

View File

@ -1,5 +1,6 @@
import os
from abc import abstractmethod from abc import abstractmethod
from typing import Dict from typing import Any, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def init_from_config(config: Coqpit): def init_from_config(config: Coqpit) -> "BaseTrainerModel":
"""Init the model and all its attributes from the given config. """Init the model and all its attributes from the given config.
Override this depending on your model. Override this depending on your model.
@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel):
... ...
@abstractmethod @abstractmethod
def inference(self, input: torch.Tensor, aux_input={}) -> Dict: def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
"""Forward pass for inference. """Forward pass for inference.
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel):
@abstractmethod @abstractmethod
def load_checkpoint( def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False self,
config: Coqpit,
checkpoint_path: Union[str, os.PathLike[Any]],
eval: bool = False,
strict: bool = True,
cache: bool = False,
) -> None: ) -> None:
"""Load a model checkpoint gile and get ready for training or inference. """Load a model checkpoint file and get ready for training or inference.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file. checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False. eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.

View File

@ -5,6 +5,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
from TTS.tts.utils.helpers import convert_pad_shape
class RelativePositionMultiHeadAttention(nn.Module): class RelativePositionMultiHeadAttention(nn.Module):
@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module):
pad_l = self.kernel_size - 1 pad_l = self.kernel_size - 1
pad_r = 0 pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]] padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding)) x = F.pad(x, convert_pad_shape(padding))
return x return x
def _same_padding(self, x): def _same_padding(self, x):
@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module):
pad_l = (self.kernel_size - 1) // 2 pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2 pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]] padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding)) x = F.pad(x, convert_pad_shape(padding))
return x return x
@staticmethod
def _pad_shape(padding):
l = padding[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
class RelativePositionTransformer(nn.Module): class RelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding. """Transformer with Relative Potional Encoding.

View File

@ -255,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module):
@staticmethod @staticmethod
def _make_ga_mask(ilen, olen, sigma): def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij")
grid_x, grid_y = grid_x.float(), grid_y.float() grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))

View File

@ -128,7 +128,8 @@ class NeuralHMM(nn.Module):
# Get mean, std and transition vector from decoder for this timestep # Get mean, std and transition vector from decoder for this timestep
# Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop
if self.use_grad_checkpointing and self.training: if self.use_grad_checkpointing and self.training:
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs) # TODO: use_reentrant=False is recommended
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True)
else: else:
mean, std, transition_vector = self.output_net(h_memory, inputs) mean, std, transition_vector = self.output_net(h_memory, inputs)

View File

@ -10,22 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class TextEncoder(nn.Module): class TextEncoder(nn.Module):
def __init__( def __init__(
self, self,

View File

@ -9,16 +9,13 @@ from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.vocoder.models.hifigan_generator import get_padding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
def get_padding(k, d):
return int((k * d - d) / 2)
class ResBlock1(torch.nn.Module): class ResBlock1(torch.nn.Module):
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block. """Residual Block Type 1. It has 3 convolutional layers in each convolutional block.

View File

@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel):
if speaker_name is None: if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding() d_vector = self.speaker_manager.get_random_embedding()
else: else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
elif config.use_speaker_embedding: elif config.use_speaker_embedding:
if speaker_name is None: if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id() speaker_id = self.speaker_manager.get_random_id()

View File

@ -88,12 +88,6 @@ def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
return out_padded return out_padded
def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
return torch.ceil(lens / stride).int() return torch.ceil(lens / stride).int()

View File

@ -145,10 +145,9 @@ def average_over_durations(values, durs):
return avg return avg
def convert_pad_shape(pad_shape): def convert_pad_shape(pad_shape: list[list]) -> list:
l = pad_shape[::-1] l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist] return [item for sublist in l for item in sublist]
return pad_shape
def generate_path(duration, mask): def generate_path(duration, mask):

View File

@ -1,7 +1,7 @@
import logging import logging
import os import os
import random import random
from typing import Dict, List, Tuple, Union from typing import Any, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -10,6 +10,7 @@ from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer import Trainer
from TTS.model import BaseTrainerModel from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio.processor import AudioProcessor
# pylint: skip-file # pylint: skip-file
@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel):
def __init__( def __init__(
self, self,
config: Coqpit, config: Coqpit,
ap: "AudioProcessor", ap: AudioProcessor,
speaker_manager: SpeakerManager = None, speaker_manager: Optional[SpeakerManager] = None,
language_manager: LanguageManager = None, language_manager: Optional[LanguageManager] = None,
): ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.ap = ap self.ap = ap
@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel):
self.language_manager = language_manager self.language_manager = language_manager
self._set_model_args(config) self._set_model_args(config)
def _set_model_args(self, config: Coqpit): def _set_model_args(self, config: Coqpit) -> None:
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`). """Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture. `ModelArgs` has all the fields reuqired to initialize the model architecture.
@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel):
else: else:
raise ValueError("config must be either a *Config or *Args") raise ValueError("config must be either a *Config or *Args")
def init_multispeaker(self, config: Coqpit, data: List = None): def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers. `in_channels` size of the connected layers.
@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict: def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare and return `aux_input` used by `forward()`""" """Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_sentences(self, sentence_info): def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
if hasattr(self.config, "model_args"): if hasattr(self.config, "model_args"):
config = self.config.model_args config = self.config.model_args
else: else:
@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel):
if speaker_name is None: if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding() d_vector = self.speaker_manager.get_random_embedding()
else: else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
elif config.use_speaker_embedding: elif config.use_speaker_embedding:
if speaker_name is None: if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id() speaker_id = self.speaker_manager.get_random_id()
@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel):
"language_id": language_id, "language_id": language_id,
} }
def format_batch(self, batch: Dict) -> Dict: def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Generic batch formatting for `VCDataset`. """Generic batch formatting for `VCDataset`.
You must override this if you use a custom dataset. You must override this if you use a custom dataset.
Args: Args:
batch (Dict): [description] batch (dict): [description]
Returns: Returns:
Dict: [description] dict: [description]
""" """
# setup input batch # setup input batch
text_input = batch["token_id"] text_input = batch["token_id"]
@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel):
"audio_unique_names": batch["audio_unique_names"], "audio_unique_names": batch["audio_unique_names"],
} }
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1):
weights = None weights = None
data_items = dataset.samples data_items = dataset.samples
@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel):
def get_data_loader( def get_data_loader(
self, self,
config: Coqpit, config: Coqpit,
assets: Dict, assets: dict,
is_eval: bool, is_eval: bool,
samples: Union[List[Dict], List[List]], samples: Union[list[dict], list[list]],
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
rank: int = None, rank: Optional[int] = None,
) -> "DataLoader": ) -> "DataLoader":
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel):
def _get_test_aux_input( def _get_test_aux_input(
self, self,
) -> Dict: ) -> dict[str, Any]:
d_vector = None d_vector = None
if self.config.use_d_vector_file: if self.speaker_manager is not None and self.config.use_d_vector_file:
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
d_vector = (random.sample(sorted(d_vector), 1),) d_vector = (random.sample(sorted(d_vector), 1),)
@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel):
} }
return aux_inputs return aux_inputs
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: def test_run(self, assets: dict) -> tuple[dict, dict]:
"""Generic test run for `vc` models used by `Trainer`. """Generic test run for `vc` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`. assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. tuple[dict, dict]: Test figures and audios to be projected to Tensorboard.
""" """
logger.info("Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
test_audios = {} test_audios = {}
@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel):
) )
return test_figures, test_audios return test_figures, test_audios
def on_init_start(self, trainer): def on_init_start(self, trainer: Trainer) -> None:
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None: if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.pth") output_path = os.path.join(trainer.output_path, "speakers.pth")

View File

@ -14,14 +14,16 @@ from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules import TTS.vc.modules.freevc.modules as modules
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.vc.configs.freevc_config import FreeVCConfig from TTS.vc.configs.freevc_config import FreeVCConfig
from TTS.vc.models.base_vc import BaseVC from TTS.vc.models.base_vc import BaseVC
from TTS.vc.modules.freevc.commons import get_padding, init_weights from TTS.vc.modules.freevc.commons import init_weights
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm from TTS.vc.modules.freevc.wavlm import get_wavlm
from TTS.vocoder.models.hifigan_generator import get_padding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,7 +82,7 @@ class Encoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask

View File

@ -3,23 +3,15 @@ import math
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
def init_weights(m, mean=0.0, std=0.01):
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
classname = m.__class__.__name__ classname = m.__class__.__name__
if classname.find("Conv") != -1: if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std) m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item): def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1) result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst result[1::2] = lst
@ -119,20 +111,11 @@ def shift_1d(x):
return x return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask): def generate_path(duration, mask):
""" """
duration: [b, 1, t_x] duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x] mask: [b, 1, t_y, t_x]
""" """
device = duration.device
b, _, t_y, t_x = mask.shape b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1) cum_duration = torch.cumsum(duration, -1)

View File

@ -6,26 +6,13 @@ from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.commons as commons
from TTS.vc.modules.freevc.commons import get_padding, init_weights from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.vc.modules.freevc.commons import init_weights
from TTS.vocoder.models.hifigan_generator import get_padding
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module): class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__() super().__init__()
@ -40,11 +27,11 @@ class ConvReluNorm(nn.Module):
self.conv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList() self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm2(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm2(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_() self.proj.weight.data.zero_()
self.proj.bias.data.zero_() self.proj.bias.data.zero_()
@ -59,48 +46,6 @@ class ConvReluNorm(nn.Module):
return x * x_mask return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module): class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
super(WN, self).__init__() super(WN, self).__init__()
@ -317,24 +262,6 @@ class Flip(nn.Module):
return x return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module): class ResidualCouplingLayer(nn.Module):
def __init__( def __init__(
self, self,

View File

@ -3,6 +3,8 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from TTS.vocoder.models.hifigan_generator import get_padding
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@ -29,7 +31,6 @@ class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super().__init__() super().__init__()
self.period = period self.period = period
get_padding = lambda k, d: int((k * d - d) / 2)
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [

View File

@ -15,8 +15,8 @@ logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
def get_padding(k, d): def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((k * d - d) / 2) return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module): class ResBlock1(torch.nn.Module):

View File

@ -3,7 +3,7 @@ import torch as T
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
def average_over_durations_test(): # pylint: disable=no-self-use def test_average_over_durations(): # pylint: disable=no-self-use
pitch = T.rand(1, 1, 128) pitch = T.rand(1, 1, 128)
durations = T.randint(1, 5, (1, 21)) durations = T.randint(1, 5, (1, 21))
@ -21,7 +21,7 @@ def average_over_durations_test(): # pylint: disable=no-self-use
index += dur index += dur
def seqeunce_mask_test(): def test_sequence_mask():
lengths = T.randint(10, 15, (8,)) lengths = T.randint(10, 15, (8,))
mask = sequence_mask(lengths) mask = sequence_mask(lengths)
for i in range(8): for i in range(8):
@ -30,8 +30,8 @@ def seqeunce_mask_test():
assert mask[i, l:].sum() == 0 assert mask[i, l:].sum() == 0
def segment_test(): def test_segment():
x = T.range(0, 11) x = T.arange(0, 12)
x = x.repeat(8, 1).unsqueeze(1) x = x.repeat(8, 1).unsqueeze(1)
segment_ids = T.randint(0, 7, (8,)) segment_ids = T.randint(0, 7, (8,))
@ -50,11 +50,11 @@ def segment_test():
assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum() assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum()
def rand_segments_test(): def test_rand_segments():
x = T.rand(2, 3, 4) x = T.rand(2, 3, 4)
x_lens = T.randint(3, 4, (2,)) x_lens = T.randint(3, 4, (2,))
segments, seg_idxs = rand_segments(x, x_lens, segment_size=3) segments, seg_idxs = rand_segments(x, x_lens, segment_size=2)
assert segments.shape == (2, 3, 3) assert segments.shape == (2, 3, 2)
assert all(seg_idxs >= 0), seg_idxs assert all(seg_idxs >= 0), seg_idxs
try: try:
segments, _ = rand_segments(x, x_lens, segment_size=5) segments, _ = rand_segments(x, x_lens, segment_size=5)
@ -68,10 +68,10 @@ def rand_segments_test():
assert all(x_lens_back == x_lens) assert all(x_lens_back == x_lens)
def generate_path_test(): def test_generate_path():
durations = T.randint(1, 4, (10, 21)) durations = T.randint(1, 4, (10, 21))
x_length = T.randint(18, 22, (10,)) x_length = T.randint(18, 22, (10,))
x_mask = sequence_mask(x_length).unsqueeze(1).long() x_mask = sequence_mask(x_length, max_len=21).unsqueeze(1).long()
durations = durations * x_mask.squeeze(1) durations = durations * x_mask.squeeze(1)
y_length = durations.sum(1) y_length = durations.sum(1)
y_mask = sequence_mask(y_length).unsqueeze(1).long() y_mask = sequence_mask(y_length).unsqueeze(1).long()