bug fixes

This commit is contained in:
Eren Golge 2018-12-18 12:58:09 +01:00
parent 4826e7db9c
commit f5c972dee6
1 changed files with 4 additions and 5 deletions

View File

@ -20,6 +20,7 @@ from utils.generic_utils import (
from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
from datasets.TTSDataset import MyDataset
from utils.audio import AudioProcessor
from utils.synthesis import synthesis
from utils.logger import Logger
@ -44,8 +45,8 @@ def setup_loader(is_val=False):
ap=ap,
batch_group_size=0 if is_val else 8 * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len,
max_seq_len=float("inf") if is_val else c.max_seq_len
cached=False if c.dataset ~= "tts_cache" else True)
max_seq_len=float("inf") if is_val else c.max_seq_len,
cached=False if c.dataset != "tts_cache" else True)
loader = DataLoader(
dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size,
@ -444,7 +445,7 @@ if __name__ == '__main__':
default=False,
help='Do not verify commit integrity to run training.')
parser.add_argument(
'--data_path', type=str, help='Defines the data path. It overwrites config.json.', default='')
'--data_path', type=str, default='', help='Defines the data path. It overwrites config.json.')
args = parser.parse_args()
# setup output paths and read configs
@ -467,8 +468,6 @@ if __name__ == '__main__':
# Conditional imports
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower())
MyDataset = importlib.import_module('datasets.' + c.data_loader)
MyDataset = getattr(MyDataset, "MyDataset")
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
AudioProcessor = getattr(audio, 'AudioProcessor')