mirror of https://github.com/coqui-ai/TTS.git
commit
4f6153965e
46
config.json
46
config.json
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"run_name": "ljspeech-graves",
|
||||
"run_description": "tacotron2 wuth graves attention",
|
||||
"run_name": "ljspeech-stft_params",
|
||||
"run_description": "tacotron2 cosntant stf parameters",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
|
@ -9,8 +9,10 @@
|
|||
"num_mels": 80, // size of the mel spec frame.
|
||||
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"frame_length_ms": 50, // stft window length in ms.
|
||||
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
|
@ -19,13 +21,26 @@
|
|||
// Normalization parameters
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60 // threshold for timming silence. Set this according to your dataset.
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
// if custom character set is not defined,
|
||||
// default set in symbols.py is used
|
||||
"characters":{
|
||||
"pad": "_",
|
||||
"eos": "~",
|
||||
"bos": "^",
|
||||
"characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
|
||||
"punctuations":"!'(),-.:;? ",
|
||||
"phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||
},
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
|
@ -48,11 +63,12 @@
|
|||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": false, // use noam warmup and lr schedule.
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"grad_clip": 1.0, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
|
||||
|
||||
// TACOTRON PRENET
|
||||
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
|
||||
|
@ -61,13 +77,13 @@
|
|||
|
||||
// ATTENTION
|
||||
"attention_type": "original", // 'original' or 'graves'
|
||||
"attention_heads": 5, // number of attention heads (only for 'graves')
|
||||
"attention_heads": 4, // number of attention heads (only for 'graves')
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"use_forward_attn": false, // if it uses forward attention. In general, it aligns faster.
|
||||
"forward_attn_mask": false, // Additional masking forcing monotonicity only in eval mode.
|
||||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"location_attn": false, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||
|
||||
// STOPNET
|
||||
|
@ -90,11 +106,10 @@
|
|||
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
||||
|
||||
// PATHS
|
||||
"output_path": "/data5/rw/pit/keep/", // DATASET-RELATED: output path for all training outputs.
|
||||
// "output_path": "/media/erogol/data_ssd/Models/runs/",
|
||||
"output_path": "/data4/rw/home/Trainings/",
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"phoneme_cache_path": "mozilla_us_phonemes_2_1", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
|
||||
|
@ -108,10 +123,9 @@
|
|||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "/data5/ro/shared/data/keithito/LJSpeech-1.1/",
|
||||
// "path": "/home/erogol/Data/LJSpeech-1.1",
|
||||
"meta_file_train": "metadata_train.csv",
|
||||
"meta_file_val": "metadata_val.csv"
|
||||
"path": "/root/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": null
|
||||
}
|
||||
]
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ class MyDataset(Dataset):
|
|||
text_cleaner,
|
||||
ap,
|
||||
meta_data,
|
||||
tp=None,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
max_seq_len=float("inf"),
|
||||
|
@ -49,6 +50,7 @@ class MyDataset(Dataset):
|
|||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
self.tp = tp
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.phoneme_language = phoneme_language
|
||||
|
@ -75,13 +77,13 @@ class MyDataset(Dataset):
|
|||
|
||||
def _generate_and_cache_phoneme_sequence(self, text, cache_path):
|
||||
"""generate a phoneme sequence from text.
|
||||
|
||||
since the usage is for subsequent caching, we never add bos and
|
||||
eos chars here. Instead we add those dynamically later; based on the
|
||||
config option."""
|
||||
phonemes = phoneme_to_sequence(text, [self.cleaners],
|
||||
language=self.phoneme_language,
|
||||
enable_eos_bos=False)
|
||||
enable_eos_bos=False,
|
||||
tp=self.tp)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
np.save(cache_path, phonemes)
|
||||
return phonemes
|
||||
|
@ -101,7 +103,7 @@ class MyDataset(Dataset):
|
|||
phonemes = self._generate_and_cache_phoneme_sequence(text,
|
||||
cache_path)
|
||||
if self.enable_eos_bos:
|
||||
phonemes = pad_with_eos_bos(phonemes)
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=self.tp)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
return phonemes
|
||||
|
||||
|
@ -113,7 +115,7 @@ class MyDataset(Dataset):
|
|||
text = self._load_or_generate_phoneme_sequence(wav_file, text)
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
text_to_sequence(text, [self.cleaners], tp=self.tp), dtype=np.int32)
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
assert wav.size > 0, self.items[idx][1]
|
||||
|
@ -193,7 +195,7 @@ class MyDataset(Dataset):
|
|||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [
|
||||
|
|
|
@ -60,22 +60,6 @@ def tweb(root_path, meta_file):
|
|||
# return {'text': texts, 'wavs': wavs}
|
||||
|
||||
|
||||
def mozilla_old(root_path, meta_file):
|
||||
"""Normalizes Mozilla meta data files to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "mozilla_old"
|
||||
with open(txt_file, 'r') as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
batch_no = int(cols[1].strip().split("_")[0])
|
||||
wav_folder = "batch{}".format(batch_no)
|
||||
wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip())
|
||||
text = cols[0].strip()
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
||||
|
||||
def mozilla(root_path, meta_file):
|
||||
"""Normalizes Mozilla meta data files to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
|
@ -91,6 +75,22 @@ def mozilla(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def mozilla_de(root_path, meta_file):
|
||||
"""Normalizes Mozilla meta data files to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "mozilla"
|
||||
with open(txt_file, 'r', encoding="ISO 8859-1") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.strip().split('|')
|
||||
wav_file = cols[0].strip()
|
||||
text = cols[1].strip()
|
||||
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
||||
wav_file = os.path.join(root_path, folder_name, wav_file)
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
||||
|
||||
def mailabs(root_path, meta_files=None):
|
||||
"""Normalizes M-AI-Labs meta data files to TTS format"""
|
||||
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
|
||||
|
|
|
@ -110,6 +110,86 @@ class LocationLayer(nn.Module):
|
|||
return processed_attention
|
||||
|
||||
|
||||
class GravesAttention(nn.Module):
|
||||
""" Discretized Graves attention:
|
||||
- https://arxiv.org/abs/1910.10288
|
||||
- https://arxiv.org/pdf/1906.01083.pdf
|
||||
"""
|
||||
COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi))
|
||||
|
||||
def __init__(self, query_dim, K):
|
||||
super(GravesAttention, self).__init__()
|
||||
self._mask_value = 1e-8
|
||||
self.K = K
|
||||
# self.attention_alignment = 0.05
|
||||
self.eps = 1e-5
|
||||
self.J = None
|
||||
self.N_a = nn.Sequential(
|
||||
nn.Linear(query_dim, query_dim, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Linear(query_dim, 3*K, bias=True))
|
||||
self.attention_weights = None
|
||||
self.mu_prev = None
|
||||
self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) # bias mean
|
||||
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) # bias std
|
||||
|
||||
def init_states(self, inputs):
|
||||
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
|
||||
self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5
|
||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||
|
||||
# pylint: disable=R0201
|
||||
# pylint: disable=unused-argument
|
||||
def preprocess_inputs(self, inputs):
|
||||
return None
|
||||
|
||||
def forward(self, query, inputs, processed_inputs, mask):
|
||||
"""
|
||||
shapes:
|
||||
query: B x D_attention_rnn
|
||||
inputs: B x T_in x D_encoder
|
||||
processed_inputs: place_holder
|
||||
mask: B x T_in
|
||||
"""
|
||||
gbk_t = self.N_a(query)
|
||||
gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K)
|
||||
|
||||
# attention model parameters
|
||||
# each B x K
|
||||
g_t = gbk_t[:, 0, :]
|
||||
b_t = gbk_t[:, 1, :]
|
||||
k_t = gbk_t[:, 2, :]
|
||||
|
||||
# attention GMM parameters
|
||||
sig_t = torch.nn.functional.softplus(b_t) + self.eps
|
||||
|
||||
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
||||
g_t = torch.softmax(g_t, dim=-1) + self.eps
|
||||
|
||||
j = self.J[:inputs.size(1)+1]
|
||||
|
||||
# attention weights
|
||||
phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1))))
|
||||
|
||||
# discritize attention weights
|
||||
alpha_t = torch.sum(phi_t, 1)
|
||||
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
||||
alpha_t[alpha_t == 0] = 1e-8
|
||||
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
alpha_t.data.masked_fill_(~mask, self._mask_value)
|
||||
|
||||
context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
||||
self.attention_weights = alpha_t
|
||||
self.mu_prev = mu_t
|
||||
return context
|
||||
|
||||
|
||||
class OriginalAttention(nn.Module):
|
||||
"""Following the methods proposed here:
|
||||
- https://arxiv.org/abs/1712.05884
|
||||
|
@ -289,82 +369,6 @@ class OriginalAttention(nn.Module):
|
|||
return context
|
||||
|
||||
|
||||
class GravesAttention(nn.Module):
|
||||
""" Graves attention as described here:
|
||||
- https://arxiv.org/abs/1910.10288
|
||||
"""
|
||||
COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi))
|
||||
|
||||
def __init__(self, query_dim, K):
|
||||
super(GravesAttention, self).__init__()
|
||||
self._mask_value = 0.0
|
||||
self.K = K
|
||||
# self.attention_alignment = 0.05
|
||||
self.eps = 1e-5
|
||||
self.J = None
|
||||
self.N_a = nn.Sequential(
|
||||
nn.Linear(query_dim, query_dim, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Linear(query_dim, 3*K, bias=True))
|
||||
self.attention_weights = None
|
||||
self.mu_prev = None
|
||||
self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.)
|
||||
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10)
|
||||
|
||||
def init_states(self, inputs):
|
||||
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
||||
self.J = torch.arange(0, inputs.shape[1]).to(inputs.device)
|
||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||
|
||||
# pylint: disable=R0201
|
||||
# pylint: disable=unused-argument
|
||||
def preprocess_inputs(self, inputs):
|
||||
return None
|
||||
|
||||
def forward(self, query, inputs, processed_inputs, mask):
|
||||
"""
|
||||
shapes:
|
||||
query: B x D_attention_rnn
|
||||
inputs: B x T_in x D_encoder
|
||||
processed_inputs: place_holder
|
||||
mask: B x T_in
|
||||
"""
|
||||
gbk_t = self.N_a(query)
|
||||
gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K)
|
||||
|
||||
# attention model parameters
|
||||
# each B x K
|
||||
g_t = gbk_t[:, 0, :]
|
||||
b_t = gbk_t[:, 1, :]
|
||||
k_t = gbk_t[:, 2, :]
|
||||
|
||||
# attention GMM parameters
|
||||
sig_t = torch.nn.functional.softplus(b_t) + self.eps
|
||||
|
||||
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
||||
g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps
|
||||
|
||||
# each B x K x T_in
|
||||
j = self.J[:inputs.size(1)]
|
||||
|
||||
# attention weights
|
||||
phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2))
|
||||
alpha_t = self.COEF * torch.sum(phi_t, 1)
|
||||
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
alpha_t.data.masked_fill_(~mask, self._mask_value)
|
||||
|
||||
context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
||||
self.attention_weights = alpha_t
|
||||
self.mu_prev = mu_t
|
||||
return context
|
||||
|
||||
|
||||
def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
|
|
|
@ -6,6 +6,11 @@ from TTS.utils.generic_utils import sequence_mask
|
|||
|
||||
|
||||
class L1LossMasked(nn.Module):
|
||||
|
||||
def __init__(self, seq_len_norm):
|
||||
super(L1LossMasked, self).__init__()
|
||||
self.seq_len_norm = seq_len_norm
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
Args:
|
||||
|
@ -24,14 +29,27 @@ class L1LossMasked(nn.Module):
|
|||
target.requires_grad = False
|
||||
mask = sequence_mask(
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.l1_loss(
|
||||
x * mask, target * mask, reduction="sum")
|
||||
loss = loss / mask.sum()
|
||||
if self.seq_len_norm:
|
||||
norm_w = mask / mask.sum(dim=1, keepdim=True)
|
||||
out_weights = norm_w.div(target.shape[0] * target.shape[2])
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.l1_loss(
|
||||
x * mask, target * mask, reduction='none')
|
||||
loss = loss.mul(out_weights.to(loss.device)).sum()
|
||||
else:
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.l1_loss(
|
||||
x * mask, target * mask, reduction='sum')
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
||||
|
||||
class MSELossMasked(nn.Module):
|
||||
|
||||
def __init__(self, seq_len_norm):
|
||||
super(MSELossMasked, self).__init__()
|
||||
self.seq_len_norm = seq_len_norm
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
Args:
|
||||
|
@ -50,10 +68,18 @@ class MSELossMasked(nn.Module):
|
|||
target.requires_grad = False
|
||||
mask = sequence_mask(
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.mse_loss(
|
||||
x * mask, target * mask, reduction="sum")
|
||||
loss = loss / mask.sum()
|
||||
if self.seq_len_norm:
|
||||
norm_w = mask / mask.sum(dim=1, keepdim=True)
|
||||
out_weights = norm_w.div(target.shape[0] * target.shape[2])
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.mse_loss(
|
||||
x * mask, target * mask, reduction='none')
|
||||
loss = loss.mul(out_weights.to(loss.device)).sum()
|
||||
else:
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.mse_loss(
|
||||
x * mask, target * mask, reduction='sum')
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
||||
|
||||
|
@ -70,3 +96,32 @@ class AttentionEntropyLoss(nn.Module):
|
|||
entropy = torch.distributions.Categorical(probs=align).entropy()
|
||||
loss = (entropy / np.log(align.shape[1])).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class BCELossMasked(nn.Module):
|
||||
|
||||
def __init__(self, pos_weight):
|
||||
super(BCELossMasked, self).__init__()
|
||||
self.pos_weight = pos_weight
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
Args:
|
||||
x: A Variable containing a FloatTensor of size
|
||||
(batch, max_len) which contains the
|
||||
unnormalized probability for each class.
|
||||
target: A Variable containing a LongTensor of size
|
||||
(batch, max_len) which contains the index of the true
|
||||
class for each corresponding step.
|
||||
length: A Variable containing a LongTensor of size (batch,)
|
||||
which contains the length of each data in a batch.
|
||||
Returns:
|
||||
loss: An average loss value in range [0, 1] masked by the length.
|
||||
"""
|
||||
# mask: (batch, max_len, 1)
|
||||
target.requires_grad = False
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum')
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
|
|
@ -64,7 +64,6 @@ class Encoder(nn.Module):
|
|||
def forward(self, x, input_lengths):
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
input_lengths = input_lengths.cpu().numpy()
|
||||
x = nn.utils.rnn.pack_padded_sequence(x,
|
||||
input_lengths,
|
||||
batch_first=True)
|
||||
|
@ -290,7 +289,7 @@ class Decoder(nn.Module):
|
|||
stop_tokens += [stop_token]
|
||||
alignments += [alignment]
|
||||
|
||||
if stop_token > 0.7:
|
||||
if stop_token > 0.7 and t > inputs.shape[0] / 2:
|
||||
break
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
|
|
@ -39,7 +39,7 @@ class Tacotron(nn.Module):
|
|||
encoder_dim = 512 if num_speakers > 1 else 256
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 256)
|
||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
# boilerplate model
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
|
@ -132,6 +132,7 @@ class Tacotron(nn.Module):
|
|||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, characters, speaker_ids=None, style_mel=None):
|
||||
inputs = self.embedding(characters)
|
||||
self._init_states()
|
||||
|
|
|
@ -35,7 +35,7 @@ class Tacotron2(nn.Module):
|
|||
encoder_dim = 512 if num_speakers > 1 else 512
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 512)
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
std = sqrt(2.0 / (num_chars + 512))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
|
@ -82,6 +82,7 @@ class Tacotron2(nn.Module):
|
|||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, speaker_ids=None):
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
|
|
|
@ -0,0 +1,585 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is to test TTS models with benchmark sentences for speech synthesis.\n",
|
||||
"\n",
|
||||
"Before running this script please DON'T FORGET: \n",
|
||||
"- to set file paths.\n",
|
||||
"- to download related model files from TTS and PWGAN.\n",
|
||||
"- download or clone related repos, linked below.\n",
|
||||
"- setup the repositories. ```python setup.py install```\n",
|
||||
"- to checkout right commit versions (given next to the model) of TTS and PWGAN.\n",
|
||||
"- to set the right paths in the cell below.\n",
|
||||
"\n",
|
||||
"Repositories:\n",
|
||||
"- TTS: https://github.com/mozilla/TTS\n",
|
||||
"- PWGAN: https://github.com/erogol/ParallelWaveGAN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import io\n",
|
||||
"import torch \n",
|
||||
"import time\n",
|
||||
"import json\n",
|
||||
"import yaml\n",
|
||||
"import numpy as np\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"plt.rcParams[\"figure.figsize\"] = (16,5)\n",
|
||||
"\n",
|
||||
"import librosa\n",
|
||||
"import librosa.display\n",
|
||||
"\n",
|
||||
"from TTS.models.tacotron import Tacotron \n",
|
||||
"from TTS.layers import *\n",
|
||||
"from TTS.utils.data import *\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.generic_utils import load_config, setup_model\n",
|
||||
"from TTS.utils.text import text_to_sequence\n",
|
||||
"from TTS.utils.synthesis import synthesis\n",
|
||||
"from TTS.utils.visual import visualize\n",
|
||||
"\n",
|
||||
"import IPython\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# you may need to change this depending on your system\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
|
||||
" t_1 = time.time()\n",
|
||||
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n",
|
||||
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
|
||||
" # coorect the normalization differences b/w TTS and the Vocoder.\n",
|
||||
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
|
||||
" mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n",
|
||||
"# mel_postnet_spec = np.pad(mel_postnet_spec, pad_width=((2, 2), (0, 0)))\n",
|
||||
" print(mel_postnet_spec.shape)\n",
|
||||
" print(\"max- \", mel_postnet_spec.max(), \" -- min- \", mel_postnet_spec.min())\n",
|
||||
" if not use_gl:\n",
|
||||
" waveform = vocoder_model.inference(torch.FloatTensor(ap_vocoder._normalize(mel_postnet_spec).T).unsqueeze(0), hop_size=ap_vocoder.hop_length)\n",
|
||||
"# waveform = waveform / abs(waveform).max() * 0.9\n",
|
||||
" if use_cuda:\n",
|
||||
" waveform = waveform.cpu()\n",
|
||||
" waveform = waveform.numpy()\n",
|
||||
" rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n",
|
||||
" print(waveform.shape)\n",
|
||||
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
|
||||
" print(\" > Real-time factor: {}\".format(rtf))\n",
|
||||
" if figures: \n",
|
||||
" visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec)) \n",
|
||||
" IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=False)) \n",
|
||||
" os.makedirs(OUT_FOLDER, exist_ok=True)\n",
|
||||
" file_name = text.replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
|
||||
" out_path = os.path.join(OUT_FOLDER, file_name)\n",
|
||||
" ap.save_wav(waveform, out_path)\n",
|
||||
" return alignment, mel_postnet_spec, stop_tokens, waveform"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set constants\n",
|
||||
"ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-bn-December-23-2019_08+34AM-ffea133/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/checkpoint_670000.pth.tar'\n",
|
||||
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
||||
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
|
||||
"CONFIG = load_config(CONFIG_PATH)\n",
|
||||
"VOCODER_MODEL_PATH = \"/home/erogol/Models/LJSpeech/pwgan-ljspeech/checkpoint-400000steps.pkl\"\n",
|
||||
"VOCODER_CONFIG_PATH = \"/home/erogol/Models/LJSpeech/pwgan-ljspeech/config.yml\"\n",
|
||||
"\n",
|
||||
"# load PWGAN config\n",
|
||||
"with open(VOCODER_CONFIG_PATH) as f:\n",
|
||||
" VOCODER_CONFIG = yaml.load(f, Loader=yaml.Loader)\n",
|
||||
" \n",
|
||||
"# Run FLAGs\n",
|
||||
"use_cuda = False\n",
|
||||
"# Set some config fields manually for testing\n",
|
||||
"CONFIG.windowing = True\n",
|
||||
"CONFIG.use_forward_attn = True \n",
|
||||
"# Set the vocoder\n",
|
||||
"use_gl = False # use GL if True\n",
|
||||
"batched_wavernn = True # use batched wavernn inference if True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# LOAD TTS MODEL\n",
|
||||
"from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
||||
"\n",
|
||||
"# multi speaker \n",
|
||||
"if CONFIG.use_speaker_embedding:\n",
|
||||
" speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n",
|
||||
" speakers_idx_to_id = {v: k for k, v in speakers.items()}\n",
|
||||
"else:\n",
|
||||
" speakers = []\n",
|
||||
" speaker_id = None\n",
|
||||
"\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in CONFIG.keys():\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.characters)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
|
||||
"model = setup_model(num_chars, len(speakers), CONFIG)\n",
|
||||
"\n",
|
||||
"# load the audio processor\n",
|
||||
"ap = AudioProcessor(**CONFIG.audio) \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# load model state\n",
|
||||
"cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"model.load_state_dict(cp['model'])\n",
|
||||
"if use_cuda:\n",
|
||||
" model.cuda()\n",
|
||||
"model.eval()\n",
|
||||
"print(cp['step'])\n",
|
||||
"print(cp['r'])\n",
|
||||
"\n",
|
||||
"# set model stepsize\n",
|
||||
"if 'r' in cp:\n",
|
||||
" model.decoder.set_r(cp['r'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# LOAD WAVERNN\n",
|
||||
"if use_gl == False:\n",
|
||||
" from parallel_wavegan.models import ParallelWaveGANGenerator\n",
|
||||
" from parallel_wavegan.utils.audio import AudioProcessor as AudioProcessorVocoder\n",
|
||||
" \n",
|
||||
" vocoder_model = ParallelWaveGANGenerator(**VOCODER_CONFIG[\"generator_params\"])\n",
|
||||
" vocoder_model.load_state_dict(torch.load(VOCODER_MODEL_PATH, map_location=\"cpu\")[\"model\"][\"generator\"])\n",
|
||||
" vocoder_model.remove_weight_norm()\n",
|
||||
" ap_vocoder = AudioProcessorVocoder(**VOCODER_CONFIG['audio']) \n",
|
||||
" if use_cuda:\n",
|
||||
" vocoder_model.cuda()\n",
|
||||
" vocoder_model.eval();"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Comparision with https://mycroft.ai/blog/available-voices/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.eval()\n",
|
||||
"model.decoder.max_decoder_steps = 2000\n",
|
||||
"model.decoder.prenet.eval()\n",
|
||||
"speaker_id = None\n",
|
||||
"sentence = '''A breeding jennet, lusty, young, and proud,'''\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Bill got in the habit of asking himself “Is that thought true?” and if he wasn’t absolutely certain it was, he just let it go.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### https://espnet.github.io/icassp2020-tts/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The Commission also recommends\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"As a result of these studies, the planning document submitted by the Secretary of the Treasury to the Bureau of the Budget on August thirty-one.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The FBI now transmits information on all defectors, a category which would, of course, have included Oswald.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"they seem unduly restrictive in continuing to require some manifestation of animus against a Government official.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"and each agency given clear understanding of the assistance which the Secret Service expects.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Other examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This cake is great. It's so delicious and moist.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Comparison with https://keithito.github.io/audio-samples/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Comparison with https://google.github.io/tacotron/publications/tacotron/index.html"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \" He has read the whole thing.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"He reads books.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Thisss isrealy awhsome.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This is your internet browser, Firefox.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This is your internet browser Firefox.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Eren, how are you?\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Hard Sentences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Encouraged, he started with a minute a day.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"If he decided to watch TV he really watched it.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# for twb dataset\n",
|
||||
"sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
|
@ -65,7 +65,7 @@
|
|||
"from TTS.utils.text import text_to_sequence\n",
|
||||
"from TTS.utils.synthesis import synthesis\n",
|
||||
"from TTS.utils.visual import visualize\n",
|
||||
"from TTS.utils.text.symbols import symbols, phonemes\n",
|
||||
"from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
||||
"\n",
|
||||
"import IPython\n",
|
||||
"from IPython.display import Audio\n",
|
||||
|
@ -149,6 +149,10 @@
|
|||
" speakers = []\n",
|
||||
" speaker_id = None\n",
|
||||
"\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in CONFIG.keys():\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.characters)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
|
||||
"model = setup_model(num_chars, len(speakers), CONFIG)\n",
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.visual import plot_spectrogram\n",
|
||||
"from TTS.utils.generic_utils import load_config, setup_model, sequence_mask\n",
|
||||
"from TTS.utils.text.symbols import symbols, phonemes\n",
|
||||
"from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
|
@ -94,6 +94,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in C.keys():\n",
|
||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
||||
"# TODO: multiple speaker\n",
|
||||
|
@ -116,7 +120,7 @@
|
|||
"preprocessor = importlib.import_module('TTS.datasets.preprocess')\n",
|
||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
|
||||
"dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
|
||||
"dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
|
||||
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -100,7 +100,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# LOAD TTS MODEL\n",
|
||||
"from TTS.utils.text.symbols import symbols, phonemes\n",
|
||||
"from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
||||
"\n",
|
||||
"# multi speaker \n",
|
||||
"if CONFIG.use_speaker_embedding:\n",
|
||||
|
@ -110,6 +110,10 @@
|
|||
" speakers = []\n",
|
||||
" speaker_id = None\n",
|
||||
"\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in CONFIG.keys():\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.characters)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
|
||||
"model = setup_model(num_chars, len(speakers), CONFIG)\n",
|
||||
|
|
|
@ -6,6 +6,10 @@ Instructions below are based on a Ubuntu 18.04 machine, but it should be simple
|
|||
|
||||
#### Development server:
|
||||
|
||||
##### Using server.py
|
||||
If you have the environment set already for TTS, then you can directly call ```setup.py```.
|
||||
|
||||
##### Using .whl
|
||||
1. apt-get install -y espeak libsndfile1 python3-venv
|
||||
2. python3 -m venv /tmp/venv
|
||||
3. source /tmp/venv/bin/activate
|
||||
|
|
|
@ -14,30 +14,52 @@ def create_argparser():
|
|||
parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file')
|
||||
parser.add_argument('--tts_config', type=str, help='path to TTS config.json file')
|
||||
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
|
||||
parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||
parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.')
|
||||
parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.')
|
||||
parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||
parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.')
|
||||
parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.')
|
||||
parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.')
|
||||
parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||
parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.')
|
||||
parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.')
|
||||
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
|
||||
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
|
||||
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
|
||||
return parser
|
||||
|
||||
|
||||
config = None
|
||||
synthesizer = None
|
||||
|
||||
embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')
|
||||
checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar')
|
||||
config_file = os.path.join(embedded_model_folder, 'config.json')
|
||||
embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')
|
||||
|
||||
if os.path.isfile(checkpoint_file) and os.path.isfile(config_file):
|
||||
# Use default config with embedded model files
|
||||
config = create_argparser().parse_args([])
|
||||
config.tts_checkpoint = checkpoint_file
|
||||
config.tts_config = config_file
|
||||
synthesizer = Synthesizer(config)
|
||||
embedded_tts_folder = os.path.join(embedded_models_folder, 'tts')
|
||||
tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar')
|
||||
tts_config_file = os.path.join(embedded_tts_folder, 'config.json')
|
||||
|
||||
embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn')
|
||||
wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar')
|
||||
wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json')
|
||||
|
||||
embedded_pwgan_folder = os.path.join(embedded_models_folder, 'pwgan')
|
||||
pwgan_checkpoint_file = os.path.join(embedded_pwgan_folder, 'checkpoint.pkl')
|
||||
pwgan_config_file = os.path.join(embedded_pwgan_folder, 'config.yml')
|
||||
|
||||
args = create_argparser().parse_args()
|
||||
|
||||
# If these were not specified in the CLI args, use default values with embedded model files
|
||||
if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file):
|
||||
args.tts_checkpoint = tts_checkpoint_file
|
||||
if not args.tts_config and os.path.isfile(tts_config_file):
|
||||
args.tts_config = tts_config_file
|
||||
if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file):
|
||||
args.wavernn_file = wavernn_checkpoint_file
|
||||
if not args.wavernn_config and os.path.isfile(wavernn_config_file):
|
||||
args.wavernn_config = wavernn_config_file
|
||||
if not args.pwgan_file and os.path.isfile(pwgan_checkpoint_file):
|
||||
args.pwgan_file = pwgan_checkpoint_file
|
||||
if not args.pwgan_config and os.path.isfile(pwgan_config_file):
|
||||
args.pwgan_config = pwgan_config_file
|
||||
|
||||
synthesizer = Synthesizer(args)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
@ -55,11 +77,4 @@ def tts():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = create_argparser().parse_args()
|
||||
|
||||
# Setup synthesizer from CLI args if they're specified or no embedded model
|
||||
# is present.
|
||||
if not config or not synthesizer or args.tts_checkpoint or args.tts_config:
|
||||
synthesizer = Synthesizer(args)
|
||||
|
||||
app.run(debug=config.debug, host='0.0.0.0', port=config.port)
|
||||
app.run(debug=args.debug, host='0.0.0.0', port=args.port)
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import load_config, setup_model
|
||||
from TTS.utils.text import phonemes, symbols
|
||||
from TTS.utils.speakers import load_speaker_mapping
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from TTS.utils.synthesis import *
|
||||
|
||||
import re
|
||||
from TTS.utils.text import make_symbols, phonemes, symbols
|
||||
|
||||
alphabets = r"([A-Za-z])"
|
||||
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
|
||||
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
|
||||
|
@ -23,6 +26,7 @@ websites = r"[.](com|net|org|io|gov)"
|
|||
class Synthesizer(object):
|
||||
def __init__(self, config):
|
||||
self.wavernn = None
|
||||
self.pwgan = None
|
||||
self.config = config
|
||||
self.use_cuda = self.config.use_cuda
|
||||
if self.use_cuda:
|
||||
|
@ -30,28 +34,38 @@ class Synthesizer(object):
|
|||
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
|
||||
self.config.use_cuda)
|
||||
if self.config.wavernn_lib_path:
|
||||
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path,
|
||||
self.config.wavernn_file, self.config.wavernn_config,
|
||||
self.config.use_cuda)
|
||||
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
|
||||
self.config.wavernn_config, self.config.use_cuda)
|
||||
if self.config.pwgan_lib_path:
|
||||
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
|
||||
self.config.pwgan_config, self.config.use_cuda)
|
||||
|
||||
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
||||
# pylint: disable=global-statement
|
||||
global symbols, phonemes
|
||||
|
||||
print(" > Loading TTS model ...")
|
||||
print(" | > model config: ", tts_config)
|
||||
print(" | > checkpoint file: ", tts_checkpoint)
|
||||
|
||||
self.tts_config = load_config(tts_config)
|
||||
self.use_phonemes = self.tts_config.use_phonemes
|
||||
self.ap = AudioProcessor(**self.tts_config.audio)
|
||||
|
||||
if 'characters' in self.tts_config.keys():
|
||||
symbols, phonemes = make_symbols(**self.tts_config.characters)
|
||||
|
||||
if self.use_phonemes:
|
||||
self.input_size = len(phonemes)
|
||||
else:
|
||||
self.input_size = len(symbols)
|
||||
# load speakers
|
||||
# TODO: fix this for multi-speaker model - load speakers
|
||||
if self.config.tts_speakers is not None:
|
||||
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers))
|
||||
self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
|
||||
num_speakers = len(self.tts_speakers)
|
||||
else:
|
||||
num_speakers = 0
|
||||
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
||||
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
||||
# load model state
|
||||
cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
|
||||
# load the model
|
||||
|
@ -63,16 +77,17 @@ class Synthesizer(object):
|
|||
if 'r' in cp:
|
||||
self.tts_model.decoder.set_r(cp['r'])
|
||||
|
||||
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
||||
def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
|
||||
# TODO: set a function in wavernn code base for model setup and call it here.
|
||||
sys.path.append(lib_path) # set this if TTS is not installed globally
|
||||
sys.path.append(lib_path) # set this if WaveRNN is not installed globally
|
||||
#pylint: disable=import-outside-toplevel
|
||||
from WaveRNN.models.wavernn import Model
|
||||
wavernn_config = os.path.join(model_path, model_config)
|
||||
model_file = os.path.join(model_path, model_file)
|
||||
print(" > Loading WaveRNN model ...")
|
||||
print(" | > model config: ", wavernn_config)
|
||||
print(" | > model config: ", model_config)
|
||||
print(" | > model file: ", model_file)
|
||||
self.wavernn_config = load_config(wavernn_config)
|
||||
self.wavernn_config = load_config(model_config)
|
||||
# This is the default architecture we use for our models.
|
||||
# You might need to update it
|
||||
self.wavernn = Model(
|
||||
rnn_dims=512,
|
||||
fc_dims=512,
|
||||
|
@ -80,7 +95,7 @@ class Synthesizer(object):
|
|||
mulaw=self.wavernn_config.mulaw,
|
||||
pad=self.wavernn_config.pad,
|
||||
use_aux_net=self.wavernn_config.use_aux_net,
|
||||
use_upsample_net = self.wavernn_config.use_upsample_net,
|
||||
use_upsample_net=self.wavernn_config.use_upsample_net,
|
||||
upsample_factors=self.wavernn_config.upsample_factors,
|
||||
feat_dims=80,
|
||||
compute_dims=128,
|
||||
|
@ -90,19 +105,36 @@ class Synthesizer(object):
|
|||
sample_rate=self.ap.sample_rate,
|
||||
).cuda()
|
||||
|
||||
check = torch.load(model_file)
|
||||
check = torch.load(model_file, map_location="cpu")
|
||||
self.wavernn.load_state_dict(check['model'])
|
||||
if use_cuda:
|
||||
self.wavernn.cuda()
|
||||
self.wavernn.eval()
|
||||
|
||||
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
|
||||
sys.path.append(lib_path) # set this if ParallelWaveGAN is not installed globally
|
||||
#pylint: disable=import-outside-toplevel
|
||||
from parallel_wavegan.models import ParallelWaveGANGenerator
|
||||
print(" > Loading PWGAN model ...")
|
||||
print(" | > model config: ", model_config)
|
||||
print(" | > model file: ", model_file)
|
||||
with open(model_config) as f:
|
||||
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
|
||||
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
|
||||
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
|
||||
self.pwgan.remove_weight_norm()
|
||||
if use_cuda:
|
||||
self.pwgan.cuda()
|
||||
self.pwgan.eval()
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
||||
wav = np.array(wav)
|
||||
self.ap.save_wav(wav, path)
|
||||
|
||||
def split_into_sentences(self, text):
|
||||
text = " " + text + " "
|
||||
@staticmethod
|
||||
def split_into_sentences(text):
|
||||
text = " " + text + " <stop>"
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(prefixes, "\\1<prd>", text)
|
||||
text = re.sub(websites, "<prd>\\1", text)
|
||||
|
@ -129,15 +161,13 @@ class Synthesizer(object):
|
|||
text = text.replace("<prd>", ".")
|
||||
sentences = text.split("<stop>")
|
||||
sentences = sentences[:-1]
|
||||
sentences = [s.strip() for s in sentences]
|
||||
sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences
|
||||
return sentences
|
||||
|
||||
def tts(self, text):
|
||||
wavs = []
|
||||
sens = self.split_into_sentences(text)
|
||||
print(sens)
|
||||
if not sens:
|
||||
sens = [text+'.']
|
||||
for sen in sens:
|
||||
# preprocess the given text
|
||||
inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda)
|
||||
|
@ -148,9 +178,16 @@ class Synthesizer(object):
|
|||
postnet_output, decoder_output, _ = parse_outputs(
|
||||
postnet_output, decoder_output, alignments)
|
||||
|
||||
if self.wavernn:
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
wav = self.wavernn.generate(torch.FloatTensor(postnet_output.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
||||
if self.pwgan:
|
||||
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
|
||||
if self.use_cuda:
|
||||
vocoder_input.cuda()
|
||||
wav = self.pwgan.inference(vocoder_input, hop_size=self.ap.hop_length)
|
||||
elif self.wavernn:
|
||||
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
|
||||
if self.use_cuda:
|
||||
vocoder_input.cuda()
|
||||
wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
||||
else:
|
||||
wav = inv_spectrogram(postnet_output, self.ap, self.tts_config)
|
||||
# trim silence
|
||||
|
|
7
setup.py
7
setup.py
|
@ -61,10 +61,11 @@ package_data = ['server/templates/*']
|
|||
if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config:
|
||||
print('Embedding model in wheel file...')
|
||||
model_dir = os.path.join('server', 'model')
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
embedded_checkpoint_path = os.path.join(model_dir, 'checkpoint.pth.tar')
|
||||
tts_dir = os.path.join(model_dir, 'tts')
|
||||
os.makedirs(tts_dir, exist_ok=True)
|
||||
embedded_checkpoint_path = os.path.join(tts_dir, 'checkpoint.pth.tar')
|
||||
shutil.copy(args.checkpoint, embedded_checkpoint_path)
|
||||
embedded_config_path = os.path.join(model_dir, 'config.json')
|
||||
embedded_config_path = os.path.join(tts_dir, 'config.json')
|
||||
shutil.copy(args.model_config, embedded_config_path)
|
||||
package_data.extend([embedded_checkpoint_path, embedded_config_path])
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
|
@ -7,7 +8,7 @@ import string
|
|||
|
||||
from TTS.utils.synthesis import synthesis
|
||||
from TTS.utils.generic_utils import load_config, setup_model
|
||||
from TTS.utils.text.symbols import symbols, phonemes
|
||||
from TTS.utils.text.symbols import make_symbols, symbols, phonemes
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
|
@ -47,6 +48,8 @@ def tts(model,
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
global symbols, phonemes
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('text', type=str, help='Text to generate speech.')
|
||||
parser.add_argument('config_path',
|
||||
|
@ -104,6 +107,10 @@ if __name__ == "__main__":
|
|||
# load the audio processor
|
||||
ap = AudioProcessor(**C.audio)
|
||||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if 'characters' in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.characters)
|
||||
|
||||
# load speakers
|
||||
if args.speakers_json != '':
|
||||
speakers = json.load(open(args.speakers_json, 'r'))
|
||||
|
|
|
@ -3,9 +3,11 @@
|
|||
"tts_config":"dummy_model_config.json", // tts config.json file
|
||||
"tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
|
||||
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
|
||||
"wavernn_path": null, // wavernn model root path
|
||||
"wavernn_file": null, // wavernn checkpoint file name
|
||||
"wavernn_config": null, // wavernn config file
|
||||
"pwgan_lib_path": null,
|
||||
"pwgan_file": null,
|
||||
"pwgan_config": null,
|
||||
"is_wavernn_batched":true,
|
||||
"port": 5002,
|
||||
"use_cuda": false,
|
||||
|
|
|
@ -19,6 +19,16 @@
|
|||
"mel_fmax": 7600, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"do_trim_silence": false
|
||||
},
|
||||
|
||||
"characters":{
|
||||
"pad": "_",
|
||||
"eos": "~",
|
||||
"bos": "^",
|
||||
"characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
|
||||
"punctuations":"!'(),-.:;? ",
|
||||
"phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||
},
|
||||
|
||||
"hidden_size": 128,
|
||||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
|
|
@ -5,13 +5,19 @@ import torch as T
|
|||
|
||||
from TTS.server.synthesizer import Synthesizer
|
||||
from TTS.tests import get_tests_input_path, get_tests_output_path
|
||||
from TTS.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model
|
||||
|
||||
|
||||
class DemoServerTest(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def _create_random_model(self):
|
||||
# pylint: disable=global-statement
|
||||
global symbols, phonemes
|
||||
config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json'))
|
||||
if 'characters' in config.keys():
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
|
||||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, 0, config)
|
||||
output_path = os.path.join(get_tests_output_path())
|
||||
|
|
|
@ -131,7 +131,7 @@ class L1LossMaskedTests(unittest.TestCase):
|
|||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
|
@ -140,7 +140,7 @@ class L1LossMaskedTests(unittest.TestCase):
|
|||
mask = (
|
||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
|
@ -148,4 +148,37 @@ class L1LossMaskedTests(unittest.TestCase):
|
|||
mask = (
|
||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.data[0])
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
# seq_len_norm = True
|
||||
# test input == target
|
||||
layer = L1LossMasked(seq_len_norm=True)
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = (
|
||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = (
|
||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
|
|
@ -37,7 +37,8 @@ class TestTTSDataset(unittest.TestCase):
|
|||
r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
meta_data=items,
|
||||
meta_data=items,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
batch_group_size=bgs,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=float("inf"),
|
||||
|
@ -137,9 +138,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
# NOTE: Below needs to check == 0 but due to an unknown reason
|
||||
# there is a slight difference between two matrices.
|
||||
# TODO: Check this assert cond more in detail.
|
||||
assert abs((abs(mel.T)
|
||||
- abs(mel_dl)
|
||||
).sum()) < 1e-5, (abs(mel.T) - abs(mel_dl)).sum()
|
||||
assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max()
|
||||
|
||||
# check mel-spec correctness
|
||||
mel_spec = mel_input[0].cpu().numpy()
|
||||
|
|
|
@ -11,7 +11,7 @@ source /tmp/venv/bin/activate
|
|||
pip install --quiet --upgrade pip setuptools wheel
|
||||
|
||||
rm -f dist/*.whl
|
||||
python setup.py bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json
|
||||
python setup.py --quiet bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json
|
||||
pip install --quiet dist/TTS*.whl
|
||||
|
||||
python -m TTS.server.server &
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
import os
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
# pylint: disable=unused-import
|
||||
import unittest
|
||||
import torch as T
|
||||
|
||||
from TTS.utils.text import *
|
||||
from TTS.tests import get_tests_path
|
||||
from TTS.utils.generic_utils import load_config
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
|
||||
|
||||
def test_phoneme_to_sequence():
|
||||
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
||||
|
@ -9,67 +16,80 @@ def test_phoneme_to_sequence():
|
|||
lang = "en-us"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
|
||||
assert text_hat == gt
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
# multiple punctuations
|
||||
text = "Be a voice, not an! echo?"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
# not ending with punctuation
|
||||
text = "Be a voice, not an! echo"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
# original
|
||||
text = "Be a voice, not an echo!"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
# extra space after the sentence
|
||||
text = "Be a voice, not an! echo. "
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
# extra space after the sentence
|
||||
text = "Be a voice, not an! echo. "
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang, True)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
# padding char
|
||||
text = "_Be a _voice, not an! echo_"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
def test_text2phone():
|
||||
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
||||
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i|| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n||| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
|
||||
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
|
||||
lang = "en-us"
|
||||
phonemes = text2phone(text, lang)
|
||||
assert gt == phonemes
|
||||
ph = text2phone(text, lang)
|
||||
assert gt == ph, f"\n{phonemes} \n vs \n{gt}"
|
252
train.py
252
train.py
|
@ -13,19 +13,19 @@ from torch.utils.data import DataLoader
|
|||
from TTS.datasets.TTSDataset import MyDataset
|
||||
from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
init_distributed, reduce_tensor)
|
||||
from TTS.layers.losses import L1LossMasked, MSELossMasked
|
||||
from TTS.layers.losses import L1LossMasked, MSELossMasked, BCELossMasked
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import (
|
||||
NoamLR, check_update, count_parameters, create_experiment_folder,
|
||||
get_git_branch, load_config, remove_experiment_folder, save_best_model,
|
||||
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
|
||||
setup_model, gradual_training_scheduler, KeepAverage,
|
||||
set_weight_decay)
|
||||
set_weight_decay, check_config)
|
||||
from TTS.utils.logger import Logger
|
||||
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||
get_speakers
|
||||
from TTS.utils.synthesis import synthesis
|
||||
from TTS.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.datasets.preprocess import load_meta_data
|
||||
from TTS.utils.radam import RAdam
|
||||
|
@ -49,6 +49,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
c.text_cleaner,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
batch_group_size=0 if is_val else c.batch_group_size *
|
||||
c.batch_size,
|
||||
min_seq_len=c.min_seq_len,
|
||||
|
@ -167,7 +168,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens,
|
||||
stop_targets) if c.stopnet else torch.zeros(1)
|
||||
stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
|
@ -327,6 +328,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
return keep_avg['avg_postnet_loss'], global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
||||
if c.use_speaker_embedding:
|
||||
|
@ -346,125 +348,124 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
keep_avg.add_values(eval_values_dict)
|
||||
print("\n > Validation")
|
||||
|
||||
with torch.no_grad():
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
|
||||
assert mel_input.shape[1] % model.decoder.r == 0
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
|
||||
assert mel_input.shape[1] % model.decoder.r == 0
|
||||
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(
|
||||
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input,
|
||||
mel_lengths)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input,
|
||||
mel_lengths)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input,
|
||||
mel_lengths)
|
||||
else:
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
loss = decoder_loss + postnet_loss + stop_loss
|
||||
|
||||
# backward decoder loss
|
||||
if c.bidirectional_decoder:
|
||||
if c.loss_masking:
|
||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
|
||||
else:
|
||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
||||
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
||||
loss += decoder_backward_loss + decoder_c_loss
|
||||
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_score = alignment_diagonal_score(alignments)
|
||||
keep_avg.update_value('avg_align_score', align_score)
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
|
||||
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
|
||||
if c.stopnet:
|
||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||
|
||||
keep_avg.update_values({
|
||||
'avg_postnet_loss':
|
||||
float(postnet_loss.item()),
|
||||
'avg_decoder_loss':
|
||||
float(decoder_loss.item()),
|
||||
'avg_stop_loss':
|
||||
float(stop_loss.item()),
|
||||
})
|
||||
|
||||
if num_iter % c.print_step == 0:
|
||||
print(
|
||||
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
||||
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}"
|
||||
.format(loss.item(), postnet_loss.item(),
|
||||
keep_avg['avg_postnet_loss'],
|
||||
decoder_loss.item(),
|
||||
keep_avg['avg_decoder_loss'], stop_loss.item(),
|
||||
keep_avg['avg_stop_loss'], align_score,
|
||||
keep_avg['avg_align_score']),
|
||||
flush=True)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
||||
"Tacotron", "TacotronGST"
|
||||
] else mel_input[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img)
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
# loss computation
|
||||
stop_loss = criterion_st(
|
||||
stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input,
|
||||
mel_lengths)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
postnet_loss = criterion(postnet_output, linear_input,
|
||||
mel_lengths)
|
||||
else:
|
||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
postnet_loss = criterion(postnet_output, mel_input,
|
||||
mel_lengths)
|
||||
else:
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
loss = decoder_loss + postnet_loss + stop_loss
|
||||
|
||||
# Plot Validation Stats
|
||||
epoch_stats = {
|
||||
"loss_postnet": keep_avg['avg_postnet_loss'],
|
||||
"loss_decoder": keep_avg['avg_decoder_loss'],
|
||||
"stop_loss": keep_avg['avg_stop_loss'],
|
||||
"alignment_score": keep_avg['avg_align_score']
|
||||
}
|
||||
# backward decoder loss
|
||||
if c.bidirectional_decoder:
|
||||
if c.loss_masking:
|
||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
|
||||
else:
|
||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
||||
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
||||
loss += decoder_backward_loss + decoder_c_loss
|
||||
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
||||
|
||||
if c.bidirectional_decoder:
|
||||
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss']
|
||||
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
||||
eval_figures['alignment_backward'] = plot_alignment(align_b_img)
|
||||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_score = alignment_diagonal_score(alignments)
|
||||
keep_avg.update_value('avg_align_score', align_score)
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
|
||||
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
|
||||
if c.stopnet:
|
||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||
|
||||
keep_avg.update_values({
|
||||
'avg_postnet_loss':
|
||||
float(postnet_loss.item()),
|
||||
'avg_decoder_loss':
|
||||
float(decoder_loss.item()),
|
||||
'avg_stop_loss':
|
||||
float(stop_loss.item()),
|
||||
})
|
||||
|
||||
if num_iter % c.print_step == 0:
|
||||
print(
|
||||
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
||||
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}"
|
||||
.format(loss.item(), postnet_loss.item(),
|
||||
keep_avg['avg_postnet_loss'],
|
||||
decoder_loss.item(),
|
||||
keep_avg['avg_decoder_loss'], stop_loss.item(),
|
||||
keep_avg['avg_stop_loss'], align_score,
|
||||
keep_avg['avg_align_score']),
|
||||
flush=True)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
||||
"Tacotron", "TacotronGST"
|
||||
] else mel_input[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img)
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
epoch_stats = {
|
||||
"loss_postnet": keep_avg['avg_postnet_loss'],
|
||||
"loss_decoder": keep_avg['avg_decoder_loss'],
|
||||
"stop_loss": keep_avg['avg_stop_loss'],
|
||||
"alignment_score": keep_avg['avg_align_score']
|
||||
}
|
||||
|
||||
if c.bidirectional_decoder:
|
||||
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss']
|
||||
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
||||
eval_figures['alignment_backward'] = plot_alignment(align_b_img)
|
||||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||
if c.test_sentences_file is None:
|
||||
|
@ -493,7 +494,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
use_cuda,
|
||||
ap,
|
||||
speaker_id=speaker_id,
|
||||
style_wav=style_wav)
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
|
@ -515,9 +521,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
global meta_data_train, meta_data_eval
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
|
@ -561,12 +570,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer_st = None
|
||||
|
||||
if c.loss_masking:
|
||||
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
|
||||
] else MSELossMasked()
|
||||
criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron", "TacotronGST"
|
||||
] else MSELossMasked(c.seq_len_norm)
|
||||
else:
|
||||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
||||
] else nn.MSELoss()
|
||||
criterion_st = nn.BCEWithLogitsLoss(
|
||||
criterion_st = BCELossMasked(
|
||||
pos_weight=torch.tensor(10)) if c.stopnet else None
|
||||
|
||||
if args.restore_path:
|
||||
|
@ -687,6 +696,7 @@ if __name__ == '__main__':
|
|||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
check_config(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
|
|
|
@ -12,6 +12,8 @@ class AudioProcessor(object):
|
|||
min_level_db=None,
|
||||
frame_shift_ms=None,
|
||||
frame_length_ms=None,
|
||||
hop_length=None,
|
||||
win_length=None,
|
||||
ref_level_db=None,
|
||||
num_freq=None,
|
||||
power=None,
|
||||
|
@ -24,6 +26,7 @@ class AudioProcessor(object):
|
|||
clip_norm=True,
|
||||
griffin_lim_iters=None,
|
||||
do_trim_silence=False,
|
||||
trim_db=60,
|
||||
sound_norm=False,
|
||||
**_):
|
||||
|
||||
|
@ -46,8 +49,14 @@ class AudioProcessor(object):
|
|||
self.max_norm = 1.0 if max_norm is None else float(max_norm)
|
||||
self.clip_norm = clip_norm
|
||||
self.do_trim_silence = do_trim_silence
|
||||
self.trim_db = trim_db
|
||||
self.sound_norm = sound_norm
|
||||
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||
if hop_length is None:
|
||||
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||
else:
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.n_fft = (self.num_freq - 1) * 2
|
||||
assert min_level_db != 0.0, " [!] min_level_db is 0"
|
||||
members = vars(self)
|
||||
for key, value in members.items():
|
||||
|
@ -66,12 +75,11 @@ class AudioProcessor(object):
|
|||
return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spec))
|
||||
|
||||
def _build_mel_basis(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
if self.mel_fmax is not None:
|
||||
assert self.mel_fmax <= self.sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
self.sample_rate,
|
||||
n_fft,
|
||||
self.n_fft,
|
||||
n_mels=self.num_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax)
|
||||
|
@ -197,6 +205,7 @@ class AudioProcessor(object):
|
|||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode='constant'
|
||||
)
|
||||
|
||||
def _istft(self, y):
|
||||
|
@ -217,7 +226,7 @@ class AudioProcessor(object):
|
|||
margin = int(self.sample_rate * 0.01)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(
|
||||
wav, top_db=60, frame_length=self.win_length, hop_length=self.hop_length)[0]
|
||||
wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0]
|
||||
|
||||
@staticmethod
|
||||
def mulaw_encode(wav, qc):
|
||||
|
|
|
@ -14,7 +14,7 @@ def prepare_data(inputs):
|
|||
|
||||
|
||||
def _pad_tensor(x, length):
|
||||
_pad = 0
|
||||
_pad = 0.
|
||||
assert x.ndim == 2
|
||||
x = np.pad(
|
||||
x, [[0, 0], [0, length - x.shape[1]]],
|
||||
|
@ -31,7 +31,7 @@ def prepare_tensor(inputs, out_steps):
|
|||
|
||||
|
||||
def _pad_stop_target(x, length):
|
||||
_pad = 1.
|
||||
_pad = 0.
|
||||
assert x.ndim == 1
|
||||
return np.pad(
|
||||
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||
|
|
|
@ -389,3 +389,133 @@ class KeepAverage():
|
|||
def update_values(self, value_dict):
|
||||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
||||
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
|
||||
if alternative in c.keys() and c[alternative] is not None:
|
||||
return
|
||||
if restricted:
|
||||
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
||||
if name in c.keys():
|
||||
if max_val:
|
||||
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
|
||||
if min_val:
|
||||
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
||||
if enum_list:
|
||||
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
||||
if val_type:
|
||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||
|
||||
|
||||
def check_config(c):
|
||||
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||
_check_argument('run_name', c, restricted=True, val_type=str)
|
||||
_check_argument('run_description', c, val_type=str)
|
||||
|
||||
# AUDIO
|
||||
_check_argument('audio', c, restricted=True, val_type=dict)
|
||||
|
||||
# audio processing parameters
|
||||
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
_check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
_check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
_check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# vocabulary parameters
|
||||
_check_argument('characters', c, restricted=False, val_type=dict)
|
||||
_check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
|
||||
# normalization parameters
|
||||
_check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||
_check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||
_check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||
_check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||
|
||||
# training parameters
|
||||
_check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
_check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
# _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
|
||||
# validation parameters
|
||||
_check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||
_check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||
|
||||
# optimizer
|
||||
_check_argument('noam_schedule', c, restricted=False, val_type=bool)
|
||||
_check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
|
||||
_check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
_check_argument('wd', c, restricted=True, val_type=float, min_val=0)
|
||||
_check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('seq_len_norm', c, restricted=True, val_type=bool)
|
||||
|
||||
# tacotron prenet
|
||||
_check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
|
||||
_check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
|
||||
_check_argument('prenet_dropout', c, restricted=True, val_type=bool)
|
||||
|
||||
# attention
|
||||
_check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
|
||||
_check_argument('attention_heads', c, restricted=True, val_type=int)
|
||||
_check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
_check_argument('windowing', c, restricted=True, val_type=bool)
|
||||
_check_argument('use_forward_attn', c, restricted=True, val_type=bool)
|
||||
_check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
|
||||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
_check_argument('location_attn', c, restricted=True, val_type=bool)
|
||||
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
||||
|
||||
# stopnet
|
||||
_check_argument('stopnet', c, restricted=True, val_type=bool)
|
||||
_check_argument('separate_stopnet', c, restricted=True, val_type=bool)
|
||||
|
||||
# tensorboard
|
||||
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
|
||||
# dataloading
|
||||
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=['english_cleaners', 'phoneme_cleaners', 'transliteration_cleaners', 'basic_cleaners'])
|
||||
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||
_check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||
|
||||
# paths
|
||||
_check_argument('output_path', c, restricted=True, val_type=str)
|
||||
|
||||
# multi-speaker gst
|
||||
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
_check_argument('style_wav_for_test', c, restricted=True, val_type=str)
|
||||
_check_argument('use_gst', c, restricted=True, val_type=bool)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
_check_argument('datasets', c, restricted=True, val_type=list)
|
||||
for dataset_entry in c['datasets']:
|
||||
_check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
|
@ -9,10 +9,11 @@ def text_to_seqvec(text, CONFIG, use_cuda):
|
|||
if CONFIG.use_phonemes:
|
||||
seq = np.asarray(
|
||||
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars),
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None),
|
||||
dtype=np.int32)
|
||||
else:
|
||||
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
||||
seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32)
|
||||
# torch tensor
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
|
@ -69,6 +70,24 @@ def id_to_torch(speaker_id):
|
|||
return speaker_id
|
||||
|
||||
|
||||
# TODO: perform GL with pytorch for batching
|
||||
def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
||||
'''Apply griffin-lim to each sample iterating throught the first dimension.
|
||||
Args:
|
||||
inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size.
|
||||
input_lens (Tensor or np.Array): 1D array of sample lengths.
|
||||
CONFIG (Dict): TTS config.
|
||||
ap (AudioProcessor): TTS audio processor.
|
||||
'''
|
||||
wavs = []
|
||||
for idx, spec in enumerate(inputs):
|
||||
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
|
||||
wav = inv_spectrogram(spec, ap, CONFIG)
|
||||
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
|
||||
wavs.append(wav[:wav_len])
|
||||
return wavs
|
||||
|
||||
|
||||
def synthesis(model,
|
||||
text,
|
||||
CONFIG,
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import re
|
||||
from packaging import version
|
||||
import phonemizer
|
||||
from phonemizer.phonemize import phonemize
|
||||
from TTS.utils.text import cleaners
|
||||
from TTS.utils.text.symbols import symbols, phonemes, _phoneme_punctuations, _bos, \
|
||||
from TTS.utils.text.symbols import make_symbols, symbols, phonemes, _phoneme_punctuations, _bos, \
|
||||
_eos
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)}
|
||||
_ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)}
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
_PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)}
|
||||
_ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)}
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
@ -28,29 +29,53 @@ def text2phone(text, language):
|
|||
seperator = phonemizer.separator.Separator(' |', '', '|')
|
||||
#try:
|
||||
punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text)
|
||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
|
||||
ph = ph[:-1].strip() # skip the last empty character
|
||||
# Replace \n with matching punctuations.
|
||||
if punctuations:
|
||||
# if text ends with a punctuation.
|
||||
if text[-1] == punctuations[-1]:
|
||||
for punct in punctuations[:-1]:
|
||||
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
||||
try:
|
||||
ph = ph + punctuations[-1]
|
||||
except:
|
||||
print(text)
|
||||
else:
|
||||
for punct in punctuations:
|
||||
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
||||
if version.parse(phonemizer.__version__) < version.parse('2.1'):
|
||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
|
||||
ph = ph[:-1].strip() # skip the last empty character
|
||||
# phonemizer does not tackle punctuations. Here we do.
|
||||
# Replace \n with matching punctuations.
|
||||
if punctuations:
|
||||
# if text ends with a punctuation.
|
||||
if text[-1] == punctuations[-1]:
|
||||
for punct in punctuations[:-1]:
|
||||
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
||||
ph = ph + punctuations[-1]
|
||||
else:
|
||||
for punct in punctuations:
|
||||
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
||||
elif version.parse(phonemizer.__version__) >= version.parse('2.1'):
|
||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True)
|
||||
# this is a simple fix for phonemizer.
|
||||
# https://github.com/bootphon/phonemizer/issues/32
|
||||
if punctuations:
|
||||
for punctuation in punctuations:
|
||||
ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace(f"| |{punctuation}", f"|{punctuation}| |")
|
||||
ph = ph[:-3]
|
||||
else:
|
||||
raise RuntimeError(" [!] Use 'phonemizer' version 2.1 or older.")
|
||||
|
||||
return ph
|
||||
|
||||
|
||||
def pad_with_eos_bos(phoneme_sequence):
|
||||
return [_PHONEMES_TO_ID[_bos]] + list(phoneme_sequence) + [_PHONEMES_TO_ID[_eos]]
|
||||
def pad_with_eos_bos(phoneme_sequence, tp=None):
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id, _bos, _eos
|
||||
if tp:
|
||||
_bos = tp['bos']
|
||||
_eos = tp['eos']
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
|
||||
|
||||
|
||||
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
|
||||
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None):
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id
|
||||
if tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
sequence = []
|
||||
text = text.replace(":", "")
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
|
@ -62,21 +87,27 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
|
|||
sequence += _phoneme_to_sequence(phoneme)
|
||||
# Append EOS char
|
||||
if enable_eos_bos:
|
||||
sequence = pad_with_eos_bos(sequence)
|
||||
sequence = pad_with_eos_bos(sequence, tp=tp)
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_phoneme(sequence):
|
||||
def sequence_to_phoneme(sequence, tp=None):
|
||||
# pylint: disable=global-statement
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
global _id_to_phonemes
|
||||
result = ''
|
||||
if tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
||||
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _ID_TO_PHONEMES:
|
||||
s = _ID_TO_PHONEMES[symbol_id]
|
||||
if symbol_id in _id_to_phonemes:
|
||||
s = _id_to_phonemes[symbol_id]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names):
|
||||
def text_to_sequence(text, cleaner_names, tp=None):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
|
@ -89,6 +120,12 @@ def text_to_sequence(text, cleaner_names):
|
|||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
# pylint: disable=global-statement
|
||||
global _symbol_to_id
|
||||
if tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||
|
||||
sequence = []
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while text:
|
||||
|
@ -103,12 +140,18 @@ def text_to_sequence(text, cleaner_names):
|
|||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
def sequence_to_text(sequence, tp=None):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
# pylint: disable=global-statement
|
||||
global _id_to_symbol
|
||||
if tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _ID_TO_SYMBOL:
|
||||
s = _ID_TO_SYMBOL[symbol_id]
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
|
@ -126,11 +169,11 @@ def _clean_text(text, cleaner_names):
|
|||
|
||||
|
||||
def _symbols_to_sequence(syms):
|
||||
return [_SYMBOL_TO_ID[s] for s in syms if _should_keep_symbol(s)]
|
||||
return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _phoneme_to_sequence(phons):
|
||||
return [_PHONEMES_TO_ID[s] for s in list(phons) if _should_keep_phoneme(s)]
|
||||
return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
|
@ -138,8 +181,8 @@ def _arpabet_to_sequence(text):
|
|||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _SYMBOL_TO_ID and s not in ['~', '^', '_']
|
||||
return s in _symbol_to_id and s not in ['~', '^', '_']
|
||||
|
||||
|
||||
def _should_keep_phoneme(p):
|
||||
return p in _PHONEMES_TO_ID and p not in ['~', '^', '_']
|
||||
return p in _phonemes_to_id and p not in ['~', '^', '_']
|
||||
|
|
|
@ -63,6 +63,19 @@ def convert_to_ascii(text):
|
|||
return unidecode(text)
|
||||
|
||||
|
||||
def remove_aux_symbols(text):
|
||||
text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text)
|
||||
return text
|
||||
|
||||
|
||||
def replace_symbols(text):
|
||||
text = text.replace(';', ',')
|
||||
text = text.replace('-', ' ')
|
||||
text = text.replace(':', ' ')
|
||||
text = text.replace('&', 'and')
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
|
@ -84,6 +97,8 @@ def english_cleaners(text):
|
|||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = replace_symbols(text)
|
||||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
@ -93,5 +108,7 @@ def phoneme_cleaners(text):
|
|||
text = convert_to_ascii(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = replace_symbols(text)
|
||||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
|
|
@ -5,6 +5,18 @@ Defines the set of symbols used in text input to the model.
|
|||
The default is a set of ASCII characters that works well for English or text that has been run
|
||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||
'''
|
||||
def make_symbols(characters, phonemes, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name
|
||||
''' Function to create symbols and phonemes '''
|
||||
_phonemes_sorted = sorted(list(phonemes))
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ['@' + s for s in _phonemes_sorted]
|
||||
|
||||
# Export all symbols:
|
||||
_symbols = [pad, eos, bos] + list(characters) + _arpabet
|
||||
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
|
||||
|
||||
return _symbols, _phonemes
|
||||
|
||||
_pad = '_'
|
||||
_eos = '~'
|
||||
|
@ -20,14 +32,9 @@ _pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðsz
|
|||
_suprasegmentals = 'ˈˌːˑ'
|
||||
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
|
||||
_diacrilics = 'ɚ˞ɫ'
|
||||
_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))
|
||||
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ['@' + s for s in _phonemes]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad, _eos, _bos] + list(_characters) + _arpabet
|
||||
phonemes = [_pad, _eos, _bos] + list(_phonemes) + list(_punctuations)
|
||||
symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
|
||||
|
||||
# Generate ALIEN language
|
||||
# from random import shuffle
|
||||
|
|
|
@ -54,9 +54,10 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
|
|||
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
|
||||
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||
if CONFIG.use_phonemes:
|
||||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars)
|
||||
text = sequence_to_phoneme(seq)
|
||||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
print(text)
|
||||
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
|
||||
|
|
Loading…
Reference in New Issue