Make stlye

This commit is contained in:
Eren Gölge 2022-02-20 12:37:27 +01:00
parent fc8264d9d2
commit 424d04e4f6
18 changed files with 103 additions and 111 deletions

View File

@ -7,7 +7,6 @@ import subprocess
import time import time
import torch import torch
from trainer import TrainerArgs from trainer import TrainerArgs

View File

@ -8,7 +8,6 @@ import traceback
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.torch import NoamLR from trainer.torch import NoamLR
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.dataset import SpeakerEncoderDataset

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
import os import os
from dataclasses import dataclass, field
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
import os import os
from dataclasses import dataclass, field
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs

View File

@ -5,11 +5,11 @@ import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
# pylint: skip-file
class BaseTrainerModel(ABC, nn.Module): class BaseTrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
"""
@staticmethod @staticmethod
@abstractmethod @abstractmethod
@ -63,7 +63,7 @@ class BaseTrainerModel(ABC, nn.Module):
""" """
return batch return batch
def format_batch_on_device(self, batch:Dict) -> Dict: def format_batch_on_device(self, batch: Dict) -> Dict:
"""Format batch on device before sending it to the model. """Format batch on device before sending it to the model.
If not implemented, model uses the batch as is. If not implemented, model uses the batch as is.
@ -124,7 +124,6 @@ class BaseTrainerModel(ABC, nn.Module):
"""The same as `train_log()`""" """The same as `train_log()`"""
... ...
@abstractmethod @abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
"""Load a checkpoint and get ready for training or inference. """Load a checkpoint and get ready for training or inference.
@ -148,13 +147,8 @@ class BaseTrainerModel(ABC, nn.Module):
@abstractmethod @abstractmethod
def get_data_loader( def get_data_loader(
self, self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
config: Coqpit, ):
assets: Dict,
is_eval: True,
data_items: List,
verbose: bool,
num_gpus: int):
... ...
# def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:

View File

@ -1,16 +1,15 @@
from asyncio.log import logger
from dataclasses import dataclass, field
import os import os
from dataclasses import dataclass, field
from coqpit import Coqpit from coqpit import Coqpit
from trainer import TrainerArgs
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from trainer import TrainerArgs
from TTS.tts.utils.text.characters import parse_symbols from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files from TTS.utils.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.utils.trainer_utils import get_last_checkpoint from TTS.utils.trainer_utils import get_last_checkpoint

View File

@ -769,5 +769,3 @@ class F0Dataset:
print("\n") print("\n")
print(f"{indent}> F0Dataset ") print(f"{indent}> F0Dataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}") print(f"{indent}| > Number of instances : {len(self.samples)}")

View File

@ -672,7 +672,9 @@ class VitsDiscriminatorLoss(nn.Module):
def forward(self, scores_disc_real, scores_disc_fake): def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0 loss = 0.0
return_dict = {} return_dict = {}
loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) loss_disc, loss_disc_real, _ = self.discriminator_loss(
scores_real=scores_disc_real, scores_fake=scores_disc_fake
)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_disc"] loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss return_dict["loss"] = loss

View File

@ -26,8 +26,12 @@ class BaseTTS(BaseTrainerModel):
""" """
def __init__( def __init__(
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, self,
language_manager: LanguageManager = None config: Coqpit,
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config

View File

@ -530,7 +530,8 @@ class GlowTTS(BaseTTS):
self.store_inverse() self.store_inverse()
assert not self.training assert not self.training
def get_criterion(self): @staticmethod
def get_criterion():
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
return GlowTTSLoss() return GlowTTSLoss()

View File

@ -1,9 +1,8 @@
import collections
import math import math
import os import os
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -25,7 +24,7 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDuration
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -38,6 +37,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
# IO / Feature extraction # IO / Feature extraction
############################## ##############################
# pylint: disable=global-statement
hann_window = {} hann_window = {}
mel_basis = {} mel_basis = {}
@ -200,7 +200,7 @@ class VitsDataset(TTSDataset):
text, wav_file, speaker_name, language_name, _ = _parse_sample(item) text, wav_file, speaker_name, language_name, _ = _parse_sample(item)
raw_text = text raw_text = text
wav, sr = load_audio(wav_file) wav, _ = load_audio(wav_file)
wav_filename = os.path.basename(wav_file) wav_filename = os.path.basename(wav_file)
token_ids = self.get_token_ids(idx, text) token_ids = self.get_token_ids(idx, text)
@ -538,12 +538,14 @@ class Vits(BaseTTS):
>>> model = Vits(config) >>> model = Vits(config)
""" """
def __init__(self, def __init__(
self,
config: Coqpit, config: Coqpit,
ap: "AudioProcessor" = None, ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None, tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None, speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,): language_manager: LanguageManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager, language_manager) super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
@ -777,9 +779,9 @@ class Vits(BaseTTS):
with torch.no_grad(): with torch.no_grad():
o_scale = torch.exp(-2 * logs_p) o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4 logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
@ -806,7 +808,7 @@ class Vits(BaseTTS):
outputs["loss_duration"] = loss_duration outputs["loss_duration"] = loss_duration
return outputs, attn return outputs, attn
def forward( def forward( # pylint: disable=dangerous-default-value
self, self,
x: torch.tensor, x: torch.tensor,
x_lengths: torch.tensor, x_lengths: torch.tensor,
@ -886,7 +888,7 @@ class Vits(BaseTTS):
waveform, waveform,
slice_ids * self.config.audio.hop_length, slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length,
pad_short = True pad_short=True,
) )
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
@ -929,7 +931,9 @@ class Vits(BaseTTS):
return aux_input["x_lengths"] return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device) return torch.tensor(x.shape[1:2]).to(x.device)
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}): def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
): # pylint: disable=dangerous-default-value
""" """
Note: Note:
To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
@ -1023,7 +1027,6 @@ class Vits(BaseTTS):
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat) return o_hat, y_mask, (z, z_p, z_hat)
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses. """Perform a single training step. Run the model forward pass and compute losses.
@ -1062,7 +1065,7 @@ class Vits(BaseTTS):
) )
# cache tensors for the generator pass # cache tensors for the generator pass
self.model_outputs_cache = outputs self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
# compute scores and features # compute scores and features
scores_disc_fake, _, scores_disc_real, _ = self.disc( scores_disc_fake, _, scores_disc_real, _ = self.disc(
@ -1082,14 +1085,16 @@ class Vits(BaseTTS):
# compute melspec segment # compute melspec segment
with autocast(enabled=False): with autocast(enabled=False):
mel_slice = segment(mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True) mel_slice = segment(
mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True
)
mel_slice_hat = wav_to_mel( mel_slice_hat = wav_to_mel(
y = self.model_outputs_cache["model_outputs"].float(), y=self.model_outputs_cache["model_outputs"].float(),
n_fft = self.config.audio.fft_size, n_fft=self.config.audio.fft_size,
sample_rate = self.config.audio.sample_rate, sample_rate=self.config.audio.sample_rate,
num_mels = self.config.audio.num_mels, num_mels=self.config.audio.num_mels,
hop_length = self.config.audio.hop_length, hop_length=self.config.audio.hop_length,
win_length = self.config.audio.win_length, win_length=self.config.audio.win_length,
fmin=self.config.audio.mel_fmin, fmin=self.config.audio.mel_fmin,
fmax=self.config.audio.mel_fmax, fmax=self.config.audio.mel_fmax,
center=False, center=False,
@ -1105,18 +1110,18 @@ class Vits(BaseTTS):
loss_dict = criterion[optimizer_idx]( loss_dict = criterion[optimizer_idx](
mel_slice_hat=mel_slice.float(), mel_slice_hat=mel_slice.float(),
mel_slice=mel_slice_hat.float(), mel_slice=mel_slice_hat.float(),
z_p= self.model_outputs_cache["z_p"].float(), z_p=self.model_outputs_cache["z_p"].float(),
logs_q= self.model_outputs_cache["logs_q"].float(), logs_q=self.model_outputs_cache["logs_q"].float(),
m_p= self.model_outputs_cache["m_p"].float(), m_p=self.model_outputs_cache["m_p"].float(),
logs_p= self.model_outputs_cache["logs_p"].float(), logs_p=self.model_outputs_cache["logs_p"].float(),
z_len=mel_lens, z_len=mel_lens,
scores_disc_fake= scores_disc_fake, scores_disc_fake=scores_disc_fake,
feats_disc_fake= feats_disc_fake, feats_disc_fake=feats_disc_fake,
feats_disc_real= feats_disc_real, feats_disc_real=feats_disc_real,
loss_duration= self.model_outputs_cache["loss_duration"], loss_duration=self.model_outputs_cache["loss_duration"],
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
gt_spk_emb= self.model_outputs_cache["gt_spk_emb"], gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
syn_spk_emb= self.model_outputs_cache["syn_spk_emb"], syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
) )
return self.model_outputs_cache, loss_dict return self.model_outputs_cache, loss_dict
@ -1248,7 +1253,9 @@ class Vits(BaseTTS):
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
return {"figures": test_figures, "audios": test_audios} return {"figures": test_figures, "audios": test_audios}
def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
logger.test_figures(steps, outputs["figures"]) logger.test_figures(steps, outputs["figures"])
@ -1273,7 +1280,11 @@ class Vits(BaseTTS):
d_vectors = torch.FloatTensor(d_vectors) d_vectors = torch.FloatTensor(d_vectors)
# get language ids from language names # get language ids from language names
if self.language_manager is not None and self.language_manager.language_id_mapping and self.args.use_language_embedding: if (
self.language_manager is not None
and self.language_manager.language_id_mapping
and self.args.use_language_embedding
):
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
if language_ids is not None: if language_ids is not None:
@ -1289,16 +1300,14 @@ class Vits(BaseTTS):
ac = self.config.audio ac = self.config.audio
# compute spectrograms # compute spectrograms
batch["spec"] = wav_to_spec( batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False
)
batch["mel"] = spec_to_mel( batch["mel"] = spec_to_mel(
spec = batch["spec"], spec=batch["spec"],
n_fft = ac.fft_size, n_fft=ac.fft_size,
num_mels = ac.num_mels, num_mels=ac.num_mels,
sample_rate = ac.sample_rate, sample_rate=ac.sample_rate,
fmin = ac.mel_fmin, fmin=ac.mel_fmin,
fmax = ac.mel_fmax, fmax=ac.mel_fmax,
) )
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
@ -1325,27 +1334,6 @@ class Vits(BaseTTS):
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
else: else:
# setup multi-speaker attributes
speaker_id_mapping = None
d_vector_mapping = None
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
)
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
# setup multi-lingual attributes
language_id_mapping = None
if hasattr(self, "language_manager"):
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)
# init dataloader # init dataloader
dataset = VitsDataset( dataset = VitsDataset(
samples=samples, samples=samples,
@ -1495,6 +1483,7 @@ class Vits(BaseTTS):
language_manager = LanguageManager.init_from_config(config) language_manager = LanguageManager.init_from_config(config)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
################################## ##################################
# VITS CHARACTERS # VITS CHARACTERS
################################## ##################################

View File

@ -119,6 +119,7 @@ def rand_segments(
ret = segment(x, segment_indices, segment_size) ret = segment(x, segment_indices, segment_size)
return ret, segment_indices return ret, segment_indices
def average_over_durations(values, durs): def average_over_durations(values, durs):
"""Average values over durations. """Average values over durations.

View File

@ -1,4 +1,3 @@
from abc import ABC
from dataclasses import replace from dataclasses import replace
from typing import Dict from typing import Dict

View File

@ -57,7 +57,7 @@ class Punctuation:
if not isinstance(value, six.string_types): if not isinstance(value, six.string_types):
raise ValueError("[!] Punctuations must be of type str.") raise ValueError("[!] Punctuations must be of type str.")
self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+") self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
def strip(self, text): def strip(self, text):
"""Remove all the punctuations by replacing with `space`. """Remove all the punctuations by replacing with `space`.

View File

@ -270,7 +270,7 @@ class Wavegrad(BaseVocoder):
) -> None: ) -> None:
pass pass
def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
# setup noise schedule and inference # setup noise schedule and inference
ap = assets["audio_processor"] ap = assets["audio_processor"]
noise_schedule = self.config["test_noise_schedule"] noise_schedule = self.config["test_noise_schedule"]
@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder):
y = y.unsqueeze(1) y = y.unsqueeze(1)
return {"input": m, "waveform": y} return {"input": m, "waveform": y}
def get_data_loader( def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
):
ap = assets["audio_processor"] ap = assets["audio_processor"]
dataset = WaveGradDataset( dataset = WaveGradDataset(
ap=ap, ap=ap,

View File

@ -69,8 +69,8 @@ config = VitsConfig(
print_eval=False, print_eval=False,
mixed_precision=False, mixed_precision=False,
sort_by_audio_len=True, sort_by_audio_len=True,
min_seq_len=32 * 256 * 4, min_audio_len=32 * 256 * 4,
max_seq_len=160000, max_audio_len=160000,
output_path=output_path, output_path=output_path,
datasets=dataset_config, datasets=dataset_config,
characters={ characters={

View File

@ -4,6 +4,7 @@ import unittest
import torch import torch
from torch import optim from torch import optim
from trainer.logging.tensorboard_logger import TensorboardLogger
from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path
from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig
@ -11,7 +12,6 @@ from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from trainer.logging.tensorboard_logger import TensorboardLogger
# pylint: disable=unused-variable # pylint: disable=unused-variable

View File

@ -3,15 +3,14 @@ import os
import unittest import unittest
import torch import torch
from TTS.tts.datasets.formatters import ljspeech from trainer.logging.tensorboard_logger import TensorboardLogger
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
from TTS.config import load_config from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from trainer.logging.tensorboard_logger import TensorboardLogger
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
@ -31,7 +30,17 @@ class TestVits(unittest.TestCase):
self.assertEqual(sr, 22050) self.assertEqual(sr, 22050)
spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False) spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False)
mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False) mel = wav_to_mel(
wav,
n_fft=1024,
num_mels=80,
sample_rate=sr,
hop_length=512,
win_length=1024,
fmin=0,
fmax=8000,
center=False,
)
mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000) mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000)
self.assertEqual((mel - mel2).abs().max(), 0) self.assertEqual((mel - mel2).abs().max(), 0)
@ -164,7 +173,7 @@ class TestVits(unittest.TestCase):
num_speakers = 0 num_speakers = 0
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
config.model_args.spec_segment_size = 10 config.model_args.spec_segment_size = 10
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config)
model = Vits(config).to(device) model = Vits(config).to(device)
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
self._check_forward_outputs(config, output_dict) self._check_forward_outputs(config, output_dict)
@ -175,7 +184,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
config.model_args.spec_segment_size = 10 config.model_args.spec_segment_size = 10
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config)
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
model = Vits(config).to(device) model = Vits(config).to(device)
@ -196,7 +205,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(model_args=args) config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config, verbose=False).to(device)
model.train() model.train()
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
d_vectors = torch.randn(batch_size, 256).to(device) d_vectors = torch.randn(batch_size, 256).to(device)
output_dict = model.forward( output_dict = model.forward(
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors} input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}
@ -211,7 +220,7 @@ class TestVits(unittest.TestCase):
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
@ -246,7 +255,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
config.audio.sample_rate = 16000 config.audio.sample_rate = 16000
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)