Fix Punctuation

This commit is contained in:
Eren Gölge 2021-11-19 10:39:21 +01:00
parent ff7c385838
commit d8bdeb8b8f
3 changed files with 32 additions and 29 deletions

View File

@ -74,36 +74,36 @@ phonemes = create_phonemes(_phonemes, _punctuations, _pad, _eos, _bos, _blank)
class BaseCharacters:
"""🐸BaseCharacters class
Every new character class should inherit from this.
Every new character class should inherit from this.
Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```.
Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```.
If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method.
If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method.
Args:
characters (str):
Main set of characters to be used in the vocabulary.
Args:
characters (str):
Main set of characters to be used in the vocabulary.
punctuations (str):
Characters to be treated as punctuation.
punctuations (str):
Characters to be treated as punctuation.
pad (str):
Special padding character that would be ignored by the model.
pad (str):
Special padding character that would be ignored by the model.
eos (str):
End of the sentence character.
eos (str):
End of the sentence character.
bos (str):
Beginning of the sentence character.
bos (str):
Beginning of the sentence character.
blank (str):
Optional character used between characters by some models for better prosody.
blank (str):
Optional character used between characters by some models for better prosody.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
el
is_sorted (bool):
Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True.
is_unique (bool):
Remove duplicates from the provided characters. Defaults to True.
el
is_sorted (bool):
Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True.
"""
def __init__(
@ -196,10 +196,10 @@ el
if self.is_sorted:
_vocab = sorted(_vocab)
_vocab = list(_vocab)
_vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
_vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab
_vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
_vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
_vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
self._vocab = _vocab + list(self._punctuations)
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
@ -214,7 +214,7 @@ el
def id_to_char(self, idx: int) -> str:
return self._id_to_char[idx]
def print_log(self, level:int=0):
def print_log(self, level: int = 0):
"""
Prints the vocabulary in a nice format.
"""

View File

@ -91,10 +91,13 @@ class Punctuation:
puncs.append(_PUNC_IDX(match.group(), position))
# convert str text to a List[str], each item is separated by a punctuation
splitted_text = []
for punc in puncs:
for idx, punc in enumerate(puncs):
split = text.split(punc.punc)
prefix, suffix = split[0], punc.punc.join(split[1:])
splitted_text.append(prefix)
# if the text does not end with a punctuation, add it to the last item
if idx == len(puncs) - 1 and len(suffix) > 0:
splitted_text.append(suffix)
text = suffix
return splitted_text, puncs
@ -126,7 +129,7 @@ class Punctuation:
current = puncs[0]
if current.position == PuncPosition.BEGIN:
return cls._restore([current.mark + text[0]] + text[1:], puncs[1:], num)
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
if current.position == PuncPosition.END:
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)

View File

@ -1,8 +1,8 @@
from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
class TTSTokenizer: