Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradio app issues #23

Merged
merged 4 commits into from
Oct 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 41 additions & 32 deletions demo/gradio_monai_vila2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
# Sample images dictionary
IMAGES_URLS = {
"CT Sample 1": "https://developer.download.nvidia.com/assets/Clara/monai/samples/liver_0.nii.gz",
"CT Sample 2": "https://developer.download.nvidia.com/assets/Clara/monai/samples/ct_sample.nii.gz",
"Chest X-ray Sample 1": "https://developer.download.nvidia.com/assets/Clara/monai/samples/cxr_ce3d3d98-bf5170fa-8e962da1-97422442-6653c48a_v1.jpg",
"Chest X-ray Sample 2": "https://developer.download.nvidia.com/assets/Clara/monai/samples/cxr_fcb77615-ceca521c-c8e4d028-0d294832-b97b7d77_v1.jpg",
"Chest X-ray Sample 3": "https://developer.download.nvidia.com/assets/Clara/monai/samples/cxr_6cbf5aa1-71de2d2b-96f6b460-24227d6e-6e7a7e1d_v1.jpg",
Expand All @@ -71,20 +72,21 @@

SYS_PROMPT = None # set when the script initializes

EXAMPLE_PROMPTS = [
"Segment the visceral structures in the current image.",
"Can you identify any liver masses or tumors?",
"Segment the entire image.",
"What abnormalities are seen in this image?",
"Is there evidence of edema in this image?",
"Is there evidence of any abnormalities in this image?",
"What is the total number of [condition/abnormality] present in this image?",
"Is there pneumothorax?",
"What type is the lung opacity?",
"which view is this image taken?",
"Is there evidence of cardiomegaly in this image?",
"Is the atelectasis located on the left side or right side?",
"What level is the cardiomegaly?",
EXAMPLE_PROMPTS_3D = [
["Segment the visceral structures in the current image."],
["Can you identify any liver masses or tumors?"],
["Segment the entire image."],
]

EXAMPLE_PROMPTS_2D = [
["What abnormalities are seen in this image?"],
["Is there evidence of edema in this image?"],
["Is there pneumothorax?"],
["What type is the lung opacity?"],
["Which view is this image taken?"],
["Is there evidence of cardiomegaly in this image?"],
["Is the atelectasis located on the left side or right side?"],
["What level is the cardiomegaly?"],
]

HTML_PLACEHOLDER = "<br>".join([""] * 15)
Expand All @@ -94,7 +96,7 @@
CACHED_IMAGES = {}

TITLE = """
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
<p>
<img src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" alt="project monai" style="width: 50%; min-width: 500px; max-width: 800px; margin: auto; display: block;">
</p>
Expand All @@ -111,7 +113,8 @@
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
Placeholder text for the description of the tool.
VILA-M3 is a vision-language model for medical applications that interprets medical images and text prompts to generate relevant responses.
Disclaimer: AI models generate responses and outputs based on complex algorithms and machine learning techniques, and those responses or outputs may be inaccurate, harmful, biased or indecent. By testing this model, you assume the risk of any harm caused by any response or output of the model. This model is for research purposes and not for clinical usage.
</p>

</div>
Expand Down Expand Up @@ -274,6 +277,8 @@ def __init__(self, source="local", model_path="", conv_mode=""):
model_path, self.model_name
)
logger.info(f"Model {self.model_name} loaded successfully. Context length: {self.context_len}")
elif source == "huggingface":
pass
else:
raise NotImplementedError(f"Source {source} is not supported.")

Expand Down Expand Up @@ -474,8 +479,6 @@ def input_image(image, sv: SessionVariables):
"""Update the session variables with the input image data URL if it's inputted by the user"""
logger.debug(f"Received user input image")
# TODO: support user uploaded images
sv.image_url = image_to_data_url(image)
sv.interactive = True
return image, sv


Expand All @@ -486,7 +489,7 @@ def update_image_selection(selected_image, sv: SessionVariables, slice_index_htm
img_file = CACHED_IMAGES.get(sv.image_url, None)

if sv.image_url is None:
return None, sv, slice_index_html
return None, sv, slice_index_html, [[""]]

if sv.temp_working_dir is None:
sv.temp_working_dir = tempfile.mkdtemp()
Expand All @@ -512,13 +515,19 @@ def update_image_selection(selected_image, sv: SessionVariables, slice_index_htm
image_filename=image_filename,
)
compose({"image": img_file})
return os.path.join(sv.temp_working_dir, image_filename), sv, f"Slice Index: {sv.slice_index}"
return (
os.path.join(sv.temp_working_dir, image_filename),
sv,
f"Slice Index: {sv.slice_index}",
gr.Dataset(samples=EXAMPLE_PROMPTS_3D),
)

sv.slice_index = None
return (
img_file,
sv,
"Slice Index: N/A for 2D images, clicking prev/next will not change the image.",
gr.Dataset(samples=EXAMPLE_PROMPTS_2D),
)


Expand Down Expand Up @@ -652,8 +661,16 @@ def create_demo(source, model_path, conv_mode, server_port):

with gr.Row():
with gr.Column():
image_dropdown = gr.Dropdown(label="Select an image", choices=["Please select .."] + list(IMAGES_URLS.keys()))
image_input = gr.Image(label="Image", sources=[], placeholder="Please select an image from the dropdown list.")
image_dropdown = gr.Dropdown(label="Select an image", choices=list(IMAGES_URLS.keys()))
with gr.Accordion("3D image panel", open=False):
slice_index_html = gr.HTML("Slice Index: N/A")
with gr.Row():
prev10_btn = gr.Button("<<")
prev01_btn = gr.Button("<")
next01_btn = gr.Button(">")
next10_btn = gr.Button(">>")

with gr.Accordion("View Parameters", open=False):
temperature_slider = gr.Slider(
label="Temperature", minimum=0.0, maximum=1.0, step=0.01, value=0.0, interactive=True
Expand All @@ -665,14 +682,6 @@ def create_demo(source, model_path, conv_mode, server_port):
label="Max Tokens", minimum=1, maximum=1024, step=1, value=1024, interactive=True
)

with gr.Accordion("3D image panel", open=False):
slice_index_html = gr.HTML("Slice Index: N/A")
with gr.Row():
prev10_btn = gr.Button("<<")
prev01_btn = gr.Button("<")
next01_btn = gr.Button(">")
next10_btn = gr.Button(">>")

with gr.Accordion("System Prompt and Message", open=False):
sys_prompt_text = gr.Textbox(
label="System Prompt",
Expand All @@ -694,10 +703,10 @@ def create_demo(source, model_path, conv_mode, server_port):
clear_btn = gr.Button("Clear Conversation")
with gr.Row(variant="compact"):
prompt_edit = gr.Textbox(
label="Enter your prompt here", container=False, placeholder="Enter your prompt here", scale=2
label="TextPrompt", container=False, placeholder="Please ask a question about the current image or 2D slice", scale=2
)
submit_btn = gr.Button("Submit", scale=0)
gr.Examples(EXAMPLE_PROMPTS, prompt_edit)
examples = gr.Examples([[""]], prompt_edit)

# Process image and clear it immediately by returning None
submit_btn.click(
Expand All @@ -716,7 +725,7 @@ def create_demo(source, model_path, conv_mode, server_port):
image_dropdown.change(
fn=update_image_selection,
inputs=[image_dropdown, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html, examples.dataset],
)
prev10_btn.click(
fn=update_image_prev_10,
Expand Down
Loading