mirror of https://github.com/coqui-ai/TTS.git
Test punctuations
This commit is contained in:
parent
d8bdeb8b8f
commit
79a84410f2
|
@ -19,12 +19,25 @@ class PuncPosition(Enum):
|
|||
|
||||
|
||||
class Punctuation:
|
||||
"""Handle punctuations characters in text.
|
||||
"""Handle punctuations in text.
|
||||
|
||||
Just strip punctuations from text or strip and restore them later.
|
||||
|
||||
Args:
|
||||
puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
|
||||
|
||||
Example:
|
||||
>>> punc = Punctuation()
|
||||
>>> punc.strip("This is. example !")
|
||||
'This is example'
|
||||
|
||||
>>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
|
||||
>>> ' '.join(text_striped)
|
||||
'This is example'
|
||||
|
||||
>>> text_restored = punc.restore(text_striped, punc_map)
|
||||
>>> text_restored[0]
|
||||
'This is. example !'
|
||||
"""
|
||||
|
||||
def __init__(self, puncs: str = _DEF_PUNCS):
|
||||
|
@ -43,7 +56,7 @@ class Punctuation:
|
|||
def puncs(self, value):
|
||||
if not isinstance(value, six.string_types):
|
||||
raise ValueError("[!] Punctuations must be of type str.")
|
||||
self._puncs = "".join(set(value))
|
||||
self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
|
||||
self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+")
|
||||
|
||||
def strip(self, text):
|
||||
|
@ -56,7 +69,7 @@ class Punctuation:
|
|||
|
||||
"This is. example !" -> "This is example "
|
||||
"""
|
||||
return re.sub(self.puncs_regular_exp, " ", text).strip()
|
||||
return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
|
||||
|
||||
def strip_to_restore(self, text):
|
||||
"""Remove punctuations from text to restore them later.
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import unittest
|
||||
from TTS.tts.utils.text.punctuation import Punctuation, _DEF_PUNCS
|
||||
|
||||
class PunctuationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.punctuation = Punctuation()
|
||||
self.test_texts = [("This, is my text ... to be striped !! from text?", "This is my text to be striped from text"),
|
||||
("This, is my text ... to be striped !! from text", "This is my text to be striped from text"),
|
||||
("This, is my text ... to be striped from text?", "This is my text to be striped from text"),
|
||||
("This, is my text to be striped from text", "This is my text to be striped from text")
|
||||
]
|
||||
|
||||
def test_get_set_puncs(self):
|
||||
self.punctuation.puncs = "-="
|
||||
self.assertEqual(self.punctuation.puncs, "-=")
|
||||
|
||||
self.punctuation.puncs = _DEF_PUNCS
|
||||
self.assertEqual(self.punctuation.puncs, _DEF_PUNCS)
|
||||
|
||||
def test_strip_punc(self):
|
||||
for text, gt in self.test_texts:
|
||||
text_striped = self.punctuation.strip(text)
|
||||
self.assertEqual(text_striped, gt)
|
||||
|
||||
def test_strip_restore(self):
|
||||
for text, gt in self.test_texts:
|
||||
text_striped, puncs_map = self.punctuation.strip_to_restore(text)
|
||||
text_restored = self.punctuation.restore(text_striped, puncs_map)
|
||||
self.assertEqual(' '.join(text_striped), gt)
|
||||
self.assertEqual(text_restored[0], text)
|
Loading…
Reference in New Issue