mirror of https://github.com/coqui-ai/TTS.git
linter fix
This commit is contained in:
parent
609d8efa69
commit
d45d963dc1
29
train.py
29
train.py
|
@ -190,7 +190,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# compute alignment score
|
# compute alignment score
|
||||||
align_score = alignment_diagonal_score(alignments)
|
align_score = alignment_diagonal_score(alignments)
|
||||||
keep_avg.update_value('avg_align_score', align_score)
|
keep_avg.update_value('avg_align_score', align_score)
|
||||||
|
|
||||||
|
@ -281,7 +281,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, keep_avg['avg_postnet_loss'], keep_avg['avg_decoder_loss'],
|
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, keep_avg['avg_postnet_loss'], keep_avg['avg_decoder_loss'],
|
||||||
keep_avg['avg_stop_loss'], keep_avg['avg_align_score'],
|
keep_avg['avg_stop_loss'], keep_avg['avg_align_score'],
|
||||||
epoch_time, keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
|
epoch_time, keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
|
@ -305,11 +305,11 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
eval_values_dict = {'avg_postnet_loss' : 0,
|
eval_values_dict = {'avg_postnet_loss': 0,
|
||||||
'avg_decoder_loss' : 0,
|
'avg_decoder_loss': 0,
|
||||||
'avg_stop_loss' : 0,
|
'avg_stop_loss': 0,
|
||||||
'avg_align_score': 0}
|
'avg_align_score': 0}
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
keep_avg.add_values(eval_values_dict)
|
keep_avg.add_values(eval_values_dict)
|
||||||
print("\n > Validation")
|
print("\n > Validation")
|
||||||
if c.test_sentences_file is None:
|
if c.test_sentences_file is None:
|
||||||
|
@ -401,18 +401,19 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
if c.stopnet:
|
if c.stopnet:
|
||||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||||
|
|
||||||
keep_avg.update_values({'avg_postnet_loss' : float(postnet_loss.item()),
|
keep_avg.update_values({'avg_postnet_loss': float(postnet_loss.item()),
|
||||||
'avg_decoder_loss' : float(decoder_loss.item()),
|
'avg_decoder_loss': float(decoder_loss.item()),
|
||||||
'avg_stop_loss' : float(stop_loss.item())})
|
'avg_stop_loss': float(stop_loss.item())})
|
||||||
|
|
||||||
if num_iter % c.print_step == 0:
|
if num_iter % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
||||||
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format(loss.item(),
|
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format(
|
||||||
postnet_loss.item(), keep_avg['avg_postnet_loss'],
|
loss.item(),
|
||||||
decoder_loss.item(), keep_avg['avg_decoder_loss'],
|
postnet_loss.item(), keep_avg['avg_postnet_loss'],
|
||||||
stop_loss.item(), keep_avg['avg_stop_loss'],
|
decoder_loss.item(), keep_avg['avg_decoder_loss'],
|
||||||
align_score.item(), keep_avg['avg_align_score']),
|
stop_loss.item(), keep_avg['avg_stop_loss'],
|
||||||
|
align_score.item(), keep_avg['avg_align_score']),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
|
|
|
@ -31,7 +31,8 @@ def load_config(config_path):
|
||||||
def get_git_branch():
|
def get_git_branch():
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||||
current = next(line for line in out.split("\n") if line.startswith("*"))
|
current = next(line for line in out.split(
|
||||||
|
"\n") if line.startswith("*"))
|
||||||
current.replace("* ", "")
|
current.replace("* ", "")
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
current = "inside_docker"
|
current = "inside_docker"
|
||||||
|
@ -298,7 +299,7 @@ def split_dataset(items):
|
||||||
# most stupid code ever -- Fix it !
|
# most stupid code ever -- Fix it !
|
||||||
while len(items_eval) < eval_split_size:
|
while len(items_eval) < eval_split_size:
|
||||||
speakers = [item[-1] for item in items]
|
speakers = [item[-1] for item in items]
|
||||||
speaker_counter = Counter(speakers)
|
speaker_counter = Counter(speakers)
|
||||||
item_idx = np.random.randint(0, len(items))
|
item_idx = np.random.randint(0, len(items))
|
||||||
if speaker_counter[items[item_idx][-1]] > 1:
|
if speaker_counter[items[item_idx][-1]] > 1:
|
||||||
items_eval.append(items[item_idx])
|
items_eval.append(items[item_idx])
|
||||||
|
@ -323,20 +324,21 @@ class KeepAverage():
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.avg_values[key]
|
return self.avg_values[key]
|
||||||
|
|
||||||
def add_value(self, name, init_val=0, init_iter=0):
|
def add_value(self, name, init_val=0, init_iter=0):
|
||||||
self.avg_values[name] = init_val
|
self.avg_values[name] = init_val
|
||||||
self.iters[name] = init_iter
|
self.iters[name] = init_iter
|
||||||
|
|
||||||
def update_value(self, name, value, weighted_avg=False):
|
def update_value(self, name, value, weighted_avg=False):
|
||||||
if weighted_avg:
|
if weighted_avg:
|
||||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||||
self.iters[name] += 1
|
self.iters[name] += 1
|
||||||
else:
|
else:
|
||||||
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
|
self.avg_values[name] = self.avg_values[name] * \
|
||||||
|
self.iters[name] + value
|
||||||
self.iters[name] += 1
|
self.iters[name] += 1
|
||||||
self.avg_values[name] /= self.iters[name]
|
self.avg_values[name] /= self.iters[name]
|
||||||
|
|
||||||
def add_values(self, name_dict):
|
def add_values(self, name_dict):
|
||||||
for key, value in name_dict.items():
|
for key, value in name_dict.items():
|
||||||
self.add_value(key, init_val=value)
|
self.add_value(key, init_val=value)
|
||||||
|
@ -344,4 +346,3 @@ class KeepAverage():
|
||||||
def update_values(self, value_dict):
|
def update_values(self, value_dict):
|
||||||
for key, value in value_dict.items():
|
for key, value in value_dict.items():
|
||||||
self.update_value(key, value)
|
self.update_value(key, value)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def alignment_diagonal_score(alignments):
|
def alignment_diagonal_score(alignments):
|
||||||
"""
|
"""
|
||||||
|
@ -12,8 +9,3 @@ def alignment_diagonal_score(alignments):
|
||||||
alignments : batch x decoder_steps x encoder_steps
|
alignments : batch x decoder_steps x encoder_steps
|
||||||
"""
|
"""
|
||||||
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0)
|
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue