Merge pull request #777 from coqui-ai/dev

v0.2.1
This commit is contained in:
Eren Gölge 2021-08-31 12:38:28 +02:00 committed by GitHub
commit 5793dca71d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 496 additions and 349 deletions

3
.gitignore vendored
View File

@ -155,4 +155,5 @@ deps.json
speakers.json
internal/*
*_pitch.npy
*_phoneme.npy
*_phoneme.npy
wandb

View File

@ -64,6 +64,11 @@ disable=missing-docstring,
too-many-public-methods,
too-many-lines,
bare-except,
## for avoiding weird p3.6 CI linter error
## TODO: see later if we can remove this
assigning-non-slot,
unsupported-assignment-operation,
## end
line-too-long,
fixme,
wrong-import-order,
@ -73,6 +78,7 @@ disable=missing-docstring,
invalid-name,
too-many-instance-attributes,
arguments-differ,
arguments-renamed,
no-name-in-module,
no-member,
unsubscriptable-object,

View File

@ -102,7 +102,7 @@ You can also help us implement more models.
## Install TTS
🐸TTS is tested on Ubuntu 18.04 with **python >= 3.6, < 3.9**.
If you are only interested in [synthesizing speech](https://github.com/coqui-ai/TTS/tree/dev#example-synthesizing-speech-on-terminal-using-the-released-models) with the released 🐸TTS models, installing from PyPI is the easiest option.
If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option.
```bash
pip install TTS

View File

@ -1 +1 @@
0.2.0
0.2.1

View File

@ -1,6 +1,6 @@
import os
with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f:
with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f:
version = f.read().strip()
__version__ = version

View File

@ -97,7 +97,7 @@ Example run:
enable_eos_bos=C.enable_eos_bos_chars,
)
dataset.sort_items()
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
loader = DataLoader(
dataset,
batch_size=args.batch_size,
@ -158,7 +158,7 @@ Example run:
# ourput metafile
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
with open(metafile, "w") as f:
with open(metafile, "w", encoding="utf-8") as f:
for p in file_paths:
f.write(f"{p[0]}|{p[1]}\n")
print(f" >> Metafile created: {metafile}")

View File

@ -32,6 +32,7 @@ def main():
command.append("--restore_path={}".format(args.restore_path))
command.append("--config_path={}".format(args.config_path))
command.append("--group_id=group_{}".format(group_id))
command.append("--use_ddp=true")
command += unargs
command.append("")
@ -42,7 +43,7 @@ def main():
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[-1] = "--rank={}".format(i)
# prevent stdout for processes with rank != 0
stdout = None if i == 0 else open(os.devnull, "w")
stdout = None
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
processes.append(p)
print(command)

View File

@ -46,7 +46,7 @@ def setup_loader(ap, r, verbose=False):
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_items()
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
loader = DataLoader(
dataset,
@ -215,7 +215,7 @@ def extract_spectrograms(
wav = ap.inv_melspectrogram(mel)
ap.save_wav(wav, wav_gl_path)
with open(os.path.join(output_path, metada_name), "w") as f:
with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
for data in export_metadata:
f.write(f"{data[0]}|{data[1]+'.npy'}\n")

View File

@ -190,7 +190,7 @@ class BaseTrainingConfig(Coqpit):
Name of the model that is used in the training.
run_name (str):
Name of the experiment. This prefixes the output folder name.
Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
run_description (str):
Short description of the experiment.
@ -272,7 +272,7 @@ class BaseTrainingConfig(Coqpit):
"""
model: str = None
run_name: str = ""
run_name: str = "coqui_tts"
run_description: str = ""
# training params
epochs: int = 10000

View File

@ -23,35 +23,31 @@ class BaseModel(nn.Module, ABC):
"""
@abstractmethod
def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict:
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is mostly used by
`train_step()` in training whitout exposing it to the out of the class.
You can be flexible here and use different number of arguments and argument names since it is intended to be
used by `train_step()` without exposing it out of the model.
Args:
text (torch.Tensor): Input text character sequence ids.
input (torch.Tensor): Input tensor.
aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
for the model.
Returns:
Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model.
Dict: Model outputs. Main model output must be named as "model_outputs".
"""
outputs_dict = {"model_outputs": None}
...
return outputs_dict
@abstractmethod
def inference(self, text: torch.Tensor, aux_input={}) -> Dict:
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
"""Forward pass for inference.
After the model is trained this is the only function that connects the model the out world.
This function must only take a `text` input and a dictionary that has all the other model specific inputs.
We don't use `*kwargs` since it is problematic with the TorchScript API.
Args:
text (torch.Tensor): [description]
input (torch.Tensor): [description]
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
Returns:

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
# adapted from https://github.com/cvqluu/GE2E-Loss

View File

@ -1,6 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
from torch import nn
from TTS.utils.io import load_fsspec

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field
from typing import List
from typing import Dict, List
from coqpit import MISSING
@ -14,7 +14,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# model params
model_params: dict = field(
model_params: Dict = field(
default_factory=lambda: {
"model_name": "lstm",
"input_dim": 80,
@ -25,9 +25,9 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
}
)
audio_augmentation: dict = field(default_factory=lambda: {})
audio_augmentation: Dict = field(default_factory=lambda: {})
storage: dict = field(
storage: Dict = field(
default_factory=lambda: {
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 15, # the size of the in-memory storage with respect to a single batch

View File

@ -94,7 +94,8 @@ def download_and_extract(directory, subset, urls):
extract_path = zip_filepath.strip(".zip")
# check zip file md5sum
md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest()
with open(zip_filepath, "rb") as f_zip:
md5 = hashlib.md5(f_zip.read()).hexdigest()
if md5 != MD5SUM[subset]:
raise ValueError("md5sum of %s mismatch" % zip_filepath)

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
import importlib
import logging
import multiprocessing
import os
import platform
@ -16,6 +15,7 @@ from urllib.parse import urlparse
import fsspec
import torch
import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
@ -38,7 +38,7 @@ from TTS.utils.generic_utils import (
to_cuda,
)
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model
@ -77,12 +77,16 @@ class TrainingArgs(Coqpit):
best_path: str = field(
default="",
metadata={
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
},
)
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
use_ddp: bool = field(
default=False,
metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
)
class Trainer:
@ -144,6 +148,7 @@ class Trainer:
>>> trainer.fit()
TODO:
- Wrap model for not calling .module in DDP.
- Accumulate gradients b/w batches.
- Deepspeed integration
- Profiler integration.
@ -151,29 +156,33 @@ class Trainer:
- TPU training
"""
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
if config is None:
# parse config from console arguments
config, output_path, _, c_logger, dashboard_logger = process_args(args)
self.output_path = output_path
self.args = args
self.config = config
self.output_path = output_path
self.config.output_log_path = output_path
# setup logging
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
# init loggers
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
self.dashboard_logger = dashboard_logger
if self.dashboard_logger is None:
self.dashboard_logger = init_logger(config)
# only allow dashboard logging for the main process in DDP mode
if self.dashboard_logger is None and args.rank == 0:
self.dashboard_logger = init_dashboard_logger(config)
if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
self.total_steps_done = 0
self.epochs_done = 0
self.restore_step = 0
@ -247,10 +256,10 @@ class Trainer:
if self.use_apex:
self.scaler = None
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
if isinstance(self.optimizer, list):
self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
else:
self.scaler = torch.cuda.amp.GradScaler()
# if isinstance(self.optimizer, list):
# self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
# else:
self.scaler = torch.cuda.amp.GradScaler()
else:
self.scaler = None
@ -319,14 +328,14 @@ class Trainer:
return obj
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = load_fsspec(restore_path)
checkpoint = load_fsspec(restore_path, map_location="cpu")
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
print(" > Restoring Optimizer...")
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
print(" > Restoring AMP Scaler...")
print(" > Restoring Scaler...")
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
except (KeyError, RuntimeError):
print(" > Partial model initialization...")
@ -346,10 +355,11 @@ class Trainer:
" > Model restored from step %d" % checkpoint["step"],
)
restore_step = checkpoint["step"]
torch.cuda.empty_cache()
return model, optimizer, scaler, restore_step
@staticmethod
def _get_loader(
self,
model: nn.Module,
config: Coqpit,
ap: AudioProcessor,
@ -358,8 +368,14 @@ class Trainer:
verbose: bool,
num_gpus: int,
) -> DataLoader:
if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
if num_gpus > 1:
if hasattr(model.module, "get_data_loader"):
loader = model.module.get_data_loader(
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
)
else:
if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
return loader
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
@ -387,12 +403,28 @@ class Trainer:
Returns:
Dict: Formatted batch.
"""
batch = self.model.format_batch(batch)
if self.num_gpus > 1:
batch = self.model.module.format_batch(batch)
else:
batch = self.model.format_batch(batch)
if self.use_cuda:
for k, v in batch.items():
batch[k] = to_cuda(v)
return batch
@staticmethod
def master_params(optimizer: torch.optim.Optimizer):
"""Generator over parameters owned by the optimizer.
Used to select parameters used by the optimizer for gradient clipping.
Args:
optimizer: Target optimizer.
"""
for group in optimizer.param_groups:
for p in group["params"]:
yield p
@staticmethod
def _model_train_step(
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
@ -450,6 +482,8 @@ class Trainer:
step_start_time = time.time()
# zero-out optimizer
optimizer.zero_grad()
# forward pass and loss computation
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
if optimizer_idx is not None:
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
@ -461,9 +495,9 @@ class Trainer:
step_time = time.time() - step_start_time
return None, {}, step_time
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
# # check nan loss
# if torch.isnan(loss_dict["loss"]).any():
# raise RuntimeError(f" > NaN loss detected - {loss_dict}")
# set gradient clipping threshold
if "grad_clip" in config and config.grad_clip is not None:
@ -481,6 +515,8 @@ class Trainer:
update_lr_scheduler = True
if self.use_amp_scaler:
if self.use_apex:
# TODO: verify AMP use for GAN training in TTS
# https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
@ -491,7 +527,9 @@ class Trainer:
scaler.scale(loss_dict["loss"]).backward()
if grad_clip > 0:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.master_params(optimizer), grad_clip, error_if_nonfinite=False
)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0
@ -571,7 +609,8 @@ class Trainer:
total_step_time = 0
for idx, optimizer in enumerate(self.optimizer):
criterion = self.criterion
scaler = self.scaler[idx] if self.use_amp_scaler else None
# scaler = self.scaler[idx] if self.use_amp_scaler else None
scaler = self.scaler
scheduler = self.scheduler[idx]
outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
@ -592,13 +631,13 @@ class Trainer:
outputs = outputs_per_optimizer
# update avg runtime stats
keep_avg_update = dict()
keep_avg_update = {}
keep_avg_update["avg_loader_time"] = loader_time
keep_avg_update["avg_step_time"] = step_time
self.keep_avg_train.update_values(keep_avg_update)
# update avg loss stats
update_eval_values = dict()
update_eval_values = {}
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_train.update_values(update_eval_values)
@ -662,9 +701,10 @@ class Trainer:
if audios is not None:
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.dashboard_logger.flush()
self.total_steps_done += 1
self.callbacks.on_train_step_end()
self.dashboard_logger.flush()
return outputs, loss_dict
def train_epoch(self) -> None:
@ -674,16 +714,20 @@ class Trainer:
self.data_train,
verbose=True,
)
self.model.train()
if self.num_gpus > 1:
self.model.module.train()
else:
self.model.train()
epoch_start_time = time.time()
if self.use_cuda:
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
else:
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start()
loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader):
loader_start_time = time.time()
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
loader_start_time = time.time()
epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats
if self.args.rank == 0:
@ -753,7 +797,7 @@ class Trainer:
loss_dict = self._detach_loss_dict(loss_dict)
# update avg stats
update_eval_values = dict()
update_eval_values = {}
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_eval.update_values(update_eval_values)
@ -784,6 +828,7 @@ class Trainer:
loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
outputs, _ = self.eval_step(batch, cur_step)
loader_start_time = time.time()
# plot epoch stats, artifacts and figures
if self.args.rank == 0:
figures, audios = None, None
@ -800,7 +845,7 @@ class Trainer:
def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"):
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader(
self.ap,
@ -810,27 +855,43 @@ class Trainer:
if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None)
if self.num_gpus > 1:
figures, audios = self.model.module.test_run(self.ap, samples, None)
else:
figures, audios = self.model.test_run(self.ap, samples, None)
else:
figures, audios = self.model.test_run(self.ap)
if self.num_gpus > 1:
figures, audios = self.model.module.test_run(self.ap)
else:
figures, audios = self.model.test_run(self.ap)
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.dashboard_logger.test_figures(self.total_steps_done, figures)
def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs."""
def _restore_best_loss(self):
"""Restore the best loss from the args.best_path if provided else
from the model (`args.restore_path` or `args.continue_path`) used for resuming the training"""
if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"]
ch = load_fsspec(self.args.restore_path, map_location="cpu")
if "model_loss" in ch:
self.best_loss = ch["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.")
def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs."""
self._restore_best_loss()
self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs):
if self.num_gpus > 1:
# let all processes sync up before starting with a new epoch of training
dist.barrier()
self.callbacks.on_epoch_start()
self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs)
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
@ -839,20 +900,26 @@ class Trainer:
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
)
self.save_best_model()
if self.args.rank in [None, 0]:
self.save_best_model()
self.callbacks.on_epoch_end()
def fit(self) -> None:
"""Where the ✨magic✨ happens..."""
try:
self._fit()
self.dashboard_logger.finish()
if self.args.rank == 0:
self.dashboard_logger.finish()
except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt()
# if the output folder is empty remove the run.
remove_experiment_folder(self.output_path)
# clear the DDP processes
if self.num_gpus > 1:
dist.destroy_process_group()
# finish the wandb run and sync data
self.dashboard_logger.finish()
if self.args.rank == 0:
self.dashboard_logger.finish()
# stop without error signal
try:
sys.exit(0)
@ -902,18 +969,30 @@ class Trainer:
keep_after=self.config.keep_after,
)
@staticmethod
def _setup_logger_config(log_file: str) -> None:
handlers = [logging.StreamHandler()]
def _setup_logger_config(self, log_file: str) -> None:
"""Write log strings to a file and print logs to the terminal.
TODO: Causes formatting issues in pdb debugging."""
# Only add a log file if the output location is local due to poor
# support for writing logs to file-like objects.
parsed_url = urlparse(log_file)
if not parsed_url.scheme or parsed_url.scheme == "file":
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
handlers.append(logging.FileHandler(schemeless_path))
class Logger(object):
def __init__(self, print_to_terminal=True):
self.print_to_terminal = print_to_terminal
self.terminal = sys.stdout
self.log_file = log_file
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
# don't let processes rank > 0 write to the terminal
sys.stdout = Logger(self.args.rank == 0)
@staticmethod
def _is_apex_available() -> bool:
@ -1109,8 +1188,6 @@ def process_args(args, config=None):
config = register_config(config_base.model)()
# override values from command-line args
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
if config.mixed_precision:
print(" > Mixed precision mode is ON")
experiment_path = args.continue_path
if not experiment_path:
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
@ -1130,8 +1207,7 @@ def process_args(args, config=None):
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
dashboard_logger = init_logger(config)
dashboard_logger = init_dashboard_logger(config)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger

View File

@ -96,7 +96,7 @@ class VitsConfig(BaseTTSConfig):
model_args: VitsArgs = field(default_factory=VitsArgs)
# optimizer
grad_clip: List[float] = field(default_factory=lambda: [5, 5])
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"
@ -113,14 +113,16 @@ class VitsConfig(BaseTTSConfig):
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
# data loader params
return_wav: bool = True
compute_linear_spec: bool = True
# overrides
min_seq_len: int = 13
max_seq_len: int = 500
sort_by_audio_len: bool = True
min_seq_len: int = 0
max_seq_len: int = 500000
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

View File

@ -69,7 +69,7 @@ class TTSDataset(Dataset):
batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed
by the loader. Filter out input sequences that are shorter than this. Some models have a
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
minimum input length due to its architecture. Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
@ -302,10 +302,23 @@ class TTSDataset(Dataset):
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
def sort_items(self):
r"""Sort instances based on text length in ascending order"""
lengths = np.array([len(ins[0]) for ins in self.items])
def sort_and_filter_items(self, by_audio_len=False):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
"""
# compute the target sequence length
if by_audio_len:
lengths = []
for item in self.items:
lengths.append(os.path.getsize(item[1]))
lengths = np.array(lengths)
else:
lengths = np.array([len(ins[0]) for ins in self.items])
# sort items based on the sequence length in ascending order
idxs = np.argsort(lengths)
new_items = []
ignored = []
@ -315,7 +328,10 @@ class TTSDataset(Dataset):
ignored.append(idx)
else:
new_items.append(self.items[idx])
# shuffle batch groups
# create batches with similar length items
# the larger the `batch_group_size`, the higher the length variety in a batch.
if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
@ -325,6 +341,7 @@ class TTSDataset(Dataset):
new_items[offset:end_offset] = temp_items
self.items = new_items
# logging
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))

View File

@ -66,7 +66,7 @@ def load_meta_data(datasets, eval_split=True):
def load_attention_mask_meta_data(metafile_path):
"""Load meta data file created by compute_attention_masks.py"""
with open(metafile_path, "r") as f:
with open(metafile_path, "r", encoding="utf-8") as f:
lines = f.readlines()
meta_data = []

View File

@ -19,7 +19,7 @@ def tweb(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "tweb"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
@ -33,7 +33,7 @@ def mozilla(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = cols[1].strip()
@ -77,7 +77,7 @@ def mailabs(root_path, meta_files=None):
continue
speaker_name = speaker_name_match.group("speaker_name")
print(" | > {}".format(csv_file))
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
if meta_files is None:
@ -102,7 +102,7 @@ def ljspeech(root_path, meta_file):
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
text = cols[2]
items.append([text, wav_file, speaker_name])
return items
@ -116,7 +116,7 @@ def ljspeech_test(root_path, meta_file):
for idx, line in enumerate(ttf):
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"])
return items
@ -158,7 +158,7 @@ def css10(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
@ -172,7 +172,7 @@ def nancy(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "nancy"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1]
@ -185,7 +185,7 @@ def common_voice(root_path, meta_file):
"""Normalize the common voice meta data file to TTS format."""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("client_id"):
continue
@ -208,7 +208,7 @@ def libri_tts(root_path, meta_files=None):
for meta_file in meta_files:
_meta_file = os.path.basename(meta_file).split(".")[0]
with open(meta_file, "r") as ttf:
with open(meta_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("\t")
file_name = cols[0]
@ -245,7 +245,7 @@ def brspeech(root_path, meta_file):
"""BRSpeech 3.0 beta"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("wav_filename"):
continue
@ -268,7 +268,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers:
continue
with open(meta_file) as file_text:
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id])
@ -295,7 +295,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
def mls(root_path, meta_files=None):
"""http://www.openslr.org/94/"""
items = []
with open(os.path.join(root_path, meta_files), "r") as meta:
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
for line in meta:
file, text = line.split("\t")
text = text[:-1]
@ -329,7 +329,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
# if not exists meta file, crawl recursively for 'wav' files
if meta_file is not None:
with open(str(meta_file), "r") as f:
with open(str(meta_file), "r", encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()]
elif not cache_to.exists():
@ -346,12 +346,12 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
text = None # VoxCel does not provide transciptions, and they are not needed for training the SE
meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
cnt += 1
with open(str(cache_to), "w") as f:
with open(str(cache_to), "w", encoding="utf-8") as f:
f.write("".join(meta_data))
if cnt < expected_count:
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")
with open(str(cache_to), "r") as f:
with open(str(cache_to), "r", encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()]
@ -367,7 +367,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "baker"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
@ -380,7 +380,7 @@ def kokoro(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "kokoro"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
class FFTransformer(nn.Module):

View File

@ -524,6 +524,7 @@ class VitsGeneratorLoss(nn.Module):
self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
@ -590,10 +591,11 @@ class VitsGeneratorLoss(nn.Module):
scores_disc_fake,
feats_disc_fake,
feats_disc_real,
loss_duration,
):
"""
Shapes:
- wavefrom: :math:`[B, 1, T]`
- waveform : :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]`
@ -615,12 +617,14 @@ class VitsGeneratorLoss(nn.Module):
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
return_dict["loss_feat"] = loss_feat
return_dict["loss_mel"] = loss_mel
return_dict["loss_duration"] = loss_duration
return_dict["loss"] = loss
return return_dict
@ -651,7 +655,6 @@ class VitsDiscriminatorLoss(nn.Module):
return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + loss_disc
return_dict["loss_disc"] = loss_disc
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss
return return_dict

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
class GST(nn.Module):

View File

@ -388,8 +388,8 @@ class Decoder(nn.Module):
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx])
for idx, decoder_rnn in enumerate(self.decoder_rnns):
self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx])
# Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input

View File

@ -2,7 +2,7 @@ import torch
from torch import nn
from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
class DiscriminatorS(torch.nn.Module):
@ -60,18 +60,32 @@ class VitsDiscriminator(nn.Module):
def __init__(self, use_spectral_norm=False):
super().__init__()
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
periods = [2, 3, 5, 7, 11]
def forward(self, x):
self.nets = nn.ModuleList()
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
def forward(self, x, x_hat=None):
"""
Args:
x (Tensor): input waveform.
x (Tensor): ground truth waveform.
x_hat (Tensor): predicted waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
score_sd, feats_sd = self.sd(x)
return scores + [score_sd], feats + [feats_sd]
x_scores = []
x_hat_scores = [] if x_hat is not None else None
x_feats = []
x_hat_feats = [] if x_hat is not None else None
for net in self.nets:
x_score, x_feat = net(x)
x_scores.append(x_score)
x_feats.append(x_feat)
if x_hat is not None:
x_hat_score, x_hat_feat = net(x_hat)
x_hat_scores.append(x_hat_score)
x_hat_feats.append(x_hat_feat)
return x_scores, x_feats, x_hat_scores, x_hat_feats

View File

@ -228,7 +228,7 @@ class StochasticDurationPredictor(nn.Module):
h = self.post_pre(dr)
h = self.post_convs(h, x_mask)
h = self.post_proj(h) * x_mask
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = noise
# posterior encoder

View File

@ -2,8 +2,8 @@ from dataclasses import dataclass, field
from typing import Dict, Tuple
import torch
import torch.nn as nn
from coqpit import Coqpit
from torch import nn
from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder

View File

@ -2,6 +2,7 @@ import os
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
@ -164,7 +165,14 @@ class BaseTTS(BaseModel):
}
def get_data_loader(
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int
self,
config: Coqpit,
ap: AudioProcessor,
is_eval: bool,
data_items: List,
verbose: bool,
num_gpus: int,
rank: int = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@ -186,7 +194,7 @@ class BaseTTS(BaseModel):
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
# init dataloader
# init dataset
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
@ -212,13 +220,15 @@ class BaseTTS(BaseModel):
else None,
)
if config.use_phonemes and config.compute_input_seq_cache:
# pre-compute phonemes
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval:
dataset.items = self.train_data_items
else:
# precompute phonemes to have a better estimate of sequence lengths.
# precompute phonemes for precise estimate of sequence lengths.
# otherwise `dataset.sort_items()` uses raw text lengths
dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution
@ -228,9 +238,17 @@ class BaseTTS(BaseModel):
else:
self.train_data_items = dataset.items
dataset.sort_items()
# halt DDP processes for the main process to finish computing the phoneme cache
if num_gpus > 1:
dist.barrier()
# sort input sequences from short to long
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# init dataloader
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,

View File

@ -1,4 +1,6 @@
import math
from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Tuple
import torch
@ -11,8 +13,6 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
@ -119,7 +119,7 @@ class VitsArgs(Coqpit):
upsample_kernel_sizes_decoder (List[int]):
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
use_sdp (int):
use_sdp (bool):
Use Stochastic Duration Predictor. Defaults to True.
noise_scale (float):
@ -128,7 +128,7 @@ class VitsArgs(Coqpit):
inference_noise_scale (float):
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
length_scale (int):
length_scale (float):
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
noise_scale_dp (float):
@ -176,26 +176,26 @@ class VitsArgs(Coqpit):
num_heads_text_encoder: int = 2
num_layers_text_encoder: int = 6
kernel_size_text_encoder: int = 3
dropout_p_text_encoder: int = 0.1
dropout_p_duration_predictor: int = 0.1
dropout_p_text_encoder: float = 0.1
dropout_p_duration_predictor: float = 0.5
kernel_size_posterior_encoder: int = 5
dilation_rate_posterior_encoder: int = 1
num_layers_posterior_encoder: int = 16
kernel_size_flow: int = 5
dilation_rate_flow: int = 1
num_layers_flow: int = 4
resblock_type_decoder: int = "1"
resblock_type_decoder: str = "1"
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
upsample_initial_channel_decoder: int = 512
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
use_sdp: int = True
use_sdp: bool = True
noise_scale: float = 1.0
inference_noise_scale: float = 0.667
length_scale: int = 1
length_scale: float = 1
noise_scale_dp: float = 1.0
inference_noise_scale_dp: float = 0.8
inference_noise_scale_dp: float = 1.0
max_inference_len: int = None
init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False
@ -419,24 +419,23 @@ class Vits(BaseTTS):
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# duration predictor
attn_durations = attn.sum(3)
if self.args.use_sdp:
nll_duration = self.duration_predictor(
loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
outputs["nll_duration"] = nll_duration
loss_duration = loss_duration / torch.sum(x_mask)
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
@ -445,7 +444,7 @@ class Vits(BaseTTS):
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
outputs["loss_duration"] = loss_duration
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
@ -563,8 +562,9 @@ class Vits(BaseTTS):
outputs["waveform_seg"] = wav_seg
# compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
_, outputs["feats_disc_real"] = self.disc(wav_seg)
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
outputs["model_outputs"], wav_seg
)
# compute losses
with autocast(enabled=False): # use float32 for the criterion
@ -579,23 +579,17 @@ class Vits(BaseTTS):
scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"],
)
# handle the duration loss
if self.args.use_sdp:
loss_dict["nll_duration"] = outputs["nll_duration"]
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["nll_duration"]
elif optimizer_idx == 1:
# discriminator pass
outputs = {}
# compute scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(
self.y_disc_cache.detach(), self.wav_seg_disc_cache
)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
@ -686,14 +680,21 @@ class Vits(BaseTTS):
Returns:
List: optimizers.
"""
self.disc.requires_grad_(False)
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
self.disc.requires_grad_(True)
optimizer1 = get_optimizer(
gen_parameters = chain(
self.text_encoder.parameters(),
self.posterior_encoder.parameters(),
self.flow.parameters(),
self.duration_predictor.parameters(),
self.waveform_decoder.parameters(),
)
# add the speaker embedding layer
if hasattr(self, "emb_g"):
gen_parameters = chain(gen_parameters, self.emb_g)
optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
return [optimizer1, optimizer2]
optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
return [optimizer0, optimizer1]
def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
@ -712,9 +713,9 @@ class Vits(BaseTTS):
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2]
scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler0, scheduler1]
def get_criterion(self):
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in

View File

@ -225,9 +225,10 @@ def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_sy
if custom_symbols is not None:
_symbols = custom_symbols
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
elif tp:
_symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = ""
for symbol_id in sequence:

View File

@ -241,6 +241,7 @@ class AudioProcessor(object):
self.sample_rate = sample_rate
self.resample = resample
self.num_mels = num_mels
self.log_func = log_func
self.min_level_db = min_level_db or 0
self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms

View File

@ -1,8 +1,6 @@
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Variable
def reduce_tensor(tensor, num_gpus):
@ -20,46 +18,3 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
# Initialize distributed communication
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
def apply_gradient_allreduce(module):
# sync model parameters
for p in module.state_dict().values():
if not torch.is_tensor(p):
continue
dist.broadcast(p, 0)
def allreduce_params():
if module.needs_reduction:
module.needs_reduction = False
# bucketing params based on value types
buckets = {}
for param in module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.reduce_op.SUM)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(module.parameters()):
def allreduce_hook(*_):
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access
if param.requires_grad:
param.register_hook(allreduce_hook)
def set_needs_reduction(self, *_):
self.needs_reduction = True
module.register_forward_hook(set_needs_reduction)
return module

View File

@ -53,7 +53,6 @@ def get_commit_hash():
# Not copying .git folder into docker container
except (subprocess.CalledProcessError, FileNotFoundError):
commit = "0000000"
print(" > Git Hash: {}".format(commit))
return commit
@ -62,7 +61,6 @@ def get_experiment_folder_path(root_path, model_name):
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
print(" > Experiment folder: {}".format(output_folder))
return output_folder

View File

@ -3,7 +3,7 @@ import json
import os
import pickle as pickle_tts
import shutil
from typing import Any
from typing import Any, Callable, Dict, Union
import fsspec
import torch
@ -53,18 +53,23 @@ def copy_model_files(config: Coqpit, out_path, new_fields):
shutil.copyfileobj(source_file, target_file)
def load_fsspec(path: str, **kwargs) -> Any:
def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
**kwargs,
) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
map_location: torch.device or str.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
with fsspec.open(path, "rb") as f:
return torch.load(f, **kwargs)
return torch.load(f, map_location=map_location, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin

View File

@ -3,7 +3,7 @@ from TTS.utils.logging.tensorboard_logger import TensorboardLogger
from TTS.utils.logging.wandb_logger import WandbLogger
def init_logger(config):
def init_dashboard_logger(config):
if config.dashboard_logger == "tensorboard":
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)

View File

@ -29,11 +29,13 @@ class ConsoleLogger:
now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S")
def print_epoch_start(self, epoch, max_epoch):
def print_epoch_start(self, epoch, max_epoch, output_path=None):
print(
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
flush=True,
)
if output_path is not None:
print(f" --> {output_path}")
def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")

View File

@ -110,7 +110,7 @@ class ModelManager(object):
os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}")
output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_speakers_path = os.path.join(output_path, "speakers.json")
# download files to the output path
if self._check_dict_key(model_item, "github_rls_url"):
# download from github release
@ -122,22 +122,52 @@ class ModelManager(object):
if self._check_dict_key(model_item, "stats_file"):
self._download_gdrive_file(model_item["stats_file"], output_stats_path)
# update the scale_path.npy file path in the model config.json
if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path):
# set scale stats path in config.json
config_path = output_config_path
config = load_config(config_path)
config.audio.stats_path = output_stats_path
config.save_json(config_path)
# update the speakers.json file path in the model config.json to the current path
if os.path.exists(output_speakers_path):
# set scale stats path in config.json
config_path = output_config_path
config = load_config(config_path)
config.d_vector_file = output_speakers_path
config.save_json(config_path)
# update paths in the config.json
self._update_paths(output_path, output_config_path)
return output_model_path, output_config_path, model_item
def _update_paths(self, output_path: str, config_path: str) -> None:
"""Update paths for certain files in config.json after download.
Args:
output_path (str): local path the model is downloaded to.
config_path (str): local config.json path.
"""
output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
# update the scale_path.npy file path in the model config.json
self._update_path("audio.stats_path", output_stats_path, config_path)
# update the speakers.json file path in the model config.json to the current path
self._update_path("d_vector_file", output_d_vector_file_path, config_path)
self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
# update the speaker_ids.json file path in the model config.json to the current path
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
@staticmethod
def _update_path(field_name, new_path, config_path):
"""Update the path in the model config.json for the current environment after download"""
if os.path.exists(new_path):
config = load_config(config_path)
field_names = field_name.split(".")
if len(field_names) > 1:
# field name points to a sub-level field
sub_conf = config
for fd in field_names[:-1]:
if fd in sub_conf:
sub_conf = sub_conf[fd]
else:
return
sub_conf[field_names[-1]] = new_path
else:
# field name points to a top-level field
config[field_name] = new_path
config.save_json(config_path)
def _download_gdrive_file(self, gdrive_idx, output):
"""Download files from GDrive using their file ids"""
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)

View File

@ -251,7 +251,7 @@ class Synthesizer(object):
d_vector=speaker_embedding,
)
waveform = outputs["wav"]
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().numpy()
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
if not use_gl:
# denormalize tts output based on tts audio config
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T

View File

@ -1,5 +1,5 @@
import importlib
from typing import Dict, List
from typing import Dict, List, Tuple
import torch
@ -10,9 +10,20 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]:
"""Setup PyTorch environment for training.
Args:
cudnn_enable (bool): Enable/disable CUDNN.
cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is
variable between batches.
use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise.
Returns:
Tuple[bool, int]: is cuda on or off and number of GPUs in the environment.
"""
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
if num_gpus > 1 and not use_ddp:
raise RuntimeError(
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
)

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
@ -95,7 +96,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# model specific params
discriminator_model: str = "univnet_discriminator"
generator_model: str = "univnet_generator"
generator_model_params: dict = field(
generator_model_params: Dict = field(
default_factory=lambda: {
"in_channels": 64,
"out_channels": 1,
@ -120,7 +121,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# loss weights - overrides
stft_loss_weight: float = 2.5
stft_loss_params: dict = field(
stft_loss_params: Dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
@ -132,7 +133,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 0
l1_spec_loss_weight: float = 0
l1_spec_loss_params: dict = field(
l1_spec_loss_params: Dict = field(
default_factory=lambda: {
"use_mel": True,
"sample_rate": 22050,
@ -152,7 +153,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
steps_to_start_discriminator: int = 200000
def __post_init__(self):

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import weight_norm

View File

@ -1,8 +1,8 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from TTS.utils.io import load_fsspec

View File

@ -40,8 +40,8 @@ class MelganMultiscaleDiscriminator(nn.Module):
)
def forward(self, x):
scores = list()
feats = list()
scores = []
feats = []
for disc in self.discriminators:
score, feat = disc(x)
scores.append(score)

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm, weight_norm
from TTS.utils.audio import TorchSTFT

View File

@ -9,12 +9,12 @@ from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results
@ -33,7 +33,7 @@ class WavegradArgs(Coqpit):
)
class Wavegrad(BaseModel):
class Wavegrad(BaseVocoder):
"""🐸 🌊 WaveGrad 🌊 model.
Paper - https://arxiv.org/abs/2009.00713
@ -257,14 +257,18 @@ class Wavegrad(BaseModel):
loss = criterion(noise, noise_hat)
return {"model_output": noise_hat}, {"loss": loss}
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def train_log( # pylint: disable=no-self-use
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
return None, None
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def eval_log( # pylint: disable=no-self-use
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
return None, None
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
@ -291,7 +295,8 @@ class Wavegrad(BaseModel):
def get_scheduler(self, optimizer):
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
def get_criterion(self):
@staticmethod
def get_criterion():
return torch.nn.L1Loss()
@staticmethod

View File

@ -5,9 +5,9 @@ from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

View File

@ -3,28 +3,24 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.io import load_config\n",
"from TTS.config import load_config\n",
"\n",
"import IPython.display as ipd\n",
"import glob"
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"config_path = \"/home/erogol/gdrive/Projects/TTS/recipes/ljspeech/align_tts/config_transformer2.json\"\n",
"data_path = \"/home/erogol/gdrive/Datasets/LJSpeech-1.1/\"\n",
@ -39,28 +35,28 @@
"\n",
"print(\"File list, by index:\")\n",
"dict(enumerate(file_paths))"
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Setup Audio Processor\n",
"Play with the AP parameters until you find a good fit with the synthesis speech below.\n",
"\n",
"The default values are loaded from your config.json file, so you only need to\n",
"uncomment and modify values below that you'd like to tune."
]
],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"tune_params={\n",
"# 'audio_processor': 'audio',\n",
@ -95,54 +91,54 @@
"tuned_config.update(tune_params)\n",
"\n",
"AP = AudioProcessor(**tuned_config);"
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Check audio loading "
]
],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"wav = AP.load_wav(SAMPLE_FILE_PATH)\n",
"ipd.Audio(data=wav, rate=AP.sample_rate) "
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Generate Mel-Spectrogram and Re-synthesis with GL"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"AP.power = 1.5"
]
},
{
"cell_type": "code",
"execution_count": null,
],
"metadata": {
"Collapsed": "false"
},
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"AP.power = 1.5"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"mel = AP.melspectrogram(wav)\n",
"print(\"Max:\", mel.max())\n",
@ -152,24 +148,24 @@
"\n",
"wav_gen = AP.inv_melspectrogram(mel)\n",
"ipd.Audio(wav_gen, rate=AP.sample_rate)"
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Generate Linear-Spectrogram and Re-synthesis with GL"
]
],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"spec = AP.spectrogram(wav)\n",
"print(\"Max:\", spec.max())\n",
@ -179,26 +175,26 @@
"\n",
"wav_gen = AP.inv_spectrogram(spec)\n",
"ipd.Audio(wav_gen, rate=AP.sample_rate)"
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Compare values for a certain parameter\n",
"\n",
"Optimize your parameters by comparing different values per parameter at a time."
]
],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"from librosa import display\n",
"from matplotlib import pylab as plt\n",
@ -238,36 +234,39 @@
" val = values[idx]\n",
" print(\" > {} = {}\".format(attribute, val))\n",
" IPython.display.display(IPython.display.Audio(wav_gen, rate=AP.sample_rate))"
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"compare_values(\"preemphasis\", [0, 0.5, 0.97, 0.98, 0.99])"
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"compare_values(\"ref_level_db\", [2, 5, 10, 15, 20, 25, 30, 35, 40, 1000])"
]
],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"name": "python3",
"display_name": "Python 3.8.5 64-bit ('torch': conda)"
},
"language_info": {
"codemirror_mode": {
@ -280,8 +279,11 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"interpreter": {
"hash": "27648abe09795c3a768a281b31f7524fcf66a207e733f8ecda3a4e1fd1059fb0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@ -43,7 +43,7 @@ def process_meta_data(path):
meta_data = {}
# load meta data
with open(path, "r") as f:
with open(path, "r", encoding="utf-8") as f:
data = csv.reader(f, delimiter="|")
for row in data:
frames = int(row[2])
@ -92,7 +92,7 @@ def save_training(file_path, meta_data):
rows.append(d["row"] + "\n")
random.shuffle(rows)
with open(file_path, "w+") as f:
with open(file_path, "w+", encoding="utf-8") as f:
for row in rows:
f.write(row)
@ -156,7 +156,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
phonemes = {}
with open(train_path, "r") as f:
with open(train_path, "r", encoding="utf-8") as f:
data = csv.reader(f, delimiter="|")
phonemes["None"] = 0
for row in data:
@ -174,9 +174,9 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
phonemes["None"] += 1
x, y = [], []
for key in phonemes:
x.append(key)
y.append(phonemes[key])
for k, v in phonemes.items():
x.append(k)
y.append(v)
plt.figure()
plt.rcParams["figure.figsize"] = (50, 20)

View File

@ -29,7 +29,7 @@ config = VitsConfig(
run_name="vits_ljspeech",
batch_size=48,
eval_batch_size=16,
batch_group_size=0,
batch_group_size=5,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
@ -43,7 +43,7 @@ config = VitsConfig(
print_step=25,
print_eval=True,
mixed_precision=True,
max_seq_len=5000,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
)

View File

@ -2,4 +2,4 @@ black
coverage
isort
nose
pylint==2.8.3
pylint==2.10.2

View File

@ -124,7 +124,7 @@ class TestTTSDataset(unittest.TestCase):
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
dataloader.dataset.sort_items()
dataloader.dataset.sort_and_filter_items()
is_items_reordered = False
for idx, item in enumerate(dataloader.dataset.items):
if item != frames[idx]: