mirror of https://github.com/coqui-ai/TTS.git
bug fixes on train.py
This commit is contained in:
parent
96b48c003a
commit
17b65d5cde
14
train.py
14
train.py
|
@ -43,7 +43,9 @@ def setup_loader(is_val=False):
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
batch_group_size=0 if is_val else 8 * c.batch_size,
|
batch_group_size=0 if is_val else 8 * c.batch_size,
|
||||||
min_seq_len=0 if is_val else c.min_seq_len)
|
min_seq_len=0 if is_val else c.min_seq_len,
|
||||||
|
max_seq_len=float("inf") if is_val else c.max_seq_len
|
||||||
|
cached=False if c.dataset ~= "tts_cache" else True)
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||||
|
@ -164,8 +166,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
|
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += float(linear_loss.item())
|
||||||
avg_mel_loss += mel_loss.item()
|
avg_mel_loss += float(mel_loss.item())
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss.item()
|
||||||
avg_step_time += step_time
|
avg_step_time += step_time
|
||||||
|
|
||||||
|
@ -198,7 +200,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
# Sample audio
|
# Sample audio
|
||||||
tb_logger.tb_train_audios(current_step,
|
tb_logger.tb_train_audios(current_step,
|
||||||
{'TrainAudio': ap.inv_spectrogram(const_spec.T)},
|
{'TrainAudio': ap.inv_spectrogram(const_spec.T)},
|
||||||
c.sample_rate)
|
c.audio["sample_rate"])
|
||||||
|
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
|
@ -295,8 +297,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
|
||||||
stop_loss.item()),
|
stop_loss.item()),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += float(linear_loss.item())
|
||||||
avg_mel_loss += mel_loss.item()
|
avg_mel_loss += float(mel_loss.item())
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss.item()
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
|
|
Loading…
Reference in New Issue