Skip to content

Commit e08267a

Browse files
authored
Implement Kolmogorov Smirnov Test in SQL-only (#28)
## Change This PR implements the Kolmogorov Smirnov test in pure SQL which is then run directly on the database. ## Commits * first integration: ks-test in database functionality * integrate sql query with data refs * formatting * formatting * refactoring: call `test` directly to access sql-result of KS test * fix row count retrieval * fix acceptance level domain error * fix alpha adjustment * fix type hints for python<3.10 * update sql query for postgres: all tables need to have an alias assigned to them * fix: typo * update query for mssql server * add check for column names * alternative way of getting table name, incl. hot fix for mssql quotation marks in table reference * don't accept zero alphas since in practice they don't make much sense * update variable naming and doc-strings * update data retrieval * include query nesting brackets * better formatting for understandibility * better formatting for understandibility * update query for better readibility with more WITH statements * new option of passing values to the TestResult to compare these * seperate implementation testing from use case testing * make independent of numpy * update tests: new distributions, no scipy and numpy dependency, random numbers generated from seed for reproducability * update comment * optional accuracy through scipy * refactoring, clean up and formatting * update comment and type hints * update tpye hints for older python versions * fix type hint: Tuple instead of tuple * update changelog and include comment about scipy calculation
1 parent 5ddc874 commit e08267a

File tree

6 files changed

+305
-42
lines changed

6 files changed

+305
-42
lines changed

CHANGELOG.rst

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
Changelog
88
=========
99

10+
1.1.1 - 2022.06.30
11+
------------------
12+
13+
**New: SQL implementation for KS-test**
14+
15+
- The Kolgomorov Smirnov test is now implemented in pure SQL, shifting the computation to the database engine, improving performance tremendously.
16+
1017
1.1.0 - 2022.06.01
1118
------------------
1219

src/datajudge/constraints/stats.py

+92-30
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import Any, Collection, Optional, Tuple
1+
import math
2+
import warnings
3+
from typing import Optional, Tuple, Union
24

35
import sqlalchemy as sa
6+
from sqlalchemy.sql import Selectable
47

58
from .. import db_access
69
from ..db_access import DataReference
7-
from .base import Constraint, OptionalSelections
10+
from .base import Constraint, TestResult
811

912

1013
class KolmogorovSmirnov2Sample(Constraint):
@@ -15,43 +18,102 @@ def __init__(
1518
super().__init__(ref, ref2=ref2)
1619

1720
@staticmethod
18-
def calculate_2sample_ks_test(data: Collection, data2: Collection) -> float:
21+
def approximate_p_value(
22+
d: float, n_samples: int, m_samples: int
23+
) -> Optional[float]:
1924
"""
20-
For two given lists of values calculates the Kolmogorov-Smirnov test.
21-
Read more here: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kstest.html
25+
Calculates the approximate p-value according to
26+
'A procedure to find exact critical values of Kolmogorov-Smirnov Test', Silvia Fachinetti, 2009
27+
28+
Note: For environments with `scipy` installed, this method will return a quasi-exact p-value.
2229
"""
30+
31+
# approximation does not work for small sample sizes
32+
samples = min(n_samples, m_samples)
33+
if samples < 35:
34+
warnings.warn(
35+
"Approximating the p-value is not accurate enough for sample size < 35"
36+
)
37+
return None
38+
39+
# if scipy is installed, accurately calculate the p_value using the full distribution
2340
try:
24-
from scipy.stats import ks_2samp
25-
except ModuleNotFoundError:
26-
raise ModuleNotFoundError(
27-
"Calculating the Kolmogorov-Smirnov test relies on scipy."
28-
"Therefore, please install scipy before using this test."
41+
from scipy.stats.distributions import kstwo
42+
43+
approx_p = kstwo.sf(
44+
d, round((n_samples * m_samples) / (n_samples + m_samples))
2945
)
46+
except ModuleNotFoundError:
47+
d_alpha = d * math.sqrt(samples)
48+
approx_p = 2 * math.exp(-(d_alpha**2))
49+
50+
# clamp value to [0, 1]
51+
return 1.0 if approx_p > 1.0 else 0.0 if approx_p < 0.0 else approx_p
52+
53+
@staticmethod
54+
def check_acceptance(
55+
d_statistic: float, n_samples: int, m_samples: int, accepted_level: float
56+
) -> bool:
57+
"""
58+
For a given test statistic, d, and the respective sample sizes `n` and `m`, this function
59+
checks whether the null hypothesis can be rejected for an accepted significance level.
60+
61+
For more information, check out the `Wikipedia entry <https://w.wiki/5May>`_.
62+
"""
3063

31-
# Currently, the calculation will be performed locally through scipy
32-
# In future versions, an implementation where either the database engine
33-
# (1) calculates the CDF
34-
# or even (2) calculates the KS test
35-
# can be expected
36-
statistic, p_value = ks_2samp(data, data2)
64+
def c(alpha: float):
65+
return math.sqrt(-math.log(alpha / 2.0 + 1e-10) * 0.5)
3766

38-
return p_value
67+
return d_statistic <= c(accepted_level) * math.sqrt(
68+
(n_samples + m_samples) / (n_samples * m_samples)
69+
)
70+
71+
@staticmethod
72+
def calculate_statistic(
73+
engine,
74+
table1_def: Tuple[Union[Selectable, str], str],
75+
table2_def: Tuple[Union[Selectable, str], str],
76+
) -> Tuple[float, Optional[float], int, int]:
77+
78+
# retrieve test statistic d, as well as sample sizes m and n
79+
d_statistic, n_samples, m_samples = db_access.get_ks_2sample(
80+
engine, table1=table1_def, table2=table2_def
81+
)
82+
83+
# calculate approximate p-value
84+
p_value = KolmogorovSmirnov2Sample.approximate_p_value(
85+
d_statistic, n_samples, m_samples
86+
)
3987

40-
def retrieve(
41-
self, engine: sa.engine.Engine, ref: DataReference
42-
) -> Tuple[Any, OptionalSelections]:
43-
return db_access.get_column(engine, ref)
88+
return d_statistic, p_value, n_samples, m_samples
4489

45-
def compare(
46-
self, value_factual: Any, value_target: Any
47-
) -> Tuple[bool, Optional[str]]:
90+
def test(self, engine: sa.engine.Engine) -> TestResult:
91+
92+
# get query selections and column names for target columns
93+
selection1 = self.ref.data_source.get_clause(engine)
94+
column1 = self.ref.get_column(engine)
95+
selection2 = self.ref2.data_source.get_clause(engine)
96+
column2 = self.ref2.get_column(engine)
97+
98+
d_statistic, p_value, n_samples, m_samples = self.calculate_statistic(
99+
engine, (selection1, column1), (selection2, column2)
100+
)
101+
102+
# calculate test acceptance
103+
result = self.check_acceptance(
104+
d_statistic, n_samples, m_samples, self.significance_level
105+
)
48106

49-
p_value = self.calculate_2sample_ks_test(value_factual, value_target)
50-
result = p_value >= self.significance_level
51107
assertion_text = (
52-
f"2-Sample Kolmogorov-Smirnov between {self.ref.get_string()} and {self.target_prefix}"
53-
f"has p-value {p_value} < {self.significance_level}"
54-
f"{self.condition_string}"
108+
f"Null hypothesis (H0) for the 2-sample Kolmogorov-Smirnov test was rejected, i.e., "
109+
f"the two samples ({self.ref.get_string()} and {self.target_prefix})"
110+
f" do not originate from the same distribution."
111+
f"The test results are d={d_statistic}"
55112
)
113+
if p_value is not None:
114+
assertion_text += f"and {p_value=}"
115+
116+
if not result:
117+
return TestResult.failure(assertion_text)
56118

57-
return result, assertion_text
119+
return TestResult.success()

src/datajudge/db_access.py

+86
Original file line numberDiff line numberDiff line change
@@ -902,3 +902,89 @@ def get_column_array_agg(
902902
for t in result
903903
]
904904
return result, selections
905+
906+
907+
def get_ks_2sample(
908+
engine: sa.engine.Engine, table1: tuple, table2: tuple
909+
) -> tuple[float, int, int]:
910+
"""
911+
Runs the query for the two-sample Kolmogorov-Smirnov test and returns the test statistic d.
912+
"""
913+
914+
# make sure we have a string representation here
915+
table1_selection, col1 = str(table1[0]), str(table1[1])
916+
table2_selection, col2 = str(table2[0]), str(table2[1])
917+
918+
if is_mssql(engine): # "tempdb.dbo".table_name -> tempdb.dbo.table_name
919+
table1_selection = table1_selection.replace('"', "")
920+
table2_selection = table2_selection.replace('"', "")
921+
922+
# for RawQueryDataSource this could be a whole subquery and will therefore need to be wrapped
923+
if "SELECT" in table1_selection:
924+
table1_selection = f"({table1_selection})"
925+
table2_selection = f"({table2_selection})"
926+
927+
# for a more extensive explanation, see:
928+
# https://github.com/Quantco/datajudge/pull/28#issuecomment-1165587929
929+
ks_query_string = f"""
930+
WITH
931+
tab1 AS ( -- Step 0: Prepare data source and value column
932+
SELECT {col1} as val FROM {table1_selection}
933+
),
934+
tab2 AS (
935+
SELECT {col2} as val FROM {table2_selection}
936+
),
937+
tab1_cdf AS ( -- Step 1: Calculate the CDF over the value column
938+
SELECT val, cume_dist() over (order by val) as cdf
939+
FROM tab1
940+
),
941+
tab2_cdf AS (
942+
SELECT val, cume_dist() over (order by val) as cdf
943+
FROM tab2
944+
),
945+
tab1_grouped AS ( -- Step 2: Remove unnecessary values, s.t. we have (x, cdf(x)) rows only
946+
SELECT val, MAX(cdf) as cdf
947+
FROM tab1_cdf
948+
GROUP BY val
949+
),
950+
tab2_grouped AS (
951+
SELECT val, MAX(cdf) as cdf
952+
FROM tab2_cdf
953+
GROUP BY val
954+
),
955+
joined_cdf AS ( -- Step 3: combine the cdfs
956+
SELECT coalesce(tab1_grouped.val, tab2_grouped.val) as v, tab1_grouped.cdf as cdf1, tab2_grouped.cdf as cdf2
957+
FROM tab1_grouped FULL OUTER JOIN tab2_grouped ON tab1_grouped.val = tab2_grouped.val
958+
),
959+
-- Step 4: Create a grouper id based on the value count; this is just a helper for forward-filling
960+
grouped_cdf AS (
961+
SELECT v,
962+
COUNT(cdf1) over (order by v) as _grp1,
963+
cdf1,
964+
COUNT(cdf2) over (order by v) as _grp2,
965+
cdf2
966+
FROM joined_cdf
967+
),
968+
-- Step 5: Forward-Filling: Select first non-null value per group (defined in the prev. step)
969+
filled_cdf AS (
970+
SELECT v,
971+
first_value(cdf1) over (partition by _grp1 order by v) as cdf1_filled,
972+
first_value(cdf2) over (partition by _grp2 order by v) as cdf2_filled
973+
FROM grouped_cdf),
974+
-- Step 6: Replace NULL values (at the beginning) with 0 to calculate difference
975+
replaced_nulls AS (
976+
SELECT coalesce(cdf1_filled, 0) as cdf1, coalesce(cdf2_filled, 0) as cdf2
977+
FROM filled_cdf)
978+
-- Step 7: Calculate final statistic as max. distance
979+
SELECT MAX(ABS(cdf1 - cdf2)) FROM replaced_nulls;
980+
"""
981+
982+
d_statistic = engine.execute(ks_query_string).scalar()
983+
n_samples = engine.execute(
984+
f"SELECT COUNT(*) FROM {table1_selection} as n_table"
985+
).scalar()
986+
m_samples = engine.execute(
987+
f"SELECT COUNT(*) FROM {table2_selection} as m_table"
988+
).scalar()
989+
990+
return d_statistic, n_samples, m_samples

src/datajudge/requirements.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1268,9 +1268,14 @@ def add_ks_2sample_constraint(
12681268
The signifance_level must be a value between 0.0 and 1.0.
12691269
"""
12701270

1271-
if significance_level < 0.0 or significance_level > 1.0:
1271+
if not column1 or not column2:
12721272
raise ValueError(
1273-
"The requested significance level has to be between 0.0 and 1.0. Default is 0.05."
1273+
"Column names have to be given for this test's functionality."
1274+
)
1275+
1276+
if significance_level <= 0.0 or significance_level > 1.0:
1277+
raise ValueError(
1278+
"The requested significance level has to be in `(0.0, 1.0]`. Default is 0.05."
12741279
)
12751280

12761281
ref = DataReference(self.data_source, [column1], condition=condition1)

tests/integration/conftest.py

+32
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import itertools
33
import os
4+
import random
45
import urllib.parse
56

67
import pytest
@@ -661,6 +662,37 @@ def groupby_aggregation_table_incorrect(engine, metadata):
661662
return TEST_DB_NAME, SCHEMA, table_name
662663

663664

665+
@pytest.fixture(scope="module")
666+
def random_normal_table(engine, metadata):
667+
"""
668+
Table containing 10_000 randomly distributed values with mean = 0 and std.dev = 1.
669+
"""
670+
table_name = "random_normal_table"
671+
columns = [
672+
sa.Column("value_0_1", sa.Float()),
673+
sa.Column("value_005_1", sa.Float()),
674+
sa.Column("value_02_1", sa.Float()),
675+
sa.Column("value_1_1", sa.Float()),
676+
]
677+
row_size = 10_000
678+
random.seed(0)
679+
rand1 = [random.gauss(0, 1) for _ in range(row_size)]
680+
rand2 = [random.gauss(0.05, 1) for _ in range(row_size)]
681+
rand3 = [random.gauss(0.2, 1) for _ in range(row_size)]
682+
rand4 = [random.gauss(1, 1) for _ in range(row_size)]
683+
data = [
684+
{
685+
"value_0_1": rand1[idx],
686+
"value_005_1": rand2[idx],
687+
"value_02_1": rand3[idx],
688+
"value_1_1": rand4[idx],
689+
}
690+
for idx in range(row_size)
691+
]
692+
_handle_table(engine, metadata, table_name, columns, data)
693+
return TEST_DB_NAME, SCHEMA, table_name
694+
695+
664696
@pytest.fixture(scope="module")
665697
def capitalization_table(engine, metadata):
666698
table_name = "capitalization"

0 commit comments

Comments
 (0)