mirror of https://github.com/coqui-ai/TTS.git
bug fix
This commit is contained in:
parent
332abb1de6
commit
f434781ab3
|
@ -14,7 +14,7 @@
|
|||
"epochs": 2000,
|
||||
"lr": 0.00001875,
|
||||
"warmup_steps": 4000,
|
||||
"batch_size": 1,
|
||||
"batch_size": 2,
|
||||
"eval_batch_size": 32,
|
||||
"r": 5,
|
||||
|
||||
|
|
24
train.py
24
train.py
|
@ -177,9 +177,10 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
tb.add_audio('SampleAudio', audio_signal, current_step,
|
||||
sample_rate=c.sample_rate)
|
||||
except:
|
||||
print("\n > Error at audio signal on TB!!")
|
||||
print(audio_signal.max())
|
||||
print(audio_signal.min())
|
||||
# print("\n > Error at audio signal on TB!!")
|
||||
# print(audio_signal.max())
|
||||
# print(audio_signal.min())
|
||||
pass
|
||||
|
||||
|
||||
avg_linear_loss /= (num_iter + 1)
|
||||
|
@ -197,12 +198,12 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
|
||||
|
||||
def evaluate(model, criterion, data_loader, current_step):
|
||||
model = model.train()
|
||||
model = model.eval()
|
||||
epoch_time = 0
|
||||
|
||||
print(" | > Validation")
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||
progbar = Progbar(len(data_loader.dataset) / c.eval_batch_size)
|
||||
|
||||
avg_linear_loss = 0
|
||||
avg_mel_loss = 0
|
||||
|
@ -271,9 +272,10 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
tb.add_audio('ValSampleAudio', audio_signal, current_step,
|
||||
sample_rate=c.sample_rate)
|
||||
except:
|
||||
print(" | > Error at audio signal on TB!!")
|
||||
print(audio_signal.max())
|
||||
print(audio_signal.min())
|
||||
# print(" | > Error at audio signal on TB!!")
|
||||
# print(audio_signal.max())
|
||||
# print(audio_signal.min())
|
||||
pass
|
||||
|
||||
# compute average losses
|
||||
avg_linear_loss /= (num_iter + 1)
|
||||
|
@ -307,9 +309,9 @@ def main(args):
|
|||
)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||
drop_last=False, num_workers=c.num_loader_workers,
|
||||
pin_memory=True)
|
||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||
drop_last=False, num_workers=c.num_loader_workers,
|
||||
pin_memory=True)
|
||||
|
||||
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
|
|
Loading…
Reference in New Issue