Skip to content

Commit

Permalink
explain communication rois
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 12, 2024
1 parent bd8d52f commit 010f500
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 90 deletions.
244 changes: 172 additions & 72 deletions notebooks_stories/0_voxel_select/roi_custom_drive.ipynb

Large diffs are not rendered by default.

56 changes: 38 additions & 18 deletions notebooks_stories/0_voxel_select/run_openai_calls.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,51 @@
import imodelsx.llm
import pandas as pd
import joblib
import json


prompt_template = '''Here is a list of phrases:
{s}
What is a common theme among these phrases (especially the top ones)? Return only a concise phrase.'''

subject = 'S02'
suffix_setting = '_filt=0.15'
# suffix_setting = '_rj_custom'


d_selected = pd.read_pickle(
f'top_clusters_by_pfc_cluster_{subject}{suffix_setting}.pkl')
# d_selected = pd.read_pickle(
# f'top_clusters_by_pfc_cluster_{subject}{suffix_setting}.pkl')
# gpt4 = imodelsx.llm.get_llm('gpt-4')
# explanations = []
# print(d_selected.shape)

# for i, row in d_selected.iterrows():
# top_ngrams = row['top_ngrams']
# s = '- ' + '\n- '.join(top_ngrams[:60])
# prompt = prompt_template.format(s=s)
# # print(prompt)
# explanation = gpt4(prompt, use_cache=True)
# if explanation is None:
# explanation = '<FAILED FOR CONTENT MODERATION>'
# # explanation = gpt4(prompt, use_cache=False)
# explanations.append(explanation)
# print(explanations)
# joblib.dump(
# explanations, f'explanations_by_pfc_cluster_{subject}{suffix_setting}.jbl')


explanations = {}
top_ngrams_df = pd.read_pickle(
f'top_ngrams_custom_communication_{subject}.pkl')
gpt4 = imodelsx.llm.get_llm('gpt-4')
explanations = []
print(d_selected.shape)
prompt_template = '''Here is a list of phrases:
{s}
for k in top_ngrams_df.columns:

What is a common theme among these phrases (especially the top ones)? Return only a concise phrase.'''
for i, row in d_selected.iterrows():
top_ngrams = row['top_ngrams']
s = '- ' + '\n- '.join(top_ngrams[:60])
s = '- ' + '\n- '.join(top_ngrams_df[k].iloc[:100])
prompt = prompt_template.format(s=s)
# print(prompt)
explanation = gpt4(prompt, use_cache=True)
if explanation is None:
explanation = '<FAILED FOR CONTENT MODERATION>'
# explanation = gpt4(prompt, use_cache=False)
explanations.append(explanation)
if not k in explanations:
# print(prompt)
explanations[k] = gpt4(prompt)
print(explanations)
joblib.dump(
explanations, f'explanations_by_pfc_cluster_{subject}{suffix_setting}.jbl')
json.dump(explanations, open(
f'explanations_by_roi_communication_{subject}.json', 'w'), indent=4)

0 comments on commit 010f500

Please sign in to comment.