mirror of https://github.com/coqui-ai/TTS.git
Fix pylint issues
This commit is contained in:
parent
ac9416fb86
commit
3fbbebd74d
|
@ -82,8 +82,8 @@ class VitsConfig(BaseTTSConfig):
|
||||||
add_blank (bool):
|
add_blank (bool):
|
||||||
If true, a blank token is added in between every character. Defaults to `True`.
|
If true, a blank token is added in between every character. Defaults to `True`.
|
||||||
|
|
||||||
test_sentences (List[str]):
|
test_sentences (List[List]):
|
||||||
List of sentences to be used for testing.
|
List of sentences with speaker and language information to be used for testing.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||||
|
|
|
@ -740,7 +740,7 @@ class Vits(BaseTTS):
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||||
except: # pylint: disable=bare-except
|
except: # pylint: disable=bare-except
|
||||||
print(" !! Error creating Test Sentence -", idx)
|
print(" !! Error creating Test Sentence -", idx)
|
||||||
return test_figures, test_audios
|
return test_figures, test_audios
|
||||||
|
|
||||||
def get_optimizer(self) -> List:
|
def get_optimizer(self) -> List:
|
||||||
|
@ -837,5 +837,3 @@ class Vits(BaseTTS):
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ class LanguageManager:
|
||||||
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
|
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
|
||||||
>>> language_id_mapper = manager.language_ids
|
>>> language_id_mapper = manager.language_ids
|
||||||
"""
|
"""
|
||||||
num_languages: int = 0
|
|
||||||
language_id_mapping: Dict = {}
|
language_id_mapping: Dict = {}
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -77,7 +76,6 @@ class LanguageManager:
|
||||||
file_path (str): Path to the target json file.
|
file_path (str): Path to the target json file.
|
||||||
"""
|
"""
|
||||||
self.language_id_mapping = self._load_json(file_path)
|
self.language_id_mapping = self._load_json(file_path)
|
||||||
self.num_languages = len(self.language_id_mapping)
|
|
||||||
|
|
||||||
def save_language_ids_to_file(self, file_path: str) -> None:
|
def save_language_ids_to_file(self, file_path: str) -> None:
|
||||||
"""Save language IDs to a json file.
|
"""Save language IDs to a json file.
|
||||||
|
@ -99,7 +97,7 @@ def _set_file_path(path):
|
||||||
return path_continue
|
return path_continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> LanguageManager:
|
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
|
||||||
"""Initiate a `LanguageManager` instance by the provided config.
|
"""Initiate a `LanguageManager` instance by the provided config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -135,4 +133,4 @@ def get_language_weighted_sampler(items: list):
|
||||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||||
weight_language = 1. / language_count
|
weight_language = 1. / language_count
|
||||||
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
||||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
||||||
|
|
|
@ -142,4 +142,4 @@ def multilingual_cleaners(text):
|
||||||
text = replace_symbols(text, lang=None)
|
text = replace_symbols(text, lang=None)
|
||||||
text = remove_aux_symbols(text)
|
text = remove_aux_symbols(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
|
@ -38,6 +38,11 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
|
|
||||||
def _create_dataloader(self, batch_size, r, bgs):
|
def _create_dataloader(self, batch_size, r, bgs):
|
||||||
items = ljspeech(c.data_path, "metadata.csv")
|
items = ljspeech(c.data_path, "metadata.csv")
|
||||||
|
|
||||||
|
# add a default language because now the TTSDataset expect a language
|
||||||
|
language = ""
|
||||||
|
items = [[*item, language] for item in items]
|
||||||
|
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
|
|
Loading…
Reference in New Issue