diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 0a4337da..e14ff433 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -8,7 +8,7 @@ import torch from torch.utils.data import DataLoader from tqdm import tqdm -from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import load_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols @@ -83,7 +83,7 @@ Example run: preprocessor = importlib.import_module("TTS.tts.datasets.preprocess") preprocessor = getattr(preprocessor, args.dataset) meta_data = preprocessor(args.data_path, args.dataset_metafile) - dataset = MyDataset( + dataset = TTSDataset( model.decoder.r, C.text_cleaner, compute_linear_spec=False, diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index ace7464a..e8814a11 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -11,7 +11,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets.preprocess import load_meta_data -from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols @@ -22,7 +22,7 @@ use_cuda = torch.cuda.is_available() def setup_loader(ap, r, verbose=False): - dataset = MyDataset( + dataset = TTSDataset( r, c.text_cleaner, compute_linear_spec=False, diff --git a/TTS/bin/train_align_tts.py b/TTS/bin/train_align_tts.py index 7e3921b0..f5658dd2 100644 --- a/TTS/bin/train_align_tts.py +++ b/TTS/bin/train_align_tts.py @@ -14,7 +14,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.tts.datasets.preprocess import load_meta_data -from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.layers.losses import AlignTTSLoss from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint @@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False): if is_val and not config.run_eval: loader = None else: - dataset = MyDataset( + dataset = TTSDataset( r, config.text_cleaner, compute_linear_spec=False, diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index e93a4e8a..50e95a2b 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.tts.datasets.preprocess import load_meta_data -from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint @@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False): if is_val and not config.run_eval: loader = None else: - dataset = MyDataset( + dataset = TTSDataset( r, config.text_cleaner, compute_linear_spec=False, diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py index 2fba3df1..4ab0c899 100644 --- a/TTS/bin/train_speedy_speech.py +++ b/TTS/bin/train_speedy_speech.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.tts.datasets.preprocess import load_meta_data -from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.layers.losses import SpeedySpeechLoss from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint @@ -39,7 +39,7 @@ def setup_loader(ap, r, is_val=False, verbose=False): if is_val and not config.run_eval: loader = None else: - dataset = MyDataset( + dataset = TTSDataset( r, config.text_cleaner, compute_linear_spec=False, diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 9685d0d7..098a8d3f 100755 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -12,7 +12,7 @@ import torch from torch.utils.data import DataLoader from TTS.tts.datasets.preprocess import load_meta_data -from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.layers.losses import TacotronLoss from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint @@ -43,7 +43,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): loader = None else: if dataset is None: - dataset = MyDataset( + dataset = TTSDataset( r, config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron", diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 4ca93232..cbb0a593 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -12,7 +12,7 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence -class MyDataset(Dataset): +class TTSDataset(Dataset): def __init__( self, outputs_per_step, @@ -117,12 +117,12 @@ class MyDataset(Dataset): try: phonemes = np.load(cache_path) except FileNotFoundError: - phonemes = MyDataset._generate_and_cache_phoneme_sequence( + phonemes = TTSDataset._generate_and_cache_phoneme_sequence( text, cache_path, cleaners, language, tp, add_blank ) except (ValueError, IOError): print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) - phonemes = MyDataset._generate_and_cache_phoneme_sequence( + phonemes = TTSDataset._generate_and_cache_phoneme_sequence( text, cache_path, cleaners, language, tp, add_blank ) if enable_eos_bos: @@ -190,7 +190,7 @@ class MyDataset(Dataset): item = args[0] func_args = args[1] text, wav_file, *_ = item - phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) + phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) return phonemes def compute_input_seq(self, num_workers=0): @@ -225,7 +225,7 @@ class MyDataset(Dataset): with Pool(num_workers) as p: phonemes = list( tqdm.tqdm( - p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]), + p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]), total=len(self.items), ) ) diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index dc35e86f..bdc7c955 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -22,7 +22,7 @@ "import numpy as np\n", "from tqdm import tqdm as tqdm\n", "from torch.utils.data import DataLoader\n", - "from TTS.tts.datasets.TTSDataset import MyDataset\n", + "from TTS.tts.datasets.TTSDataset import TTSDataset\n", "from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.io import load_config\n", @@ -112,7 +112,7 @@ "preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", - "dataset = MyDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", + "dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" ] }, diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index e2dba37a..053da516 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -38,7 +38,7 @@ class TestTTSDataset(unittest.TestCase): def _create_dataloader(self, batch_size, r, bgs): items = ljspeech(c.data_path, "metadata.csv") - dataset = TTSDataset.MyDataset( + dataset = TTSDataset.TTSDataset( r, c.text_cleaner, compute_linear_spec=True,