Merge pull request #49 from idiap/vc-refactors

VC-related refactors and fixes
This commit is contained in:
Enno Hermann 2024-06-26 14:01:21 +01:00 committed by GitHub
commit ff2cd5c97d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 76 additions and 184 deletions

View File

@ -161,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
loader_time = time.time() - end_time
global_step += 1
# setup lr
if c.lr_decay:
scheduler.step()
optimizer.zero_grad()
# dispatch data to GPU
@ -182,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
# setup lr
if c.lr_decay:
scheduler.step()
step_time = time.time() - start_time
epoch_time += step_time

View File

@ -55,7 +55,7 @@ class EncoderDataset(Dataset):
logger.info(" | Number of instances: %d", len(self.items))
logger.info(" | Sequence length: %d", self.seq_len)
logger.info(" | Number of classes: %d", len(self.classes))
logger.info(" | Classes: %d", self.classes)
logger.info(" | Classes: %s", self.classes)
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)

View File

@ -1,5 +1,6 @@
import os
from abc import abstractmethod
from typing import Dict
from typing import Any, Union
import torch
from coqpit import Coqpit
@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel):
@staticmethod
@abstractmethod
def init_from_config(config: Coqpit):
def init_from_config(config: Coqpit) -> "BaseTrainerModel":
"""Init the model and all its attributes from the given config.
Override this depending on your model.
@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel):
...
@abstractmethod
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
"""Forward pass for inference.
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel):
@abstractmethod
def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
self,
config: Coqpit,
checkpoint_path: Union[str, os.PathLike[Any]],
eval: bool = False,
strict: bool = True,
cache: bool = False,
) -> None:
"""Load a model checkpoint gile and get ready for training or inference.
"""Load a model checkpoint file and get ready for training or inference.
Args:
config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file.
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.

View File

@ -5,6 +5,7 @@ from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
from TTS.tts.utils.helpers import convert_pad_shape
class RelativePositionMultiHeadAttention(nn.Module):
@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module):
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
x = F.pad(x, convert_pad_shape(padding))
return x
def _same_padding(self, x):
@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module):
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
x = F.pad(x, convert_pad_shape(padding))
return x
@staticmethod
def _pad_shape(padding):
l = padding[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
class RelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding.

View File

@ -255,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module):
@staticmethod
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij")
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))

View File

@ -128,7 +128,8 @@ class NeuralHMM(nn.Module):
# Get mean, std and transition vector from decoder for this timestep
# Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop
if self.use_grad_checkpointing and self.training:
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs)
# TODO: use_reentrant=False is recommended
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True)
else:
mean, std, transition_vector = self.output_net(h_memory, inputs)

View File

@ -10,22 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
LRELU_SLOPE = 0.1
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class TextEncoder(nn.Module):
def __init__(
self,

View File

@ -9,16 +9,13 @@ from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
from TTS.vocoder.models.hifigan_generator import get_padding
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1
def get_padding(k, d):
return int((k * d - d) / 2)
class ResBlock1(torch.nn.Module):
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.

View File

@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel):
if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding()
else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id()

View File

@ -88,12 +88,6 @@ def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
return out_padded
def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
return torch.ceil(lens / stride).int()

View File

@ -145,10 +145,9 @@ def average_over_durations(values, durs):
return avg
def convert_pad_shape(pad_shape):
def convert_pad_shape(pad_shape: list[list]) -> list:
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
return [item for sublist in l for item in sublist]
def generate_path(duration, mask):

View File

@ -1,7 +1,7 @@
import logging
import os
import random
from typing import Dict, List, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
@ -10,6 +10,7 @@ from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer import Trainer
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio.processor import AudioProcessor
# pylint: skip-file
@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel):
def __init__(
self,
config: Coqpit,
ap: "AudioProcessor",
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
ap: AudioProcessor,
speaker_manager: Optional[SpeakerManager] = None,
language_manager: Optional[LanguageManager] = None,
) -> None:
super().__init__()
self.config = config
self.ap = ap
@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel):
self.language_manager = language_manager
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
def _set_model_args(self, config: Coqpit) -> None:
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture.
@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel):
else:
raise ValueError("config must be either a *Config or *Args")
def init_multispeaker(self, config: Coqpit, data: List = None):
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict:
def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_sentences(self, sentence_info):
def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel):
if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding()
else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id()
@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel):
"language_id": language_id,
}
def format_batch(self, batch: Dict) -> Dict:
def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Generic batch formatting for `VCDataset`.
You must override this if you use a custom dataset.
Args:
batch (Dict): [description]
batch (dict): [description]
Returns:
Dict: [description]
dict: [description]
"""
# setup input batch
text_input = batch["token_id"]
@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel):
"audio_unique_names": batch["audio_unique_names"],
}
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1):
weights = None
data_items = dataset.samples
@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel):
def get_data_loader(
self,
config: Coqpit,
assets: Dict,
assets: dict,
is_eval: bool,
samples: Union[List[Dict], List[List]],
samples: Union[list[dict], list[list]],
verbose: bool,
num_gpus: int,
rank: int = None,
rank: Optional[int] = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel):
def _get_test_aux_input(
self,
) -> Dict:
) -> dict[str, Any]:
d_vector = None
if self.config.use_d_vector_file:
if self.speaker_manager is not None and self.config.use_d_vector_file:
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
d_vector = (random.sample(sorted(d_vector), 1),)
@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel):
}
return aux_inputs
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
def test_run(self, assets: dict) -> tuple[dict, dict]:
"""Generic test run for `vc` models used by `Trainer`.
You can override this for a different behaviour.
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
tuple[dict, dict]: Test figures and audios to be projected to Tensorboard.
"""
logger.info("Synthesizing test sentences.")
test_audios = {}
@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel):
)
return test_figures, test_audios
def on_init_start(self, trainer):
def on_init_start(self, trainer: Trainer) -> None:
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.pth")

View File

@ -14,14 +14,16 @@ from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.io import load_fsspec
from TTS.vc.configs.freevc_config import FreeVCConfig
from TTS.vc.models.base_vc import BaseVC
from TTS.vc.modules.freevc.commons import get_padding, init_weights
from TTS.vc.modules.freevc.commons import init_weights
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm
from TTS.vocoder.models.hifigan_generator import get_padding
logger = logging.getLogger(__name__)
@ -80,7 +82,7 @@ class Encoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask

View File

@ -3,23 +3,15 @@ import math
import torch
from torch.nn import functional as F
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
def init_weights(m, mean=0.0, std=0.01):
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
@ -119,20 +111,11 @@ def shift_1d(x):
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)

View File

@ -6,26 +6,13 @@ from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons
from TTS.vc.modules.freevc.commons import get_padding, init_weights
from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.vc.modules.freevc.commons import init_weights
from TTS.vocoder.models.hifigan_generator import get_padding
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
@ -40,11 +27,11 @@ class ConvReluNorm(nn.Module):
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.norm_layers.append(LayerNorm2(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.norm_layers.append(LayerNorm2(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
@ -59,48 +46,6 @@ class ConvReluNorm(nn.Module):
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
super(WN, self).__init__()
@ -317,24 +262,6 @@ class Flip(nn.Module):
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(
self,

View File

@ -3,6 +3,8 @@ import torch
from torch import nn
from torch.nn import functional as F
from TTS.vocoder.models.hifigan_generator import get_padding
LRELU_SLOPE = 0.1
@ -29,7 +31,6 @@ class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super().__init__()
self.period = period
get_padding = lambda k, d: int((k * d - d) / 2)
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[

View File

@ -15,8 +15,8 @@ logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1
def get_padding(k, d):
return int((k * d - d) / 2)
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module):

View File

@ -3,7 +3,7 @@ import torch as T
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
def average_over_durations_test(): # pylint: disable=no-self-use
def test_average_over_durations(): # pylint: disable=no-self-use
pitch = T.rand(1, 1, 128)
durations = T.randint(1, 5, (1, 21))
@ -21,7 +21,7 @@ def average_over_durations_test(): # pylint: disable=no-self-use
index += dur
def seqeunce_mask_test():
def test_sequence_mask():
lengths = T.randint(10, 15, (8,))
mask = sequence_mask(lengths)
for i in range(8):
@ -30,8 +30,8 @@ def seqeunce_mask_test():
assert mask[i, l:].sum() == 0
def segment_test():
x = T.range(0, 11)
def test_segment():
x = T.arange(0, 12)
x = x.repeat(8, 1).unsqueeze(1)
segment_ids = T.randint(0, 7, (8,))
@ -50,11 +50,11 @@ def segment_test():
assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum()
def rand_segments_test():
def test_rand_segments():
x = T.rand(2, 3, 4)
x_lens = T.randint(3, 4, (2,))
segments, seg_idxs = rand_segments(x, x_lens, segment_size=3)
assert segments.shape == (2, 3, 3)
segments, seg_idxs = rand_segments(x, x_lens, segment_size=2)
assert segments.shape == (2, 3, 2)
assert all(seg_idxs >= 0), seg_idxs
try:
segments, _ = rand_segments(x, x_lens, segment_size=5)
@ -68,10 +68,10 @@ def rand_segments_test():
assert all(x_lens_back == x_lens)
def generate_path_test():
def test_generate_path():
durations = T.randint(1, 4, (10, 21))
x_length = T.randint(18, 22, (10,))
x_mask = sequence_mask(x_length).unsqueeze(1).long()
x_mask = sequence_mask(x_length, max_len=21).unsqueeze(1).long()
durations = durations * x_mask.squeeze(1)
y_length = durations.sum(1)
y_mask = sequence_mask(y_length).unsqueeze(1).long()