From 39321d02befe17ad49194c0d42f8020d8fb8a856 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 30 Nov 2023 13:03:16 +0100 Subject: [PATCH] fix: correctly strip/restore initial punctuation (#3336) * refactor(punctuation): remove orphan code for handling lone punctuation The case of lone punctuation is already handled at the top of restore(). The removed if statement would never be called and would in fact raise an AttributeError because the _punc_index named tuple doesn't have the attribute `mark`. * refactor(punctuation): remove unused argument * fix(punctuation): correctly handle initial punctuation Stripping and restoring initial punctuation didn't work correctly because the string-splitting caused an additional empty string to be inserted in the text list (because `".A".split(".")` => `["", "A"]`). Now, an initial empty string is skipped and relevant test cases are added. Fixes #3333 --- TTS/tts/utils/text/punctuation.py | 23 +++++++++++------------ tests/text_tests/test_punctuation.py | 5 +++++ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index 8d199cc5..36c467d0 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -15,7 +15,6 @@ class PuncPosition(Enum): BEGIN = 0 END = 1 MIDDLE = 2 - ALONE = 3 class Punctuation: @@ -92,7 +91,7 @@ class Punctuation: return [text], [] # the text is only punctuations if len(matches) == 1 and matches[0].group() == text: - return [], [_PUNC_IDX(text, PuncPosition.ALONE)] + return [], [_PUNC_IDX(text, PuncPosition.BEGIN)] # build a punctuation map to be used later to restore punctuations puncs = [] for match in matches: @@ -107,11 +106,14 @@ class Punctuation: for idx, punc in enumerate(puncs): split = text.split(punc.punc) prefix, suffix = split[0], punc.punc.join(split[1:]) + text = suffix + if prefix == "": + # We don't want to insert an empty string in case of initial punctuation + continue splitted_text.append(prefix) # if the text does not end with a punctuation, add it to the last item if idx == len(puncs) - 1 and len(suffix) > 0: splitted_text.append(suffix) - text = suffix return splitted_text, puncs @classmethod @@ -127,10 +129,10 @@ class Punctuation: ['This is', 'example'], ['.', '!'] -> "This is. example!" """ - return cls._restore(text, puncs, 0) + return cls._restore(text, puncs) @classmethod - def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements + def _restore(cls, text, puncs): # pylint: disable=too-many-return-statements """Auxiliary method for Punctuation.restore()""" if not puncs: return text @@ -142,21 +144,18 @@ class Punctuation: current = puncs[0] if current.position == PuncPosition.BEGIN: - return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num) + return cls._restore([current.punc + text[0]] + text[1:], puncs[1:]) if current.position == PuncPosition.END: - return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1) - - if current.position == PuncPosition.ALONE: - return [current.mark] + cls._restore(text, puncs[1:], num + 1) + return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:]) # POSITION == MIDDLE if len(text) == 1: # pragma: nocover # a corner case where the final part of an intermediate # mark (I) has not been phonemized - return cls._restore([text[0] + current.punc], puncs[1:], num) + return cls._restore([text[0] + current.punc], puncs[1:]) - return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num) + return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:]) # if __name__ == "__main__": diff --git a/tests/text_tests/test_punctuation.py b/tests/text_tests/test_punctuation.py index 141c10e4..bb7b11ed 100644 --- a/tests/text_tests/test_punctuation.py +++ b/tests/text_tests/test_punctuation.py @@ -11,6 +11,11 @@ class PunctuationTest(unittest.TestCase): ("This, is my text ... to be striped !! from text", "This is my text to be striped from text"), ("This, is my text ... to be striped from text?", "This is my text to be striped from text"), ("This, is my text to be striped from text", "This is my text to be striped from text"), + (".", ""), + (" . ", ""), + ("!!! Attention !!!", "Attention"), + ("!!! Attention !!! This is just a ... test.", "Attention This is just a test"), + ("!!! Attention! This is just a ... test.", "Attention This is just a test"), ] def test_get_set_puncs(self):