mirror of https://github.com/coqui-ai/TTS.git
padding idx for embedding layer
This commit is contained in:
parent
e7ef4e9050
commit
957f7dcbc5
|
@ -39,7 +39,7 @@ class Tacotron(nn.Module):
|
||||||
encoder_dim = 512 if num_speakers > 1 else 256
|
encoder_dim = 512 if num_speakers > 1 else 256
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
# embedding layer
|
# embedding layer
|
||||||
self.embedding = nn.Embedding(num_chars, 256)
|
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
# boilerplate model
|
# boilerplate model
|
||||||
self.encoder = Encoder(encoder_dim)
|
self.encoder = Encoder(encoder_dim)
|
||||||
|
|
|
@ -35,7 +35,7 @@ class Tacotron2(nn.Module):
|
||||||
encoder_dim = 512 if num_speakers > 1 else 512
|
encoder_dim = 512 if num_speakers > 1 else 512
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
# embedding layer
|
# embedding layer
|
||||||
self.embedding = nn.Embedding(num_chars, 512)
|
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||||
std = sqrt(2.0 / (num_chars + 512))
|
std = sqrt(2.0 / (num_chars + 512))
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding.weight.data.uniform_(-val, val)
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
|
|
|
@ -63,6 +63,19 @@ def convert_to_ascii(text):
|
||||||
return unidecode(text)
|
return unidecode(text)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_aux_symbols(text):
|
||||||
|
text = re.sub(r'[\<\>\(\)\[\]\"\']+', '', text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def replace_symbols(text):
|
||||||
|
text = text.replace(';', ',')
|
||||||
|
text = text.replace('-', ' ')
|
||||||
|
text = text.replace(':', ' ')
|
||||||
|
text = text.replace('&', 'and')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def basic_cleaners(text):
|
def basic_cleaners(text):
|
||||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
|
@ -84,6 +97,8 @@ def english_cleaners(text):
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = expand_numbers(text)
|
text = expand_numbers(text)
|
||||||
text = expand_abbreviations(text)
|
text = expand_abbreviations(text)
|
||||||
|
text = replace_symbols(text)
|
||||||
|
text = remove_aux_symbols(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
@ -93,5 +108,7 @@ def phoneme_cleaners(text):
|
||||||
text = convert_to_ascii(text)
|
text = convert_to_ascii(text)
|
||||||
text = expand_numbers(text)
|
text = expand_numbers(text)
|
||||||
text = expand_abbreviations(text)
|
text = expand_abbreviations(text)
|
||||||
|
text = replace_symbols(text)
|
||||||
|
text = remove_aux_symbols(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
Loading…
Reference in New Issue