Merge branch 'fix_distribute' into dev

This commit is contained in:
Eren Gölge 2021-08-30 16:31:07 +00:00
commit d16da949a5
47 changed files with 487 additions and 341 deletions

3
.gitignore vendored
View File

@ -155,4 +155,5 @@ deps.json
speakers.json speakers.json
internal/* internal/*
*_pitch.npy *_pitch.npy
*_phoneme.npy *_phoneme.npy
wandb

View File

@ -64,6 +64,11 @@ disable=missing-docstring,
too-many-public-methods, too-many-public-methods,
too-many-lines, too-many-lines,
bare-except, bare-except,
## for avoiding weird p3.6 CI linter error
## TODO: see later if we can remove this
assigning-non-slot,
unsupported-assignment-operation,
## end
line-too-long, line-too-long,
fixme, fixme,
wrong-import-order, wrong-import-order,
@ -73,6 +78,7 @@ disable=missing-docstring,
invalid-name, invalid-name,
too-many-instance-attributes, too-many-instance-attributes,
arguments-differ, arguments-differ,
arguments-renamed,
no-name-in-module, no-name-in-module,
no-member, no-member,
unsubscriptable-object, unsubscriptable-object,

View File

@ -102,7 +102,7 @@ You can also help us implement more models.
## Install TTS ## Install TTS
🐸TTS is tested on Ubuntu 18.04 with **python >= 3.6, < 3.9**. 🐸TTS is tested on Ubuntu 18.04 with **python >= 3.6, < 3.9**.
If you are only interested in [synthesizing speech](https://github.com/coqui-ai/TTS/tree/dev#example-synthesizing-speech-on-terminal-using-the-released-models) with the released 🐸TTS models, installing from PyPI is the easiest option. If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option.
```bash ```bash
pip install TTS pip install TTS

View File

@ -1,6 +1,6 @@
import os import os
with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f:
version = f.read().strip() version = f.read().strip()
__version__ = version __version__ = version

View File

@ -97,7 +97,7 @@ Example run:
enable_eos_bos=C.enable_eos_bos_chars, enable_eos_bos=C.enable_eos_bos_chars,
) )
dataset.sort_items() dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -158,7 +158,7 @@ Example run:
# ourput metafile # ourput metafile
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
with open(metafile, "w") as f: with open(metafile, "w", encoding="utf-8") as f:
for p in file_paths: for p in file_paths:
f.write(f"{p[0]}|{p[1]}\n") f.write(f"{p[0]}|{p[1]}\n")
print(f" >> Metafile created: {metafile}") print(f" >> Metafile created: {metafile}")

View File

@ -32,6 +32,7 @@ def main():
command.append("--restore_path={}".format(args.restore_path)) command.append("--restore_path={}".format(args.restore_path))
command.append("--config_path={}".format(args.config_path)) command.append("--config_path={}".format(args.config_path))
command.append("--group_id=group_{}".format(group_id)) command.append("--group_id=group_{}".format(group_id))
command.append("--use_ddp=true")
command += unargs command += unargs
command.append("") command.append("")
@ -42,7 +43,7 @@ def main():
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[-1] = "--rank={}".format(i) command[-1] = "--rank={}".format(i)
# prevent stdout for processes with rank != 0 # prevent stdout for processes with rank != 0
stdout = None if i == 0 else open(os.devnull, "w") stdout = None
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
processes.append(p) processes.append(p)
print(command) print(command)

View File

@ -46,7 +46,7 @@ def setup_loader(ap, r, verbose=False):
if c.use_phonemes and c.compute_input_seq_cache: if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers) dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_items() dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
loader = DataLoader( loader = DataLoader(
dataset, dataset,
@ -215,7 +215,7 @@ def extract_spectrograms(
wav = ap.inv_melspectrogram(mel) wav = ap.inv_melspectrogram(mel)
ap.save_wav(wav, wav_gl_path) ap.save_wav(wav, wav_gl_path)
with open(os.path.join(output_path, metada_name), "w") as f: with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
for data in export_metadata: for data in export_metadata:
f.write(f"{data[0]}|{data[1]+'.npy'}\n") f.write(f"{data[0]}|{data[1]+'.npy'}\n")

View File

@ -190,7 +190,7 @@ class BaseTrainingConfig(Coqpit):
Name of the model that is used in the training. Name of the model that is used in the training.
run_name (str): run_name (str):
Name of the experiment. This prefixes the output folder name. Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
run_description (str): run_description (str):
Short description of the experiment. Short description of the experiment.
@ -272,7 +272,7 @@ class BaseTrainingConfig(Coqpit):
""" """
model: str = None model: str = None
run_name: str = "" run_name: str = "coqui_tts"
run_description: str = "" run_description: str = ""
# training params # training params
epochs: int = 10000 epochs: int = 10000

View File

@ -23,35 +23,31 @@ class BaseModel(nn.Module, ABC):
""" """
@abstractmethod @abstractmethod
def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict: def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training. """Forward pass for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is mostly used by You can be flexible here and use different number of arguments and argument names since it is intended to be
`train_step()` in training whitout exposing it to the out of the class. used by `train_step()` without exposing it out of the model.
Args: Args:
text (torch.Tensor): Input text character sequence ids. input (torch.Tensor): Input tensor.
aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs. aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
for the model.
Returns: Returns:
Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model. Dict: Model outputs. Main model output must be named as "model_outputs".
""" """
outputs_dict = {"model_outputs": None} outputs_dict = {"model_outputs": None}
... ...
return outputs_dict return outputs_dict
@abstractmethod @abstractmethod
def inference(self, text: torch.Tensor, aux_input={}) -> Dict: def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
"""Forward pass for inference. """Forward pass for inference.
After the model is trained this is the only function that connects the model the out world.
This function must only take a `text` input and a dictionary that has all the other model specific inputs.
We don't use `*kwargs` since it is problematic with the TorchScript API. We don't use `*kwargs` since it is problematic with the TorchScript API.
Args: Args:
text (torch.Tensor): [description] input (torch.Tensor): [description]
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc. aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
Returns: Returns:

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
# adapted from https://github.com/cvqluu/GE2E-Loss # adapted from https://github.com/cvqluu/GE2E-Loss

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import List from typing import Dict, List
from coqpit import MISSING from coqpit import MISSING
@ -14,7 +14,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# model params # model params
model_params: dict = field( model_params: Dict = field(
default_factory=lambda: { default_factory=lambda: {
"model_name": "lstm", "model_name": "lstm",
"input_dim": 80, "input_dim": 80,
@ -25,9 +25,9 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
} }
) )
audio_augmentation: dict = field(default_factory=lambda: {}) audio_augmentation: Dict = field(default_factory=lambda: {})
storage: dict = field( storage: Dict = field(
default_factory=lambda: { default_factory=lambda: {
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 15, # the size of the in-memory storage with respect to a single batch "storage_size": 15, # the size of the in-memory storage with respect to a single batch

View File

@ -94,7 +94,8 @@ def download_and_extract(directory, subset, urls):
extract_path = zip_filepath.strip(".zip") extract_path = zip_filepath.strip(".zip")
# check zip file md5sum # check zip file md5sum
md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest() with open(zip_filepath, "rb") as f_zip:
md5 = hashlib.md5(f_zip.read()).hexdigest()
if md5 != MD5SUM[subset]: if md5 != MD5SUM[subset]:
raise ValueError("md5sum of %s mismatch" % zip_filepath) raise ValueError("md5sum of %s mismatch" % zip_filepath)

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import importlib import importlib
import logging
import multiprocessing import multiprocessing
import os import os
import platform import platform
@ -16,6 +15,7 @@ from urllib.parse import urlparse
import fsspec import fsspec
import torch import torch
import torch.distributed as dist
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
@ -38,7 +38,7 @@ from TTS.utils.generic_utils import (
to_cuda, to_cuda,
) )
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.models import setup_model as setup_vocoder_model
@ -77,12 +77,16 @@ class TrainingArgs(Coqpit):
best_path: str = field( best_path: str = field(
default="", default="",
metadata={ metadata={
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
}, },
) )
config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
use_ddp: bool = field(
default=False,
metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
)
class Trainer: class Trainer:
@ -144,6 +148,7 @@ class Trainer:
>>> trainer.fit() >>> trainer.fit()
TODO: TODO:
- Wrap model for not calling .module in DDP.
- Accumulate gradients b/w batches. - Accumulate gradients b/w batches.
- Deepspeed integration - Deepspeed integration
- Profiler integration. - Profiler integration.
@ -151,29 +156,33 @@ class Trainer:
- TPU training - TPU training
""" """
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
if config is None: if config is None:
# parse config from console arguments # parse config from console arguments
config, output_path, _, c_logger, dashboard_logger = process_args(args) config, output_path, _, c_logger, dashboard_logger = process_args(args)
self.output_path = output_path
self.args = args self.args = args
self.config = config self.config = config
self.output_path = output_path
self.config.output_log_path = output_path self.config.output_log_path = output_path
# setup logging
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
# init loggers # init loggers
self.c_logger = ConsoleLogger() if c_logger is None else c_logger self.c_logger = ConsoleLogger() if c_logger is None else c_logger
self.dashboard_logger = dashboard_logger self.dashboard_logger = dashboard_logger
if self.dashboard_logger is None: # only allow dashboard logging for the main process in DDP mode
self.dashboard_logger = init_logger(config) if self.dashboard_logger is None and args.rank == 0:
self.dashboard_logger = init_dashboard_logger(config)
if not self.config.log_model_step: if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step self.config.log_model_step = self.config.save_step
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
self.total_steps_done = 0 self.total_steps_done = 0
self.epochs_done = 0 self.epochs_done = 0
self.restore_step = 0 self.restore_step = 0
@ -247,10 +256,10 @@ class Trainer:
if self.use_apex: if self.use_apex:
self.scaler = None self.scaler = None
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
if isinstance(self.optimizer, list): # if isinstance(self.optimizer, list):
self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer) # self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
else: # else:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
else: else:
self.scaler = None self.scaler = None
@ -319,14 +328,14 @@ class Trainer:
return obj return obj
print(" > Restoring from %s ..." % os.path.basename(restore_path)) print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = load_fsspec(restore_path) checkpoint = load_fsspec(restore_path, map_location="cpu")
try: try:
print(" > Restoring Model...") print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
print(" > Restoring Optimizer...") print(" > Restoring Optimizer...")
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
print(" > Restoring AMP Scaler...") print(" > Restoring Scaler...")
scaler = _restore_list_objs(checkpoint["scaler"], scaler) scaler = _restore_list_objs(checkpoint["scaler"], scaler)
except (KeyError, RuntimeError): except (KeyError, RuntimeError):
print(" > Partial model initialization...") print(" > Partial model initialization...")
@ -346,10 +355,11 @@ class Trainer:
" > Model restored from step %d" % checkpoint["step"], " > Model restored from step %d" % checkpoint["step"],
) )
restore_step = checkpoint["step"] restore_step = checkpoint["step"]
torch.cuda.empty_cache()
return model, optimizer, scaler, restore_step return model, optimizer, scaler, restore_step
@staticmethod
def _get_loader( def _get_loader(
self,
model: nn.Module, model: nn.Module,
config: Coqpit, config: Coqpit,
ap: AudioProcessor, ap: AudioProcessor,
@ -358,8 +368,14 @@ class Trainer:
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
) -> DataLoader: ) -> DataLoader:
if hasattr(model, "get_data_loader"): if num_gpus > 1:
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) if hasattr(model.module, "get_data_loader"):
loader = model.module.get_data_loader(
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
)
else:
if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
return loader return loader
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader: def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
@ -387,12 +403,28 @@ class Trainer:
Returns: Returns:
Dict: Formatted batch. Dict: Formatted batch.
""" """
batch = self.model.format_batch(batch) if self.num_gpus > 1:
batch = self.model.module.format_batch(batch)
else:
batch = self.model.format_batch(batch)
if self.use_cuda: if self.use_cuda:
for k, v in batch.items(): for k, v in batch.items():
batch[k] = to_cuda(v) batch[k] = to_cuda(v)
return batch return batch
@staticmethod
def master_params(optimizer: torch.optim.Optimizer):
"""Generator over parameters owned by the optimizer.
Used to select parameters used by the optimizer for gradient clipping.
Args:
optimizer: Target optimizer.
"""
for group in optimizer.param_groups:
for p in group["params"]:
yield p
@staticmethod @staticmethod
def _model_train_step( def _model_train_step(
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
@ -450,6 +482,8 @@ class Trainer:
step_start_time = time.time() step_start_time = time.time()
# zero-out optimizer # zero-out optimizer
optimizer.zero_grad() optimizer.zero_grad()
# forward pass and loss computation
with torch.cuda.amp.autocast(enabled=config.mixed_precision): with torch.cuda.amp.autocast(enabled=config.mixed_precision):
if optimizer_idx is not None: if optimizer_idx is not None:
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx) outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
@ -461,9 +495,9 @@ class Trainer:
step_time = time.time() - step_start_time step_time = time.time() - step_start_time
return None, {}, step_time return None, {}, step_time
# check nan loss # # check nan loss
if torch.isnan(loss_dict["loss"]).any(): # if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.") # raise RuntimeError(f" > NaN loss detected - {loss_dict}")
# set gradient clipping threshold # set gradient clipping threshold
if "grad_clip" in config and config.grad_clip is not None: if "grad_clip" in config and config.grad_clip is not None:
@ -481,6 +515,8 @@ class Trainer:
update_lr_scheduler = True update_lr_scheduler = True
if self.use_amp_scaler: if self.use_amp_scaler:
if self.use_apex: if self.use_apex:
# TODO: verify AMP use for GAN training in TTS
# https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss: with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
@ -491,7 +527,9 @@ class Trainer:
scaler.scale(loss_dict["loss"]).backward() scaler.scale(loss_dict["loss"]).backward()
if grad_clip > 0: if grad_clip > 0:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False) grad_norm = torch.nn.utils.clip_grad_norm_(
self.master_params(optimizer), grad_clip, error_if_nonfinite=False
)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm): if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0 grad_norm = 0
@ -571,7 +609,8 @@ class Trainer:
total_step_time = 0 total_step_time = 0
for idx, optimizer in enumerate(self.optimizer): for idx, optimizer in enumerate(self.optimizer):
criterion = self.criterion criterion = self.criterion
scaler = self.scaler[idx] if self.use_amp_scaler else None # scaler = self.scaler[idx] if self.use_amp_scaler else None
scaler = self.scaler
scheduler = self.scheduler[idx] scheduler = self.scheduler[idx]
outputs, loss_dict_new, step_time = self._optimize( outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
@ -592,13 +631,13 @@ class Trainer:
outputs = outputs_per_optimizer outputs = outputs_per_optimizer
# update avg runtime stats # update avg runtime stats
keep_avg_update = dict() keep_avg_update = {}
keep_avg_update["avg_loader_time"] = loader_time keep_avg_update["avg_loader_time"] = loader_time
keep_avg_update["avg_step_time"] = step_time keep_avg_update["avg_step_time"] = step_time
self.keep_avg_train.update_values(keep_avg_update) self.keep_avg_train.update_values(keep_avg_update)
# update avg loss stats # update avg loss stats
update_eval_values = dict() update_eval_values = {}
for key, value in loss_dict.items(): for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value update_eval_values["avg_" + key] = value
self.keep_avg_train.update_values(update_eval_values) self.keep_avg_train.update_values(update_eval_values)
@ -662,9 +701,10 @@ class Trainer:
if audios is not None: if audios is not None:
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate) self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.dashboard_logger.flush()
self.total_steps_done += 1 self.total_steps_done += 1
self.callbacks.on_train_step_end() self.callbacks.on_train_step_end()
self.dashboard_logger.flush()
return outputs, loss_dict return outputs, loss_dict
def train_epoch(self) -> None: def train_epoch(self) -> None:
@ -674,16 +714,20 @@ class Trainer:
self.data_train, self.data_train,
verbose=True, verbose=True,
) )
self.model.train() if self.num_gpus > 1:
self.model.module.train()
else:
self.model.train()
epoch_start_time = time.time() epoch_start_time = time.time()
if self.use_cuda: if self.use_cuda:
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
else: else:
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start() self.c_logger.print_train_start()
loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader): for cur_step, batch in enumerate(self.train_loader):
loader_start_time = time.time()
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
loader_start_time = time.time()
epoch_time = time.time() - epoch_start_time epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats # Plot self.epochs_done Stats
if self.args.rank == 0: if self.args.rank == 0:
@ -753,7 +797,7 @@ class Trainer:
loss_dict = self._detach_loss_dict(loss_dict) loss_dict = self._detach_loss_dict(loss_dict)
# update avg stats # update avg stats
update_eval_values = dict() update_eval_values = {}
for key, value in loss_dict.items(): for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value update_eval_values["avg_" + key] = value
self.keep_avg_eval.update_values(update_eval_values) self.keep_avg_eval.update_values(update_eval_values)
@ -784,6 +828,7 @@ class Trainer:
loader_time = time.time() - loader_start_time loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
outputs, _ = self.eval_step(batch, cur_step) outputs, _ = self.eval_step(batch, cur_step)
loader_start_time = time.time()
# plot epoch stats, artifacts and figures # plot epoch stats, artifacts and figures
if self.args.rank == 0: if self.args.rank == 0:
figures, audios = None, None figures, audios = None, None
@ -800,7 +845,7 @@ class Trainer:
def test_run(self) -> None: def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard.""" Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
if self.eval_loader is None: if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader( self.eval_loader = self.get_eval_dataloader(
self.ap, self.ap,
@ -810,27 +855,43 @@ class Trainer:
if hasattr(self.eval_loader.dataset, "load_test_samples"): if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1) samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None) if self.num_gpus > 1:
figures, audios = self.model.module.test_run(self.ap, samples, None)
else:
figures, audios = self.model.test_run(self.ap, samples, None)
else: else:
figures, audios = self.model.test_run(self.ap) if self.num_gpus > 1:
figures, audios = self.model.module.test_run(self.ap)
else:
figures, audios = self.model.test_run(self.ap)
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.dashboard_logger.test_figures(self.total_steps_done, figures) self.dashboard_logger.test_figures(self.total_steps_done, figures)
def _fit(self) -> None: def _restore_best_loss(self):
"""🏃 train -> evaluate -> test for the number of epochs.""" """Restore the best loss from the args.best_path if provided else
from the model (`args.restore_path` or `args.continue_path`) used for resuming the training"""
if self.restore_step != 0 or self.args.best_path: if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"] ch = load_fsspec(self.args.restore_path, map_location="cpu")
if "model_loss" in ch:
self.best_loss = ch["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.") print(f" > Starting with loaded last best loss {self.best_loss}.")
def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs."""
self._restore_best_loss()
self.total_steps_done = self.restore_step self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs): for epoch in range(0, self.config.epochs):
if self.num_gpus > 1:
# let all processes sync up before starting with a new epoch of training
dist.barrier()
self.callbacks.on_epoch_start() self.callbacks.on_epoch_start()
self.keep_avg_train = KeepAverage() self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs) self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
self.train_epoch() self.train_epoch()
if self.config.run_eval: if self.config.run_eval:
self.eval_epoch() self.eval_epoch()
@ -839,20 +900,26 @@ class Trainer:
self.c_logger.print_epoch_end( self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
) )
self.save_best_model() if self.args.rank in [None, 0]:
self.save_best_model()
self.callbacks.on_epoch_end() self.callbacks.on_epoch_end()
def fit(self) -> None: def fit(self) -> None:
"""Where the ✨magic✨ happens...""" """Where the ✨magic✨ happens..."""
try: try:
self._fit() self._fit()
self.dashboard_logger.finish() if self.args.rank == 0:
self.dashboard_logger.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt() self.callbacks.on_keyboard_interrupt()
# if the output folder is empty remove the run. # if the output folder is empty remove the run.
remove_experiment_folder(self.output_path) remove_experiment_folder(self.output_path)
# clear the DDP processes
if self.num_gpus > 1:
dist.destroy_process_group()
# finish the wandb run and sync data # finish the wandb run and sync data
self.dashboard_logger.finish() if self.args.rank == 0:
self.dashboard_logger.finish()
# stop without error signal # stop without error signal
try: try:
sys.exit(0) sys.exit(0)
@ -902,18 +969,30 @@ class Trainer:
keep_after=self.config.keep_after, keep_after=self.config.keep_after,
) )
@staticmethod def _setup_logger_config(self, log_file: str) -> None:
def _setup_logger_config(log_file: str) -> None: """Write log strings to a file and print logs to the terminal.
handlers = [logging.StreamHandler()] TODO: Causes formatting issues in pdb debugging."""
# Only add a log file if the output location is local due to poor class Logger(object):
# support for writing logs to file-like objects. def __init__(self, print_to_terminal=True):
parsed_url = urlparse(log_file) self.print_to_terminal = print_to_terminal
if not parsed_url.scheme or parsed_url.scheme == "file": self.terminal = sys.stdout
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path) self.log_file = log_file
handlers.append(logging.FileHandler(schemeless_path))
logging.basicConfig(level=logging.INFO, format="", handlers=handlers) def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
# don't let processes rank > 0 write to the terminal
sys.stdout = Logger(self.args.rank == 0)
@staticmethod @staticmethod
def _is_apex_available() -> bool: def _is_apex_available() -> bool:
@ -1109,8 +1188,6 @@ def process_args(args, config=None):
config = register_config(config_base.model)() config = register_config(config_base.model)()
# override values from command-line args # override values from command-line args
config.parse_known_args(coqpit_overrides, relaxed_parser=True) config.parse_known_args(coqpit_overrides, relaxed_parser=True)
if config.mixed_precision:
print(" > Mixed precision mode is ON")
experiment_path = args.continue_path experiment_path = args.continue_path
if not experiment_path: if not experiment_path:
experiment_path = get_experiment_folder_path(config.output_path, config.run_name) experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
@ -1130,8 +1207,7 @@ def process_args(args, config=None):
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields["characters"] = used_characters new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields) copy_model_files(config, experiment_path, new_fields)
dashboard_logger = init_dashboard_logger(config)
dashboard_logger = init_logger(config)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger return config, experiment_path, audio_path, c_logger, dashboard_logger

View File

@ -96,7 +96,7 @@ class VitsConfig(BaseTTSConfig):
model_args: VitsArgs = field(default_factory=VitsArgs) model_args: VitsArgs = field(default_factory=VitsArgs)
# optimizer # optimizer
grad_clip: List[float] = field(default_factory=lambda: [5, 5]) grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
lr_gen: float = 0.0002 lr_gen: float = 0.0002
lr_disc: float = 0.0002 lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR" lr_scheduler_gen: str = "ExponentialLR"
@ -113,14 +113,16 @@ class VitsConfig(BaseTTSConfig):
gen_loss_alpha: float = 1.0 gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0 feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0 mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
# data loader params # data loader params
return_wav: bool = True return_wav: bool = True
compute_linear_spec: bool = True compute_linear_spec: bool = True
# overrides # overrides
min_seq_len: int = 13 sort_by_audio_len: bool = True
max_seq_len: int = 500 min_seq_len: int = 0
max_seq_len: int = 500000
r: int = 1 # DO NOT CHANGE r: int = 1 # DO NOT CHANGE
add_blank: bool = True add_blank: bool = True

View File

@ -69,7 +69,7 @@ class TTSDataset(Dataset):
batch. Set 0 to disable. Defaults to 0. batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed min_seq_len (int): Minimum input sequence length to be processed
by the loader. Filter out input sequences that are shorter than this. Some models have a by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
minimum input length due to its architecture. Defaults to 0. minimum input length due to its architecture. Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
@ -302,10 +302,23 @@ class TTSDataset(Dataset):
for idx, p in enumerate(phonemes): for idx, p in enumerate(phonemes):
self.items[idx][0] = p self.items[idx][0] = p
def sort_items(self): def sort_and_filter_items(self, by_audio_len=False):
r"""Sort instances based on text length in ascending order""" r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
lengths = np.array([len(ins[0]) for ins in self.items]) range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
"""
# compute the target sequence length
if by_audio_len:
lengths = []
for item in self.items:
lengths.append(os.path.getsize(item[1]))
lengths = np.array(lengths)
else:
lengths = np.array([len(ins[0]) for ins in self.items])
# sort items based on the sequence length in ascending order
idxs = np.argsort(lengths) idxs = np.argsort(lengths)
new_items = [] new_items = []
ignored = [] ignored = []
@ -315,7 +328,10 @@ class TTSDataset(Dataset):
ignored.append(idx) ignored.append(idx)
else: else:
new_items.append(self.items[idx]) new_items.append(self.items[idx])
# shuffle batch groups # shuffle batch groups
# create batches with similar length items
# the larger the `batch_group_size`, the higher the length variety in a batch.
if self.batch_group_size > 0: if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size): for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size offset = i * self.batch_group_size
@ -325,6 +341,7 @@ class TTSDataset(Dataset):
new_items[offset:end_offset] = temp_items new_items[offset:end_offset] = temp_items
self.items = new_items self.items = new_items
# logging
if self.verbose: if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths))) print(" | > Min length sequence: {}".format(np.min(lengths)))

View File

@ -66,7 +66,7 @@ def load_meta_data(datasets, eval_split=True):
def load_attention_mask_meta_data(metafile_path): def load_attention_mask_meta_data(metafile_path):
"""Load meta data file created by compute_attention_masks.py""" """Load meta data file created by compute_attention_masks.py"""
with open(metafile_path, "r") as f: with open(metafile_path, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
meta_data = [] meta_data = []

View File

@ -19,7 +19,7 @@ def tweb(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "tweb" speaker_name = "tweb"
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
cols = line.split("\t") cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav") wav_file = os.path.join(root_path, cols[0] + ".wav")
@ -33,7 +33,7 @@ def mozilla(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "mozilla" speaker_name = "mozilla"
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
cols = line.split("|") cols = line.split("|")
wav_file = cols[1].strip() wav_file = cols[1].strip()
@ -77,7 +77,7 @@ def mailabs(root_path, meta_files=None):
continue continue
speaker_name = speaker_name_match.group("speaker_name") speaker_name = speaker_name_match.group("speaker_name")
print(" | > {}".format(csv_file)) print(" | > {}".format(csv_file))
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
cols = line.split("|") cols = line.split("|")
if meta_files is None: if meta_files is None:
@ -102,7 +102,7 @@ def ljspeech(root_path, meta_file):
for line in ttf: for line in ttf:
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1] text = cols[2]
items.append([text, wav_file, speaker_name]) items.append([text, wav_file, speaker_name])
return items return items
@ -116,7 +116,7 @@ def ljspeech_test(root_path, meta_file):
for idx, line in enumerate(ttf): for idx, line in enumerate(ttf):
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1] text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"]) items.append([text, wav_file, f"ljspeech-{idx}"])
return items return items
@ -158,7 +158,7 @@ def css10(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "ljspeech" speaker_name = "ljspeech"
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, cols[0]) wav_file = os.path.join(root_path, cols[0])
@ -172,7 +172,7 @@ def nancy(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "nancy" speaker_name = "nancy"
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
utt_id = line.split()[1] utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1] text = line[line.find('"') + 1 : line.rfind('"') - 1]
@ -185,7 +185,7 @@ def common_voice(root_path, meta_file):
"""Normalize the common voice meta data file to TTS format.""" """Normalize the common voice meta data file to TTS format."""
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
if line.startswith("client_id"): if line.startswith("client_id"):
continue continue
@ -208,7 +208,7 @@ def libri_tts(root_path, meta_files=None):
for meta_file in meta_files: for meta_file in meta_files:
_meta_file = os.path.basename(meta_file).split(".")[0] _meta_file = os.path.basename(meta_file).split(".")[0]
with open(meta_file, "r") as ttf: with open(meta_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
cols = line.split("\t") cols = line.split("\t")
file_name = cols[0] file_name = cols[0]
@ -245,7 +245,7 @@ def brspeech(root_path, meta_file):
"""BRSpeech 3.0 beta""" """BRSpeech 3.0 beta"""
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
if line.startswith("wav_filename"): if line.startswith("wav_filename"):
continue continue
@ -268,7 +268,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
if isinstance(test_speakers, list): # if is list ignore this speakers ids if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers: if speaker_id in test_speakers:
continue continue
with open(meta_file) as file_text: with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0] text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id]) items.append([text, wav_file, "VCTK_" + speaker_id])
@ -295,7 +295,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
def mls(root_path, meta_files=None): def mls(root_path, meta_files=None):
"""http://www.openslr.org/94/""" """http://www.openslr.org/94/"""
items = [] items = []
with open(os.path.join(root_path, meta_files), "r") as meta: with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
for line in meta: for line in meta:
file, text = line.split("\t") file, text = line.split("\t")
text = text[:-1] text = text[:-1]
@ -329,7 +329,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
# if not exists meta file, crawl recursively for 'wav' files # if not exists meta file, crawl recursively for 'wav' files
if meta_file is not None: if meta_file is not None:
with open(str(meta_file), "r") as f: with open(str(meta_file), "r", encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()] return [x.strip().split("|") for x in f.readlines()]
elif not cache_to.exists(): elif not cache_to.exists():
@ -346,12 +346,12 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
text = None # VoxCel does not provide transciptions, and they are not needed for training the SE text = None # VoxCel does not provide transciptions, and they are not needed for training the SE
meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
cnt += 1 cnt += 1
with open(str(cache_to), "w") as f: with open(str(cache_to), "w", encoding="utf-8") as f:
f.write("".join(meta_data)) f.write("".join(meta_data))
if cnt < expected_count: if cnt < expected_count:
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")
with open(str(cache_to), "r") as f: with open(str(cache_to), "r", encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()] return [x.strip().split("|") for x in f.readlines()]
@ -367,7 +367,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "baker" speaker_name = "baker"
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
wav_name, text = line.rstrip("\n").split("|") wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name) wav_path = os.path.join(root_path, "clips_22", wav_name)
@ -380,7 +380,7 @@ def kokoro(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "kokoro" speaker_name = "kokoro"
with open(txt_file, "r") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
class FFTransformer(nn.Module): class FFTransformer(nn.Module):

View File

@ -524,6 +524,7 @@ class VitsGeneratorLoss(nn.Module):
self.kl_loss_alpha = c.kl_loss_alpha self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha
self.stft = TorchSTFT( self.stft = TorchSTFT(
c.audio.fft_size, c.audio.fft_size,
@ -590,10 +591,11 @@ class VitsGeneratorLoss(nn.Module):
scores_disc_fake, scores_disc_fake,
feats_disc_fake, feats_disc_fake,
feats_disc_real, feats_disc_real,
loss_duration,
): ):
""" """
Shapes: Shapes:
- wavefrom: :math:`[B, 1, T]` - waveform : :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]` - waveform_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]` - z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]` - logs_q: :math:`[B, C, T]`
@ -615,12 +617,14 @@ class VitsGeneratorLoss(nn.Module):
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
# pass losses to the dict # pass losses to the dict
return_dict["loss_gen"] = loss_gen return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl return_dict["loss_kl"] = loss_kl
return_dict["loss_feat"] = loss_feat return_dict["loss_feat"] = loss_feat
return_dict["loss_mel"] = loss_mel return_dict["loss_mel"] = loss_mel
return_dict["loss_duration"] = loss_duration
return_dict["loss"] = loss return_dict["loss"] = loss
return return_dict return return_dict
@ -651,7 +655,6 @@ class VitsDiscriminatorLoss(nn.Module):
return_dict = {} return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake) loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + loss_disc loss = loss + return_dict["loss_disc"]
return_dict["loss_disc"] = loss_disc
return_dict["loss"] = loss return_dict["loss"] = loss
return return_dict return return_dict

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
class GST(nn.Module): class GST(nn.Module):

View File

@ -388,8 +388,8 @@ class Decoder(nn.Module):
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
# Pass through the decoder RNNs # Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)): for idx, decoder_rnn in enumerate(self.decoder_rnns):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx]) self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx])
# Residual connection # Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input decoder_output = decoder_input

View File

@ -2,7 +2,7 @@ import torch
from torch import nn from torch import nn
from torch.nn.modules.conv import Conv1d from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
class DiscriminatorS(torch.nn.Module): class DiscriminatorS(torch.nn.Module):
@ -60,18 +60,32 @@ class VitsDiscriminator(nn.Module):
def __init__(self, use_spectral_norm=False): def __init__(self, use_spectral_norm=False):
super().__init__() super().__init__()
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm) periods = [2, 3, 5, 7, 11]
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
def forward(self, x): self.nets = nn.ModuleList()
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
def forward(self, x, x_hat=None):
""" """
Args: Args:
x (Tensor): input waveform. x (Tensor): ground truth waveform.
x_hat (Tensor): predicted waveform.
Returns: Returns:
List[Tensor]: discriminator scores. List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator. List[List[Tensor]]: list of list of features from each layers of each discriminator.
""" """
scores, feats = self.mpd(x) x_scores = []
score_sd, feats_sd = self.sd(x) x_hat_scores = [] if x_hat is not None else None
return scores + [score_sd], feats + [feats_sd] x_feats = []
x_hat_feats = [] if x_hat is not None else None
for net in self.nets:
x_score, x_feat = net(x)
x_scores.append(x_score)
x_feats.append(x_feat)
if x_hat is not None:
x_hat_score, x_hat_feat = net(x_hat)
x_hat_scores.append(x_hat_score)
x_hat_feats.append(x_hat_feat)
return x_scores, x_feats, x_hat_scores, x_hat_feats

View File

@ -228,7 +228,7 @@ class StochasticDurationPredictor(nn.Module):
h = self.post_pre(dr) h = self.post_pre(dr)
h = self.post_convs(h, x_mask) h = self.post_convs(h, x_mask)
h = self.post_proj(h) * x_mask h = self.post_proj(h) * x_mask
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = noise z_q = noise
# posterior encoder # posterior encoder

View File

@ -2,8 +2,8 @@ from dataclasses import dataclass, field
from typing import Dict, Tuple from typing import Dict, Tuple
import torch import torch
import torch.nn as nn
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn
from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder

View File

@ -2,6 +2,7 @@ import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
import torch.distributed as dist
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -164,7 +165,14 @@ class BaseTTS(BaseModel):
} }
def get_data_loader( def get_data_loader(
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int self,
config: Coqpit,
ap: AudioProcessor,
is_eval: bool,
data_items: List,
verbose: bool,
num_gpus: int,
rank: int = None,
) -> "DataLoader": ) -> "DataLoader":
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
@ -186,7 +194,7 @@ class BaseTTS(BaseModel):
if hasattr(self, "make_symbols"): if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config) custom_symbols = self.make_symbols(self.config)
# init dataloader # init dataset
dataset = TTSDataset( dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1, outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner, text_cleaner=config.text_cleaner,
@ -212,13 +220,15 @@ class BaseTTS(BaseModel):
else None, else None,
) )
if config.use_phonemes and config.compute_input_seq_cache: # pre-compute phonemes
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
if hasattr(self, "eval_data_items") and is_eval: if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval: elif hasattr(self, "train_data_items") and not is_eval:
dataset.items = self.train_data_items dataset.items = self.train_data_items
else: else:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes for precise estimate of sequence lengths.
# otherwise `dataset.sort_items()` uses raw text lengths
dataset.compute_input_seq(config.num_loader_workers) dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution # TODO: find a more efficient solution
@ -228,9 +238,17 @@ class BaseTTS(BaseModel):
else: else:
self.train_data_items = dataset.items self.train_data_items = dataset.items
dataset.sort_items() # halt DDP processes for the main process to finish computing the phoneme cache
if num_gpus > 1:
dist.barrier()
# sort input sequences from short to long
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# init dataloader
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size, batch_size=config.eval_batch_size if is_eval else config.batch_size,

View File

@ -1,4 +1,6 @@
import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
@ -11,8 +13,6 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import get_speaker_manager
@ -177,7 +177,7 @@ class VitsArgs(Coqpit):
num_layers_text_encoder: int = 6 num_layers_text_encoder: int = 6
kernel_size_text_encoder: int = 3 kernel_size_text_encoder: int = 3
dropout_p_text_encoder: int = 0.1 dropout_p_text_encoder: int = 0.1
dropout_p_duration_predictor: int = 0.1 dropout_p_duration_predictor: int = 0.5
kernel_size_posterior_encoder: int = 5 kernel_size_posterior_encoder: int = 5
dilation_rate_posterior_encoder: int = 1 dilation_rate_posterior_encoder: int = 1
num_layers_posterior_encoder: int = 16 num_layers_posterior_encoder: int = 16
@ -195,7 +195,7 @@ class VitsArgs(Coqpit):
inference_noise_scale: float = 0.667 inference_noise_scale: float = 0.667
length_scale: int = 1 length_scale: int = 1
noise_scale_dp: float = 1.0 noise_scale_dp: float = 1.0
inference_noise_scale_dp: float = 0.8 inference_noise_scale_dp: float = 1.0
max_inference_len: int = None max_inference_len: int = None
init_discriminator: bool = True init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False use_spectral_norm_disriminator: bool = False
@ -419,24 +419,23 @@ class Vits(BaseTTS):
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad(): with torch.no_grad():
o_scale = torch.exp(-2 * logs_p) o_scale = torch.exp(-2 * logs_p)
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# duration predictor # duration predictor
attn_durations = attn.sum(3) attn_durations = attn.sum(3)
if self.args.use_sdp: if self.args.use_sdp:
nll_duration = self.duration_predictor( loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x, x.detach() if self.args.detach_dp_input else x,
x_mask, x_mask,
attn_durations, attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g, g=g.detach() if self.args.detach_dp_input and g is not None else g,
) )
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask)) loss_duration = loss_duration / torch.sum(x_mask)
outputs["nll_duration"] = nll_duration
else: else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor( log_durations = self.duration_predictor(
@ -445,7 +444,7 @@ class Vits(BaseTTS):
g=g.detach() if self.args.detach_dp_input and g is not None else g, g=g.detach() if self.args.detach_dp_input and g is not None else g,
) )
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration outputs["loss_duration"] = loss_duration
# expand prior # expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
@ -563,8 +562,9 @@ class Vits(BaseTTS):
outputs["waveform_seg"] = wav_seg outputs["waveform_seg"] = wav_seg
# compute discriminator scores and features # compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"]) outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
_, outputs["feats_disc_real"] = self.disc(wav_seg) outputs["model_outputs"], wav_seg
)
# compute losses # compute losses
with autocast(enabled=False): # use float32 for the criterion with autocast(enabled=False): # use float32 for the criterion
@ -579,23 +579,17 @@ class Vits(BaseTTS):
scores_disc_fake=outputs["scores_disc_fake"], scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"], feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"],
) )
# handle the duration loss
if self.args.use_sdp:
loss_dict["nll_duration"] = outputs["nll_duration"]
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["nll_duration"]
elif optimizer_idx == 1: elif optimizer_idx == 1:
# discriminator pass # discriminator pass
outputs = {} outputs = {}
# compute scores and features # compute scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach()) outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache) self.y_disc_cache.detach(), self.wav_seg_disc_cache
)
# compute loss # compute loss
with autocast(enabled=False): # use float32 for the criterion with autocast(enabled=False): # use float32 for the criterion
@ -686,14 +680,21 @@ class Vits(BaseTTS):
Returns: Returns:
List: optimizers. List: optimizers.
""" """
self.disc.requires_grad_(False) gen_parameters = chain(
gen_parameters = filter(lambda p: p.requires_grad, self.parameters()) self.text_encoder.parameters(),
self.disc.requires_grad_(True) self.posterior_encoder.parameters(),
optimizer1 = get_optimizer( self.flow.parameters(),
self.duration_predictor.parameters(),
self.waveform_decoder.parameters(),
)
# add the speaker embedding layer
if hasattr(self, "emb_g"):
gen_parameters = chain(gen_parameters, self.emb_g)
optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
) )
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
return [optimizer1, optimizer2] return [optimizer0, optimizer1]
def get_lr(self) -> List: def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer. """Set the initial learning rates for each optimizer.
@ -712,9 +713,9 @@ class Vits(BaseTTS):
Returns: Returns:
List: Schedulers, one for each optimizer. List: Schedulers, one for each optimizer.
""" """
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2] return [scheduler0, scheduler1]
def get_criterion(self): def get_criterion(self):
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in

View File

@ -225,9 +225,10 @@ def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_sy
if custom_symbols is not None: if custom_symbols is not None:
_symbols = custom_symbols _symbols = custom_symbols
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
elif tp: elif tp:
_symbols, _ = make_symbols(**tp) _symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)} _id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = "" result = ""
for symbol_id in sequence: for symbol_id in sequence:

View File

@ -1,8 +1,6 @@
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py # edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Variable
def reduce_tensor(tensor, num_gpus): def reduce_tensor(tensor, num_gpus):
@ -20,46 +18,3 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
# Initialize distributed communication # Initialize distributed communication
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
def apply_gradient_allreduce(module):
# sync model parameters
for p in module.state_dict().values():
if not torch.is_tensor(p):
continue
dist.broadcast(p, 0)
def allreduce_params():
if module.needs_reduction:
module.needs_reduction = False
# bucketing params based on value types
buckets = {}
for param in module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.reduce_op.SUM)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(module.parameters()):
def allreduce_hook(*_):
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access
if param.requires_grad:
param.register_hook(allreduce_hook)
def set_needs_reduction(self, *_):
self.needs_reduction = True
module.register_forward_hook(set_needs_reduction)
return module

View File

@ -53,7 +53,6 @@ def get_commit_hash():
# Not copying .git folder into docker container # Not copying .git folder into docker container
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
commit = "0000000" commit = "0000000"
print(" > Git Hash: {}".format(commit))
return commit return commit
@ -62,7 +61,6 @@ def get_experiment_folder_path(root_path, model_name):
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash() commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
print(" > Experiment folder: {}".format(output_folder))
return output_folder return output_folder

View File

@ -3,7 +3,7 @@ import json
import os import os
import pickle as pickle_tts import pickle as pickle_tts
import shutil import shutil
from typing import Any from typing import Any, Callable, Dict, Union
import fsspec import fsspec
import torch import torch
@ -53,18 +53,23 @@ def copy_model_files(config: Coqpit, out_path, new_fields):
shutil.copyfileobj(source_file, target_file) shutil.copyfileobj(source_file, target_file)
def load_fsspec(path: str, **kwargs) -> Any: def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
**kwargs,
) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://). """Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args: Args:
path: Any path or url supported by fsspec. path: Any path or url supported by fsspec.
map_location: torch.device or str.
**kwargs: Keyword arguments forwarded to torch.load. **kwargs: Keyword arguments forwarded to torch.load.
Returns: Returns:
Object stored in path. Object stored in path.
""" """
with fsspec.open(path, "rb") as f: with fsspec.open(path, "rb") as f:
return torch.load(f, **kwargs) return torch.load(f, map_location=map_location, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin

View File

@ -3,7 +3,7 @@ from TTS.utils.logging.tensorboard_logger import TensorboardLogger
from TTS.utils.logging.wandb_logger import WandbLogger from TTS.utils.logging.wandb_logger import WandbLogger
def init_logger(config): def init_dashboard_logger(config):
if config.dashboard_logger == "tensorboard": if config.dashboard_logger == "tensorboard":
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model) dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)

View File

@ -29,11 +29,13 @@ class ConsoleLogger:
now = datetime.datetime.now() now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S") return now.strftime("%Y-%m-%d %H:%M:%S")
def print_epoch_start(self, epoch, max_epoch): def print_epoch_start(self, epoch, max_epoch, output_path=None):
print( print(
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
flush=True, flush=True,
) )
if output_path is not None:
print(f" --> {output_path}")
def print_train_start(self): def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")

View File

@ -110,7 +110,7 @@ class ModelManager(object):
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}") print(f" > Downloading model to {output_path}")
output_stats_path = os.path.join(output_path, "scale_stats.npy") output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_speakers_path = os.path.join(output_path, "speakers.json")
# download files to the output path # download files to the output path
if self._check_dict_key(model_item, "github_rls_url"): if self._check_dict_key(model_item, "github_rls_url"):
# download from github release # download from github release
@ -122,22 +122,52 @@ class ModelManager(object):
if self._check_dict_key(model_item, "stats_file"): if self._check_dict_key(model_item, "stats_file"):
self._download_gdrive_file(model_item["stats_file"], output_stats_path) self._download_gdrive_file(model_item["stats_file"], output_stats_path)
# update the scale_path.npy file path in the model config.json # update paths in the config.json
if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path): self._update_paths(output_path, output_config_path)
# set scale stats path in config.json
config_path = output_config_path
config = load_config(config_path)
config.audio.stats_path = output_stats_path
config.save_json(config_path)
# update the speakers.json file path in the model config.json to the current path
if os.path.exists(output_speakers_path):
# set scale stats path in config.json
config_path = output_config_path
config = load_config(config_path)
config.d_vector_file = output_speakers_path
config.save_json(config_path)
return output_model_path, output_config_path, model_item return output_model_path, output_config_path, model_item
def _update_paths(self, output_path: str, config_path: str) -> None:
"""Update paths for certain files in config.json after download.
Args:
output_path (str): local path the model is downloaded to.
config_path (str): local config.json path.
"""
output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
# update the scale_path.npy file path in the model config.json
self._update_path("audio.stats_path", output_stats_path, config_path)
# update the speakers.json file path in the model config.json to the current path
self._update_path("d_vector_file", output_d_vector_file_path, config_path)
self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
# update the speaker_ids.json file path in the model config.json to the current path
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
@staticmethod
def _update_path(field_name, new_path, config_path):
"""Update the path in the model config.json for the current environment after download"""
if os.path.exists(new_path):
config = load_config(config_path)
field_names = field_name.split(".")
if len(field_names) > 1:
# field name points to a sub-level field
sub_conf = config
for fd in field_names[:-1]:
if fd in sub_conf:
sub_conf = sub_conf[fd]
else:
return
sub_conf[field_names[-1]] = new_path
else:
# field name points to a top-level field
config[field_name] = new_path
config.save_json(config_path)
def _download_gdrive_file(self, gdrive_idx, output): def _download_gdrive_file(self, gdrive_idx, output):
"""Download files from GDrive using their file ids""" """Download files from GDrive using their file ids"""
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False) gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)

View File

@ -1,5 +1,5 @@
import importlib import importlib
from typing import Dict, List from typing import Dict, List, Tuple
import torch import torch
@ -10,9 +10,20 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None return importlib.util.find_spec("apex") is not None
def setup_torch_training_env(cudnn_enable, cudnn_benchmark): def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]:
"""Setup PyTorch environment for training.
Args:
cudnn_enable (bool): Enable/disable CUDNN.
cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is
variable between batches.
use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise.
Returns:
Tuple[bool, int]: is cuda on or off and number of GPUs in the environment.
"""
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
if num_gpus > 1: if num_gpus > 1 and not use_ddp:
raise RuntimeError( raise RuntimeError(
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`." f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
) )

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
@ -95,7 +96,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# model specific params # model specific params
discriminator_model: str = "univnet_discriminator" discriminator_model: str = "univnet_discriminator"
generator_model: str = "univnet_generator" generator_model: str = "univnet_generator"
generator_model_params: dict = field( generator_model_params: Dict = field(
default_factory=lambda: { default_factory=lambda: {
"in_channels": 64, "in_channels": 64,
"out_channels": 1, "out_channels": 1,
@ -120,7 +121,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# loss weights - overrides # loss weights - overrides
stft_loss_weight: float = 2.5 stft_loss_weight: float = 2.5
stft_loss_params: dict = field( stft_loss_params: Dict = field(
default_factory=lambda: { default_factory=lambda: {
"n_ffts": [1024, 2048, 512], "n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50], "hop_lengths": [120, 240, 50],
@ -132,7 +133,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
hinge_G_loss_weight: float = 0 hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 0 feat_match_loss_weight: float = 0
l1_spec_loss_weight: float = 0 l1_spec_loss_weight: float = 0
l1_spec_loss_params: dict = field( l1_spec_loss_params: Dict = field(
default_factory=lambda: { default_factory=lambda: {
"use_mel": True, "use_mel": True,
"sample_rate": 22050, "sample_rate": 22050,
@ -152,7 +153,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) # lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) # lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0}) optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
steps_to_start_discriminator: int = 200000 steps_to_start_discriminator: int = 200000
def __post_init__(self): def __post_init__(self):

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm

View File

@ -1,8 +1,8 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec

View File

@ -40,8 +40,8 @@ class MelganMultiscaleDiscriminator(nn.Module):
) )
def forward(self, x): def forward(self, x):
scores = list() scores = []
feats = list() feats = []
for disc in self.discriminators: for disc in self.discriminators:
score, feat = disc(x) score, feat = disc(x)
scores.append(score) scores.append(score)

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm, weight_norm from torch.nn.utils import spectral_norm, weight_norm
from TTS.utils.audio import TorchSTFT from TTS.utils.audio import TorchSTFT

View File

@ -9,12 +9,12 @@ from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results
@ -33,7 +33,7 @@ class WavegradArgs(Coqpit):
) )
class Wavegrad(BaseModel): class Wavegrad(BaseVocoder):
"""🐸 🌊 WaveGrad 🌊 model. """🐸 🌊 WaveGrad 🌊 model.
Paper - https://arxiv.org/abs/2009.00713 Paper - https://arxiv.org/abs/2009.00713
@ -257,14 +257,18 @@ class Wavegrad(BaseModel):
loss = criterion(noise, noise_hat) loss = criterion(noise, noise_hat)
return {"model_output": noise_hat}, {"loss": loss} return {"model_output": noise_hat}, {"loss": loss}
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def train_log( # pylint: disable=no-self-use
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
return None, None return None, None
@torch.no_grad() @torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def eval_log( # pylint: disable=no-self-use
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
return None, None return None, None
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
@ -291,7 +295,8 @@ class Wavegrad(BaseModel):
def get_scheduler(self, optimizer): def get_scheduler(self, optimizer):
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
def get_criterion(self): @staticmethod
def get_criterion():
return torch.nn.L1Loss() return torch.nn.L1Loss()
@staticmethod @staticmethod

View File

@ -5,9 +5,9 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler

View File

@ -3,28 +3,24 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [ "source": [
"%matplotlib inline\n", "%matplotlib inline\n",
"\n", "\n",
"from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.audio import AudioProcessor\n",
"from TTS.tts.utils.visual import plot_spectrogram\n", "from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.io import load_config\n", "from TTS.config import load_config\n",
"\n", "\n",
"import IPython.display as ipd\n", "import IPython.display as ipd\n",
"import glob" "import glob"
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [ "source": [
"config_path = \"/home/erogol/gdrive/Projects/TTS/recipes/ljspeech/align_tts/config_transformer2.json\"\n", "config_path = \"/home/erogol/gdrive/Projects/TTS/recipes/ljspeech/align_tts/config_transformer2.json\"\n",
"data_path = \"/home/erogol/gdrive/Datasets/LJSpeech-1.1/\"\n", "data_path = \"/home/erogol/gdrive/Datasets/LJSpeech-1.1/\"\n",
@ -39,28 +35,28 @@
"\n", "\n",
"print(\"File list, by index:\")\n", "print(\"File list, by index:\")\n",
"dict(enumerate(file_paths))" "dict(enumerate(file_paths))"
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [ "source": [
"### Setup Audio Processor\n", "### Setup Audio Processor\n",
"Play with the AP parameters until you find a good fit with the synthesis speech below.\n", "Play with the AP parameters until you find a good fit with the synthesis speech below.\n",
"\n", "\n",
"The default values are loaded from your config.json file, so you only need to\n", "The default values are loaded from your config.json file, so you only need to\n",
"uncomment and modify values below that you'd like to tune." "uncomment and modify values below that you'd like to tune."
] ],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [ "source": [
"tune_params={\n", "tune_params={\n",
"# 'audio_processor': 'audio',\n", "# 'audio_processor': 'audio',\n",
@ -95,54 +91,54 @@
"tuned_config.update(tune_params)\n", "tuned_config.update(tune_params)\n",
"\n", "\n",
"AP = AudioProcessor(**tuned_config);" "AP = AudioProcessor(**tuned_config);"
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [ "source": [
"### Check audio loading " "### Check audio loading "
] ],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [ "source": [
"wav = AP.load_wav(SAMPLE_FILE_PATH)\n", "wav = AP.load_wav(SAMPLE_FILE_PATH)\n",
"ipd.Audio(data=wav, rate=AP.sample_rate) " "ipd.Audio(data=wav, rate=AP.sample_rate) "
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [ "source": [
"### Generate Mel-Spectrogram and Re-synthesis with GL" "### Generate Mel-Spectrogram and Re-synthesis with GL"
] ],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"AP.power = 1.5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"Collapsed": "false" "Collapsed": "false"
}, }
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"AP.power = 1.5"
],
"outputs": [], "outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [ "source": [
"mel = AP.melspectrogram(wav)\n", "mel = AP.melspectrogram(wav)\n",
"print(\"Max:\", mel.max())\n", "print(\"Max:\", mel.max())\n",
@ -152,24 +148,24 @@
"\n", "\n",
"wav_gen = AP.inv_melspectrogram(mel)\n", "wav_gen = AP.inv_melspectrogram(mel)\n",
"ipd.Audio(wav_gen, rate=AP.sample_rate)" "ipd.Audio(wav_gen, rate=AP.sample_rate)"
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [ "source": [
"### Generate Linear-Spectrogram and Re-synthesis with GL" "### Generate Linear-Spectrogram and Re-synthesis with GL"
] ],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [ "source": [
"spec = AP.spectrogram(wav)\n", "spec = AP.spectrogram(wav)\n",
"print(\"Max:\", spec.max())\n", "print(\"Max:\", spec.max())\n",
@ -179,26 +175,26 @@
"\n", "\n",
"wav_gen = AP.inv_spectrogram(spec)\n", "wav_gen = AP.inv_spectrogram(spec)\n",
"ipd.Audio(wav_gen, rate=AP.sample_rate)" "ipd.Audio(wav_gen, rate=AP.sample_rate)"
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [ "source": [
"### Compare values for a certain parameter\n", "### Compare values for a certain parameter\n",
"\n", "\n",
"Optimize your parameters by comparing different values per parameter at a time." "Optimize your parameters by comparing different values per parameter at a time."
] ],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [ "source": [
"from librosa import display\n", "from librosa import display\n",
"from matplotlib import pylab as plt\n", "from matplotlib import pylab as plt\n",
@ -238,36 +234,39 @@
" val = values[idx]\n", " val = values[idx]\n",
" print(\" > {} = {}\".format(attribute, val))\n", " print(\" > {} = {}\".format(attribute, val))\n",
" IPython.display.display(IPython.display.Audio(wav_gen, rate=AP.sample_rate))" " IPython.display.display(IPython.display.Audio(wav_gen, rate=AP.sample_rate))"
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [ "source": [
"compare_values(\"preemphasis\", [0, 0.5, 0.97, 0.98, 0.99])" "compare_values(\"preemphasis\", [0, 0.5, 0.97, 0.98, 0.99])"
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [ "source": [
"compare_values(\"ref_level_db\", [2, 5, 10, 15, 20, 25, 30, 35, 40, 1000])" "compare_values(\"ref_level_db\", [2, 5, 10, 15, 20, 25, 30, 35, 40, 1000])"
] ],
"outputs": [],
"metadata": {
"Collapsed": "false"
}
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "name": "python3",
"language": "python", "display_name": "Python 3.8.5 64-bit ('torch': conda)"
"name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -280,8 +279,11 @@
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.5" "version": "3.8.5"
},
"interpreter": {
"hash": "27648abe09795c3a768a281b31f7524fcf66a207e733f8ecda3a4e1fd1059fb0"
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }

View File

@ -43,7 +43,7 @@ def process_meta_data(path):
meta_data = {} meta_data = {}
# load meta data # load meta data
with open(path, "r") as f: with open(path, "r", encoding="utf-8") as f:
data = csv.reader(f, delimiter="|") data = csv.reader(f, delimiter="|")
for row in data: for row in data:
frames = int(row[2]) frames = int(row[2])
@ -92,7 +92,7 @@ def save_training(file_path, meta_data):
rows.append(d["row"] + "\n") rows.append(d["row"] + "\n")
random.shuffle(rows) random.shuffle(rows)
with open(file_path, "w+") as f: with open(file_path, "w+", encoding="utf-8") as f:
for row in rows: for row in rows:
f.write(row) f.write(row)
@ -156,7 +156,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
phonemes = {} phonemes = {}
with open(train_path, "r") as f: with open(train_path, "r", encoding="utf-8") as f:
data = csv.reader(f, delimiter="|") data = csv.reader(f, delimiter="|")
phonemes["None"] = 0 phonemes["None"] = 0
for row in data: for row in data:
@ -174,9 +174,9 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
phonemes["None"] += 1 phonemes["None"] += 1
x, y = [], [] x, y = [], []
for key in phonemes: for k, v in phonemes.items():
x.append(key) x.append(k)
y.append(phonemes[key]) y.append(v)
plt.figure() plt.figure()
plt.rcParams["figure.figsize"] = (50, 20) plt.rcParams["figure.figsize"] = (50, 20)

View File

@ -29,7 +29,7 @@ config = VitsConfig(
run_name="vits_ljspeech", run_name="vits_ljspeech",
batch_size=48, batch_size=48,
eval_batch_size=16, eval_batch_size=16,
batch_group_size=0, batch_group_size=5,
num_loader_workers=4, num_loader_workers=4,
num_eval_loader_workers=4, num_eval_loader_workers=4,
run_eval=True, run_eval=True,
@ -43,7 +43,7 @@ config = VitsConfig(
print_step=25, print_step=25,
print_eval=True, print_eval=True,
mixed_precision=True, mixed_precision=True,
max_seq_len=5000, max_seq_len=500000,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],
) )

View File

@ -2,4 +2,4 @@ black
coverage coverage
isort isort
nose nose
pylint==2.8.3 pylint==2.10.2

View File

@ -124,7 +124,7 @@ class TestTTSDataset(unittest.TestCase):
avg_length = mel_lengths.numpy().mean() avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length assert avg_length >= last_length
dataloader.dataset.sort_items() dataloader.dataset.sort_and_filter_items()
is_items_reordered = False is_items_reordered = False
for idx, item in enumerate(dataloader.dataset.items): for idx, item in enumerate(dataloader.dataset.items):
if item != frames[idx]: if item != frames[idx]: