diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88b2edb..b954e14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: name: Format docstrings - repo: https://github.com/asottile/pyupgrade - rev: v3.18.0 + rev: v3.19.0 hooks: - id: pyupgrade args: [--py38-plus] @@ -42,8 +42,9 @@ repos: name: Sort imports - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.7.0 + rev: v0.7.2 hooks: - id: ruff args: [--exit-non-zero-on-fix, --fix, --line-length=180] + exclude: "\\.ipynb$" name: Lint code diff --git a/pythresh/test/test_mad.py b/pythresh/test/test_mad.py index c003443..1a0afd7 100644 --- a/pythresh/test/test_mad.py +++ b/pythresh/test/test_mad.py @@ -1,5 +1,6 @@ import sys import unittest +from itertools import product from os.path import dirname as up # noinspection PyProtectedMember @@ -41,11 +42,15 @@ def setUp(self): self.all_scores = [scores, multiple_scores] - self.thres = MAD() + self.factors = [0.5, 1, 2] def test_prediction_labels(self): - for scores in self.all_scores: + params = product(self.all_scores, self.factors) + + for scores, factor in params: + + self.thres = MAD(factor=factor) pred_labels = self.thres.eval(scores) assert (self.thres.thresh_ is not None) diff --git a/pythresh/test/test_zscore.py b/pythresh/test/test_zscore.py index ca7760e..c8be587 100644 --- a/pythresh/test/test_zscore.py +++ b/pythresh/test/test_zscore.py @@ -1,5 +1,6 @@ import sys import unittest +from itertools import product from os.path import dirname as up # noinspection PyProtectedMember @@ -41,11 +42,15 @@ def setUp(self): self.all_scores = [scores, multiple_scores] - self.thres = ZSCORE() + self.factors = [0.5, 1, 2] def test_prediction_labels(self): - for scores in self.all_scores: + params = product(self.all_scores, self.factors) + + for scores, factor in params: + + self.thres = ZSCORE(factor=factor) pred_labels = self.thres.eval(scores) assert (self.thres.thresh_ is not None) diff --git a/pythresh/thresholds/mad.py b/pythresh/thresholds/mad.py index 3a5c7e5..0f44b61 100644 --- a/pythresh/thresholds/mad.py +++ b/pythresh/thresholds/mad.py @@ -17,6 +17,9 @@ class MAD(BaseThresholder): Parameters ---------- + factor : int, optional (default=1) + The factor to multiply the MAD by to set the threshold. + The default is 1. random_state : int, optional (default=1234) Random seed for the random number generators of the thresholders. Can also be set to None. @@ -48,8 +51,9 @@ class MAD(BaseThresholder): """ - def __init__(self, random_state=1234): + def __init__(self, factor=1, random_state=1234): + self.factor = factor self.random_state = random_state def eval(self, decision): @@ -78,8 +82,8 @@ def eval(self, decision): # Set limit mean = np.mean(decision) - limit = mean + \ - stats.median_abs_deviation(decision, scale=np.std(decision)) + mad = stats.median_abs_deviation(decision, scale=np.std(decision)) + limit = mean + self.factor * mad self.thresh_ = limit diff --git a/pythresh/thresholds/zscore.py b/pythresh/thresholds/zscore.py index 41faaed..0cf3993 100644 --- a/pythresh/thresholds/zscore.py +++ b/pythresh/thresholds/zscore.py @@ -15,7 +15,9 @@ class ZSCORE(BaseThresholder): Parameters ---------- - + factor : int, optional (default=1) + The factor to multiply the zscore by to set the threshold. + The default is 1. random_state : int, optional (default=1234) Random seed for the random number generators of the thresholders. Can also be set to None. @@ -43,8 +45,9 @@ class ZSCORE(BaseThresholder): """ - def __init__(self, random_state=1234): + def __init__(self, factor=1, random_state=1234): + self.factor = factor self.random_state = random_state def eval(self, decision): @@ -74,9 +77,9 @@ def eval(self, decision): # Get the zscore of the decision scores zscore = stats.zscore(decision) - # Set the limit to where the zscore is 1 + # Set the limit to where the zscore is greater than the factor labels = np.zeros(len(decision), dtype=int) - mask = np.where(zscore >= 1.0) + mask = np.where(zscore >= self.factor) labels[mask] = 1 self.thresh_ = np.min(labels[labels == 1])