Make style (#1405)

This commit is contained in:
Eren Gölge 2022-03-16 12:13:55 +01:00 committed by GitHub
parent 690c96ed28
commit 0870a4faa2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 184 additions and 139 deletions

View File

@ -35,7 +35,7 @@ def main():
command += unargs command += unargs
command.append("") command.append("")
# run processes # run a processes per GPU
processes = [] processes = []
for i in range(num_gpus): for i in range(num_gpus):
my_env = os.environ.copy() my_env = os.environ.copy()

View File

@ -1,17 +1,18 @@
import argparse import argparse
import torch
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
import torch
from tqdm import tqdm from tqdm import tqdm
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
def compute_encoder_accuracy(dataset_items, encoder_manager): def compute_encoder_accuracy(dataset_items, encoder_manager):
class_name_key = encoder_manager.speaker_encoder_config.class_name_key class_name_key = encoder_manager.speaker_encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None) map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, "map_classid_to_classname", None)
class_acc_dict = {} class_acc_dict = {}

View File

@ -210,7 +210,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
args = parser.parse_args() args = parser.parse_args()
# print the description if either text or list_models is not set # print the description if either text or list_models is not set
if not args.text and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs and not args.reference_wav: if (
not args.text
and not args.list_models
and not args.list_speaker_idxs
and not args.list_language_idxs
and not args.reference_wav
):
parser.parse_args(["-h"]) parser.parse_args(["-h"])
# load model manager # load model manager
@ -296,7 +302,14 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(" > Text: {}".format(args.text)) print(" > Text: {}".format(args.text))
# kick it # kick it
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav, reference_wav=args.reference_wav, reference_speaker_name=args.reference_speaker_idx) wav = synthesizer.tts(
args.text,
args.speaker_idx,
args.language_idx,
args.speaker_wav,
reference_wav=args.reference_wav,
reference_speaker_name=args.reference_speaker_idx,
)
# save the results # save the results
print(" > Saving output to {}".format(args.out_path)) print(" > Saving output to {}".format(args.out_path))

View File

@ -9,6 +9,7 @@ import traceback
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.torch import NoamLR from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
@ -19,7 +20,6 @@ from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files from TTS.utils.io import copy_model_files
from trainer.trainer_utils import get_optimizer
from TTS.utils.training import check_update from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
@ -56,12 +56,17 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
num_classes_in_batch=num_classes_in_batch, num_classes_in_batch=num_classes_in_batch,
num_gpus=1, num_gpus=1,
shuffle=not is_val, shuffle=not is_val,
drop_last=True) drop_last=True,
)
if len(classes) < num_classes_in_batch: if len(classes) < num_classes_in_batch:
if is_val: if is_val:
raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !") raise RuntimeError(
raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !") f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !"
)
raise RuntimeError(
f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !"
)
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
if is_val: if is_val:
@ -76,6 +81,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
return loader, classes, dataset.get_map_classid_to_classname() return loader, classes, dataset.get_map_classid_to_classname()
def evaluation(model, criterion, data_loader, global_step): def evaluation(model, criterion, data_loader, global_step):
eval_loss = 0 eval_loss = 0
for _, data in enumerate(data_loader): for _, data in enumerate(data_loader):
@ -84,8 +90,12 @@ def evaluation(model, criterion, data_loader, global_step):
inputs, labels = data inputs, labels = data
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1).reshape(labels.shape) labels = torch.transpose(
inputs = torch.transpose(inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1
).reshape(labels.shape)
inputs = torch.transpose(
inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1
).reshape(inputs.shape)
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
@ -96,7 +106,9 @@ def evaluation(model, criterion, data_loader, global_step):
outputs = model(inputs) outputs = model(inputs)
# loss computation # loss computation
loss = criterion(outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels) loss = criterion(
outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels
)
eval_loss += loss.item() eval_loss += loss.item()
@ -110,6 +122,7 @@ def evaluation(model, criterion, data_loader, global_step):
dashboard_logger.eval_figures(global_step, figures) dashboard_logger.eval_figures(global_step, figures)
return eval_avg_loss return eval_avg_loss
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
model.train() model.train()
best_loss = float("inf") best_loss = float("inf")
@ -124,8 +137,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
# setup input data # setup input data
inputs, labels = data inputs, labels = data
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) labels.shape
)
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(
inputs.shape
)
# ToDo: move it to a unit test # ToDo: move it to a unit test
# labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
# inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
@ -157,7 +174,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
outputs = model(inputs) outputs = model(inputs)
# loss computation # loss computation
loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels) loss = criterion(
outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels
)
loss.backward() loss.backward()
grad_norm, _ = check_update(model, c.grad_clip) grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step() optimizer.step()
@ -222,9 +241,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
print("\n\n") print("\n\n")
print("--> EVAL PERFORMANCE") print("--> EVAL PERFORMANCE")
print( print(
" | > Epoch:{} AvgLoss: {:.5f} ".format( " | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss),
epoch, eval_loss
),
flush=True, flush=True,
) )
# save the best checkpoint # save the best checkpoint
@ -262,7 +279,9 @@ def main(args): # pylint: disable=redefined-outer-name
copy_model_files(c, OUT_PATH) copy_model_files(c, OUT_PATH)
if args.restore_path: if args.restore_path:
criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion) criterion, args.restore_step = model.load_checkpoint(
c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion
)
print(" > Model restored from step %d" % args.restore_step, flush=True) print(" > Model restored from step %d" % args.restore_step, flush=True)
else: else:
args.restore_step = 0 args.restore_step = 0

View File

@ -33,10 +33,7 @@ class BaseEncoderConfig(BaseTrainingConfig):
grad_clip: float = 3.0 grad_clip: float = 3.0
lr: float = 0.0001 lr: float = 0.0001
optimizer: str = "radam" optimizer: str = "radam"
optimizer_params: Dict = field(default_factory=lambda: { optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
"betas": [0.9, 0.999],
"weight_decay": 0
})
lr_decay: bool = False lr_decay: bool = False
warmup_steps: int = 4000 warmup_steps: int = 4000

View File

@ -5,6 +5,7 @@ from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV from TTS.encoder.utils.generic_utils import AugmentWAV
class EncoderDataset(Dataset): class EncoderDataset(Dataset):
def __init__( def __init__(
self, self,
@ -57,7 +58,6 @@ class EncoderDataset(Dataset):
print(f" | > Num Classes: {len(self.classes)}") print(f" | > Num Classes: {len(self.classes)}")
print(f" | > Classes: {self.classes}") print(f" | > Classes: {self.classes}")
def load_wav(self, filename): def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio return audio
@ -75,9 +75,7 @@ class EncoderDataset(Dataset):
] ]
# skip classes with number of samples >= self.num_utter_per_class # skip classes with number of samples >= self.num_utter_per_class
class_to_utters = { class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class}
k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class
}
classes = list(class_to_utters.keys()) classes = list(class_to_utters.keys())
classes.sort() classes.sort()
@ -105,11 +103,11 @@ class EncoderDataset(Dataset):
def get_class_list(self): def get_class_list(self):
return self.classes return self.classes
def set_classes(self, classes): def set_classes(self, classes):
self.classes = classes self.classes = classes
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
def get_map_classid_to_classname(self): def get_map_classid_to_classname(self):
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())

View File

@ -195,6 +195,7 @@ class SoftmaxLoss(nn.Module):
class_id = torch.argmax(activations) class_id = torch.argmax(activations)
return class_id return class_id
class SoftmaxAngleProtoLoss(nn.Module): class SoftmaxAngleProtoLoss(nn.Module):
""" """
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153

View File

@ -1,12 +1,13 @@
import numpy as np
import torch import torch
import torchaudio import torchaudio
import numpy as np from coqpit import Coqpit
from torch import nn from torch import nn
from TTS.utils.io import load_fsspec
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict from TTS.utils.generic_utils import set_init_dict
from coqpit import Coqpit from TTS.utils.io import load_fsspec
class PreEmphasis(nn.Module): class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97): def __init__(self, coefficient=0.97):
@ -20,6 +21,7 @@ class PreEmphasis(nn.Module):
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1) return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class BaseEncoder(nn.Module): class BaseEncoder(nn.Module):
"""Base `encoder` class. Every new `encoder` model must inherit this. """Base `encoder` class. Every new `encoder` model must inherit this.
@ -55,7 +57,7 @@ class BaseEncoder(nn.Module):
hop_length=audio_config["hop_length"], hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window, window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"], n_mels=audio_config["num_mels"],
) ),
) )
@torch.no_grad() @torch.no_grad()
@ -104,7 +106,9 @@ class BaseEncoder(nn.Module):
raise Exception("The %s not is a loss supported" % c.loss) raise Exception("The %s not is a loss supported" % c.loss)
return criterion return criterion
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None): def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None
):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
try: try:
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
@ -127,7 +131,12 @@ class BaseEncoder(nn.Module):
print(" > Criterion load ignored because of:", error) print(" > Criterion load ignored because of:", error)
# instance and load the criterion for the encoder classifier in inference time # instance and load the criterion for the encoder classifier in inference time
if eval and criterion is None and "criterion" in state and getattr(config, 'map_classid_to_classname', None) is not None: if (
eval
and criterion is None
and "criterion" in state
and getattr(config, "map_classid_to_classname", None) is not None
):
criterion = self.get_criterion(config, len(config.map_classid_to_classname)) criterion = self.get_criterion(config, len(config.map_classid_to_classname))
criterion.load_state_dict(state["criterion"]) criterion.load_state_dict(state["criterion"])

View File

@ -4,6 +4,7 @@ from torch import nn
# from TTS.utils.audio import TorchSTFT # from TTS.utils.audio import TorchSTFT
from TTS.encoder.models.base_encoder import BaseEncoder from TTS.encoder.models.base_encoder import BaseEncoder
class SELayer(nn.Module): class SELayer(nn.Module):
def __init__(self, channel, reduction=8): def __init__(self, channel, reduction=8):
super(SELayer, self).__init__() super(SELayer, self).__init__()

View File

@ -1,4 +1,5 @@
import random import random
from torch.utils.data.sampler import Sampler, SubsetRandomSampler from torch.utils.data.sampler import Sampler, SubsetRandomSampler
@ -34,10 +35,21 @@ class PerfectBatchSampler(Sampler):
drop_last (bool): if True, drops last incomplete batch. drop_last (bool): if True, drops last incomplete batch.
""" """
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"): def __init__(
self,
dataset_items,
classes,
batch_size,
num_classes_in_batch,
num_gpus=1,
shuffle=True,
drop_last=False,
label_key="class_name",
):
super().__init__(dataset_items) super().__init__(dataset_items)
assert batch_size % (num_classes_in_batch * num_gpus) == 0, ( assert (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).') batch_size % (num_classes_in_batch * num_gpus) == 0
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
label_indices = {} label_indices = {}
for idx, item in enumerate(dataset_items): for idx, item in enumerate(dataset_items):

View File

@ -7,15 +7,15 @@ import torch.distributed as dist
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from torch.utils.data.sampler import WeightedRandomSampler
# pylint: skip-file # pylint: skip-file

View File

@ -994,8 +994,11 @@ class Vits(BaseTTS):
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs return outputs
@torch.no_grad() @torch.no_grad()
def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None): def inference_voice_conversion(
self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None
):
"""Inference for voice conversion """Inference for voice conversion
Args: Args:
@ -1006,7 +1009,13 @@ class Vits(BaseTTS):
reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]` reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]`
""" """
# compute spectrograms # compute spectrograms
y = wav_to_spec(reference_wav, self.config.audio.fft_size, self.config.audio.hop_length, self.config.audio.win_length, center=False).transpose(1, 2) y = wav_to_spec(
reference_wav,
self.config.audio.fft_size,
self.config.audio.hop_length,
self.config.audio.win_length,
center=False,
).transpose(1, 2)
y_lengths = torch.tensor([y.size(-1)]).to(y.device) y_lengths = torch.tensor([y.size(-1)]).to(y.device)
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector

View File

@ -269,7 +269,9 @@ class SpeakerManager:
""" """
self.speaker_encoder_config = load_config(config_path) self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config) self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda) self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(
self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda
)
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio) self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list: def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:

View File

@ -206,6 +206,7 @@ def synthesis(
} }
return return_dict return return_dict
def transfer_voice( def transfer_voice(
model, model,
CONFIG, CONFIG,
@ -269,12 +270,7 @@ def transfer_voice(
_func = model.module.inference_voice_conversion _func = model.module.inference_voice_conversion
else: else:
_func = model.inference_voice_conversion _func = model.inference_voice_conversion
model_outputs = _func( model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector)
reference_wav,
speaker_id,
d_vector,
reference_speaker_id,
reference_d_vector)
# convert outputs to numpy # convert outputs to numpy
# plot results # plot results

View File

@ -214,7 +214,9 @@ class Synthesizer(object):
if speaker_name and isinstance(speaker_name, str): if speaker_name and isinstance(speaker_name, str):
if self.tts_config.use_d_vector_file: if self.tts_config.use_d_vector_file:
# get the average speaker embedding from the saved d_vectors. # get the average speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False) speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector(
speaker_name, num_samples=None, randomize=False
)
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
else: else:
# get speaker idx from the speaker name # get speaker idx from the speaker name
@ -315,13 +317,19 @@ class Synthesizer(object):
if reference_speaker_name and isinstance(reference_speaker_name, str): if reference_speaker_name and isinstance(reference_speaker_name, str):
if self.tts_config.use_d_vector_file: if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors. # get the speaker embedding from the saved d_vectors.
reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(reference_speaker_name)[0] reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(
reference_speaker_embedding = np.array(reference_speaker_embedding)[None, :] # [1 x embedding_dim] reference_speaker_name
)[0]
reference_speaker_embedding = np.array(reference_speaker_embedding)[
None, :
] # [1 x embedding_dim]
else: else:
# get speaker idx from the speaker name # get speaker idx from the speaker name
reference_speaker_id = self.tts_model.speaker_manager.speaker_ids[reference_speaker_name] reference_speaker_id = self.tts_model.speaker_manager.speaker_ids[reference_speaker_name]
else: else:
reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(reference_wav) reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(
reference_wav
)
outputs = transfer_voice( outputs = transfer_voice(
model=self.tts_model, model=self.tts_model,
@ -332,7 +340,7 @@ class Synthesizer(object):
d_vector=speaker_embedding, d_vector=speaker_embedding,
use_griffin_lim=use_gl, use_griffin_lim=use_gl,
reference_speaker_id=reference_speaker_id, reference_speaker_id=reference_speaker_id,
reference_d_vector=reference_speaker_embedding reference_d_vector=reference_speaker_embedding,
) )
waveform = outputs waveform = outputs
if not use_gl: if not use_gl:

View File

@ -41,11 +41,6 @@ model = GAN(config, ap)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainerArgs(), TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
) )
trainer.fit() trainer.fit()

View File

@ -41,11 +41,6 @@ model = GAN(config, ap)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainerArgs(), TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
) )
trainer.fit() trainer.fit()

View File

@ -84,11 +84,6 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainerArgs(), TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
) )
trainer.fit() trainer.fit()

View File

@ -40,11 +40,6 @@ model = GAN(config, ap)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainerArgs(), TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
) )
trainer.fit() trainer.fit()

View File

@ -6,12 +6,11 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import CharactersConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
@ -131,11 +130,6 @@ model = Vits(config, ap, tokenizer, speaker_manager, language_manager)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainerArgs(), TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
) )
trainer.fit() trainer.fit()

View File

@ -1,14 +1,13 @@
import functools import functools
import unittest import unittest
import torch import torch
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.encoder.utils.samplers import PerfectBatchSampler
# Fixing random state to avoid random fails # Fixing random state to avoid random fails
torch.manual_seed(0) torch.manual_seed(0)
@ -60,7 +59,9 @@ class TestSamplers(unittest.TestCase):
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples)) weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_language_balancer_weights(train_samples), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0 en, pt = 0, 0
for index in ids: for index in ids:
@ -73,7 +74,9 @@ class TestSamplers(unittest.TestCase):
def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples)) weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_speaker_balancer_weights(train_samples), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
spk1, spk2 = 0, 0 spk1, spk2 = 0, 0
for index in ids: for index in ids:
@ -96,7 +99,8 @@ class TestSamplers(unittest.TestCase):
num_classes_in_batch=2, num_classes_in_batch=2,
label_key="speaker_name", label_key="speaker_name",
shuffle=False, shuffle=False,
drop_last=True) drop_last=True,
)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs: for batch in batchs:
spk1, spk2 = 0, 0 spk1, spk2 = 0, 0
@ -120,7 +124,8 @@ class TestSamplers(unittest.TestCase):
num_classes_in_batch=2, num_classes_in_batch=2,
label_key="speaker_name", label_key="speaker_name",
shuffle=True, shuffle=True,
drop_last=False) drop_last=False,
)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs: for batch in batchs:
spk1, spk2 = 0, 0 spk1, spk2 = 0, 0