mirror of https://github.com/coqui-ai/TTS.git
Add BaseEncoder Class
This commit is contained in:
parent
a9208e9edd
commit
50305215b3
|
@ -11,15 +11,14 @@ from torch.utils.data import DataLoader
|
|||
from trainer.torch import NoamLR
|
||||
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
from TTS.encoder.utils.training import init_training
|
||||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec, copy_model_files
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||
from TTS.utils.io import copy_model_files
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
|
@ -254,40 +253,17 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
|
||||
else:
|
||||
eval_data_loader = None
|
||||
num_classes = len(train_classes)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
|
||||
if c.model == "emotion_encoder":
|
||||
# update config with the class map
|
||||
num_classes = len(train_classes)
|
||||
criterion = model.get_criterion(c, num_classes)
|
||||
|
||||
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
|
||||
c.map_classid_to_classname = map_classid_to_classname
|
||||
copy_model_files(c, OUT_PATH)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = load_fsspec(args.restore_path)
|
||||
try:
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
if "criterion" in checkpoint:
|
||||
criterion.load_state_dict(checkpoint["criterion"])
|
||||
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = c.lr
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion)
|
||||
print(" > Model restored from step %d" % args.restore_step, flush=True)
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.utils.generic_utils import set_init_dict
|
||||
from coqpit import Coqpit
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
class BaseEncoder(nn.Module):
|
||||
"""Base `encoder` class. Every new `encoder` model must inherit this.
|
||||
|
||||
It defines common `encoder` specific functions.
|
||||
"""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(self):
|
||||
super(BaseEncoder, self).__init__()
|
||||
|
||||
def get_torch_mel_spectrogram_class(audio_config):
|
||||
return torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=False):
|
||||
return self.forward(x, l2_norm)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
# map to the waveform size
|
||||
if self.use_torch_spec:
|
||||
num_frames = num_frames * self.audio_config["hop_length"]
|
||||
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
return embeddings
|
||||
|
||||
def get_criterion(self, c: Coqpit, num_classes=None):
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
return criterion
|
||||
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"], c)
|
||||
self.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# load the criterion for restore_path
|
||||
if criterion is not None and "criterion" in state:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
# instance and load the criterion for the encoder classifier in inference time
|
||||
if eval and criterion is None and "criterion" in state and config.map_classid_to_classname is not None:
|
||||
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
if not eval:
|
||||
return criterion, state["step"]
|
||||
return criterion
|
|
@ -1,10 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
from TTS.encoder.models.resnet import PreEmphasis
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||
|
||||
|
||||
class LSTMWithProjection(nn.Module):
|
||||
|
@ -34,7 +31,7 @@ class LSTMWithoutProjection(nn.Module):
|
|||
return self.relu(self.linear(hidden[-1]))
|
||||
|
||||
|
||||
class LSTMSpeakerEncoder(nn.Module):
|
||||
class LSTMSpeakerEncoder(BaseEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
|
@ -64,32 +61,7 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
|
@ -125,75 +97,3 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
if l2_norm:
|
||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
return d
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=True):
|
||||
d = self.forward(x, l2_norm=l2_norm)
|
||||
return d
|
||||
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.inference(frames_batch)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
|
||||
return embeddings
|
||||
|
||||
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: BxTxD
|
||||
"""
|
||||
num_overlap = num_frames * overlap
|
||||
max_len = x.shape[1]
|
||||
embed = None
|
||||
num_iters = seq_lens / (num_frames - num_overlap)
|
||||
cur_iter = 0
|
||||
for offset in range(0, max_len, num_frames - num_overlap):
|
||||
cur_iter += 1
|
||||
end_offset = min(x.shape[1], offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
if embed is None:
|
||||
embed = self.inference(frames)
|
||||
else:
|
||||
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
|
||||
return embed / num_iters
|
||||
|
||||
# pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
# load the criterion for emotion classification
|
||||
if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None:
|
||||
criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys()))
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
else:
|
||||
criterion = None
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
return criterion
|
||||
|
|
|
@ -1,24 +1,8 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
# from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.encoder.losses import SoftmaxAngleProtoLoss
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
|
@ -71,7 +55,7 @@ class SEBasicBlock(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
class ResNetSpeakerEncoder(BaseEncoder):
|
||||
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
|
||||
Adapted from: https://github.com/clovaai/voxceleb_trainer
|
||||
"""
|
||||
|
@ -110,32 +94,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
|
@ -238,57 +197,3 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
if l2_norm:
|
||||
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=False):
|
||||
return self.forward(x, l2_norm)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
# map to the waveform size
|
||||
if self.use_torch_spec:
|
||||
num_frames = num_frames * self.audio_config["hop_length"]
|
||||
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
return embeddings
|
||||
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
# load the criterion for emotion classification
|
||||
if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None:
|
||||
criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys()))
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
else:
|
||||
criterion = None
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
return criterion
|
||||
|
|
Loading…
Reference in New Issue