Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable more debugging options #43

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions m3/demo/gradio_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def __init__(self):
self.idx_range = (None, None)
self.interactive = False
self.sys_msgs_to_hide = []
self.modality_prompt = "Auto"


def new_session_variables(**kwargs):
Expand Down Expand Up @@ -426,7 +427,10 @@ def process_prompt(self, prompt, sv, chat_history):
if sv.temp_working_dir is None:
sv.temp_working_dir = tempfile.mkdtemp()

modality = get_modality(sv.image_url, text=prompt)
if sv.modality_prompt == "Auto":
modality = get_modality(sv.image_url, text=prompt)
else:
modality = sv.modality_prompt
mod_msg = f"This is a {modality} image.\n" if modality != "Unknown" else ""

img_file = CACHED_IMAGES.get(sv.image_url, None)
Expand Down Expand Up @@ -604,7 +608,7 @@ def clear_one_conv(sv: SessionVariables):
- history_text
- history_text_full
- sys_prompt_text
- sys_message_text
- model_cards_text
If some of the parameters need to stay persistent in the session, they should be modified in the `clear_all_convs` function.
"""
logger.debug(f"Clearing the parameters of one conversation")
Expand All @@ -625,7 +629,7 @@ def clear_all_convs(sv: SessionVariables):
if sv.temp_working_dir is not None:
rmtree(sv.temp_working_dir)
new_sv = new_session_variables()
# Order of output: prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, sys_message_text
# Order of output: prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, model_cards_text
return (
new_sv,
"Enter your prompt here",
Expand All @@ -634,6 +638,7 @@ def clear_all_convs(sv: SessionVariables):
HTML_PLACEHOLDER,
new_sv.sys_prompt,
new_sv.sys_msg,
new_sv.modality_prompt,
)


Expand Down Expand Up @@ -666,11 +671,17 @@ def update_sys_prompt(sys_prompt, sv):


def update_sys_message(sys_message, sv):
"""Update the system message"""
logger.debug(f"Updating the system message")
"""Update the model cards"""
logger.debug(f"Updating the model cards")
sv.sys_msg = sys_message
return sv

def update_modality_prompt(modality_prompt, sv):
"""Update the modality prompt"""
logger.debug(f"Updating the modality prompt")
sv.modality_prompt = modality_prompt
return sv


def download_file():
"""Download the file."""
Expand Down Expand Up @@ -713,11 +724,15 @@ def create_demo(source, model_path, conv_mode, server_port):
value=sv.value.sys_prompt,
lines=4,
)
sys_message_text = gr.Textbox(
label="System Message",
model_cards_text = gr.Textbox(
label="Model Cards",
value=sv.value.sys_msg,
lines=10,
)
modality_prompt_dropdown = gr.Dropdown(
label="Select Modality",
choices=["Auto", "CT", "MRI", "X-ray", "Unknown"],
)

with gr.Column():
with gr.Tab("In front of the scene"):
Expand Down Expand Up @@ -764,12 +779,13 @@ def create_demo(source, model_path, conv_mode, server_port):
top_p_slider.change(fn=update_top_p, inputs=[top_p_slider, sv], outputs=[sv])
max_tokens_slider.change(fn=update_max_tokens, inputs=[max_tokens_slider, sv], outputs=[sv])
sys_prompt_text.change(fn=update_sys_prompt, inputs=[sys_prompt_text, sv], outputs=[sv])
sys_message_text.change(fn=update_sys_message, inputs=[sys_message_text, sv], outputs=[sv])
model_cards_text.change(fn=update_sys_message, inputs=[model_cards_text, sv], outputs=[sv])
modality_prompt_dropdown.change(fn=update_modality_prompt, inputs=[modality_prompt_dropdown, sv], outputs=[sv])
# Reset button
clear_btn.click(
fn=clear_all_convs,
inputs=[sv],
outputs=[sv, prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, sys_message_text],
outputs=[sv, prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, model_cards_text, modality_prompt_dropdown],
)

# States
Expand Down
Loading