mirror of https://github.com/coqui-ai/TTS.git
commit
5793dca71d
|
@ -155,4 +155,5 @@ deps.json
|
|||
speakers.json
|
||||
internal/*
|
||||
*_pitch.npy
|
||||
*_phoneme.npy
|
||||
*_phoneme.npy
|
||||
wandb
|
|
@ -64,6 +64,11 @@ disable=missing-docstring,
|
|||
too-many-public-methods,
|
||||
too-many-lines,
|
||||
bare-except,
|
||||
## for avoiding weird p3.6 CI linter error
|
||||
## TODO: see later if we can remove this
|
||||
assigning-non-slot,
|
||||
unsupported-assignment-operation,
|
||||
## end
|
||||
line-too-long,
|
||||
fixme,
|
||||
wrong-import-order,
|
||||
|
@ -73,6 +78,7 @@ disable=missing-docstring,
|
|||
invalid-name,
|
||||
too-many-instance-attributes,
|
||||
arguments-differ,
|
||||
arguments-renamed,
|
||||
no-name-in-module,
|
||||
no-member,
|
||||
unsubscriptable-object,
|
||||
|
|
|
@ -102,7 +102,7 @@ You can also help us implement more models.
|
|||
## Install TTS
|
||||
🐸TTS is tested on Ubuntu 18.04 with **python >= 3.6, < 3.9**.
|
||||
|
||||
If you are only interested in [synthesizing speech](https://github.com/coqui-ai/TTS/tree/dev#example-synthesizing-speech-on-terminal-using-the-released-models) with the released 🐸TTS models, installing from PyPI is the easiest option.
|
||||
If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option.
|
||||
|
||||
```bash
|
||||
pip install TTS
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.2.0
|
||||
0.2.1
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f:
|
||||
with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f:
|
||||
version = f.read().strip()
|
||||
|
||||
__version__ = version
|
||||
|
|
|
@ -97,7 +97,7 @@ Example run:
|
|||
enable_eos_bos=C.enable_eos_bos_chars,
|
||||
)
|
||||
|
||||
dataset.sort_items()
|
||||
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
|
@ -158,7 +158,7 @@ Example run:
|
|||
# ourput metafile
|
||||
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
||||
|
||||
with open(metafile, "w") as f:
|
||||
with open(metafile, "w", encoding="utf-8") as f:
|
||||
for p in file_paths:
|
||||
f.write(f"{p[0]}|{p[1]}\n")
|
||||
print(f" >> Metafile created: {metafile}")
|
||||
|
|
|
@ -32,6 +32,7 @@ def main():
|
|||
command.append("--restore_path={}".format(args.restore_path))
|
||||
command.append("--config_path={}".format(args.config_path))
|
||||
command.append("--group_id=group_{}".format(group_id))
|
||||
command.append("--use_ddp=true")
|
||||
command += unargs
|
||||
command.append("")
|
||||
|
||||
|
@ -42,7 +43,7 @@ def main():
|
|||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||
command[-1] = "--rank={}".format(i)
|
||||
# prevent stdout for processes with rank != 0
|
||||
stdout = None if i == 0 else open(os.devnull, "w")
|
||||
stdout = None
|
||||
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
|
||||
processes.append(p)
|
||||
print(command)
|
||||
|
|
|
@ -46,7 +46,7 @@ def setup_loader(ap, r, verbose=False):
|
|||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(c.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
@ -215,7 +215,7 @@ def extract_spectrograms(
|
|||
wav = ap.inv_melspectrogram(mel)
|
||||
ap.save_wav(wav, wav_gl_path)
|
||||
|
||||
with open(os.path.join(output_path, metada_name), "w") as f:
|
||||
with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
|
||||
for data in export_metadata:
|
||||
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ class BaseTrainingConfig(Coqpit):
|
|||
Name of the model that is used in the training.
|
||||
|
||||
run_name (str):
|
||||
Name of the experiment. This prefixes the output folder name.
|
||||
Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
|
||||
|
||||
run_description (str):
|
||||
Short description of the experiment.
|
||||
|
@ -272,7 +272,7 @@ class BaseTrainingConfig(Coqpit):
|
|||
"""
|
||||
|
||||
model: str = None
|
||||
run_name: str = ""
|
||||
run_name: str = "coqui_tts"
|
||||
run_description: str = ""
|
||||
# training params
|
||||
epochs: int = 10000
|
||||
|
|
18
TTS/model.py
18
TTS/model.py
|
@ -23,35 +23,31 @@ class BaseModel(nn.Module, ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict:
|
||||
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
|
||||
"""Forward pass for the model mainly used in training.
|
||||
|
||||
You can be flexible here and use different number of arguments and argument names since it is mostly used by
|
||||
`train_step()` in training whitout exposing it to the out of the class.
|
||||
You can be flexible here and use different number of arguments and argument names since it is intended to be
|
||||
used by `train_step()` without exposing it out of the model.
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): Input text character sequence ids.
|
||||
input (torch.Tensor): Input tensor.
|
||||
aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
|
||||
for the model.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model.
|
||||
Dict: Model outputs. Main model output must be named as "model_outputs".
|
||||
"""
|
||||
outputs_dict = {"model_outputs": None}
|
||||
...
|
||||
return outputs_dict
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, text: torch.Tensor, aux_input={}) -> Dict:
|
||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
||||
"""Forward pass for inference.
|
||||
|
||||
After the model is trained this is the only function that connects the model the out world.
|
||||
|
||||
This function must only take a `text` input and a dictionary that has all the other model specific inputs.
|
||||
We don't use `*kwargs` since it is problematic with the TorchScript API.
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): [description]
|
||||
input (torch.Tensor): [description]
|
||||
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
# adapted from https://github.com/cvqluu/GE2E-Loss
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
from coqpit import MISSING
|
||||
|
||||
|
@ -14,7 +14,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
|
|||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# model params
|
||||
model_params: dict = field(
|
||||
model_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"model_name": "lstm",
|
||||
"input_dim": 80,
|
||||
|
@ -25,9 +25,9 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
|
|||
}
|
||||
)
|
||||
|
||||
audio_augmentation: dict = field(default_factory=lambda: {})
|
||||
audio_augmentation: Dict = field(default_factory=lambda: {})
|
||||
|
||||
storage: dict = field(
|
||||
storage: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 15, # the size of the in-memory storage with respect to a single batch
|
||||
|
|
|
@ -94,7 +94,8 @@ def download_and_extract(directory, subset, urls):
|
|||
extract_path = zip_filepath.strip(".zip")
|
||||
|
||||
# check zip file md5sum
|
||||
md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest()
|
||||
with open(zip_filepath, "rb") as f_zip:
|
||||
md5 = hashlib.md5(f_zip.read()).hexdigest()
|
||||
if md5 != MD5SUM[subset]:
|
||||
raise ValueError("md5sum of %s mismatch" % zip_filepath)
|
||||
|
||||
|
|
188
TTS/trainer.py
188
TTS/trainer.py
|
@ -1,7 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
|
@ -16,6 +15,7 @@ from urllib.parse import urlparse
|
|||
|
||||
import fsspec
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
|
@ -38,7 +38,7 @@ from TTS.utils.generic_utils import (
|
|||
to_cuda,
|
||||
)
|
||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
|
@ -77,12 +77,16 @@ class TrainingArgs(Coqpit):
|
|||
best_path: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
|
||||
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
|
||||
},
|
||||
)
|
||||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
||||
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
||||
use_ddp: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
|
||||
)
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
@ -144,6 +148,7 @@ class Trainer:
|
|||
>>> trainer.fit()
|
||||
|
||||
TODO:
|
||||
- Wrap model for not calling .module in DDP.
|
||||
- Accumulate gradients b/w batches.
|
||||
- Deepspeed integration
|
||||
- Profiler integration.
|
||||
|
@ -151,29 +156,33 @@ class Trainer:
|
|||
- TPU training
|
||||
"""
|
||||
|
||||
# set and initialize Pytorch runtime
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
||||
if config is None:
|
||||
# parse config from console arguments
|
||||
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
||||
|
||||
self.output_path = output_path
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.output_path = output_path
|
||||
self.config.output_log_path = output_path
|
||||
|
||||
# setup logging
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
||||
# set and initialize Pytorch runtime
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
||||
|
||||
# init loggers
|
||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||
self.dashboard_logger = dashboard_logger
|
||||
|
||||
if self.dashboard_logger is None:
|
||||
self.dashboard_logger = init_logger(config)
|
||||
# only allow dashboard logging for the main process in DDP mode
|
||||
if self.dashboard_logger is None and args.rank == 0:
|
||||
self.dashboard_logger = init_dashboard_logger(config)
|
||||
|
||||
if not self.config.log_model_step:
|
||||
self.config.log_model_step = self.config.save_step
|
||||
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
||||
self.total_steps_done = 0
|
||||
self.epochs_done = 0
|
||||
self.restore_step = 0
|
||||
|
@ -247,10 +256,10 @@ class Trainer:
|
|||
if self.use_apex:
|
||||
self.scaler = None
|
||||
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
|
||||
if isinstance(self.optimizer, list):
|
||||
self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
|
||||
else:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
# if isinstance(self.optimizer, list):
|
||||
# self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
|
||||
# else:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
else:
|
||||
self.scaler = None
|
||||
|
||||
|
@ -319,14 +328,14 @@ class Trainer:
|
|||
return obj
|
||||
|
||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||
checkpoint = load_fsspec(restore_path)
|
||||
checkpoint = load_fsspec(restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
|
||||
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
print(" > Restoring Scaler...")
|
||||
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization...")
|
||||
|
@ -346,10 +355,11 @@ class Trainer:
|
|||
" > Model restored from step %d" % checkpoint["step"],
|
||||
)
|
||||
restore_step = checkpoint["step"]
|
||||
torch.cuda.empty_cache()
|
||||
return model, optimizer, scaler, restore_step
|
||||
|
||||
@staticmethod
|
||||
def _get_loader(
|
||||
self,
|
||||
model: nn.Module,
|
||||
config: Coqpit,
|
||||
ap: AudioProcessor,
|
||||
|
@ -358,8 +368,14 @@ class Trainer:
|
|||
verbose: bool,
|
||||
num_gpus: int,
|
||||
) -> DataLoader:
|
||||
if hasattr(model, "get_data_loader"):
|
||||
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
||||
if num_gpus > 1:
|
||||
if hasattr(model.module, "get_data_loader"):
|
||||
loader = model.module.get_data_loader(
|
||||
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
|
||||
)
|
||||
else:
|
||||
if hasattr(model, "get_data_loader"):
|
||||
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
||||
return loader
|
||||
|
||||
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
|
||||
|
@ -387,12 +403,28 @@ class Trainer:
|
|||
Returns:
|
||||
Dict: Formatted batch.
|
||||
"""
|
||||
batch = self.model.format_batch(batch)
|
||||
if self.num_gpus > 1:
|
||||
batch = self.model.module.format_batch(batch)
|
||||
else:
|
||||
batch = self.model.format_batch(batch)
|
||||
if self.use_cuda:
|
||||
for k, v in batch.items():
|
||||
batch[k] = to_cuda(v)
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def master_params(optimizer: torch.optim.Optimizer):
|
||||
"""Generator over parameters owned by the optimizer.
|
||||
|
||||
Used to select parameters used by the optimizer for gradient clipping.
|
||||
|
||||
Args:
|
||||
optimizer: Target optimizer.
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for p in group["params"]:
|
||||
yield p
|
||||
|
||||
@staticmethod
|
||||
def _model_train_step(
|
||||
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
|
||||
|
@ -450,6 +482,8 @@ class Trainer:
|
|||
step_start_time = time.time()
|
||||
# zero-out optimizer
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass and loss computation
|
||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||
if optimizer_idx is not None:
|
||||
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
|
||||
|
@ -461,9 +495,9 @@ class Trainer:
|
|||
step_time = time.time() - step_start_time
|
||||
return None, {}, step_time
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
|
||||
# # check nan loss
|
||||
# if torch.isnan(loss_dict["loss"]).any():
|
||||
# raise RuntimeError(f" > NaN loss detected - {loss_dict}")
|
||||
|
||||
# set gradient clipping threshold
|
||||
if "grad_clip" in config and config.grad_clip is not None:
|
||||
|
@ -481,6 +515,8 @@ class Trainer:
|
|||
update_lr_scheduler = True
|
||||
if self.use_amp_scaler:
|
||||
if self.use_apex:
|
||||
# TODO: verify AMP use for GAN training in TTS
|
||||
# https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
|
||||
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
@ -491,7 +527,9 @@ class Trainer:
|
|||
scaler.scale(loss_dict["loss"]).backward()
|
||||
if grad_clip > 0:
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.master_params(optimizer), grad_clip, error_if_nonfinite=False
|
||||
)
|
||||
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||
grad_norm = 0
|
||||
|
@ -571,7 +609,8 @@ class Trainer:
|
|||
total_step_time = 0
|
||||
for idx, optimizer in enumerate(self.optimizer):
|
||||
criterion = self.criterion
|
||||
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||
# scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||
scaler = self.scaler
|
||||
scheduler = self.scheduler[idx]
|
||||
outputs, loss_dict_new, step_time = self._optimize(
|
||||
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
||||
|
@ -592,13 +631,13 @@ class Trainer:
|
|||
outputs = outputs_per_optimizer
|
||||
|
||||
# update avg runtime stats
|
||||
keep_avg_update = dict()
|
||||
keep_avg_update = {}
|
||||
keep_avg_update["avg_loader_time"] = loader_time
|
||||
keep_avg_update["avg_step_time"] = step_time
|
||||
self.keep_avg_train.update_values(keep_avg_update)
|
||||
|
||||
# update avg loss stats
|
||||
update_eval_values = dict()
|
||||
update_eval_values = {}
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
self.keep_avg_train.update_values(update_eval_values)
|
||||
|
@ -662,9 +701,10 @@ class Trainer:
|
|||
if audios is not None:
|
||||
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
|
||||
self.dashboard_logger.flush()
|
||||
|
||||
self.total_steps_done += 1
|
||||
self.callbacks.on_train_step_end()
|
||||
self.dashboard_logger.flush()
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_epoch(self) -> None:
|
||||
|
@ -674,16 +714,20 @@ class Trainer:
|
|||
self.data_train,
|
||||
verbose=True,
|
||||
)
|
||||
self.model.train()
|
||||
if self.num_gpus > 1:
|
||||
self.model.module.train()
|
||||
else:
|
||||
self.model.train()
|
||||
epoch_start_time = time.time()
|
||||
if self.use_cuda:
|
||||
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
|
||||
else:
|
||||
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
||||
self.c_logger.print_train_start()
|
||||
loader_start_time = time.time()
|
||||
for cur_step, batch in enumerate(self.train_loader):
|
||||
loader_start_time = time.time()
|
||||
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
||||
loader_start_time = time.time()
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
# Plot self.epochs_done Stats
|
||||
if self.args.rank == 0:
|
||||
|
@ -753,7 +797,7 @@ class Trainer:
|
|||
loss_dict = self._detach_loss_dict(loss_dict)
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
update_eval_values = {}
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
self.keep_avg_eval.update_values(update_eval_values)
|
||||
|
@ -784,6 +828,7 @@ class Trainer:
|
|||
loader_time = time.time() - loader_start_time
|
||||
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
||||
outputs, _ = self.eval_step(batch, cur_step)
|
||||
loader_start_time = time.time()
|
||||
# plot epoch stats, artifacts and figures
|
||||
if self.args.rank == 0:
|
||||
figures, audios = None, None
|
||||
|
@ -800,7 +845,7 @@ class Trainer:
|
|||
def test_run(self) -> None:
|
||||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
|
||||
if self.eval_loader is None:
|
||||
self.eval_loader = self.get_eval_dataloader(
|
||||
self.ap,
|
||||
|
@ -810,27 +855,43 @@ class Trainer:
|
|||
|
||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||
if self.num_gpus > 1:
|
||||
figures, audios = self.model.module.test_run(self.ap, samples, None)
|
||||
else:
|
||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||
else:
|
||||
figures, audios = self.model.test_run(self.ap)
|
||||
if self.num_gpus > 1:
|
||||
figures, audios = self.model.module.test_run(self.ap)
|
||||
else:
|
||||
figures, audios = self.model.test_run(self.ap)
|
||||
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
def _restore_best_loss(self):
|
||||
"""Restore the best loss from the args.best_path if provided else
|
||||
from the model (`args.restore_path` or `args.continue_path`) used for resuming the training"""
|
||||
if self.restore_step != 0 or self.args.best_path:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"]
|
||||
ch = load_fsspec(self.args.restore_path, map_location="cpu")
|
||||
if "model_loss" in ch:
|
||||
self.best_loss = ch["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
self._restore_best_loss()
|
||||
|
||||
self.total_steps_done = self.restore_step
|
||||
|
||||
for epoch in range(0, self.config.epochs):
|
||||
if self.num_gpus > 1:
|
||||
# let all processes sync up before starting with a new epoch of training
|
||||
dist.barrier()
|
||||
self.callbacks.on_epoch_start()
|
||||
self.keep_avg_train = KeepAverage()
|
||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||
self.epochs_done = epoch
|
||||
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
||||
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
|
||||
self.train_epoch()
|
||||
if self.config.run_eval:
|
||||
self.eval_epoch()
|
||||
|
@ -839,20 +900,26 @@ class Trainer:
|
|||
self.c_logger.print_epoch_end(
|
||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||
)
|
||||
self.save_best_model()
|
||||
if self.args.rank in [None, 0]:
|
||||
self.save_best_model()
|
||||
self.callbacks.on_epoch_end()
|
||||
|
||||
def fit(self) -> None:
|
||||
"""Where the ✨️magic✨️ happens..."""
|
||||
try:
|
||||
self._fit()
|
||||
self.dashboard_logger.finish()
|
||||
if self.args.rank == 0:
|
||||
self.dashboard_logger.finish()
|
||||
except KeyboardInterrupt:
|
||||
self.callbacks.on_keyboard_interrupt()
|
||||
# if the output folder is empty remove the run.
|
||||
remove_experiment_folder(self.output_path)
|
||||
# clear the DDP processes
|
||||
if self.num_gpus > 1:
|
||||
dist.destroy_process_group()
|
||||
# finish the wandb run and sync data
|
||||
self.dashboard_logger.finish()
|
||||
if self.args.rank == 0:
|
||||
self.dashboard_logger.finish()
|
||||
# stop without error signal
|
||||
try:
|
||||
sys.exit(0)
|
||||
|
@ -902,18 +969,30 @@ class Trainer:
|
|||
keep_after=self.config.keep_after,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _setup_logger_config(log_file: str) -> None:
|
||||
handlers = [logging.StreamHandler()]
|
||||
def _setup_logger_config(self, log_file: str) -> None:
|
||||
"""Write log strings to a file and print logs to the terminal.
|
||||
TODO: Causes formatting issues in pdb debugging."""
|
||||
|
||||
# Only add a log file if the output location is local due to poor
|
||||
# support for writing logs to file-like objects.
|
||||
parsed_url = urlparse(log_file)
|
||||
if not parsed_url.scheme or parsed_url.scheme == "file":
|
||||
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
|
||||
handlers.append(logging.FileHandler(schemeless_path))
|
||||
class Logger(object):
|
||||
def __init__(self, print_to_terminal=True):
|
||||
self.print_to_terminal = print_to_terminal
|
||||
self.terminal = sys.stdout
|
||||
self.log_file = log_file
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
|
||||
def write(self, message):
|
||||
if self.print_to_terminal:
|
||||
self.terminal.write(message)
|
||||
with open(self.log_file, "a", encoding="utf-8") as f:
|
||||
f.write(message)
|
||||
|
||||
def flush(self):
|
||||
# this flush method is needed for python 3 compatibility.
|
||||
# this handles the flush command by doing nothing.
|
||||
# you might want to specify some extra behavior here.
|
||||
pass
|
||||
|
||||
# don't let processes rank > 0 write to the terminal
|
||||
sys.stdout = Logger(self.args.rank == 0)
|
||||
|
||||
@staticmethod
|
||||
def _is_apex_available() -> bool:
|
||||
|
@ -1109,8 +1188,6 @@ def process_args(args, config=None):
|
|||
config = register_config(config_base.model)()
|
||||
# override values from command-line args
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
if config.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
|
@ -1130,8 +1207,7 @@ def process_args(args, config=None):
|
|||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
|
||||
dashboard_logger = init_logger(config)
|
||||
dashboard_logger = init_dashboard_logger(config)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||
|
||||
# optimizer
|
||||
grad_clip: List[float] = field(default_factory=lambda: [5, 5])
|
||||
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
||||
lr_gen: float = 0.0002
|
||||
lr_disc: float = 0.0002
|
||||
lr_scheduler_gen: str = "ExponentialLR"
|
||||
|
@ -113,14 +113,16 @@ class VitsConfig(BaseTTSConfig):
|
|||
gen_loss_alpha: float = 1.0
|
||||
feat_loss_alpha: float = 1.0
|
||||
mel_loss_alpha: float = 45.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
compute_linear_spec: bool = True
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
max_seq_len: int = 500
|
||||
sort_by_audio_len: bool = True
|
||||
min_seq_len: int = 0
|
||||
max_seq_len: int = 500000
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ class TTSDataset(Dataset):
|
|||
batch. Set 0 to disable. Defaults to 0.
|
||||
|
||||
min_seq_len (int): Minimum input sequence length to be processed
|
||||
by the loader. Filter out input sequences that are shorter than this. Some models have a
|
||||
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
|
||||
minimum input length due to its architecture. Defaults to 0.
|
||||
|
||||
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
||||
|
@ -302,10 +302,23 @@ class TTSDataset(Dataset):
|
|||
for idx, p in enumerate(phonemes):
|
||||
self.items[idx][0] = p
|
||||
|
||||
def sort_items(self):
|
||||
r"""Sort instances based on text length in ascending order"""
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
def sort_and_filter_items(self, by_audio_len=False):
|
||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||
range.
|
||||
|
||||
Args:
|
||||
by_audio_len (bool): if True, sort by audio length else by text length.
|
||||
"""
|
||||
# compute the target sequence length
|
||||
if by_audio_len:
|
||||
lengths = []
|
||||
for item in self.items:
|
||||
lengths.append(os.path.getsize(item[1]))
|
||||
lengths = np.array(lengths)
|
||||
else:
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
|
||||
# sort items based on the sequence length in ascending order
|
||||
idxs = np.argsort(lengths)
|
||||
new_items = []
|
||||
ignored = []
|
||||
|
@ -315,7 +328,10 @@ class TTSDataset(Dataset):
|
|||
ignored.append(idx)
|
||||
else:
|
||||
new_items.append(self.items[idx])
|
||||
|
||||
# shuffle batch groups
|
||||
# create batches with similar length items
|
||||
# the larger the `batch_group_size`, the higher the length variety in a batch.
|
||||
if self.batch_group_size > 0:
|
||||
for i in range(len(new_items) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
|
@ -325,6 +341,7 @@ class TTSDataset(Dataset):
|
|||
new_items[offset:end_offset] = temp_items
|
||||
self.items = new_items
|
||||
|
||||
# logging
|
||||
if self.verbose:
|
||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||
|
|
|
@ -66,7 +66,7 @@ def load_meta_data(datasets, eval_split=True):
|
|||
|
||||
def load_attention_mask_meta_data(metafile_path):
|
||||
"""Load meta data file created by compute_attention_masks.py"""
|
||||
with open(metafile_path, "r") as f:
|
||||
with open(metafile_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
meta_data = []
|
||||
|
|
|
@ -19,7 +19,7 @@ def tweb(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "tweb"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("\t")
|
||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||
|
@ -33,7 +33,7 @@ def mozilla(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "mozilla"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
wav_file = cols[1].strip()
|
||||
|
@ -77,7 +77,7 @@ def mailabs(root_path, meta_files=None):
|
|||
continue
|
||||
speaker_name = speaker_name_match.group("speaker_name")
|
||||
print(" | > {}".format(csv_file))
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
if meta_files is None:
|
||||
|
@ -102,7 +102,7 @@ def ljspeech(root_path, meta_file):
|
|||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
text = cols[2]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
||||
|
@ -116,7 +116,7 @@ def ljspeech_test(root_path, meta_file):
|
|||
for idx, line in enumerate(ttf):
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
text = cols[2]
|
||||
items.append([text, wav_file, f"ljspeech-{idx}"])
|
||||
return items
|
||||
|
||||
|
@ -158,7 +158,7 @@ def css10(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "ljspeech"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, cols[0])
|
||||
|
@ -172,7 +172,7 @@ def nancy(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "nancy"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
utt_id = line.split()[1]
|
||||
text = line[line.find('"') + 1 : line.rfind('"') - 1]
|
||||
|
@ -185,7 +185,7 @@ def common_voice(root_path, meta_file):
|
|||
"""Normalize the common voice meta data file to TTS format."""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
if line.startswith("client_id"):
|
||||
continue
|
||||
|
@ -208,7 +208,7 @@ def libri_tts(root_path, meta_files=None):
|
|||
|
||||
for meta_file in meta_files:
|
||||
_meta_file = os.path.basename(meta_file).split(".")[0]
|
||||
with open(meta_file, "r") as ttf:
|
||||
with open(meta_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("\t")
|
||||
file_name = cols[0]
|
||||
|
@ -245,7 +245,7 @@ def brspeech(root_path, meta_file):
|
|||
"""BRSpeech 3.0 beta"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
if line.startswith("wav_filename"):
|
||||
continue
|
||||
|
@ -268,7 +268,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
|
|||
if isinstance(test_speakers, list): # if is list ignore this speakers ids
|
||||
if speaker_id in test_speakers:
|
||||
continue
|
||||
with open(meta_file) as file_text:
|
||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append([text, wav_file, "VCTK_" + speaker_id])
|
||||
|
@ -295,7 +295,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
|
|||
def mls(root_path, meta_files=None):
|
||||
"""http://www.openslr.org/94/"""
|
||||
items = []
|
||||
with open(os.path.join(root_path, meta_files), "r") as meta:
|
||||
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
|
||||
for line in meta:
|
||||
file, text = line.split("\t")
|
||||
text = text[:-1]
|
||||
|
@ -329,7 +329,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
|
|||
|
||||
# if not exists meta file, crawl recursively for 'wav' files
|
||||
if meta_file is not None:
|
||||
with open(str(meta_file), "r") as f:
|
||||
with open(str(meta_file), "r", encoding="utf-8") as f:
|
||||
return [x.strip().split("|") for x in f.readlines()]
|
||||
|
||||
elif not cache_to.exists():
|
||||
|
@ -346,12 +346,12 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
|
|||
text = None # VoxCel does not provide transciptions, and they are not needed for training the SE
|
||||
meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
|
||||
cnt += 1
|
||||
with open(str(cache_to), "w") as f:
|
||||
with open(str(cache_to), "w", encoding="utf-8") as f:
|
||||
f.write("".join(meta_data))
|
||||
if cnt < expected_count:
|
||||
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")
|
||||
|
||||
with open(str(cache_to), "r") as f:
|
||||
with open(str(cache_to), "r", encoding="utf-8") as f:
|
||||
return [x.strip().split("|") for x in f.readlines()]
|
||||
|
||||
|
||||
|
@ -367,7 +367,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "baker"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
wav_name, text = line.rstrip("\n").split("|")
|
||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||
|
@ -380,7 +380,7 @@ def kokoro(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "kokoro"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class FFTransformer(nn.Module):
|
||||
|
|
|
@ -524,6 +524,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.kl_loss_alpha = c.kl_loss_alpha
|
||||
self.gen_loss_alpha = c.gen_loss_alpha
|
||||
self.feat_loss_alpha = c.feat_loss_alpha
|
||||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.mel_loss_alpha = c.mel_loss_alpha
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
|
@ -590,10 +591,11 @@ class VitsGeneratorLoss(nn.Module):
|
|||
scores_disc_fake,
|
||||
feats_disc_fake,
|
||||
feats_disc_real,
|
||||
loss_duration,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
- wavefrom: :math:`[B, 1, T]`
|
||||
- waveform : :math:`[B, 1, T]`
|
||||
- waveform_hat: :math:`[B, 1, T]`
|
||||
- z_p: :math:`[B, C, T]`
|
||||
- logs_q: :math:`[B, C, T]`
|
||||
|
@ -615,12 +617,14 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
return_dict["loss_feat"] = loss_feat
|
||||
return_dict["loss_mel"] = loss_mel
|
||||
return_dict["loss_duration"] = loss_duration
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
||||
|
@ -651,7 +655,6 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
return_dict = {}
|
||||
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
|
||||
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||
loss = loss + loss_disc
|
||||
return_dict["loss_disc"] = loss_disc
|
||||
loss = loss + return_dict["loss_disc"]
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GST(nn.Module):
|
||||
|
|
|
@ -388,8 +388,8 @@ class Decoder(nn.Module):
|
|||
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||||
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx])
|
||||
for idx, decoder_rnn in enumerate(self.decoder_rnns):
|
||||
self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx])
|
||||
# Residual connection
|
||||
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
||||
decoder_output = decoder_input
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn.modules.conv import Conv1d
|
||||
|
||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
|
@ -60,18 +60,32 @@ class VitsDiscriminator(nn.Module):
|
|||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
|
||||
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
def forward(self, x):
|
||||
self.nets = nn.ModuleList()
|
||||
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
|
||||
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
|
||||
|
||||
def forward(self, x, x_hat=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
x (Tensor): ground truth waveform.
|
||||
x_hat (Tensor): predicted waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
score_sd, feats_sd = self.sd(x)
|
||||
return scores + [score_sd], feats + [feats_sd]
|
||||
x_scores = []
|
||||
x_hat_scores = [] if x_hat is not None else None
|
||||
x_feats = []
|
||||
x_hat_feats = [] if x_hat is not None else None
|
||||
for net in self.nets:
|
||||
x_score, x_feat = net(x)
|
||||
x_scores.append(x_score)
|
||||
x_feats.append(x_feat)
|
||||
if x_hat is not None:
|
||||
x_hat_score, x_hat_feat = net(x_hat)
|
||||
x_hat_scores.append(x_hat_score)
|
||||
x_hat_feats.append(x_hat_feat)
|
||||
return x_scores, x_feats, x_hat_scores, x_hat_feats
|
||||
|
|
|
@ -228,7 +228,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||
h = self.post_pre(dr)
|
||||
h = self.post_convs(h, x_mask)
|
||||
h = self.post_proj(h) * x_mask
|
||||
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = noise
|
||||
|
||||
# posterior encoder
|
||||
|
|
|
@ -2,8 +2,8 @@ from dataclasses import dataclass, field
|
|||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -164,7 +165,14 @@ class BaseTTS(BaseModel):
|
|||
}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: AudioProcessor,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
|
@ -186,7 +194,7 @@ class BaseTTS(BaseModel):
|
|||
if hasattr(self, "make_symbols"):
|
||||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
# init dataloader
|
||||
# init dataset
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
|
@ -212,13 +220,15 @@ class BaseTTS(BaseModel):
|
|||
else None,
|
||||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
# pre-compute phonemes
|
||||
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
|
||||
if hasattr(self, "eval_data_items") and is_eval:
|
||||
dataset.items = self.eval_data_items
|
||||
elif hasattr(self, "train_data_items") and not is_eval:
|
||||
dataset.items = self.train_data_items
|
||||
else:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
# precompute phonemes for precise estimate of sequence lengths.
|
||||
# otherwise `dataset.sort_items()` uses raw text lengths
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
|
||||
# TODO: find a more efficient solution
|
||||
|
@ -228,9 +238,17 @@ class BaseTTS(BaseModel):
|
|||
else:
|
||||
self.train_data_items = dataset.items
|
||||
|
||||
dataset.sort_items()
|
||||
# halt DDP processes for the main process to finish computing the phoneme cache
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
# init dataloader
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -11,8 +13,6 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
|||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
||||
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
|
@ -119,7 +119,7 @@ class VitsArgs(Coqpit):
|
|||
upsample_kernel_sizes_decoder (List[int]):
|
||||
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
|
||||
|
||||
use_sdp (int):
|
||||
use_sdp (bool):
|
||||
Use Stochastic Duration Predictor. Defaults to True.
|
||||
|
||||
noise_scale (float):
|
||||
|
@ -128,7 +128,7 @@ class VitsArgs(Coqpit):
|
|||
inference_noise_scale (float):
|
||||
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
|
||||
|
||||
length_scale (int):
|
||||
length_scale (float):
|
||||
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
|
||||
|
||||
noise_scale_dp (float):
|
||||
|
@ -176,26 +176,26 @@ class VitsArgs(Coqpit):
|
|||
num_heads_text_encoder: int = 2
|
||||
num_layers_text_encoder: int = 6
|
||||
kernel_size_text_encoder: int = 3
|
||||
dropout_p_text_encoder: int = 0.1
|
||||
dropout_p_duration_predictor: int = 0.1
|
||||
dropout_p_text_encoder: float = 0.1
|
||||
dropout_p_duration_predictor: float = 0.5
|
||||
kernel_size_posterior_encoder: int = 5
|
||||
dilation_rate_posterior_encoder: int = 1
|
||||
num_layers_posterior_encoder: int = 16
|
||||
kernel_size_flow: int = 5
|
||||
dilation_rate_flow: int = 1
|
||||
num_layers_flow: int = 4
|
||||
resblock_type_decoder: int = "1"
|
||||
resblock_type_decoder: str = "1"
|
||||
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
|
||||
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
|
||||
upsample_initial_channel_decoder: int = 512
|
||||
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||
use_sdp: int = True
|
||||
use_sdp: bool = True
|
||||
noise_scale: float = 1.0
|
||||
inference_noise_scale: float = 0.667
|
||||
length_scale: int = 1
|
||||
length_scale: float = 1
|
||||
noise_scale_dp: float = 1.0
|
||||
inference_noise_scale_dp: float = 0.8
|
||||
inference_noise_scale_dp: float = 1.0
|
||||
max_inference_len: int = None
|
||||
init_discriminator: bool = True
|
||||
use_spectral_norm_disriminator: bool = False
|
||||
|
@ -419,24 +419,23 @@ class Vits(BaseTTS):
|
|||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * logs_p)
|
||||
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp2 + logp3
|
||||
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp2 + logp3 + logp1 + logp4
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
# duration predictor
|
||||
attn_durations = attn.sum(3)
|
||||
if self.args.use_sdp:
|
||||
nll_duration = self.duration_predictor(
|
||||
loss_duration = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
)
|
||||
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
|
||||
outputs["nll_duration"] = nll_duration
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
else:
|
||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||
log_durations = self.duration_predictor(
|
||||
|
@ -445,7 +444,7 @@ class Vits(BaseTTS):
|
|||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
outputs["loss_duration"] = loss_duration
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
|
@ -563,8 +562,9 @@ class Vits(BaseTTS):
|
|||
outputs["waveform_seg"] = wav_seg
|
||||
|
||||
# compute discriminator scores and features
|
||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
|
||||
_, outputs["feats_disc_real"] = self.disc(wav_seg)
|
||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
|
||||
outputs["model_outputs"], wav_seg
|
||||
)
|
||||
|
||||
# compute losses
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
|
@ -579,23 +579,17 @@ class Vits(BaseTTS):
|
|||
scores_disc_fake=outputs["scores_disc_fake"],
|
||||
feats_disc_fake=outputs["feats_disc_fake"],
|
||||
feats_disc_real=outputs["feats_disc_real"],
|
||||
loss_duration=outputs["loss_duration"],
|
||||
)
|
||||
|
||||
# handle the duration loss
|
||||
if self.args.use_sdp:
|
||||
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||
loss_dict["loss"] += outputs["nll_duration"]
|
||||
else:
|
||||
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||
loss_dict["loss"] += outputs["nll_duration"]
|
||||
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator pass
|
||||
outputs = {}
|
||||
|
||||
# compute scores and features
|
||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
|
||||
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
|
||||
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(
|
||||
self.y_disc_cache.detach(), self.wav_seg_disc_cache
|
||||
)
|
||||
|
||||
# compute loss
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
|
@ -686,14 +680,21 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
self.disc.requires_grad_(False)
|
||||
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
self.disc.requires_grad_(True)
|
||||
optimizer1 = get_optimizer(
|
||||
gen_parameters = chain(
|
||||
self.text_encoder.parameters(),
|
||||
self.posterior_encoder.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.duration_predictor.parameters(),
|
||||
self.waveform_decoder.parameters(),
|
||||
)
|
||||
# add the speaker embedding layer
|
||||
if hasattr(self, "emb_g"):
|
||||
gen_parameters = chain(gen_parameters, self.emb_g)
|
||||
optimizer0 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||
)
|
||||
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||
return [optimizer1, optimizer2]
|
||||
optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||
return [optimizer0, optimizer1]
|
||||
|
||||
def get_lr(self) -> List:
|
||||
"""Set the initial learning rates for each optimizer.
|
||||
|
@ -712,9 +713,9 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
List: Schedulers, one for each optimizer.
|
||||
"""
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler1, scheduler2]
|
||||
scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler0, scheduler1]
|
||||
|
||||
def get_criterion(self):
|
||||
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
|
||||
|
|
|
@ -225,9 +225,10 @@ def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_sy
|
|||
|
||||
if custom_symbols is not None:
|
||||
_symbols = custom_symbols
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
elif tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
|
||||
result = ""
|
||||
for symbol_id in sequence:
|
||||
|
|
|
@ -241,6 +241,7 @@ class AudioProcessor(object):
|
|||
self.sample_rate = sample_rate
|
||||
self.resample = resample
|
||||
self.num_mels = num_mels
|
||||
self.log_func = log_func
|
||||
self.min_level_db = min_level_db or 0
|
||||
self.frame_shift_ms = frame_shift_ms
|
||||
self.frame_length_ms = frame_length_ms
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def reduce_tensor(tensor, num_gpus):
|
||||
|
@ -20,46 +18,3 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
|
|||
|
||||
# Initialize distributed communication
|
||||
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
|
||||
|
||||
|
||||
def apply_gradient_allreduce(module):
|
||||
|
||||
# sync model parameters
|
||||
for p in module.state_dict().values():
|
||||
if not torch.is_tensor(p):
|
||||
continue
|
||||
dist.broadcast(p, 0)
|
||||
|
||||
def allreduce_params():
|
||||
if module.needs_reduction:
|
||||
module.needs_reduction = False
|
||||
# bucketing params based on value types
|
||||
buckets = {}
|
||||
for param in module.parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = type(param.data)
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.reduce_op.SUM)
|
||||
coalesced /= dist.get_world_size()
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
for param in list(module.parameters()):
|
||||
|
||||
def allreduce_hook(*_):
|
||||
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access
|
||||
|
||||
if param.requires_grad:
|
||||
param.register_hook(allreduce_hook)
|
||||
|
||||
def set_needs_reduction(self, *_):
|
||||
self.needs_reduction = True
|
||||
|
||||
module.register_forward_hook(set_needs_reduction)
|
||||
return module
|
||||
|
|
|
@ -53,7 +53,6 @@ def get_commit_hash():
|
|||
# Not copying .git folder into docker container
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
commit = "0000000"
|
||||
print(" > Git Hash: {}".format(commit))
|
||||
return commit
|
||||
|
||||
|
||||
|
@ -62,7 +61,6 @@ def get_experiment_folder_path(root_path, model_name):
|
|||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||
print(" > Experiment folder: {}".format(output_folder))
|
||||
return output_folder
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import pickle as pickle_tts
|
||||
import shutil
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
@ -53,18 +53,23 @@ def copy_model_files(config: Coqpit, out_path, new_fields):
|
|||
shutil.copyfileobj(source_file, target_file)
|
||||
|
||||
|
||||
def load_fsspec(path: str, **kwargs) -> Any:
|
||||
def load_fsspec(
|
||||
path: str,
|
||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
path: Any path or url supported by fsspec.
|
||||
map_location: torch.device or str.
|
||||
**kwargs: Keyword arguments forwarded to torch.load.
|
||||
|
||||
Returns:
|
||||
Object stored in path.
|
||||
"""
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, **kwargs)
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||
|
|
|
@ -3,7 +3,7 @@ from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
|||
from TTS.utils.logging.wandb_logger import WandbLogger
|
||||
|
||||
|
||||
def init_logger(config):
|
||||
def init_dashboard_logger(config):
|
||||
if config.dashboard_logger == "tensorboard":
|
||||
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
|
||||
|
||||
|
|
|
@ -29,11 +29,13 @@ class ConsoleLogger:
|
|||
now = datetime.datetime.now()
|
||||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def print_epoch_start(self, epoch, max_epoch):
|
||||
def print_epoch_start(self, epoch, max_epoch, output_path=None):
|
||||
print(
|
||||
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
|
||||
flush=True,
|
||||
)
|
||||
if output_path is not None:
|
||||
print(f" --> {output_path}")
|
||||
|
||||
def print_train_start(self):
|
||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||
|
|
|
@ -110,7 +110,7 @@ class ModelManager(object):
|
|||
os.makedirs(output_path, exist_ok=True)
|
||||
print(f" > Downloading model to {output_path}")
|
||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
output_speakers_path = os.path.join(output_path, "speakers.json")
|
||||
|
||||
# download files to the output path
|
||||
if self._check_dict_key(model_item, "github_rls_url"):
|
||||
# download from github release
|
||||
|
@ -122,22 +122,52 @@ class ModelManager(object):
|
|||
if self._check_dict_key(model_item, "stats_file"):
|
||||
self._download_gdrive_file(model_item["stats_file"], output_stats_path)
|
||||
|
||||
# update the scale_path.npy file path in the model config.json
|
||||
if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path):
|
||||
# set scale stats path in config.json
|
||||
config_path = output_config_path
|
||||
config = load_config(config_path)
|
||||
config.audio.stats_path = output_stats_path
|
||||
config.save_json(config_path)
|
||||
# update the speakers.json file path in the model config.json to the current path
|
||||
if os.path.exists(output_speakers_path):
|
||||
# set scale stats path in config.json
|
||||
config_path = output_config_path
|
||||
config = load_config(config_path)
|
||||
config.d_vector_file = output_speakers_path
|
||||
config.save_json(config_path)
|
||||
# update paths in the config.json
|
||||
self._update_paths(output_path, output_config_path)
|
||||
return output_model_path, output_config_path, model_item
|
||||
|
||||
def _update_paths(self, output_path: str, config_path: str) -> None:
|
||||
"""Update paths for certain files in config.json after download.
|
||||
|
||||
Args:
|
||||
output_path (str): local path the model is downloaded to.
|
||||
config_path (str): local config.json path.
|
||||
"""
|
||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||
|
||||
# update the scale_path.npy file path in the model config.json
|
||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||
|
||||
# update the speakers.json file path in the model config.json to the current path
|
||||
self._update_path("d_vector_file", output_d_vector_file_path, config_path)
|
||||
self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
|
||||
|
||||
# update the speaker_ids.json file path in the model config.json to the current path
|
||||
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||
|
||||
@staticmethod
|
||||
def _update_path(field_name, new_path, config_path):
|
||||
"""Update the path in the model config.json for the current environment after download"""
|
||||
if os.path.exists(new_path):
|
||||
config = load_config(config_path)
|
||||
field_names = field_name.split(".")
|
||||
if len(field_names) > 1:
|
||||
# field name points to a sub-level field
|
||||
sub_conf = config
|
||||
for fd in field_names[:-1]:
|
||||
if fd in sub_conf:
|
||||
sub_conf = sub_conf[fd]
|
||||
else:
|
||||
return
|
||||
sub_conf[field_names[-1]] = new_path
|
||||
else:
|
||||
# field name points to a top-level field
|
||||
config[field_name] = new_path
|
||||
config.save_json(config_path)
|
||||
|
||||
def _download_gdrive_file(self, gdrive_idx, output):
|
||||
"""Download files from GDrive using their file ids"""
|
||||
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
||||
|
|
|
@ -251,7 +251,7 @@ class Synthesizer(object):
|
|||
d_vector=speaker_embedding,
|
||||
)
|
||||
waveform = outputs["wav"]
|
||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().numpy()
|
||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
||||
if not use_gl:
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import importlib
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -10,9 +10,20 @@ def is_apex_available():
|
|||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
|
||||
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||
def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]:
|
||||
"""Setup PyTorch environment for training.
|
||||
|
||||
Args:
|
||||
cudnn_enable (bool): Enable/disable CUDNN.
|
||||
cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is
|
||||
variable between batches.
|
||||
use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, int]: is cuda on or off and number of GPUs in the environment.
|
||||
"""
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1:
|
||||
if num_gpus > 1 and not use_ddp:
|
||||
raise RuntimeError(
|
||||
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
@ -95,7 +96,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
|
|||
# model specific params
|
||||
discriminator_model: str = "univnet_discriminator"
|
||||
generator_model: str = "univnet_generator"
|
||||
generator_model_params: dict = field(
|
||||
generator_model_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"in_channels": 64,
|
||||
"out_channels": 1,
|
||||
|
@ -120,7 +121,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
|
|||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 2.5
|
||||
stft_loss_params: dict = field(
|
||||
stft_loss_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
|
@ -132,7 +133,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
|
|||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 0
|
||||
l1_spec_loss_weight: float = 0
|
||||
l1_spec_loss_params: dict = field(
|
||||
l1_spec_loss_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
|
@ -152,7 +153,7 @@ class UnivnetConfig(BaseGANVocoderConfig):
|
|||
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
|
||||
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
|
||||
steps_to_start_discriminator: int = 200000
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
|
|
@ -40,8 +40,8 @@ class MelganMultiscaleDiscriminator(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = list()
|
||||
feats = list()
|
||||
scores = []
|
||||
feats = []
|
||||
for disc in self.discriminators:
|
||||
score, feat = disc(x)
|
||||
scores.append(score)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
|
|
|
@ -9,12 +9,12 @@ from torch.nn.utils import weight_norm
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ class WavegradArgs(Coqpit):
|
|||
)
|
||||
|
||||
|
||||
class Wavegrad(BaseModel):
|
||||
class Wavegrad(BaseVocoder):
|
||||
"""🐸 🌊 WaveGrad 🌊 model.
|
||||
Paper - https://arxiv.org/abs/2009.00713
|
||||
|
||||
|
@ -257,14 +257,18 @@ class Wavegrad(BaseModel):
|
|||
loss = criterion(noise, noise_hat)
|
||||
return {"model_output": noise_hat}, {"loss": loss}
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
def train_log( # pylint: disable=no-self-use
|
||||
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
def eval_log( # pylint: disable=no-self-use
|
||||
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
|
||||
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
|
||||
|
@ -291,7 +295,8 @@ class Wavegrad(BaseModel):
|
|||
def get_scheduler(self, optimizer):
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
|
||||
|
||||
def get_criterion(self):
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
return torch.nn.L1Loss()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -5,9 +5,9 @@ from typing import Dict, List, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
|
|
@ -3,28 +3,24 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||
"from TTS.utils.io import load_config\n",
|
||||
"from TTS.config import load_config\n",
|
||||
"\n",
|
||||
"import IPython.display as ipd\n",
|
||||
"import glob"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config_path = \"/home/erogol/gdrive/Projects/TTS/recipes/ljspeech/align_tts/config_transformer2.json\"\n",
|
||||
"data_path = \"/home/erogol/gdrive/Datasets/LJSpeech-1.1/\"\n",
|
||||
|
@ -39,28 +35,28 @@
|
|||
"\n",
|
||||
"print(\"File list, by index:\")\n",
|
||||
"dict(enumerate(file_paths))"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Setup Audio Processor\n",
|
||||
"Play with the AP parameters until you find a good fit with the synthesis speech below.\n",
|
||||
"\n",
|
||||
"The default values are loaded from your config.json file, so you only need to\n",
|
||||
"uncomment and modify values below that you'd like to tune."
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tune_params={\n",
|
||||
"# 'audio_processor': 'audio',\n",
|
||||
|
@ -95,54 +91,54 @@
|
|||
"tuned_config.update(tune_params)\n",
|
||||
"\n",
|
||||
"AP = AudioProcessor(**tuned_config);"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Check audio loading "
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"wav = AP.load_wav(SAMPLE_FILE_PATH)\n",
|
||||
"ipd.Audio(data=wav, rate=AP.sample_rate) "
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Generate Mel-Spectrogram and Re-synthesis with GL"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"AP.power = 1.5"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"AP.power = 1.5"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"mel = AP.melspectrogram(wav)\n",
|
||||
"print(\"Max:\", mel.max())\n",
|
||||
|
@ -152,24 +148,24 @@
|
|||
"\n",
|
||||
"wav_gen = AP.inv_melspectrogram(mel)\n",
|
||||
"ipd.Audio(wav_gen, rate=AP.sample_rate)"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Generate Linear-Spectrogram and Re-synthesis with GL"
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"spec = AP.spectrogram(wav)\n",
|
||||
"print(\"Max:\", spec.max())\n",
|
||||
|
@ -179,26 +175,26 @@
|
|||
"\n",
|
||||
"wav_gen = AP.inv_spectrogram(spec)\n",
|
||||
"ipd.Audio(wav_gen, rate=AP.sample_rate)"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Compare values for a certain parameter\n",
|
||||
"\n",
|
||||
"Optimize your parameters by comparing different values per parameter at a time."
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from librosa import display\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
|
@ -238,36 +234,39 @@
|
|||
" val = values[idx]\n",
|
||||
" print(\" > {} = {}\".format(attribute, val))\n",
|
||||
" IPython.display.display(IPython.display.Audio(wav_gen, rate=AP.sample_rate))"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"compare_values(\"preemphasis\", [0, 0.5, 0.97, 0.98, 0.99])"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"compare_values(\"ref_level_db\", [2, 5, 10, 15, 20, 25, 30, 35, 40, 1000])"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "python3",
|
||||
"display_name": "Python 3.8.5 64-bit ('torch': conda)"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
@ -280,8 +279,11 @@
|
|||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
},
|
||||
"interpreter": {
|
||||
"hash": "27648abe09795c3a768a281b31f7524fcf66a207e733f8ecda3a4e1fd1059fb0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
|
@ -43,7 +43,7 @@ def process_meta_data(path):
|
|||
meta_data = {}
|
||||
|
||||
# load meta data
|
||||
with open(path, "r") as f:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = csv.reader(f, delimiter="|")
|
||||
for row in data:
|
||||
frames = int(row[2])
|
||||
|
@ -92,7 +92,7 @@ def save_training(file_path, meta_data):
|
|||
rows.append(d["row"] + "\n")
|
||||
|
||||
random.shuffle(rows)
|
||||
with open(file_path, "w+") as f:
|
||||
with open(file_path, "w+", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(row)
|
||||
|
||||
|
@ -156,7 +156,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
|||
|
||||
phonemes = {}
|
||||
|
||||
with open(train_path, "r") as f:
|
||||
with open(train_path, "r", encoding="utf-8") as f:
|
||||
data = csv.reader(f, delimiter="|")
|
||||
phonemes["None"] = 0
|
||||
for row in data:
|
||||
|
@ -174,9 +174,9 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
|||
phonemes["None"] += 1
|
||||
|
||||
x, y = [], []
|
||||
for key in phonemes:
|
||||
x.append(key)
|
||||
y.append(phonemes[key])
|
||||
for k, v in phonemes.items():
|
||||
x.append(k)
|
||||
y.append(v)
|
||||
|
||||
plt.figure()
|
||||
plt.rcParams["figure.figsize"] = (50, 20)
|
||||
|
|
|
@ -29,7 +29,7 @@ config = VitsConfig(
|
|||
run_name="vits_ljspeech",
|
||||
batch_size=48,
|
||||
eval_batch_size=16,
|
||||
batch_group_size=0,
|
||||
batch_group_size=5,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
|
@ -43,7 +43,7 @@ config = VitsConfig(
|
|||
print_step=25,
|
||||
print_eval=True,
|
||||
mixed_precision=True,
|
||||
max_seq_len=5000,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
|
|
|
@ -2,4 +2,4 @@ black
|
|||
coverage
|
||||
isort
|
||||
nose
|
||||
pylint==2.8.3
|
||||
pylint==2.10.2
|
||||
|
|
|
@ -124,7 +124,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
|
||||
avg_length = mel_lengths.numpy().mean()
|
||||
assert avg_length >= last_length
|
||||
dataloader.dataset.sort_items()
|
||||
dataloader.dataset.sort_and_filter_items()
|
||||
is_items_reordered = False
|
||||
for idx, item in enumerate(dataloader.dataset.items):
|
||||
if item != frames[idx]:
|
||||
|
|
Loading…
Reference in New Issue