Skip to content

Commit

Permalink
SubarrayNorms and SubarrayImpute get __eq__ and __hash__
Browse files Browse the repository at this point in the history
  • Loading branch information
markotoplak committed Jul 26, 2023
1 parent e1115b8 commit cbedefa
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
23 changes: 23 additions & 0 deletions Orange/preprocess/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, source_vars, offsets, factors):
self.source_vars = tuple(source_vars)
self.offsets = np.array(offsets)
self.factors = np.array(factors)
self._hash = None

def __call__(self, data, cols):
X = data.transform(Domain(self.source_vars[cols])).X
Expand All @@ -29,6 +30,28 @@ def __call__(self, data, cols):
else:
return (X-offsets.reshape(1, -1)) * (factors.reshape(1, -1))

def __eq__(self, other):
if self is other:
return True
return type(self) is type(other) \
and self.source_vars == other.source_vars \
and np.all(self.offsets == other.offsets) \
and np.all(self.factors == other.factors)

def __setstate__(self, state):
self.__dict__.update(state)
self._hash = None

def __getstate__(self):
state = self.__dict__.copy()
del state["_hash"]
return state

def __hash__(self):
if self._hash is None:
self._hash = hash((self.source_vars, tuple(self.offsets), tuple(self.factors)))
return self._hash


def compress_norm_to_subarray(domain):
source_vars = []
Expand Down
22 changes: 22 additions & 0 deletions Orange/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class SubarrayImpute:
def __init__(self, source_vars, vals):
self.source_vars = tuple(source_vars)
self.vals = np.array(vals)
self._hash = None

def __call__(self, data, cols):
X = data.transform(Orange.data.Domain(self.source_vars[cols])).X
Expand All @@ -172,6 +173,27 @@ def __call__(self, data, cols):
else:
return np.where(np.isnan(X), vals.reshape(1, -1), X)

def __eq__(self, other):
if self is other:
return True
return type(self) is type(other) \
and self.source_vars == other.source_vars \
and np.all(self.vals == other.vals)

def __setstate__(self, state):
self.__dict__.update(state)
self._hash = None

def __getstate__(self):
state = self.__dict__.copy()
del state["_hash"]
return state

def __hash__(self):
if self._hash is None:
self._hash = hash((self.source_vars, tuple(self.vals)))
return self._hash


def compress_replace_unknowns_to_subarray(domain):
source_vars = []
Expand Down

0 comments on commit cbedefa

Please sign in to comment.