mirror of https://github.com/coqui-ai/TTS.git
Remove redundant dataset import
This commit is contained in:
parent
78ad7021c5
commit
dbfc489775
59
train.py
59
train.py
|
@ -23,7 +23,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
count_parameters, check_update, get_commit_hash)
|
count_parameters, check_update, get_commit_hash)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
|
||||||
|
@ -296,41 +295,41 @@ def main(args):
|
||||||
Dataset = getattr(mod, c.dataset+"Dataset")
|
Dataset = getattr(mod, c.dataset+"Dataset")
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
train_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_train),
|
train_dataset = Dataset(os.path.join(c.data_path, c.meta_file_train),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.min_level_db,
|
c.min_level_db,
|
||||||
c.frame_shift_ms,
|
c.frame_shift_ms,
|
||||||
c.frame_length_ms,
|
c.frame_length_ms,
|
||||||
c.preemphasis,
|
c.preemphasis,
|
||||||
c.ref_level_db,
|
c.ref_level_db,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.power,
|
c.power,
|
||||||
min_seq_len=c.min_seq_len
|
min_seq_len=c.min_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||||
drop_last=False, num_workers=c.num_loader_workers,
|
drop_last=False, num_workers=c.num_loader_workers,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
val_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_val),
|
val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.min_level_db,
|
c.min_level_db,
|
||||||
c.frame_shift_ms,
|
c.frame_shift_ms,
|
||||||
c.frame_length_ms,
|
c.frame_length_ms,
|
||||||
c.preemphasis,
|
c.preemphasis,
|
||||||
c.ref_level_db,
|
c.ref_level_db,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.power
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
||||||
shuffle=False, collate_fn=val_dataset.collate_fn,
|
shuffle=False, collate_fn=val_dataset.collate_fn,
|
||||||
|
|
Loading…
Reference in New Issue