Make style (#1405)

This commit is contained in:
Eren Gölge 2022-03-16 12:13:55 +01:00 committed by GitHub
parent 690c96ed28
commit 0870a4faa2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 184 additions and 139 deletions

View File

@ -35,7 +35,7 @@ def main():
command += unargs
command.append("")
# run processes
# run a processes per GPU
processes = []
for i in range(num_gpus):
my_env = os.environ.copy()

View File

@ -1,17 +1,18 @@
import argparse
import torch
from argparse import RawTextHelpFormatter
import torch
from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.speakers import SpeakerManager
def compute_encoder_accuracy(dataset_items, encoder_manager):
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None)
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, "map_classid_to_classname", None)
class_acc_dict = {}
@ -43,11 +44,11 @@ def compute_encoder_accuracy(dataset_items, encoder_manager):
acc_avg = 0
for key, values in class_acc_dict.items():
acc = sum(values)/len(values)
acc = sum(values) / len(values)
print("Class", key, "Accuracy:", acc)
acc_avg += acc
print("Average Accuracy:", acc_avg/len(class_acc_dict))
print("Average Accuracy:", acc_avg / len(class_acc_dict))
if __name__ == "__main__":

View File

@ -210,7 +210,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
args = parser.parse_args()
# print the description if either text or list_models is not set
if not args.text and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs and not args.reference_wav:
if (
not args.text
and not args.list_models
and not args.list_speaker_idxs
and not args.list_language_idxs
and not args.reference_wav
):
parser.parse_args(["-h"])
# load model manager
@ -296,7 +302,14 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(" > Text: {}".format(args.text))
# kick it
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav, reference_wav=args.reference_wav, reference_speaker_name=args.reference_speaker_idx)
wav = synthesizer.tts(
args.text,
args.speaker_idx,
args.language_idx,
args.speaker_wav,
reference_wav=args.reference_wav,
reference_speaker_name=args.reference_speaker_idx,
)
# save the results
print(" > Saving output to {}".format(args.out_path))

View File

@ -9,6 +9,7 @@ import traceback
import torch
from torch.utils.data import DataLoader
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
@ -19,7 +20,6 @@ from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from trainer.trainer_utils import get_optimizer
from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True
@ -52,16 +52,21 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
sampler = PerfectBatchSampler(
dataset.items,
classes,
batch_size=num_classes_in_batch*num_utter_per_class, # total batch size
batch_size=num_classes_in_batch * num_utter_per_class, # total batch size
num_classes_in_batch=num_classes_in_batch,
num_gpus=1,
shuffle=not is_val,
drop_last=True)
drop_last=True,
)
if len(classes) < num_classes_in_batch:
if is_val:
raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !")
raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !")
raise RuntimeError(
f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !"
)
raise RuntimeError(
f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !"
)
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
if is_val:
@ -76,6 +81,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
return loader, classes, dataset.get_map_classid_to_classname()
def evaluation(model, criterion, data_loader, global_step):
eval_loss = 0
for _, data in enumerate(data_loader):
@ -84,8 +90,12 @@ def evaluation(model, criterion, data_loader, global_step):
inputs, labels = data
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
labels = torch.transpose(
labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1
).reshape(labels.shape)
inputs = torch.transpose(
inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1
).reshape(inputs.shape)
# dispatch data to GPU
if use_cuda:
@ -96,20 +106,23 @@ def evaluation(model, criterion, data_loader, global_step):
outputs = model(inputs)
# loss computation
loss = criterion(outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels)
loss = criterion(
outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels
)
eval_loss += loss.item()
eval_avg_loss = eval_loss/len(data_loader)
eval_avg_loss = eval_loss / len(data_loader)
# save stats
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
# plot the last batch in the evaluation
figures = {
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
}
dashboard_logger.eval_figures(global_step, figures)
return eval_avg_loss
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
model.train()
best_loss = float("inf")
@ -124,8 +137,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
# setup input data
inputs, labels = data
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(
labels.shape
)
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(
inputs.shape
)
# ToDo: move it to a unit test
# labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
# inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
@ -157,7 +174,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
outputs = model(inputs)
# loss computation
loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels)
loss = criterion(
outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels
)
loss.backward()
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
@ -211,7 +230,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
print(
">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
"EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time
epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time
),
flush=True,
)
@ -222,10 +241,8 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
print("\n\n")
print("--> EVAL PERFORMANCE")
print(
" | > Epoch:{} AvgLoss: {:.5f} ".format(
epoch, eval_loss
),
flush=True,
" | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss),
flush=True,
)
# save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
@ -262,7 +279,9 @@ def main(args): # pylint: disable=redefined-outer-name
copy_model_files(c, OUT_PATH)
if args.restore_path:
criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion)
criterion, args.restore_step = model.load_checkpoint(
c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion
)
print(" > Model restored from step %d" % args.restore_step, flush=True)
else:
args.restore_step = 0

View File

@ -33,10 +33,7 @@ class BaseEncoderConfig(BaseTrainingConfig):
grad_clip: float = 3.0
lr: float = 0.0001
optimizer: str = "radam"
optimizer_params: Dict = field(default_factory=lambda: {
"betas": [0.9, 0.999],
"weight_decay": 0
})
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
lr_decay: bool = False
warmup_steps: int = 4000

View File

@ -5,6 +5,7 @@ from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV
class EncoderDataset(Dataset):
def __init__(
self,
@ -57,7 +58,6 @@ class EncoderDataset(Dataset):
print(f" | > Num Classes: {len(self.classes)}")
print(f" | > Classes: {self.classes}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio
@ -75,9 +75,7 @@ class EncoderDataset(Dataset):
]
# skip classes with number of samples >= self.num_utter_per_class
class_to_utters = {
k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class
}
class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class}
classes = list(class_to_utters.keys())
classes.sort()
@ -105,11 +103,11 @@ class EncoderDataset(Dataset):
def get_class_list(self):
return self.classes
def set_classes(self, classes):
self.classes = classes
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
def get_map_classid_to_classname(self):
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())

View File

@ -195,6 +195,7 @@ class SoftmaxLoss(nn.Module):
class_id = torch.argmax(activations)
return class_id
class SoftmaxAngleProtoLoss(nn.Module):
"""
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153

View File

@ -1,12 +1,13 @@
import numpy as np
import torch
import torchaudio
import numpy as np
from coqpit import Coqpit
from torch import nn
from TTS.utils.io import load_fsspec
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
from coqpit import Coqpit
from TTS.utils.io import load_fsspec
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
@ -20,6 +21,7 @@ class PreEmphasis(nn.Module):
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class BaseEncoder(nn.Module):
"""Base `encoder` class. Every new `encoder` model must inherit this.
@ -32,31 +34,31 @@ class BaseEncoder(nn.Module):
def get_torch_mel_spectrogram_class(self, audio_config):
return torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
)
)
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
@torch.no_grad()
def inference(self, x, l2_norm=True):
@ -104,7 +106,9 @@ class BaseEncoder(nn.Module):
raise Exception("The %s not is a loss supported" % c.loss)
return criterion
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None):
def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None
):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
try:
self.load_state_dict(state["model"])
@ -127,7 +131,12 @@ class BaseEncoder(nn.Module):
print(" > Criterion load ignored because of:", error)
# instance and load the criterion for the encoder classifier in inference time
if eval and criterion is None and "criterion" in state and getattr(config, 'map_classid_to_classname', None) is not None:
if (
eval
and criterion is None
and "criterion" in state
and getattr(config, "map_classid_to_classname", None) is not None
):
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
criterion.load_state_dict(state["criterion"])

View File

@ -4,6 +4,7 @@ from torch import nn
# from TTS.utils.audio import TorchSTFT
from TTS.encoder.models.base_encoder import BaseEncoder
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()

View File

@ -1,4 +1,5 @@
import random
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
@ -34,10 +35,21 @@ class PerfectBatchSampler(Sampler):
drop_last (bool): if True, drops last incomplete batch.
"""
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
def __init__(
self,
dataset_items,
classes,
batch_size,
num_classes_in_batch,
num_gpus=1,
shuffle=True,
drop_last=False,
label_key="class_name",
):
super().__init__(dataset_items)
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
assert (
batch_size % (num_classes_in_batch * num_gpus) == 0
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
label_indices = {}
for idx, item in enumerate(dataset_items):
@ -93,7 +105,7 @@ class PerfectBatchSampler(Sampler):
if groups % self._dp_devices == 0:
yield batch
else:
batch = batch[:(groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
if len(batch) > 0:
yield batch

View File

@ -7,15 +7,15 @@ import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from torch.utils.data.sampler import WeightedRandomSampler
# pylint: skip-file
@ -258,7 +258,7 @@ class BaseTTS(BaseTrainerModel):
# sampler for DDP
if sampler is None:
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
else: # If a sampler is already defined use this sampler and DDP sampler together
else: # If a sampler is already defined use this sampler and DDP sampler together
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler
return sampler

View File

@ -994,8 +994,11 @@ class Vits(BaseTTS):
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs
@torch.no_grad()
def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None):
def inference_voice_conversion(
self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None
):
"""Inference for voice conversion
Args:
@ -1006,7 +1009,13 @@ class Vits(BaseTTS):
reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]`
"""
# compute spectrograms
y = wav_to_spec(reference_wav, self.config.audio.fft_size, self.config.audio.hop_length, self.config.audio.win_length, center=False).transpose(1, 2)
y = wav_to_spec(
reference_wav,
self.config.audio.fft_size,
self.config.audio.hop_length,
self.config.audio.win_length,
center=False,
).transpose(1, 2)
y_lengths = torch.tensor([y.size(-1)]).to(y.device)
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector

View File

@ -269,7 +269,9 @@ class SpeakerManager:
"""
self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(
self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda
)
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:

View File

@ -206,6 +206,7 @@ def synthesis(
}
return return_dict
def transfer_voice(
model,
CONFIG,
@ -269,12 +270,7 @@ def transfer_voice(
_func = model.module.inference_voice_conversion
else:
_func = model.inference_voice_conversion
model_outputs = _func(
reference_wav,
speaker_id,
d_vector,
reference_speaker_id,
reference_d_vector)
model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector)
# convert outputs to numpy
# plot results

View File

@ -119,7 +119,7 @@ class Synthesizer(object):
if use_cuda:
self.tts_model.cuda()
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"):
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"):
self.tts_model.speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
def _set_speaker_encoder_paths_from_tts_config(self):
@ -199,8 +199,8 @@ class Synthesizer(object):
if not text and not reference_wav:
raise ValueError(
"You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API."
)
"You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API."
)
if text:
sens = self.split_into_sentences(text)
@ -214,7 +214,9 @@ class Synthesizer(object):
if speaker_name and isinstance(speaker_name, str):
if self.tts_config.use_d_vector_file:
# get the average speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False)
speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector(
speaker_name, num_samples=None, randomize=False
)
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
@ -315,25 +317,31 @@ class Synthesizer(object):
if reference_speaker_name and isinstance(reference_speaker_name, str):
if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors.
reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(reference_speaker_name)[0]
reference_speaker_embedding = np.array(reference_speaker_embedding)[None, :] # [1 x embedding_dim]
reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(
reference_speaker_name
)[0]
reference_speaker_embedding = np.array(reference_speaker_embedding)[
None, :
] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
reference_speaker_id = self.tts_model.speaker_manager.speaker_ids[reference_speaker_name]
else:
reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(reference_wav)
reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(
reference_wav
)
outputs = transfer_voice(
model=self.tts_model,
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
reference_wav=reference_wav,
speaker_id=speaker_id,
d_vector=speaker_embedding,
use_griffin_lim=use_gl,
reference_speaker_id=reference_speaker_id,
reference_d_vector=reference_speaker_embedding
)
model=self.tts_model,
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
reference_wav=reference_wav,
speaker_id=speaker_id,
d_vector=speaker_embedding,
use_griffin_lim=use_gl,
reference_speaker_id=reference_speaker_id,
reference_d_vector=reference_speaker_embedding,
)
waveform = outputs
if not use_gl:
mel_postnet_spec = outputs[0].detach().cpu().numpy()

View File

@ -41,11 +41,6 @@ model = GAN(config, ap)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -41,11 +41,6 @@ model = GAN(config, ap)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -84,11 +84,6 @@ model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -40,11 +40,6 @@ model = GAN(config, ap)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -6,12 +6,11 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import CharactersConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
@ -131,11 +130,6 @@ model = Vits(config, ap, tokenizer, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -1,14 +1,13 @@
import functools
import unittest
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.encoder.utils.samplers import PerfectBatchSampler
# Fixing random state to avoid random fails
torch.manual_seed(0)
@ -60,7 +59,9 @@ class TestSamplers(unittest.TestCase):
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples))
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_language_balancer_weights(train_samples), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
@ -73,7 +74,9 @@ class TestSamplers(unittest.TestCase):
def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples))
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_speaker_balancer_weights(train_samples), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
spk1, spk2 = 0, 0
for index in ids:
@ -92,11 +95,12 @@ class TestSamplers(unittest.TestCase):
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=False,
drop_last=True)
drop_last=True,
)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
@ -116,11 +120,12 @@ class TestSamplers(unittest.TestCase):
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=True,
drop_last=False)
drop_last=False,
)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0