diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index 36e1f194..414ac253 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -19,12 +19,25 @@ class PuncPosition(Enum): class Punctuation: - """Handle punctuations characters in text. + """Handle punctuations in text. Just strip punctuations from text or strip and restore them later. Args: puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`. + + Example: + >>> punc = Punctuation() + >>> punc.strip("This is. example !") + 'This is example' + + >>> text_striped, punc_map = punc.strip_to_restore("This is. example !") + >>> ' '.join(text_striped) + 'This is example' + + >>> text_restored = punc.restore(text_striped, punc_map) + >>> text_restored[0] + 'This is. example !' """ def __init__(self, puncs: str = _DEF_PUNCS): @@ -43,7 +56,7 @@ class Punctuation: def puncs(self, value): if not isinstance(value, six.string_types): raise ValueError("[!] Punctuations must be of type str.") - self._puncs = "".join(set(value)) + self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+") def strip(self, text): @@ -56,7 +69,7 @@ class Punctuation: "This is. example !" -> "This is example " """ - return re.sub(self.puncs_regular_exp, " ", text).strip() + return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip() def strip_to_restore(self, text): """Remove punctuations from text to restore them later. diff --git a/tests/text_tests/test_punctuation.py b/tests/text_tests/test_punctuation.py new file mode 100644 index 00000000..f349bc50 --- /dev/null +++ b/tests/text_tests/test_punctuation.py @@ -0,0 +1,30 @@ +import unittest +from TTS.tts.utils.text.punctuation import Punctuation, _DEF_PUNCS + +class PunctuationTest(unittest.TestCase): + def setUp(self): + self.punctuation = Punctuation() + self.test_texts = [("This, is my text ... to be striped !! from text?", "This is my text to be striped from text"), + ("This, is my text ... to be striped !! from text", "This is my text to be striped from text"), + ("This, is my text ... to be striped from text?", "This is my text to be striped from text"), + ("This, is my text to be striped from text", "This is my text to be striped from text") + ] + + def test_get_set_puncs(self): + self.punctuation.puncs = "-=" + self.assertEqual(self.punctuation.puncs, "-=") + + self.punctuation.puncs = _DEF_PUNCS + self.assertEqual(self.punctuation.puncs, _DEF_PUNCS) + + def test_strip_punc(self): + for text, gt in self.test_texts: + text_striped = self.punctuation.strip(text) + self.assertEqual(text_striped, gt) + + def test_strip_restore(self): + for text, gt in self.test_texts: + text_striped, puncs_map = self.punctuation.strip_to_restore(text) + text_restored = self.punctuation.restore(text_striped, puncs_map) + self.assertEqual(' '.join(text_striped), gt) + self.assertEqual(text_restored[0], text)