mirror of https://github.com/coqui-ai/TTS.git
bugfix in DDC now DDC work on Tacotron1
This commit is contained in:
parent
b750452782
commit
e265810e8c
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"model": "Tacotron2",
|
"model": "Tacotron",
|
||||||
"run_name": "ljspeech-ddc-bn",
|
"run_name": "ljspeech-ddc-bn",
|
||||||
"run_description": "tacotron2 with ddc and batch-normalization",
|
"run_description": "tacotron2 with ddc and batch-normalization",
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@
|
||||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
|
||||||
// DATA LOADING
|
// DATA LOADING
|
||||||
"text_cleaner": "phoneme_cleaners",
|
"text_cleaner": "portuguese_cleaners",
|
||||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||||
|
@ -123,15 +123,15 @@
|
||||||
"max_seq_len": 153, // DATASET-RELATED: maximum text length
|
"max_seq_len": 153, // DATASET-RELATED: maximum text length
|
||||||
|
|
||||||
// PATHS
|
// PATHS
|
||||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
"output_path": "../../Mozilla-TTS/vctk-test/",
|
||||||
|
|
||||||
// PHONEMES
|
// PHONEMES
|
||||||
"phoneme_cache_path": "/media/erogol/data_ssd2/mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder.
|
"phoneme_cache_path": "../../Mozilla-TTS/vctk-test/", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||||
|
|
||||||
// MULTI-SPEAKER and GST
|
// MULTI-SPEAKER and GST
|
||||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
"use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning.
|
||||||
"use_gst": true, // use global style tokens
|
"use_gst": true, // use global style tokens
|
||||||
"gst": { // gst parameter if gst is enabled
|
"gst": { // gst parameter if gst is enabled
|
||||||
"gst_style_input": null, // Condition the style input either on a
|
"gst_style_input": null, // Condition the style input either on a
|
||||||
|
@ -147,9 +147,9 @@
|
||||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"name": "vctk",
|
||||||
"path": "/home/erogol/Data/LJSpeech-1.1/",
|
"path": "../../../datasets/VCTK-Corpus-removed-silence/",
|
||||||
"meta_file_train": "metadata.csv",
|
"meta_file_train": ["p225", "p234", "p238", "p245", "p248", "p261", "p294", "p302", "p326", "p335", "p347"], // for vtck if list, ignore speakers id in list for train, its useful for test cloning with new speakers
|
||||||
"meta_file_val": null
|
"meta_file_val": null
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -70,7 +70,7 @@ class MyDataset(Dataset):
|
||||||
self.sort_items()
|
self.sort_items()
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
audio = self.ap.load_wav(filename)
|
audio = self.ap.load_wav(filename, sr=self.sample_rate)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -303,7 +303,7 @@ class Decoder(nn.Module):
|
||||||
self.separate_stopnet = separate_stopnet
|
self.separate_stopnet = separate_stopnet
|
||||||
self.query_dim = 256
|
self.query_dim = 256
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
prenet_dim = frame_channels * self.memory_size + speaker_embedding_dim if self.use_memory_queue else frame_channels + speaker_embedding_dim
|
prenet_dim = memory_dim * self.memory_size if self.use_memory_queue else memory_dim
|
||||||
self.prenet = Prenet(
|
self.prenet = Prenet(
|
||||||
prenet_dim,
|
prenet_dim,
|
||||||
prenet_type,
|
prenet_type,
|
||||||
|
@ -429,7 +429,7 @@ class Decoder(nn.Module):
|
||||||
# assert new_memory.shape[-1] == self.r * self.frame_channels
|
# assert new_memory.shape[-1] == self.r * self.frame_channels
|
||||||
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):]
|
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):]
|
||||||
|
|
||||||
def forward(self, inputs, memory, mask, speaker_embeddings=None):
|
def forward(self, inputs, memory, mask):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs: Encoder outputs.
|
inputs: Encoder outputs.
|
||||||
|
@ -454,8 +454,7 @@ class Decoder(nn.Module):
|
||||||
if t > 0:
|
if t > 0:
|
||||||
new_memory = memory[t - 1]
|
new_memory = memory[t - 1]
|
||||||
self._update_memory_input(new_memory)
|
self._update_memory_input(new_memory)
|
||||||
if speaker_embeddings is not None:
|
|
||||||
self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1)
|
|
||||||
output, stop_token, attention = self.decode(inputs, mask)
|
output, stop_token, attention = self.decode(inputs, mask)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
attentions += [attention]
|
attentions += [attention]
|
||||||
|
|
|
@ -300,7 +300,7 @@ class Decoder(nn.Module):
|
||||||
decoder_output = decoder_output[:, :self.r * self.frame_channels]
|
decoder_output = decoder_output[:, :self.r * self.frame_channels]
|
||||||
return decoder_output, self.attention.attention_weights, stop_token
|
return decoder_output, self.attention.attention_weights, stop_token
|
||||||
|
|
||||||
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
def forward(self, inputs, memories, mask):
|
||||||
r"""Train Decoder with teacher forcing.
|
r"""Train Decoder with teacher forcing.
|
||||||
Args:
|
Args:
|
||||||
inputs: Encoder outputs.
|
inputs: Encoder outputs.
|
||||||
|
@ -318,8 +318,6 @@ class Decoder(nn.Module):
|
||||||
memories = self._reshape_memory(memories)
|
memories = self._reshape_memory(memories)
|
||||||
memories = torch.cat((memory, memories), dim=0)
|
memories = torch.cat((memory, memories), dim=0)
|
||||||
memories = self._update_memory(memories)
|
memories = self._update_memory(memories)
|
||||||
if speaker_embeddings is not None:
|
|
||||||
memories = torch.cat([memories, speaker_embeddings], dim=-1)
|
|
||||||
memories = self.prenet(memories)
|
memories = self.prenet(memories)
|
||||||
|
|
||||||
self._init_states(inputs, mask=mask)
|
self._init_states(inputs, mask=mask)
|
||||||
|
|
|
@ -6,7 +6,6 @@ from mozilla_voice_tts.tts.layers.gst_layers import GST
|
||||||
from mozilla_voice_tts.tts.layers.tacotron import Decoder, Encoder, PostCBHG
|
from mozilla_voice_tts.tts.layers.tacotron import Decoder, Encoder, PostCBHG
|
||||||
from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract
|
from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract
|
||||||
|
|
||||||
|
|
||||||
class Tacotron(TacotronAbstract):
|
class Tacotron(TacotronAbstract):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_chars,
|
num_chars,
|
||||||
|
@ -41,10 +40,19 @@ class Tacotron(TacotronAbstract):
|
||||||
location_attn, attn_K, separate_stopnet,
|
location_attn, attn_K, separate_stopnet,
|
||||||
bidirectional_decoder, double_decoder_consistency,
|
bidirectional_decoder, double_decoder_consistency,
|
||||||
ddc_r, gst)
|
ddc_r, gst)
|
||||||
decoder_in_features = 512 if num_speakers > 1 else 256
|
|
||||||
encoder_in_features = 512 if num_speakers > 1 else 256
|
|
||||||
|
# init layer dims
|
||||||
|
decoder_in_features = 256
|
||||||
|
encoder_in_features = 256
|
||||||
speaker_embedding_dim = 256
|
speaker_embedding_dim = 256
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
|
|
||||||
|
if num_speakers > 1:
|
||||||
|
decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim
|
||||||
|
if self.gst:
|
||||||
|
decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim
|
||||||
|
|
||||||
# base model layers
|
# base model layers
|
||||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
|
@ -98,10 +106,6 @@ class Tacotron(TacotronAbstract):
|
||||||
# B x speaker_embed_dim
|
# B x speaker_embed_dim
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
self.compute_speaker_embedding(speaker_ids)
|
self.compute_speaker_embedding(speaker_ids)
|
||||||
if self.num_speakers > 1:
|
|
||||||
# B x T_in x embed_dim + speaker_embed_dim
|
|
||||||
inputs = self._concat_speaker_embedding(inputs,
|
|
||||||
self.speaker_embeddings)
|
|
||||||
# B x T_in x encoder_in_features
|
# B x T_in x encoder_in_features
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
|
@ -117,8 +121,7 @@ class Tacotron(TacotronAbstract):
|
||||||
# alignments: B x T_in x encoder_in_features
|
# alignments: B x T_in x encoder_in_features
|
||||||
# stop_tokens: B x T_in
|
# stop_tokens: B x T_in
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs, input_mask,
|
encoder_outputs, mel_specs, input_mask)
|
||||||
self.speaker_embeddings_projected)
|
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if output_mask is not None:
|
if output_mask is not None:
|
||||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||||
|
@ -145,9 +148,6 @@ class Tacotron(TacotronAbstract):
|
||||||
self._init_states()
|
self._init_states()
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
self.compute_speaker_embedding(speaker_ids)
|
self.compute_speaker_embedding(speaker_ids)
|
||||||
if self.num_speakers > 1:
|
|
||||||
inputs = self._concat_speaker_embedding(inputs,
|
|
||||||
self.speaker_embeddings)
|
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
if self.gst and style_mel is not None:
|
if self.gst and style_mel is not None:
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
|
|
@ -5,7 +5,6 @@ from mozilla_voice_tts.tts.layers.gst_layers import GST
|
||||||
from mozilla_voice_tts.tts.layers.tacotron2 import Decoder, Encoder, Postnet
|
from mozilla_voice_tts.tts.layers.tacotron2 import Decoder, Encoder, Postnet
|
||||||
from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract
|
from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract
|
||||||
|
|
||||||
|
|
||||||
# TODO: match function arguments with tacotron
|
# TODO: match function arguments with tacotron
|
||||||
class Tacotron2(TacotronAbstract):
|
class Tacotron2(TacotronAbstract):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -86,24 +85,6 @@ class Tacotron2(TacotronAbstract):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def compute_gst(self, inputs, style_input):
|
|
||||||
""" Compute global style token """
|
|
||||||
device = inputs.device
|
|
||||||
if isinstance(style_input, dict):
|
|
||||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
|
||||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
|
||||||
for k_token, v_amplifier in style_input.items():
|
|
||||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
|
||||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
|
||||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
|
||||||
elif style_input is None:
|
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
|
||||||
else:
|
|
||||||
gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable
|
|
||||||
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
|
||||||
return inputs, embedded_gst
|
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
|
@ -113,20 +94,13 @@ class Tacotron2(TacotronAbstract):
|
||||||
# B x T_in_max x D_en
|
# B x T_in_max x D_en
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
|
|
||||||
|
if self.gst:
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||||
if hasattr(self, 'gst'):
|
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
|
||||||
else:
|
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
|
||||||
else:
|
|
||||||
if hasattr(self, 'gst'):
|
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
|
||||||
|
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||||
|
|
||||||
|
@ -163,15 +137,14 @@ class Tacotron2(TacotronAbstract):
|
||||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||||
if hasattr(self, 'gst'):
|
if hasattr(self, 'gst'):
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
else:
|
else:
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
else:
|
else:
|
||||||
if hasattr(self, 'gst'):
|
if hasattr(self, 'gst'):
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
|
||||||
|
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
|
@ -193,15 +166,13 @@ class Tacotron2(TacotronAbstract):
|
||||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||||
if hasattr(self, 'gst'):
|
if hasattr(self, 'gst'):
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
|
||||||
else:
|
else:
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
else:
|
else:
|
||||||
if hasattr(self, 'gst'):
|
if hasattr(self, 'gst'):
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
|
||||||
|
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
|
|
|
@ -164,11 +164,22 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||||
self.speaker_embeddings).squeeze(1)
|
self.speaker_embeddings).squeeze(1)
|
||||||
|
|
||||||
def compute_gst(self, inputs, mel_specs):
|
def compute_gst(self, inputs, style_input):
|
||||||
""" Compute global style token """
|
""" Compute global style token """
|
||||||
# pylint: disable=not-callable
|
device = inputs.device
|
||||||
gst_outputs = self.gst_layer(mel_specs)
|
if isinstance(style_input, dict):
|
||||||
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||||
|
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||||
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
for k_token, v_amplifier in style_input.items():
|
||||||
|
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
|
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||||
|
elif style_input is None:
|
||||||
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
else:
|
||||||
|
gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable
|
||||||
|
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -257,10 +257,16 @@ def check_config(c):
|
||||||
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
|
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||||
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
|
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||||
|
|
||||||
|
check_argument('gst', c, restricted=True, val_type=dict)
|
||||||
|
check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict])
|
||||||
|
check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||||
|
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10)
|
||||||
|
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000)
|
||||||
|
|
||||||
# datasets - checking only the first entry
|
# datasets - checking only the first entry
|
||||||
check_argument('datasets', c, restricted=True, val_type=list)
|
check_argument('datasets', c, restricted=True, val_type=list)
|
||||||
for dataset_entry in c['datasets']:
|
for dataset_entry in c['datasets']:
|
||||||
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||||
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str)
|
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||||
|
|
|
@ -107,7 +107,6 @@ def basic_turkish_cleaners(text):
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def english_cleaners(text):
|
def english_cleaners(text):
|
||||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||||
text = convert_to_ascii(text)
|
text = convert_to_ascii(text)
|
||||||
|
|
|
@ -146,5 +146,12 @@ def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restrict
|
||||||
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
||||||
if enum_list:
|
if enum_list:
|
||||||
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
||||||
if val_type:
|
if isinstance(val_type, list):
|
||||||
|
valid_types = val_type
|
||||||
|
is_valid = False
|
||||||
|
for typ in val_type:
|
||||||
|
if isinstance(c[name], typ):
|
||||||
|
is_valid = True
|
||||||
|
assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||||
|
elif val_type:
|
||||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||||
|
|
Loading…
Reference in New Issue