rename MyDataset -> TTSDataset

This commit is contained in:
Eren Gölge 2021-05-20 18:22:52 +02:00
parent d245b5d48f
commit 0f284841d1
9 changed files with 20 additions and 20 deletions

View File

@ -8,7 +8,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import load_checkpoint from TTS.tts.utils.io import load_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
@ -83,7 +83,7 @@ Example run:
preprocessor = importlib.import_module("TTS.tts.datasets.preprocess") preprocessor = importlib.import_module("TTS.tts.datasets.preprocess")
preprocessor = getattr(preprocessor, args.dataset) preprocessor = getattr(preprocessor, args.dataset)
meta_data = preprocessor(args.data_path, args.dataset_metafile) meta_data = preprocessor(args.data_path, args.dataset_metafile)
dataset = MyDataset( dataset = TTSDataset(
model.decoder.r, model.decoder.r,
C.text_cleaner, C.text_cleaner,
compute_linear_spec=False, compute_linear_spec=False,

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.speakers import parse_speakers
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
@ -22,7 +22,7 @@ use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False): def setup_loader(ap, r, verbose=False):
dataset = MyDataset( dataset = TTSDataset(
r, r,
c.text_cleaner, c.text_cleaner,
compute_linear_spec=False, compute_linear_spec=False,

View File

@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import AlignTTSLoss from TTS.tts.layers.losses import AlignTTSLoss
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval: if is_val and not config.run_eval:
loader = None loader = None
else: else:
dataset = MyDataset( dataset = TTSDataset(
r, r,
config.text_cleaner, config.text_cleaner,
compute_linear_spec=False, compute_linear_spec=False,

View File

@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval: if is_val and not config.run_eval:
loader = None loader = None
else: else:
dataset = MyDataset( dataset = TTSDataset(
r, r,
config.text_cleaner, config.text_cleaner,
compute_linear_spec=False, compute_linear_spec=False,

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import SpeedySpeechLoss from TTS.tts.layers.losses import SpeedySpeechLoss
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
@ -39,7 +39,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval: if is_val and not config.run_eval:
loader = None loader = None
else: else:
dataset = MyDataset( dataset = TTSDataset(
r, r,
config.text_cleaner, config.text_cleaner,
compute_linear_spec=False, compute_linear_spec=False,

View File

@ -12,7 +12,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import TacotronLoss from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
@ -43,7 +43,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
loader = None loader = None
else: else:
if dataset is None: if dataset is None:
dataset = MyDataset( dataset = TTSDataset(
r, r,
config.text_cleaner, config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron", compute_linear_spec=config.model.lower() == "tacotron",

View File

@ -12,7 +12,7 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
class MyDataset(Dataset): class TTSDataset(Dataset):
def __init__( def __init__(
self, self,
outputs_per_step, outputs_per_step,
@ -117,12 +117,12 @@ class MyDataset(Dataset):
try: try:
phonemes = np.load(cache_path) phonemes = np.load(cache_path)
except FileNotFoundError: except FileNotFoundError:
phonemes = MyDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank text, cache_path, cleaners, language, tp, add_blank
) )
except (ValueError, IOError): except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = MyDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank text, cache_path, cleaners, language, tp, add_blank
) )
if enable_eos_bos: if enable_eos_bos:
@ -190,7 +190,7 @@ class MyDataset(Dataset):
item = args[0] item = args[0]
func_args = args[1] func_args = args[1]
text, wav_file, *_ = item text, wav_file, *_ = item
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
return phonemes return phonemes
def compute_input_seq(self, num_workers=0): def compute_input_seq(self, num_workers=0):
@ -225,7 +225,7 @@ class MyDataset(Dataset):
with Pool(num_workers) as p: with Pool(num_workers) as p:
phonemes = list( phonemes = list(
tqdm.tqdm( tqdm.tqdm(
p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]), p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items), total=len(self.items),
) )
) )

View File

@ -22,7 +22,7 @@
"import numpy as np\n", "import numpy as np\n",
"from tqdm import tqdm as tqdm\n", "from tqdm import tqdm as tqdm\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"from TTS.tts.datasets.TTSDataset import MyDataset\n", "from TTS.tts.datasets.TTSDataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.io import load_config\n", "from TTS.utils.io import load_config\n",
@ -112,7 +112,7 @@
"preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n", "preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
"dataset = MyDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", "dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
] ]
}, },

View File

@ -38,7 +38,7 @@ class TestTTSDataset(unittest.TestCase):
def _create_dataloader(self, batch_size, r, bgs): def _create_dataloader(self, batch_size, r, bgs):
items = ljspeech(c.data_path, "metadata.csv") items = ljspeech(c.data_path, "metadata.csv")
dataset = TTSDataset.MyDataset( dataset = TTSDataset.TTSDataset(
r, r,
c.text_cleaner, c.text_cleaner,
compute_linear_spec=True, compute_linear_spec=True,