diff --git a/dowhy/gcm/stats.py b/dowhy/gcm/stats.py index 2a074eb708..ea43fad86b 100644 --- a/dowhy/gcm/stats.py +++ b/dowhy/gcm/stats.py @@ -25,6 +25,8 @@ def merge_p_values_average(p_values: Union[np.ndarray, List[float]], randomizati """ if len(p_values) == 0: raise ValueError("Given list of p-values is empty!") + if len(p_values) == 1: + return p_values[0] if np.all(np.isnan(p_values)): return float(np.nan) diff --git a/tests/gcm/test_stats.py b/tests/gcm/test_stats.py index 6b7597f27b..3221192af1 100644 --- a/tests/gcm/test_stats.py +++ b/tests/gcm/test_stats.py @@ -73,6 +73,7 @@ def test_given_invalid_inputs_when_merge_p_values_quantile_then_raises_error(): def test_when_merge_p_values_average_without_randomization_then_returns_expected_results(): assert merge_p_values_average([0]) == 0 assert merge_p_values_average([1]) == 1 + assert merge_p_values_average([0.3]) == 0.3 assert merge_p_values_average([0, 1]) == approx(1.0) assert merge_p_values_average([0, 0, 1]) == 0 assert merge_p_values_average([0, 0.5, 0.5, np.nan, 1, np.nan]) == approx(1.0) @@ -83,6 +84,7 @@ def test_when_merge_p_values_average_without_randomization_then_returns_expected def test_when_merge_p_values_average_with_randomization_then_returns_expected_results(): assert merge_p_values_average([0], randomization=True) == 0 assert merge_p_values_average([1], randomization=True) == 1 + assert merge_p_values_average([0.3], randomization=True) == 0.3 assert merge_p_values_average([0, 1], randomization=True) == approx(0.0, abs=0.01) assert merge_p_values_average([0, 0, 1], randomization=True) == approx(0.0, abs=0.01) assert merge_p_values_average([0, np.nan, 0, np.nan, 1, 1], randomization=True) == approx(0.0, abs=0.01)