coqui-tts/TTS/tts/utils/generic_utils.py

363 lines
17 KiB
Python

import importlib
import re
from collections import Counter
import numpy as np
import torch
from TTS.utils.generic_utils import check_argument
def split_dataset(items):
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
del items[item_idx]
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace("Tts", "TTS")
return text
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module("TTS.tts.models." + c.model.lower())
MyModel = getattr(MyModel, to_camel(c.model))
if c.model.lower() in "tacotron":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
decoder_output_dim=c.audio["num_mels"],
gst=c.use_gst,
gst_embedding_dim=c.gst["gst_embedding_dim"],
gst_num_heads=c.gst["gst_num_heads"],
gst_style_tokens=c.gst["gst_style_tokens"],
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
memory_size=c.memory_size,
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim,
)
elif c.model.lower() == "tacotron2":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
gst=c.use_gst,
gst_embedding_dim=c.gst["gst_embedding_dim"],
gst_num_heads=c.gst["gst_num_heads"],
gst_style_tokens=c.gst["gst_style_tokens"],
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim,
)
elif c.model.lower() == "glow_tts":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels_enc=c["hidden_channels_encoder"],
hidden_channels_dec=c["hidden_channels_decoder"],
hidden_channels_dp=c["hidden_channels_duration_predictor"],
out_channels=c.audio["num_mels"],
encoder_type=c.encoder_type,
encoder_params=c.encoder_params,
use_encoder_prenet=c["use_encoder_prenet"],
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=1,
num_block_layers=4,
dropout_p_dec=0.05,
num_speakers=num_speakers,
c_in_channels=0,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
mean_only=True,
external_speaker_embedding_dim=speaker_embedding_dim,
)
elif c.model.lower() == "speedy_speech":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio["num_mels"],
hidden_channels=c["hidden_channels"],
positional_encoding=c["positional_encoding"],
encoder_type=c["encoder_type"],
encoder_params=c["encoder_params"],
decoder_type=c["decoder_type"],
decoder_params=c["decoder_params"],
c_in_channels=0,
)
elif c.model.lower() == "align_tts":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio["num_mels"],
hidden_channels=c["hidden_channels"],
hidden_channels_dp=c["hidden_channels_dp"],
encoder_type=c["encoder_type"],
encoder_params=c["encoder_params"],
decoder_type=c["decoder_type"],
decoder_params=c["decoder_params"],
c_in_channels=0,
)
return model
def is_tacotron(c):
return "tacotron" in c["model"].lower()
def check_config_tts(c):
check_argument(
"model",
c,
enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"],
restricted=True,
val_type=str,
)
check_argument("run_name", c, restricted=True, val_type=str)
check_argument("run_description", c, val_type=str)
# AUDIO
check_argument("audio", c, restricted=True, val_type=dict)
# audio processing parameters
check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
check_argument(
"frame_length_ms",
c["audio"],
restricted=True,
val_type=float,
min_val=10,
max_val=1000,
alternative="win_length",
)
check_argument(
"frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
)
check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
# vocabulary parameters
check_argument("characters", c, restricted=False, val_type=dict)
check_argument(
"pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"characters",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys(),
val_type=str,
)
check_argument(
"phonemes",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys() and c["use_phonemes"],
val_type=str,
)
check_argument(
"punctuations",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys(),
val_type=str,
)
# normalization parameters
check_argument("signal_norm", c["audio"], restricted=True, val_type=bool)
check_argument("symmetric_norm", c["audio"], restricted=True, val_type=bool)
check_argument("max_norm", c["audio"], restricted=True, val_type=float, min_val=0.1, max_val=1000)
check_argument("clip_norm", c["audio"], restricted=True, val_type=bool)
check_argument("mel_fmin", c["audio"], restricted=True, val_type=float, min_val=0.0, max_val=1000)
check_argument("mel_fmax", c["audio"], restricted=True, val_type=float, min_val=500.0)
check_argument("spec_gain", c["audio"], restricted=True, val_type=[int, float], min_val=1, max_val=100)
check_argument("do_trim_silence", c["audio"], restricted=True, val_type=bool)
check_argument("trim_db", c["audio"], restricted=True, val_type=int)
# training parameters
check_argument("batch_size", c, restricted=True, val_type=int, min_val=1)
check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1)
check_argument("r", c, restricted=True, val_type=int, min_val=1)
check_argument("gradual_training", c, restricted=False, val_type=list)
check_argument("mixed_precision", c, restricted=False, val_type=bool)
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# loss parameters
check_argument("loss_masking", c, restricted=True, val_type=bool)
if c["model"].lower() in ["tacotron", "tacotron2"]:
check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0)
if c["model"].lower in ["speedy_speech", "align_tts"]:
check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0)
# validation parameters
check_argument("run_eval", c, restricted=True, val_type=bool)
check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0)
check_argument("test_sentences_file", c, restricted=False, val_type=str)
# optimizer
check_argument("noam_schedule", c, restricted=False, val_type=bool)
check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0)
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0)
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool)
# tacotron prenet
check_argument("memory_size", c, restricted=is_tacotron(c), val_type=int, min_val=-1)
check_argument("prenet_type", c, restricted=is_tacotron(c), val_type=str, enum_list=["original", "bn"])
check_argument("prenet_dropout", c, restricted=is_tacotron(c), val_type=bool)
# attention
check_argument(
"attention_type",
c,
restricted=is_tacotron(c),
val_type=str,
enum_list=["graves", "original", "dynamic_convolution"],
)
check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int)
check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"])
check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool)
check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool)
check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool)
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool)
check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool)
check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool)
check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int)
if c["model"].lower() in ["tacotron", "tacotron2"]:
# stopnet
check_argument("stopnet", c, restricted=is_tacotron(c), val_type=bool)
check_argument("separate_stopnet", c, restricted=is_tacotron(c), val_type=bool)
# Model Parameters for non-tacotron models
if c["model"].lower in ["speedy_speech", "align_tts"]:
check_argument("positional_encoding", c, restricted=True, val_type=type)
check_argument("encoder_type", c, restricted=True, val_type=str)
check_argument("encoder_params", c, restricted=True, val_type=dict)
check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict)
# GlowTTS parameters
check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str)
# tensorboard
check_argument("print_step", c, restricted=True, val_type=int, min_val=1)
check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1)
check_argument("save_step", c, restricted=True, val_type=int, min_val=1)
check_argument("checkpoint", c, restricted=True, val_type=bool)
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
# dataloading
# pylint: disable=import-outside-toplevel
from TTS.tts.utils.text import cleaners
check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners))
check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool)
check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0)
check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0)
check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0)
check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0)
check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10)
check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool)
# paths
check_argument("output_path", c, restricted=True, val_type=str)
# multi-speaker and gst
check_argument("use_speaker_embedding", c, restricted=True, val_type=bool)
check_argument("use_external_speaker_embedding_file", c, restricted=c["use_speaker_embedding"], val_type=bool)
check_argument(
"external_speaker_embedding_file", c, restricted=c["use_external_speaker_embedding_file"], val_type=str
)
if c["model"].lower() in ["tacotron", "tacotron2"] and c["use_gst"]:
check_argument("use_gst", c, restricted=is_tacotron(c), val_type=bool)
check_argument("gst", c, restricted=is_tacotron(c), val_type=dict)
check_argument("gst_style_input", c["gst"], restricted=is_tacotron(c), val_type=[str, dict])
check_argument("gst_embedding_dim", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
check_argument("gst_use_speaker_embedding", c["gst"], restricted=is_tacotron(c), val_type=bool)
check_argument("gst_num_heads", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
check_argument("gst_style_tokens", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
# datasets - checking only the first entry
check_argument("datasets", c, restricted=True, val_type=list)
for dataset_entry in c["datasets"]:
check_argument("name", dataset_entry, restricted=True, val_type=str)
check_argument("path", dataset_entry, restricted=True, val_type=str)
check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)