bug fixes

This commit is contained in:
Eren Golge 2018-12-18 12:58:09 +01:00
parent 3c8eb51713
commit 7cdeef1b5c
1 changed files with 4 additions and 5 deletions

View File

@ -20,6 +20,7 @@ from utils.generic_utils import (
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
from datasets.TTSDataset import MyDataset
from utils.audio import AudioProcessor from utils.audio import AudioProcessor
from utils.synthesis import synthesis from utils.synthesis import synthesis
from utils.logger import Logger from utils.logger import Logger
@ -44,8 +45,8 @@ def setup_loader(is_val=False):
ap=ap, ap=ap,
batch_group_size=0 if is_val else 8 * c.batch_size, batch_group_size=0 if is_val else 8 * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len, min_seq_len=0 if is_val else c.min_seq_len,
max_seq_len=float("inf") if is_val else c.max_seq_len max_seq_len=float("inf") if is_val else c.max_seq_len,
cached=False if c.dataset ~= "tts_cache" else True) cached=False if c.dataset != "tts_cache" else True)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size, batch_size=c.eval_batch_size if is_val else c.batch_size,
@ -444,7 +445,7 @@ if __name__ == '__main__':
default=False, default=False,
help='Do not verify commit integrity to run training.') help='Do not verify commit integrity to run training.')
parser.add_argument( parser.add_argument(
'--data_path', type=str, default='', default='Defines the data path. It overwrites config.json.') '--data_path', type=str, default='', help='Defines the data path. It overwrites config.json.')
args = parser.parse_args() args = parser.parse_args()
# setup output paths and read configs # setup output paths and read configs
@ -467,8 +468,6 @@ if __name__ == '__main__':
# Conditional imports # Conditional imports
preprocessor = importlib.import_module('datasets.preprocess') preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower()) preprocessor = getattr(preprocessor, c.dataset.lower())
MyDataset = importlib.import_module('datasets.' + c.data_loader)
MyDataset = getattr(MyDataset, "MyDataset")
audio = importlib.import_module('utils.' + c.audio['audio_processor']) audio = importlib.import_module('utils.' + c.audio['audio_processor'])
AudioProcessor = getattr(audio, 'AudioProcessor') AudioProcessor = getattr(audio, 'AudioProcessor')