mirror of https://github.com/coqui-ai/TTS.git
fix: correctly strip/restore initial punctuation (#3336)
* refactor(punctuation): remove orphan code for handling lone punctuation The case of lone punctuation is already handled at the top of restore(). The removed if statement would never be called and would in fact raise an AttributeError because the _punc_index named tuple doesn't have the attribute `mark`. * refactor(punctuation): remove unused argument * fix(punctuation): correctly handle initial punctuation Stripping and restoring initial punctuation didn't work correctly because the string-splitting caused an additional empty string to be inserted in the text list (because `".A".split(".")` => `["", "A"]`). Now, an initial empty string is skipped and relevant test cases are added. Fixes #3333
This commit is contained in:
parent
93283385e0
commit
39321d02be
|
@ -15,7 +15,6 @@ class PuncPosition(Enum):
|
|||
BEGIN = 0
|
||||
END = 1
|
||||
MIDDLE = 2
|
||||
ALONE = 3
|
||||
|
||||
|
||||
class Punctuation:
|
||||
|
@ -92,7 +91,7 @@ class Punctuation:
|
|||
return [text], []
|
||||
# the text is only punctuations
|
||||
if len(matches) == 1 and matches[0].group() == text:
|
||||
return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
|
||||
return [], [_PUNC_IDX(text, PuncPosition.BEGIN)]
|
||||
# build a punctuation map to be used later to restore punctuations
|
||||
puncs = []
|
||||
for match in matches:
|
||||
|
@ -107,11 +106,14 @@ class Punctuation:
|
|||
for idx, punc in enumerate(puncs):
|
||||
split = text.split(punc.punc)
|
||||
prefix, suffix = split[0], punc.punc.join(split[1:])
|
||||
text = suffix
|
||||
if prefix == "":
|
||||
# We don't want to insert an empty string in case of initial punctuation
|
||||
continue
|
||||
splitted_text.append(prefix)
|
||||
# if the text does not end with a punctuation, add it to the last item
|
||||
if idx == len(puncs) - 1 and len(suffix) > 0:
|
||||
splitted_text.append(suffix)
|
||||
text = suffix
|
||||
return splitted_text, puncs
|
||||
|
||||
@classmethod
|
||||
|
@ -127,10 +129,10 @@ class Punctuation:
|
|||
['This is', 'example'], ['.', '!'] -> "This is. example!"
|
||||
|
||||
"""
|
||||
return cls._restore(text, puncs, 0)
|
||||
return cls._restore(text, puncs)
|
||||
|
||||
@classmethod
|
||||
def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
|
||||
def _restore(cls, text, puncs): # pylint: disable=too-many-return-statements
|
||||
"""Auxiliary method for Punctuation.restore()"""
|
||||
if not puncs:
|
||||
return text
|
||||
|
@ -142,21 +144,18 @@ class Punctuation:
|
|||
current = puncs[0]
|
||||
|
||||
if current.position == PuncPosition.BEGIN:
|
||||
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
|
||||
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:])
|
||||
|
||||
if current.position == PuncPosition.END:
|
||||
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
|
||||
|
||||
if current.position == PuncPosition.ALONE:
|
||||
return [current.mark] + cls._restore(text, puncs[1:], num + 1)
|
||||
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:])
|
||||
|
||||
# POSITION == MIDDLE
|
||||
if len(text) == 1: # pragma: nocover
|
||||
# a corner case where the final part of an intermediate
|
||||
# mark (I) has not been phonemized
|
||||
return cls._restore([text[0] + current.punc], puncs[1:], num)
|
||||
return cls._restore([text[0] + current.punc], puncs[1:])
|
||||
|
||||
return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
|
||||
return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:])
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
|
|
@ -11,6 +11,11 @@ class PunctuationTest(unittest.TestCase):
|
|||
("This, is my text ... to be striped !! from text", "This is my text to be striped from text"),
|
||||
("This, is my text ... to be striped from text?", "This is my text to be striped from text"),
|
||||
("This, is my text to be striped from text", "This is my text to be striped from text"),
|
||||
(".", ""),
|
||||
(" . ", ""),
|
||||
("!!! Attention !!!", "Attention"),
|
||||
("!!! Attention !!! This is just a ... test.", "Attention This is just a test"),
|
||||
("!!! Attention! This is just a ... test.", "Attention This is just a test"),
|
||||
]
|
||||
|
||||
def test_get_set_puncs(self):
|
||||
|
|
Loading…
Reference in New Issue