Skip to content

Commit da059c5

Browse files
authored
Minor refactoring of KS test. (#41)
1 parent 2152b5b commit da059c5

File tree

4 files changed

+70
-55
lines changed

4 files changed

+70
-55
lines changed

src/datajudge/constraints/stats.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import math
22
import warnings
3-
from typing import Optional, Tuple, Union
3+
from typing import Optional, Tuple
44

55
import sqlalchemy as sa
6-
from sqlalchemy.sql import Selectable
76

87
from .. import db_access
98
from ..db_access import DataReference
@@ -71,15 +70,20 @@ def c(alpha: float):
7170
@staticmethod
7271
def calculate_statistic(
7372
engine,
74-
table1_def: Tuple[Union[Selectable, str], str],
75-
table2_def: Tuple[Union[Selectable, str], str],
73+
ref1: DataReference,
74+
ref2: DataReference,
7675
) -> Tuple[float, Optional[float], int, int]:
7776

7877
# retrieve test statistic d, as well as sample sizes m and n
79-
d_statistic, n_samples, m_samples = db_access.get_ks_2sample(
80-
engine, table1=table1_def, table2=table2_def
78+
d_statistic = db_access.get_ks_2sample(
79+
engine,
80+
ref1,
81+
ref2,
8182
)
8283

84+
n_samples, _ = db_access.get_row_count(engine, ref1)
85+
m_samples, _ = db_access.get_row_count(engine, ref2)
86+
8387
# calculate approximate p-value
8488
p_value = KolmogorovSmirnov2Sample.approximate_p_value(
8589
d_statistic, n_samples, m_samples
@@ -90,13 +94,11 @@ def calculate_statistic(
9094
def test(self, engine: sa.engine.Engine) -> TestResult:
9195

9296
# get query selections and column names for target columns
93-
selection1 = self.ref.data_source.get_clause(engine)
94-
column1 = self.ref.get_column(engine)
95-
selection2 = self.ref2.data_source.get_clause(engine)
96-
column2 = self.ref2.get_column(engine)
9797

9898
d_statistic, p_value, n_samples, m_samples = self.calculate_statistic(
99-
engine, (selection1, column1), (selection2, column2)
99+
engine,
100+
self.ref,
101+
self.ref2,
100102
)
101103

102104
# calculate test acceptance

src/datajudge/db_access.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -905,34 +905,28 @@ def get_column_array_agg(
905905

906906

907907
def get_ks_2sample(
908-
engine: sa.engine.Engine, table1: tuple, table2: tuple
909-
) -> tuple[float, int, int]:
908+
engine: sa.engine.Engine,
909+
ref1: DataReference,
910+
ref2: DataReference,
911+
) -> float:
910912
"""
911913
Runs the query for the two-sample Kolmogorov-Smirnov test and returns the test statistic d.
912914
"""
913-
914-
# make sure we have a string representation here
915-
table1_selection, col1 = str(table1[0]), str(table1[1])
916-
table2_selection, col2 = str(table2[0]), str(table2[1])
917-
918-
if is_mssql(engine): # "tempdb.dbo".table_name -> tempdb.dbo.table_name
919-
table1_selection = table1_selection.replace('"', "")
920-
table2_selection = table2_selection.replace('"', "")
921-
922-
# for RawQueryDataSource this could be a whole subquery and will therefore need to be wrapped
923-
if "SELECT" in table1_selection:
924-
table1_selection = f"({table1_selection})"
925-
table2_selection = f"({table2_selection})"
915+
# For mssql: "tempdb.dbo".table_name -> tempdb.dbo.table_name
916+
table1_str = str(ref1.data_source.get_clause(engine)).replace('"', "")
917+
col1 = ref1.get_column(engine)
918+
table2_str = str(ref2.data_source.get_clause(engine)).replace('"', "")
919+
col2 = ref2.get_column(engine)
926920

927921
# for a more extensive explanation, see:
928922
# https://github.com/Quantco/datajudge/pull/28#issuecomment-1165587929
929923
ks_query_string = f"""
930924
WITH
931925
tab1 AS ( -- Step 0: Prepare data source and value column
932-
SELECT {col1} as val FROM {table1_selection}
926+
SELECT {col1} as val FROM {table1_str}
933927
),
934928
tab2 AS (
935-
SELECT {col2} as val FROM {table2_selection}
929+
SELECT {col2} as val FROM {table2_str}
936930
),
937931
tab1_cdf AS ( -- Step 1: Calculate the CDF over the value column
938932
SELECT val, cume_dist() over (order by val) as cdf
@@ -980,14 +974,7 @@ def get_ks_2sample(
980974
"""
981975

982976
d_statistic = engine.execute(ks_query_string).scalar()
983-
n_samples = engine.execute(
984-
f"SELECT COUNT(*) FROM {table1_selection} as n_table"
985-
).scalar()
986-
m_samples = engine.execute(
987-
f"SELECT COUNT(*) FROM {table2_selection} as m_table"
988-
).scalar()
989-
990-
return d_statistic, n_samples, m_samples
977+
return d_statistic
991978

992979

993980
def get_regex_violations(engine, ref, aggregated, regex, n_counterexamples):

tests/integration/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def groupby_aggregation_table_incorrect(engine, metadata):
665665
@pytest.fixture(scope="module")
666666
def random_normal_table(engine, metadata):
667667
"""
668-
Table containing 10_000 randomly distributed values with mean = 0 and std.dev = 1.
668+
Table with normally distributed values of varying means and sd 1.
669669
"""
670670
table_name = "random_normal_table"
671671
columns = [

tests/integration/test_integration.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,25 +1897,59 @@ def test_diff_average_between():
18971897
@pytest.mark.parametrize(
18981898
"data",
18991899
[
1900-
(identity, "col_int", "col_int", None, 1.0),
1901-
(identity, "col_int", "col_int", Condition("col_int >= 3"), 1.0),
1900+
(identity, "col_int", "col_int", None, None, 1.0),
1901+
(
1902+
identity,
1903+
"col_int",
1904+
"col_int",
1905+
Condition("col_int >= 3"),
1906+
Condition("col_int >= 3"),
1907+
1.0,
1908+
),
19021909
],
19031910
)
19041911
def test_ks_2sample_constraint_perfect_between(engine, int_table1, data):
19051912
"""
19061913
Test Kolmogorov-Smirnov for the same column -> p-value should be perfect 1.0.
19071914
"""
1908-
(operation, col_1, col_2, condition, significance_level) = data
1915+
(operation, col_1, col_2, condition1, condition2, significance_level) = data
19091916
req = requirements.BetweenRequirement.from_tables(*int_table1, *int_table1)
19101917
req.add_ks_2sample_constraint(
19111918
column1=col_1,
19121919
column2=col_2,
1913-
condition1=condition,
1914-
condition2=condition,
1920+
condition1=condition1,
1921+
condition2=condition2,
19151922
significance_level=significance_level,
19161923
)
1924+
test_result = req[0].test(engine)
1925+
assert operation(test_result.outcome), test_result.failure_message
19171926

1918-
assert operation(req[0].test(engine).outcome)
1927+
1928+
# TODO: Enable this test once the bug is fixed.
1929+
@pytest.mark.skip(reason="This is a known bug and unintended behaviour.")
1930+
@pytest.mark.parametrize(
1931+
"data",
1932+
[
1933+
(negation, "col_int", "col_int", None, Condition("col_int >= 10"), 1.0),
1934+
],
1935+
)
1936+
def test_ks_2sample_constraint_perfect_between_different_condition(
1937+
engine, int_table1, data
1938+
):
1939+
"""
1940+
Test Kolmogorov-Smirnov for the same column -> p-value should be perfect 1.0.
1941+
"""
1942+
(operation, col_1, col_2, condition1, condition2, significance_level) = data
1943+
req = requirements.BetweenRequirement.from_tables(*int_table1, *int_table1)
1944+
req.add_ks_2sample_constraint(
1945+
column1=col_1,
1946+
column2=col_2,
1947+
condition1=condition1,
1948+
condition2=condition2,
1949+
significance_level=significance_level,
1950+
)
1951+
test_result = req[0].test(engine)
1952+
assert operation(test_result.outcome), test_result.failure_message
19191953

19201954

19211955
@pytest.mark.parametrize(
@@ -1933,8 +1967,8 @@ def test_ks_2sample_constraint_wrong_between(
19331967
req.add_ks_2sample_constraint(
19341968
column1=col_1, column2=col_2, significance_level=min_p_value
19351969
)
1936-
1937-
assert operation(req[0].test(engine).outcome)
1970+
test_result = req[0].test(engine)
1971+
assert operation(test_result.outcome), test_result.failure_message
19381972

19391973

19401974
@pytest.mark.parametrize(
@@ -1964,7 +1998,7 @@ def test_ks_2sample_random(engine, random_normal_table, configuration):
19641998
column1=col_1, column2=col_2, significance_level=min_p_value
19651999
)
19662000
test_result = req[0].test(engine)
1967-
assert operation(test_result.outcome)
2001+
assert operation(test_result.outcome), test_result.failure_message
19682002

19692003

19702004
@pytest.mark.parametrize(
@@ -1983,20 +2017,12 @@ def test_ks_2sample_implementation(engine, random_normal_table, configuration):
19832017
ref = DataReference(tds, columns=[col_1])
19842018
ref2 = DataReference(tds, columns=[col_2])
19852019

1986-
# retrieve table selections from data references
1987-
selection1 = ref.data_source.get_clause(engine)
1988-
column1 = ref.get_column(engine)
1989-
selection2 = ref2.data_source.get_clause(engine)
1990-
column2 = ref2.get_column(engine)
1991-
19922020
(
19932021
d_statistic,
19942022
p_value,
19952023
n_samples,
19962024
m_samples,
1997-
) = KolmogorovSmirnov2Sample.calculate_statistic(
1998-
engine, (selection1, column1), (selection2, column2)
1999-
)
2025+
) = KolmogorovSmirnov2Sample.calculate_statistic(engine, ref, ref2)
20002026

20012027
assert (
20022028
abs(d_statistic - expected_d) <= 1e-10

0 commit comments

Comments
 (0)