From cbedefac0ad93ca841e2b7691d9db84c8c7a24c4 Mon Sep 17 00:00:00 2001 From: Marko Toplak Date: Tue, 18 Jul 2023 13:04:02 +0200 Subject: [PATCH] SubarrayNorms and SubarrayImpute get __eq__ and __hash__ --- Orange/preprocess/normalize.py | 23 +++++++++++++++++++++++ Orange/preprocess/preprocess.py | 22 ++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/Orange/preprocess/normalize.py b/Orange/preprocess/normalize.py index 5847c45354e..1296a93cd24 100644 --- a/Orange/preprocess/normalize.py +++ b/Orange/preprocess/normalize.py @@ -16,6 +16,7 @@ def __init__(self, source_vars, offsets, factors): self.source_vars = tuple(source_vars) self.offsets = np.array(offsets) self.factors = np.array(factors) + self._hash = None def __call__(self, data, cols): X = data.transform(Domain(self.source_vars[cols])).X @@ -29,6 +30,28 @@ def __call__(self, data, cols): else: return (X-offsets.reshape(1, -1)) * (factors.reshape(1, -1)) + def __eq__(self, other): + if self is other: + return True + return type(self) is type(other) \ + and self.source_vars == other.source_vars \ + and np.all(self.offsets == other.offsets) \ + and np.all(self.factors == other.factors) + + def __setstate__(self, state): + self.__dict__.update(state) + self._hash = None + + def __getstate__(self): + state = self.__dict__.copy() + del state["_hash"] + return state + + def __hash__(self): + if self._hash is None: + self._hash = hash((self.source_vars, tuple(self.offsets), tuple(self.factors))) + return self._hash + def compress_norm_to_subarray(domain): source_vars = [] diff --git a/Orange/preprocess/preprocess.py b/Orange/preprocess/preprocess.py index a6195ee2d44..44f7cc6f6f2 100644 --- a/Orange/preprocess/preprocess.py +++ b/Orange/preprocess/preprocess.py @@ -155,6 +155,7 @@ class SubarrayImpute: def __init__(self, source_vars, vals): self.source_vars = tuple(source_vars) self.vals = np.array(vals) + self._hash = None def __call__(self, data, cols): X = data.transform(Orange.data.Domain(self.source_vars[cols])).X @@ -172,6 +173,27 @@ def __call__(self, data, cols): else: return np.where(np.isnan(X), vals.reshape(1, -1), X) + def __eq__(self, other): + if self is other: + return True + return type(self) is type(other) \ + and self.source_vars == other.source_vars \ + and np.all(self.vals == other.vals) + + def __setstate__(self, state): + self.__dict__.update(state) + self._hash = None + + def __getstate__(self): + state = self.__dict__.copy() + del state["_hash"] + return state + + def __hash__(self): + if self._hash is None: + self._hash = hash((self.source_vars, tuple(self.vals))) + return self._hash + def compress_replace_unknowns_to_subarray(domain): source_vars = []