diff --git a/.gitignore b/.gitignore index 673d01da..951d3132 100644 --- a/.gitignore +++ b/.gitignore @@ -155,4 +155,5 @@ deps.json speakers.json internal/* *_pitch.npy -*_phoneme.npy \ No newline at end of file +*_phoneme.npy +wandb \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 7293f5ad..6e9f953e 100644 --- a/.pylintrc +++ b/.pylintrc @@ -64,6 +64,11 @@ disable=missing-docstring, too-many-public-methods, too-many-lines, bare-except, + ## for avoiding weird p3.6 CI linter error + ## TODO: see later if we can remove this + assigning-non-slot, + unsupported-assignment-operation, + ## end line-too-long, fixme, wrong-import-order, @@ -73,6 +78,7 @@ disable=missing-docstring, invalid-name, too-many-instance-attributes, arguments-differ, + arguments-renamed, no-name-in-module, no-member, unsubscriptable-object, diff --git a/README.md b/README.md index a53f391b..9b448a75 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ You can also help us implement more models. ## Install TTS 🐸TTS is tested on Ubuntu 18.04 with **python >= 3.6, < 3.9**. -If you are only interested in [synthesizing speech](https://github.com/coqui-ai/TTS/tree/dev#example-synthesizing-speech-on-terminal-using-the-released-models) with the released 🐸TTS models, installing from PyPI is the easiest option. +If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option. ```bash pip install TTS diff --git a/TTS/VERSION b/TTS/VERSION index 341cf11f..7dff5b89 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.2.0 \ No newline at end of file +0.2.1 \ No newline at end of file diff --git a/TTS/__init__.py b/TTS/__init__.py index 5162d4ec..eaf05db1 100644 --- a/TTS/__init__.py +++ b/TTS/__init__.py @@ -1,6 +1,6 @@ import os -with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: +with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f: version = f.read().strip() __version__ = version diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 88d60d7d..3a5c067e 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -97,7 +97,7 @@ Example run: enable_eos_bos=C.enable_eos_bos_chars, ) - dataset.sort_items() + dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) loader = DataLoader( dataset, batch_size=args.batch_size, @@ -158,7 +158,7 @@ Example run: # ourput metafile metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") - with open(metafile, "w") as f: + with open(metafile, "w", encoding="utf-8") as f: for p in file_paths: f.write(f"{p[0]}|{p[1]}\n") print(f" >> Metafile created: {metafile}") diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index e05747d0..06d5f388 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -32,6 +32,7 @@ def main(): command.append("--restore_path={}".format(args.restore_path)) command.append("--config_path={}".format(args.config_path)) command.append("--group_id=group_{}".format(group_id)) + command.append("--use_ddp=true") command += unargs command.append("") @@ -42,7 +43,7 @@ def main(): my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) command[-1] = "--rank={}".format(i) # prevent stdout for processes with rank != 0 - stdout = None if i == 0 else open(os.devnull, "w") + stdout = None p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with processes.append(p) print(command) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index debe5933..6ec99fac 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -46,7 +46,7 @@ def setup_loader(ap, r, verbose=False): if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. dataset.compute_input_seq(c.num_loader_workers) - dataset.sort_items() + dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False)) loader = DataLoader( dataset, @@ -215,7 +215,7 @@ def extract_spectrograms( wav = ap.inv_melspectrogram(mel) ap.save_wav(wav, wav_gl_path) - with open(os.path.join(output_path, metada_name), "w") as f: + with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f: for data in export_metadata: f.write(f"{data[0]}|{data[1]+'.npy'}\n") diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index caa169d9..af054346 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -190,7 +190,7 @@ class BaseTrainingConfig(Coqpit): Name of the model that is used in the training. run_name (str): - Name of the experiment. This prefixes the output folder name. + Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`. run_description (str): Short description of the experiment. @@ -272,7 +272,7 @@ class BaseTrainingConfig(Coqpit): """ model: str = None - run_name: str = "" + run_name: str = "coqui_tts" run_description: str = "" # training params epochs: int = 10000 diff --git a/TTS/model.py b/TTS/model.py index aefb925e..cfd1ec62 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -23,35 +23,31 @@ class BaseModel(nn.Module, ABC): """ @abstractmethod - def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict: + def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: """Forward pass for the model mainly used in training. - You can be flexible here and use different number of arguments and argument names since it is mostly used by - `train_step()` in training whitout exposing it to the out of the class. + You can be flexible here and use different number of arguments and argument names since it is intended to be + used by `train_step()` without exposing it out of the model. Args: - text (torch.Tensor): Input text character sequence ids. + input (torch.Tensor): Input tensor. aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs. - for the model. Returns: - Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model. + Dict: Model outputs. Main model output must be named as "model_outputs". """ outputs_dict = {"model_outputs": None} ... return outputs_dict @abstractmethod - def inference(self, text: torch.Tensor, aux_input={}) -> Dict: + def inference(self, input: torch.Tensor, aux_input={}) -> Dict: """Forward pass for inference. - After the model is trained this is the only function that connects the model the out world. - - This function must only take a `text` input and a dictionary that has all the other model specific inputs. We don't use `*kwargs` since it is problematic with the TorchScript API. Args: - text (torch.Tensor): [description] + input (torch.Tensor): [description] aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc. Returns: diff --git a/TTS/speaker_encoder/losses.py b/TTS/speaker_encoder/losses.py index ac7e62bf..8ba917b7 100644 --- a/TTS/speaker_encoder/losses.py +++ b/TTS/speaker_encoder/losses.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn # adapted from https://github.com/cvqluu/GE2E-Loss diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index f121631b..fcc850d7 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -1,6 +1,6 @@ import numpy as np import torch -import torch.nn as nn +from torch import nn from TTS.utils.io import load_fsspec diff --git a/TTS/speaker_encoder/speaker_encoder_config.py b/TTS/speaker_encoder/speaker_encoder_config.py index e830a0f5..8212acc7 100644 --- a/TTS/speaker_encoder/speaker_encoder_config.py +++ b/TTS/speaker_encoder/speaker_encoder_config.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import List +from typing import Dict, List from coqpit import MISSING @@ -14,7 +14,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig): audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # model params - model_params: dict = field( + model_params: Dict = field( default_factory=lambda: { "model_name": "lstm", "input_dim": 80, @@ -25,9 +25,9 @@ class SpeakerEncoderConfig(BaseTrainingConfig): } ) - audio_augmentation: dict = field(default_factory=lambda: {}) + audio_augmentation: Dict = field(default_factory=lambda: {}) - storage: dict = field( + storage: Dict = field( default_factory=lambda: { "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage "storage_size": 15, # the size of the in-memory storage with respect to a single batch diff --git a/TTS/speaker_encoder/utils/prepare_voxceleb.py b/TTS/speaker_encoder/utils/prepare_voxceleb.py index 05a65bea..b93baf9e 100644 --- a/TTS/speaker_encoder/utils/prepare_voxceleb.py +++ b/TTS/speaker_encoder/utils/prepare_voxceleb.py @@ -94,7 +94,8 @@ def download_and_extract(directory, subset, urls): extract_path = zip_filepath.strip(".zip") # check zip file md5sum - md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest() + with open(zip_filepath, "rb") as f_zip: + md5 = hashlib.md5(f_zip.read()).hexdigest() if md5 != MD5SUM[subset]: raise ValueError("md5sum of %s mismatch" % zip_filepath) diff --git a/TTS/trainer.py b/TTS/trainer.py index 4267f120..68b45fe2 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import importlib -import logging import multiprocessing import os import platform @@ -16,6 +15,7 @@ from urllib.parse import urlparse import fsspec import torch +import torch.distributed as dist from coqpit import Coqpit from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP_th @@ -38,7 +38,7 @@ from TTS.utils.generic_utils import ( to_cuda, ) from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint -from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger +from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model as setup_vocoder_model @@ -77,12 +77,16 @@ class TrainingArgs(Coqpit): best_path: str = field( default="", metadata={ - "help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" + "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used" }, ) config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) + use_ddp: bool = field( + default=False, + metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."}, + ) class Trainer: @@ -144,6 +148,7 @@ class Trainer: >>> trainer.fit() TODO: + - Wrap model for not calling .module in DDP. - Accumulate gradients b/w batches. - Deepspeed integration - Profiler integration. @@ -151,29 +156,33 @@ class Trainer: - TPU training """ - # set and initialize Pytorch runtime - self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark) if config is None: # parse config from console arguments config, output_path, _, c_logger, dashboard_logger = process_args(args) - self.output_path = output_path self.args = args self.config = config + self.output_path = output_path self.config.output_log_path = output_path + + # setup logging + log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") + self._setup_logger_config(log_file) + + # set and initialize Pytorch runtime + self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp) + # init loggers self.c_logger = ConsoleLogger() if c_logger is None else c_logger self.dashboard_logger = dashboard_logger - if self.dashboard_logger is None: - self.dashboard_logger = init_logger(config) + # only allow dashboard logging for the main process in DDP mode + if self.dashboard_logger is None and args.rank == 0: + self.dashboard_logger = init_dashboard_logger(config) if not self.config.log_model_step: self.config.log_model_step = self.config.save_step - log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") - self._setup_logger_config(log_file) - self.total_steps_done = 0 self.epochs_done = 0 self.restore_step = 0 @@ -247,10 +256,10 @@ class Trainer: if self.use_apex: self.scaler = None self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") - if isinstance(self.optimizer, list): - self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer) - else: - self.scaler = torch.cuda.amp.GradScaler() + # if isinstance(self.optimizer, list): + # self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer) + # else: + self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None @@ -319,14 +328,14 @@ class Trainer: return obj print(" > Restoring from %s ..." % os.path.basename(restore_path)) - checkpoint = load_fsspec(restore_path) + checkpoint = load_fsspec(restore_path, map_location="cpu") try: print(" > Restoring Model...") model.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: - print(" > Restoring AMP Scaler...") + print(" > Restoring Scaler...") scaler = _restore_list_objs(checkpoint["scaler"], scaler) except (KeyError, RuntimeError): print(" > Partial model initialization...") @@ -346,10 +355,11 @@ class Trainer: " > Model restored from step %d" % checkpoint["step"], ) restore_step = checkpoint["step"] + torch.cuda.empty_cache() return model, optimizer, scaler, restore_step - @staticmethod def _get_loader( + self, model: nn.Module, config: Coqpit, ap: AudioProcessor, @@ -358,8 +368,14 @@ class Trainer: verbose: bool, num_gpus: int, ) -> DataLoader: - if hasattr(model, "get_data_loader"): - loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) + if num_gpus > 1: + if hasattr(model.module, "get_data_loader"): + loader = model.module.get_data_loader( + config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank + ) + else: + if hasattr(model, "get_data_loader"): + loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) return loader def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader: @@ -387,12 +403,28 @@ class Trainer: Returns: Dict: Formatted batch. """ - batch = self.model.format_batch(batch) + if self.num_gpus > 1: + batch = self.model.module.format_batch(batch) + else: + batch = self.model.format_batch(batch) if self.use_cuda: for k, v in batch.items(): batch[k] = to_cuda(v) return batch + @staticmethod + def master_params(optimizer: torch.optim.Optimizer): + """Generator over parameters owned by the optimizer. + + Used to select parameters used by the optimizer for gradient clipping. + + Args: + optimizer: Target optimizer. + """ + for group in optimizer.param_groups: + for p in group["params"]: + yield p + @staticmethod def _model_train_step( batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None @@ -450,6 +482,8 @@ class Trainer: step_start_time = time.time() # zero-out optimizer optimizer.zero_grad() + + # forward pass and loss computation with torch.cuda.amp.autocast(enabled=config.mixed_precision): if optimizer_idx is not None: outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx) @@ -461,9 +495,9 @@ class Trainer: step_time = time.time() - step_start_time return None, {}, step_time - # check nan loss - if torch.isnan(loss_dict["loss"]).any(): - raise RuntimeError(f" > Detected NaN loss - {loss_dict}.") + # # check nan loss + # if torch.isnan(loss_dict["loss"]).any(): + # raise RuntimeError(f" > NaN loss detected - {loss_dict}") # set gradient clipping threshold if "grad_clip" in config and config.grad_clip is not None: @@ -481,6 +515,8 @@ class Trainer: update_lr_scheduler = True if self.use_amp_scaler: if self.use_apex: + # TODO: verify AMP use for GAN training in TTS + # https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss: scaled_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( @@ -491,7 +527,9 @@ class Trainer: scaler.scale(loss_dict["loss"]).backward() if grad_clip > 0: scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False) + grad_norm = torch.nn.utils.clip_grad_norm_( + self.master_params(optimizer), grad_clip, error_if_nonfinite=False + ) # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN if torch.isnan(grad_norm) or torch.isinf(grad_norm): grad_norm = 0 @@ -571,7 +609,8 @@ class Trainer: total_step_time = 0 for idx, optimizer in enumerate(self.optimizer): criterion = self.criterion - scaler = self.scaler[idx] if self.use_amp_scaler else None + # scaler = self.scaler[idx] if self.use_amp_scaler else None + scaler = self.scaler scheduler = self.scheduler[idx] outputs, loss_dict_new, step_time = self._optimize( batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx @@ -592,13 +631,13 @@ class Trainer: outputs = outputs_per_optimizer # update avg runtime stats - keep_avg_update = dict() + keep_avg_update = {} keep_avg_update["avg_loader_time"] = loader_time keep_avg_update["avg_step_time"] = step_time self.keep_avg_train.update_values(keep_avg_update) # update avg loss stats - update_eval_values = dict() + update_eval_values = {} for key, value in loss_dict.items(): update_eval_values["avg_" + key] = value self.keep_avg_train.update_values(update_eval_values) @@ -662,9 +701,10 @@ class Trainer: if audios is not None: self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate) + self.dashboard_logger.flush() + self.total_steps_done += 1 self.callbacks.on_train_step_end() - self.dashboard_logger.flush() return outputs, loss_dict def train_epoch(self) -> None: @@ -674,16 +714,20 @@ class Trainer: self.data_train, verbose=True, ) - self.model.train() + if self.num_gpus > 1: + self.model.module.train() + else: + self.model.train() epoch_start_time = time.time() if self.use_cuda: batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) else: batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) self.c_logger.print_train_start() + loader_start_time = time.time() for cur_step, batch in enumerate(self.train_loader): - loader_start_time = time.time() _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) + loader_start_time = time.time() epoch_time = time.time() - epoch_start_time # Plot self.epochs_done Stats if self.args.rank == 0: @@ -753,7 +797,7 @@ class Trainer: loss_dict = self._detach_loss_dict(loss_dict) # update avg stats - update_eval_values = dict() + update_eval_values = {} for key, value in loss_dict.items(): update_eval_values["avg_" + key] = value self.keep_avg_eval.update_values(update_eval_values) @@ -784,6 +828,7 @@ class Trainer: loader_time = time.time() - loader_start_time self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) outputs, _ = self.eval_step(batch, cur_step) + loader_start_time = time.time() # plot epoch stats, artifacts and figures if self.args.rank == 0: figures, audios = None, None @@ -800,7 +845,7 @@ class Trainer: def test_run(self) -> None: """Run test and log the results. Test run must be defined by the model. Model must return figures and audios to be logged by the Tensorboard.""" - if hasattr(self.model, "test_run"): + if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")): if self.eval_loader is None: self.eval_loader = self.get_eval_dataloader( self.ap, @@ -810,27 +855,43 @@ class Trainer: if hasattr(self.eval_loader.dataset, "load_test_samples"): samples = self.eval_loader.dataset.load_test_samples(1) - figures, audios = self.model.test_run(self.ap, samples, None) + if self.num_gpus > 1: + figures, audios = self.model.module.test_run(self.ap, samples, None) + else: + figures, audios = self.model.test_run(self.ap, samples, None) else: - figures, audios = self.model.test_run(self.ap) + if self.num_gpus > 1: + figures, audios = self.model.module.test_run(self.ap) + else: + figures, audios = self.model.test_run(self.ap) self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_figures(self.total_steps_done, figures) - def _fit(self) -> None: - """🏃 train -> evaluate -> test for the number of epochs.""" + def _restore_best_loss(self): + """Restore the best loss from the args.best_path if provided else + from the model (`args.restore_path` or `args.continue_path`) used for resuming the training""" if self.restore_step != 0 or self.args.best_path: print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") - self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"] + ch = load_fsspec(self.args.restore_path, map_location="cpu") + if "model_loss" in ch: + self.best_loss = ch["model_loss"] print(f" > Starting with loaded last best loss {self.best_loss}.") + def _fit(self) -> None: + """🏃 train -> evaluate -> test for the number of epochs.""" + self._restore_best_loss() + self.total_steps_done = self.restore_step for epoch in range(0, self.config.epochs): + if self.num_gpus > 1: + # let all processes sync up before starting with a new epoch of training + dist.barrier() self.callbacks.on_epoch_start() self.keep_avg_train = KeepAverage() self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.epochs_done = epoch - self.c_logger.print_epoch_start(epoch, self.config.epochs) + self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path) self.train_epoch() if self.config.run_eval: self.eval_epoch() @@ -839,20 +900,26 @@ class Trainer: self.c_logger.print_epoch_end( epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values ) - self.save_best_model() + if self.args.rank in [None, 0]: + self.save_best_model() self.callbacks.on_epoch_end() def fit(self) -> None: """Where the ✨️magic✨️ happens...""" try: self._fit() - self.dashboard_logger.finish() + if self.args.rank == 0: + self.dashboard_logger.finish() except KeyboardInterrupt: self.callbacks.on_keyboard_interrupt() # if the output folder is empty remove the run. remove_experiment_folder(self.output_path) + # clear the DDP processes + if self.num_gpus > 1: + dist.destroy_process_group() # finish the wandb run and sync data - self.dashboard_logger.finish() + if self.args.rank == 0: + self.dashboard_logger.finish() # stop without error signal try: sys.exit(0) @@ -902,18 +969,30 @@ class Trainer: keep_after=self.config.keep_after, ) - @staticmethod - def _setup_logger_config(log_file: str) -> None: - handlers = [logging.StreamHandler()] + def _setup_logger_config(self, log_file: str) -> None: + """Write log strings to a file and print logs to the terminal. + TODO: Causes formatting issues in pdb debugging.""" - # Only add a log file if the output location is local due to poor - # support for writing logs to file-like objects. - parsed_url = urlparse(log_file) - if not parsed_url.scheme or parsed_url.scheme == "file": - schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path) - handlers.append(logging.FileHandler(schemeless_path)) + class Logger(object): + def __init__(self, print_to_terminal=True): + self.print_to_terminal = print_to_terminal + self.terminal = sys.stdout + self.log_file = log_file - logging.basicConfig(level=logging.INFO, format="", handlers=handlers) + def write(self, message): + if self.print_to_terminal: + self.terminal.write(message) + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(message) + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass + + # don't let processes rank > 0 write to the terminal + sys.stdout = Logger(self.args.rank == 0) @staticmethod def _is_apex_available() -> bool: @@ -1109,8 +1188,6 @@ def process_args(args, config=None): config = register_config(config_base.model)() # override values from command-line args config.parse_known_args(coqpit_overrides, relaxed_parser=True) - if config.mixed_precision: - print(" > Mixed precision mode is ON") experiment_path = args.continue_path if not experiment_path: experiment_path = get_experiment_folder_path(config.output_path, config.run_name) @@ -1130,8 +1207,7 @@ def process_args(args, config=None): used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(config, experiment_path, new_fields) - - dashboard_logger = init_logger(config) + dashboard_logger = init_dashboard_logger(config) c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, dashboard_logger diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index df7dee26..58fc66ee 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -96,7 +96,7 @@ class VitsConfig(BaseTTSConfig): model_args: VitsArgs = field(default_factory=VitsArgs) # optimizer - grad_clip: List[float] = field(default_factory=lambda: [5, 5]) + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) lr_gen: float = 0.0002 lr_disc: float = 0.0002 lr_scheduler_gen: str = "ExponentialLR" @@ -113,14 +113,16 @@ class VitsConfig(BaseTTSConfig): gen_loss_alpha: float = 1.0 feat_loss_alpha: float = 1.0 mel_loss_alpha: float = 45.0 + dur_loss_alpha: float = 1.0 # data loader params return_wav: bool = True compute_linear_spec: bool = True # overrides - min_seq_len: int = 13 - max_seq_len: int = 500 + sort_by_audio_len: bool = True + min_seq_len: int = 0 + max_seq_len: int = 500000 r: int = 1 # DO NOT CHANGE add_blank: bool = True diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 89326c9c..5d38243e 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -69,7 +69,7 @@ class TTSDataset(Dataset): batch. Set 0 to disable. Defaults to 0. min_seq_len (int): Minimum input sequence length to be processed - by the loader. Filter out input sequences that are shorter than this. Some models have a + by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a minimum input length due to its architecture. Defaults to 0. max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. @@ -302,10 +302,23 @@ class TTSDataset(Dataset): for idx, p in enumerate(phonemes): self.items[idx][0] = p - def sort_items(self): - r"""Sort instances based on text length in ascending order""" - lengths = np.array([len(ins[0]) for ins in self.items]) + def sort_and_filter_items(self, by_audio_len=False): + r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length + range. + Args: + by_audio_len (bool): if True, sort by audio length else by text length. + """ + # compute the target sequence length + if by_audio_len: + lengths = [] + for item in self.items: + lengths.append(os.path.getsize(item[1])) + lengths = np.array(lengths) + else: + lengths = np.array([len(ins[0]) for ins in self.items]) + + # sort items based on the sequence length in ascending order idxs = np.argsort(lengths) new_items = [] ignored = [] @@ -315,7 +328,10 @@ class TTSDataset(Dataset): ignored.append(idx) else: new_items.append(self.items[idx]) + # shuffle batch groups + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. if self.batch_group_size > 0: for i in range(len(new_items) // self.batch_group_size): offset = i * self.batch_group_size @@ -325,6 +341,7 @@ class TTSDataset(Dataset): new_items[offset:end_offset] = temp_items self.items = new_items + # logging if self.verbose: print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Min length sequence: {}".format(np.min(lengths))) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index cbae78a7..a2520751 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -66,7 +66,7 @@ def load_meta_data(datasets, eval_split=True): def load_attention_mask_meta_data(metafile_path): """Load meta data file created by compute_attention_masks.py""" - with open(metafile_path, "r") as f: + with open(metafile_path, "r", encoding="utf-8") as f: lines = f.readlines() meta_data = [] diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index c057c51e..eee407a8 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -19,7 +19,7 @@ def tweb(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "tweb" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("\t") wav_file = os.path.join(root_path, cols[0] + ".wav") @@ -33,7 +33,7 @@ def mozilla(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "mozilla" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") wav_file = cols[1].strip() @@ -77,7 +77,7 @@ def mailabs(root_path, meta_files=None): continue speaker_name = speaker_name_match.group("speaker_name") print(" | > {}".format(csv_file)) - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") if meta_files is None: @@ -102,7 +102,7 @@ def ljspeech(root_path, meta_file): for line in ttf: cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") - text = cols[1] + text = cols[2] items.append([text, wav_file, speaker_name]) return items @@ -116,7 +116,7 @@ def ljspeech_test(root_path, meta_file): for idx, line in enumerate(ttf): cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") - text = cols[1] + text = cols[2] items.append([text, wav_file, f"ljspeech-{idx}"]) return items @@ -158,7 +158,7 @@ def css10(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "ljspeech" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) @@ -172,7 +172,7 @@ def nancy(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "nancy" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: utt_id = line.split()[1] text = line[line.find('"') + 1 : line.rfind('"') - 1] @@ -185,7 +185,7 @@ def common_voice(root_path, meta_file): """Normalize the common voice meta data file to TTS format.""" txt_file = os.path.join(root_path, meta_file) items = [] - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: if line.startswith("client_id"): continue @@ -208,7 +208,7 @@ def libri_tts(root_path, meta_files=None): for meta_file in meta_files: _meta_file = os.path.basename(meta_file).split(".")[0] - with open(meta_file, "r") as ttf: + with open(meta_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("\t") file_name = cols[0] @@ -245,7 +245,7 @@ def brspeech(root_path, meta_file): """BRSpeech 3.0 beta""" txt_file = os.path.join(root_path, meta_file) items = [] - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: if line.startswith("wav_filename"): continue @@ -268,7 +268,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"): if isinstance(test_speakers, list): # if is list ignore this speakers ids if speaker_id in test_speakers: continue - with open(meta_file) as file_text: + with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") items.append([text, wav_file, "VCTK_" + speaker_id]) @@ -295,7 +295,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"): def mls(root_path, meta_files=None): """http://www.openslr.org/94/""" items = [] - with open(os.path.join(root_path, meta_files), "r") as meta: + with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta: for line in meta: file, text = line.split("\t") text = text[:-1] @@ -329,7 +329,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx): # if not exists meta file, crawl recursively for 'wav' files if meta_file is not None: - with open(str(meta_file), "r") as f: + with open(str(meta_file), "r", encoding="utf-8") as f: return [x.strip().split("|") for x in f.readlines()] elif not cache_to.exists(): @@ -346,12 +346,12 @@ def _voxcel_x(root_path, meta_file, voxcel_idx): text = None # VoxCel does not provide transciptions, and they are not needed for training the SE meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") cnt += 1 - with open(str(cache_to), "w") as f: + with open(str(cache_to), "w", encoding="utf-8") as f: f.write("".join(meta_data)) if cnt < expected_count: raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") - with open(str(cache_to), "r") as f: + with open(str(cache_to), "r", encoding="utf-8") as f: return [x.strip().split("|") for x in f.readlines()] @@ -367,7 +367,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]: txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "baker" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: wav_name, text = line.rstrip("\n").split("|") wav_path = os.path.join(root_path, "clips_22", wav_name) @@ -380,7 +380,7 @@ def kokoro(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "kokoro" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 75631e0e..9e6b69ac 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn class FFTransformer(nn.Module): diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 171b0217..0ce4ada9 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -524,6 +524,7 @@ class VitsGeneratorLoss(nn.Module): self.kl_loss_alpha = c.kl_loss_alpha self.gen_loss_alpha = c.gen_loss_alpha self.feat_loss_alpha = c.feat_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha self.stft = TorchSTFT( c.audio.fft_size, @@ -590,10 +591,11 @@ class VitsGeneratorLoss(nn.Module): scores_disc_fake, feats_disc_fake, feats_disc_real, + loss_duration, ): """ Shapes: - - wavefrom: :math:`[B, 1, T]` + - waveform : :math:`[B, 1, T]` - waveform_hat: :math:`[B, 1, T]` - z_p: :math:`[B, C, T]` - logs_q: :math:`[B, C, T]` @@ -615,12 +617,14 @@ class VitsGeneratorLoss(nn.Module): loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha - loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha + loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl return_dict["loss_feat"] = loss_feat return_dict["loss_mel"] = loss_mel + return_dict["loss_duration"] = loss_duration return_dict["loss"] = loss return return_dict @@ -651,7 +655,6 @@ class VitsDiscriminatorLoss(nn.Module): return_dict = {} loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha - loss = loss + loss_disc - return_dict["loss_disc"] = loss_disc + loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss return return_dict diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index 0d3ed039..01a81e0b 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn class GST(nn.Module): diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index 47b5ea7e..bddaf449 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -388,8 +388,8 @@ class Decoder(nn.Module): decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) # Pass through the decoder RNNs - for idx in range(len(self.decoder_rnns)): - self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx]) + for idx, decoder_rnn in enumerate(self.decoder_rnns): + self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx]) # Residual connection decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input decoder_output = decoder_input diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index 650c9b61..e9d54713 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -2,7 +2,7 @@ import torch from torch import nn from torch.nn.modules.conv import Conv1d -from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator class DiscriminatorS(torch.nn.Module): @@ -60,18 +60,32 @@ class VitsDiscriminator(nn.Module): def __init__(self, use_spectral_norm=False): super().__init__() - self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm) - self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm) + periods = [2, 3, 5, 7, 11] - def forward(self, x): + self.nets = nn.ModuleList() + self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) + self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) + + def forward(self, x, x_hat=None): """ Args: - x (Tensor): input waveform. + x (Tensor): ground truth waveform. + x_hat (Tensor): predicted waveform. Returns: List[Tensor]: discriminator scores. List[List[Tensor]]: list of list of features from each layers of each discriminator. """ - scores, feats = self.mpd(x) - score_sd, feats_sd = self.sd(x) - return scores + [score_sd], feats + [feats_sd] + x_scores = [] + x_hat_scores = [] if x_hat is not None else None + x_feats = [] + x_hat_feats = [] if x_hat is not None else None + for net in self.nets: + x_score, x_feat = net(x) + x_scores.append(x_score) + x_feats.append(x_feat) + if x_hat is not None: + x_hat_score, x_hat_feat = net(x_hat) + x_hat_scores.append(x_hat_score) + x_hat_feats.append(x_hat_feat) + return x_scores, x_feats, x_hat_scores, x_hat_feats diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py index ae1edebb..53f7ca7c 100644 --- a/TTS/tts/layers/vits/stochastic_duration_predictor.py +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -228,7 +228,7 @@ class StochasticDurationPredictor(nn.Module): h = self.post_pre(dr) h = self.post_convs(h, x_mask) h = self.post_proj(h) * x_mask - noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask z_q = noise # posterior encoder diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index fb2fa697..2aa84cb2 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -2,8 +2,8 @@ from dataclasses import dataclass, field from typing import Dict, Tuple import torch -import torch.nn as nn from coqpit import Coqpit +from torch import nn from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.feed_forward.decoder import Decoder diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index cd4c33d0..922761cb 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -2,6 +2,7 @@ import os from typing import Dict, List, Tuple import torch +import torch.distributed as dist from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader @@ -164,7 +165,14 @@ class BaseTTS(BaseModel): } def get_data_loader( - self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int + self, + config: Coqpit, + ap: AudioProcessor, + is_eval: bool, + data_items: List, + verbose: bool, + num_gpus: int, + rank: int = None, ) -> "DataLoader": if is_eval and not config.run_eval: loader = None @@ -186,7 +194,7 @@ class BaseTTS(BaseModel): if hasattr(self, "make_symbols"): custom_symbols = self.make_symbols(self.config) - # init dataloader + # init dataset dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, @@ -212,13 +220,15 @@ class BaseTTS(BaseModel): else None, ) - if config.use_phonemes and config.compute_input_seq_cache: + # pre-compute phonemes + if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]: if hasattr(self, "eval_data_items") and is_eval: dataset.items = self.eval_data_items elif hasattr(self, "train_data_items") and not is_eval: dataset.items = self.train_data_items else: - # precompute phonemes to have a better estimate of sequence lengths. + # precompute phonemes for precise estimate of sequence lengths. + # otherwise `dataset.sort_items()` uses raw text lengths dataset.compute_input_seq(config.num_loader_workers) # TODO: find a more efficient solution @@ -228,9 +238,17 @@ class BaseTTS(BaseModel): else: self.train_data_items = dataset.items - dataset.sort_items() + # halt DDP processes for the main process to finish computing the phoneme cache + if num_gpus > 1: + dist.barrier() + # sort input sequences from short to long + dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) + + # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None + + # init dataloader loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 0e91c1f3..72c67df2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,4 +1,6 @@ +import math from dataclasses import dataclass, field +from itertools import chain from typing import Dict, List, Tuple import torch @@ -11,8 +13,6 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor - -# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.speakers import get_speaker_manager @@ -119,7 +119,7 @@ class VitsArgs(Coqpit): upsample_kernel_sizes_decoder (List[int]): Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. - use_sdp (int): + use_sdp (bool): Use Stochastic Duration Predictor. Defaults to True. noise_scale (float): @@ -128,7 +128,7 @@ class VitsArgs(Coqpit): inference_noise_scale (float): Noise scale used for the sample noise tensor in inference. Defaults to 0.667. - length_scale (int): + length_scale (float): Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. noise_scale_dp (float): @@ -176,26 +176,26 @@ class VitsArgs(Coqpit): num_heads_text_encoder: int = 2 num_layers_text_encoder: int = 6 kernel_size_text_encoder: int = 3 - dropout_p_text_encoder: int = 0.1 - dropout_p_duration_predictor: int = 0.1 + dropout_p_text_encoder: float = 0.1 + dropout_p_duration_predictor: float = 0.5 kernel_size_posterior_encoder: int = 5 dilation_rate_posterior_encoder: int = 1 num_layers_posterior_encoder: int = 16 kernel_size_flow: int = 5 dilation_rate_flow: int = 1 num_layers_flow: int = 4 - resblock_type_decoder: int = "1" + resblock_type_decoder: str = "1" resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) upsample_initial_channel_decoder: int = 512 upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) - use_sdp: int = True + use_sdp: bool = True noise_scale: float = 1.0 inference_noise_scale: float = 0.667 - length_scale: int = 1 + length_scale: float = 1 noise_scale_dp: float = 1.0 - inference_noise_scale_dp: float = 0.8 + inference_noise_scale_dp: float = 1.0 max_inference_len: int = None init_discriminator: bool = True use_spectral_norm_disriminator: bool = False @@ -419,24 +419,23 @@ class Vits(BaseTTS): attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) - # logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - # logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp = logp2 + logp3 + logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # duration predictor attn_durations = attn.sum(3) if self.args.use_sdp: - nll_duration = self.duration_predictor( + loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, ) - nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask)) - outputs["nll_duration"] = nll_duration + loss_duration = loss_duration / torch.sum(x_mask) else: attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( @@ -445,7 +444,7 @@ class Vits(BaseTTS): g=g.detach() if self.args.detach_dp_input and g is not None else g, ) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) - outputs["loss_duration"] = loss_duration + outputs["loss_duration"] = loss_duration # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) @@ -563,8 +562,9 @@ class Vits(BaseTTS): outputs["waveform_seg"] = wav_seg # compute discriminator scores and features - outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"]) - _, outputs["feats_disc_real"] = self.disc(wav_seg) + outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc( + outputs["model_outputs"], wav_seg + ) # compute losses with autocast(enabled=False): # use float32 for the criterion @@ -579,23 +579,17 @@ class Vits(BaseTTS): scores_disc_fake=outputs["scores_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], + loss_duration=outputs["loss_duration"], ) - # handle the duration loss - if self.args.use_sdp: - loss_dict["nll_duration"] = outputs["nll_duration"] - loss_dict["loss"] += outputs["nll_duration"] - else: - loss_dict["loss_duration"] = outputs["loss_duration"] - loss_dict["loss"] += outputs["nll_duration"] - elif optimizer_idx == 1: # discriminator pass outputs = {} # compute scores and features - outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach()) - outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache) + outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc( + self.y_disc_cache.detach(), self.wav_seg_disc_cache + ) # compute loss with autocast(enabled=False): # use float32 for the criterion @@ -686,14 +680,21 @@ class Vits(BaseTTS): Returns: List: optimizers. """ - self.disc.requires_grad_(False) - gen_parameters = filter(lambda p: p.requires_grad, self.parameters()) - self.disc.requires_grad_(True) - optimizer1 = get_optimizer( + gen_parameters = chain( + self.text_encoder.parameters(), + self.posterior_encoder.parameters(), + self.flow.parameters(), + self.duration_predictor.parameters(), + self.waveform_decoder.parameters(), + ) + # add the speaker embedding layer + if hasattr(self, "emb_g"): + gen_parameters = chain(gen_parameters, self.emb_g) + optimizer0 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters ) - optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) - return [optimizer1, optimizer2] + optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + return [optimizer0, optimizer1] def get_lr(self) -> List: """Set the initial learning rates for each optimizer. @@ -712,9 +713,9 @@ class Vits(BaseTTS): Returns: List: Schedulers, one for each optimizer. """ - scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) - scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) - return [scheduler1, scheduler2] + scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler0, scheduler1] def get_criterion(self): """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index d4345b64..20712f1d 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -225,9 +225,10 @@ def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_sy if custom_symbols is not None: _symbols = custom_symbols + _id_to_symbol = {i: s for i, s in enumerate(_symbols)} elif tp: _symbols, _ = make_symbols(**tp) - _id_to_symbol = {i: s for i, s in enumerate(_symbols)} + _id_to_symbol = {i: s for i, s in enumerate(_symbols)} result = "" for symbol_id in sequence: diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 40d82365..0a343fbf 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -241,6 +241,7 @@ class AudioProcessor(object): self.sample_rate = sample_rate self.resample = resample self.num_mels = num_mels + self.log_func = log_func self.min_level_db = min_level_db or 0 self.frame_shift_ms = frame_shift_ms self.frame_length_ms = frame_length_ms diff --git a/TTS/utils/distribute.py b/TTS/utils/distribute.py index 1c6b0e1c..a51ef766 100644 --- a/TTS/utils/distribute.py +++ b/TTS/utils/distribute.py @@ -1,8 +1,6 @@ # edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py import torch import torch.distributed as dist -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from torch.autograd import Variable def reduce_tensor(tensor, num_gpus): @@ -20,46 +18,3 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): # Initialize distributed communication dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) - - -def apply_gradient_allreduce(module): - - # sync model parameters - for p in module.state_dict().values(): - if not torch.is_tensor(p): - continue - dist.broadcast(p, 0) - - def allreduce_params(): - if module.needs_reduction: - module.needs_reduction = False - # bucketing params based on value types - buckets = {} - for param in module.parameters(): - if param.requires_grad and param.grad is not None: - tp = type(param.data) - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(param) - for tp in buckets: - bucket = buckets[tp] - grads = [param.grad.data for param in bucket] - coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.reduce_op.SUM) - coalesced /= dist.get_world_size() - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) - - for param in list(module.parameters()): - - def allreduce_hook(*_): - Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access - - if param.requires_grad: - param.register_hook(allreduce_hook) - - def set_needs_reduction(self, *_): - self.needs_reduction = True - - module.register_forward_hook(set_needs_reduction) - return module diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 287143e5..6504cca6 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -53,7 +53,6 @@ def get_commit_hash(): # Not copying .git folder into docker container except (subprocess.CalledProcessError, FileNotFoundError): commit = "0000000" - print(" > Git Hash: {}".format(commit)) return commit @@ -62,7 +61,6 @@ def get_experiment_folder_path(root_path, model_name): date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") commit_hash = get_commit_hash() output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) - print(" > Experiment folder: {}".format(output_folder)) return output_folder diff --git a/TTS/utils/io.py b/TTS/utils/io.py index f634b023..dd4ffd60 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -3,7 +3,7 @@ import json import os import pickle as pickle_tts import shutil -from typing import Any +from typing import Any, Callable, Dict, Union import fsspec import torch @@ -53,18 +53,23 @@ def copy_model_files(config: Coqpit, out_path, new_fields): shutil.copyfileobj(source_file, target_file) -def load_fsspec(path: str, **kwargs) -> Any: +def load_fsspec( + path: str, + map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + **kwargs, +) -> Any: """Like torch.load but can load from other locations (e.g. s3:// , gs://). Args: path: Any path or url supported by fsspec. + map_location: torch.device or str. **kwargs: Keyword arguments forwarded to torch.load. Returns: Object stored in path. """ with fsspec.open(path, "rb") as f: - return torch.load(f, **kwargs) + return torch.load(f, map_location=map_location, **kwargs) def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin diff --git a/TTS/utils/logging/__init__.py b/TTS/utils/logging/__init__.py index 4b92221f..43fbf6f1 100644 --- a/TTS/utils/logging/__init__.py +++ b/TTS/utils/logging/__init__.py @@ -3,7 +3,7 @@ from TTS.utils.logging.tensorboard_logger import TensorboardLogger from TTS.utils.logging.wandb_logger import WandbLogger -def init_logger(config): +def init_dashboard_logger(config): if config.dashboard_logger == "tensorboard": dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model) diff --git a/TTS/utils/logging/console_logger.py b/TTS/utils/logging/console_logger.py index c5fbe8b4..0c1aa862 100644 --- a/TTS/utils/logging/console_logger.py +++ b/TTS/utils/logging/console_logger.py @@ -29,11 +29,13 @@ class ConsoleLogger: now = datetime.datetime.now() return now.strftime("%Y-%m-%d %H:%M:%S") - def print_epoch_start(self, epoch, max_epoch): + def print_epoch_start(self, epoch, max_epoch, output_path=None): print( "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), flush=True, ) + if output_path is not None: + print(f" --> {output_path}") def print_train_start(self): print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 1d61d392..4a45fb2d 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -110,7 +110,7 @@ class ModelManager(object): os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") output_stats_path = os.path.join(output_path, "scale_stats.npy") - output_speakers_path = os.path.join(output_path, "speakers.json") + # download files to the output path if self._check_dict_key(model_item, "github_rls_url"): # download from github release @@ -122,22 +122,52 @@ class ModelManager(object): if self._check_dict_key(model_item, "stats_file"): self._download_gdrive_file(model_item["stats_file"], output_stats_path) - # update the scale_path.npy file path in the model config.json - if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path): - # set scale stats path in config.json - config_path = output_config_path - config = load_config(config_path) - config.audio.stats_path = output_stats_path - config.save_json(config_path) - # update the speakers.json file path in the model config.json to the current path - if os.path.exists(output_speakers_path): - # set scale stats path in config.json - config_path = output_config_path - config = load_config(config_path) - config.d_vector_file = output_speakers_path - config.save_json(config_path) + # update paths in the config.json + self._update_paths(output_path, output_config_path) return output_model_path, output_config_path, model_item + def _update_paths(self, output_path: str, config_path: str) -> None: + """Update paths for certain files in config.json after download. + + Args: + output_path (str): local path the model is downloaded to. + config_path (str): local config.json path. + """ + output_stats_path = os.path.join(output_path, "scale_stats.npy") + output_d_vector_file_path = os.path.join(output_path, "speakers.json") + output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") + + # update the scale_path.npy file path in the model config.json + self._update_path("audio.stats_path", output_stats_path, config_path) + + # update the speakers.json file path in the model config.json to the current path + self._update_path("d_vector_file", output_d_vector_file_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path) + + # update the speaker_ids.json file path in the model config.json to the current path + self._update_path("speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) + + @staticmethod + def _update_path(field_name, new_path, config_path): + """Update the path in the model config.json for the current environment after download""" + if os.path.exists(new_path): + config = load_config(config_path) + field_names = field_name.split(".") + if len(field_names) > 1: + # field name points to a sub-level field + sub_conf = config + for fd in field_names[:-1]: + if fd in sub_conf: + sub_conf = sub_conf[fd] + else: + return + sub_conf[field_names[-1]] = new_path + else: + # field name points to a top-level field + config[field_name] = new_path + config.save_json(config_path) + def _download_gdrive_file(self, gdrive_idx, output): """Download files from GDrive using their file ids""" gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 98711d17..531523a4 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -251,7 +251,7 @@ class Synthesizer(object): d_vector=speaker_embedding, ) waveform = outputs["wav"] - mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().numpy() + mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() if not use_gl: # denormalize tts output based on tts audio config mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T diff --git a/TTS/utils/trainer_utils.py b/TTS/utils/trainer_utils.py index 577f1a8d..005114d1 100644 --- a/TTS/utils/trainer_utils.py +++ b/TTS/utils/trainer_utils.py @@ -1,5 +1,5 @@ import importlib -from typing import Dict, List +from typing import Dict, List, Tuple import torch @@ -10,9 +10,20 @@ def is_apex_available(): return importlib.util.find_spec("apex") is not None -def setup_torch_training_env(cudnn_enable, cudnn_benchmark): +def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]: + """Setup PyTorch environment for training. + + Args: + cudnn_enable (bool): Enable/disable CUDNN. + cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is + variable between batches. + use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise. + + Returns: + Tuple[bool, int]: is cuda on or off and number of GPUs in the environment. + """ num_gpus = torch.cuda.device_count() - if num_gpus > 1: + if num_gpus > 1 and not use_ddp: raise RuntimeError( f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`." ) diff --git a/TTS/vocoder/configs/univnet_config.py b/TTS/vocoder/configs/univnet_config.py index 85662831..67f324cf 100644 --- a/TTS/vocoder/configs/univnet_config.py +++ b/TTS/vocoder/configs/univnet_config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Dict from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig @@ -95,7 +96,7 @@ class UnivnetConfig(BaseGANVocoderConfig): # model specific params discriminator_model: str = "univnet_discriminator" generator_model: str = "univnet_generator" - generator_model_params: dict = field( + generator_model_params: Dict = field( default_factory=lambda: { "in_channels": 64, "out_channels": 1, @@ -120,7 +121,7 @@ class UnivnetConfig(BaseGANVocoderConfig): # loss weights - overrides stft_loss_weight: float = 2.5 - stft_loss_params: dict = field( + stft_loss_params: Dict = field( default_factory=lambda: { "n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], @@ -132,7 +133,7 @@ class UnivnetConfig(BaseGANVocoderConfig): hinge_G_loss_weight: float = 0 feat_match_loss_weight: float = 0 l1_spec_loss_weight: float = 0 - l1_spec_loss_params: dict = field( + l1_spec_loss_params: Dict = field( default_factory=lambda: { "use_mel": True, "sample_rate": 22050, @@ -152,7 +153,7 @@ class UnivnetConfig(BaseGANVocoderConfig): # lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html # lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) - optimizer_params: dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0}) + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0}) steps_to_start_discriminator: int = 200000 def __post_init__(self): diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index 83cd4233..24b905f9 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn from torch.nn.utils import weight_norm diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index a1e16150..4ce743b3 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -1,8 +1,8 @@ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm from TTS.utils.io import load_fsspec diff --git a/TTS/vocoder/models/melgan_multiscale_discriminator.py b/TTS/vocoder/models/melgan_multiscale_discriminator.py index 33e0a688..b4909f37 100644 --- a/TTS/vocoder/models/melgan_multiscale_discriminator.py +++ b/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -40,8 +40,8 @@ class MelganMultiscaleDiscriminator(nn.Module): ) def forward(self, x): - scores = list() - feats = list() + scores = [] + feats = [] for disc in self.discriminators: score, feat = disc(x) scores.append(score) diff --git a/TTS/vocoder/models/univnet_discriminator.py b/TTS/vocoder/models/univnet_discriminator.py index d99b2760..d6b0e5d5 100644 --- a/TTS/vocoder/models/univnet_discriminator.py +++ b/TTS/vocoder/models/univnet_discriminator.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn from torch.nn.utils import spectral_norm, weight_norm from TTS.utils.audio import TorchSTFT diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 5dc878d7..8d95a063 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -9,12 +9,12 @@ from torch.nn.utils import weight_norm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.model import BaseModel from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock +from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.utils.generic_utils import plot_results @@ -33,7 +33,7 @@ class WavegradArgs(Coqpit): ) -class Wavegrad(BaseModel): +class Wavegrad(BaseVocoder): """🐸 🌊 WaveGrad 🌊 model. Paper - https://arxiv.org/abs/2009.00713 @@ -257,14 +257,18 @@ class Wavegrad(BaseModel): loss = criterion(noise, noise_hat) return {"model_output": noise_hat}, {"loss": loss} - def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def train_log( # pylint: disable=no-self-use + self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: return None, None @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: return self.train_step(batch, criterion) - def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def eval_log( # pylint: disable=no-self-use + self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: return None, None def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument @@ -291,7 +295,8 @@ class Wavegrad(BaseModel): def get_scheduler(self, optimizer): return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) - def get_criterion(self): + @staticmethod + def get_criterion(): return torch.nn.L1Loss() @staticmethod diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 8a968019..9b0d6837 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -5,9 +5,9 @@ from typing import Dict, List, Tuple import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F from coqpit import Coqpit +from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler diff --git a/notebooks/dataset_analysis/CheckSpectrograms.ipynb b/notebooks/dataset_analysis/CheckSpectrograms.ipynb index c0cd0aa6..74ca51ab 100644 --- a/notebooks/dataset_analysis/CheckSpectrograms.ipynb +++ b/notebooks/dataset_analysis/CheckSpectrograms.ipynb @@ -3,28 +3,24 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "%matplotlib inline\n", "\n", "from TTS.utils.audio import AudioProcessor\n", "from TTS.tts.utils.visual import plot_spectrogram\n", - "from TTS.utils.io import load_config\n", + "from TTS.config import load_config\n", "\n", "import IPython.display as ipd\n", "import glob" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "config_path = \"/home/erogol/gdrive/Projects/TTS/recipes/ljspeech/align_tts/config_transformer2.json\"\n", "data_path = \"/home/erogol/gdrive/Datasets/LJSpeech-1.1/\"\n", @@ -39,28 +35,28 @@ "\n", "print(\"File list, by index:\")\n", "dict(enumerate(file_paths))" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "### Setup Audio Processor\n", "Play with the AP parameters until you find a good fit with the synthesis speech below.\n", "\n", "The default values are loaded from your config.json file, so you only need to\n", "uncomment and modify values below that you'd like to tune." - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "tune_params={\n", "# 'audio_processor': 'audio',\n", @@ -95,54 +91,54 @@ "tuned_config.update(tune_params)\n", "\n", "AP = AudioProcessor(**tuned_config);" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "### Check audio loading " - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "wav = AP.load_wav(SAMPLE_FILE_PATH)\n", "ipd.Audio(data=wav, rate=AP.sample_rate) " - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "### Generate Mel-Spectrogram and Re-synthesis with GL" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "AP.power = 1.5" - ] - }, - { - "cell_type": "code", - "execution_count": null, + ], "metadata": { "Collapsed": "false" - }, + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "AP.power = 1.5" + ], "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, "source": [ "mel = AP.melspectrogram(wav)\n", "print(\"Max:\", mel.max())\n", @@ -152,24 +148,24 @@ "\n", "wav_gen = AP.inv_melspectrogram(mel)\n", "ipd.Audio(wav_gen, rate=AP.sample_rate)" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "### Generate Linear-Spectrogram and Re-synthesis with GL" - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "spec = AP.spectrogram(wav)\n", "print(\"Max:\", spec.max())\n", @@ -179,26 +175,26 @@ "\n", "wav_gen = AP.inv_spectrogram(spec)\n", "ipd.Audio(wav_gen, rate=AP.sample_rate)" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "### Compare values for a certain parameter\n", "\n", "Optimize your parameters by comparing different values per parameter at a time." - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "from librosa import display\n", "from matplotlib import pylab as plt\n", @@ -238,36 +234,39 @@ " val = values[idx]\n", " print(\" > {} = {}\".format(attribute, val))\n", " IPython.display.display(IPython.display.Audio(wav_gen, rate=AP.sample_rate))" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "compare_values(\"preemphasis\", [0, 0.5, 0.97, 0.98, 0.99])" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "compare_values(\"ref_level_db\", [2, 5, 10, 15, 20, 25, 30, 35, 40, 1000])" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } } ], "metadata": { "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3.8.5 64-bit ('torch': conda)" }, "language_info": { "codemirror_mode": { @@ -280,8 +279,11 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" + }, + "interpreter": { + "hash": "27648abe09795c3a768a281b31f7524fcf66a207e733f8ecda3a4e1fd1059fb0" } }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/notebooks/dataset_analysis/analyze.py b/notebooks/dataset_analysis/analyze.py index 6c6bc582..9ba42fb9 100644 --- a/notebooks/dataset_analysis/analyze.py +++ b/notebooks/dataset_analysis/analyze.py @@ -43,7 +43,7 @@ def process_meta_data(path): meta_data = {} # load meta data - with open(path, "r") as f: + with open(path, "r", encoding="utf-8") as f: data = csv.reader(f, delimiter="|") for row in data: frames = int(row[2]) @@ -92,7 +92,7 @@ def save_training(file_path, meta_data): rows.append(d["row"] + "\n") random.shuffle(rows) - with open(file_path, "w+") as f: + with open(file_path, "w+", encoding="utf-8") as f: for row in rows: f.write(row) @@ -156,7 +156,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path): phonemes = {} - with open(train_path, "r") as f: + with open(train_path, "r", encoding="utf-8") as f: data = csv.reader(f, delimiter="|") phonemes["None"] = 0 for row in data: @@ -174,9 +174,9 @@ def plot_phonemes(train_path, cmu_dict_path, save_path): phonemes["None"] += 1 x, y = [], [] - for key in phonemes: - x.append(key) - y.append(phonemes[key]) + for k, v in phonemes.items(): + x.append(k) + y.append(v) plt.figure() plt.rcParams["figure.figsize"] = (50, 20) diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index 45e9d429..7cf52f89 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -29,7 +29,7 @@ config = VitsConfig( run_name="vits_ljspeech", batch_size=48, eval_batch_size=16, - batch_group_size=0, + batch_group_size=5, num_loader_workers=4, num_eval_loader_workers=4, run_eval=True, @@ -43,7 +43,7 @@ config = VitsConfig( print_step=25, print_eval=True, mixed_precision=True, - max_seq_len=5000, + max_seq_len=500000, output_path=output_path, datasets=[dataset_config], ) diff --git a/requirements.dev.txt b/requirements.dev.txt index afb5ebe6..c995f9e6 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -2,4 +2,4 @@ black coverage isort nose -pylint==2.8.3 +pylint==2.10.2 diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 10067094..717b2e0f 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -124,7 +124,7 @@ class TestTTSDataset(unittest.TestCase): avg_length = mel_lengths.numpy().mean() assert avg_length >= last_length - dataloader.dataset.sort_items() + dataloader.dataset.sort_and_filter_items() is_items_reordered = False for idx, item in enumerate(dataloader.dataset.items): if item != frames[idx]: