Skip to content

Commit

Permalink
format, fix memory consumption, load data dictionary from GitHub repo
Browse files Browse the repository at this point in the history
  • Loading branch information
alistairewj committed Jun 17, 2024
1 parent c55eae8 commit 5123cec
Showing 1 changed file with 75 additions and 33 deletions.
108 changes: 75 additions & 33 deletions src/b2aiprep/app/pages/5_Search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,105 @@
from scipy.spatial import distance as ssd
import numpy as np
import streamlit as st
import typing as t

data_path = "bridge2ai-Voice/bridge2ai-voice-corpus-1//b2ai-voice-corpus-1-dictionary.csv"
rcdict = pd.read_csv(data_path)


#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
"""Mean pool the model output over the tokens factoring in attention."""
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)


@st.cache_data
def load_model(model_path='models/sentence-transformers/all-MiniLM-L6-v2'):

def load_model(model_path="sentence-transformers/all-MiniLM-L6-v2"):
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)

return tokenizer, model

tokenizer, model = load_model()

def embed_sentences(text_list):
tokenizer, model = load_model()

# Tokenize sentences
encoded_input = tokenizer(text_list, padding=True, truncation=True, return_tensors='pt')

def embed_sentences(text_list):
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)

# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
# Batch this for memory efficiency
batch_size = 10
sentence_embeddings = []
for i in range(0, len(text_list), batch_size):
batch = text_list[i:i+batch_size]
# Tokenize sentences
encoded_input = tokenizer(batch, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
model_output = model(**encoded_input)
batch_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
sentence_embeddings.append(batch_embeddings)

sentence_embeddings = torch.cat(sentence_embeddings, dim=0)
return sentence_embeddings


@st.cache_data
def embed_corpus(c):
return embed_sentences(c)

@st.cache_data
def download_and_load_data_dictionary() -> pd.DataFrame:
"""Load the data dictionary from the public release GitHub repo.
https://github.com/eipm/bridge2ai-redcap/
Returns:
pd.DataFrame: The data dictionary as a DataFrame.
"""
data_dictionary_url = (
"https://raw.githubusercontent.com/eipm/bridge2ai-redcap/main/data/bridge2ai_voice_project_data_dictionary.csv"
)
return pd.read_csv(data_dictionary_url)

rcdict = download_and_load_data_dictionary()


def extract_descriptions(df: pd.DataFrame) -> t.Tuple[t.List[str], t.List[str]]:
"""Extract the descriptions from the data dictionary."""

corpus = rcdict['Field Label'].values.tolist()
field_ids = rcdict['Variable / Field Name'].values.tolist()
# There are a number of fields in the data dictionary which do not have a label
# For example, page_1, page2, page3... of the eConsent
# We will only consider fields with a label, as the ones missing label
# are not useful for our search.
idx = df["Field Label"].notnull()
corpus = df.loc[idx, "Field Label"].values.tolist()
field_ids = df.loc[idx, "Variable / Field Name"].values.tolist()
return corpus, field_ids

corpus, field_ids = extract_descriptions(rcdict)

# TODO: This is a very memory hungry operation, spikes with ~40 GB of memory.
corpus_as_vector = embed_corpus(corpus)

search_string = st.text_input('Search string', 'age')
st.markdown(
"""
# Search the data dictionary
The following text box allows you to semantically search the data dictionary.
You can use it to find the name for fields collected in the study.
"""
)
search_string = st.text_input("Search string", "age")

search_embedding = embed_sentences([search_string,])
search_embedding = embed_sentences(
[
search_string,
]
)

# Compute cosine similarity scores for the search string to all other sentences
sims = []
Expand All @@ -76,19 +121,16 @@ def embed_corpus(c):
col1, col2 = st.columns(2)

with col1:

cutoff = st.number_input("Cutoff", 0.0, 1.0, 0.3)

plt.plot(sims)
plt.title("Cosine similarity")
st.pyplot(plt)


with col2:

sentences_to_show = sentences_sorted[sims > cutoff].tolist()
field_ids_to_show = field_ids_sorted[sims > cutoff].tolist()

final_df = pd.DataFrame({'field_ids': field_ids_to_show, 'field_desc': sentences_to_show})

st.table(final_df)
final_df = pd.DataFrame(
{"field_ids": field_ids_to_show, "field_desc": sentences_to_show}
)
st.table(final_df)

0 comments on commit 5123cec

Please sign in to comment.