From b936cc6fab478b586f0023eade706e64759f6bbd Mon Sep 17 00:00:00 2001 From: Google Health Date: Tue, 21 Mar 2023 16:37:19 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 518415186 --- .../CONTRIBUTING.md | 1 + .../README.md | 10 + .../cluster_utils.py | 352 +++++ .../data_utils.py | 72 + .../demo.ipynb | 1398 +++++++++++++++++ .../requirements.txt | 5 + 6 files changed, 1838 insertions(+) create mode 100644 colorectal_lymph_node_metastasis_prediction/CONTRIBUTING.md create mode 100644 colorectal_lymph_node_metastasis_prediction/README.md create mode 100644 colorectal_lymph_node_metastasis_prediction/cluster_utils.py create mode 100644 colorectal_lymph_node_metastasis_prediction/data_utils.py create mode 100644 colorectal_lymph_node_metastasis_prediction/demo.ipynb create mode 100644 colorectal_lymph_node_metastasis_prediction/requirements.txt diff --git a/colorectal_lymph_node_metastasis_prediction/CONTRIBUTING.md b/colorectal_lymph_node_metastasis_prediction/CONTRIBUTING.md new file mode 100644 index 00000000..6d6c95b5 --- /dev/null +++ b/colorectal_lymph_node_metastasis_prediction/CONTRIBUTING.md @@ -0,0 +1 @@ +We are not accepting contributions for this project. diff --git a/colorectal_lymph_node_metastasis_prediction/README.md b/colorectal_lymph_node_metastasis_prediction/README.md new file mode 100644 index 00000000..72d9d242 --- /dev/null +++ b/colorectal_lymph_node_metastasis_prediction/README.md @@ -0,0 +1,10 @@ +# Lymph node metastasis prediction: machine-learned feature generation/selection and model evaluation + +This repo contains the code needed to generate and select cluster-based +machine-learned features while controlling for baseline features as described +in "Predicting lymph node metastasis from primary tumor histology and clinicopathologic factors in colorectal cancer using deep learning." For an example of how this +code may be used, see demo.ipynb. + +NOTE: the content of this research code repository (i) is not intended to be a +medical device; and (ii) is not intended for clinical use of any kind, including +but not limited to diagnosis or prognosis. \ No newline at end of file diff --git a/colorectal_lymph_node_metastasis_prediction/cluster_utils.py b/colorectal_lymph_node_metastasis_prediction/cluster_utils.py new file mode 100644 index 00000000..8f9109fa --- /dev/null +++ b/colorectal_lymph_node_metastasis_prediction/cluster_utils.py @@ -0,0 +1,352 @@ +"""Utils need to generate machine-learned features via clustering. + +For example use, see demo.ipynb. +""" + +import numpy as np +import pandas as pd +import scipy +import sklearn +from sklearn.cluster import MiniBatchKMeans +import statsmodels.api as sm + + +_CASE_ID = 'case_id' + + +def train_k_means_model(embedding_dict, k, batch_size=10000): + """Generate KMeans models with k clusters. + + Args: + embedding_dict: dict mapping case id: embeddings. Embeddings have shape + [num_patches, emb_dimensions]. + k: number of clusters. + batch_size: size of batch to use for MiniBatchKMeans training. + + Returns: + Trained kmeans model. + """ + x = [] + for case_id in embedding_dict: + x.append(embedding_dict[case_id]) + x = np.concatenate(x) + print(f'Embeddings shape: {x.shape}') + + return MiniBatchKMeans( + n_clusters=k, random_state=0, batch_size=batch_size).fit(x) + + +def get_cluster_quantitation_df(embedding_dict, model): + """Get case-level cluster quantitation vectors. + + Computes the fraction of patches for each case that belong to each cluster. + + Args: + embedding_dict: dict mapping case id: embeddings. Embeddings have shape + [num_patches, emb_dimensions]. + model: trained kmeans model. + + Returns: + pd.DataFrame of cluster quantitation vectors. + """ + cq = {} + for case_id, embeddings in embedding_dict.items(): + cluster_distances = model.transform(embeddings) + cq[case_id] = _distances_to_cluster_quantitation(cluster_distances) + + df = pd.DataFrame.from_dict(cq, orient='index') + cols = list(df.columns) + df[_CASE_ID] = df.index + df = df[[_CASE_ID] + cols] + return df + + +def select_top_clusters( + df_train, + df_valid, + label_col, + baseline_cols, + cluster_cols, + n): + """Select top clusters and return these clusters with respective AUCs. + + The set of n `cluster_cols` that lead to the greatest gain in AUC over + `baseline_cols` on `df_valid` are chosen via forward stepwise selection. + + Args: + df_train: pd.Dataframe with training data. + df_valid: pd.Dataframe with validation data. + label_col: column to use for labels. + baseline_cols: a list of column names in `df` corresponding to baseline + features. + cluster_cols: a list of column names in `df` corresponding to cluster + quantitation features. + n: number of clusters to select. + + Returns: + pd.DataFrame of cluster ids and AUCs. + """ + cluster_cols = cluster_cols.copy() + selected_cluster_cols = [] + results = [] + + for i in range(n): + cluster_id, auc = _select_next_cluster( + df_train=df_train, + df_valid=df_valid, + label_col=label_col, + baseline_cols=baseline_cols, + selected_cluster_cols=selected_cluster_cols, + candidate_cluster_cols=cluster_cols) + selected_cluster_cols.append(cluster_id) + cluster_cols.remove(cluster_id) + results.append({'order': i, 'cluster_id': cluster_id, 'auc': auc}) + return pd.DataFrame(results) + + +def likelihood_ratio_test( + df, + label_col, + baseline_cols, + cluster_cols): + """Likelihood ratio test for significance of `cluster_cols`. + + Likelihood ratio test comparing the full model fit on the combination of + `baseline_cols` and `cluster_cols` and the null model fit only on + `baseline_cols`. + + Args: + df: pd.Dataframe with data on which to fit logistic regression models. + label_col: column to use for labels. + baseline_cols: a list of column names in `df` corresponding to baseline + features. + cluster_cols: a list of column names in `df` corresponding to cluster + quantitation features. + + Returns: + p-value of likelihood ratio test between alternative hypothesis (with + clusters) model and null hypothesis (without clusters) model. + """ + lr_alt, _, _ = train_eval_lr( + df_train=df, + df_valid=df, + label_col=label_col, + baseline_cols=baseline_cols, + cluster_cols=cluster_cols) + lr_null, _, _ = train_eval_lr( + df_train=df, + df_valid=df, + label_col=label_col, + baseline_cols=baseline_cols, + cluster_cols=[]) + + def lrt(ll_alt, ll_null, k): + test_stat = 2 * (ll_alt - ll_null) + return scipy.stats.chi2.sf(test_stat, k) + + return lrt(lr_alt.llf, lr_null.llf, len(cluster_cols)) + + +def get_odds_ratios_p_values(df, label_col, baseline_cols, cluster_cols): + """Train LR model and return odds ratios and p-values of model parameters. + + Args: + df: pd.Dataframe with data on which to fit logistic regression models. + label_col: column to use for labels. + baseline_cols: a list of column names in `df` corresponding to baseline + features. + cluster_cols: a list of column names in `df` corresponding to cluster + quantitation features. + + Returns: + pd.Dataframe containing model parameters with their odds ratios and + p-values. + """ + lr, _, _ = train_eval_lr( + df_train=df, + df_valid=df, + label_col=label_col, + baseline_cols=baseline_cols, + cluster_cols=cluster_cols) + + point = np.exp(lr.params).apply(lambda x: f'{x:.2f}') + lower = np.exp(lr.conf_int()[0]).apply(lambda x: f'{x:.2f}') + upper = np.exp(lr.conf_int()[1]).apply(lambda x: f'{x:.2f}') + or_ci = point + ' ' + '[' + lower + ', ' + upper + ']' + + def p_value_to_str(p, num_digits=3): + min_value = 10**(-num_digits) + if p < min_value: + return f'<{min_value}' + pattern = f'%.{str(num_digits)}f' + return pattern % p + + p = lr.pvalues.apply(p_value_to_str) + return pd.DataFrame({'OR': or_ci, 'p': p}) + + +def get_eval_aucs(df_train, df_valid, label_col, baseline_cols, cluster_cols): + """Evaluate models' predictive performance. + + Trains logistic regression models with baseline_cols, cluster_cols, and both + sets of features and computes AUC on separate validation dataset. + + Args: + df_train: pd.Dataframe with training data. + df_valid: pd.Dataframe with validation data (e.g., `validation` or `test`). + label_col: column to use for labels. + baseline_cols: a list of column names in `df_train` and `df_test` + corresponding to baseline features. + + cluster_cols: a list of column names in `df_train` and `df_test` + corresponding to cluster quantitation features. + + + Returns: + pd.Dataframe containing AUCs. + """ + _, _, auc_baseline = train_eval_lr( + df_train=df_train, + df_valid=df_valid, + label_col=label_col, + baseline_cols=baseline_cols, + cluster_cols=[]) + _, _, auc_cluster = train_eval_lr( + df_train=df_train, + df_valid=df_valid, + label_col=label_col, + baseline_cols=[], + cluster_cols=cluster_cols) + _, _, auc_all = train_eval_lr( + df_train=df_train, + df_valid=df_valid, + label_col=label_col, + baseline_cols=baseline_cols, + cluster_cols=cluster_cols) + return pd.DataFrame({ + 'Baseline features only': [auc_baseline], + 'Cluster features only': [auc_cluster], + 'Baseline + cluster features': [auc_all]}, index=['AUC']) + + +def train_eval_lr(df_train, df_valid, label_col, + baseline_cols, cluster_cols): + """Train and evaluate LR. + + Train logistic regression model on `df_train` using specified `baseline_cols` + and `cluster_cols`, then evaluate performance on `df_valid`. + + Args: + df_train: pd.Dataframe with training data. + df_valid: pd.Dataframe with validation data. + label_col: column to use for labels. + baseline_cols: a list of column names in `df` corresponding to baseline + features. + cluster_cols: a list of column names in `df` corresponding to cluster + quantitation features. + + Returns: + tuple: (LR model, validation set predictions, AUC evaluated on `df_valid`). + """ + x_train, y_train = _get_lr_data( + df_train, + label_col, + baseline_cols=baseline_cols, + cluster_cols=cluster_cols) + x_valid, y_valid = _get_lr_data( + df_valid, + label_col, + baseline_cols=baseline_cols, + cluster_cols=cluster_cols) + lr = train_lr(x_train, y_train) + y_hat = lr.predict(x_valid) + auc = sklearn.metrics.roc_auc_score(y_valid, y_hat) + return lr, y_hat, auc + + +def train_lr(x, y): + """Returns trained logistic regression model.""" + return sm.Logit(y, x).fit(disp=0) + + +def _distances_to_cluster_quantitation(cluster_distances): + """Converts distances to cluster centroids to cluster quantitation vector. + + Args: + cluster_distances: ndarray with shape [num_patches, k]. + + Returns: + ndarray with shape [k] reflecting percent of patches assigned to each + cluster in the case. + """ + if len(cluster_distances.shape) != 2: + raise ValueError('Expect cluster distances to be of rank 2') + k = cluster_distances.shape[1] + min_distances = cluster_distances.min(axis=1, keepdims=True) + min_distances = np.tile(min_distances, [1, k]) + cluster_quants = np.mean((cluster_distances == min_distances), axis=0) + assert cluster_quants.shape == (k,) + return cluster_quants + + +def _get_lr_data(df, label_col, baseline_cols=None, cluster_cols=None): + """Get X and y np.arrays for fitting logistic regression. + + Args: + df: pd.Dataframe with labels, baseline features, and cluster features. + label_col: name of column in `df` to use for labels. + baseline_cols: a list of column names in `df` corresponding to baseline + features. + cluster_cols: a list of column names in `df` corresponding to cluster + quantitation features. + + Returns: + Tuple of (features, labels). + """ + x = df[baseline_cols + cluster_cols] + y = df[label_col] + return x, y + + +def _select_next_cluster( + df_train, + df_valid, + label_col, + baseline_cols, + selected_cluster_cols, + candidate_cluster_cols): + """Select next best candidate cluster. + + Selects the next cluster from `candidate_cluster_cols` that gives the best AUC + on `df_valid` when added to a logistic regression model trained on `df_train` + using `baseline_cols` and `selected_cluster_cols` as features. + + Args: + df_train: pd.Dataframe with training data. + df_valid: pd.Dataframe with validation data. + label_col: column to use for labels. + baseline_cols: a list of column names in `df` corresponding to baseline + features. + selected_cluster_cols: a list of column names in `df` corresponding to + candidate cluster quantitation features that have already been added to + the model. + candidate_cluster_cols: a list of column names in `df` corresponding to + cluster quantitation features that are candidates for being added to the + model. + + Returns: + Tuple: (next best cluster, AUC obtained when using this cluster). + """ + aucs = {} + AUC_INDEX = 2 # pylint: disable=invalid-name + for cluster_col in candidate_cluster_cols: + assert cluster_col not in selected_cluster_cols + cluster_cols = selected_cluster_cols + [cluster_col] + aucs[cluster_col] = train_eval_lr( + df_train=df_train, + df_valid=df_valid, + label_col=label_col, + baseline_cols=baseline_cols, + cluster_cols=cluster_cols)[AUC_INDEX] + top_cluster_id = max(aucs, key=aucs.get) + return top_cluster_id, aucs[top_cluster_id] diff --git a/colorectal_lymph_node_metastasis_prediction/data_utils.py b/colorectal_lymph_node_metastasis_prediction/data_utils.py new file mode 100644 index 00000000..2f9169d6 --- /dev/null +++ b/colorectal_lymph_node_metastasis_prediction/data_utils.py @@ -0,0 +1,72 @@ +"""Data-processing utils need to prep data for feature generation and selection. + +For example use, see demo.ipynb. +""" + +import pandas as pd + + +def bin_age(age, start_cutoff=60, stop_cutoff=80, increment=10): + """Categorize age by bins. + + Args: + age: age to bin. + start_cutoff: first cutoff to use forming bins. + stop_cutoff: last cutoff to use forming bins. + increment: difference in age between bins. + + Returns: + string representing binned age. + """ + if pd.isnull([age]): + return age + age = float(age) + + last_cutoff = 0 + for age_cutoff in range(start_cutoff, stop_cutoff + increment, increment): + if age < age_cutoff: + return f'{last_cutoff}-{age_cutoff-1}' + last_cutoff = age_cutoff + return f'>={stop_cutoff}' + + +def prep_features(df, feature_cols, label_cols): + """Returns a pd.DataFrame suitable for modeling. + + 1) Select only desired `feature_cols`. + 2) Remove rows with nan values. + 3) Convert categorical features to dummy variables. + 4) Remove constant `feature_cols`. + + Args: + df: pd.DataFrame containing `feature_cols`, `label_cols`. + feature_cols: a list of columns in `df` containing regression features. + label_cols: a list of column names in `df` containing labels. These columns + are not coded as dummies if they are categorical. + + Returns: + tuple of (dataframe, expanded feature cols). + """ + # Select subset of required columns. + df = df.copy()[label_cols + feature_cols] + # Remove rows with missing values. + n_rows = df.shape[0] + df = df.dropna() + delta = n_rows - df.shape[0] + if delta > 0: + print('Dropped %d rows due to missing values.' % delta) + + # Convert categorical cols to dummy vars. + df_labels = df[label_cols] + df = pd.get_dummies(df[feature_cols], drop_first=True) + expanded_feature_cols = list(df.columns) + + # Remove constant feature columns + df = df.loc[:, (df != df.iloc[0]).any()] + delta = set(expanded_feature_cols) - set(df.columns) + if delta: + print('Dropped %s constant feature columns.' % list(delta)) + + # Add labels to regression + df = pd.concat([df_labels, df], axis=1) + return df, expanded_feature_cols diff --git a/colorectal_lymph_node_metastasis_prediction/demo.ipynb b/colorectal_lymph_node_metastasis_prediction/demo.ipynb new file mode 100644 index 00000000..56bb3522 --- /dev/null +++ b/colorectal_lymph_node_metastasis_prediction/demo.ipynb @@ -0,0 +1,1398 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "50qgDOVeYjtW" + }, + "source": [ + "# Sample notebook demonstrating cluster generation, selection, and evaluation\n", + "\n", + "This notebook demonstrates how to utilize the provided open-source code on sample data to generate cluster-derived machine learned features while controlling for baseline features, and how to evaluate the models derived from utilizing these features." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "executionInfo": { + "elapsed": 16, + "status": "ok", + "timestamp": 1676676402129, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "-XFqsoF5Ytl1" + }, + "outputs": [], + "source": [ + "import os\n", + "import math\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import sklearn\n", + "\n", + "import cluster_utils\n", + "import data_utils" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9TfvU63njn1u" + }, + "source": [ + "# Constants" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "executionInfo": { + "elapsed": 8, + "status": "ok", + "timestamp": 1676676405449, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "o1jdFxobuqzy" + }, + "outputs": [], + "source": [ + "N_CASES = 1000\n", + "EMBEDDING_DIM = 4\n", + "PATCHES_PER_CASE = 1000\n", + "K = 200\n", + "CLUSTER_COLS = list(range(K))\n", + "N_TOP_CLUSTERS = 5\n", + "LABEL_COL = 'lnm'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dlAQt3Agjp1r" + }, + "source": [ + "# Data Generation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g2hpkmqRtqpx" + }, + "source": [ + "## Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "executionInfo": { + "elapsed": 85, + "status": "ok", + "timestamp": 1676676405603, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "nEMxU0M7tqxy" + }, + "outputs": [], + "source": [ + "# Generate random metadata\n", + "rs = np.random.RandomState(0)\n", + "\n", + "BASELINE_COLS = [\n", + " 'age_bins',\n", + " 'sex',\n", + " 't_stage',\n", + " 'grade',\n", + " 'venous_inv',\n", + " 'lymphovascular_inv'\n", + "]\n", + "\n", + "df = pd.DataFrame({\n", + " 'case_id': [f'case_{i}' for i in range(N_CASES)],\n", + " 'split': rs.choice(['train', 'validation', 'test'], p=[0.6, 0.2, 0.2], size=N_CASES),\n", + " 'age': rs.randint(low=50, high=90, size=N_CASES),\n", + " 'sex': rs.choice(['M', 'W'], size=N_CASES),\n", + " 't_stage': rs.choice(['T3', 'T4'], size=N_CASES),\n", + " 'grade': rs.choice(['G1', 'G2', 'G3'], size=N_CASES),\n", + " 'venous_inv': rs.choice([0, 1], size=N_CASES),\n", + " 'lymphovascular_inv': rs.choice([0, 1], size=N_CASES),\n", + "})\n", + "\n", + "df['age_bins'] = df['age'].apply(data_utils.bin_age)\n", + "df, BASELINE_COLS = data_utils.prep_features(df, BASELINE_COLS, ['case_id', 'split'])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "height": 268 + }, + "executionInfo": { + "elapsed": 173, + "status": "ok", + "timestamp": 1676676405848, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "Ul1IzLX1uoxa", + "outputId": "584b5a86-8522-44d1-ac12-9510803843f8" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \u003cdiv id=\"df-e1db2a29-cd47-4094-963c-fc128f81aae2\"\u003e\n", + " \u003cdiv class=\"colab-df-container\"\u003e\n", + " \u003cdiv\u003e\n", + "\u003cstyle scoped\u003e\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "\u003c/style\u003e\n", + "\u003ctable border=\"1\" class=\"dataframe\"\u003e\n", + " \u003cthead\u003e\n", + " \u003ctr style=\"text-align: right;\"\u003e\n", + " \u003cth\u003e\u003c/th\u003e\n", + " \u003cth\u003ecase_id\u003c/th\u003e\n", + " \u003cth\u003esplit\u003c/th\u003e\n", + " \u003cth\u003evenous_inv\u003c/th\u003e\n", + " \u003cth\u003elymphovascular_inv\u003c/th\u003e\n", + " \u003cth\u003eage_bins_60-69\u003c/th\u003e\n", + " \u003cth\u003eage_bins_70-79\u003c/th\u003e\n", + " \u003cth\u003eage_bins_\u0026gt;=80\u003c/th\u003e\n", + " \u003cth\u003esex_W\u003c/th\u003e\n", + " \u003cth\u003et_stage_T4\u003c/th\u003e\n", + " \u003cth\u003egrade_G2\u003c/th\u003e\n", + " \u003cth\u003egrade_G3\u003c/th\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/thead\u003e\n", + " \u003ctbody\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e0\u003c/th\u003e\n", + " \u003ctd\u003ecase_0\u003c/td\u003e\n", + " \u003ctd\u003etrain\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e1\u003c/th\u003e\n", + " \u003ctd\u003ecase_1\u003c/td\u003e\n", + " \u003ctd\u003evalidation\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e2\u003c/th\u003e\n", + " \u003ctd\u003ecase_2\u003c/td\u003e\n", + " \u003ctd\u003evalidation\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e3\u003c/th\u003e\n", + " \u003ctd\u003ecase_3\u003c/td\u003e\n", + " \u003ctd\u003etrain\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e4\u003c/th\u003e\n", + " \u003ctd\u003ecase_4\u003c/td\u003e\n", + " \u003ctd\u003etrain\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/tbody\u003e\n", + "\u003c/table\u003e\n", + "\u003c/div\u003e\n", + " \u003cbutton class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e1db2a29-cd47-4094-963c-fc128f81aae2')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\"\u003e\n", + " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \n", + " \u003cstyle\u003e\n", + " .colab-df-container {\n", + " display:flex;\n", + " flex-wrap:wrap;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " \u003c/style\u003e\n", + "\n", + " \u003cscript\u003e\n", + " const buttonEl =\n", + " document.querySelector('#df-e1db2a29-cd47-4094-963c-fc128f81aae2 button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-e1db2a29-cd47-4094-963c-fc128f81aae2');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '\u003ca target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb\u003edata table notebook\u003c/a\u003e'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", + " " + ], + "text/plain": [ + " case_id split venous_inv lymphovascular_inv ... sex_W t_stage_T4 grade_G2 grade_G3\n", + "0 case_0 train 1 0 ... 0 0 0 1\n", + "1 case_1 validation 0 1 ... 0 0 0 0\n", + "2 case_2 validation 1 0 ... 1 1 0 1\n", + "3 case_3 train 1 1 ... 0 0 1 0\n", + "4 case_4 train 1 0 ... 0 1 0 1\n", + "\n", + "[5 rows x 11 columns]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O9cNTPq2tq3G" + }, + "source": [ + "## Embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "executionInfo": { + "elapsed": 93, + "status": "ok", + "timestamp": 1676676406055, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "s7V0mnfztq9V" + }, + "outputs": [], + "source": [ + "rs = np.random.RandomState(0)\n", + "embeddings = {}\n", + "for i in range(N_CASES):\n", + " embeddings[f'case_{i}'] = rs.uniform(size=[PATCHES_PER_CASE, EMBEDDING_DIM])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OKTfQvSotrCl" + }, + "source": [ + "## Cluster Quantitations" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "executionInfo": { + "elapsed": 5863, + "status": "ok", + "timestamp": 1676676411988, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "b_qyiGj7trJI", + "outputId": "f96c1b6c-3424-4428-aa88-d96d7eb7a8c8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Embeddings shape: (610000, 4)\n" + ] + } + ], + "source": [ + "# Fit kmeans model on train embeddings\n", + "train_ids = set(df[df['split'] == 'train']['case_id'])\n", + "train_embeddings = {case_id: emb for case_id, emb in embeddings.items() if case_id in train_ids}\n", + "kmeans_model = cluster_utils.train_k_means_model(train_embeddings, K)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "executionInfo": { + "elapsed": 6240, + "status": "ok", + "timestamp": 1676676418365, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "nlLbTAwCvKYg" + }, + "outputs": [], + "source": [ + "# Compute cluster quantitations\n", + "df_cq = cluster_utils.get_cluster_quantitation_df(embeddings, kmeans_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "executionInfo": { + "elapsed": 3568, + "status": "ok", + "timestamp": 1676676422016, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "0r6PQa0gvNja" + }, + "outputs": [], + "source": [ + "# Standardize cluster quantitations\n", + "df_cq_train = cluster_utils.get_cluster_quantitation_df(train_embeddings, kmeans_model)\n", + "scaler = sklearn.preprocessing.StandardScaler().fit(df_cq_train[CLUSTER_COLS])\n", + "df_cq[CLUSTER_COLS] = scaler.transform(df_cq[CLUSTER_COLS])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fE9Lw7mptrk5" + }, + "source": [ + "## Labels" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "executionInfo": { + "elapsed": 32, + "status": "ok", + "timestamp": 1676676422113, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "iAPCnKWYwhEl" + }, + "outputs": [], + "source": [ + "# Simulate lables and associations with baseline features and cluster quantitations\n", + "def get_label(r):\n", + " rs = np.random.RandomState(int(r['case_id'][-1]))\n", + "\n", + " logit = 0\n", + " logit += 0.5 * r['venous_inv']\n", + " logit += -0.5 * r['age_bins_60-69']\n", + " logit += -1.0 * r['age_bins_70-79']\n", + " logit += -1.0 * r['age_bins_\u003e=80']\n", + " logit += 0.1 * r['sex_W']\n", + " logit += 0.5 * r['t_stage_T4']\n", + " logit += 0.1 * r['grade_G2']\n", + " logit += 0.5 * r['grade_G3']\n", + " logit += 1.0 * r['lymphovascular_inv']\n", + "\n", + " # Clusters of interest\n", + " logit += 1.0 * r[0]\n", + " logit += -1.0 * r[1]\n", + " logit += 1.0 * r[2]\n", + " logit += -1.0 * r[3]\n", + " logit += 1.0 * r[4]\n", + " prob = 1 / (1 + math.exp(-logit))\n", + " return rs.binomial(1, prob)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "executionInfo": { + "elapsed": 549, + "status": "ok", + "timestamp": 1676676422741, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "Ge1AQsGPwhup", + "outputId": "f768db33-a32a-474d-9cf9-2cea3f4a884c" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1 575\n", + "0 425\n", + "Name: lnm, dtype: int64" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = df.merge(df_cq, on='case_id')\n", + "df[LABEL_COL] = df.apply(get_label, axis=1)\n", + "df[LABEL_COL].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "executionInfo": { + "elapsed": 8, + "status": "ok", + "timestamp": 1676676422820, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "hGEwooBkwz7Z", + "outputId": "6192de8f-0c5c-49de-f40f-283824b9db58" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.6263427109974424" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Cluster of interest\n", + "sklearn.metrics.roc_auc_score(df[LABEL_COL], df[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "executionInfo": { + "elapsed": 12, + "status": "ok", + "timestamp": 1676676422909, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "26pZOL3aw5hm", + "outputId": "3b519f37-7dba-4ee0-af7f-46126eb0db3e" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5419846547314578" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Non-informative cluster\n", + "sklearn.metrics.roc_auc_score(df[LABEL_COL], df[5])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FvtZNjAYTVcd" + }, + "source": [ + "# Select top clusters" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "executionInfo": { + "elapsed": 42, + "status": "ok", + "timestamp": 1676676423021, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "kzHaP3cZxFPE" + }, + "outputs": [], + "source": [ + "df_train = df.query(\"split=='train'\")\n", + "df_valid = df.query(\"split=='validation'\")\n", + "df_test = df.query(\"split=='test'\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "height": 206 + }, + "executionInfo": { + "elapsed": 10021, + "status": "ok", + "timestamp": 1676676433116, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "VU8z5Yj6eKJT", + "outputId": "21955c59-e188-45bc-c61a-7c0e42d2d79e" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \u003cdiv id=\"df-9a2b64dd-2ab9-49eb-aba1-7169284549f1\"\u003e\n", + " \u003cdiv class=\"colab-df-container\"\u003e\n", + " \u003cdiv\u003e\n", + "\u003cstyle scoped\u003e\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "\u003c/style\u003e\n", + "\u003ctable border=\"1\" class=\"dataframe\"\u003e\n", + " \u003cthead\u003e\n", + " \u003ctr style=\"text-align: right;\"\u003e\n", + " \u003cth\u003e\u003c/th\u003e\n", + " \u003cth\u003eorder\u003c/th\u003e\n", + " \u003cth\u003ecluster_id\u003c/th\u003e\n", + " \u003cth\u003eauc\u003c/th\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/thead\u003e\n", + " \u003ctbody\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e0\u003c/th\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e4\u003c/td\u003e\n", + " \u003ctd\u003e0.682931\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e1\u003c/th\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e3\u003c/td\u003e\n", + " \u003ctd\u003e0.757471\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e2\u003c/th\u003e\n", + " \u003ctd\u003e2\u003c/td\u003e\n", + " \u003ctd\u003e1\u003c/td\u003e\n", + " \u003ctd\u003e0.802241\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e3\u003c/th\u003e\n", + " \u003ctd\u003e3\u003c/td\u003e\n", + " \u003ctd\u003e0\u003c/td\u003e\n", + " \u003ctd\u003e0.827816\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e4\u003c/th\u003e\n", + " \u003ctd\u003e4\u003c/td\u003e\n", + " \u003ctd\u003e2\u003c/td\u003e\n", + " \u003ctd\u003e0.841609\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/tbody\u003e\n", + "\u003c/table\u003e\n", + "\u003c/div\u003e\n", + " \u003cbutton class=\"colab-df-convert\" onclick=\"convertToInteractive('df-9a2b64dd-2ab9-49eb-aba1-7169284549f1')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\"\u003e\n", + " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \n", + " \u003cstyle\u003e\n", + " .colab-df-container {\n", + " display:flex;\n", + " flex-wrap:wrap;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " \u003c/style\u003e\n", + "\n", + " \u003cscript\u003e\n", + " const buttonEl =\n", + " document.querySelector('#df-9a2b64dd-2ab9-49eb-aba1-7169284549f1 button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-9a2b64dd-2ab9-49eb-aba1-7169284549f1');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '\u003ca target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb\u003edata table notebook\u003c/a\u003e'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", + " " + ], + "text/plain": [ + " order cluster_id auc\n", + "0 0 4 0.682931\n", + "1 1 3 0.757471\n", + "2 2 1 0.802241\n", + "3 3 0 0.827816\n", + "4 4 2 0.841609" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_cluster = cluster_utils.select_top_clusters(\n", + " df_train=df_train,\n", + " df_valid=df_valid,\n", + " label_col=LABEL_COL,\n", + " baseline_cols=BASELINE_COLS,\n", + " cluster_cols=CLUSTER_COLS,\n", + " n=N_TOP_CLUSTERS\n", + ")\n", + "df_cluster" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "executionInfo": { + "elapsed": 10, + "status": "ok", + "timestamp": 1676676433242, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "8Rr6CDToftvY", + "outputId": "cda4a492-2a1a-479b-bb55-37f50066612c" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[4, 3, 1, 0, 2]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "TOP_CLUSTERS = list(df_cluster['cluster_id'])\n", + "TOP_CLUSTERS" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I0sYgHnJlKrc" + }, + "source": [ + "# Eval top clusters" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bNe46v8plO9T" + }, + "source": [ + "### Analysis" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l_FJ6MJrqc6t" + }, + "source": [ + "#### Likelihood ratio test" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "executionInfo": { + "elapsed": 28, + "status": "ok", + "timestamp": 1676676433340, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "7O457bzsqieb", + "outputId": "a77e8d92-2bd1-463e-94bf-beb53fb8cfbe" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "3.326469144250546e-12" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p_value = cluster_utils.likelihood_ratio_test(\n", + " df=df_test,\n", + " label_col=LABEL_COL,\n", + " baseline_cols=BASELINE_COLS,\n", + " cluster_cols=TOP_CLUSTERS\n", + ")\n", + "p_value" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k2yqY3d3qfqr" + }, + "source": [ + "#### Evaluate multivariate odds ratios" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "height": 488 + }, + "executionInfo": { + "elapsed": 41, + "status": "ok", + "timestamp": 1676676433450, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "vvYfk87fwE8s", + "outputId": "5dfc86a9-2560-4d9f-b730-ca3b215676c0" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \u003cdiv id=\"df-4429b190-7e25-4d02-84ed-65a57f386855\"\u003e\n", + " \u003cdiv class=\"colab-df-container\"\u003e\n", + " \u003cdiv\u003e\n", + "\u003cstyle scoped\u003e\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "\u003c/style\u003e\n", + "\u003ctable border=\"1\" class=\"dataframe\"\u003e\n", + " \u003cthead\u003e\n", + " \u003ctr style=\"text-align: right;\"\u003e\n", + " \u003cth\u003e\u003c/th\u003e\n", + " \u003cth\u003eOR\u003c/th\u003e\n", + " \u003cth\u003ep\u003c/th\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/thead\u003e\n", + " \u003ctbody\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003evenous_inv\u003c/th\u003e\n", + " \u003ctd\u003e1.19 [0.59, 2.39]\u003c/td\u003e\n", + " \u003ctd\u003e0.622\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003elymphovascular_inv\u003c/th\u003e\n", + " \u003ctd\u003e3.45 [1.69, 7.02]\u003c/td\u003e\n", + " \u003ctd\u003e\u0026lt;0.001\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003eage_bins_60-69\u003c/th\u003e\n", + " \u003ctd\u003e0.67 [0.27, 1.64]\u003c/td\u003e\n", + " \u003ctd\u003e0.379\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003eage_bins_70-79\u003c/th\u003e\n", + " \u003ctd\u003e0.42 [0.17, 1.06]\u003c/td\u003e\n", + " \u003ctd\u003e0.067\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003eage_bins_\u0026gt;=80\u003c/th\u003e\n", + " \u003ctd\u003e0.36 [0.13, 1.00]\u003c/td\u003e\n", + " \u003ctd\u003e0.049\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003esex_W\u003c/th\u003e\n", + " \u003ctd\u003e0.85 [0.42, 1.71]\u003c/td\u003e\n", + " \u003ctd\u003e0.647\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003et_stage_T4\u003c/th\u003e\n", + " \u003ctd\u003e0.94 [0.47, 1.88]\u003c/td\u003e\n", + " \u003ctd\u003e0.864\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003egrade_G2\u003c/th\u003e\n", + " \u003ctd\u003e1.42 [0.62, 3.26]\u003c/td\u003e\n", + " \u003ctd\u003e0.410\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003egrade_G3\u003c/th\u003e\n", + " \u003ctd\u003e2.07 [0.93, 4.62]\u003c/td\u003e\n", + " \u003ctd\u003e0.076\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e4\u003c/th\u003e\n", + " \u003ctd\u003e2.42 [1.59, 3.68]\u003c/td\u003e\n", + " \u003ctd\u003e\u0026lt;0.001\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e3\u003c/th\u003e\n", + " \u003ctd\u003e0.53 [0.36, 0.78]\u003c/td\u003e\n", + " \u003ctd\u003e0.001\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e1\u003c/th\u003e\n", + " \u003ctd\u003e0.47 [0.31, 0.72]\u003c/td\u003e\n", + " \u003ctd\u003e\u0026lt;0.001\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e0\u003c/th\u003e\n", + " \u003ctd\u003e1.50 [1.03, 2.18]\u003c/td\u003e\n", + " \u003ctd\u003e0.037\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e2\u003c/th\u003e\n", + " \u003ctd\u003e2.49 [1.61, 3.86]\u003c/td\u003e\n", + " \u003ctd\u003e\u0026lt;0.001\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/tbody\u003e\n", + "\u003c/table\u003e\n", + "\u003c/div\u003e\n", + " \u003cbutton class=\"colab-df-convert\" onclick=\"convertToInteractive('df-4429b190-7e25-4d02-84ed-65a57f386855')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\"\u003e\n", + " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \n", + " \u003cstyle\u003e\n", + " .colab-df-container {\n", + " display:flex;\n", + " flex-wrap:wrap;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " \u003c/style\u003e\n", + "\n", + " \u003cscript\u003e\n", + " const buttonEl =\n", + " document.querySelector('#df-4429b190-7e25-4d02-84ed-65a57f386855 button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-4429b190-7e25-4d02-84ed-65a57f386855');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '\u003ca target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb\u003edata table notebook\u003c/a\u003e'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", + " " + ], + "text/plain": [ + " OR p\n", + "venous_inv 1.19 [0.59, 2.39] 0.622\n", + "lymphovascular_inv 3.45 [1.69, 7.02] \u003c0.001\n", + "age_bins_60-69 0.67 [0.27, 1.64] 0.379\n", + "age_bins_70-79 0.42 [0.17, 1.06] 0.067\n", + "age_bins_\u003e=80 0.36 [0.13, 1.00] 0.049\n", + "sex_W 0.85 [0.42, 1.71] 0.647\n", + "t_stage_T4 0.94 [0.47, 1.88] 0.864\n", + "grade_G2 1.42 [0.62, 3.26] 0.410\n", + "grade_G3 2.07 [0.93, 4.62] 0.076\n", + "4 2.42 [1.59, 3.68] \u003c0.001\n", + "3 0.53 [0.36, 0.78] 0.001\n", + "1 0.47 [0.31, 0.72] \u003c0.001\n", + "0 1.50 [1.03, 2.18] 0.037\n", + "2 2.49 [1.61, 3.86] \u003c0.001" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "odds_ratios = cluster_utils.get_odds_ratios_p_values(\n", + " df=df_test,\n", + " label_col=LABEL_COL,\n", + " baseline_cols=BASELINE_COLS,\n", + " cluster_cols=TOP_CLUSTERS\n", + ")\n", + "odds_ratios" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p_YlD5Hrqa9P" + }, + "source": [ + "### Evaluate predictive performance of model" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": { + "height": 81 + }, + "executionInfo": { + "elapsed": 52, + "status": "ok", + "timestamp": 1676676433617, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "P9C1eYOGqcAl", + "outputId": "b89c3c9a-e4d5-400a-b6fa-ac6ad18c3ea1" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \u003cdiv id=\"df-fbb1dd54-79f3-4239-ab2f-2233c0fc06da\"\u003e\n", + " \u003cdiv class=\"colab-df-container\"\u003e\n", + " \u003cdiv\u003e\n", + "\u003cstyle scoped\u003e\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "\u003c/style\u003e\n", + "\u003ctable border=\"1\" class=\"dataframe\"\u003e\n", + " \u003cthead\u003e\n", + " \u003ctr style=\"text-align: right;\"\u003e\n", + " \u003cth\u003e\u003c/th\u003e\n", + " \u003cth\u003eBaseline features only\u003c/th\u003e\n", + " \u003cth\u003eCluster features only\u003c/th\u003e\n", + " \u003cth\u003eBaseline + cluster features\u003c/th\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/thead\u003e\n", + " \u003ctbody\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003eAUC\u003c/th\u003e\n", + " \u003ctd\u003e0.618361\u003c/td\u003e\n", + " \u003ctd\u003e0.785176\u003c/td\u003e\n", + " \u003ctd\u003e0.813417\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/tbody\u003e\n", + "\u003c/table\u003e\n", + "\u003c/div\u003e\n", + " \u003cbutton class=\"colab-df-convert\" onclick=\"convertToInteractive('df-fbb1dd54-79f3-4239-ab2f-2233c0fc06da')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\"\u003e\n", + " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \n", + " \u003cstyle\u003e\n", + " .colab-df-container {\n", + " display:flex;\n", + " flex-wrap:wrap;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " \u003c/style\u003e\n", + "\n", + " \u003cscript\u003e\n", + " const buttonEl =\n", + " document.querySelector('#df-fbb1dd54-79f3-4239-ab2f-2233c0fc06da button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-fbb1dd54-79f3-4239-ab2f-2233c0fc06da');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '\u003ca target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb\u003edata table notebook\u003c/a\u003e'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", + " " + ], + "text/plain": [ + " Baseline features only Cluster features only Baseline + cluster features\n", + "AUC 0.618361 0.785176 0.813417" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aucs = cluster_utils.get_eval_aucs(\n", + " df_train=df_train,\n", + " df_valid=df_test,\n", + " label_col=LABEL_COL,\n", + " baseline_cols=BASELINE_COLS,\n", + " cluster_cols=TOP_CLUSTERS\n", + ")\n", + "aucs" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "demo.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/colorectal_lymph_node_metastasis_prediction/requirements.txt b/colorectal_lymph_node_metastasis_prediction/requirements.txt new file mode 100644 index 00000000..827f015a --- /dev/null +++ b/colorectal_lymph_node_metastasis_prediction/requirements.txt @@ -0,0 +1,5 @@ +numpy>=1.19.5 +pandas>=1.0.5 +scikit-learn>=0.23.2 +scipy>=1.8.0 +statsmodels>=0.12.2 \ No newline at end of file