diff --git a/utils/data.py b/utils/data.py index f2d7538a..a83325cb 100644 --- a/utils/data.py +++ b/utils/data.py @@ -50,3 +50,28 @@ def pad_per_step(inputs, pad_len): inputs, [[0, 0], [0, 0], [0, pad_len]], mode='constant', constant_values=0.0) + + +# pylint: disable=attribute-defined-outside-init +class StandardScaler(): + + def set_stats(self, mean, scale): + self.mean_ = mean + self.scale_ = scale + + def reset_stats(self): + delattr(self, 'mean_') + delattr(self, 'scale_') + + def transform(self, X): + X = np.asarray(X) + X -= self.mean_ + X /= self.scale_ + return X + + def inverse_transform(self, X): + X = np.asarray(X) + X *= self.scale_ + X += self.mean_ + return X +