mirror of https://github.com/coqui-ai/TTS.git
Make stlye
This commit is contained in:
parent
fc8264d9d2
commit
424d04e4f6
|
@ -7,7 +7,6 @@ import subprocess
|
|||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from trainer import TrainerArgs
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ import traceback
|
|||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainer.torch import NoamLR
|
||||
|
||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
|
|
16
TTS/model.py
16
TTS/model.py
|
@ -5,11 +5,11 @@ import torch
|
|||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseTrainerModel(ABC, nn.Module):
|
||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
|
||||
"""
|
||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
|
@ -63,7 +63,7 @@ class BaseTrainerModel(ABC, nn.Module):
|
|||
"""
|
||||
return batch
|
||||
|
||||
def format_batch_on_device(self, batch:Dict) -> Dict:
|
||||
def format_batch_on_device(self, batch: Dict) -> Dict:
|
||||
"""Format batch on device before sending it to the model.
|
||||
|
||||
If not implemented, model uses the batch as is.
|
||||
|
@ -124,7 +124,6 @@ class BaseTrainerModel(ABC, nn.Module):
|
|||
"""The same as `train_log()`"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
|
||||
"""Load a checkpoint and get ready for training or inference.
|
||||
|
@ -148,13 +147,8 @@ class BaseTrainerModel(ABC, nn.Module):
|
|||
|
||||
@abstractmethod
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: True,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
num_gpus: int):
|
||||
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
|
||||
):
|
||||
...
|
||||
|
||||
# def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
from asyncio.log import logger
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from coqpit import Coqpit
|
||||
from trainer import TrainerArgs
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from trainer import TrainerArgs
|
||||
from TTS.tts.utils.text.characters import parse_symbols
|
||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
from TTS.utils.trainer_utils import get_last_checkpoint
|
||||
|
||||
|
||||
|
|
|
@ -769,5 +769,3 @@ class F0Dataset:
|
|||
print("\n")
|
||||
print(f"{indent}> F0Dataset ")
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
|
||||
|
||||
|
|
|
@ -672,7 +672,9 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
def forward(self, scores_disc_real, scores_disc_fake):
|
||||
loss = 0.0
|
||||
return_dict = {}
|
||||
loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake)
|
||||
loss_disc, loss_disc_real, _ = self.discriminator_loss(
|
||||
scores_real=scores_disc_real, scores_fake=scores_disc_fake
|
||||
)
|
||||
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||
loss = loss + return_dict["loss_disc"]
|
||||
return_dict["loss"] = loss
|
||||
|
|
|
@ -26,8 +26,12 @@ class BaseTTS(BaseTrainerModel):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor",
|
||||
tokenizer: "TTSTokenizer",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
|
|
@ -530,7 +530,8 @@ class GlowTTS(BaseTTS):
|
|||
self.store_inverse()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self):
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
|
||||
|
||||
return GlowTTSLoss()
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import collections
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field, replace
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -25,7 +24,7 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDuration
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
|
@ -38,6 +37,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
|
|||
# IO / Feature extraction
|
||||
##############################
|
||||
|
||||
# pylint: disable=global-statement
|
||||
hann_window = {}
|
||||
mel_basis = {}
|
||||
|
||||
|
@ -200,7 +200,7 @@ class VitsDataset(TTSDataset):
|
|||
text, wav_file, speaker_name, language_name, _ = _parse_sample(item)
|
||||
raw_text = text
|
||||
|
||||
wav, sr = load_audio(wav_file)
|
||||
wav, _ = load_audio(wav_file)
|
||||
wav_filename = os.path.basename(wav_file)
|
||||
|
||||
token_ids = self.get_token_ids(idx, text)
|
||||
|
@ -538,12 +538,14 @@ class Vits(BaseTTS):
|
|||
>>> model = Vits(config)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,):
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
|
||||
|
||||
|
@ -673,9 +675,9 @@ class Vits(BaseTTS):
|
|||
)
|
||||
# pylint: disable=W0101,W0105
|
||||
self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.config.audio.sample_rate,
|
||||
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
orig_freq=self.config.audio.sample_rate,
|
||||
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
|
||||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
@ -777,9 +779,9 @@ class Vits(BaseTTS):
|
|||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * logs_p)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
|
||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp2 + logp3 + logp1 + logp4
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
||||
|
||||
|
@ -806,7 +808,7 @@ class Vits(BaseTTS):
|
|||
outputs["loss_duration"] = loss_duration
|
||||
return outputs, attn
|
||||
|
||||
def forward(
|
||||
def forward( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
x: torch.tensor,
|
||||
x_lengths: torch.tensor,
|
||||
|
@ -886,7 +888,7 @@ class Vits(BaseTTS):
|
|||
waveform,
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
pad_short = True
|
||||
pad_short=True,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||
|
@ -929,7 +931,9 @@ class Vits(BaseTTS):
|
|||
return aux_input["x_lengths"]
|
||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||
def inference(
|
||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Note:
|
||||
To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
|
||||
|
@ -1023,7 +1027,6 @@ class Vits(BaseTTS):
|
|||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||
return o_hat, y_mask, (z, z_p, z_hat)
|
||||
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||
|
||||
|
@ -1062,7 +1065,7 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
# cache tensors for the generator pass
|
||||
self.model_outputs_cache = outputs
|
||||
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
# compute scores and features
|
||||
scores_disc_fake, _, scores_disc_real, _ = self.disc(
|
||||
|
@ -1082,14 +1085,16 @@ class Vits(BaseTTS):
|
|||
|
||||
# compute melspec segment
|
||||
with autocast(enabled=False):
|
||||
mel_slice = segment(mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True)
|
||||
mel_slice = segment(
|
||||
mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True
|
||||
)
|
||||
mel_slice_hat = wav_to_mel(
|
||||
y = self.model_outputs_cache["model_outputs"].float(),
|
||||
n_fft = self.config.audio.fft_size,
|
||||
sample_rate = self.config.audio.sample_rate,
|
||||
num_mels = self.config.audio.num_mels,
|
||||
hop_length = self.config.audio.hop_length,
|
||||
win_length = self.config.audio.win_length,
|
||||
y=self.model_outputs_cache["model_outputs"].float(),
|
||||
n_fft=self.config.audio.fft_size,
|
||||
sample_rate=self.config.audio.sample_rate,
|
||||
num_mels=self.config.audio.num_mels,
|
||||
hop_length=self.config.audio.hop_length,
|
||||
win_length=self.config.audio.win_length,
|
||||
fmin=self.config.audio.mel_fmin,
|
||||
fmax=self.config.audio.mel_fmax,
|
||||
center=False,
|
||||
|
@ -1097,7 +1102,7 @@ class Vits(BaseTTS):
|
|||
|
||||
# compute discriminator scores and features
|
||||
scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc(
|
||||
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
|
||||
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
|
||||
)
|
||||
|
||||
# compute losses
|
||||
|
@ -1105,18 +1110,18 @@ class Vits(BaseTTS):
|
|||
loss_dict = criterion[optimizer_idx](
|
||||
mel_slice_hat=mel_slice.float(),
|
||||
mel_slice=mel_slice_hat.float(),
|
||||
z_p= self.model_outputs_cache["z_p"].float(),
|
||||
logs_q= self.model_outputs_cache["logs_q"].float(),
|
||||
m_p= self.model_outputs_cache["m_p"].float(),
|
||||
logs_p= self.model_outputs_cache["logs_p"].float(),
|
||||
z_p=self.model_outputs_cache["z_p"].float(),
|
||||
logs_q=self.model_outputs_cache["logs_q"].float(),
|
||||
m_p=self.model_outputs_cache["m_p"].float(),
|
||||
logs_p=self.model_outputs_cache["logs_p"].float(),
|
||||
z_len=mel_lens,
|
||||
scores_disc_fake= scores_disc_fake,
|
||||
feats_disc_fake= feats_disc_fake,
|
||||
feats_disc_real= feats_disc_real,
|
||||
loss_duration= self.model_outputs_cache["loss_duration"],
|
||||
scores_disc_fake=scores_disc_fake,
|
||||
feats_disc_fake=feats_disc_fake,
|
||||
feats_disc_real=feats_disc_real,
|
||||
loss_duration=self.model_outputs_cache["loss_duration"],
|
||||
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
||||
gt_spk_emb= self.model_outputs_cache["gt_spk_emb"],
|
||||
syn_spk_emb= self.model_outputs_cache["syn_spk_emb"],
|
||||
gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
|
||||
syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
@ -1248,7 +1253,9 @@ class Vits(BaseTTS):
|
|||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
return {"figures": test_figures, "audios": test_audios}
|
||||
|
||||
def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
def test_log(
|
||||
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs["figures"])
|
||||
|
||||
|
@ -1273,7 +1280,11 @@ class Vits(BaseTTS):
|
|||
d_vectors = torch.FloatTensor(d_vectors)
|
||||
|
||||
# get language ids from language names
|
||||
if self.language_manager is not None and self.language_manager.language_id_mapping and self.args.use_language_embedding:
|
||||
if (
|
||||
self.language_manager is not None
|
||||
and self.language_manager.language_id_mapping
|
||||
and self.args.use_language_embedding
|
||||
):
|
||||
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
|
||||
|
||||
if language_ids is not None:
|
||||
|
@ -1289,16 +1300,14 @@ class Vits(BaseTTS):
|
|||
ac = self.config.audio
|
||||
|
||||
# compute spectrograms
|
||||
batch["spec"] = wav_to_spec(
|
||||
batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False
|
||||
)
|
||||
batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
batch["mel"] = spec_to_mel(
|
||||
spec = batch["spec"],
|
||||
n_fft = ac.fft_size,
|
||||
num_mels = ac.num_mels,
|
||||
sample_rate = ac.sample_rate,
|
||||
fmin = ac.mel_fmin,
|
||||
fmax = ac.mel_fmax,
|
||||
spec=batch["spec"],
|
||||
n_fft=ac.fft_size,
|
||||
num_mels=ac.num_mels,
|
||||
sample_rate=ac.sample_rate,
|
||||
fmin=ac.mel_fmin,
|
||||
fmax=ac.mel_fmax,
|
||||
)
|
||||
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
|
||||
|
@ -1325,27 +1334,6 @@ class Vits(BaseTTS):
|
|||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# setup multi-speaker attributes
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = (
|
||||
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
||||
)
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
|
||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||
else:
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||
|
||||
# setup multi-lingual attributes
|
||||
language_id_mapping = None
|
||||
if hasattr(self, "language_manager"):
|
||||
language_id_mapping = (
|
||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
)
|
||||
|
||||
# init dataloader
|
||||
dataset = VitsDataset(
|
||||
samples=samples,
|
||||
|
@ -1495,6 +1483,7 @@ class Vits(BaseTTS):
|
|||
language_manager = LanguageManager.init_from_config(config)
|
||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
||||
|
||||
|
||||
##################################
|
||||
# VITS CHARACTERS
|
||||
##################################
|
||||
|
|
|
@ -119,6 +119,7 @@ def rand_segments(
|
|||
ret = segment(x, segment_indices, segment_size)
|
||||
return ret, segment_indices
|
||||
|
||||
|
||||
def average_over_durations(values, durs):
|
||||
"""Average values over durations.
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from abc import ABC
|
||||
from dataclasses import replace
|
||||
from typing import Dict
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ class Punctuation:
|
|||
if not isinstance(value, six.string_types):
|
||||
raise ValueError("[!] Punctuations must be of type str.")
|
||||
self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
|
||||
self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+")
|
||||
self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
|
||||
|
||||
def strip(self, text):
|
||||
"""Remove all the punctuations by replacing with `space`.
|
||||
|
|
|
@ -270,7 +270,7 @@ class Wavegrad(BaseVocoder):
|
|||
) -> None:
|
||||
pass
|
||||
|
||||
def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument
|
||||
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
ap = assets["audio_processor"]
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
|
@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder):
|
|||
y = y.unsqueeze(1)
|
||||
return {"input": m, "waveform": y}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
|
||||
):
|
||||
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
|
||||
ap = assets["audio_processor"]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
|
|
|
@ -69,8 +69,8 @@ config = VitsConfig(
|
|||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
min_seq_len=32 * 256 * 4,
|
||||
max_seq_len=160000,
|
||||
min_audio_len=32 * 256 * 4,
|
||||
max_audio_len=160000,
|
||||
output_path=output_path,
|
||||
datasets=dataset_config,
|
||||
characters={
|
||||
|
|
|
@ -4,6 +4,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
from torch import optim
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
|
@ -11,7 +12,6 @@ from TTS.tts.layers.losses import GlowTTSLoss
|
|||
from TTS.tts.models.glow_tts import GlowTTS
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
|
|
|
@ -3,15 +3,14 @@ import os
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
from TTS.tts.datasets.formatters import ljspeech
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset
|
||||
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||
|
@ -31,7 +30,17 @@ class TestVits(unittest.TestCase):
|
|||
self.assertEqual(sr, 22050)
|
||||
|
||||
spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False)
|
||||
mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False)
|
||||
mel = wav_to_mel(
|
||||
wav,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
sample_rate=sr,
|
||||
hop_length=512,
|
||||
win_length=1024,
|
||||
fmin=0,
|
||||
fmax=8000,
|
||||
center=False,
|
||||
)
|
||||
mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000)
|
||||
|
||||
self.assertEqual((mel - mel2).abs().max(), 0)
|
||||
|
@ -45,7 +54,7 @@ class TestVits(unittest.TestCase):
|
|||
|
||||
def test_dataset(self):
|
||||
"""TODO:"""
|
||||
...
|
||||
...
|
||||
|
||||
def test_init_multispeaker(self):
|
||||
num_speakers = 10
|
||||
|
@ -164,7 +173,7 @@ class TestVits(unittest.TestCase):
|
|||
num_speakers = 0
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
config.model_args.spec_segment_size = 10
|
||||
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config)
|
||||
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config)
|
||||
model = Vits(config).to(device)
|
||||
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
|
||||
self._check_forward_outputs(config, output_dict)
|
||||
|
@ -175,7 +184,7 @@ class TestVits(unittest.TestCase):
|
|||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
config.model_args.spec_segment_size = 10
|
||||
|
||||
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config)
|
||||
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config)
|
||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||
|
||||
model = Vits(config).to(device)
|
||||
|
@ -196,7 +205,7 @@ class TestVits(unittest.TestCase):
|
|||
config = VitsConfig(model_args=args)
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
model.train()
|
||||
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
d_vectors = torch.randn(batch_size, 256).to(device)
|
||||
output_dict = model.forward(
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}
|
||||
|
@ -211,7 +220,7 @@ class TestVits(unittest.TestCase):
|
|||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||
|
||||
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
||||
|
||||
|
@ -246,7 +255,7 @@ class TestVits(unittest.TestCase):
|
|||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||
config.audio.sample_rate = 16000
|
||||
|
||||
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue