Merge pull request #800 from coqui-ai/forward_tts

Forward TTS implementation
This commit is contained in:
Eren Gölge 2021-09-13 09:25:56 +02:00 committed by GitHub
commit aed9a32d52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 24517 additions and 1050 deletions

View File

@ -1,5 +1,5 @@
.DEFAULT_GOAL := help
.PHONY: test system-deps dev-deps deps style lint install help
.PHONY: test system-deps dev-deps deps style lint install help docs
help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
@ -45,3 +45,6 @@ deps: ## install 🐸 requirements.
install: ## install 🐸 TTS for development.
pip install -e .[all]
docs: ## build the docs
$(MAKE) -C docs clean && $(MAKE) -C docs html

View File

@ -73,6 +73,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
- FastPitch: [paper](https://arxiv.org/pdf/2006.06873.pdf)
- FastSpeech: [paper](https://arxiv.org/abs/1905.09263)
### End-to-End Models
- VITS: [paper](https://arxiv.org/pdf/2106.06103)

View File

@ -47,15 +47,6 @@
"license": "MPL",
"contact": "egolge@coqui.com"
},
"speedy-speech-wn": {
"description": "Speedy Speech model with wavenet decoder.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip",
"default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan",
"commit": "77b6145",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact": "egolge@coqui.com"
},
"vits": {
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--vits.zip",

View File

@ -16,7 +16,6 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
from TTS.utils.io import load_fsspec
use_cuda = torch.cuda.is_available()
@ -77,14 +76,14 @@ def set_filename(wav_path, out_path):
def format_data(data):
# setup input data
text_input = data['text']
text_lengths = data['text_lengths']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
item_idx = data['item_idxs']
d_vectors = data['d_vectors']
speaker_ids = data['speaker_ids']
attn_mask = data['attns']
text_input = data["text"]
text_lengths = data["text_lengths"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
item_idx = data["item_idxs"]
d_vectors = data["d_vectors"]
speaker_ids = data["speaker_ids"]
attn_mask = data["attns"]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
@ -133,7 +132,11 @@ def inference(
elif d_vectors is not None:
speaker_c = d_vectors
outputs = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
)
model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
@ -239,8 +242,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_model(c)
# restore model
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.load_checkpoint(c, args.checkpoint_path, eval=True)
if use_cuda:
model.cuda()

View File

@ -275,7 +275,8 @@ class Trainer:
if self.args.continue_path:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
scheduler.last_epoch = self.restore_step
if scheduler is not None:
scheduler.last_epoch = self.restore_step
else:
self.scheduler.last_epoch = self.restore_step
@ -662,6 +663,7 @@ class Trainer:
lrs = {"current_lr": current_lr}
# log run-time stats
loss_dict.update(lrs)
loss_dict.update(
{
"step_time": round(step_time, 4),
@ -1125,7 +1127,7 @@ def get_last_checkpoint(path: str) -> Tuple[str, str]:
last_model_num = model_num
last_model = file_name
# if there is not checkpoint found above
# if there is no checkpoint found above
# find the checkpoint with the latest
# modification date.
key_file_names = [fn for fn in file_names if key in fn]
@ -1144,7 +1146,7 @@ def get_last_checkpoint(path: str) -> Tuple[str, str]:
last_models["checkpoint"] = last_models["best_model"]
elif "best_model" not in last_models: # no best model
# this shouldn't happen, but let's handle it just in case
last_models["best_model"] = None
last_models["best_model"] = last_models["checkpoint"]
# finally check if last best model is more recent than checkpoint
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
last_models["checkpoint"] = last_models["best_model"]
@ -1180,7 +1182,6 @@ def process_args(args, config=None):
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# init config if not already defined
if config is None:
if args.config_path:

View File

@ -2,12 +2,12 @@ from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.fast_pitch import FastPitchArgs
from TTS.tts.models.forward_tts import ForwardTTSArgs
@dataclass
class FastPitchConfig(BaseTTSConfig):
"""Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models.
"""Configure `ForwardTTS` as FastPitch model.
Example:
@ -18,6 +18,10 @@ class FastPitchConfig(BaseTTSConfig):
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
@ -36,22 +40,43 @@ class FastPitchConfig(BaseTTSConfig):
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
noam_schedule (bool):
enable / disable the use of Noam LR scheduler. Defaults to False.
d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.
warmup_steps (int):
Number of warm-up steps for the Noam scheduler. Defaults 4000.
optimizer (str):
Name of the model optimizer. Defaults to `Adam`.
optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.
lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
lr (float):
Initial learning rate. Defaults to `1e-3`.
grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.
spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
huber_loss_alpha (float):
dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
@ -74,8 +99,10 @@ class FastPitchConfig(BaseTTSConfig):
"""
model: str = "fast_pitch"
base_model: str = "forward_tts"
# model specific params
model_args: FastPitchArgs = field(default_factory=FastPitchArgs)
model_args: ForwardTTSArgs = ForwardTTSArgs()
# multi-speaker settings
use_speaker_embedding: bool = False
@ -92,11 +119,13 @@ class FastPitchConfig(BaseTTSConfig):
grad_clip: float = 5.0
# loss params
spec_loss_type: str = "mse"
duration_loss_type: str = "mse"
use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000

View File

@ -0,0 +1,151 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@dataclass
class FastSpeechConfig(BaseTTSConfig):
"""Configure `ForwardTTS` as FastSpeech model.
Example:
>>> from TTS.tts.configs import FastSpeechConfig
>>> config = FastSpeechConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `FastSpeechArgs` for more details. Defaults to `FastSpeechArgs()`.
data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.
use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.
use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.
optimizer (str):
Name of the model optimizer. Defaults to `Adam`.
optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.
lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
lr (float):
Initial learning rate. Defaults to `1e-3`.
grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.
spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
model: str = "fast_speech"
base_model: str = "forward_tts"
# model specific params
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False)
# multi-speaker settings
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = 0
# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 5.0
# loss params
spec_loss_type: str = "mse"
duration_loss_type: str = "mse"
use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 0.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000
# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 # DO NOT CHANGE
# dataset configs
compute_f0: bool = True
f0_cache_path: str = None
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)

View File

@ -2,81 +2,160 @@ from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.speedy_speech import SpeedySpeechArgs
from TTS.tts.models.forward_tts import ForwardTTSArgs
@dataclass
class SpeedySpeechConfig(BaseTTSConfig):
"""Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models.
"""Configure `ForwardTTS` as SpeedySpeech model.
Example:
>>> from TTS.tts.configs import SpeedySpeechConfig
>>> config = SpeedySpeechConfig()
Args:
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `speedy_speech`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `SpeedySpeechArgs` for more details. Defaults to `SpeedySpeechArgs()`.
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.
use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.
use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
noam_schedule (bool):
enable / disable the use of Noam LR scheduler. Defaults to False.
warmup_steps (int):
Number of warm-up steps for the Noam scheduler. Defaults 4000.
d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.
optimizer (str):
Name of the model optimizer. Defaults to `RAdam`.
optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.
lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
lr (float):
Initial learning rate. Defaults to `1e-3`.
grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.
spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `l1`.
duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `huber`.
use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_alpha (float):
Weight for the SSIM loss. If set <= 0, disables the SSIM loss. Defaults to 1.0.
huber_alpha (float):
Weight for the duration predictor's loss. Defaults to 1.0.
l1_alpha (float):
Weight for the L1 spectrogram loss. If set <= 0, disables the L1 loss. Defaults to 1.0.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
model: str = "speedy_speech"
# model specific params
model_args: SpeedySpeechArgs = field(default_factory=SpeedySpeechArgs)
base_model: str = "forward_tts"
# set model args as SpeedySpeech
model_args: ForwardTTSArgs = ForwardTTSArgs(
use_pitch=False,
encoder_type="residual_conv_bn",
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13,
},
decoder_type="residual_conv_bn",
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17,
},
out_channels=80,
hidden_channels=128,
num_speakers=0,
positional_encoding=True,
detach_duration_predictor=True
)
# multi-speaker settings
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = 0
# optimizer parameters
optimizer: str = "RAdam"
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = None
lr_scheduler_params: dict = None
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 5.0
# loss params
ssim_alpha: float = 1.0
huber_alpha: float = 1.0
l1_alpha: float = 1.0
spec_loss_type: str = "l1"
duration_loss_type: str = "huber"
use_ssim_loss: bool = False
ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 0.3
binary_align_loss_start_step: int = 50000
# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 # DO NOT CHANGE
# dataset configs
compute_f0: bool = False
f0_cache_path: str = None
# testing
test_sentences: List[str] = field(
default_factory=lambda: [

View File

@ -1,15 +1 @@
from TTS.tts.layers.losses import *
def setup_loss(config):
if config.model.lower() in ["tacotron", "tacotron2"]:
model = TacotronLoss(config)
elif config.model.lower() == "glow_tts":
model = GlowTTSLoss()
elif config.model.lower() == "speedy_speech":
model = SpeedySpeechLoss(config)
elif config.model.lower() == "align_tts":
model = AlignTTSLoss(config)
else:
raise ValueError(f" [!] loss for model {config.model.lower()} cannot be found.")
return model

View File

@ -70,7 +70,9 @@ class FFTransformerBlock(nn.Module):
class FFTDurationPredictor:
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
def __init__(
self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None
): # pylint: disable=unused-argument
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
self.proj = nn.Linear(in_channels, 1)

View File

@ -9,7 +9,7 @@ from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlo
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import sequence_mask
class Encoder(nn.Module):

View File

@ -1,106 +0,0 @@
import numpy as np
import torch
from torch.nn import functional as F
from TTS.tts.utils.data import sequence_mask
try:
# TODO: fix pypi cython installation problem.
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
CYTHON = True
except ModuleNotFoundError:
CYTHON = False
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def generate_path(duration, mask):
"""
Shapes:
- duration: :math:`[B, T_en]`
- mask: :math:'[B, T_en, T_de]`
- path: :math:`[B, T_en, T_de]`
"""
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
def maximum_path(value, mask):
if CYTHON:
return maximum_path_cython(value, mask)
return maximum_path_numpy(value, mask)
def maximum_path_cython(value, mask):
"""Cython optimised version.
Shapes:
- value: :math:`[B, T_en, T_de]`
- mask: :math:`[B, T_en, T_de]`
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.data.cpu().numpy().astype(np.float32)
path = np.zeros_like(value).astype(np.int32)
mask = mask.data.cpu().numpy()
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
maximum_path_c(path, value, t_x_max, t_y_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)
def maximum_path_numpy(value, mask, max_neg_val=None):
"""
Monotonic alignment search algorithm
Numpy-friendly version. It's about 4 times faster than torch version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
if max_neg_val is None:
max_neg_val = -np.inf # Patch for Sphinx complaint
value = value * mask
device = value.device
dtype = value.dtype
value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(np.bool)
b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64)
v = np.zeros((b, t_x), dtype=np.float32)
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
for j in range(t_y):
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
v1 = v
max_mask = v1 >= v0
v_max = np.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = x_range <= j
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
direction = np.where(mask, direction, 1)
path = np.zeros(value.shape, dtype=np.float32)
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
index_range = np.arange(b)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.astype(np.float32)
path = torch.from_numpy(path).to(device=device, dtype=dtype)
return path

View File

@ -6,7 +6,7 @@ from coqpit import Coqpit
from torch import nn
from torch.nn import functional
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import ssim
from TTS.utils.audio import TorchSTFT
@ -236,10 +236,40 @@ class Huber(nn.Module):
y: B x T
length: B
"""
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float()
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float()
return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum()
class ForwardSumLoss(nn.Module):
def __init__(self, blank_logprob=-1):
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
total_loss = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
loss = self.ctc_loss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
total_loss = total_loss + loss
total_loss = total_loss / attn_logprob.shape[0]
return total_loss
########################
# MODEL LOSS LAYERS
########################
@ -413,25 +443,6 @@ class GlowTTSLoss(torch.nn.Module):
return return_dict
class SpeedySpeechLoss(nn.Module):
def __init__(self, c):
super().__init__()
self.l1 = L1LossMasked(False)
self.ssim = SSIMLoss()
self.huber = Huber()
self.ssim_alpha = c.ssim_alpha
self.huber_alpha = c.huber_alpha
self.l1_alpha = c.l1_alpha
def forward(self, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens):
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss
return {"loss": loss, "loss_l1": l1_loss, "loss_ssim": ssim_loss, "loss_dur": huber_loss}
def mse_loss_custom(x, y):
"""MSE loss using the torch back-end without reduction.
It uses less VRAM than the raw code"""
@ -660,51 +671,41 @@ class VitsDiscriminatorLoss(nn.Module):
return return_dict
class ForwardSumLoss(nn.Module):
def __init__(self, blank_logprob=-1):
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob
class ForwardTTSLoss(nn.Module):
"""Generic configurable ForwardTTS loss."""
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
total_loss = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
loss = self.ctc_loss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
total_loss = total_loss + loss
total_loss = total_loss / attn_logprob.shape[0]
return total_loss
class FastPitchLoss(nn.Module):
def __init__(self, c):
super().__init__()
self.spec_loss = MSELossMasked(False)
self.ssim = SSIMLoss()
self.dur_loss = MSELossMasked(False)
self.pitch_loss = MSELossMasked(False)
if c.spec_loss_type == "mse":
self.spec_loss = MSELossMasked(False)
elif c.spec_loss_type == "l1":
self.spec_loss = L1LossMasked(False)
else:
raise ValueError(" [!] Unknown spec_loss_type {}".format(c.spec_loss_type))
if c.duration_loss_type == "mse":
self.dur_loss = MSELossMasked(False)
elif c.duration_loss_type == "l1":
self.dur_loss = L1LossMasked(False)
elif c.duration_loss_type == "huber":
self.dur_loss = Huber()
else:
raise ValueError(" [!] Unknown duration_loss_type {}".format(c.duration_loss_type))
if c.model_args.use_aligner:
self.aligner_loss = ForwardSumLoss()
self.aligner_loss_alpha = c.aligner_loss_alpha
if c.model_args.use_pitch:
self.pitch_loss = MSELossMasked(False)
self.pitch_loss_alpha = c.pitch_loss_alpha
if c.use_ssim_loss:
self.ssim = SSIMLoss() if c.use_ssim_loss else None
self.ssim_loss_alpha = c.ssim_loss_alpha
self.spec_loss_alpha = c.spec_loss_alpha
self.ssim_loss_alpha = c.ssim_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.pitch_loss_alpha = c.pitch_loss_alpha
self.aligner_loss_alpha = c.aligner_loss_alpha
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
@staticmethod
@ -731,7 +732,7 @@ class FastPitchLoss(nn.Module):
):
loss = 0
return_dict = {}
if self.ssim_loss_alpha > 0:
if hasattr(self, "ssim_loss") and self.ssim_loss_alpha > 0:
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
loss = loss + self.ssim_loss_alpha * ssim_loss
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
@ -747,12 +748,12 @@ class FastPitchLoss(nn.Module):
loss = loss + self.dur_loss_alpha * dur_loss
return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss
if self.pitch_loss_alpha > 0:
if hasattr(self, "pitch_loss") and self.pitch_loss_alpha > 0:
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
loss = loss + self.pitch_loss_alpha * pitch_loss
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
if self.aligner_loss_alpha > 0:
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
loss = loss + self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss

View File

@ -5,7 +5,7 @@ from torch import nn
from TTS.tts.layers.glow_tts.glow import WN
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import sequence_mask
LRELU_SLOPE = 0.1

View File

@ -4,7 +4,11 @@ from TTS.utils.generic_utils import find_module
def setup_model(config):
print(" > Using model: {}".format(config.model))
MyModel = find_module("TTS.tts.models", config.model.lower())
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.tts.models", config.base_model.lower())
else:
MyModel = find_module("TTS.tts.models", config.model.lower())
# define set of characters used by the model
if config.characters is not None:
# set characters from config

View File

@ -10,9 +10,8 @@ from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@ -168,7 +167,12 @@ class AlignTTS(BaseTTS):
return dr_mas.squeeze(1), log_p
@staticmethod
def convert_dr_to_align(dr, x_mask, y_mask):
def generate_attn(dr, x_mask, y_mask=None):
# compute decode mask from the durations
if y_mask is None:
y_lengths = dr.sum(1).long()
y_lengths[y_lengths < 1] = 1
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn
@ -187,7 +191,7 @@ class AlignTTS(BaseTTS):
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
"""
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
attn = self.generate_attn(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
@ -275,7 +279,7 @@ class AlignTTS(BaseTTS):
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
attn = self.generate_attn(dr_mas, x_mask, y_mask)
elif phase == 1:
# train decoder
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)

View File

@ -9,7 +9,7 @@ from torch import nn
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.text import make_symbols
from TTS.utils.generic_utils import format_aux_input
@ -115,12 +115,19 @@ class BaseTacotron(BaseTTS):
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
# TODO: set r in run-time by taking it from the new config
if "r" in state:
# set r from the state (for compatibility with older checkpoints)
self.decoder.set_r(state["r"])
else:
elif "config" in state:
# set r from config used at training time (for inference)
self.decoder.set_r(state["config"]["r"])
else:
# set r from the new config (for new-models)
self.decoder.set_r(config.r)
if eval:
self.eval()
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
assert not self.training
def get_criterion(self) -> nn.Module:

View File

@ -11,16 +11,15 @@ from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.aligner import AlignmentNetwork
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
from TTS.utils.audio import AudioProcessor
@dataclass
class FastPitchArgs(Coqpit):
"""Fast Pitch Model arguments.
class ForwardTTSArgs(Coqpit):
"""ForwardTTS Model arguments.
Args:
@ -36,6 +35,14 @@ class FastPitchArgs(Coqpit):
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
use_aligner (bool):
Whether to use aligner network to learn the text to speech alignment or use pre-computed durations.
If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the
pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True.
use_pitch (bool):
Use pitch predictor to learn the pitch. Defaults to True.
duration_predictor_hidden_channels (int):
Number of hidden channels in the duration predictor. Defaults to 256.
@ -93,21 +100,21 @@ class FastPitchArgs(Coqpit):
max_duration (int):
Maximum duration accepted by the model. Defaults to 75.
use_aligner (bool):
Use aligner network to learn the text to speech alignment. Defaults to True.
"""
num_chars: int = None
out_channels: int = 80
hidden_channels: int = 384
num_speakers: int = 0
duration_predictor_hidden_channels: int = 256
duration_predictor_kernel_size: int = 3
duration_predictor_dropout_p: float = 0.1
use_aligner: bool = True
use_pitch: bool = True
pitch_predictor_hidden_channels: int = 256
pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3
duration_predictor_hidden_channels: int = 256
duration_predictor_kernel_size: int = 3
duration_predictor_dropout_p: float = 0.1
positional_encoding: bool = True
poisitonal_encoding_use_scale: bool = True
length_scale: int = 1
@ -123,32 +130,32 @@ class FastPitchArgs(Coqpit):
d_vector_dim: int = 0
detach_duration_predictor: bool = False
max_duration: int = 75
use_aligner: bool = True
class FastPitch(BaseTTS):
"""FastPitch model. Very similart to SpeedySpeech model but with pitch prediction.
class ForwardTTS(BaseTTS):
"""General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment
network and a pitch predictor.
Paper::
https://arxiv.org/abs/2006.06873
If the alignment network is used, the model learns the text-to-speech alignment
from the data instead of using pre-computed durations.
Paper abstract::
We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental
frequency contours. The model predicts pitch contours during inference. By altering these predictions,
the generated speech can be more expressive, better match the semantic of the utterance, and in the end
more engaging to the listener. Uniformly increasing or decreasing pitch with FastPitch generates speech
that resembles the voluntary modulation of voice. Conditioning on frequency contours improves the overall
quality of synthesized speech, making it comparable to state-of-the-art. It does not introduce an overhead,
and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time
factor for mel-spectrogram synthesis of a typical utterance."
If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each
input character as in the FastPitch model.
`ForwardTTS` can be configured to one of these architectures,
- FastPitch
- SpeedySpeech
- FastSpeech
- TODO: FastSpeech2 (requires average speech energy predictor)
Args:
config (Coqpit): Model coqpit class.
Examples:
>>> from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs
>>> config = FastPitchArgs()
>>> model = FastPitch(config)
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
>>> config = ForwardTTSArgs()
>>> model = ForwardTTS(config)
"""
# pylint: disable=dangerous-default-value
@ -157,24 +164,25 @@ class FastPitch(BaseTTS):
super().__init__()
# don't use isintance not to import recursively
if config.__class__.__name__ == "FastPitchConfig":
if "Config" in config.__class__.__name__:
if "characters" in config:
# loading from FasrPitchConfig
_, self.config, num_chars = self.get_characters(config)
config.model_args.num_chars = num_chars
self.args = self.config.model_args
else:
# loading from FastPitchArgs
# loading from ForwardTTSArgs
self.config = config
self.args = config.model_args
elif isinstance(config, FastPitchArgs):
elif isinstance(config, ForwardTTSArgs):
self.args = config
self.config = config
else:
raise ValueError("config must be either a VitsConfig or Vitsself.args")
raise ValueError("config must be either a *Config or ForwardTTSArgs")
self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner
self.use_pitch = self.args.use_pitch
self.use_binary_alignment_loss = False
self.length_scale = (
@ -208,19 +216,19 @@ class FastPitch(BaseTTS):
self.args.duration_predictor_dropout_p,
)
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.args.d_vector_dim,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
)
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.pitch_embedding_kernel_size,
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
)
if self.args.use_pitch:
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.args.d_vector_dim,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
)
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.pitch_embedding_kernel_size,
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
)
if self.args.num_speakers > 1 and not self.args.use_d_vector:
# speaker embedding layer
@ -257,18 +265,22 @@ class FastPitch(BaseTTS):
"""Generate attention alignment map from durations and
expand encoder outputs
Shapes
Shapes:
- en: :math:`(B, D_{en}, T_{en})`
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
Examples:
- encoder output: :math:`[a,b,c,d]`
- durations: :math:`[1, 3, 2, 1]`
Examples::
- expanded: :math:`[a, b, b, b, c, c, d]`
- attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]`
encoder output: [a,b,c,d]
durations: [1, 3, 2, 1]
expanded: [a, b, b, b, c, c, d]
attention map: [[0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
"""
attn = self.generate_attn(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
@ -416,7 +428,7 @@ class FastPitch(BaseTTS):
"""
o_pitch = self.pitch_predictor(o_en, x_mask)
if pitch is not None:
avg_pitch = average_pitch(pitch, dr)
avg_pitch = average_over_durations(pitch, dr)
o_pitch_emb = self.pitch_emb(avg_pitch)
return o_pitch_emb, o_pitch, avg_pitch
o_pitch_emb = self.pitch_emb(o_pitch)
@ -471,7 +483,7 @@ class FastPitch(BaseTTS):
y: torch.FloatTensor = None,
dr: torch.IntTensor = None,
pitch: torch.FloatTensor = None,
aux_input: Dict = {"d_vectors": 0, "speaker_ids": None}, # pylint: disable=unused-argument
aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
) -> Dict:
"""Model's forward pass.
@ -479,10 +491,10 @@ class FastPitch(BaseTTS):
x (torch.LongTensor): Input character sequences.
x_lengths (torch.LongTensor): Input sequence lengths.
y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None.
y (torch.FloatTensor): Spectrogram frames. Defaults to None.
dr (torch.IntTensor): Character durations over the spectrogram frames. Defaults to None.
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Defaults to None.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
Shapes:
- x: :math:`[B, T_max]`
@ -495,8 +507,8 @@ class FastPitch(BaseTTS):
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
# compute sequence masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(y.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(y.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
# encoder pass
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
# duration predictor pass
@ -507,27 +519,36 @@ class FastPitch(BaseTTS):
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# generate attn mask from predicted durations
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
# aligner pass
# aligner
o_alignment_dur = None
alignment_soft = None
alignment_logprob = None
alignment_mas = None
if self.use_aligner:
o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
x_emb, y, x_mask, y_mask
)
alignment_soft = alignment_soft.transpose(1, 2)
alignment_mas = alignment_mas.transpose(1, 2)
dr = o_alignment_dur
# pitch predictor pass
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb
o_pitch = None
avg_pitch = None
if self.args.use_pitch:
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb
# decoder pass
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)
outputs = {
"model_outputs": o_de,
"durations_log": o_dr_log.squeeze(1),
"durations": o_dr.squeeze(1),
"attn_durations": o_attn, # for visualization
"model_outputs": o_de, # [B, T, C]
"durations_log": o_dr_log.squeeze(1), # [B, T]
"durations": o_dr.squeeze(1), # [B, T]
"attn_durations": o_attn, # for visualization [B, T_en, T_de']
"pitch_avg": o_pitch,
"pitch_avg_gt": avg_pitch,
"alignments": attn,
"alignment_soft": alignment_soft.transpose(1, 2),
"alignment_mas": alignment_mas.transpose(1, 2),
"alignments": attn, # [B, T_de, T_en]
"alignment_soft": alignment_soft,
"alignment_mas": alignment_mas,
"o_alignment_dur": o_alignment_dur,
"alignment_logprob": alignment_logprob,
"x_mask": x_mask,
@ -558,8 +579,10 @@ class FastPitch(BaseTTS):
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
# pitch predictor pass
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
o_en = o_en + o_pitch_emb
o_pitch = None
if self.args.use_pitch:
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
o_en = o_en + o_pitch_emb
# decoder pass
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g)
outputs = {
@ -575,7 +598,7 @@ class FastPitch(BaseTTS):
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
pitch = batch["pitch"]
pitch = batch["pitch"] if self.args.use_pitch else None
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
durations = batch["durations"]
@ -597,10 +620,10 @@ class FastPitch(BaseTTS):
decoder_output_lens=mel_lengths,
dur_output=outputs["durations_log"],
dur_target=durations,
pitch_output=outputs["pitch_avg"],
pitch_target=outputs["pitch_avg_gt"],
pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"],
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
)
@ -615,28 +638,33 @@ class FastPitch(BaseTTS):
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]
pitch = batch["pitch"]
pitch_avg_expanded, _ = self.expand_encoder_outputs(
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
)
pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
pitch = pitch[0, 0].data.cpu().numpy()
# TODO: denormalize before plotting
pitch = abs(pitch)
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False),
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
}
# plot pitch figures
if self.args.use_pitch:
pitch = batch["pitch"]
pitch_avg_expanded, _ = self.expand_encoder_outputs(
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
)
pitch = pitch[0, 0].data.cpu().numpy()
# TODO: denormalize before plotting
pitch = abs(pitch)
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
pitch_figures = {
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
}
figures.update(pitch_figures)
# plot the attention mask computed from the predicted durations
if "attn_durations" in outputs:
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
@ -662,36 +690,11 @@ class FastPitch(BaseTTS):
assert not self.training
def get_criterion(self):
from TTS.tts.layers.losses import FastPitchLoss # pylint: disable=import-outside-toplevel
from TTS.tts.layers.losses import ForwardTTSLoss # pylint: disable=import-outside-toplevel
return FastPitchLoss(self.config)
return ForwardTTSLoss(self.config)
def on_train_step_start(self, trainer):
"""Enable binary alignment loss when needed"""
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
self.use_binary_alignment_loss = True
def average_pitch(pitch, durs):
"""Compute the average pitch value for each input character based on the durations.
Shapes:
- pitch: :math:`[B, 1, T_de]`
- durs: :math:`[B, T_en]`
"""
durs_cums_ends = torch.cumsum(durs, dim=1).long()
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0))
pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0))
bs, l = durs_cums_ends.size()
n_formants = pitch.size(1)
dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l)
dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l)
pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float()
pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float()
pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems)
return pitch_avg

View File

@ -7,9 +7,8 @@ from torch.nn import functional as F
from TTS.tts.configs import GlowTTSConfig
from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -133,7 +132,7 @@ class GlowTTS(BaseTTS):
return y_mean, y_log_scale, o_attn_dur
def forward(
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
@ -185,7 +184,7 @@ class GlowTTS(BaseTTS):
@torch.no_grad()
def inference_with_MAS(
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
It's similar to the teacher forcing in Tacotron.
@ -246,7 +245,7 @@ class GlowTTS(BaseTTS):
@torch.no_grad()
def decoder_inference(
self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
@ -278,7 +277,9 @@ class GlowTTS(BaseTTS):
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"]
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
@ -331,7 +332,13 @@ class GlowTTS(BaseTTS):
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids})
outputs = self.forward(
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
)
loss_dict = criterion(
outputs["model_outputs"],

View File

@ -1,320 +0,0 @@
from dataclasses import dataclass, field
import torch
from coqpit import Coqpit
from torch import nn
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass
class SpeedySpeechArgs(Coqpit):
"""
Args:
num_chars (int): number of unique input to characters
out_channels (int): number of output tensor channels. It is equal to the expected spectrogram size.
hidden_channels (int): number of channels in all the model layers.
positional_encoding (bool, optional): enable/disable Positional encoding on encoder outputs. Defaults to True.
length_scale (int, optional): coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1.
encoder_type (str, optional): set the encoder type. Defaults to 'residual_conv_bn'.
encoder_params (dict, optional): set encoder parameters depending on 'encoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13 }.
decoder_type (str, optional): decoder type. Defaults to 'residual_conv_bn'.
decoder_params (dict, optional): set decoder parameters depending on 'decoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 }.
num_speakers (int, optional): number of speakers for multi-speaker training. Defaults to 0.
use_d_vector (bool, optional): enable external speaker embeddings. Defaults to False.
d_vector_dim (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
"""
num_chars: int = None
out_channels: int = 80
hidden_channels: int = 128
num_speakers: int = 0
positional_encoding: bool = True
length_scale: int = 1
encoder_type: str = "residual_conv_bn"
encoder_params: dict = field(
default_factory=lambda: {
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13,
}
)
decoder_type: str = "residual_conv_bn"
decoder_params: dict = field(
default_factory=lambda: {
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17,
}
)
use_d_vector: bool = False
d_vector_dim: int = 0
class SpeedySpeech(BaseTTS):
"""Speedy Speech model
https://arxiv.org/abs/2008.03802
Encoder -> DurationPredictor -> Decoder
Paper abstract:
While recent neural sequence-to-sequence models have greatly improved the quality of speech
synthesis, there has not been a system capable of fast training, fast inference and high-quality audio synthesis
at the same time. We propose a student-teacher network capable of high-quality faster-than-real-time spectrogram
synthesis, with low requirements on computational resources and fast training time. We show that self-attention
layers are not necessary for generation of high quality audio. We utilize simple convolutional blocks with
residual connections in both student and teacher networks and use only a single attention layer in the teacher
model. Coupled with a MelGAN vocoder, our model's voice quality was rated significantly higher than Tacotron 2.
Our model can be efficiently trained on a single GPU and can run in real time even on a CPU. We provide both
our source code and audio samples in our GitHub repository.
Notes:
The vanilla model is able to achieve a reasonable performance with only
~3M model parameters and convolutional layers.
This model requires precomputed phoneme durations to train a duration predictor. At inference
it only uses the duration predictor to compute durations and expand encoder outputs respectively.
You can also mix and match different encoder and decoder networks beyond the paper.
Check `SpeedySpeechArgs` for arguments.
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__()
self.config = config
if "characters" in config:
_, self.config, self.num_chars = self.get_characters(config)
self.length_scale = (
float(config.model_args.length_scale)
if isinstance(config.model_args.length_scale, int)
else config.model_args.length_scale
)
self.emb = nn.Embedding(self.num_chars, config.model_args.hidden_channels)
self.encoder = Encoder(
config.model_args.hidden_channels,
config.model_args.hidden_channels,
config.model_args.encoder_type,
config.model_args.encoder_params,
config.model_args.d_vector_dim,
)
if config.model_args.positional_encoding:
self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels)
self.decoder = Decoder(
config.model_args.out_channels,
config.model_args.hidden_channels,
config.model_args.decoder_type,
config.model_args.decoder_params,
)
self.duration_predictor = DurationPredictor(config.model_args.hidden_channels + config.model_args.d_vector_dim)
if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector:
# speaker embedding layer
self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels:
self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1)
@staticmethod
def expand_encoder_outputs(en, dr, x_mask, y_mask):
"""Generate attention alignment map from durations and
expand encoder outputs
Example:
encoder output: [a,b,c,d]
durations: [1, 3, 2, 1]
expanded: [a, b, b, b, c, c, d]
attention map: [[0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
o_dr[o_dr < 1] = 1.0
o_dr = torch.round(o_dr)
return o_dr
@staticmethod
def _concat_speaker_embedding(o_en, g):
g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en]
o_en = torch.cat([o_en, g_exp], 1)
return o_en
def _sum_speaker_embedding(self, x, g):
# project g to decoder dim.
if hasattr(self, "proj_g"):
g = self.proj_g(g)
return x + g
def _forward_encoder(self, x, x_lengths, g=None):
if hasattr(self, "emb_g"):
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
if g is not None:
g = g.unsqueeze(-1)
# [B, T, C]
x_emb = self.emb(x)
# [B, C, T]
x_emb = torch.transpose(x_emb, 1, -1)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
# encoder pass
o_en = self.encoder(x_emb, x_mask)
# speaker conditioning for duration predictor
if g is not None:
o_en_dp = self._concat_speaker_embedding(o_en, g)
else:
o_en_dp = o_en
return o_en, o_en_dp, x_mask, g
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
# speaker embedding
if g is not None:
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
# decoder pass
o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de, attn.transpose(1, 2)
def forward(
self, x, x_lengths, y_lengths, dr, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=unused-argument
"""
TODO: speaker embedding for speaker_ids
Shapes:
x: [B, T_max]
x_lengths: [B]
y_lengths: [B]
dr: [B, T_max]
g: [B, C]
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
"""
Shapes:
x: [B, T_max]
x_lengths: [B]
g: [B, C]
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# input sequence should be greated than the max convolution size
inference_padding = 5
if x.shape[1] < 13:
inference_padding += 13 - x.shape[1]
# pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "durations_log": None}
return outputs
def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
durations = batch["durations"]
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, aux_input)
# compute loss
loss_dict = criterion(
outputs["model_outputs"],
mel_input,
mel_lengths,
outputs["durations_log"],
torch.log(1 + durations),
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]
pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False),
}
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio}
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs)
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
def get_criterion(self):
from TTS.tts.layers.losses import SpeedySpeechLoss # pylint: disable=import-outside-toplevel
return SpeedySpeechLoss(self.config)

View File

@ -9,12 +9,11 @@ from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
@ -24,28 +23,6 @@ from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
"""Segment each sample in a batch based on the provided segment indices"""
segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
index_start = segment_indices[i]
index_end = index_start + segment_size
segments[i] = x[i, :, index_start:index_end]
return segments
def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
"""Create random segments based on the input lengths."""
B, _, T = x.size()
if x_lengths is None:
x_lengths = T
max_idxs = x_lengths - segment_size + 1
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long()
ret = segment(x, segment_indices, segment_size)
return ret, segment_indices
@dataclass
class VitsArgs(Coqpit):
"""VITS model arguments.
@ -451,7 +428,7 @@ class Vits(BaseTTS):
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size)
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g)
outputs.update(
{

View File

@ -1,5 +1,4 @@
import numpy as np
import torch
def _pad_data(x, length):
@ -52,35 +51,3 @@ def prepare_stop_target(inputs, out_steps):
def pad_per_step(inputs, pad_len):
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
# pylint: disable=attribute-defined-outside-init
class StandardScaler:
def set_stats(self, mean, scale):
self.mean_ = mean
self.scale_ = scale
def reset_stats(self):
delattr(self, "mean_")
delattr(self, "scale_")
def transform(self, X):
X = np.asarray(X)
X -= self.mean_
X /= self.scale_
return X
def inverse_transform(self, X):
X = np.asarray(X)
X *= self.scale_
X += self.mean_
return X
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)

213
TTS/tts/utils/helpers.py Normal file
View File

@ -0,0 +1,213 @@
import numpy as np
import torch
from torch.nn import functional as F
try:
from TTS.tts.utils.monotonic_align.core import maximum_path_c
CYTHON = True
except ModuleNotFoundError:
CYTHON = False
class StandardScaler:
"""StandardScaler for mean-scale normalization with the given mean and scale values."""
def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None:
self.mean_ = mean
self.scale_ = scale
def set_stats(self, mean, scale):
self.mean_ = mean
self.scale_ = scale
def reset_stats(self):
delattr(self, "mean_")
delattr(self, "scale_")
def transform(self, X):
X = np.asarray(X)
X -= self.mean_
X /= self.scale_
return X
def inverse_transform(self, X):
X = np.asarray(X)
X *= self.scale_
X += self.mean_
return X
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
"""Create a sequence mask for filtering padding in a sequence tensor.
Args:
sequence_length (torch.tensor): Sequence lengths.
max_len (int, Optional): Maximum sequence length. Defaults to None.
Shapes:
- mask: :math:`[B, T_max]`
"""
if max_len is None:
max_len = sequence_length.data.max()
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
return mask
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
"""Segment each sample in a batch based on the provided segment indices
Args:
x (torch.tensor): Input tensor.
segment_indices (torch.tensor): Segment indices.
segment_size (int): Expected output segment size.
"""
segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
index_start = segment_indices[i]
index_end = index_start + segment_size
segments[i] = x[i, :, index_start:index_end]
return segments
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
"""Create random segments based on the input lengths.
Args:
x (torch.tensor): Input tensor.
x_lengths (torch.tensor): Input lengths.
segment_size (int): Expected output segment size.
Shapes:
- x: :math:`[B, C, T]`
- x_lengths: :math:`[B]`
"""
B, _, T = x.size()
if x_lengths is None:
x_lengths = T
max_idxs = x_lengths - segment_size + 1
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long()
ret = segment(x, segment_indices, segment_size)
return ret, segment_indices
def average_over_durations(values, durs):
"""Average values over durations.
Shapes:
- values: :math:`[B, 1, T_de]`
- durs: :math:`[B, T_en]`
- avg: :math:`[B, 1, T_en]`
"""
durs_cums_ends = torch.cumsum(durs, dim=1).long()
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
values_nonzero_cums = torch.nn.functional.pad(torch.cumsum(values != 0.0, dim=2), (1, 0))
values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0))
bs, l = durs_cums_ends.size()
n_formants = values.size(1)
dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l)
dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l)
values_sums = (torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)).float()
values_nelems = (torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs)).float()
avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems)
return avg
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def generate_path(duration, mask):
"""
Shapes:
- duration: :math:`[B, T_en]`
- mask: :math:'[B, T_en, T_de]`
- path: :math:`[B, T_en, T_de]`
"""
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
def maximum_path(value, mask):
if CYTHON:
return maximum_path_cython(value, mask)
return maximum_path_numpy(value, mask)
def maximum_path_cython(value, mask):
"""Cython optimised version.
Shapes:
- value: :math:`[B, T_en, T_de]`
- mask: :math:`[B, T_en, T_de]`
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.data.cpu().numpy().astype(np.float32)
path = np.zeros_like(value).astype(np.int32)
mask = mask.data.cpu().numpy()
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
maximum_path_c(path, value, t_x_max, t_y_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)
def maximum_path_numpy(value, mask, max_neg_val=None):
"""
Monotonic alignment search algorithm
Numpy-friendly version. It's about 4 times faster than torch version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
if max_neg_val is None:
max_neg_val = -np.inf # Patch for Sphinx complaint
value = value * mask
device = value.device
dtype = value.dtype
value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(np.bool)
b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64)
v = np.zeros((b, t_x), dtype=np.float32)
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
for j in range(t_y):
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
v1 = v
max_mask = v1 >= v0
v_max = np.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = x_range <= j
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
direction = np.where(mask, direction, 1)
path = np.zeros(value.shape, dtype=np.float32)
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
index_range = np.arange(b)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.astype(np.float32)
path = torch.from_numpy(path).to(device=device, dtype=dtype)
return path

File diff suppressed because it is too large Load Diff

View File

@ -101,6 +101,7 @@ def visualize(
figsize=(8, 24),
output_fig=False,
):
"""Intended to be used in Notebooks."""
if decoder_output is not None:
num_plot = 4

View File

@ -9,7 +9,7 @@ import soundfile as sf
import torch
from torch import nn
from TTS.tts.utils.data import StandardScaler
from TTS.tts.utils.helpers import StandardScaler
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
@ -608,6 +608,9 @@ class AudioProcessor(object):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = self._istft(S_complex * angles)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(self._stft(y)))
y = self._istft(S_complex * angles)

View File

@ -7,7 +7,7 @@ We tried to collect common issues and questions we receive about 🐸TTS. It is
- If you feel like it's a bug to be fixed, then prefer Github issues with the same level of scrutiny.
## What are the requirements of a good 🐸TTS dataset?
* https://github.com/coqui-ai/TTS/wiki/What-makes-a-good-TTS-dataset
* {ref}`See this page <what_makes_a_good_dataset>`
## How should I choose the right model?
- First, train Tacotron. It is smaller and faster to experiment with. If it performs poorly, try Tacotron2.

115
docs/source/finetuning.md Normal file
View File

@ -0,0 +1,115 @@
# Fine-tuning a 🐸 TTS model
## Fine-tuning
Fine-tuning takes a pre-trained model, and retrains it to improve the model performance on a different task or dataset.
In 🐸TTS we provide different pre-trained models in different languages and different pros and cons. You can take one of
them and fine-tune it for your own dataset. This will help you in two main ways:
1. Faster learning
Since a pre-trained model has already learned features that are relevant for the task, it will converge faster on
a new dataset. This will reduce the cost of training and let you experient faster.
2. Better resutls with small datasets
Deep learning models are data hungry and they give better performance with more data. However, it is not always
possible to have this abondance, especially in domain. For instance, LJSpeech dataset, that we released most of
our English models with, is almost 24 hours long. And it requires for someone to collect thid amount of data with
a help of a voice talent takes weeks.
Fine-tuning cames to rescue in this case. You can take one of our pre-trained models and fine-tune it for your own
speech dataset and achive reasonable results with only a couple of hours in the worse case.
However, note that, fine-tuning does not promise great results. The model performance is still depends on the
{ref}`dataset quality <what_makes_a_good_dataset>` and the hyper-parameters you choose for fine-tuning. Therefore,
it still demands a bit of tinkering.
## Steps to fine-tune a 🐸 TTS model
1. Setup your dataset.
You need to format your target dataset in a certain way so that 🐸TTS data loader would be able to load it for the
training. Please see {ref}`this page <formatting_your_dataset>` for more information about formatting.
2. Choose the model you want to fine-tune.
You can list the availabe models on terminal as
```bash
tts --list-models
```
The command above lists the the models in a naming format as ```<model_type>/<language>/<dataset>/<model_name>```.
Or you can manually check `.model.json` file in the project directory.
You should choose the model based on your requirements. Some models are fast and some are better in speech quality.
One lazy way to check a model is running the model on the hardware you want to use and see how it works. For
simple testing, you can use the `tts` command on the terminal. For more info see {ref}`here <synthesizing_speech>`.
3. Download the model.
You can download the model by `tts` command. If you run `tts` with a particular model, it will download automatically
and the model path will be printed on the terminal.
```bash
tts --model_name tts_models/es/mai/tacotron2-DDC --text "Ola."
> Downloading model to /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts
...
```
In the example above, we called the Spanish Tacotron model and give the sample output showing use the path where
the model is downloaded.
4. Setup the model config for fine-tuning.
You need to change certain fields in the model config. You have 3 options for playing with the configuration.
1. Edit the fields in the ```config.json``` file if you want to use ```TTS/bin/train_tts.py``` to train the model.
2. Edit the fields in one of the training scripts in the ```recipes``` directory if you want to use python.
3. Use the command-line arguments to override the fields like ```--coqpit.lr 0.00001``` to change the learning rate.
Some of the important fields are as follows:
- `datasets` field: This is set to the dataset you want to fine-tune the model on.
- `run_name` field: This is the name of the run. This is used to name the output directory and the entry in the
logging dashboard.
- `output_path` field: This is the path where the fine-tuned model is saved.
- `lr` field: You may need to use a smaller learning rate for fine-tuning not to impair the features learned by the
pre-trained model with big update steps.
- `audio` fields: Different datasets have different audio characteristics. You must check the current audio parameters and
make sure that the values reflect your dataset. For instance, your dataset might have a different audio sampling rate.
Apart from these above, you should check the whole configuration file and make sure that the values are correct for
your dataset and training.
5. Start fine-tuning.
Whether you use one of the training scripts under ```recipes``` folder or the ```train_tts.py``` to start
your training, you should use the ```--restore_path``` flag to specify the path to the pre-trained model.
```bash
CUDA_VISIBLE_DEVICES="0" python recipes/ljspeech/glow_tts/train_glowtts.py \
--restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts
```
```bash
CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tts.py \
--config_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts/config.json \
--restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts
```
As stated above, you can also use command-line arguments to change the model configuration.
```bash
CUDA_VISIBLE_DEVICES="0" python recipes/ljspeech/glow_tts/train_glowtts.py \
--restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts
--coqpit.run_name "glow-tts-finetune" \
--coqpit.lr 0.00001
```

View File

@ -1,3 +1,4 @@
(formatting_your_dataset)=
# Formatting Your Dataset
For training a TTS model, you need a dataset with speech recordings and transcriptions. The speech must be divided into audio clips and each clip needs transcription.
@ -18,15 +19,15 @@ Let's assume you created the audio clips and their transcription. You can collec
You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each line must be delimitered by a special character separating the audio file name from the transcription. And make sure that the delimiter is not used in the transcription text.
We recommend the following format delimited by `|`.
We recommend the following format delimited by `||`.
```
# metadata.txt
audio1.wav | This is my sentence.
audio2.wav | This is maybe my sentence.
audio3.wav | This is certainly my sentence.
audio4.wav | Let this be your sentence.
audio1.wav || This is my sentence.
audio2.wav || This is maybe my sentence.
audio3.wav || This is certainly my sentence.
audio4.wav || Let this be your sentence.
...
```

View File

@ -22,6 +22,7 @@
inference
implementing_a_new_model
training_a_model
finetuning
configuration
formatting_your_dataset
what_makes_a_good_dataset
@ -45,7 +46,7 @@
models/glow_tts.md
models/vits.md
models/fast_pitch.md
models/forward_tts.md
.. toctree::
:maxdepth: 2

View File

@ -1,28 +0,0 @@
# FastPitch
FastPitch is a feed-forward encoder-decoder TTS model. It computes mel-spectrogram from the given input character sequence.
It uses a duration predictor network to predict the duration of each input character in the output sequence. In the original paper, they use a pre-trained Tacotron model to generate the labels for the duration predictor. In this implementation, you can also use an aligner network to learn the durations from the data and train the duration predictor in parallel. Original FastPitch model uses FeedForwardTransformer networks for both encoder and decoder. But in this implementation, you have the freedom to choose different encoder and decoder networks by just changing the relevant fields in the model configuration. Please see `FastPitchArgs` and `FastPitchConfig` below for more details.
## Important resources & papers
- FastPitch: https://arxiv.org/abs/2006.06873
- FastSpeech: https://arxiv.org/pdf/1905.09263
- Aligner Network: https://arxiv.org/abs/2108.10447
- What is Pitch: https://www.britannica.com/topic/pitch-speech
## FastPitchConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.fast_pitch_config.FastPitchConfig
:members:
```
## FastPitchArgs
```{eval-rst}
.. autoclass:: TTS.tts.models.fast_pitch.FastPitchArgs
:members:
```
## FastPitch Model
```{eval-rst}
.. autoclass:: TTS.tts.models.fast_pitch.FastPitch
:members:
```

View File

@ -0,0 +1,65 @@
# Forward TTS model(s)
A general feed-forward TTS model implementation that can be configured to different architectures by setting different
encoder and decoder networks. It can be trained with either pre-computed durations (from pre-trained Tacotron) or
an alignment network that learns the text to audio alignment from the input data.
Currently we provide the following pre-configured architectures:
- **FastSpeech:**
It's a feed-forward model TTS model that uses Feed Forward Transformer (FFT) modules as the encoder and decoder.
- **FastPitch:**
It uses the same FastSpeech architecture that us conditioned on fundemental frequency (f0) contours with the
promise of more expressive speech.
- **SpeedySpeech:**
It uses Residual Convolution layers instead of Transformers that leads to a more compute friendly model.
- **FastSpeech2 (TODO):**
Similar to FastPitch but it also uses a spectral energy values as an addition.
## Important resources & papers
- FastPitch: https://arxiv.org/abs/2006.06873
- SpeedySpeech: https://arxiv.org/abs/2008.03802
- FastSpeech: https://arxiv.org/pdf/1905.09263
- FastSpeech2: https://arxiv.org/abs/2006.04558
- Aligner Network: https://arxiv.org/abs/2108.10447
- What is Pitch: https://www.britannica.com/topic/pitch-speech
## ForwardTTSArgs
```{eval-rst}
.. autoclass:: TTS.tts.models.forward_tts.ForwardTTSArgs
:members:
```
## ForwardTTS Model
```{eval-rst}
.. autoclass:: TTS.tts.models.forward_tts.ForwardTTS
:members:
```
## FastPitchConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.fast_pitch_config.FastPitchConfig
:members:
```
## SpeedySpeechConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.speedy_speech_config.SpeedySpeechConfig
:members:
```
## FastSpeechConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.fast_speech_config.FastSpeechConfig
:members:
```

View File

@ -54,7 +54,7 @@
4. Run the training.
You need to call the python training script.
You need to run the training script.
```bash
$ CUDA_VISIBLE_DEVICES="0" python train_glowtts.py
@ -63,7 +63,7 @@
Notice that you set the GPU you want to use on your system by setting `CUDA_VISIBLE_DEVICES` environment variable.
To see available GPUs on your system, you can use `nvidia-smi` command on the terminal.
If you like to run a multi-gpu training
If you like to run a multi-gpu training using DDP back-end,
```bash
$ CUDA_VISIBLE_DEVICES="0, 1, 2" python TTS/bin/distribute.py --script <path_to_your_script>/train_glowtts.py

View File

@ -1,3 +1,4 @@
(what_makes_a_good_dataset)=
# What makes a good TTS dataset
## What Makes a Good Dataset

View File

@ -2,16 +2,14 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training."
]
"This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@ -25,22 +23,23 @@
"from TTS.tts.datasets.TTSDataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.io import load_config\n",
"from TTS.config import load_config\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.tts.utils.generic_utils import setup_model, sequence_mask\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
"\n",
"%matplotlib inline\n",
"\n",
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']='0'"
]
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def set_filename(wav_path, out_path):\n",
" wav_file = os.path.basename(wav_path)\n",
@ -52,20 +51,20 @@
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
" return file_name, wavq_path, mel_path, wav_path"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"OUT_PATH = \"/home/erogol/gdrive/Datasets/non-binary-voice-files/tacotron-DCA\"\n",
"DATA_PATH = \"/home/erogol/gdrive/Datasets/non-binary-voice-files/\"\n",
"DATASET = \"sam_accenture\"\n",
"METADATA_FILE = \"recording_script.xml\"\n",
"CONFIG_PATH = \"/home/erogol/gdrive/Trainings/sam/ljspeech-dcattn-April-03-2021_05+02-2344379/config.json\"\n",
"MODEL_FILE = \"/home/erogol/gdrive/Trainings/sam/ljspeech-dcattn-April-03-2021_05+02-2344379/best_model.pth.tar\"\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth.tar\"\n",
"BATCH_SIZE = 32\n",
"\n",
"QUANTIZED_WAV = False\n",
@ -78,56 +77,63 @@
"C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(C['r'])\n",
"# if the vocabulary was passed, replace the default\n",
"if 'characters' in C.keys():\n",
"if 'characters' in C and C['characters']:\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"\n",
"# load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speaker\n",
"model = setup_model(num_chars, num_speakers=0, c=C)\n",
"checkpoint = torch.load(MODEL_FILE)\n",
"model.load_state_dict(checkpoint['model'])\n",
"print(checkpoint['step'])\n",
"model.eval()\n",
"model.decoder.set_r(checkpoint['r'])\n",
"if use_cuda:\n",
" model = model.cuda()"
]
"model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n",
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
"dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,characters=c.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
]
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
"dataset = TTSDataset(\n",
" checkpoint[\"config\"][\"r\"],\n",
" C.text_cleaner,\n",
" False,\n",
" ap,\n",
" meta_data,\n",
" characters=C.get('characters', None),\n",
" use_phonemes=C.use_phonemes,\n",
" phoneme_cache_path=C.phoneme_cache_path,\n",
" enable_eos_bos=C.enable_eos_bos_chars,\n",
")\n",
"loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate model outputs "
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
@ -206,42 +212,42 @@
"\n",
" print(np.mean(losses))\n",
" print(np.mean(postnet_losses))"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sanity Check"
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import soundfile as sf\n",
"wav, sr = sf.read(item_idx[idx])\n",
@ -249,46 +255,46 @@
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot decoder output\n",
"print(mel_decoder.shape)\n",
"plot_spectrogram(mel_decoder, ap)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot GT specgrogram\n",
"print(mel_truth.shape)\n",
"plot_spectrogram(mel_truth.T, ap)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
@ -297,13 +303,13 @@
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
@ -312,13 +318,13 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
@ -328,21 +334,22 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": [],
"outputs": [],
"source": []
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"name": "python3",
"display_name": "Python 3.9.7 64-bit ('base': conda)"
},
"language_info": {
"codemirror_mode": {
@ -354,7 +361,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.9.7"
},
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
}
},
"nbformat": 4,

View File

@ -0,0 +1,68 @@
import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import SpeedySpeechConfig
from TTS.utils.manage import ModelManager
output_path = os.path.dirname(os.path.abspath(__file__))
# init configs
dataset_config = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
path=os.path.join(output_path, "../LJSpeech-1.1/"),
)
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = SpeedySpeechConfig(
run_name="speedy_speech_ljspeech",
audio=audio_config,
batch_size=32,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
compute_input_seq_cache=True,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
)
# compute alignments
if not config.model_args.use_aligner:
manager = ModelManager()
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
# TODO: make compute_attention python callable
os.system(
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
)
# train the model
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
trainer.fit()

View File

@ -54,8 +54,8 @@ with open("README.md", "r", encoding="utf-8") as readme_file:
exts = [
Extension(
name="TTS.tts.layers.glow_tts.monotonic_align.core",
sources=["TTS/tts/layers/glow_tts/monotonic_align/core.pyx"],
name="TTS.tts.utils.monotonic_align.core",
sources=["TTS/tts/utils/monotonic_align/core.pyx"],
)
]
setup(

View File

@ -7,8 +7,8 @@ from TTS.utils.generic_utils import get_cuda
def get_device_id():
use_cuda, _ = get_cuda()
if use_cuda:
if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES'] != "":
GPU_ID = os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0]
if "CUDA_VISIBLE_DEVICES" in os.environ and os.environ["CUDA_VISIBLE_DEVICES"] != "":
GPU_ID = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
else:
GPU_ID = "0"
else:

View File

@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
wavs = data['waveform']
text_input = data["text"]
text_lengths = data["text_lengths"]
speaker_name = data["speaker_names"]
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
wavs = data["waveform"]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
text_input = data["text"]
text_lengths = data["text_lengths"]
speaker_name = data["speaker_names"]
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
text_input = data["text"]
text_lengths = data["text_lengths"]
speaker_name = data["speaker_names"]
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
# check mel_spec consistency
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
text_input = data["text"]
text_lengths = data["text_lengths"]
speaker_name = data["speaker_names"]
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
if mel_lengths[0] > mel_lengths[1]:
idx = 0

View File

@ -11,11 +11,10 @@ def test_synthesize():
# single speaker model
run_cli(f'tts --text "This is an example." --out_path "{output_path}"')
run_cli(
"tts --model_name tts_models/en/ljspeech/speedy-speech-wn "
f'--text "This is an example." --out_path "{output_path}"'
"tts --model_name tts_models/en/ljspeech/glow-tts " f'--text "This is an example." --out_path "{output_path}"'
)
run_cli(
"tts --model_name tts_models/en/ljspeech/speedy-speech-wn "
"tts --model_name tts_models/en/ljspeech/glow-tts "
"--vocoder_name vocoder_models/en/ljspeech/multiband-melgan "
f'--text "This is an example." --out_path "{output_path}"'
)

View File

@ -1,47 +0,0 @@
import unittest
import torch as T
from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs, average_pitch
# pylint: disable=unused-variable
class AveragePitchTests(unittest.TestCase):
def test_in_out(self): # pylint: disable=no-self-use
pitch = T.rand(1, 1, 128)
durations = T.randint(1, 5, (1, 21))
coeff = 128.0 / durations.sum()
durations = T.round(durations * coeff)
diff = 128.0 - durations.sum()
durations[0, -1] += diff
durations = durations.long()
pitch_avg = average_pitch(pitch, durations)
index = 0
for idx, dur in enumerate(durations[0]):
assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5
index += dur
def expand_encoder_outputs_test():
model = FastPitch(FastPitchArgs(num_chars=10))
inputs = T.rand(2, 5, 57)
durations = T.randint(1, 4, (2, 57))
x_mask = T.ones(2, 1, 57)
y_mask = T.ones(2, 1, durations.sum(1).max())
expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask)
for b in range(durations.shape[0]):
index = 0
for idx, dur in enumerate(durations[b]):
diff = (
expanded[b, :, index : index + dur.item()]
- inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape)
).sum()
assert abs(diff) < 1e-6, diff
index += dur

View File

@ -0,0 +1,68 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs import FastPitchConfig
config_path = os.path.join(get_tests_output_path(), "test_fast_pitch_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = FastPitchConfig(
audio=audio_config,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
f0_cache_path="tests/data/ljspeech/f0_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
"Be a voice, not an echo.",
],
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -2,7 +2,7 @@ import torch
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import sequence_mask
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

View File

@ -0,0 +1,147 @@
import torch as T
from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs
from TTS.tts.utils.helpers import sequence_mask
# pylint: disable=unused-variable
def expand_encoder_outputs_test():
model = ForwardTTS(ForwardTTSArgs(num_chars=10))
inputs = T.rand(2, 5, 57)
durations = T.randint(1, 4, (2, 57))
x_mask = T.ones(2, 1, 57)
y_mask = T.ones(2, 1, durations.sum(1).max())
expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask)
for b in range(durations.shape[0]):
index = 0
for idx, dur in enumerate(durations[b]):
diff = (
expanded[b, :, index : index + dur.item()]
- inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape)
).sum()
assert abs(diff) < 1e-6, diff
index += dur
def model_input_output_test():
"""Assert the output shapes of the model in different modes"""
# VANILLA MODEL
model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=False, use_aligner=False))
x = T.randint(0, 10, (2, 21))
x_lengths = T.randint(10, 22, (2,))
x_lengths[-1] = 21
x_mask = sequence_mask(x_lengths).unsqueeze(1).long()
durations = T.randint(1, 4, (2, 21))
durations = durations * x_mask.squeeze(1)
y_lengths = durations.sum(1)
y_mask = sequence_mask(y_lengths).unsqueeze(1).long()
outputs = model.forward(x, x_lengths, y_lengths, dr=durations)
assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80)
assert outputs["durations_log"].shape == (2, 21)
assert outputs["durations"].shape == (2, 21)
assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21)
assert (outputs["x_mask"] - x_mask).sum() == 0.0
assert (outputs["y_mask"] - y_mask).sum() == 0.0
assert outputs["alignment_soft"] is None
assert outputs["alignment_mas"] is None
assert outputs["alignment_logprob"] is None
assert outputs["o_alignment_dur"] is None
assert outputs["pitch_avg"] is None
assert outputs["pitch_avg_gt"] is None
# USE PITCH
model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=True, use_aligner=False))
x = T.randint(0, 10, (2, 21))
x_lengths = T.randint(10, 22, (2,))
x_lengths[-1] = 21
x_mask = sequence_mask(x_lengths).unsqueeze(1).long()
durations = T.randint(1, 4, (2, 21))
durations = durations * x_mask.squeeze(1)
y_lengths = durations.sum(1)
y_mask = sequence_mask(y_lengths).unsqueeze(1).long()
pitch = T.rand(2, 1, y_lengths.max())
outputs = model.forward(x, x_lengths, y_lengths, dr=durations, pitch=pitch)
assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80)
assert outputs["durations_log"].shape == (2, 21)
assert outputs["durations"].shape == (2, 21)
assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21)
assert (outputs["x_mask"] - x_mask).sum() == 0.0
assert (outputs["y_mask"] - y_mask).sum() == 0.0
assert outputs["pitch_avg"].shape == (2, 1, 21)
assert outputs["pitch_avg_gt"].shape == (2, 1, 21)
assert outputs["alignment_soft"] is None
assert outputs["alignment_mas"] is None
assert outputs["alignment_logprob"] is None
assert outputs["o_alignment_dur"] is None
# USE ALIGNER NETWORK
model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=False, use_aligner=True))
x = T.randint(0, 10, (2, 21))
x_lengths = T.randint(10, 22, (2,))
x_lengths[-1] = 21
x_mask = sequence_mask(x_lengths).unsqueeze(1).long()
durations = T.randint(1, 4, (2, 21))
durations = durations * x_mask.squeeze(1)
y_lengths = durations.sum(1)
y_mask = sequence_mask(y_lengths).unsqueeze(1).long()
y = T.rand(2, y_lengths.max(), 80)
outputs = model.forward(x, x_lengths, y_lengths, dr=durations, y=y)
assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80)
assert outputs["durations_log"].shape == (2, 21)
assert outputs["durations"].shape == (2, 21)
assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21)
assert (outputs["x_mask"] - x_mask).sum() == 0.0
assert (outputs["y_mask"] - y_mask).sum() == 0.0
assert outputs["alignment_soft"].shape == (2, durations.sum(1).max(), 21)
assert outputs["alignment_mas"].shape == (2, durations.sum(1).max(), 21)
assert outputs["alignment_logprob"].shape == (2, 1, durations.sum(1).max(), 21)
assert outputs["o_alignment_dur"].shape == (2, 21)
assert outputs["pitch_avg"] is None
assert outputs["pitch_avg_gt"] is None
# USE ALIGNER NETWORK AND PITCH
model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=True, use_aligner=True))
x = T.randint(0, 10, (2, 21))
x_lengths = T.randint(10, 22, (2,))
x_lengths[-1] = 21
x_mask = sequence_mask(x_lengths).unsqueeze(1).long()
durations = T.randint(1, 4, (2, 21))
durations = durations * x_mask.squeeze(1)
y_lengths = durations.sum(1)
y_mask = sequence_mask(y_lengths).unsqueeze(1).long()
y = T.rand(2, y_lengths.max(), 80)
pitch = T.rand(2, 1, y_lengths.max())
outputs = model.forward(x, x_lengths, y_lengths, dr=durations, pitch=pitch, y=y)
assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80)
assert outputs["durations_log"].shape == (2, 21)
assert outputs["durations"].shape == (2, 21)
assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21)
assert (outputs["x_mask"] - x_mask).sum() == 0.0
assert (outputs["y_mask"] - y_mask).sum() == 0.0
assert outputs["alignment_soft"].shape == (2, durations.sum(1).max(), 21)
assert outputs["alignment_mas"].shape == (2, durations.sum(1).max(), 21)
assert outputs["alignment_logprob"].shape == (2, 1, durations.sum(1).max(), 21)
assert outputs["o_alignment_dur"].shape == (2, 21)
assert outputs["pitch_avg"].shape == (2, 1, 21)
assert outputs["pitch_avg_gt"].shape == (2, 1, 21)

View File

@ -0,0 +1,60 @@
import torch as T
from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask
def average_over_durations_test(): # pylint: disable=no-self-use
pitch = T.rand(1, 1, 128)
durations = T.randint(1, 5, (1, 21))
coeff = 128.0 / durations.sum()
durations = T.floor(durations * coeff)
diff = 128.0 - durations.sum()
durations[0, -1] += diff
durations = durations.long()
pitch_avg = average_over_durations(pitch, durations)
index = 0
for idx, dur in enumerate(durations[0]):
assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5
index += dur
def seqeunce_mask_test():
lengths = T.randint(10, 15, (8,))
mask = sequence_mask(lengths)
for i in range(8):
l = lengths[i].item()
assert mask[i, :l].sum() == l
assert mask[i, l:].sum() == 0
def segment_test():
x = T.range(0, 11)
x = x.repeat(8, 1).unsqueeze(1)
segment_ids = T.randint(0, 7, (8,))
segments = segment(x, segment_ids, segment_size=4)
for idx, start_indx in enumerate(segment_ids):
assert x[idx, :, start_indx : start_indx + 4].sum() == segments[idx, :, :].sum()
def generate_path_test():
durations = T.randint(1, 4, (10, 21))
x_length = T.randint(18, 22, (10,))
x_mask = sequence_mask(x_length).unsqueeze(1).long()
durations = durations * x_mask.squeeze(1)
y_length = durations.sum(1)
y_mask = sequence_mask(y_length).unsqueeze(1).long()
attn_mask = (T.unsqueeze(x_mask, -1) * T.unsqueeze(y_mask, 2)).squeeze(1).long()
print(attn_mask.shape)
path = generate_path(durations, attn_mask)
assert path.shape == (10, 21, durations.sum(1).max().item())
for b in range(durations.shape[0]):
current_idx = 0
for t in range(durations.shape[1]):
assert all(path[b, t, current_idx : current_idx + durations[b, t].item()] == 1.0)
assert all(path[b, t, :current_idx] == 0.0)
assert all(path[b, t, current_idx + durations[b, t].item() :] == 0.0)
current_idx += durations[b, t].item()

View File

@ -1,96 +0,0 @@
import torch
from TTS.tts.configs import SpeedySpeechConfig
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.models.speedy_speech import SpeedySpeech, SpeedySpeechArgs
from TTS.tts.utils.data import sequence_mask
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def test_duration_predictor():
input_dummy = torch.rand(8, 128, 27).to(device)
input_lengths = torch.randint(20, 27, (8,)).long().to(device)
input_lengths[-1] = 27
x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
layer = DurationPredictor(hidden_channels=128).to(device)
output = layer(input_dummy, x_mask)
assert list(output.shape) == [8, 1, 27]
def test_speedy_speech():
num_chars = 7
B = 8
T_en = 37
T_de = 74
x_dummy = torch.randint(0, 7, (B, T_en)).long().to(device)
x_lengths = torch.randint(31, T_en, (B,)).long().to(device)
x_lengths[-1] = T_en
# set durations. max total duration should be equal to T_de
durations = torch.randint(1, 4, (B, T_en))
durations = durations * (T_de / durations.sum(1)).unsqueeze(1)
durations = durations.to(torch.long).to(device)
max_dur = durations.sum(1).max()
durations[:, 0] += T_de - max_dur if T_de > max_dur else 0
y_lengths = durations.sum(1)
config = SpeedySpeechConfig(model_args=SpeedySpeechArgs(num_chars=num_chars, out_channels=80, hidden_channels=128))
model = SpeedySpeech(config)
if use_cuda:
model.cuda()
# forward pass
outputs = model(x_dummy, x_lengths, y_lengths, durations)
o_de = outputs["model_outputs"]
attn = outputs["alignments"]
o_dr = outputs["durations_log"]
assert list(o_de.shape) == [B, T_de, 80], f"{list(o_de.shape)}"
assert list(attn.shape) == [B, T_de, T_en]
assert list(o_dr.shape) == [B, T_en]
# with speaker embedding
config = SpeedySpeechConfig(
model_args=SpeedySpeechArgs(
num_chars=num_chars, out_channels=80, hidden_channels=128, num_speakers=80, d_vector_dim=256
)
)
model = SpeedySpeech(config).to(device)
model.forward(
x_dummy, x_lengths, y_lengths, durations, aux_input={"d_vectors": torch.randint(0, 10, (B,)).to(device)}
)
o_de = outputs["model_outputs"]
attn = outputs["alignments"]
o_dr = outputs["durations_log"]
assert list(o_de.shape) == [B, T_de, 80], f"{list(o_de.shape)}"
assert list(attn.shape) == [B, T_de, T_en]
assert list(o_dr.shape) == [B, T_en]
# with speaker external embedding
config = SpeedySpeechConfig(
model_args=SpeedySpeechArgs(
num_chars=num_chars,
out_channels=80,
hidden_channels=128,
num_speakers=10,
use_d_vector=True,
d_vector_dim=256,
)
)
model = SpeedySpeech(config).to(device)
model.forward(x_dummy, x_lengths, y_lengths, durations, aux_input={"d_vectors": torch.rand((B, 256)).to(device)})
o_de = outputs["model_outputs"]
attn = outputs["alignments"]
o_dr = outputs["durations_log"]
assert list(o_de.shape) == [B, T_de, 80], f"{list(o_de.shape)}"
assert list(attn.shape) == [B, T_de, T_en]
assert list(o_dr.shape) == [B, T_en]

View File

@ -4,14 +4,12 @@ import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import SpeedySpeechConfig
from TTS.tts.models.speedy_speech import SpeedySpeechArgs
config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = SpeedySpeechConfig(
model_args=SpeedySpeechArgs(num_chars=50, out_channels=80, hidden_channels=128, num_speakers=0),
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,

View File

@ -38,6 +38,7 @@ class TacotronTFTrainTest(unittest.TestCase):
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
def test_train_step(self):
"""test forward pass"""
(
@ -70,6 +71,7 @@ class TacotronTFTrainTest(unittest.TestCase):
# inference pass
output = model(chars_seq, training=False)
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
def test_forward_attention(
self,
):
@ -103,6 +105,7 @@ class TacotronTFTrainTest(unittest.TestCase):
# inference pass
output = model(chars_seq, training=False)
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
def test_tflite_conversion(
self,
): # pylint:disable=no-self-use

View File

@ -4,7 +4,7 @@ import torch as T
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss
from TTS.tts.layers.tacotron.tacotron import CBHG, Decoder, Encoder, Prenet
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.helpers import sequence_mask
# pylint: disable=unused-variable

View File

@ -1,9 +1,15 @@
import unittest
import numpy as np
import tensorflow as tf
import torch
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
use_cuda = torch.cuda.is_available()
@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.")
def test_melgan_generator():
hop_length = 256
model = MelganGenerator()

View File

@ -1,7 +1,9 @@
import os
import unittest
import soundfile as sf
import tensorflow as tf
import torch
from librosa.core import load
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
@ -9,8 +11,10 @@ from TTS.vocoder.tf.layers.pqmf import PQMF
TESTS_PATH = get_tests_path()
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
use_cuda = torch.cuda.is_available()
@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.")
def test_pqmf():
w, sr = load(WAV_FILE)