mirror of https://github.com/coqui-ai/TTS.git
bug fixes
This commit is contained in:
parent
3c8eb51713
commit
7cdeef1b5c
9
train.py
9
train.py
|
@ -20,6 +20,7 @@ from utils.generic_utils import (
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
from datasets.TTSDataset import MyDataset
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
from utils.synthesis import synthesis
|
from utils.synthesis import synthesis
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
|
@ -44,8 +45,8 @@ def setup_loader(is_val=False):
|
||||||
ap=ap,
|
ap=ap,
|
||||||
batch_group_size=0 if is_val else 8 * c.batch_size,
|
batch_group_size=0 if is_val else 8 * c.batch_size,
|
||||||
min_seq_len=0 if is_val else c.min_seq_len,
|
min_seq_len=0 if is_val else c.min_seq_len,
|
||||||
max_seq_len=float("inf") if is_val else c.max_seq_len
|
max_seq_len=float("inf") if is_val else c.max_seq_len,
|
||||||
cached=False if c.dataset ~= "tts_cache" else True)
|
cached=False if c.dataset != "tts_cache" else True)
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||||
|
@ -444,7 +445,7 @@ if __name__ == '__main__':
|
||||||
default=False,
|
default=False,
|
||||||
help='Do not verify commit integrity to run training.')
|
help='Do not verify commit integrity to run training.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--data_path', type=str, default='', default='Defines the data path. It overwrites config.json.')
|
'--data_path', type=str, default='', help='Defines the data path. It overwrites config.json.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
|
@ -467,8 +468,6 @@ if __name__ == '__main__':
|
||||||
# Conditional imports
|
# Conditional imports
|
||||||
preprocessor = importlib.import_module('datasets.preprocess')
|
preprocessor = importlib.import_module('datasets.preprocess')
|
||||||
preprocessor = getattr(preprocessor, c.dataset.lower())
|
preprocessor = getattr(preprocessor, c.dataset.lower())
|
||||||
MyDataset = importlib.import_module('datasets.' + c.data_loader)
|
|
||||||
MyDataset = getattr(MyDataset, "MyDataset")
|
|
||||||
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
|
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
|
||||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue