mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #49 from idiap/vc-refactors
VC-related refactors and fixes
This commit is contained in:
commit
ff2cd5c97d
|
@ -161,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# dispatch data to GPU
|
||||
|
@ -182,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ class EncoderDataset(Dataset):
|
|||
logger.info(" | Number of instances: %d", len(self.items))
|
||||
logger.info(" | Sequence length: %d", self.seq_len)
|
||||
logger.info(" | Number of classes: %d", len(self.classes))
|
||||
logger.info(" | Classes: %d", self.classes)
|
||||
logger.info(" | Classes: %s", self.classes)
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||
|
|
18
TTS/model.py
18
TTS/model.py
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Dict
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel):
|
|||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
def init_from_config(config: Coqpit) -> "BaseTrainerModel":
|
||||
"""Init the model and all its attributes from the given config.
|
||||
|
||||
Override this depending on your model.
|
||||
|
@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
||||
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
|
||||
"""Forward pass for inference.
|
||||
|
||||
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
|
||||
|
@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel):
|
|||
|
||||
@abstractmethod
|
||||
def load_checkpoint(
|
||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: Union[str, os.PathLike[Any]],
|
||||
eval: bool = False,
|
||||
strict: bool = True,
|
||||
cache: bool = False,
|
||||
) -> None:
|
||||
"""Load a model checkpoint gile and get ready for training or inference.
|
||||
"""Load a model checkpoint file and get ready for training or inference.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
checkpoint_path (str): Path to the model checkpoint file.
|
||||
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
|
||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||
|
|
|
@ -5,6 +5,7 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
||||
from TTS.tts.utils.helpers import convert_pad_shape
|
||||
|
||||
|
||||
class RelativePositionMultiHeadAttention(nn.Module):
|
||||
|
@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module):
|
|||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, self._pad_shape(padding))
|
||||
x = F.pad(x, convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
|
@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module):
|
|||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, self._pad_shape(padding))
|
||||
x = F.pad(x, convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _pad_shape(padding):
|
||||
l = padding[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
class RelativePositionTransformer(nn.Module):
|
||||
"""Transformer with Relative Potional Encoding.
|
||||
|
|
|
@ -255,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _make_ga_mask(ilen, olen, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij")
|
||||
grid_x, grid_y = grid_x.float(), grid_y.float()
|
||||
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
|
||||
|
||||
|
|
|
@ -128,7 +128,8 @@ class NeuralHMM(nn.Module):
|
|||
# Get mean, std and transition vector from decoder for this timestep
|
||||
# Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop
|
||||
if self.use_grad_checkpointing and self.training:
|
||||
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs)
|
||||
# TODO: use_reentrant=False is recommended
|
||||
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True)
|
||||
else:
|
||||
mean, std, transition_vector = self.output_net(h_memory, inputs)
|
||||
|
||||
|
|
|
@ -10,22 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
|
|||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -9,16 +9,13 @@ from torch.nn.utils.parametrizations import weight_norm
|
|||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
|
||||
|
||||
|
|
|
@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
|
|
|
@ -88,12 +88,6 @@ def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
|
|||
return out_padded
|
||||
|
||||
|
||||
def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
|
||||
return torch.ceil(lens / stride).int()
|
||||
|
||||
|
|
|
@ -145,10 +145,9 @@ def average_over_durations(values, durs):
|
|||
return avg
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
def convert_pad_shape(pad_shape: list[list]) -> list:
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -10,6 +10,7 @@ from torch import nn
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
from trainer.trainer import Trainer
|
||||
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
|
@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh
|
|||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel):
|
|||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
ap: AudioProcessor,
|
||||
speaker_manager: Optional[SpeakerManager] = None,
|
||||
language_manager: Optional[LanguageManager] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel):
|
|||
self.language_manager = language_manager
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
def _set_model_args(self, config: Coqpit) -> None:
|
||||
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||
|
||||
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||
|
@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel):
|
|||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
||||
|
@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel):
|
|||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||
|
||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||
def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
|
@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel):
|
|||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
|
@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel):
|
|||
"language_id": language_id,
|
||||
}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic batch formatting for `VCDataset`.
|
||||
|
||||
You must override this if you use a custom dataset.
|
||||
|
||||
Args:
|
||||
batch (Dict): [description]
|
||||
batch (dict): [description]
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch["token_id"]
|
||||
|
@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel):
|
|||
"audio_unique_names": batch["audio_unique_names"],
|
||||
}
|
||||
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1):
|
||||
weights = None
|
||||
data_items = dataset.samples
|
||||
|
||||
|
@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel):
|
|||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
assets: dict,
|
||||
is_eval: bool,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
samples: Union[list[dict], list[list]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
|
@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel):
|
|||
|
||||
def _get_test_aux_input(
|
||||
self,
|
||||
) -> Dict:
|
||||
) -> dict[str, Any]:
|
||||
d_vector = None
|
||||
if self.config.use_d_vector_file:
|
||||
if self.speaker_manager is not None and self.config.use_d_vector_file:
|
||||
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
|
||||
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||
|
||||
|
@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel):
|
|||
}
|
||||
return aux_inputs
|
||||
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
def test_run(self, assets: dict) -> tuple[dict, dict]:
|
||||
"""Generic test run for `vc` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
|
|||
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
tuple[dict, dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
logger.info("Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
|
@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel):
|
|||
)
|
||||
return test_figures, test_audios
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
def on_init_start(self, trainer: Trainer) -> None:
|
||||
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
|
||||
if self.speaker_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "speakers.pth")
|
||||
|
|
|
@ -14,14 +14,16 @@ from torch.nn.utils.parametrize import remove_parametrizations
|
|||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
import TTS.vc.modules.freevc.modules as modules
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
from TTS.vc.models.base_vc import BaseVC
|
||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||
from TTS.vc.modules.freevc.commons import init_weights
|
||||
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
||||
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
||||
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -80,7 +82,7 @@ class Encoder(nn.Module):
|
|||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
|
|
@ -3,23 +3,15 @@ import math
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
|
||||
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
|
@ -119,20 +111,11 @@ def shift_1d(x):
|
|||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
|
|
|
@ -6,26 +6,13 @@ from torch.nn.utils.parametrizations import weight_norm
|
|||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||
from TTS.tts.layers.generic.normalization import LayerNorm2
|
||||
from TTS.vc.modules.freevc.commons import init_weights
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super().__init__()
|
||||
|
@ -40,11 +27,11 @@ class ConvReluNorm(nn.Module):
|
|||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.norm_layers.append(LayerNorm2(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.norm_layers.append(LayerNorm2(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
@ -59,48 +46,6 @@ class ConvReluNorm(nn.Module):
|
|||
return x * x_mask
|
||||
|
||||
|
||||
class DDSConv(nn.Module):
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size**i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(
|
||||
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
|
||||
)
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
||||
super(WN, self).__init__()
|
||||
|
@ -317,24 +262,6 @@ class Flip(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels, 1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
||||
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -3,6 +3,8 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
@ -29,7 +31,6 @@ class DiscriminatorP(torch.nn.Module):
|
|||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
get_padding = lambda k, d: int((k * d - d) / 2)
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
|
|
|
@ -15,8 +15,8 @@ logger = logging.getLogger(__name__)
|
|||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch as T
|
|||
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
|
||||
|
||||
|
||||
def average_over_durations_test(): # pylint: disable=no-self-use
|
||||
def test_average_over_durations(): # pylint: disable=no-self-use
|
||||
pitch = T.rand(1, 1, 128)
|
||||
|
||||
durations = T.randint(1, 5, (1, 21))
|
||||
|
@ -21,7 +21,7 @@ def average_over_durations_test(): # pylint: disable=no-self-use
|
|||
index += dur
|
||||
|
||||
|
||||
def seqeunce_mask_test():
|
||||
def test_sequence_mask():
|
||||
lengths = T.randint(10, 15, (8,))
|
||||
mask = sequence_mask(lengths)
|
||||
for i in range(8):
|
||||
|
@ -30,8 +30,8 @@ def seqeunce_mask_test():
|
|||
assert mask[i, l:].sum() == 0
|
||||
|
||||
|
||||
def segment_test():
|
||||
x = T.range(0, 11)
|
||||
def test_segment():
|
||||
x = T.arange(0, 12)
|
||||
x = x.repeat(8, 1).unsqueeze(1)
|
||||
segment_ids = T.randint(0, 7, (8,))
|
||||
|
||||
|
@ -50,11 +50,11 @@ def segment_test():
|
|||
assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum()
|
||||
|
||||
|
||||
def rand_segments_test():
|
||||
def test_rand_segments():
|
||||
x = T.rand(2, 3, 4)
|
||||
x_lens = T.randint(3, 4, (2,))
|
||||
segments, seg_idxs = rand_segments(x, x_lens, segment_size=3)
|
||||
assert segments.shape == (2, 3, 3)
|
||||
segments, seg_idxs = rand_segments(x, x_lens, segment_size=2)
|
||||
assert segments.shape == (2, 3, 2)
|
||||
assert all(seg_idxs >= 0), seg_idxs
|
||||
try:
|
||||
segments, _ = rand_segments(x, x_lens, segment_size=5)
|
||||
|
@ -68,10 +68,10 @@ def rand_segments_test():
|
|||
assert all(x_lens_back == x_lens)
|
||||
|
||||
|
||||
def generate_path_test():
|
||||
def test_generate_path():
|
||||
durations = T.randint(1, 4, (10, 21))
|
||||
x_length = T.randint(18, 22, (10,))
|
||||
x_mask = sequence_mask(x_length).unsqueeze(1).long()
|
||||
x_mask = sequence_mask(x_length, max_len=21).unsqueeze(1).long()
|
||||
durations = durations * x_mask.squeeze(1)
|
||||
y_length = durations.sum(1)
|
||||
y_mask = sequence_mask(y_length).unsqueeze(1).long()
|
||||
|
|
Loading…
Reference in New Issue