rename MyDataset -> TTSDataset

This commit is contained in:
Eren Gölge 2021-05-20 18:22:52 +02:00
parent d245b5d48f
commit 0f284841d1
9 changed files with 20 additions and 20 deletions

View File

@ -8,7 +8,7 @@ import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import load_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
@ -83,7 +83,7 @@ Example run:
preprocessor = importlib.import_module("TTS.tts.datasets.preprocess")
preprocessor = getattr(preprocessor, args.dataset)
meta_data = preprocessor(args.data_path, args.dataset_metafile)
dataset = MyDataset(
dataset = TTSDataset(
model.decoder.r,
C.text_cleaner,
compute_linear_spec=False,

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import parse_speakers
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
@ -22,7 +22,7 @@ use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False):
dataset = MyDataset(
dataset = TTSDataset(
r,
c.text_cleaner,
compute_linear_spec=False,

View File

@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import AlignTTSLoss
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval:
loader = None
else:
dataset = MyDataset(
dataset = TTSDataset(
r,
config.text_cleaner,
compute_linear_spec=False,

View File

@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval:
loader = None
else:
dataset = MyDataset(
dataset = TTSDataset(
r,
config.text_cleaner,
compute_linear_spec=False,

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import SpeedySpeechLoss
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
@ -39,7 +39,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval:
loader = None
else:
dataset = MyDataset(
dataset = TTSDataset(
r,
config.text_cleaner,
compute_linear_spec=False,

View File

@ -12,7 +12,7 @@ import torch
from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
@ -43,7 +43,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
loader = None
else:
if dataset is None:
dataset = MyDataset(
dataset = TTSDataset(
r,
config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron",

View File

@ -12,7 +12,7 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
class MyDataset(Dataset):
class TTSDataset(Dataset):
def __init__(
self,
outputs_per_step,
@ -117,12 +117,12 @@ class MyDataset(Dataset):
try:
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank
)
if enable_eos_bos:
@ -190,7 +190,7 @@ class MyDataset(Dataset):
item = args[0]
func_args = args[1]
text, wav_file, *_ = item
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
return phonemes
def compute_input_seq(self, num_workers=0):
@ -225,7 +225,7 @@ class MyDataset(Dataset):
with Pool(num_workers) as p:
phonemes = list(
tqdm.tqdm(
p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]),
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items),
)
)

View File

@ -22,7 +22,7 @@
"import numpy as np\n",
"from tqdm import tqdm as tqdm\n",
"from torch.utils.data import DataLoader\n",
"from TTS.tts.datasets.TTSDataset import MyDataset\n",
"from TTS.tts.datasets.TTSDataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.io import load_config\n",
@ -112,7 +112,7 @@
"preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
"dataset = MyDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
"dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
]
},

View File

@ -38,7 +38,7 @@ class TestTTSDataset(unittest.TestCase):
def _create_dataloader(self, batch_size, r, bgs):
items = ljspeech(c.data_path, "metadata.csv")
dataset = TTSDataset.MyDataset(
dataset = TTSDataset.TTSDataset(
r,
c.text_cleaner,
compute_linear_spec=True,