This commit is contained in:
Eren Golge 2018-03-09 15:37:58 -08:00
parent 332abb1de6
commit f434781ab3
2 changed files with 14 additions and 12 deletions

View File

@ -14,7 +14,7 @@
"epochs": 2000, "epochs": 2000,
"lr": 0.00001875, "lr": 0.00001875,
"warmup_steps": 4000, "warmup_steps": 4000,
"batch_size": 1, "batch_size": 2,
"eval_batch_size": 32, "eval_batch_size": 32,
"r": 5, "r": 5,

View File

@ -177,9 +177,10 @@ def train(model, criterion, data_loader, optimizer, epoch):
tb.add_audio('SampleAudio', audio_signal, current_step, tb.add_audio('SampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate) sample_rate=c.sample_rate)
except: except:
print("\n > Error at audio signal on TB!!") # print("\n > Error at audio signal on TB!!")
print(audio_signal.max()) # print(audio_signal.max())
print(audio_signal.min()) # print(audio_signal.min())
pass
avg_linear_loss /= (num_iter + 1) avg_linear_loss /= (num_iter + 1)
@ -197,12 +198,12 @@ def train(model, criterion, data_loader, optimizer, epoch):
def evaluate(model, criterion, data_loader, current_step): def evaluate(model, criterion, data_loader, current_step):
model = model.train() model = model.eval()
epoch_time = 0 epoch_time = 0
print(" | > Validation") print(" | > Validation")
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
progbar = Progbar(len(data_loader.dataset) / c.batch_size) progbar = Progbar(len(data_loader.dataset) / c.eval_batch_size)
avg_linear_loss = 0 avg_linear_loss = 0
avg_mel_loss = 0 avg_mel_loss = 0
@ -271,9 +272,10 @@ def evaluate(model, criterion, data_loader, current_step):
tb.add_audio('ValSampleAudio', audio_signal, current_step, tb.add_audio('ValSampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate) sample_rate=c.sample_rate)
except: except:
print(" | > Error at audio signal on TB!!") # print(" | > Error at audio signal on TB!!")
print(audio_signal.max()) # print(audio_signal.max())
print(audio_signal.min()) # print(audio_signal.min())
pass
# compute average losses # compute average losses
avg_linear_loss /= (num_iter + 1) avg_linear_loss /= (num_iter + 1)