From eca1063cc8d7cad2eb7a6849c975c67f240f5539 Mon Sep 17 00:00:00 2001 From: Alistair Johnson Date: Mon, 17 Jun 2024 14:37:09 -0400 Subject: [PATCH] add arg based bids loading, fix up search and demographics pages --- src/b2aiprep/app/Dashboard.py | 8 +- src/b2aiprep/app/pages/1_Demographics.py | 99 +++++++++++++++++-- .../app/pages/2_Subject_Questionnaires.py | 34 +++---- src/b2aiprep/app/pages/3_Audio.py | 10 +- src/b2aiprep/app/pages/4_Validate.py | 9 +- src/b2aiprep/app/pages/5_Search.py | 35 ++++--- 6 files changed, 136 insertions(+), 59 deletions(-) diff --git a/src/b2aiprep/app/Dashboard.py b/src/b2aiprep/app/Dashboard.py index c446d5d..6969756 100644 --- a/src/b2aiprep/app/Dashboard.py +++ b/src/b2aiprep/app/Dashboard.py @@ -9,8 +9,6 @@ from b2aiprep.dataset import VBAIDataset - - def parse_args(args): parser = argparse.ArgumentParser('Dashboard for audio data in BIDS format.') parser.add_argument('bids_dir', help='Folder with the BIDS data', default='output') @@ -20,7 +18,11 @@ def parse_args(args): bids_dir = Path(args.bids_dir).resolve() if not bids_dir.exists(): raise ValueError(f"Folder {bids_dir} does not exist.") -dataset = VBAIDataset(bids_dir) + +if 'bids_dir' not in st.session_state: + st.session_state.bids_dir = bids_dir.as_posix() + +dataset = VBAIDataset(st.session_state.bids_dir) st.set_page_config( page_title="b2ai voice", diff --git a/src/b2aiprep/app/pages/1_Demographics.py b/src/b2aiprep/app/pages/1_Demographics.py index 7a1d23a..4b5d06a 100644 --- a/src/b2aiprep/app/pages/1_Demographics.py +++ b/src/b2aiprep/app/pages/1_Demographics.py @@ -20,16 +20,99 @@ def get_bids_data(): schema_name = 'qgenericdemographicsschema' dataset = get_bids_data() df = dataset.load_and_pivot_questionnaire(schema_name) -st.write(df) -st.markdown("# Subsets") -st.write(df.groupby('gender_identity').size().sort_values(ascending=False)) +# st.markdown("## Age Distribution") +# st.write(df['age'].describe()) + + +st.markdown("## Gender Identity") +gender_counts = df['gender_identity'].value_counts() +st.bar_chart(gender_counts) + + +st.markdown("## Sexual Orientation") +orientation_counts = df['sexual_orientation'].value_counts() +st.bar_chart(orientation_counts) + + +st.markdown("## Race") +race_columns = [col for col in df.columns if 'race___' in col] +race_counts = df[race_columns].sum() +st.bar_chart(race_counts) + +st.markdown("## Ethnicity") +ethnicity_counts = df['ethnicity'].value_counts() +st.bar_chart(ethnicity_counts) + +st.markdown("## Marital Status") +marital_status_columns = [col for col in df.columns if 'marital_status___' in col] +marital_status_counts = df[marital_status_columns].sum() +st.bar_chart(marital_status_counts) + + +st.markdown("## Employment Status") +employ_status_columns = [col for col in df.columns if 'employ_status___' in col] +employ_status_counts = df[employ_status_columns].sum() +st.bar_chart(employ_status_counts) +# we need to do some harmonization of the USA / CA salaries +st.markdown("## Household Income") income = df[['household_income_usa', 'household_income_ca']].copy() -income['household_income'] = None -idx = income['household_income_usa'].notnull() -income.loc[idx, 'household_income'] = 'USD ' + income.loc[idx, 'household_income_usa'] +income_cols = list(income.columns) +# get the upper range of their income, or if they only have one, the upper limit +for col in income_cols: + # need to extract the *last* instance of this pattern + income[f'{col}_lower'] = income[col].str.extract(r'\$(\d+,\d+)\s*$') + income[f'{col}_lower'] = income[f'{col}_lower'].str.replace(',', '') + income[f'{col}_lower'] = pd.to_numeric(income[f'{col}_lower'], errors='coerce') + + # now create an integer which is higher if the value is higher + income[f'{col}_seq_num'] = income[f'{col}_lower'].rank(ascending=True, method='dense') + income[f'{col}_seq_num'] = income[f'{col}_seq_num'].fillna(-1).astype(int) + + + idxNan = income[col].str.contains('Prefer not to answer').fillna(False) + income.loc[idxNan, f'{col}_seq_num'] = 0 +income['seq_num'] = income[['household_income_usa_seq_num', 'household_income_ca_seq_num']].max(axis=1) +# get our look-up dict for each +income_lookups = {} +for col in income_cols: + income_lookups[col] = income[ + [col, f'{col}_seq_num'] + ].drop_duplicates().set_index(f'{col}_seq_num').to_dict()[col] + +income['country'] = 'Missing' +idx = income['household_income_usa'].notnull() +income.loc[idx, 'country'] = 'USA' idx = income['household_income_ca'].notnull() -income.loc[idx, 'household_income'] = 'CAD ' + income.loc[idx, 'household_income_ca'] -st.write(income.groupby('household_income').size().sort_values(ascending=False)) \ No newline at end of file +income.loc[idx, 'country'] = 'Canada' + +income_grouped = pd.crosstab(income['seq_num'], income['country']) +# as it turns out, both countries have the same values for income brackets +# so we can just use one of the mapping tables +n_missing = (income['seq_num'] == -1).sum() +income_grouped.index = income_grouped.index.map(income_lookups[col]) +income_grouped = income_grouped[['USA', 'Canada']] +income_grouped.index.name = 'Household Income (CAD or USD)' +# st.write(income_grouped) + +# grouped barchart +income_grouped = income_grouped.reset_index() +income_grouped = income_grouped.melt(id_vars='Household Income (CAD or USD)', var_name='Country', value_name='Count') +chart = ( + alt.Chart(income_grouped) + .mark_bar() + .encode( + x=alt.X('Household Income (CAD or USD):O', axis=alt.Axis(title='Income')), + y=alt.Y('Count:Q', axis=alt.Axis(title='Count')), + color='Country:N', + tooltip=['Household Income (CAD or USD)', 'Count', 'Country'] + ) +) +st.altair_chart(chart, use_container_width=True) +st.write(f"{n_missing} missing a household income.") + + +st.markdown("## Full dataframe") +st.write(df) diff --git a/src/b2aiprep/app/pages/2_Subject_Questionnaires.py b/src/b2aiprep/app/pages/2_Subject_Questionnaires.py index 6196059..e57d189 100644 --- a/src/b2aiprep/app/pages/2_Subject_Questionnaires.py +++ b/src/b2aiprep/app/pages/2_Subject_Questionnaires.py @@ -2,18 +2,17 @@ import pandas as pd import altair as alt -from b2aiprep.dataset import VBAIDataset from b2aiprep.constants import GENERAL_QUESTIONNAIRES +from b2aiprep.dataset import VBAIDataset st.set_page_config(page_title="Subject questionnaires", page_icon="📊") + @st.cache_data -def get_bids_data(): - # TODO: allow user to specify input folder input - dataset = VBAIDataset('output') - return dataset +def get_questionnaire_dataframe(_dataset: VBAIDataset, questionnaire_name: str): + return _dataset.load_and_pivot_questionnaire(questionnaire_name) -dataset = get_bids_data() +dataset = VBAIDataset(st.session_state.bids_dir) # st.markdown("# Disease prevalence") @@ -23,12 +22,11 @@ def get_bids_data(): st.markdown("# Subject questionnaires") - questionnaire_name = st.selectbox( 'Which questionnaire would you like to review?', GENERAL_QUESTIONNAIRES ) -df = dataset.load_and_pivot_questionnaire(questionnaire_name) +df = get_questionnaire_dataframe(dataset, questionnaire_name) # there are always a lot of binary questions, # so we have a bar chart which includes all checked/unchecked or yes/no questions @@ -61,23 +59,21 @@ def get_bids_data(): grp['percentage'] = (grp['count'] * 100).round(1) grp.index.name = 'linkId' - # create bar chart in altair - # add a tick for each bar to make sure the labels are on the plot for linkId - chart = ( - alt.Chart(grp) - .mark_bar() - .encode( - x=alt.X('linkId:O', axis=alt.Axis(title='Question', tickCount=grp.shape[0])), - y=alt.Y('percentage:Q', axis=alt.Axis(title='Percentage')), - tooltip=['linkId', 'percentage'] - ) + # create horizontal bar chart + chart = alt.Chart(grp).mark_bar().encode( + x='percentage:Q', + y=alt.Y('linkId:N', sort='-x'), + tooltip=['count', 'percentage'] + ).properties( + width=600, + height=400 ) st.altair_chart(chart, use_container_width=True) st.markdown("# Bar plot") -questions = df.columns.tolist() +questions = df.columns.tolist()[1:] question_name = st.selectbox( 'Which question would you like to display?', questions diff --git a/src/b2aiprep/app/pages/3_Audio.py b/src/b2aiprep/app/pages/3_Audio.py index 3da7767..c733282 100644 --- a/src/b2aiprep/app/pages/3_Audio.py +++ b/src/b2aiprep/app/pages/3_Audio.py @@ -9,17 +9,13 @@ from b2aiprep.dataset import VBAIDataset from b2aiprep.process import Audio, specgram -st.set_page_config(page_title="Audio", page_icon="📊") +dataset = VBAIDataset(st.session_state.bids_dir) -@st.cache_data -def get_bids_data(): - # TODO: allow user to specify input folder input - dataset = VBAIDataset('output') - return dataset +st.set_page_config(page_title="Audio", page_icon="📊") st.markdown("# Audio recordings") -dataset = get_bids_data() + subject_paths = dataset.find_subjects() subjects = [path.stem[4:] for path in subject_paths] subject = st.selectbox( diff --git a/src/b2aiprep/app/pages/4_Validate.py b/src/b2aiprep/app/pages/4_Validate.py index 5cb9adc..d9f3737 100644 --- a/src/b2aiprep/app/pages/4_Validate.py +++ b/src/b2aiprep/app/pages/4_Validate.py @@ -3,17 +3,12 @@ from b2aiprep.dataset import VBAIDataset -st.set_page_config(page_title="Audio", page_icon="📊") +dataset = VBAIDataset(st.session_state.bids_dir) -@st.cache_data -def get_bids_data(): - # TODO: allow user to specify input folder input - dataset = VBAIDataset('output') - return dataset +st.set_page_config(page_title="Audio", page_icon="📊") st.markdown("# Dataset validation") -dataset = get_bids_data() audio_files_exist = dataset.validate_audio_files_exist() if audio_files_exist: diff --git a/src/b2aiprep/app/pages/5_Search.py b/src/b2aiprep/app/pages/5_Search.py index acfbe2b..4898f3d 100644 --- a/src/b2aiprep/app/pages/5_Search.py +++ b/src/b2aiprep/app/pages/5_Search.py @@ -96,6 +96,10 @@ def extract_descriptions(df: pd.DataFrame) -> t.Tuple[t.List[str], t.List[str]]: The following text box allows you to semantically search the data dictionary. You can use it to find the name for fields collected in the study. + + The dataframe column "Form Name" can be used to determine the schema name which contains + the data. For example, the "q_generic_demographics" form name corresponds to the + "qgenericdemographicsschema" schema name. """ ) search_string = st.text_input("Search string", "age") @@ -113,24 +117,25 @@ def extract_descriptions(df: pd.DataFrame) -> t.Tuple[t.List[str], t.List[str]]: # Sort sentences by similarity score in descending order (the most similar ones are first) sorted_index = np.argsort(sims)[::-1] - -sentences_sorted = np.array(corpus)[sorted_index] field_ids_sorted = np.array(field_ids)[sorted_index] sims = np.array(sims)[sorted_index] -col1, col2 = st.columns(2) +final_df = rcdict.copy() +final_df = final_df.loc[final_df["Variable / Field Name"].isin(field_ids_sorted)] -with col1: - cutoff = st.number_input("Cutoff", 0.0, 1.0, 0.3) - plt.plot(sims) - plt.title("Cosine similarity") - st.pyplot(plt) +# map similarity into the dataframe +sim_mapper = {field_id: sim for field_id, sim in zip(field_ids_sorted, sims)} +final_df['similarity'] = final_df["Variable / Field Name"].map(sim_mapper) +cols_reordered = ["similarity"] + [c for c in final_df.columns if c != "similarity"] +final_df = final_df[cols_reordered] +final_df = final_df.sort_values("similarity", ascending=False) +cutoff = st.number_input("Cutoff (controls relevance of results)", 0.0, 1.0, 0.3) -with col2: - sentences_to_show = sentences_sorted[sims > cutoff].tolist() - field_ids_to_show = field_ids_sorted[sims > cutoff].tolist() - final_df = pd.DataFrame( - {"field_ids": field_ids_to_show, "field_desc": sentences_to_show} - ) - st.table(final_df) +# only show up to the cutoff +idx = final_df["similarity"] > cutoff +st.write(final_df.loc[idx]) + +plt.plot(sims) +plt.title("Cosine similarity") +st.pyplot(plt)