diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 1c7dd5e4..5a811907 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -197,14 +197,19 @@ class KeepAverage(): self.iters[name] = init_iter def update_value(self, name, value, weighted_avg=False): - if weighted_avg: - self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value - self.iters[name] += 1 + if name not in self.avg_values: + # add value if not exist before + self.add_value(name, init_val=value) else: - self.avg_values[name] = self.avg_values[name] * \ - self.iters[name] + value - self.iters[name] += 1 - self.avg_values[name] /= self.iters[name] + # else update existing value + if weighted_avg: + self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value + self.iters[name] += 1 + else: + self.avg_values[name] = self.avg_values[name] * \ + self.iters[name] + value + self.iters[name] += 1 + self.avg_values[name] /= self.iters[name] def add_values(self, name_dict): for key, value in name_dict.items():