From e71a3772021d47bbf6a9c4f1e38420ba316f7e8c Mon Sep 17 00:00:00 2001 From: erogol Date: Sat, 30 May 2020 14:05:19 +0200 Subject: [PATCH] add a non-existing value to avg values --- utils/generic_utils.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 1c7dd5e4..5a811907 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -197,14 +197,19 @@ class KeepAverage(): self.iters[name] = init_iter def update_value(self, name, value, weighted_avg=False): - if weighted_avg: - self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value - self.iters[name] += 1 + if name not in self.avg_values: + # add value if not exist before + self.add_value(name, init_val=value) else: - self.avg_values[name] = self.avg_values[name] * \ - self.iters[name] + value - self.iters[name] += 1 - self.avg_values[name] /= self.iters[name] + # else update existing value + if weighted_avg: + self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value + self.iters[name] += 1 + else: + self.avg_values[name] = self.avg_values[name] * \ + self.iters[name] + value + self.iters[name] += 1 + self.avg_values[name] /= self.iters[name] def add_values(self, name_dict): for key, value in name_dict.items():