From 957f7dcbc5ec70e5aa634a2ca7d810ec1c95b25c Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 20 Feb 2020 12:24:54 +0100 Subject: [PATCH] padding idx for embedding layer --- models/tacotron.py | 2 +- models/tacotron2.py | 2 +- utils/text/cleaners.py | 17 +++++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/models/tacotron.py b/models/tacotron.py index 04ecd573..fba82b1b 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -39,7 +39,7 @@ class Tacotron(nn.Module): encoder_dim = 512 if num_speakers > 1 else 256 proj_speaker_dim = 80 if num_speakers > 1 else 0 # embedding layer - self.embedding = nn.Embedding(num_chars, 256) + self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) self.embedding.weight.data.normal_(0, 0.3) # boilerplate model self.encoder = Encoder(encoder_dim) diff --git a/models/tacotron2.py b/models/tacotron2.py index 3a3863de..d530774a 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -35,7 +35,7 @@ class Tacotron2(nn.Module): encoder_dim = 512 if num_speakers > 1 else 512 proj_speaker_dim = 80 if num_speakers > 1 else 0 # embedding layer - self.embedding = nn.Embedding(num_chars, 512) + self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) std = sqrt(2.0 / (num_chars + 512)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) diff --git a/utils/text/cleaners.py b/utils/text/cleaners.py index 581633a2..962b3c31 100644 --- a/utils/text/cleaners.py +++ b/utils/text/cleaners.py @@ -63,6 +63,19 @@ def convert_to_ascii(text): return unidecode(text) +def remove_aux_symbols(text): + text = re.sub(r'[\<\>\(\)\[\]\"\']+', '', text) + return text + + +def replace_symbols(text): + text = text.replace(';', ',') + text = text.replace('-', ' ') + text = text.replace(':', ' ') + text = text.replace('&', 'and') + return text + + def basic_cleaners(text): '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' text = lowercase(text) @@ -84,6 +97,8 @@ def english_cleaners(text): text = lowercase(text) text = expand_numbers(text) text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) text = collapse_whitespace(text) return text @@ -93,5 +108,7 @@ def phoneme_cleaners(text): text = convert_to_ascii(text) text = expand_numbers(text) text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) text = collapse_whitespace(text) return text