Make stlye

This commit is contained in:
Eren Gölge 2022-02-20 12:37:27 +01:00
parent fc8264d9d2
commit 424d04e4f6
18 changed files with 103 additions and 111 deletions

View File

@ -7,7 +7,6 @@ import subprocess
import time
import torch
from trainer import TrainerArgs

View File

@ -8,7 +8,6 @@ import traceback
import torch
from torch.utils.data import DataLoader
from trainer.torch import NoamLR
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
import os
from dataclasses import dataclass, field
from trainer import Trainer, TrainerArgs

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
import os
from dataclasses import dataclass, field
from trainer import Trainer, TrainerArgs

View File

@ -5,11 +5,11 @@ import torch
from coqpit import Coqpit
from torch import nn
# pylint: skip-file
class BaseTrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
"""
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
@staticmethod
@abstractmethod
@ -63,7 +63,7 @@ class BaseTrainerModel(ABC, nn.Module):
"""
return batch
def format_batch_on_device(self, batch:Dict) -> Dict:
def format_batch_on_device(self, batch: Dict) -> Dict:
"""Format batch on device before sending it to the model.
If not implemented, model uses the batch as is.
@ -124,7 +124,6 @@ class BaseTrainerModel(ABC, nn.Module):
"""The same as `train_log()`"""
...
@abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
"""Load a checkpoint and get ready for training or inference.
@ -148,13 +147,8 @@ class BaseTrainerModel(ABC, nn.Module):
@abstractmethod
def get_data_loader(
self,
config: Coqpit,
assets: Dict,
is_eval: True,
data_items: List,
verbose: bool,
num_gpus: int):
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
):
...
# def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:

View File

@ -1,16 +1,15 @@
from asyncio.log import logger
from dataclasses import dataclass, field
import os
from dataclasses import dataclass, field
from coqpit import Coqpit
from trainer import TrainerArgs
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config
from trainer import TrainerArgs
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.utils.trainer_utils import get_last_checkpoint

View File

@ -769,5 +769,3 @@ class F0Dataset:
print("\n")
print(f"{indent}> F0Dataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")

View File

@ -672,7 +672,9 @@ class VitsDiscriminatorLoss(nn.Module):
def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake)
loss_disc, loss_disc_real, _ = self.discriminator_loss(
scores_real=scores_disc_real, scores_fake=scores_disc_fake
)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss

View File

@ -26,8 +26,12 @@ class BaseTTS(BaseTrainerModel):
"""
def __init__(
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None
self,
config: Coqpit,
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
super().__init__()
self.config = config

View File

@ -530,7 +530,8 @@ class GlowTTS(BaseTTS):
self.store_inverse()
assert not self.training
def get_criterion(self):
@staticmethod
def get_criterion():
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
return GlowTTSLoss()

View File

@ -1,9 +1,8 @@
import collections
import math
import os
from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple, Union
import torch
import torch.distributed as dist
@ -25,7 +24,7 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDuration
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -38,6 +37,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
# IO / Feature extraction
##############################
# pylint: disable=global-statement
hann_window = {}
mel_basis = {}
@ -200,7 +200,7 @@ class VitsDataset(TTSDataset):
text, wav_file, speaker_name, language_name, _ = _parse_sample(item)
raw_text = text
wav, sr = load_audio(wav_file)
wav, _ = load_audio(wav_file)
wav_filename = os.path.basename(wav_file)
token_ids = self.get_token_ids(idx, text)
@ -538,12 +538,14 @@ class Vits(BaseTTS):
>>> model = Vits(config)
"""
def __init__(self,
def __init__(
self,
config: Coqpit,
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,):
language_manager: LanguageManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
@ -673,9 +675,9 @@ class Vits(BaseTTS):
)
# pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.config.audio.sample_rate,
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
orig_freq=self.config.audio.sample_rate,
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
@ -777,9 +779,9 @@ class Vits(BaseTTS):
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
@ -806,7 +808,7 @@ class Vits(BaseTTS):
outputs["loss_duration"] = loss_duration
return outputs, attn
def forward(
def forward( # pylint: disable=dangerous-default-value
self,
x: torch.tensor,
x_lengths: torch.tensor,
@ -886,7 +888,7 @@ class Vits(BaseTTS):
waveform,
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
pad_short = True
pad_short=True,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
@ -929,7 +931,9 @@ class Vits(BaseTTS):
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}):
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
): # pylint: disable=dangerous-default-value
"""
Note:
To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
@ -1023,7 +1027,6 @@ class Vits(BaseTTS):
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
@ -1062,7 +1065,7 @@ class Vits(BaseTTS):
)
# cache tensors for the generator pass
self.model_outputs_cache = outputs
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
# compute scores and features
scores_disc_fake, _, scores_disc_real, _ = self.disc(
@ -1082,14 +1085,16 @@ class Vits(BaseTTS):
# compute melspec segment
with autocast(enabled=False):
mel_slice = segment(mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True)
mel_slice = segment(
mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True
)
mel_slice_hat = wav_to_mel(
y = self.model_outputs_cache["model_outputs"].float(),
n_fft = self.config.audio.fft_size,
sample_rate = self.config.audio.sample_rate,
num_mels = self.config.audio.num_mels,
hop_length = self.config.audio.hop_length,
win_length = self.config.audio.win_length,
y=self.model_outputs_cache["model_outputs"].float(),
n_fft=self.config.audio.fft_size,
sample_rate=self.config.audio.sample_rate,
num_mels=self.config.audio.num_mels,
hop_length=self.config.audio.hop_length,
win_length=self.config.audio.win_length,
fmin=self.config.audio.mel_fmin,
fmax=self.config.audio.mel_fmax,
center=False,
@ -1097,7 +1102,7 @@ class Vits(BaseTTS):
# compute discriminator scores and features
scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc(
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
)
# compute losses
@ -1105,18 +1110,18 @@ class Vits(BaseTTS):
loss_dict = criterion[optimizer_idx](
mel_slice_hat=mel_slice.float(),
mel_slice=mel_slice_hat.float(),
z_p= self.model_outputs_cache["z_p"].float(),
logs_q= self.model_outputs_cache["logs_q"].float(),
m_p= self.model_outputs_cache["m_p"].float(),
logs_p= self.model_outputs_cache["logs_p"].float(),
z_p=self.model_outputs_cache["z_p"].float(),
logs_q=self.model_outputs_cache["logs_q"].float(),
m_p=self.model_outputs_cache["m_p"].float(),
logs_p=self.model_outputs_cache["logs_p"].float(),
z_len=mel_lens,
scores_disc_fake= scores_disc_fake,
feats_disc_fake= feats_disc_fake,
feats_disc_real= feats_disc_real,
loss_duration= self.model_outputs_cache["loss_duration"],
scores_disc_fake=scores_disc_fake,
feats_disc_fake=feats_disc_fake,
feats_disc_real=feats_disc_real,
loss_duration=self.model_outputs_cache["loss_duration"],
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
gt_spk_emb= self.model_outputs_cache["gt_spk_emb"],
syn_spk_emb= self.model_outputs_cache["syn_spk_emb"],
gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
)
return self.model_outputs_cache, loss_dict
@ -1248,7 +1253,9 @@ class Vits(BaseTTS):
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
return {"figures": test_figures, "audios": test_audios}
def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
logger.test_figures(steps, outputs["figures"])
@ -1273,7 +1280,11 @@ class Vits(BaseTTS):
d_vectors = torch.FloatTensor(d_vectors)
# get language ids from language names
if self.language_manager is not None and self.language_manager.language_id_mapping and self.args.use_language_embedding:
if (
self.language_manager is not None
and self.language_manager.language_id_mapping
and self.args.use_language_embedding
):
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
if language_ids is not None:
@ -1289,16 +1300,14 @@ class Vits(BaseTTS):
ac = self.config.audio
# compute spectrograms
batch["spec"] = wav_to_spec(
batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False
)
batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
batch["mel"] = spec_to_mel(
spec = batch["spec"],
n_fft = ac.fft_size,
num_mels = ac.num_mels,
sample_rate = ac.sample_rate,
fmin = ac.mel_fmin,
fmax = ac.mel_fmax,
spec=batch["spec"],
n_fft=ac.fft_size,
num_mels=ac.num_mels,
sample_rate=ac.sample_rate,
fmin=ac.mel_fmin,
fmax=ac.mel_fmax,
)
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
@ -1325,27 +1334,6 @@ class Vits(BaseTTS):
if is_eval and not config.run_eval:
loader = None
else:
# setup multi-speaker attributes
speaker_id_mapping = None
d_vector_mapping = None
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
)
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
# setup multi-lingual attributes
language_id_mapping = None
if hasattr(self, "language_manager"):
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)
# init dataloader
dataset = VitsDataset(
samples=samples,
@ -1495,6 +1483,7 @@ class Vits(BaseTTS):
language_manager = LanguageManager.init_from_config(config)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
##################################
# VITS CHARACTERS
##################################

View File

@ -119,6 +119,7 @@ def rand_segments(
ret = segment(x, segment_indices, segment_size)
return ret, segment_indices
def average_over_durations(values, durs):
"""Average values over durations.

View File

@ -1,4 +1,3 @@
from abc import ABC
from dataclasses import replace
from typing import Dict

View File

@ -57,7 +57,7 @@ class Punctuation:
if not isinstance(value, six.string_types):
raise ValueError("[!] Punctuations must be of type str.")
self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+")
self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
def strip(self, text):
"""Remove all the punctuations by replacing with `space`.

View File

@ -270,7 +270,7 @@ class Wavegrad(BaseVocoder):
) -> None:
pass
def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
# setup noise schedule and inference
ap = assets["audio_processor"]
noise_schedule = self.config["test_noise_schedule"]
@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder):
y = y.unsqueeze(1)
return {"input": m, "waveform": y}
def get_data_loader(
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
):
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
ap = assets["audio_processor"]
dataset = WaveGradDataset(
ap=ap,

View File

@ -69,8 +69,8 @@ config = VitsConfig(
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
min_seq_len=32 * 256 * 4,
max_seq_len=160000,
min_audio_len=32 * 256 * 4,
max_audio_len=160000,
output_path=output_path,
datasets=dataset_config,
characters={

View File

@ -4,6 +4,7 @@ import unittest
import torch
from torch import optim
from trainer.logging.tensorboard_logger import TensorboardLogger
from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
@ -11,7 +12,6 @@ from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
from trainer.logging.tensorboard_logger import TensorboardLogger
# pylint: disable=unused-variable

View File

@ -3,15 +3,14 @@ import os
import unittest
import torch
from TTS.tts.datasets.formatters import ljspeech
from trainer.logging.tensorboard_logger import TensorboardLogger
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
from TTS.tts.utils.speakers import SpeakerManager
from trainer.logging.tensorboard_logger import TensorboardLogger
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
@ -31,7 +30,17 @@ class TestVits(unittest.TestCase):
self.assertEqual(sr, 22050)
spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False)
mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False)
mel = wav_to_mel(
wav,
n_fft=1024,
num_mels=80,
sample_rate=sr,
hop_length=512,
win_length=1024,
fmin=0,
fmax=8000,
center=False,
)
mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000)
self.assertEqual((mel - mel2).abs().max(), 0)
@ -45,7 +54,7 @@ class TestVits(unittest.TestCase):
def test_dataset(self):
"""TODO:"""
...
...
def test_init_multispeaker(self):
num_speakers = 10
@ -164,7 +173,7 @@ class TestVits(unittest.TestCase):
num_speakers = 0
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
config.model_args.spec_segment_size = 10
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config)
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config)
model = Vits(config).to(device)
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
self._check_forward_outputs(config, output_dict)
@ -175,7 +184,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
config.model_args.spec_segment_size = 10
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config)
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config)
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
model = Vits(config).to(device)
@ -196,7 +205,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device)
model.train()
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
d_vectors = torch.randn(batch_size, 256).to(device)
output_dict = model.forward(
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}
@ -211,7 +220,7 @@ class TestVits(unittest.TestCase):
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
@ -246,7 +255,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
config.audio.sample_rate = 16000
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)