Skip to content

Commit

Permalink
Fix gradio app issues (#23)
Browse files Browse the repository at this point in the history
This PR includes the following improvements/fixes:
- Set the example prompts based on modality
- Add a disclaimer
- Add a statement saying “Please ask a question about the current image
or 2D slice”
- Add 1 more image example (CT)

---------

Signed-off-by: Mingxin Zheng <[email protected]>
Co-authored-by: Holger Roth <[email protected]>
  • Loading branch information
mingxin-zheng and holgerroth authored Oct 20, 2024
1 parent 8ba90d6 commit 96e0766
Showing 1 changed file with 41 additions and 32 deletions.
73 changes: 41 additions & 32 deletions demo/gradio_monai_vila2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
# Sample images dictionary
IMAGES_URLS = {
"CT Sample 1": "https://developer.download.nvidia.com/assets/Clara/monai/samples/liver_0.nii.gz",
"CT Sample 2": "https://developer.download.nvidia.com/assets/Clara/monai/samples/ct_sample.nii.gz",
"Chest X-ray Sample 1": "https://developer.download.nvidia.com/assets/Clara/monai/samples/cxr_ce3d3d98-bf5170fa-8e962da1-97422442-6653c48a_v1.jpg",
"Chest X-ray Sample 2": "https://developer.download.nvidia.com/assets/Clara/monai/samples/cxr_fcb77615-ceca521c-c8e4d028-0d294832-b97b7d77_v1.jpg",
"Chest X-ray Sample 3": "https://developer.download.nvidia.com/assets/Clara/monai/samples/cxr_6cbf5aa1-71de2d2b-96f6b460-24227d6e-6e7a7e1d_v1.jpg",
Expand All @@ -71,20 +72,21 @@

SYS_PROMPT = None # set when the script initializes

EXAMPLE_PROMPTS = [
"Segment the visceral structures in the current image.",
"Can you identify any liver masses or tumors?",
"Segment the entire image.",
"What abnormalities are seen in this image?",
"Is there evidence of edema in this image?",
"Is there evidence of any abnormalities in this image?",
"What is the total number of [condition/abnormality] present in this image?",
"Is there pneumothorax?",
"What type is the lung opacity?",
"which view is this image taken?",
"Is there evidence of cardiomegaly in this image?",
"Is the atelectasis located on the left side or right side?",
"What level is the cardiomegaly?",
EXAMPLE_PROMPTS_3D = [
["Segment the visceral structures in the current image."],
["Can you identify any liver masses or tumors?"],
["Segment the entire image."],
]

EXAMPLE_PROMPTS_2D = [
["What abnormalities are seen in this image?"],
["Is there evidence of edema in this image?"],
["Is there pneumothorax?"],
["What type is the lung opacity?"],
["Which view is this image taken?"],
["Is there evidence of cardiomegaly in this image?"],
["Is the atelectasis located on the left side or right side?"],
["What level is the cardiomegaly?"],
]

HTML_PLACEHOLDER = "<br>".join([""] * 15)
Expand All @@ -94,7 +96,7 @@
CACHED_IMAGES = {}

TITLE = """
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
<p>
<img src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" alt="project monai" style="width: 50%; min-width: 500px; max-width: 800px; margin: auto; display: block;">
</p>
Expand All @@ -111,7 +113,8 @@
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
Placeholder text for the description of the tool.
VILA-M3 is a vision-language model for medical applications that interprets medical images and text prompts to generate relevant responses.
Disclaimer: AI models generate responses and outputs based on complex algorithms and machine learning techniques, and those responses or outputs may be inaccurate, harmful, biased or indecent. By testing this model, you assume the risk of any harm caused by any response or output of the model. This model is for research purposes and not for clinical usage.
</p>
</div>
Expand Down Expand Up @@ -274,6 +277,8 @@ def __init__(self, source="local", model_path="", conv_mode=""):
model_path, self.model_name
)
logger.info(f"Model {self.model_name} loaded successfully. Context length: {self.context_len}")
elif source == "huggingface":
pass
else:
raise NotImplementedError(f"Source {source} is not supported.")

Expand Down Expand Up @@ -474,8 +479,6 @@ def input_image(image, sv: SessionVariables):
"""Update the session variables with the input image data URL if it's inputted by the user"""
logger.debug(f"Received user input image")
# TODO: support user uploaded images
sv.image_url = image_to_data_url(image)
sv.interactive = True
return image, sv


Expand All @@ -486,7 +489,7 @@ def update_image_selection(selected_image, sv: SessionVariables, slice_index_htm
img_file = CACHED_IMAGES.get(sv.image_url, None)

if sv.image_url is None:
return None, sv, slice_index_html
return None, sv, slice_index_html, [[""]]

if sv.temp_working_dir is None:
sv.temp_working_dir = tempfile.mkdtemp()
Expand All @@ -512,13 +515,19 @@ def update_image_selection(selected_image, sv: SessionVariables, slice_index_htm
image_filename=image_filename,
)
compose({"image": img_file})
return os.path.join(sv.temp_working_dir, image_filename), sv, f"Slice Index: {sv.slice_index}"
return (
os.path.join(sv.temp_working_dir, image_filename),
sv,
f"Slice Index: {sv.slice_index}",
gr.Dataset(samples=EXAMPLE_PROMPTS_3D),
)

sv.slice_index = None
return (
img_file,
sv,
"Slice Index: N/A for 2D images, clicking prev/next will not change the image.",
gr.Dataset(samples=EXAMPLE_PROMPTS_2D),
)


Expand Down Expand Up @@ -652,8 +661,16 @@ def create_demo(source, model_path, conv_mode, server_port):

with gr.Row():
with gr.Column():
image_dropdown = gr.Dropdown(label="Select an image", choices=["Please select .."] + list(IMAGES_URLS.keys()))
image_input = gr.Image(label="Image", sources=[], placeholder="Please select an image from the dropdown list.")
image_dropdown = gr.Dropdown(label="Select an image", choices=list(IMAGES_URLS.keys()))
with gr.Accordion("3D image panel", open=False):
slice_index_html = gr.HTML("Slice Index: N/A")
with gr.Row():
prev10_btn = gr.Button("<<")
prev01_btn = gr.Button("<")
next01_btn = gr.Button(">")
next10_btn = gr.Button(">>")

with gr.Accordion("View Parameters", open=False):
temperature_slider = gr.Slider(
label="Temperature", minimum=0.0, maximum=1.0, step=0.01, value=0.0, interactive=True
Expand All @@ -665,14 +682,6 @@ def create_demo(source, model_path, conv_mode, server_port):
label="Max Tokens", minimum=1, maximum=1024, step=1, value=1024, interactive=True
)

with gr.Accordion("3D image panel", open=False):
slice_index_html = gr.HTML("Slice Index: N/A")
with gr.Row():
prev10_btn = gr.Button("<<")
prev01_btn = gr.Button("<")
next01_btn = gr.Button(">")
next10_btn = gr.Button(">>")

with gr.Accordion("System Prompt and Message", open=False):
sys_prompt_text = gr.Textbox(
label="System Prompt",
Expand All @@ -694,10 +703,10 @@ def create_demo(source, model_path, conv_mode, server_port):
clear_btn = gr.Button("Clear Conversation")
with gr.Row(variant="compact"):
prompt_edit = gr.Textbox(
label="Enter your prompt here", container=False, placeholder="Enter your prompt here", scale=2
label="TextPrompt", container=False, placeholder="Please ask a question about the current image or 2D slice", scale=2
)
submit_btn = gr.Button("Submit", scale=0)
gr.Examples(EXAMPLE_PROMPTS, prompt_edit)
examples = gr.Examples([[""]], prompt_edit)

# Process image and clear it immediately by returning None
submit_btn.click(
Expand All @@ -716,7 +725,7 @@ def create_demo(source, model_path, conv_mode, server_port):
image_dropdown.change(
fn=update_image_selection,
inputs=[image_dropdown, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html, examples.dataset],
)
prev10_btn.click(
fn=update_image_prev_10,
Expand Down

0 comments on commit 96e0766

Please sign in to comment.