Skip to content

Commit

Permalink
Merge pull request #10 from MalikAly/factor
Browse files Browse the repository at this point in the history
Add scalar factor support and tests for MAD and ZSCORE
  • Loading branch information
KulikDM authored Nov 21, 2024
2 parents 45863b7 + 0389df2 commit 323ce2e
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 13 deletions.
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
name: Format docstrings

- repo: https://github.com/asottile/pyupgrade
rev: v3.18.0
rev: v3.19.0
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -42,8 +42,9 @@ repos:
name: Sort imports

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.7.0
rev: v0.7.2
hooks:
- id: ruff
args: [--exit-non-zero-on-fix, --fix, --line-length=180]
exclude: "\\.ipynb$"
name: Lint code
9 changes: 7 additions & 2 deletions pythresh/test/test_mad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import unittest
from itertools import product
from os.path import dirname as up

# noinspection PyProtectedMember
Expand Down Expand Up @@ -41,11 +42,15 @@ def setUp(self):

self.all_scores = [scores, multiple_scores]

self.thres = MAD()
self.factors = [0.5, 1, 2]

def test_prediction_labels(self):

for scores in self.all_scores:
params = product(self.all_scores, self.factors)

for scores, factor in params:

self.thres = MAD(factor=factor)

pred_labels = self.thres.eval(scores)
assert (self.thres.thresh_ is not None)
Expand Down
9 changes: 7 additions & 2 deletions pythresh/test/test_zscore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import unittest
from itertools import product
from os.path import dirname as up

# noinspection PyProtectedMember
Expand Down Expand Up @@ -41,11 +42,15 @@ def setUp(self):

self.all_scores = [scores, multiple_scores]

self.thres = ZSCORE()
self.factors = [0.5, 1, 2]

def test_prediction_labels(self):

for scores in self.all_scores:
params = product(self.all_scores, self.factors)

for scores, factor in params:

self.thres = ZSCORE(factor=factor)

pred_labels = self.thres.eval(scores)
assert (self.thres.thresh_ is not None)
Expand Down
10 changes: 7 additions & 3 deletions pythresh/thresholds/mad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class MAD(BaseThresholder):
Parameters
----------
factor : int, optional (default=1)
The factor to multiply the MAD by to set the threshold.
The default is 1.
random_state : int, optional (default=1234)
Random seed for the random number generators of the thresholders. Can also
be set to None.
Expand Down Expand Up @@ -48,8 +51,9 @@ class MAD(BaseThresholder):
"""

def __init__(self, random_state=1234):
def __init__(self, factor=1, random_state=1234):

self.factor = factor
self.random_state = random_state

def eval(self, decision):
Expand Down Expand Up @@ -78,8 +82,8 @@ def eval(self, decision):

# Set limit
mean = np.mean(decision)
limit = mean + \
stats.median_abs_deviation(decision, scale=np.std(decision))
mad = stats.median_abs_deviation(decision, scale=np.std(decision))
limit = mean + self.factor * mad

self.thresh_ = limit

Expand Down
11 changes: 7 additions & 4 deletions pythresh/thresholds/zscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class ZSCORE(BaseThresholder):
Parameters
----------
factor : int, optional (default=1)
The factor to multiply the zscore by to set the threshold.
The default is 1.
random_state : int, optional (default=1234)
Random seed for the random number generators of the thresholders. Can also
be set to None.
Expand Down Expand Up @@ -43,8 +45,9 @@ class ZSCORE(BaseThresholder):
"""

def __init__(self, random_state=1234):
def __init__(self, factor=1, random_state=1234):

self.factor = factor
self.random_state = random_state

def eval(self, decision):
Expand Down Expand Up @@ -74,9 +77,9 @@ def eval(self, decision):
# Get the zscore of the decision scores
zscore = stats.zscore(decision)

# Set the limit to where the zscore is 1
# Set the limit to where the zscore is greater than the factor
labels = np.zeros(len(decision), dtype=int)
mask = np.where(zscore >= 1.0)
mask = np.where(zscore >= self.factor)
labels[mask] = 1

self.thresh_ = np.min(labels[labels == 1])
Expand Down

0 comments on commit 323ce2e

Please sign in to comment.