diff --git a/src/tests/catwalk_tests/test_protected_groups_generators.py b/src/tests/catwalk_tests/test_protected_groups_generators.py index 398dc5bcd..50c507017 100644 --- a/src/tests/catwalk_tests/test_protected_groups_generators.py +++ b/src/tests/catwalk_tests/test_protected_groups_generators.py @@ -10,12 +10,12 @@ def create_demographics_table(db_engine, data): db_engine.execute( """drop table if exists demographics; - create table demographics (person_id int, event_date date, race text, sex text) + create table demographics (person_id int, event_date date, race text, sex text, age_bucket int) """ ) for event in data: db_engine.execute( - "insert into demographics values (%s, %s, %s, %s)", event + "insert into demographics values (%s, %s, %s, %s, %s)", event ) @@ -31,15 +31,15 @@ def create_cohort_table(db_engine, data): def default_demographics(): return [ - (1, datetime(2015, 12, 30), 'aa', 'male'), - (1, datetime(2016, 2, 1), 'aa', 'male'), - (1, datetime(2016, 3, 1), 'aa', 'female'), - (2, datetime(2015, 12, 30), 'wh', 'male'), - (2, datetime(2016, 3, 1), 'wh', 'male'), - (3, datetime(2015, 12, 30), 'aa', 'male'), - (3, datetime(2016, 3, 1), 'aa', 'male'), - (5, datetime(2016, 2, 1), 'wh', 'female'), - (5, datetime(2016, 3, 1), 'wh', 'female'), + (1, datetime(2015, 12, 30), 'aa', 'male', 1), + (1, datetime(2016, 2, 1), 'aa', 'male', 1), + (1, datetime(2016, 3, 1), 'aa', 'female', 1), + (2, datetime(2015, 12, 30), 'wh', 'male', 3), + (2, datetime(2016, 3, 1), 'wh', 'male', 3), + (3, datetime(2015, 12, 30), 'aa', 'male', 1), + (3, datetime(2016, 3, 1), 'aa', 'male', 1), + (5, datetime(2016, 2, 1), 'wh', 'female', 2), + (5, datetime(2016, 3, 1), 'wh', 'female', 2), ] @@ -65,26 +65,26 @@ def default_cohort(): def assert_data(table_generator): expected_output = [ - (1, date(2016, 1, 1), 'aa', 'male', 'abcdef'), - (1, date(2016, 3, 1), 'aa', 'male', 'abcdef'), - (1, date(2016, 4, 1), 'aa', 'female', 'abcdef'), - (2, date(2016, 1, 1), 'wh', 'male', 'abcdef'), - (2, date(2016, 3, 1), 'wh', 'male', 'abcdef'), - (2, date(2016, 4, 1), 'wh', 'male', 'abcdef'), - (3, date(2016, 1, 1), 'aa', 'male', 'abcdef'), - (3, date(2016, 3, 1), 'aa', 'male', 'abcdef'), - (3, date(2016, 4, 1), 'aa', 'male', 'abcdef'), - (4, date(2016, 1, 1), None, None, 'abcdef'), - (4, date(2016, 3, 1), None, None, 'abcdef'), - (4, date(2016, 4, 1), None, None, 'abcdef'), - (5, date(2016, 1, 1), None, None, 'abcdef'), - (5, date(2016, 3, 1), 'wh', 'female', 'abcdef'), - (5, date(2016, 4, 1), 'wh', 'female', 'abcdef'), + (1, date(2016, 1, 1), 'aa', 'male', '1', 'abcdef'), + (1, date(2016, 3, 1), 'aa', 'male', '1', 'abcdef'), + (1, date(2016, 4, 1), 'aa', 'female', '1', 'abcdef'), + (2, date(2016, 1, 1), 'wh', 'male', '3', 'abcdef'), + (2, date(2016, 3, 1), 'wh', 'male', '3', 'abcdef'), + (2, date(2016, 4, 1), 'wh', 'male', '3', 'abcdef'), + (3, date(2016, 1, 1), 'aa', 'male', '1', 'abcdef'), + (3, date(2016, 3, 1), 'aa', 'male', '1', 'abcdef'), + (3, date(2016, 4, 1), 'aa', 'male', '1', 'abcdef'), + (4, date(2016, 1, 1), None, None, None, 'abcdef'), + (4, date(2016, 3, 1), None, None, None, 'abcdef'), + (4, date(2016, 4, 1), None, None, None, 'abcdef'), + (5, date(2016, 1, 1), None, None, None, 'abcdef'), + (5, date(2016, 3, 1), 'wh', 'female', '2', 'abcdef'), + (5, date(2016, 4, 1), 'wh', 'female', '2', 'abcdef'), ] results = list( table_generator.db_engine.execute( f""" - select entity_id, as_of_date, race, sex, cohort_hash + select entity_id, as_of_date, race, sex, age_bucket, cohort_hash from {table_generator.protected_groups_table_name} order by entity_id, as_of_date """ @@ -102,7 +102,7 @@ def test_protected_groups_generator_replace(): create_cohort_table(engine, cohort_data) table_generator = ProtectedGroupsGenerator( from_obj="demographics", - attribute_columns=['race', 'sex'], + attribute_columns=['race', 'sex', 'age_bucket'], entity_id_column="person_id", knowledge_date_column="event_date", db_engine=engine, @@ -138,7 +138,7 @@ def test_protected_groups_generator_noreplace(): create_cohort_table(engine, cohort_data) table_generator = ProtectedGroupsGenerator( from_obj="demographics", - attribute_columns=['race', 'sex'], + attribute_columns=['race', 'sex', 'age_bucket'], entity_id_column="person_id", knowledge_date_column="event_date", db_engine=engine, @@ -164,3 +164,40 @@ def test_protected_groups_generator_noreplace(): ) table_generator.generate.assert_not_called() assert_data(table_generator) + + +def test_as_dataframe(): + attribute_columns = ['race', 'sex', 'age_bucket'] + demographics_data = default_demographics() + cohort_data = default_cohort() + with testing.postgresql.Postgresql() as postgresql: + engine = create_engine(postgresql.url()) + create_demographics_table(engine, demographics_data) + create_cohort_table(engine, cohort_data) + table_generator = ProtectedGroupsGenerator( + from_obj="demographics", + attribute_columns=attribute_columns, + entity_id_column="person_id", + knowledge_date_column="event_date", + db_engine=engine, + protected_groups_table_name="protected_groups_abcdef", + replace=True + ) + as_of_dates = [ + datetime(2016, 1, 1), + datetime(2016, 3, 1), + datetime(2016, 4, 1), + ] + table_generator.generate_all_dates( + as_of_dates, + cohort_table_name='cohort_abcdef', + cohort_hash='abcdef' + ) + protected_df = table_generator.as_dataframe( + as_of_dates, + cohort_hash='abcdef' + ) + assert(protected_df.shape == (15, 3)) + assert(set(attribute_columns).issubset(protected_df.columns)) + for attr_col in attribute_columns: + assert(protected_df[attr_col].dtype == 'object') \ No newline at end of file diff --git a/src/triage/component/catwalk/protected_groups_generators.py b/src/triage/component/catwalk/protected_groups_generators.py index 02d0c8fcc..68a3b6b34 100644 --- a/src/triage/component/catwalk/protected_groups_generators.py +++ b/src/triage/component/catwalk/protected_groups_generators.py @@ -158,5 +158,6 @@ def as_dataframe(self, as_of_dates, cohort_hash): parse_dates=["as_of_date"], index_col=MatrixStore.indices, ) + protected_df[self.attribute_columns] = protected_df[self.attribute_columns].astype(str) del protected_df['cohort_hash'] return protected_df