mirror of https://github.com/coqui-ai/TTS.git
Make style (#1405)
This commit is contained in:
parent
690c96ed28
commit
0870a4faa2
|
@ -35,7 +35,7 @@ def main():
|
||||||
command += unargs
|
command += unargs
|
||||||
command.append("")
|
command.append("")
|
||||||
|
|
||||||
# run processes
|
# run a processes per GPU
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(num_gpus):
|
for i in range(num_gpus):
|
||||||
my_env = os.environ.copy()
|
my_env = os.environ.copy()
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
|
|
||||||
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
|
||||||
|
|
||||||
def compute_encoder_accuracy(dataset_items, encoder_manager):
|
def compute_encoder_accuracy(dataset_items, encoder_manager):
|
||||||
|
|
||||||
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
|
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
|
||||||
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None)
|
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, "map_classid_to_classname", None)
|
||||||
|
|
||||||
class_acc_dict = {}
|
class_acc_dict = {}
|
||||||
|
|
||||||
|
|
|
@ -210,7 +210,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# print the description if either text or list_models is not set
|
# print the description if either text or list_models is not set
|
||||||
if not args.text and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs and not args.reference_wav:
|
if (
|
||||||
|
not args.text
|
||||||
|
and not args.list_models
|
||||||
|
and not args.list_speaker_idxs
|
||||||
|
and not args.list_language_idxs
|
||||||
|
and not args.reference_wav
|
||||||
|
):
|
||||||
parser.parse_args(["-h"])
|
parser.parse_args(["-h"])
|
||||||
|
|
||||||
# load model manager
|
# load model manager
|
||||||
|
@ -296,7 +302,14 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
print(" > Text: {}".format(args.text))
|
print(" > Text: {}".format(args.text))
|
||||||
|
|
||||||
# kick it
|
# kick it
|
||||||
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav, reference_wav=args.reference_wav, reference_speaker_name=args.reference_speaker_idx)
|
wav = synthesizer.tts(
|
||||||
|
args.text,
|
||||||
|
args.speaker_idx,
|
||||||
|
args.language_idx,
|
||||||
|
args.speaker_wav,
|
||||||
|
reference_wav=args.reference_wav,
|
||||||
|
reference_speaker_name=args.reference_speaker_idx,
|
||||||
|
)
|
||||||
|
|
||||||
# save the results
|
# save the results
|
||||||
print(" > Saving output to {}".format(args.out_path))
|
print(" > Saving output to {}".format(args.out_path))
|
||||||
|
|
|
@ -9,6 +9,7 @@ import traceback
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from trainer.torch import NoamLR
|
from trainer.torch import NoamLR
|
||||||
|
from trainer.trainer_utils import get_optimizer
|
||||||
|
|
||||||
from TTS.encoder.dataset import EncoderDataset
|
from TTS.encoder.dataset import EncoderDataset
|
||||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
|
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
|
||||||
|
@ -19,7 +20,6 @@ from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||||
from TTS.utils.io import copy_model_files
|
from TTS.utils.io import copy_model_files
|
||||||
from trainer.trainer_utils import get_optimizer
|
|
||||||
from TTS.utils.training import check_update
|
from TTS.utils.training import check_update
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
|
@ -56,12 +56,17 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
||||||
num_classes_in_batch=num_classes_in_batch,
|
num_classes_in_batch=num_classes_in_batch,
|
||||||
num_gpus=1,
|
num_gpus=1,
|
||||||
shuffle=not is_val,
|
shuffle=not is_val,
|
||||||
drop_last=True)
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
if len(classes) < num_classes_in_batch:
|
if len(classes) < num_classes_in_batch:
|
||||||
if is_val:
|
if is_val:
|
||||||
raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !")
|
raise RuntimeError(
|
||||||
raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !")
|
f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !"
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !"
|
||||||
|
)
|
||||||
|
|
||||||
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
|
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
|
||||||
if is_val:
|
if is_val:
|
||||||
|
@ -76,6 +81,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
||||||
|
|
||||||
return loader, classes, dataset.get_map_classid_to_classname()
|
return loader, classes, dataset.get_map_classid_to_classname()
|
||||||
|
|
||||||
|
|
||||||
def evaluation(model, criterion, data_loader, global_step):
|
def evaluation(model, criterion, data_loader, global_step):
|
||||||
eval_loss = 0
|
eval_loss = 0
|
||||||
for _, data in enumerate(data_loader):
|
for _, data in enumerate(data_loader):
|
||||||
|
@ -84,8 +90,12 @@ def evaluation(model, criterion, data_loader, global_step):
|
||||||
inputs, labels = data
|
inputs, labels = data
|
||||||
|
|
||||||
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
||||||
labels = torch.transpose(labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1).reshape(labels.shape)
|
labels = torch.transpose(
|
||||||
inputs = torch.transpose(inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
|
labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1
|
||||||
|
).reshape(labels.shape)
|
||||||
|
inputs = torch.transpose(
|
||||||
|
inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1
|
||||||
|
).reshape(inputs.shape)
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
@ -96,7 +106,9 @@ def evaluation(model, criterion, data_loader, global_step):
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
loss = criterion(outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels)
|
loss = criterion(
|
||||||
|
outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels
|
||||||
|
)
|
||||||
|
|
||||||
eval_loss += loss.item()
|
eval_loss += loss.item()
|
||||||
|
|
||||||
|
@ -110,6 +122,7 @@ def evaluation(model, criterion, data_loader, global_step):
|
||||||
dashboard_logger.eval_figures(global_step, figures)
|
dashboard_logger.eval_figures(global_step, figures)
|
||||||
return eval_avg_loss
|
return eval_avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
|
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
|
||||||
model.train()
|
model.train()
|
||||||
best_loss = float("inf")
|
best_loss = float("inf")
|
||||||
|
@ -124,8 +137,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
||||||
# setup input data
|
# setup input data
|
||||||
inputs, labels = data
|
inputs, labels = data
|
||||||
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
||||||
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
|
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(
|
||||||
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
|
labels.shape
|
||||||
|
)
|
||||||
|
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(
|
||||||
|
inputs.shape
|
||||||
|
)
|
||||||
# ToDo: move it to a unit test
|
# ToDo: move it to a unit test
|
||||||
# labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
|
# labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
|
||||||
# inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
|
# inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
|
||||||
|
@ -157,7 +174,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels)
|
loss = criterion(
|
||||||
|
outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels
|
||||||
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -222,9 +241,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
||||||
print("\n\n")
|
print("\n\n")
|
||||||
print("--> EVAL PERFORMANCE")
|
print("--> EVAL PERFORMANCE")
|
||||||
print(
|
print(
|
||||||
" | > Epoch:{} AvgLoss: {:.5f} ".format(
|
" | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss),
|
||||||
epoch, eval_loss
|
|
||||||
),
|
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
# save the best checkpoint
|
# save the best checkpoint
|
||||||
|
@ -262,7 +279,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
copy_model_files(c, OUT_PATH)
|
copy_model_files(c, OUT_PATH)
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion)
|
criterion, args.restore_step = model.load_checkpoint(
|
||||||
|
c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion
|
||||||
|
)
|
||||||
print(" > Model restored from step %d" % args.restore_step, flush=True)
|
print(" > Model restored from step %d" % args.restore_step, flush=True)
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
|
|
@ -33,10 +33,7 @@ class BaseEncoderConfig(BaseTrainingConfig):
|
||||||
grad_clip: float = 3.0
|
grad_clip: float = 3.0
|
||||||
lr: float = 0.0001
|
lr: float = 0.0001
|
||||||
optimizer: str = "radam"
|
optimizer: str = "radam"
|
||||||
optimizer_params: Dict = field(default_factory=lambda: {
|
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
|
||||||
"betas": [0.9, 0.999],
|
|
||||||
"weight_decay": 0
|
|
||||||
})
|
|
||||||
lr_decay: bool = False
|
lr_decay: bool = False
|
||||||
warmup_steps: int = 4000
|
warmup_steps: int = 4000
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.encoder.utils.generic_utils import AugmentWAV
|
from TTS.encoder.utils.generic_utils import AugmentWAV
|
||||||
|
|
||||||
|
|
||||||
class EncoderDataset(Dataset):
|
class EncoderDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -57,7 +58,6 @@ class EncoderDataset(Dataset):
|
||||||
print(f" | > Num Classes: {len(self.classes)}")
|
print(f" | > Num Classes: {len(self.classes)}")
|
||||||
print(f" | > Classes: {self.classes}")
|
print(f" | > Classes: {self.classes}")
|
||||||
|
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||||
return audio
|
return audio
|
||||||
|
@ -75,9 +75,7 @@ class EncoderDataset(Dataset):
|
||||||
]
|
]
|
||||||
|
|
||||||
# skip classes with number of samples >= self.num_utter_per_class
|
# skip classes with number of samples >= self.num_utter_per_class
|
||||||
class_to_utters = {
|
class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class}
|
||||||
k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class
|
|
||||||
}
|
|
||||||
|
|
||||||
classes = list(class_to_utters.keys())
|
classes = list(class_to_utters.keys())
|
||||||
classes.sort()
|
classes.sort()
|
||||||
|
@ -105,11 +103,11 @@ class EncoderDataset(Dataset):
|
||||||
|
|
||||||
def get_class_list(self):
|
def get_class_list(self):
|
||||||
return self.classes
|
return self.classes
|
||||||
|
|
||||||
def set_classes(self, classes):
|
def set_classes(self, classes):
|
||||||
self.classes = classes
|
self.classes = classes
|
||||||
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
|
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
|
||||||
|
|
||||||
|
|
||||||
def get_map_classid_to_classname(self):
|
def get_map_classid_to_classname(self):
|
||||||
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
|
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
|
||||||
|
|
||||||
|
|
|
@ -195,6 +195,7 @@ class SoftmaxLoss(nn.Module):
|
||||||
class_id = torch.argmax(activations)
|
class_id = torch.argmax(activations)
|
||||||
return class_id
|
return class_id
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxAngleProtoLoss(nn.Module):
|
class SoftmaxAngleProtoLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
|
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import numpy as np
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.utils.generic_utils import set_init_dict
|
from TTS.utils.generic_utils import set_init_dict
|
||||||
from coqpit import Coqpit
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class PreEmphasis(nn.Module):
|
class PreEmphasis(nn.Module):
|
||||||
def __init__(self, coefficient=0.97):
|
def __init__(self, coefficient=0.97):
|
||||||
|
@ -20,6 +21,7 @@ class PreEmphasis(nn.Module):
|
||||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
class BaseEncoder(nn.Module):
|
class BaseEncoder(nn.Module):
|
||||||
"""Base `encoder` class. Every new `encoder` model must inherit this.
|
"""Base `encoder` class. Every new `encoder` model must inherit this.
|
||||||
|
|
||||||
|
@ -55,7 +57,7 @@ class BaseEncoder(nn.Module):
|
||||||
hop_length=audio_config["hop_length"],
|
hop_length=audio_config["hop_length"],
|
||||||
window_fn=torch.hamming_window,
|
window_fn=torch.hamming_window,
|
||||||
n_mels=audio_config["num_mels"],
|
n_mels=audio_config["num_mels"],
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -104,7 +106,9 @@ class BaseEncoder(nn.Module):
|
||||||
raise Exception("The %s not is a loss supported" % c.loss)
|
raise Exception("The %s not is a loss supported" % c.loss)
|
||||||
return criterion
|
return criterion
|
||||||
|
|
||||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None):
|
def load_checkpoint(
|
||||||
|
self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None
|
||||||
|
):
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
try:
|
try:
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
|
@ -127,7 +131,12 @@ class BaseEncoder(nn.Module):
|
||||||
print(" > Criterion load ignored because of:", error)
|
print(" > Criterion load ignored because of:", error)
|
||||||
|
|
||||||
# instance and load the criterion for the encoder classifier in inference time
|
# instance and load the criterion for the encoder classifier in inference time
|
||||||
if eval and criterion is None and "criterion" in state and getattr(config, 'map_classid_to_classname', None) is not None:
|
if (
|
||||||
|
eval
|
||||||
|
and criterion is None
|
||||||
|
and "criterion" in state
|
||||||
|
and getattr(config, "map_classid_to_classname", None) is not None
|
||||||
|
):
|
||||||
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
|
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
|
||||||
criterion.load_state_dict(state["criterion"])
|
criterion.load_state_dict(state["criterion"])
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from torch import nn
|
||||||
# from TTS.utils.audio import TorchSTFT
|
# from TTS.utils.audio import TorchSTFT
|
||||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
class SELayer(nn.Module):
|
class SELayer(nn.Module):
|
||||||
def __init__(self, channel, reduction=8):
|
def __init__(self, channel, reduction=8):
|
||||||
super(SELayer, self).__init__()
|
super(SELayer, self).__init__()
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
|
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,10 +35,21 @@ class PerfectBatchSampler(Sampler):
|
||||||
drop_last (bool): if True, drops last incomplete batch.
|
drop_last (bool): if True, drops last incomplete batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_items,
|
||||||
|
classes,
|
||||||
|
batch_size,
|
||||||
|
num_classes_in_batch,
|
||||||
|
num_gpus=1,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=False,
|
||||||
|
label_key="class_name",
|
||||||
|
):
|
||||||
super().__init__(dataset_items)
|
super().__init__(dataset_items)
|
||||||
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
|
assert (
|
||||||
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
|
batch_size % (num_classes_in_batch * num_gpus) == 0
|
||||||
|
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
|
||||||
|
|
||||||
label_indices = {}
|
label_indices = {}
|
||||||
for idx, item in enumerate(dataset_items):
|
for idx, item in enumerate(dataset_items):
|
||||||
|
|
|
@ -7,15 +7,15 @@ import torch.distributed as dist
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
|
|
||||||
from TTS.model import BaseTrainerModel
|
from TTS.model import BaseTrainerModel
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
|
||||||
|
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
|
||||||
|
|
|
@ -994,8 +994,11 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None):
|
def inference_voice_conversion(
|
||||||
|
self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None
|
||||||
|
):
|
||||||
"""Inference for voice conversion
|
"""Inference for voice conversion
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1006,7 +1009,13 @@ class Vits(BaseTTS):
|
||||||
reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]`
|
reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]`
|
||||||
"""
|
"""
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
y = wav_to_spec(reference_wav, self.config.audio.fft_size, self.config.audio.hop_length, self.config.audio.win_length, center=False).transpose(1, 2)
|
y = wav_to_spec(
|
||||||
|
reference_wav,
|
||||||
|
self.config.audio.fft_size,
|
||||||
|
self.config.audio.hop_length,
|
||||||
|
self.config.audio.win_length,
|
||||||
|
center=False,
|
||||||
|
).transpose(1, 2)
|
||||||
y_lengths = torch.tensor([y.size(-1)]).to(y.device)
|
y_lengths = torch.tensor([y.size(-1)]).to(y.device)
|
||||||
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
|
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
|
||||||
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector
|
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector
|
||||||
|
|
|
@ -269,7 +269,9 @@ class SpeakerManager:
|
||||||
"""
|
"""
|
||||||
self.speaker_encoder_config = load_config(config_path)
|
self.speaker_encoder_config = load_config(config_path)
|
||||||
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
|
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
|
||||||
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
|
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(
|
||||||
|
self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda
|
||||||
|
)
|
||||||
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
||||||
|
|
||||||
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||||
|
|
|
@ -206,6 +206,7 @@ def synthesis(
|
||||||
}
|
}
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
def transfer_voice(
|
def transfer_voice(
|
||||||
model,
|
model,
|
||||||
CONFIG,
|
CONFIG,
|
||||||
|
@ -269,12 +270,7 @@ def transfer_voice(
|
||||||
_func = model.module.inference_voice_conversion
|
_func = model.module.inference_voice_conversion
|
||||||
else:
|
else:
|
||||||
_func = model.inference_voice_conversion
|
_func = model.inference_voice_conversion
|
||||||
model_outputs = _func(
|
model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector)
|
||||||
reference_wav,
|
|
||||||
speaker_id,
|
|
||||||
d_vector,
|
|
||||||
reference_speaker_id,
|
|
||||||
reference_d_vector)
|
|
||||||
|
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
# plot results
|
# plot results
|
||||||
|
|
|
@ -214,7 +214,9 @@ class Synthesizer(object):
|
||||||
if speaker_name and isinstance(speaker_name, str):
|
if speaker_name and isinstance(speaker_name, str):
|
||||||
if self.tts_config.use_d_vector_file:
|
if self.tts_config.use_d_vector_file:
|
||||||
# get the average speaker embedding from the saved d_vectors.
|
# get the average speaker embedding from the saved d_vectors.
|
||||||
speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False)
|
speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector(
|
||||||
|
speaker_name, num_samples=None, randomize=False
|
||||||
|
)
|
||||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||||
else:
|
else:
|
||||||
# get speaker idx from the speaker name
|
# get speaker idx from the speaker name
|
||||||
|
@ -315,13 +317,19 @@ class Synthesizer(object):
|
||||||
if reference_speaker_name and isinstance(reference_speaker_name, str):
|
if reference_speaker_name and isinstance(reference_speaker_name, str):
|
||||||
if self.tts_config.use_d_vector_file:
|
if self.tts_config.use_d_vector_file:
|
||||||
# get the speaker embedding from the saved d_vectors.
|
# get the speaker embedding from the saved d_vectors.
|
||||||
reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(reference_speaker_name)[0]
|
reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(
|
||||||
reference_speaker_embedding = np.array(reference_speaker_embedding)[None, :] # [1 x embedding_dim]
|
reference_speaker_name
|
||||||
|
)[0]
|
||||||
|
reference_speaker_embedding = np.array(reference_speaker_embedding)[
|
||||||
|
None, :
|
||||||
|
] # [1 x embedding_dim]
|
||||||
else:
|
else:
|
||||||
# get speaker idx from the speaker name
|
# get speaker idx from the speaker name
|
||||||
reference_speaker_id = self.tts_model.speaker_manager.speaker_ids[reference_speaker_name]
|
reference_speaker_id = self.tts_model.speaker_manager.speaker_ids[reference_speaker_name]
|
||||||
else:
|
else:
|
||||||
reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(reference_wav)
|
reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(
|
||||||
|
reference_wav
|
||||||
|
)
|
||||||
|
|
||||||
outputs = transfer_voice(
|
outputs = transfer_voice(
|
||||||
model=self.tts_model,
|
model=self.tts_model,
|
||||||
|
@ -332,7 +340,7 @@ class Synthesizer(object):
|
||||||
d_vector=speaker_embedding,
|
d_vector=speaker_embedding,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
reference_speaker_id=reference_speaker_id,
|
reference_speaker_id=reference_speaker_id,
|
||||||
reference_d_vector=reference_speaker_embedding
|
reference_d_vector=reference_speaker_embedding,
|
||||||
)
|
)
|
||||||
waveform = outputs
|
waveform = outputs
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
|
|
|
@ -41,11 +41,6 @@ model = GAN(config, ap)
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
TrainerArgs(),
|
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
|
||||||
config,
|
|
||||||
output_path,
|
|
||||||
model=model,
|
|
||||||
train_samples=train_samples,
|
|
||||||
eval_samples=eval_samples
|
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -41,11 +41,6 @@ model = GAN(config, ap)
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
TrainerArgs(),
|
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
|
||||||
config,
|
|
||||||
output_path,
|
|
||||||
model=model,
|
|
||||||
train_samples=train_samples,
|
|
||||||
eval_samples=eval_samples
|
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -84,11 +84,6 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
TrainerArgs(),
|
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
|
||||||
config,
|
|
||||||
output_path,
|
|
||||||
model=model,
|
|
||||||
train_samples=train_samples,
|
|
||||||
eval_samples=eval_samples
|
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -40,11 +40,6 @@ model = GAN(config, ap)
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
TrainerArgs(),
|
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
|
||||||
config,
|
|
||||||
output_path,
|
|
||||||
model=model,
|
|
||||||
train_samples=train_samples,
|
|
||||||
eval_samples=eval_samples
|
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -6,12 +6,11 @@ from trainer import Trainer, TrainerArgs
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.models.vits import CharactersConfig
|
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs
|
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
@ -131,11 +130,6 @@ model = Vits(config, ap, tokenizer, speaker_manager, language_manager)
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
TrainerArgs(),
|
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
|
||||||
config,
|
|
||||||
output_path,
|
|
||||||
model=model,
|
|
||||||
train_samples=train_samples,
|
|
||||||
eval_samples=eval_samples
|
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
|
||||||
|
|
||||||
# Fixing random state to avoid random fails
|
# Fixing random state to avoid random fails
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -60,7 +59,9 @@ class TestSamplers(unittest.TestCase):
|
||||||
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
||||||
|
|
||||||
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
|
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
|
||||||
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples))
|
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
|
||||||
|
get_language_balancer_weights(train_samples), len(train_samples)
|
||||||
|
)
|
||||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||||
en, pt = 0, 0
|
en, pt = 0, 0
|
||||||
for index in ids:
|
for index in ids:
|
||||||
|
@ -73,7 +74,9 @@ class TestSamplers(unittest.TestCase):
|
||||||
|
|
||||||
def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
|
def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
|
||||||
|
|
||||||
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples))
|
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
|
||||||
|
get_speaker_balancer_weights(train_samples), len(train_samples)
|
||||||
|
)
|
||||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||||
spk1, spk2 = 0, 0
|
spk1, spk2 = 0, 0
|
||||||
for index in ids:
|
for index in ids:
|
||||||
|
@ -96,7 +99,8 @@ class TestSamplers(unittest.TestCase):
|
||||||
num_classes_in_batch=2,
|
num_classes_in_batch=2,
|
||||||
label_key="speaker_name",
|
label_key="speaker_name",
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
drop_last=True)
|
drop_last=True,
|
||||||
|
)
|
||||||
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
||||||
for batch in batchs:
|
for batch in batchs:
|
||||||
spk1, spk2 = 0, 0
|
spk1, spk2 = 0, 0
|
||||||
|
@ -120,7 +124,8 @@ class TestSamplers(unittest.TestCase):
|
||||||
num_classes_in_batch=2,
|
num_classes_in_batch=2,
|
||||||
label_key="speaker_name",
|
label_key="speaker_name",
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
drop_last=False)
|
drop_last=False,
|
||||||
|
)
|
||||||
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
||||||
for batch in batchs:
|
for batch in batchs:
|
||||||
spk1, spk2 = 0, 0
|
spk1, spk2 = 0, 0
|
||||||
|
|
Loading…
Reference in New Issue