Add BaseEncoder Class

This commit is contained in:
Edresson Casanova 2022-03-10 15:10:26 -03:00
parent a9208e9edd
commit 50305215b3
4 changed files with 157 additions and 235 deletions

View File

@ -11,15 +11,14 @@ from torch.utils.data import DataLoader
from trainer.torch import NoamLR
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec, copy_model_files
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from trainer.trainer_utils import get_optimizer
from TTS.utils.training import check_update
@ -254,40 +253,17 @@ def main(args): # pylint: disable=redefined-outer-name
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
else:
eval_data_loader = None
num_classes = len(train_classes)
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto":
criterion = AngleProtoLoss()
elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
if c.model == "emotion_encoder":
# update config with the class map
num_classes = len(train_classes)
criterion = model.get_criterion(c, num_classes)
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
c.map_classid_to_classname = map_classid_to_classname
copy_model_files(c, OUT_PATH)
else:
raise Exception("The %s not is a loss supported" % c.loss)
if args.restore_path:
checkpoint = load_fsspec(args.restore_path)
try:
model.load_state_dict(checkpoint["model"])
if "criterion" in checkpoint:
criterion.load_state_dict(checkpoint["criterion"])
except (KeyError, RuntimeError):
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group["lr"] = c.lr
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion)
print(" > Model restored from step %d" % args.restore_step, flush=True)
else:
args.restore_step = 0

View File

@ -0,0 +1,141 @@
import torch
import torchaudio
import numpy as np
from torch import nn
from TTS.utils.io import load_fsspec
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
from coqpit import Coqpit
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class BaseEncoder(nn.Module):
"""Base `encoder` class. Every new `encoder` model must inherit this.
It defines common `encoder` specific functions.
"""
# pylint: disable=W0102
def __init__(self):
super(BaseEncoder, self).__init__()
def get_torch_mel_spectrogram_class(audio_config):
return torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
)
)
@torch.no_grad()
def inference(self, x, l2_norm=False):
return self.forward(x, l2_norm)
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config["hop_length"]
max_len = x.shape[1]
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset]
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def get_criterion(self, c: Coqpit, num_classes=None):
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto":
criterion = AngleProtoLoss()
elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
else:
raise Exception("The %s not is a loss supported" % c.loss)
return criterion
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
try:
self.load_state_dict(state["model"])
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error
print(" > Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"], c)
self.load_state_dict(model_dict)
del model_dict
# load the criterion for restore_path
if criterion is not None and "criterion" in state:
criterion.load_state_dict(state["criterion"])
# instance and load the criterion for the encoder classifier in inference time
if eval and criterion is None and "criterion" in state and config.map_classid_to_classname is not None:
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
criterion.load_state_dict(state["criterion"])
if use_cuda:
self.cuda()
if criterion is not None:
criterion = criterion.cuda()
if eval:
self.eval()
assert not self.training
if not eval:
return criterion, state["step"]
return criterion

View File

@ -1,10 +1,7 @@
import numpy as np
import torch
import torchaudio
from torch import nn
from TTS.encoder.models.resnet import PreEmphasis
from TTS.utils.io import load_fsspec
from TTS.encoder.models.base_encoder import BaseEncoder
class LSTMWithProjection(nn.Module):
@ -34,7 +31,7 @@ class LSTMWithoutProjection(nn.Module):
return self.relu(self.linear(hidden[-1]))
class LSTMSpeakerEncoder(nn.Module):
class LSTMSpeakerEncoder(BaseEncoder):
def __init__(
self,
input_dim,
@ -64,32 +61,7 @@ class LSTMSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
else:
self.torch_spec = None
@ -125,75 +97,3 @@ class LSTMSpeakerEncoder(nn.Module):
if l2_norm:
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d
@torch.no_grad()
def inference(self, x, l2_norm=True):
d = self.forward(x, l2_norm=l2_norm)
return d
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
max_len = x.shape[1]
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset]
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.inference(frames_batch)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
"""
Generate embeddings for a batch of utterances
x: BxTxD
"""
num_overlap = num_frames * overlap
max_len = x.shape[1]
embed = None
num_iters = seq_lens / (num_frames - num_overlap)
cur_iter = 0
for offset in range(0, max_len, num_frames - num_overlap):
cur_iter += 1
end_offset = min(x.shape[1], offset + num_frames)
frames = x[:, offset:end_offset]
if embed is None:
embed = self.inference(frames)
else:
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
return embed / num_iters
# pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
# load the criterion for emotion classification
if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None:
criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys()))
criterion.load_state_dict(state["criterion"])
else:
criterion = None
if use_cuda:
self.cuda()
if criterion is not None:
criterion = criterion.cuda()
if eval:
self.eval()
assert not self.training
return criterion

View File

@ -1,24 +1,8 @@
import numpy as np
import torch
import torchaudio
from torch import nn
# from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
from TTS.encoder.losses import SoftmaxAngleProtoLoss
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
from TTS.encoder.models.base_encoder import BaseEncoder
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
@ -71,7 +55,7 @@ class SEBasicBlock(nn.Module):
return out
class ResNetSpeakerEncoder(nn.Module):
class ResNetSpeakerEncoder(BaseEncoder):
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
Adapted from: https://github.com/clovaai/voxceleb_trainer
"""
@ -110,32 +94,7 @@ class ResNetSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
else:
self.torch_spec = None
@ -238,57 +197,3 @@ class ResNetSpeakerEncoder(nn.Module):
if l2_norm:
x = torch.nn.functional.normalize(x, p=2, dim=1)
return x
@torch.no_grad()
def inference(self, x, l2_norm=False):
return self.forward(x, l2_norm)
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config["hop_length"]
max_len = x.shape[1]
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset]
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
# load the criterion for emotion classification
if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None:
criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys()))
criterion.load_state_dict(state["criterion"])
else:
criterion = None
if use_cuda:
self.cuda()
if criterion is not None:
criterion = criterion.cuda()
if eval:
self.eval()
assert not self.training
return criterion