mirror of https://github.com/coqui-ai/TTS.git
rename MyDataset -> TTSDataset
This commit is contained in:
parent
d245b5d48f
commit
0f284841d1
|
@ -8,7 +8,7 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import load_checkpoint
|
from TTS.tts.utils.io import load_checkpoint
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
|
@ -83,7 +83,7 @@ Example run:
|
||||||
preprocessor = importlib.import_module("TTS.tts.datasets.preprocess")
|
preprocessor = importlib.import_module("TTS.tts.datasets.preprocess")
|
||||||
preprocessor = getattr(preprocessor, args.dataset)
|
preprocessor = getattr(preprocessor, args.dataset)
|
||||||
meta_data = preprocessor(args.data_path, args.dataset_metafile)
|
meta_data = preprocessor(args.data_path, args.dataset_metafile)
|
||||||
dataset = MyDataset(
|
dataset = TTSDataset(
|
||||||
model.decoder.r,
|
model.decoder.r,
|
||||||
C.text_cleaner,
|
C.text_cleaner,
|
||||||
compute_linear_spec=False,
|
compute_linear_spec=False,
|
||||||
|
|
|
@ -11,7 +11,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.speakers import parse_speakers
|
from TTS.tts.utils.speakers import parse_speakers
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
|
@ -22,7 +22,7 @@ use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, r, verbose=False):
|
def setup_loader(ap, r, verbose=False):
|
||||||
dataset = MyDataset(
|
dataset = TTSDataset(
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
compute_linear_spec=False,
|
compute_linear_spec=False,
|
||||||
|
|
|
@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.layers.losses import AlignTTSLoss
|
from TTS.tts.layers.losses import AlignTTSLoss
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
|
@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
if is_val and not config.run_eval:
|
if is_val and not config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = MyDataset(
|
dataset = TTSDataset(
|
||||||
r,
|
r,
|
||||||
config.text_cleaner,
|
config.text_cleaner,
|
||||||
compute_linear_spec=False,
|
compute_linear_spec=False,
|
||||||
|
|
|
@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
|
@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
if is_val and not config.run_eval:
|
if is_val and not config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = MyDataset(
|
dataset = TTSDataset(
|
||||||
r,
|
r,
|
||||||
config.text_cleaner,
|
config.text_cleaner,
|
||||||
compute_linear_spec=False,
|
compute_linear_spec=False,
|
||||||
|
|
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.layers.losses import SpeedySpeechLoss
|
from TTS.tts.layers.losses import SpeedySpeechLoss
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
|
@ -39,7 +39,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
if is_val and not config.run_eval:
|
if is_val and not config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = MyDataset(
|
dataset = TTSDataset(
|
||||||
r,
|
r,
|
||||||
config.text_cleaner,
|
config.text_cleaner,
|
||||||
compute_linear_spec=False,
|
compute_linear_spec=False,
|
||||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.layers.losses import TacotronLoss
|
from TTS.tts.layers.losses import TacotronLoss
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
|
@ -43,7 +43,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
dataset = MyDataset(
|
dataset = TTSDataset(
|
||||||
r,
|
r,
|
||||||
config.text_cleaner,
|
config.text_cleaner,
|
||||||
compute_linear_spec=config.model.lower() == "tacotron",
|
compute_linear_spec=config.model.lower() == "tacotron",
|
||||||
|
|
|
@ -12,7 +12,7 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||||
|
|
||||||
|
|
||||||
class MyDataset(Dataset):
|
class TTSDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
outputs_per_step,
|
outputs_per_step,
|
||||||
|
@ -117,12 +117,12 @@ class MyDataset(Dataset):
|
||||||
try:
|
try:
|
||||||
phonemes = np.load(cache_path)
|
phonemes = np.load(cache_path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
|
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||||
text, cache_path, cleaners, language, tp, add_blank
|
text, cache_path, cleaners, language, tp, add_blank
|
||||||
)
|
)
|
||||||
except (ValueError, IOError):
|
except (ValueError, IOError):
|
||||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||||
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
|
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||||
text, cache_path, cleaners, language, tp, add_blank
|
text, cache_path, cleaners, language, tp, add_blank
|
||||||
)
|
)
|
||||||
if enable_eos_bos:
|
if enable_eos_bos:
|
||||||
|
@ -190,7 +190,7 @@ class MyDataset(Dataset):
|
||||||
item = args[0]
|
item = args[0]
|
||||||
func_args = args[1]
|
func_args = args[1]
|
||||||
text, wav_file, *_ = item
|
text, wav_file, *_ = item
|
||||||
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
def compute_input_seq(self, num_workers=0):
|
def compute_input_seq(self, num_workers=0):
|
||||||
|
@ -225,7 +225,7 @@ class MyDataset(Dataset):
|
||||||
with Pool(num_workers) as p:
|
with Pool(num_workers) as p:
|
||||||
phonemes = list(
|
phonemes = list(
|
||||||
tqdm.tqdm(
|
tqdm.tqdm(
|
||||||
p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]),
|
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
|
||||||
total=len(self.items),
|
total=len(self.items),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"from tqdm import tqdm as tqdm\n",
|
"from tqdm import tqdm as tqdm\n",
|
||||||
"from torch.utils.data import DataLoader\n",
|
"from torch.utils.data import DataLoader\n",
|
||||||
"from TTS.tts.datasets.TTSDataset import MyDataset\n",
|
"from TTS.tts.datasets.TTSDataset import TTSDataset\n",
|
||||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||||
"from TTS.utils.audio import AudioProcessor\n",
|
"from TTS.utils.audio import AudioProcessor\n",
|
||||||
"from TTS.utils.io import load_config\n",
|
"from TTS.utils.io import load_config\n",
|
||||||
|
@ -112,7 +112,7 @@
|
||||||
"preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n",
|
"preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n",
|
||||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||||
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
|
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
|
||||||
"dataset = MyDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
|
"dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
|
||||||
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
|
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
@ -38,7 +38,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
|
|
||||||
def _create_dataloader(self, batch_size, r, bgs):
|
def _create_dataloader(self, batch_size, r, bgs):
|
||||||
items = ljspeech(c.data_path, "metadata.csv")
|
items = ljspeech(c.data_path, "metadata.csv")
|
||||||
dataset = TTSDataset.MyDataset(
|
dataset = TTSDataset.TTSDataset(
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
compute_linear_spec=True,
|
compute_linear_spec=True,
|
||||||
|
|
Loading…
Reference in New Issue