From d45d963dc11c37cafb4b86c73d81a18114342e2c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 11 Sep 2019 10:39:59 +0200 Subject: [PATCH] linter fix --- train.py | 29 +++++++++++++++-------------- utils/generic_utils.py | 15 ++++++++------- utils/measures.py | 8 -------- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/train.py b/train.py index 1100c1f3..13444c82 100644 --- a/train.py +++ b/train.py @@ -190,7 +190,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() - # compute alignment score + # compute alignment score align_score = alignment_diagonal_score(alignments) keep_avg.update_value('avg_align_score', align_score) @@ -281,7 +281,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, keep_avg['avg_postnet_loss'], keep_avg['avg_decoder_loss'], - keep_avg['avg_stop_loss'], keep_avg['avg_align_score'], + keep_avg['avg_stop_loss'], keep_avg['avg_align_score'], epoch_time, keep_avg['avg_step_time'], keep_avg['avg_loader_time']), flush=True) @@ -305,11 +305,11 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): speaker_mapping = load_speaker_mapping(OUT_PATH) model.eval() epoch_time = 0 - eval_values_dict = {'avg_postnet_loss' : 0, - 'avg_decoder_loss' : 0, - 'avg_stop_loss' : 0, + eval_values_dict = {'avg_postnet_loss': 0, + 'avg_decoder_loss': 0, + 'avg_stop_loss': 0, 'avg_align_score': 0} - keep_avg = KeepAverage() + keep_avg = KeepAverage() keep_avg.add_values(eval_values_dict) print("\n > Validation") if c.test_sentences_file is None: @@ -401,18 +401,19 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): if c.stopnet: stop_loss = reduce_tensor(stop_loss.data, num_gpus) - keep_avg.update_values({'avg_postnet_loss' : float(postnet_loss.item()), - 'avg_decoder_loss' : float(decoder_loss.item()), - 'avg_stop_loss' : float(stop_loss.item())}) + keep_avg.update_values({'avg_postnet_loss': float(postnet_loss.item()), + 'avg_decoder_loss': float(decoder_loss.item()), + 'avg_stop_loss': float(stop_loss.item())}) if num_iter % c.print_step == 0: print( " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} " - "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format(loss.item(), - postnet_loss.item(), keep_avg['avg_postnet_loss'], - decoder_loss.item(), keep_avg['avg_decoder_loss'], - stop_loss.item(), keep_avg['avg_stop_loss'], - align_score.item(), keep_avg['avg_align_score']), + "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format( + loss.item(), + postnet_loss.item(), keep_avg['avg_postnet_loss'], + decoder_loss.item(), keep_avg['avg_decoder_loss'], + stop_loss.item(), keep_avg['avg_stop_loss'], + align_score.item(), keep_avg['avg_align_score']), flush=True) if args.rank == 0: diff --git a/utils/generic_utils.py b/utils/generic_utils.py index d72ffdd5..1053d221 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -31,7 +31,8 @@ def load_config(config_path): def get_git_branch(): try: out = subprocess.check_output(["git", "branch"]).decode("utf8") - current = next(line for line in out.split("\n") if line.startswith("*")) + current = next(line for line in out.split( + "\n") if line.startswith("*")) current.replace("* ", "") except subprocess.CalledProcessError: current = "inside_docker" @@ -298,7 +299,7 @@ def split_dataset(items): # most stupid code ever -- Fix it ! while len(items_eval) < eval_split_size: speakers = [item[-1] for item in items] - speaker_counter = Counter(speakers) + speaker_counter = Counter(speakers) item_idx = np.random.randint(0, len(items)) if speaker_counter[items[item_idx][-1]] > 1: items_eval.append(items[item_idx]) @@ -323,20 +324,21 @@ class KeepAverage(): def __getitem__(self, key): return self.avg_values[key] - + def add_value(self, name, init_val=0, init_iter=0): self.avg_values[name] = init_val self.iters[name] = init_iter - + def update_value(self, name, value, weighted_avg=False): if weighted_avg: self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value self.iters[name] += 1 else: - self.avg_values[name] = self.avg_values[name] * self.iters[name] + value + self.avg_values[name] = self.avg_values[name] * \ + self.iters[name] + value self.iters[name] += 1 self.avg_values[name] /= self.iters[name] - + def add_values(self, name_dict): for key, value in name_dict.items(): self.add_value(key, init_val=value) @@ -344,4 +346,3 @@ class KeepAverage(): def update_values(self, value_dict): for key, value in value_dict.items(): self.update_value(key, value) - diff --git a/utils/measures.py b/utils/measures.py index 21652cf0..21b61298 100644 --- a/utils/measures.py +++ b/utils/measures.py @@ -1,6 +1,3 @@ -import torch -import numpy as np - def alignment_diagonal_score(alignments): """ @@ -12,8 +9,3 @@ def alignment_diagonal_score(alignments): alignments : batch x decoder_steps x encoder_steps """ return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0) - - - - -