Skip to content

Commit

Permalink
Ensure attribute columns are strings in protected_df (closes #875) (#876
Browse files Browse the repository at this point in the history
)

* ensure attribute columns are str type in protected_df

* add unit test for as_dataframe
  • Loading branch information
shaycrk authored Dec 7, 2021
1 parent b4ff916 commit a97c270
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 29 deletions.
95 changes: 66 additions & 29 deletions src/tests/catwalk_tests/test_protected_groups_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
def create_demographics_table(db_engine, data):
db_engine.execute(
"""drop table if exists demographics;
create table demographics (person_id int, event_date date, race text, sex text)
create table demographics (person_id int, event_date date, race text, sex text, age_bucket int)
"""
)
for event in data:
db_engine.execute(
"insert into demographics values (%s, %s, %s, %s)", event
"insert into demographics values (%s, %s, %s, %s, %s)", event
)


Expand All @@ -31,15 +31,15 @@ def create_cohort_table(db_engine, data):

def default_demographics():
return [
(1, datetime(2015, 12, 30), 'aa', 'male'),
(1, datetime(2016, 2, 1), 'aa', 'male'),
(1, datetime(2016, 3, 1), 'aa', 'female'),
(2, datetime(2015, 12, 30), 'wh', 'male'),
(2, datetime(2016, 3, 1), 'wh', 'male'),
(3, datetime(2015, 12, 30), 'aa', 'male'),
(3, datetime(2016, 3, 1), 'aa', 'male'),
(5, datetime(2016, 2, 1), 'wh', 'female'),
(5, datetime(2016, 3, 1), 'wh', 'female'),
(1, datetime(2015, 12, 30), 'aa', 'male', 1),
(1, datetime(2016, 2, 1), 'aa', 'male', 1),
(1, datetime(2016, 3, 1), 'aa', 'female', 1),
(2, datetime(2015, 12, 30), 'wh', 'male', 3),
(2, datetime(2016, 3, 1), 'wh', 'male', 3),
(3, datetime(2015, 12, 30), 'aa', 'male', 1),
(3, datetime(2016, 3, 1), 'aa', 'male', 1),
(5, datetime(2016, 2, 1), 'wh', 'female', 2),
(5, datetime(2016, 3, 1), 'wh', 'female', 2),
]


Expand All @@ -65,26 +65,26 @@ def default_cohort():

def assert_data(table_generator):
expected_output = [
(1, date(2016, 1, 1), 'aa', 'male', 'abcdef'),
(1, date(2016, 3, 1), 'aa', 'male', 'abcdef'),
(1, date(2016, 4, 1), 'aa', 'female', 'abcdef'),
(2, date(2016, 1, 1), 'wh', 'male', 'abcdef'),
(2, date(2016, 3, 1), 'wh', 'male', 'abcdef'),
(2, date(2016, 4, 1), 'wh', 'male', 'abcdef'),
(3, date(2016, 1, 1), 'aa', 'male', 'abcdef'),
(3, date(2016, 3, 1), 'aa', 'male', 'abcdef'),
(3, date(2016, 4, 1), 'aa', 'male', 'abcdef'),
(4, date(2016, 1, 1), None, None, 'abcdef'),
(4, date(2016, 3, 1), None, None, 'abcdef'),
(4, date(2016, 4, 1), None, None, 'abcdef'),
(5, date(2016, 1, 1), None, None, 'abcdef'),
(5, date(2016, 3, 1), 'wh', 'female', 'abcdef'),
(5, date(2016, 4, 1), 'wh', 'female', 'abcdef'),
(1, date(2016, 1, 1), 'aa', 'male', '1', 'abcdef'),
(1, date(2016, 3, 1), 'aa', 'male', '1', 'abcdef'),
(1, date(2016, 4, 1), 'aa', 'female', '1', 'abcdef'),
(2, date(2016, 1, 1), 'wh', 'male', '3', 'abcdef'),
(2, date(2016, 3, 1), 'wh', 'male', '3', 'abcdef'),
(2, date(2016, 4, 1), 'wh', 'male', '3', 'abcdef'),
(3, date(2016, 1, 1), 'aa', 'male', '1', 'abcdef'),
(3, date(2016, 3, 1), 'aa', 'male', '1', 'abcdef'),
(3, date(2016, 4, 1), 'aa', 'male', '1', 'abcdef'),
(4, date(2016, 1, 1), None, None, None, 'abcdef'),
(4, date(2016, 3, 1), None, None, None, 'abcdef'),
(4, date(2016, 4, 1), None, None, None, 'abcdef'),
(5, date(2016, 1, 1), None, None, None, 'abcdef'),
(5, date(2016, 3, 1), 'wh', 'female', '2', 'abcdef'),
(5, date(2016, 4, 1), 'wh', 'female', '2', 'abcdef'),
]
results = list(
table_generator.db_engine.execute(
f"""
select entity_id, as_of_date, race, sex, cohort_hash
select entity_id, as_of_date, race, sex, age_bucket, cohort_hash
from {table_generator.protected_groups_table_name}
order by entity_id, as_of_date
"""
Expand All @@ -102,7 +102,7 @@ def test_protected_groups_generator_replace():
create_cohort_table(engine, cohort_data)
table_generator = ProtectedGroupsGenerator(
from_obj="demographics",
attribute_columns=['race', 'sex'],
attribute_columns=['race', 'sex', 'age_bucket'],
entity_id_column="person_id",
knowledge_date_column="event_date",
db_engine=engine,
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_protected_groups_generator_noreplace():
create_cohort_table(engine, cohort_data)
table_generator = ProtectedGroupsGenerator(
from_obj="demographics",
attribute_columns=['race', 'sex'],
attribute_columns=['race', 'sex', 'age_bucket'],
entity_id_column="person_id",
knowledge_date_column="event_date",
db_engine=engine,
Expand All @@ -164,3 +164,40 @@ def test_protected_groups_generator_noreplace():
)
table_generator.generate.assert_not_called()
assert_data(table_generator)


def test_as_dataframe():
attribute_columns = ['race', 'sex', 'age_bucket']
demographics_data = default_demographics()
cohort_data = default_cohort()
with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
create_demographics_table(engine, demographics_data)
create_cohort_table(engine, cohort_data)
table_generator = ProtectedGroupsGenerator(
from_obj="demographics",
attribute_columns=attribute_columns,
entity_id_column="person_id",
knowledge_date_column="event_date",
db_engine=engine,
protected_groups_table_name="protected_groups_abcdef",
replace=True
)
as_of_dates = [
datetime(2016, 1, 1),
datetime(2016, 3, 1),
datetime(2016, 4, 1),
]
table_generator.generate_all_dates(
as_of_dates,
cohort_table_name='cohort_abcdef',
cohort_hash='abcdef'
)
protected_df = table_generator.as_dataframe(
as_of_dates,
cohort_hash='abcdef'
)
assert(protected_df.shape == (15, 3))
assert(set(attribute_columns).issubset(protected_df.columns))
for attr_col in attribute_columns:
assert(protected_df[attr_col].dtype == 'object')
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,6 @@ def as_dataframe(self, as_of_dates, cohort_hash):
parse_dates=["as_of_date"],
index_col=MatrixStore.indices,
)
protected_df[self.attribute_columns] = protected_df[self.attribute_columns].astype(str)
del protected_df['cohort_hash']
return protected_df

0 comments on commit a97c270

Please sign in to comment.