From 957f7dcbc5ec70e5aa634a2ca7d810ec1c95b25c Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 20 Feb 2020 12:24:54 +0100 Subject: [PATCH 01/11] padding idx for embedding layer --- models/tacotron.py | 2 +- models/tacotron2.py | 2 +- utils/text/cleaners.py | 17 +++++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/models/tacotron.py b/models/tacotron.py index 04ecd573..fba82b1b 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -39,7 +39,7 @@ class Tacotron(nn.Module): encoder_dim = 512 if num_speakers > 1 else 256 proj_speaker_dim = 80 if num_speakers > 1 else 0 # embedding layer - self.embedding = nn.Embedding(num_chars, 256) + self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) self.embedding.weight.data.normal_(0, 0.3) # boilerplate model self.encoder = Encoder(encoder_dim) diff --git a/models/tacotron2.py b/models/tacotron2.py index 3a3863de..d530774a 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -35,7 +35,7 @@ class Tacotron2(nn.Module): encoder_dim = 512 if num_speakers > 1 else 512 proj_speaker_dim = 80 if num_speakers > 1 else 0 # embedding layer - self.embedding = nn.Embedding(num_chars, 512) + self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) std = sqrt(2.0 / (num_chars + 512)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) diff --git a/utils/text/cleaners.py b/utils/text/cleaners.py index 581633a2..962b3c31 100644 --- a/utils/text/cleaners.py +++ b/utils/text/cleaners.py @@ -63,6 +63,19 @@ def convert_to_ascii(text): return unidecode(text) +def remove_aux_symbols(text): + text = re.sub(r'[\<\>\(\)\[\]\"\']+', '', text) + return text + + +def replace_symbols(text): + text = text.replace(';', ',') + text = text.replace('-', ' ') + text = text.replace(':', ' ') + text = text.replace('&', 'and') + return text + + def basic_cleaners(text): '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' text = lowercase(text) @@ -84,6 +97,8 @@ def english_cleaners(text): text = lowercase(text) text = expand_numbers(text) text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) text = collapse_whitespace(text) return text @@ -93,5 +108,7 @@ def phoneme_cleaners(text): text = convert_to_ascii(text) text = expand_numbers(text) text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) text = collapse_whitespace(text) return text From aef12f0c216e1fa69506634c1b7b06c119dc1751 Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 20 Feb 2020 15:52:38 +0100 Subject: [PATCH 02/11] fix test iteratoin --- train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 7bfb8751..4cf366e3 100644 --- a/train.py +++ b/train.py @@ -493,7 +493,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): use_cuda, ap, speaker_id=speaker_id, - style_wav=style_wav) + style_wav=style_wav, + truncated=False, + enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + use_griffin_lim=True, + do_trim_silence=False) + file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) file_path = os.path.join(file_path, From bc6764a5c7c58463d2a879f54ca0cc390ca69070 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 21 Feb 2020 14:57:10 +0100 Subject: [PATCH 03/11] bug fix at server --- server/synthesizer.py | 4 ++-- utils/synthesis.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/server/synthesizer.py b/server/synthesizer.py index 347bef21..afc083aa 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -94,8 +94,8 @@ class Synthesizer(object): sample_rate=self.ap.sample_rate, ).cuda() - check = torch.load(model_file) - self.wavernn.load_state_dict(check['model'], map_location="cpu") + check = torch.load(model_file, map_location="cpu") + self.wavernn.load_state_dict(check['model']) if use_cuda: self.wavernn.cuda() self.wavernn.eval() diff --git a/utils/synthesis.py b/utils/synthesis.py index 79a17c78..b4512dc6 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -69,6 +69,24 @@ def id_to_torch(speaker_id): return speaker_id +# TODO: perform GL with pytorch for batching +def apply_griffin_lim(inputs, input_lens, CONFIG, ap): + '''Apply griffin-lim to each sample iterating throught the first dimension. + Args: + inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size. + input_lens (Tensor or np.Array): 1D array of sample lengths. + CONFIG (Dict): TTS config. + ap (AudioProcessor): TTS audio processor. + ''' + wavs = [] + for idx, spec in enumerate(inputs): + wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding + wav = inv_spectrogram(spec, ap, CONFIG) + # assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}" + wavs.append(wav[:wav_len]) + return wavs + + def synthesis(model, text, CONFIG, From 6abae742b7b8ea6d0bdcf2ba827a3b3aa0dace76 Mon Sep 17 00:00:00 2001 From: erogol Date: Sun, 23 Feb 2020 03:04:27 +0100 Subject: [PATCH 04/11] updateing cleaners --- utils/text/cleaners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/text/cleaners.py b/utils/text/cleaners.py index 962b3c31..e6b611b4 100644 --- a/utils/text/cleaners.py +++ b/utils/text/cleaners.py @@ -64,7 +64,7 @@ def convert_to_ascii(text): def remove_aux_symbols(text): - text = re.sub(r'[\<\>\(\)\[\]\"\']+', '', text) + text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text) return text From 28eb3abfd6b19648eec95588d940f0b9f66cf1d1 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 12 Feb 2020 23:54:33 +0100 Subject: [PATCH 05/11] setting stft parameters with constants --- config.json | 12 +++++------- utils/audio.py | 9 ++++++++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/config.json b/config.json index c1a8158d..094bb2c6 100644 --- a/config.json +++ b/config.json @@ -9,8 +9,8 @@ "num_mels": 80, // size of the mel spec frame. "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. - "frame_length_ms": 50.0, // stft window length in ms. - "frame_shift_ms": 12.5, // stft window hop-lengh in ms. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. "min_level_db": -100, // normalization range "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. @@ -63,7 +63,7 @@ "prenet_dropout": true, // enable/disable dropout at prenet. // ATTENTION - "attention_type": "graves", // 'original' or 'graves' + "attention_type": "original", // 'original' or 'graves' "attention_heads": 4, // number of attention heads (only for 'graves') "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "windowing": false, // Enables attention windowing. Used only in eval mode. @@ -93,8 +93,7 @@ "max_seq_len": 150, // DATASET-RELATED: maximum text length // PATHS - // "output_path": "/data5/rw/pit/keep/", // DATASET-RELATED: output path for all training outputs. - "output_path": "/home/erogol/Models/LJSpeech/", + "output_path": "/data4/rw/home/Trainings/", // PHONEMES "phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder. @@ -111,8 +110,7 @@ [ { "name": "ljspeech", - "path": "/home/erogol/Data/LJSpeech-1.1/", - // "path": "/home/erogol/Data/LJSpeech-1.1", + "path": "/root/LJSpeech-1.1/", "meta_file_train": "metadata.csv", "meta_file_val": null } diff --git a/utils/audio.py b/utils/audio.py index 7b2c4834..771e6a43 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -12,6 +12,8 @@ class AudioProcessor(object): min_level_db=None, frame_shift_ms=None, frame_length_ms=None, + hop_length=None, + win_length=None, ref_level_db=None, num_freq=None, power=None, @@ -49,7 +51,12 @@ class AudioProcessor(object): self.do_trim_silence = do_trim_silence self.trim_db = trim_db self.sound_norm = sound_norm - self.n_fft, self.hop_length, self.win_length = self._stft_parameters() + if hop_length is None: + self.n_fft, self.hop_length, self.win_length = self._stft_parameters() + else: + self.hop_length = hop_length + self.win_length = win_length + self.n_fft = (self.num_freq - 1) * 2 assert min_level_db != 0.0, " [!] min_level_db is 0" members = vars(self) for key, value in members.items(): From 79a36843703ebcdf448d20bf5ef1e44fec483a57 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 14 Feb 2020 18:00:34 +0100 Subject: [PATCH 06/11] config update --- config.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.json b/config.json index 094bb2c6..8653d92f 100644 --- a/config.json +++ b/config.json @@ -1,7 +1,7 @@ { "model": "Tacotron2", // one of the model in models/ - "run_name": "ljspeech-gravesv2", - "run_description": "tacotron2 wuth graves attention", + "run_name": "ljspeech-stf_params", + "run_description": "tacotron2 cosntant stf parameters", // AUDIO PARAMETERS "audio":{ From f7074608868aadd337d46f5e7ab35ff848a58c73 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 18 Feb 2020 12:26:52 +0100 Subject: [PATCH 07/11] clearners update for special chars and conifg update --- config.json | 8 ++++---- notebooks/Benchmark-PWGAN.ipynb | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/config.json b/config.json index 8653d92f..a7ed04a3 100644 --- a/config.json +++ b/config.json @@ -39,7 +39,7 @@ "batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "eval_batch_size":16, "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. - "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. + "gradual_training": [[0, 7, 64], [2000, 5, 64], [35000, 3, 32], [70000, 2, 32], [140000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. "loss_masking": true, // enable / disable loss masking against the sequence padding. "grad_accum": 2, // if N > 1, enable gradient accumulation for N iterations. It is useful for low memory GPUs. @@ -93,10 +93,10 @@ "max_seq_len": 150, // DATASET-RELATED: maximum text length // PATHS - "output_path": "/data4/rw/home/Trainings/", + "output_path": "/home/erogol/Models/", // PHONEMES - "phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder. + "phoneme_cache_path": "mozilla_us_phonemes_2_1", // phoneme computation is slow, therefore, it caches results in the given folder. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages @@ -110,7 +110,7 @@ [ { "name": "ljspeech", - "path": "/root/LJSpeech-1.1/", + "path": "/home/erogol/Data/LJSpeech-1.1/", "meta_file_train": "metadata.csv", "meta_file_val": null } diff --git a/notebooks/Benchmark-PWGAN.ipynb b/notebooks/Benchmark-PWGAN.ipynb index 430d329f..f17f71f4 100644 --- a/notebooks/Benchmark-PWGAN.ipynb +++ b/notebooks/Benchmark-PWGAN.ipynb @@ -85,7 +85,10 @@ " if use_cuda:\n", " waveform = waveform.cpu()\n", " waveform = waveform.numpy()\n", - " print(\" > Run-time: {}\".format(time.time() - t_1))\n", + " rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n", + " print(waveform.shape)\n", + " print(\" > Run-time: {}\".format(time.time() - t_1))\n", + " print(\" > Real-time factor: {}\".format(rtf))\n", " if figures: \n", " visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec)) \n", " IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=False)) \n", From e6504cc9a44fb4d9587b6387c6aa28fa40296f18 Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 20 Feb 2020 12:48:45 +0100 Subject: [PATCH 08/11] config update, check arguments update and enable alternative arguments --- config.json | 7 +++---- utils/generic_utils.py | 8 +++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/config.json b/config.json index a7ed04a3..a113836f 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,6 @@ { "model": "Tacotron2", // one of the model in models/ - "run_name": "ljspeech-stf_params", + "run_name": "ljspeech-stft_params", "run_description": "tacotron2 cosntant stf parameters", // AUDIO PARAMETERS @@ -36,12 +36,11 @@ "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. // TRAINING - "batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "eval_batch_size":16, "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. - "gradual_training": [[0, 7, 64], [2000, 5, 64], [35000, 3, 32], [70000, 2, 32], [140000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. + "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. "loss_masking": true, // enable / disable loss masking against the sequence padding. - "grad_accum": 2, // if N > 1, enable gradient accumulation for N iterations. It is useful for low memory GPUs. // VALIDATION "run_eval": true, diff --git a/utils/generic_utils.py b/utils/generic_utils.py index a8de5bbb..d728eeb9 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -391,7 +391,9 @@ class KeepAverage(): self.update_value(key, value) -def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None): +def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None): + if alternative in c.keys() and c[alternative] is not None: + return if restricted: assert name in c.keys(), f' [!] {name} not defined in config.json' if name in c.keys(): @@ -417,8 +419,8 @@ def check_config(c): _check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) _check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) _check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) - _check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000) - _check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000) + _check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') + _check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') _check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) _check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) _check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) From a5d919cec8ddc99874b783d077bb9d273efbcf3c Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 25 Feb 2020 11:52:26 +0100 Subject: [PATCH 09/11] use 0 padding for stop tokens, which seems not working after padding_idx --- utils/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/data.py b/utils/data.py index 87343ec1..a8b87cb5 100644 --- a/utils/data.py +++ b/utils/data.py @@ -14,7 +14,7 @@ def prepare_data(inputs): def _pad_tensor(x, length): - _pad = 0 + _pad = 0. assert x.ndim == 2 x = np.pad( x, [[0, 0], [0, length - x.shape[1]]], From a68012aec2dc5fac2a0f05d1aa1ba5abac06d5b2 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 25 Feb 2020 14:17:20 +0100 Subject: [PATCH 10/11] BCE masked loss and padding stop_tokens with 0s not 1s --- layers/losses.py | 29 +++++++++++++++++++++++++++++ train.py | 8 ++++---- utils/data.py | 2 +- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/layers/losses.py b/layers/losses.py index 176e2f09..7e5671b2 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -96,3 +96,32 @@ class AttentionEntropyLoss(nn.Module): entropy = torch.distributions.Categorical(probs=align).entropy() loss = (entropy / np.log(align.shape[1])).mean() return loss + + +class BCELossMasked(nn.Module): + + def __init__(self, pos_weight): + super(BCELossMasked, self).__init__() + self.pos_weight = pos_weight + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() + loss = functional.binary_cross_entropy_with_logits( + x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum') + loss = loss / mask.sum() + return loss diff --git a/train.py b/train.py index 4cf366e3..1397b310 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from TTS.datasets.TTSDataset import MyDataset from distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) -from TTS.layers.losses import L1LossMasked, MSELossMasked +from TTS.layers.losses import L1LossMasked, MSELossMasked, BCELossMasked from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import ( NoamLR, check_update, count_parameters, create_experiment_folder, @@ -167,7 +167,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # loss computation stop_loss = criterion_st(stop_tokens, - stop_targets) if c.stopnet else torch.zeros(1) + stop_targets, mel_lengths) if c.stopnet else torch.zeros(1) if c.loss_masking: decoder_loss = criterion(decoder_output, mel_input, mel_lengths) if c.model in ["Tacotron", "TacotronGST"]: @@ -365,7 +365,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): # loss computation stop_loss = criterion_st( - stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) + stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1) if c.loss_masking: decoder_loss = criterion(decoder_output, mel_input, mel_lengths) @@ -571,7 +571,7 @@ def main(args): # pylint: disable=redefined-outer-name else: criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST" ] else nn.MSELoss() - criterion_st = nn.BCEWithLogitsLoss( + criterion_st = BCELossMasked( pos_weight=torch.tensor(10)) if c.stopnet else None if args.restore_path: diff --git a/utils/data.py b/utils/data.py index a8b87cb5..f2d7538a 100644 --- a/utils/data.py +++ b/utils/data.py @@ -31,7 +31,7 @@ def prepare_tensor(inputs, out_steps): def _pad_stop_target(x, length): - _pad = 1. + _pad = 0. assert x.ndim == 1 return np.pad( x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) From c825b63d4c6ee51ccf93c85af3bd26a6dd66fa33 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 2 Mar 2020 12:53:55 +0100 Subject: [PATCH 11/11] better version control for phonemizer --- datasets/preprocess.py | 2 +- utils/text/__init__.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/datasets/preprocess.py b/datasets/preprocess.py index 64efc665..029922d3 100644 --- a/datasets/preprocess.py +++ b/datasets/preprocess.py @@ -84,7 +84,7 @@ def mozilla_de(root_path, meta_file): for line in ttf: cols = line.strip().split('|') wav_file = cols[0].strip() - text = cols[1].strip() + text = cols[1].strip() folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" wav_file = os.path.join(root_path, folder_name, wav_file) items.append([text, wav_file, speaker_name]) diff --git a/utils/text/__init__.py b/utils/text/__init__.py index 0e6684d2..0fb47952 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import re +from packaging import version import phonemizer from phonemizer.phonemize import phonemize from TTS.utils.text import cleaners @@ -28,7 +29,7 @@ def text2phone(text, language): seperator = phonemizer.separator.Separator(' |', '', '|') #try: punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text) - if float(phonemizer.__version__) < 2.1: + if version.parse(phonemizer.__version__) < version.parse('2.1'): ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language) ph = ph[:-1].strip() # skip the last empty character # phonemizer does not tackle punctuations. Here we do. @@ -42,7 +43,7 @@ def text2phone(text, language): else: for punct in punctuations: ph = ph.replace('| |\n', '|'+punct+'| |', 1) - elif float(phonemizer.__version__) == 2.1: + elif version.parse(phonemizer.__version__) >= version.parse('2.1'): ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True) # this is a simple fix for phonemizer. # https://github.com/bootphon/phonemizer/issues/32