diff --git a/Cellpose_gradio.py b/Cellpose_gradio.py index 2adf9a9..398d48b 100644 --- a/Cellpose_gradio.py +++ b/Cellpose_gradio.py @@ -1,12 +1,13 @@ import gradio as gr import numpy as np -import os -from cellpose import models, utils import matplotlib.pyplot as plt +from cellpose import models, utils from PIL import Image from datetime import datetime -import json from werkzeug.utils import secure_filename +import json +import os +import io ######################################################## # 1. Utility functions @@ -121,7 +122,7 @@ def save_masks(image, masks): overlay = utils.masks_to_outlines(masks) Image.fromarray(overlay).save(outlines_filename) outlines_path = os.path.abspath(outlines_filename) - + return [npy_path, png_path, outlines_path] return None @@ -227,6 +228,8 @@ def process_and_display(image, model_type, diameter, flow_threshold, display_cha 3. Generates a figure displaying the segmentation results. 4. Counts the number of cells in the segmented image. 5. Saves the segmentation masks to files. + 6. Saves a high-quality version of the plot as a PNG file. + 7. Saves the plot as an SVG file. Args: image (numpy.ndarray): Input image for segmentation. @@ -242,8 +245,9 @@ def process_and_display(image, model_type, diameter, flow_threshold, display_cha Returns: tuple: Contains the following elements: - fig (matplotlib.figure.Figure): Figure object with segmentation results. - - mask_files (list): Paths to saved mask files. + - mask_files (list): Paths to saved mask files and high-quality plot. - cell_count (int): Number of cells detected. + - settings_summary (str): Summary of the settings used for segmentation. - gr.update: Gradio update object to hide/show error alerts. """ try: @@ -261,22 +265,51 @@ def process_and_display(image, model_type, diameter, flow_threshold, display_cha progress(0.1, desc="Segmentation starting...") channels = [seg_channel1, seg_channel2] + # Segment the image progress(0.25, desc="Segmentation in progress...") masks = segment_image(image, model_type, channels=channels, diameter=diameter, flow_threshold=flow_threshold) progress(0.75, desc="Segmentation complete. Generating results...") + # Display the results fig = display_results(image, masks, display_channel=display_channel, cmap=cmap) cell_count = count_cells(masks) mask_files = save_masks(image, masks) + # Save the figure as a high-quality PNG + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + buf.seek(0) + + # Save the plot to a file.png + plot_filename = os.path.join("Outputs", f"Result_figure_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + with open(plot_filename, 'wb') as f: + f.write(buf.getvalue()) + + # Add the PNG plot to mask_files + mask_files.append(plot_filename) + + # Save the figure as an SVG + svg_buf = io.BytesIO() + fig.savefig(svg_buf, format='svg', bbox_inches='tight') + svg_buf.seek(0) + + # Save the plot as an SVG + svg_filename = os.path.join("Outputs", f"Result_figure_{datetime.now().strftime('%Y%m%d_%H%M%S')}.svg") + with open(svg_filename, 'wb') as f: + f.write(svg_buf.getvalue()) + + # Add the SVG plot to mask_files + mask_files.append(svg_filename) + progress(1.0, desc="Process complete!") - return fig, mask_files, cell_count, gr.update(visible=False) + settings_summary = f"Model: {model_type}, Diameter: {diameter}, Flow Threshold: {flow_threshold}, Display: {display_channel}, Seg Ch1: {seg_channel1}, Seg Ch2: {seg_channel2}, Colormap: {cmap}" + return fig, mask_files, cell_count, settings_summary, gr.update(visible=False) gr.Error("No image provided.") - return None, None, None, gr.update(visible=True) + return None, None, None, None, gr.update(visible=True) except Exception as e: - gr.Error(str(e)) # Show error message - return None, None, None, gr.update(visible=True) + gr.Error(str(e)) + return None, None, None, None, gr.update(visible=True) def update_channel_visibility(channel_config): """ @@ -456,6 +489,18 @@ def load_settings(profile_name): background-color: #fdedd6; color: white; } +.custom-settings-summary { + font-weight: bold; + font-size: 10px; +} +.custom-settings-summary textarea { + font-weight: bold !important; + font-size: 10px !important; +} +.custom-settings-summary:hover { + background-color: #fdedd6; + color: white; +} """ custom_theme = gr.themes.Soft(primary_hue="orange", secondary_hue="orange", font=["Arial", "sans-serif"]) @@ -482,7 +527,9 @@ def load_settings(profile_name): profile_name = gr.Textbox(label="Save Profile", placeholder="Enter profile name", info="Name your profile to **save the current settings**.", elem_classes=["custom-component"]) save_btn = gr.Button("Save Profile", scale=1, size="sm", elem_classes=["custom-button"]) with gr.Column(scale=1): - load_profile = gr.Dropdown(label="Load Profile", choices=list_profiles(), scale=1, info="Select a profile to **load its settings**.", elem_classes=["custom-dropdown"]) + profiles = list_profiles() + default_value = profiles[0] if profiles else None + load_profile = gr.Dropdown(label="Load Profile", choices=profiles, scale=1, info="Select a profile to **load its settings**.", value=default_value, allow_custom_value=True, elem_classes=["custom-dropdown"]) load_btn = gr.Button("Load Profile", scale=1, size="sm", elem_classes=["custom-button"]) # Model type, diameter, flow threshold @@ -527,26 +574,31 @@ def load_settings(profile_name): ) process_btn = gr.Button("Run Segmentation", scale=2, elem_classes=["custom-button"]) - - # Output plot + progress animation - with gr.Row(): - output_plot = gr.Plot(label="Segmentation Results") - progress_output = gr.Textbox(label="Progress", interactive=False, visible=False) - - # Output files and cell count - with gr.Row(): - cell_count_output = gr.Number(label="Number of cells detected", scale=1, elem_classes=["custom-component"]) - output_files = gr.File(label="Download Results (date and time is in the filename)", file_count="multiple", scale=1, elem_classes=["custom-component"]) - - # buttons logic + progress_output = gr.Textbox(label="Progress", interactive=False, visible=True) + # Output components (initially hidden) + with gr.Row(visible=False) as output_row: + output_plot = gr.Plot(label="Segmentation Results") + + with gr.Row(visible=False) as results_row: + with gr.Column(scale=1): + cell_count_output = gr.Number(label="Number of cells detected", elem_classes=["custom-component"]) + settings_output = gr.Textbox(label="Settings Summary", elem_classes=["custom-settings-summary"]) + with gr.Column(scale=2): + output_files = gr.File(label="Download Results", file_count="multiple", scale=1, height=500, elem_classes=["custom-component"]) + + # Update the process_and_display_wrapper function + def process_and_display_wrapper(*args): + results = process_and_display(*args) + return [gr.update(visible=True), gr.update(visible=True)] + list(results) + # Run segmentation button process_btn.click( - fn=process_and_display, + fn=process_and_display_wrapper, inputs=[input_image, model_type, diameter, flow_threshold, display_channel, seg_channel1, seg_channel2, cmap], - outputs=[output_plot, output_files, cell_count_output, progress_output] + outputs=[output_row, results_row, output_plot, output_files, cell_count_output, settings_output, progress_output], ) - + # Save settings button save_btn.click( save_settings,