Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added API logic #29

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ docs/*
.idea
.idea/*
venv/*
kubernetes
kubernetes/*
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ setting is presented [here](https://onlinelibrary.wiley.com/doi/10.1002/pst.2352
- Provides visualization tools for analysis.
- Supports simulation of datasets for testing and demonstration purposes.

## Running Streamlit Demos

1. Install the pip packages using `pip install -r environments/requirements.txt`.
2. Run the FastAPI server using `python api.py`.
3. Run the Streamlit server using `streamlit run bin/main.py`.


## Limitations

At the moment, `pybalance` only implements matching routines. Suport for weighting
Expand Down
52 changes: 52 additions & 0 deletions api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any, Dict

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from bin.main import match, generate_data
import pandas as pd
import json

app = FastAPI(
title="PyBalance API",
)

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)


class GenerateDataRequest(BaseModel):
n_pool: int
n_target: int


class MatchRequest(BaseModel):
matching_data: Dict # Replace with proper type
objective: str
max_iter: int = Field(100)


@app.post("/generate_data")
async def generate_data_endpoint(request: GenerateDataRequest):
matching_data = generate_data(request.n_pool, request.n_target)
return matching_data


@app.post("/match")
async def match_endpoint(request: MatchRequest):
matching_data_dict = request.matching_data
post_matching_data = match(matching_data_dict, request.objective, request.max_iter)
print(f"post_matching_datapost_matching_datapost_matching_datapost_matching_data {post_matching_data}")
return {"post_matching_data": post_matching_data}


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
92 changes: 71 additions & 21 deletions bin/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import sys
import os

# Assuming 'pybalance' directory is at the same level as 'streamlit_app.py'
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from typing import Any, Dict

import streamlit as st
import pandas as pd
import seaborn as sns
Expand All @@ -10,6 +18,8 @@
plot_per_feature_loss,
)
from pybalance.utils import BALANCE_CALCULATORS, split_target_pool, MatchingData
import requests
import json

OBJECTIVES = list(BALANCE_CALCULATORS.keys())
OBJECTIVES.remove("base")
Expand All @@ -24,13 +34,15 @@
placeholder = st.empty()


def generate_data():
print("Generating data!")
def generate_data(n_pool: int, n_target: int):
print("Inside Generating data...")
seed = 45
n_pool, n_target = st.session_state["n_pool"], st.session_state["n_target"]
# n_pool, n_target = st.session_state["n_pool"], st.session_state["n_target"]
matching_data = generate_toy_dataset(n_pool, n_target, seed)
st.session_state["matching_data"] = matching_data
st.session_state["first_run"] = False

# st.session_state["first_run"] = False
print(matching_data.head(5))
return matching_data.to_dict()


def load_data():
Expand Down Expand Up @@ -59,32 +71,29 @@ def load_data():
st.session_state["matching_data"] = matching_data


def match():

def match(payload: Dict, objective: str, max_iter: int = 100):
print("Inside Matching data...")
matching_data_recreated = MatchingData.from_dict(payload)
# Create an instance of PropensityScoreMatcher
max_iter = st.session_state.get("max_iter", 100)
method = "greedy"
objective = st.session_state.get("objective")
matching_data = st.session_state.get("matching_data").copy()
matcher = PropensityScoreMatcher(
matching_data, objective, None, max_iter, time_limit, method
)

matcher = PropensityScoreMatcher(matching_data_recreated, objective, None, max_iter, 10, method)
print(f"matcher {matcher}")
# Call the match() method
post_matching_data = matcher.match()
post_matching_data.data.loc[:, "population"] = (
post_matching_data["population"] + " (postmatch)"
post_matching_data["population"] + " (postmatch)"
abhishek-ch marked this conversation as resolved.
Show resolved Hide resolved
)
st.session_state["post_matching_data"] = post_matching_data
print(f"post_matching_data {post_matching_data}")
return post_matching_data.to_dict()


def load_front_page():

st.markdown("<h5>Generate a simulated dataset</h5>", unsafe_allow_html=True)

col1, col2 = st.columns(2)
with col1:
st.number_input(
n_pool = st.number_input(
"Pool size",
min_value=1,
step=1000,
Expand All @@ -93,15 +102,34 @@ def load_front_page():
help="Number of patients in the pool (by convention, larger) population",
)
with col2:
st.number_input(
n_target = st.number_input(
"Target size",
min_value=1,
step=100,
value=1000,
key="n_target",
help="Number of patients in the target (by convention, smaller) population",
)
st.button("Generate", on_click=generate_data)

if st.button("Generate"):
# Prepare the payload
payload = {
"n_pool": n_pool,
"n_target": n_target
}

# Call the FastAPI endpoint
response = requests.post("http://localhost:8000/generate_data", json=payload)
abhishek-ch marked this conversation as resolved.
Show resolved Hide resolved

if response.status_code == 200:
matching_data_json = response.json()
st.session_state["matching_data"] = matching_data_json
# matching_instance = MatchingData.from_json(matching_data_json)
st.session_state["first_run"] = False
st.success("Data generated successfully")
st.rerun() # Force the script to rerun
else:
st.error("Failed to generate data")

st.write("---")
st.markdown("<h5>Upload your own data</h5>", unsafe_allow_html=True)
Expand Down Expand Up @@ -135,6 +163,7 @@ def load_front_page():
with placeholder.container():

matching_data = st.session_state.get("matching_data").copy()
matching_data = MatchingData.from_dict(matching_data)
target, pool = split_target_pool(matching_data)

# Create a sidebar for inputting parameters
Expand Down Expand Up @@ -185,7 +214,28 @@ def load_front_page():
hue_order = list(matching_data.populations)

# Create a button to trigger the match() method
st.sidebar.button("Match", on_click=match)
if st.sidebar.button("Match"):
if "matching_data" in st.session_state:
matching_data_str = st.session_state["matching_data"]
payload = {
"matching_data": matching_data_str,
"objective": st.session_state.get("objective"),
"max_iter": st.session_state.get("max_iter")
}

# Call the FastAPI endpoint
response = requests.post("http://localhost:8000/match", json=payload)
abhishek-ch marked this conversation as resolved.
Show resolved Hide resolved

if response.status_code == 200:
post_matching_data_json = response.json().get("post_matching_data")
post_matching_instance = MatchingData.from_json(post_matching_data_json)
st.session_state["post_matching_instance"] = post_matching_instance
st.success("Data matched successfully")
else:
st.error("Failed to match data")
else:
st.error("No matching data found. Generate data first.")
# st.sidebar.button("Match", on_click=match)

balance_calculator = BALANCE_CALCULATORS[objective](matching_data)
st.sidebar.write(balance_calculator.__doc__)
Expand Down Expand Up @@ -227,7 +277,7 @@ def load_front_page():
with tab2:
plot_vars = []
for i, col in enumerate(
st.columns(len(matching_data.headers["categoric"]))
st.columns(len(matching_data.headers["categoric"]))
):
with col:
col_name = matching_data.headers["categoric"][i]
Expand Down
4 changes: 3 additions & 1 deletion environments/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ fsspec
s3fs

# streamlit installation
streamlit
streamlit
fastapi==0.111.0
uvicorn[standard]==0.30.1
Empty file added kubernetes/deployment.yml
Empty file.
2 changes: 2 additions & 0 deletions pybalance/propensity/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(

self.matching_data = matching_data.copy()
self.target, self.pool = split_target_pool(matching_data)
print(f"self.target: {self.target}")
print(f"self.pool: {self.pool}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove print statements

if isinstance(objective, str):
self.balance_calculator = BalanceCalculator(self.matching_data, objective)
self.objective = objective
Expand Down
37 changes: 36 additions & 1 deletion pybalance/utils/matching_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
import pandas as pd
import logging
import json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -32,6 +33,17 @@ def __getitem__(self, key):
key
]

def to_dict(self):
return {
"categoric": self.categoric,
"numeric": self.numeric
}

@classmethod
def from_dict(cls, data_dict):
return cls(categoric=data_dict["categoric"], numeric=data_dict["numeric"])



def infer_matching_headers(
data: pd.DataFrame,
Expand Down Expand Up @@ -366,6 +378,29 @@ def describe(
n = self.describe_numeric(aggregations, quantiles)
return pd.concat([c, n])

def to_dict(self):
return {
"data": self.data.to_dict(orient='dict'),
"headers": self.headers,
"population_col": self.population_col
}

@classmethod
def from_dict(cls, data_dict):
pd_data = data_dict.get("data")
data = pd.DataFrame.from_dict(pd_data)
headers_dict = data_dict.get("headers", {})
population_col = data_dict["population_col"]
return cls(data, headers_dict, population_col)

def to_json(self):
return json.dumps(self.to_dict())

@classmethod
def from_json(cls, json_str):
# data_dict = json.loads(json_str)
return cls.from_dict(json_str)


def split_target_pool(
matching_data: MatchingData,
Expand All @@ -378,7 +413,6 @@ def split_target_pool(
explicitly provided, the routine will attempt to infer their names,
assuming that the target population is the smaller population.
"""

if isinstance(target_name, str) and isinstance(pool_name, str):
target = matching_data.get_population(target_name)
pool = matching_data.get_population(pool_name)
Expand Down Expand Up @@ -416,3 +450,4 @@ def split_target_pool(
target = _pool

return target, pool

abhishek-ch marked this conversation as resolved.
Show resolved Hide resolved