diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 06aa41af..af3e6ec4 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -11,15 +11,14 @@ from torch.utils.data import DataLoader from trainer.torch import NoamLR from TTS.encoder.dataset import EncoderDataset -from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.io import load_fsspec, copy_model_files +from TTS.utils.generic_utils import count_parameters, remove_experiment_folder +from TTS.utils.io import copy_model_files from trainer.trainer_utils import get_optimizer from TTS.utils.training import check_update @@ -254,40 +253,17 @@ def main(args): # pylint: disable=redefined-outer-name eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) else: eval_data_loader = None - num_classes = len(train_classes) - if c.loss == "ge2e": - criterion = GE2ELoss(loss_method="softmax") - elif c.loss == "angleproto": - criterion = AngleProtoLoss() - elif c.loss == "softmaxproto": - criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) - if c.model == "emotion_encoder": - # update config with the class map - c.map_classid_to_classname = map_classid_to_classname - copy_model_files(c, OUT_PATH) - else: - raise Exception("The %s not is a loss supported" % c.loss) + num_classes = len(train_classes) + criterion = model.get_criterion(c, num_classes) + + if c.loss == "softmaxproto" and c.model != "speaker_encoder": + c.map_classid_to_classname = map_classid_to_classname + copy_model_files(c, OUT_PATH) if args.restore_path: - checkpoint = load_fsspec(args.restore_path) - try: - model.load_state_dict(checkpoint["model"]) - - if "criterion" in checkpoint: - criterion.load_state_dict(checkpoint["criterion"]) - - except (KeyError, RuntimeError): - print(" > Partial model initialization.") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], c) - model.load_state_dict(model_dict) - del model_dict - for group in optimizer.param_groups: - group["lr"] = c.lr - - print(" > Model restored from step %d" % checkpoint["step"], flush=True) - args.restore_step = checkpoint["step"] + criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion) + print(" > Model restored from step %d" % args.restore_step, flush=True) else: args.restore_step = 0 diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py new file mode 100644 index 00000000..282a1655 --- /dev/null +++ b/TTS/encoder/models/base_encoder.py @@ -0,0 +1,141 @@ +import torch +import torchaudio +import numpy as np +from torch import nn + +from TTS.utils.io import load_fsspec +from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss +from TTS.utils.generic_utils import set_init_dict +from coqpit import Coqpit + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + +class BaseEncoder(nn.Module): + """Base `encoder` class. Every new `encoder` model must inherit this. + + It defines common `encoder` specific functions. + """ + + # pylint: disable=W0102 + def __init__(self): + super(BaseEncoder, self).__init__() + + def get_torch_mel_spectrogram_class(audio_config): + return torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ) + ) + + @torch.no_grad() + def inference(self, x, l2_norm=False): + return self.forward(x, l2_norm) + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + # map to the waveform size + if self.use_torch_spec: + num_frames = num_frames * self.audio_config["hop_length"] + + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch, l2_norm=l2_norm) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + return embeddings + + def get_criterion(self, c: Coqpit, num_classes=None): + if c.loss == "ge2e": + criterion = GE2ELoss(loss_method="softmax") + elif c.loss == "angleproto": + criterion = AngleProtoLoss() + elif c.loss == "softmaxproto": + criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) + else: + raise Exception("The %s not is a loss supported" % c.loss) + return criterion + + def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + try: + self.load_state_dict(state["model"]) + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + print(" > Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"], c) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + criterion.load_state_dict(state["criterion"]) + # instance and load the criterion for the encoder classifier in inference time + if eval and criterion is None and "criterion" in state and config.map_classid_to_classname is not None: + criterion = self.get_criterion(config, len(config.map_classid_to_classname)) + criterion.load_state_dict(state["criterion"]) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion diff --git a/TTS/encoder/models/lstm.py b/TTS/encoder/models/lstm.py index 6144a9b4..51852b5b 100644 --- a/TTS/encoder/models/lstm.py +++ b/TTS/encoder/models/lstm.py @@ -1,10 +1,7 @@ -import numpy as np import torch -import torchaudio from torch import nn -from TTS.encoder.models.resnet import PreEmphasis -from TTS.utils.io import load_fsspec +from TTS.encoder.models.base_encoder import BaseEncoder class LSTMWithProjection(nn.Module): @@ -34,7 +31,7 @@ class LSTMWithoutProjection(nn.Module): return self.relu(self.linear(hidden[-1])) -class LSTMSpeakerEncoder(nn.Module): +class LSTMSpeakerEncoder(BaseEncoder): def __init__( self, input_dim, @@ -64,32 +61,7 @@ class LSTMSpeakerEncoder(nn.Module): self.instancenorm = nn.InstanceNorm1d(input_dim) if self.use_torch_spec: - self.torch_spec = torch.nn.Sequential( - PreEmphasis(audio_config["preemphasis"]), - # TorchSTFT( - # n_fft=audio_config["fft_size"], - # hop_length=audio_config["hop_length"], - # win_length=audio_config["win_length"], - # sample_rate=audio_config["sample_rate"], - # window="hamming_window", - # mel_fmin=0.0, - # mel_fmax=None, - # use_htk=True, - # do_amp_to_db=False, - # n_mels=audio_config["num_mels"], - # power=2.0, - # use_mel=True, - # mel_norm=None, - # ) - torchaudio.transforms.MelSpectrogram( - sample_rate=audio_config["sample_rate"], - n_fft=audio_config["fft_size"], - win_length=audio_config["win_length"], - hop_length=audio_config["hop_length"], - window_fn=torch.hamming_window, - n_mels=audio_config["num_mels"], - ), - ) + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) else: self.torch_spec = None @@ -125,75 +97,3 @@ class LSTMSpeakerEncoder(nn.Module): if l2_norm: d = torch.nn.functional.normalize(d, p=2, dim=1) return d - - @torch.no_grad() - def inference(self, x, l2_norm=True): - d = self.forward(x, l2_norm=l2_norm) - return d - - def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): - """ - Generate embeddings for a batch of utterances - x: 1xTxD - """ - max_len = x.shape[1] - - if max_len < num_frames: - num_frames = max_len - - offsets = np.linspace(0, max_len - num_frames, num=num_eval) - - frames_batch = [] - for offset in offsets: - offset = int(offset) - end_offset = int(offset + num_frames) - frames = x[:, offset:end_offset] - frames_batch.append(frames) - - frames_batch = torch.cat(frames_batch, dim=0) - embeddings = self.inference(frames_batch) - - if return_mean: - embeddings = torch.mean(embeddings, dim=0, keepdim=True) - - return embeddings - - def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): - """ - Generate embeddings for a batch of utterances - x: BxTxD - """ - num_overlap = num_frames * overlap - max_len = x.shape[1] - embed = None - num_iters = seq_lens / (num_frames - num_overlap) - cur_iter = 0 - for offset in range(0, max_len, num_frames - num_overlap): - cur_iter += 1 - end_offset = min(x.shape[1], offset + num_frames) - frames = x[:, offset:end_offset] - if embed is None: - embed = self.inference(frames) - else: - embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) - return embed / num_iters - - # pylint: disable=unused-argument, redefined-builtin - def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) - self.load_state_dict(state["model"]) - # load the criterion for emotion classification - if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None: - criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys())) - criterion.load_state_dict(state["criterion"]) - else: - criterion = None - - if use_cuda: - self.cuda() - if criterion is not None: - criterion = criterion.cuda() - if eval: - self.eval() - assert not self.training - return criterion diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py index 65da2ea1..c4ba9537 100644 --- a/TTS/encoder/models/resnet.py +++ b/TTS/encoder/models/resnet.py @@ -1,24 +1,8 @@ -import numpy as np import torch -import torchaudio from torch import nn # from TTS.utils.audio import TorchSTFT -from TTS.utils.io import load_fsspec -from TTS.encoder.losses import SoftmaxAngleProtoLoss - -class PreEmphasis(nn.Module): - def __init__(self, coefficient=0.97): - super().__init__() - self.coefficient = coefficient - self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) - - def forward(self, x): - assert len(x.size()) == 2 - - x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") - return torch.nn.functional.conv1d(x, self.filter).squeeze(1) - +from TTS.encoder.models.base_encoder import BaseEncoder class SELayer(nn.Module): def __init__(self, channel, reduction=8): @@ -71,7 +55,7 @@ class SEBasicBlock(nn.Module): return out -class ResNetSpeakerEncoder(nn.Module): +class ResNetSpeakerEncoder(BaseEncoder): """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 Adapted from: https://github.com/clovaai/voxceleb_trainer """ @@ -110,32 +94,7 @@ class ResNetSpeakerEncoder(nn.Module): self.instancenorm = nn.InstanceNorm1d(input_dim) if self.use_torch_spec: - self.torch_spec = torch.nn.Sequential( - PreEmphasis(audio_config["preemphasis"]), - # TorchSTFT( - # n_fft=audio_config["fft_size"], - # hop_length=audio_config["hop_length"], - # win_length=audio_config["win_length"], - # sample_rate=audio_config["sample_rate"], - # window="hamming_window", - # mel_fmin=0.0, - # mel_fmax=None, - # use_htk=True, - # do_amp_to_db=False, - # n_mels=audio_config["num_mels"], - # power=2.0, - # use_mel=True, - # mel_norm=None, - # ) - torchaudio.transforms.MelSpectrogram( - sample_rate=audio_config["sample_rate"], - n_fft=audio_config["fft_size"], - win_length=audio_config["win_length"], - hop_length=audio_config["hop_length"], - window_fn=torch.hamming_window, - n_mels=audio_config["num_mels"], - ), - ) + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) else: self.torch_spec = None @@ -238,57 +197,3 @@ class ResNetSpeakerEncoder(nn.Module): if l2_norm: x = torch.nn.functional.normalize(x, p=2, dim=1) return x - - @torch.no_grad() - def inference(self, x, l2_norm=False): - return self.forward(x, l2_norm) - - @torch.no_grad() - def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): - """ - Generate embeddings for a batch of utterances - x: 1xTxD - """ - # map to the waveform size - if self.use_torch_spec: - num_frames = num_frames * self.audio_config["hop_length"] - - max_len = x.shape[1] - - if max_len < num_frames: - num_frames = max_len - - offsets = np.linspace(0, max_len - num_frames, num=num_eval) - - frames_batch = [] - for offset in offsets: - offset = int(offset) - end_offset = int(offset + num_frames) - frames = x[:, offset:end_offset] - frames_batch.append(frames) - - frames_batch = torch.cat(frames_batch, dim=0) - embeddings = self.inference(frames_batch, l2_norm=l2_norm) - - if return_mean: - embeddings = torch.mean(embeddings, dim=0, keepdim=True) - return embeddings - - def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) - self.load_state_dict(state["model"]) - # load the criterion for emotion classification - if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None: - criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys())) - criterion.load_state_dict(state["criterion"]) - else: - criterion = None - - if use_cuda: - self.cuda() - if criterion is not None: - criterion = criterion.cuda() - if eval: - self.eval() - assert not self.training - return criterion