Skip to content

Commit

Permalink
Overall improvements to UX and added support for new file formats in …
Browse files Browse the repository at this point in the history
…the result download section.
  • Loading branch information
LSeu-Open authored Oct 14, 2024
1 parent 10c523d commit cbc3529
Showing 1 changed file with 77 additions and 25 deletions.
102 changes: 77 additions & 25 deletions Cellpose_gradio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import gradio as gr
import numpy as np
import os
from cellpose import models, utils
import matplotlib.pyplot as plt
from cellpose import models, utils
from PIL import Image
from datetime import datetime
import json
from werkzeug.utils import secure_filename
import json
import os
import io

########################################################
# 1. Utility functions
Expand Down Expand Up @@ -121,7 +122,7 @@ def save_masks(image, masks):
overlay = utils.masks_to_outlines(masks)
Image.fromarray(overlay).save(outlines_filename)
outlines_path = os.path.abspath(outlines_filename)

return [npy_path, png_path, outlines_path]
return None

Expand Down Expand Up @@ -227,6 +228,8 @@ def process_and_display(image, model_type, diameter, flow_threshold, display_cha
3. Generates a figure displaying the segmentation results.
4. Counts the number of cells in the segmented image.
5. Saves the segmentation masks to files.
6. Saves a high-quality version of the plot as a PNG file.
7. Saves the plot as an SVG file.
Args:
image (numpy.ndarray): Input image for segmentation.
Expand All @@ -242,8 +245,9 @@ def process_and_display(image, model_type, diameter, flow_threshold, display_cha
Returns:
tuple: Contains the following elements:
- fig (matplotlib.figure.Figure): Figure object with segmentation results.
- mask_files (list): Paths to saved mask files.
- mask_files (list): Paths to saved mask files and high-quality plot.
- cell_count (int): Number of cells detected.
- settings_summary (str): Summary of the settings used for segmentation.
- gr.update: Gradio update object to hide/show error alerts.
"""
try:
Expand All @@ -261,22 +265,51 @@ def process_and_display(image, model_type, diameter, flow_threshold, display_cha
progress(0.1, desc="Segmentation starting...")
channels = [seg_channel1, seg_channel2]

# Segment the image
progress(0.25, desc="Segmentation in progress...")
masks = segment_image(image, model_type, channels=channels, diameter=diameter, flow_threshold=flow_threshold)
progress(0.75, desc="Segmentation complete. Generating results...")

# Display the results
fig = display_results(image, masks, display_channel=display_channel, cmap=cmap)
cell_count = count_cells(masks)
mask_files = save_masks(image, masks)

# Save the figure as a high-quality PNG
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
buf.seek(0)

# Save the plot to a file.png
plot_filename = os.path.join("Outputs", f"Result_figure_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
with open(plot_filename, 'wb') as f:
f.write(buf.getvalue())

# Add the PNG plot to mask_files
mask_files.append(plot_filename)

# Save the figure as an SVG
svg_buf = io.BytesIO()
fig.savefig(svg_buf, format='svg', bbox_inches='tight')
svg_buf.seek(0)

# Save the plot as an SVG
svg_filename = os.path.join("Outputs", f"Result_figure_{datetime.now().strftime('%Y%m%d_%H%M%S')}.svg")
with open(svg_filename, 'wb') as f:
f.write(svg_buf.getvalue())

# Add the SVG plot to mask_files
mask_files.append(svg_filename)

progress(1.0, desc="Process complete!")
return fig, mask_files, cell_count, gr.update(visible=False)
settings_summary = f"Model: {model_type}, Diameter: {diameter}, Flow Threshold: {flow_threshold}, Display: {display_channel}, Seg Ch1: {seg_channel1}, Seg Ch2: {seg_channel2}, Colormap: {cmap}"
return fig, mask_files, cell_count, settings_summary, gr.update(visible=False)
gr.Error("No image provided.")
return None, None, None, gr.update(visible=True)
return None, None, None, None, gr.update(visible=True)

except Exception as e:
gr.Error(str(e)) # Show error message
return None, None, None, gr.update(visible=True)
gr.Error(str(e))
return None, None, None, None, gr.update(visible=True)

def update_channel_visibility(channel_config):
"""
Expand Down Expand Up @@ -456,6 +489,18 @@ def load_settings(profile_name):
background-color: #fdedd6;
color: white;
}
.custom-settings-summary {
font-weight: bold;
font-size: 10px;
}
.custom-settings-summary textarea {
font-weight: bold !important;
font-size: 10px !important;
}
.custom-settings-summary:hover {
background-color: #fdedd6;
color: white;
}
"""

custom_theme = gr.themes.Soft(primary_hue="orange", secondary_hue="orange", font=["Arial", "sans-serif"])
Expand All @@ -482,7 +527,9 @@ def load_settings(profile_name):
profile_name = gr.Textbox(label="Save Profile", placeholder="Enter profile name", info="Name your profile to **save the current settings**.", elem_classes=["custom-component"])
save_btn = gr.Button("Save Profile", scale=1, size="sm", elem_classes=["custom-button"])
with gr.Column(scale=1):
load_profile = gr.Dropdown(label="Load Profile", choices=list_profiles(), scale=1, info="Select a profile to **load its settings**.", elem_classes=["custom-dropdown"])
profiles = list_profiles()
default_value = profiles[0] if profiles else None
load_profile = gr.Dropdown(label="Load Profile", choices=profiles, scale=1, info="Select a profile to **load its settings**.", value=default_value, allow_custom_value=True, elem_classes=["custom-dropdown"])
load_btn = gr.Button("Load Profile", scale=1, size="sm", elem_classes=["custom-button"])

# Model type, diameter, flow threshold
Expand Down Expand Up @@ -527,26 +574,31 @@ def load_settings(profile_name):
)

process_btn = gr.Button("Run Segmentation", scale=2, elem_classes=["custom-button"])

# Output plot + progress animation
with gr.Row():
output_plot = gr.Plot(label="Segmentation Results")
progress_output = gr.Textbox(label="Progress", interactive=False, visible=False)

# Output files and cell count
with gr.Row():
cell_count_output = gr.Number(label="Number of cells detected", scale=1, elem_classes=["custom-component"])
output_files = gr.File(label="Download Results (date and time is in the filename)", file_count="multiple", scale=1, elem_classes=["custom-component"])

# buttons logic
progress_output = gr.Textbox(label="Progress", interactive=False, visible=True)

# Output components (initially hidden)
with gr.Row(visible=False) as output_row:
output_plot = gr.Plot(label="Segmentation Results")

with gr.Row(visible=False) as results_row:
with gr.Column(scale=1):
cell_count_output = gr.Number(label="Number of cells detected", elem_classes=["custom-component"])
settings_output = gr.Textbox(label="Settings Summary", elem_classes=["custom-settings-summary"])
with gr.Column(scale=2):
output_files = gr.File(label="Download Results", file_count="multiple", scale=1, height=500, elem_classes=["custom-component"])

# Update the process_and_display_wrapper function
def process_and_display_wrapper(*args):
results = process_and_display(*args)
return [gr.update(visible=True), gr.update(visible=True)] + list(results)

# Run segmentation button
process_btn.click(
fn=process_and_display,
fn=process_and_display_wrapper,
inputs=[input_image, model_type, diameter, flow_threshold, display_channel, seg_channel1, seg_channel2, cmap],
outputs=[output_plot, output_files, cell_count_output, progress_output]
outputs=[output_row, results_row, output_plot, output_files, cell_count_output, settings_output, progress_output],
)

# Save settings button
save_btn.click(
save_settings,
Expand Down

0 comments on commit cbc3529

Please sign in to comment.