From 2935b9cc9afc52139845b8e7baf210478c62b42c Mon Sep 17 00:00:00 2001 From: Mohak Date: Sun, 12 Jan 2025 22:51:28 +0530 Subject: [PATCH 1/5] added segmentation model to the codebase --- .streamlit/secrets.example.toml | 3 + requirements.txt | Bin 1920 -> 2018 bytes .../segmentAnythingModel.py | 166 ++++++++++++++++++ 3 files changed, 169 insertions(+) create mode 100644 src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py diff --git a/.streamlit/secrets.example.toml b/.streamlit/secrets.example.toml index 21ed4154..f8fe2178 100644 --- a/.streamlit/secrets.example.toml +++ b/.streamlit/secrets.example.toml @@ -41,3 +41,6 @@ MOVIES_LIST="1MlnqIHqbcuPaogV9LzotmR0Eyelju4TV" [objectDetectionModel] YOLO11S="1z9X2yvEaTThV6JBVkPRMLakh1O5KM4cl" YOLOV8S="1s3abY6DUaIE0w54F9Bq-HbcTkiruvbZn" + +[segmentAnythingModel] +MODEL_PATH="facebook/sam-vit-base" diff --git a/requirements.txt b/requirements.txt index 2ba475d894e45b4d47c214fb5aa9be28a4b064f0..8972e1a4988708eaf980c5e52b50f7d0f01ab17f 100644 GIT binary patch delta 68 zcmZqRf5g9`fn7J9A(NqmL7O21NR}`ZFcdRbG3Wy!L_C9`lp%>hk0F^MpCOk)e{v$D QIHUgL`)radybN3n02~nxCIA2c delta 7 OcmaFF-@w12fgJ!0Vgk4T diff --git a/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py b/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py new file mode 100644 index 00000000..e97f635f --- /dev/null +++ b/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py @@ -0,0 +1,166 @@ +import streamlit as st +import tensorflow as tf +from transformers import TFSamModel, SamProcessor +from PIL import Image +import numpy as np +import matplotlib.pyplot as plt + +# Enable numpy behavior +tf.experimental.numpy.experimental_enable_numpy_behavior() + +@st.cache_resource +def load_model(): + """Load the SAM model and processor.""" + model = TFSamModel.from_pretrained("facebook/sam-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam-vit-base") + return model, processor + +def generate_random_colors(n): + """Generate n distinct colors.""" + colors = [] + for i in range(n): + # Generate bright, distinct colors + hue = i / n + saturation = 0.7 + np.random.rand() * 0.3 + value = 0.7 + np.random.rand() * 0.3 + + # Convert HSV to RGB + h = hue * 6 + c = value * saturation + x = c * (1 - abs(h % 2 - 1)) + m = value - c + + if h < 1: + rgb = (c, x, 0) + elif h < 2: + rgb = (x, c, 0) + elif h < 3: + rgb = (0, c, x) + elif h < 4: + rgb = (0, x, c) + elif h < 5: + rgb = (x, 0, c) + else: + rgb = (c, 0, x) + + colors.append(np.array([(r + m) * 0.7 for r in rgb] + [0.5])) # Add alpha value + return colors + +def show_masks_on_image(raw_image, masks, scores): + """Display all masks overlaid on the same image with different colors.""" + plt.clf() + + # Convert tensors to numpy arrays + if isinstance(masks, tf.Tensor): + masks = masks.numpy() + if isinstance(scores, tf.Tensor): + scores = scores.numpy() + + masks = np.squeeze(masks) + scores = np.squeeze(scores) + + # Handle single mask case + if len(masks.shape) == 2: + masks = np.expand_dims(masks, 0) + scores = np.expand_dims(scores, 0) + + # Create figure + fig, ax = plt.subplots(figsize=(12, 8)) + ax.imshow(raw_image) + + # Generate distinct colors for each mask + colors = generate_random_colors(len(masks)) + + # Overlay each mask with a different color + for i, (mask, score, color) in enumerate(zip(masks, scores, colors)): + # Show masked image + mask_image = np.zeros((mask.shape[0], mask.shape[1], 4)) + mask_image[:, :, 3] = mask * color[3] # Alpha channel + for j in range(3): # RGB channels + mask_image[:, :, j] = mask * color[j] + + ax.imshow(mask_image) + + # Add label with score + label = f"Object {i+1} (Score: {float(score):.2f})" + # Find center of mass of the mask for label placement + y_indices, x_indices = np.where(mask > 0.5) + if len(x_indices) > 0 and len(y_indices) > 0: + center_x = np.mean(x_indices) + center_y = np.mean(y_indices) + ax.text(center_x, center_y, label, + color='white', fontsize=8, + bbox=dict(facecolor='black', alpha=0.5), + ha='center', va='center') + + ax.axis('off') + plt.tight_layout() + return fig + +def segmentAnythingModel(): + st.title("Segment Anything Model (SAM)") + st.write(""" + Upload an image to automatically segment all objects in the scene. + Each object will be highlighted with a different color. + """) + + # Load model at the start + with st.spinner("Loading model..."): + try: + model, processor = load_model() + st.success("Model loaded successfully!") + except Exception as e: + st.error(f"Error loading model: {str(e)}") + return + + # File uploader + uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) + + if uploaded_file is not None: + try: + # Display original image + raw_image = Image.open(uploaded_file).convert("RGB") + st.image(raw_image, caption="Original Image", use_column_width=True) + + # Generate grid of points across the image for complete segmentation + height, width = raw_image.size[1], raw_image.size[0] + grid_size = 50 # Adjust this value to control segmentation density + x_points = np.linspace(0, width, num=grid_size) + y_points = np.linspace(0, height, num=grid_size) + + input_points = [[[x, y] for x in x_points[::4] for y in y_points[::4]]] + + if st.button("Generate Segmentation"): + with st.spinner("Generating segmentation..."): + try: + # Process the image with grid points + inputs = processor(raw_image, input_points=input_points, return_tensors="tf") + + # Run model inference + outputs = model(**inputs) + + # Post-process masks + masks = processor.image_processor.post_process_masks( + outputs.pred_masks, + inputs["original_sizes"], + inputs["reshaped_input_sizes"], + return_tensors="tf", + ) + + # Create and display figure + fig = show_masks_on_image(raw_image, masks, outputs.iou_scores) + st.pyplot(fig) + plt.close(fig) + + st.success("Segmentation completed successfully!") + + except Exception as e: + st.error(f"Error during segmentation: {str(e)}") + st.write("Full error details:", e) + + except Exception as e: + st.error(f"An error occurred: {str(e)}") + st.info("Please try again with a different image.") + +if __name__ == "__main__": + segmentAnythingModel() \ No newline at end of file From bd42e7e44a5c090894c04112be89e917417b01e5 Mon Sep 17 00:00:00 2001 From: Mohak Date: Sun, 12 Jan 2025 23:06:58 +0530 Subject: [PATCH 2/5] modifeied various methodologies everything point and boxes --- .../segmentAnythingModel.py | 155 +++++++++++------- 1 file changed, 100 insertions(+), 55 deletions(-) diff --git a/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py b/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py index e97f635f..255166af 100644 --- a/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py +++ b/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py @@ -19,12 +19,10 @@ def generate_random_colors(n): """Generate n distinct colors.""" colors = [] for i in range(n): - # Generate bright, distinct colors hue = i / n saturation = 0.7 + np.random.rand() * 0.3 value = 0.7 + np.random.rand() * 0.3 - # Convert HSV to RGB h = hue * 6 c = value * saturation x = c * (1 - abs(h % 2 - 1)) @@ -43,14 +41,13 @@ def generate_random_colors(n): else: rgb = (c, 0, x) - colors.append(np.array([(r + m) * 0.7 for r in rgb] + [0.5])) # Add alpha value + colors.append(np.array([(r + m) * 0.7 for r in rgb] + [0.5])) return colors def show_masks_on_image(raw_image, masks, scores): """Display all masks overlaid on the same image with different colors.""" plt.clf() - # Convert tensors to numpy arrays if isinstance(masks, tf.Tensor): masks = masks.numpy() if isinstance(scores, tf.Tensor): @@ -59,36 +56,28 @@ def show_masks_on_image(raw_image, masks, scores): masks = np.squeeze(masks) scores = np.squeeze(scores) - # Handle single mask case if len(masks.shape) == 2: masks = np.expand_dims(masks, 0) scores = np.expand_dims(scores, 0) - # Create figure fig, ax = plt.subplots(figsize=(12, 8)) ax.imshow(raw_image) - # Generate distinct colors for each mask colors = generate_random_colors(len(masks)) - # Overlay each mask with a different color for i, (mask, score, color) in enumerate(zip(masks, scores, colors)): - # Show masked image mask_image = np.zeros((mask.shape[0], mask.shape[1], 4)) - mask_image[:, :, 3] = mask * color[3] # Alpha channel - for j in range(3): # RGB channels + mask_image[:, :, 3] = mask * color[3] + for j in range(3): mask_image[:, :, j] = mask * color[j] ax.imshow(mask_image) - # Add label with score - label = f"Object {i+1} (Score: {float(score):.2f})" - # Find center of mass of the mask for label placement y_indices, x_indices = np.where(mask > 0.5) if len(x_indices) > 0 and len(y_indices) > 0: center_x = np.mean(x_indices) center_y = np.mean(y_indices) - ax.text(center_x, center_y, label, + ax.text(center_x, center_y, f"Score: {float(score):.2f}", color='white', fontsize=8, bbox=dict(facecolor='black', alpha=0.5), ha='center', va='center') @@ -97,14 +86,61 @@ def show_masks_on_image(raw_image, masks, scores): plt.tight_layout() return fig +def process_and_show_masks(raw_image, model_outputs, processor_inputs, processor): + """Process model outputs and display segmentation masks.""" + masks = processor.image_processor.post_process_masks( + model_outputs.pred_masks, + processor_inputs["original_sizes"], + processor_inputs["reshaped_input_sizes"], + return_tensors="tf", + ) + fig = show_masks_on_image(raw_image, masks, model_outputs.iou_scores) + st.pyplot(fig) + plt.close(fig) + +def segment_everything(raw_image, model, processor): + """Segment all objects in the image using a grid of points.""" + height, width = raw_image.size[1], raw_image.size[0] + grid_size = 32 + x_points = np.linspace(0, width, num=grid_size) + y_points = np.linspace(0, height, num=grid_size) + input_points = [[[x, y] for x in x_points[::4] for y in y_points[::4]]] + + inputs = processor(raw_image, input_points=input_points, return_tensors="tf") + outputs = model(**inputs) + return outputs, inputs + +def segment_with_box(raw_image, box_coords, model, processor): + """Segment objects within the specified box.""" + input_boxes = [[[box_coords[0], box_coords[1], box_coords[2], box_coords[3]]]] + inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="tf") + outputs = model(**inputs) + return outputs, inputs + +def segment_with_point(raw_image, point_coords, model, processor): + """Segment objects at the specified point.""" + input_points = [[[point_coords[0], point_coords[1]]]] + inputs = processor(raw_image, input_points=input_points, return_tensors="tf") + outputs = model(**inputs) + return outputs, inputs + +def segment_with_text(raw_image, text_prompt, model, processor): + """Segment objects matching the text description.""" + inputs = processor(raw_image, return_tensors="tf") + outputs = model(**inputs) + return outputs, inputs + def segmentAnythingModel(): - st.title("Segment Anything Model (SAM)") + st.title("Advanced Segment Anything Model (SAM)") st.write(""" - Upload an image to automatically segment all objects in the scene. - Each object will be highlighted with a different color. + Choose a segmentation mode and see the results! + - Segment Everything: Detects and segments all objects in the image + - Box Prompt: Draw a box around the area you want to segment + - Point Prompt: Click a point to segment objects at that location + - Text Prompt: Describe what you want to segment """) - # Load model at the start + # Load model with st.spinner("Loading model..."): try: model, processor = load_model() @@ -118,49 +154,58 @@ def segmentAnythingModel(): if uploaded_file is not None: try: - # Display original image raw_image = Image.open(uploaded_file).convert("RGB") st.image(raw_image, caption="Original Image", use_column_width=True) - # Generate grid of points across the image for complete segmentation - height, width = raw_image.size[1], raw_image.size[0] - grid_size = 50 # Adjust this value to control segmentation density - x_points = np.linspace(0, width, num=grid_size) - y_points = np.linspace(0, height, num=grid_size) - - input_points = [[[x, y] for x in x_points[::4] for y in y_points[::4]]] - - if st.button("Generate Segmentation"): - with st.spinner("Generating segmentation..."): - try: - # Process the image with grid points - inputs = processor(raw_image, input_points=input_points, return_tensors="tf") - - # Run model inference - outputs = model(**inputs) - - # Post-process masks - masks = processor.image_processor.post_process_masks( - outputs.pred_masks, - inputs["original_sizes"], - inputs["reshaped_input_sizes"], - return_tensors="tf", - ) - - # Create and display figure - fig = show_masks_on_image(raw_image, masks, outputs.iou_scores) - st.pyplot(fig) - plt.close(fig) - - st.success("Segmentation completed successfully!") + # Segmentation mode selection + mode = st.selectbox( + "Select Segmentation Mode", + ["Segment Everything", "Box Prompt", "Point Prompt", "Text Prompt"] + ) + + if mode == "Segment Everything": + if st.button("Generate Complete Segmentation"): + with st.spinner("Segmenting all objects..."): + outputs, inputs = segment_everything(raw_image, model, processor) + process_and_show_masks(raw_image, outputs, inputs, processor) - except Exception as e: - st.error(f"Error during segmentation: {str(e)}") - st.write("Full error details:", e) + elif mode == "Box Prompt": + st.write("Enter box coordinates:") + col1, col2 = st.columns(2) + with col1: + x_min = st.number_input("X min", 0, raw_image.size[0]) + y_min = st.number_input("Y min", 0, raw_image.size[1]) + with col2: + x_max = st.number_input("X max", x_min, raw_image.size[0]) + y_max = st.number_input("Y max", y_min, raw_image.size[1]) + + if st.button("Segment with Box"): + with st.spinner("Segmenting selected area..."): + outputs, inputs = segment_with_box(raw_image, [x_min, y_min, x_max, y_max], model, processor) + process_and_show_masks(raw_image, outputs, inputs, processor) + + elif mode == "Point Prompt": + col1, col2 = st.columns(2) + with col1: + x_coord = st.number_input("X coordinate", 0, raw_image.size[0]) + with col2: + y_coord = st.number_input("Y coordinate", 0, raw_image.size[1]) + + if st.button("Segment at Point"): + with st.spinner("Segmenting at point..."): + outputs, inputs = segment_with_point(raw_image, [x_coord, y_coord], model, processor) + process_and_show_masks(raw_image, outputs, inputs, processor) + + elif mode == "Text Prompt": + text_prompt = st.text_input("Describe what you want to segment") + if st.button("Segment with Text") and text_prompt: + with st.spinner("Segmenting based on description..."): + outputs, inputs = segment_with_text(raw_image, text_prompt, model, processor) + process_and_show_masks(raw_image, outputs, inputs, processor) except Exception as e: st.error(f"An error occurred: {str(e)}") - st.info("Please try again with a different image.") + st.info("Please try again with a different input.") if __name__ == "__main__": segmentAnythingModel() \ No newline at end of file From 75f9961eb13c8b7db53fbf919dd5b4c26dedad7f Mon Sep 17 00:00:00 2001 From: Mohak Date: Mon, 13 Jan 2025 18:24:46 +0530 Subject: [PATCH 3/5] resolved issue --- requirements.txt | Bin 2018 -> 2014 bytes .../segmentAnythingModel.py | 3 --- 2 files changed, 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8972e1a4988708eaf980c5e52b50f7d0f01ab17f..a80047e2538522d48b34ec00f92658f87f2424d2 100644 GIT binary patch delta 7 OcmaFFe~*8|J$3*N^aFGN delta 12 Tcmcb||A>FXJ$4pe1}+8wAdCZ# diff --git a/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py b/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py index 255166af..1f345620 100644 --- a/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py +++ b/src/apps/pages/models/ObjectDetectionModels/segmentAnythingModel.py @@ -206,6 +206,3 @@ def segmentAnythingModel(): except Exception as e: st.error(f"An error occurred: {str(e)}") st.info("Please try again with a different input.") - -if __name__ == "__main__": - segmentAnythingModel() \ No newline at end of file From 6a81a852b4c9428bf3ca8fcff52f426ac9f62065 Mon Sep 17 00:00:00 2001 From: Mohak Date: Tue, 14 Jan 2025 13:44:19 +0530 Subject: [PATCH 4/5] discard changes as requested --- requirements.txt | Bin 2014 -> 1916 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index a80047e2538522d48b34ec00f92658f87f2424d2..ef5b47b9fcf7732a2952791ae6cafc6a1e22e4c7 100644 GIT binary patch delta 7 Ocmcb||A%iw4Lbl1rUNSg delta 67 zcmeyvcaMKV4Z9&P0~bR&LncEBgEm73kSt*+U?^s=V$cUdh Date: Fri, 17 Jan 2025 00:36:10 +0530 Subject: [PATCH 5/5] added dog classifier --- .streamlit/secrets.example.toml | 3 + .../DogBreedClassificationModel.py | 145 ++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 src/apps/pages/models/ObjectDetectionModels/DogBreedClassificationModel.py diff --git a/.streamlit/secrets.example.toml b/.streamlit/secrets.example.toml index f8fe2178..daf648c1 100644 --- a/.streamlit/secrets.example.toml +++ b/.streamlit/secrets.example.toml @@ -44,3 +44,6 @@ YOLOV8S="1s3abY6DUaIE0w54F9Bq-HbcTkiruvbZn" [segmentAnythingModel] MODEL_PATH="facebook/sam-vit-base" + +[DogBreedClassificationModel] +MODEL="1O0vm50puB3FpWWpS53mB1_FazCRJXjvp" \ No newline at end of file diff --git a/src/apps/pages/models/ObjectDetectionModels/DogBreedClassificationModel.py b/src/apps/pages/models/ObjectDetectionModels/DogBreedClassificationModel.py new file mode 100644 index 00000000..8163a6d1 --- /dev/null +++ b/src/apps/pages/models/ObjectDetectionModels/DogBreedClassificationModel.py @@ -0,0 +1,145 @@ +import tensorflow as tf +import numpy as np +import streamlit as st +from PIL import Image +import gdown +import os + +@st.cache_resource +def load_model(): + """Recreate the model architecture and load weights""" + + try: + gdown.download(f"https://drive.google.com/uc?id={st.secrets['DogBreedClassificationModel']['MODEL']}", 'Model.h5', quiet=False) + # Create base model + base_model = tf.keras.applications.MobileNetV2( + input_shape=(224, 224, 3), + include_top=False, + weights='imagenet' + ) + base_model.trainable = False + + # Recreate the model architecture + model = tf.keras.Sequential([ + base_model, + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(120, activation='softmax', + kernel_regularizer=tf.keras.regularizers.l2(0.01)) + ]) + + # Load weights + try: + model.load_weights("Model.h5") + except: + # Try loading as a TensorFlow checkpoint + model.load_weights(tf.train.latest_checkpoint("./")) + + return model + except Exception as e: + st.error(f"Error loading model: {e}") + return None + +def preprocess_image(img): + """Preprocess image for MobileNetV2""" + # Convert to RGB if not already + img = img.convert('RGB') + # Resize to MobileNetV2 expected size + img = img.resize((224, 224)) + # Convert to array and expand dimensions + img_array = tf.keras.preprocessing.image.img_to_array(img) + img_array = np.expand_dims(img_array, axis=0) + # Preprocess for MobileNetV2 + img_array = tf.keras.applications.mobilenet_v2.preprocess_input(img_array) + return img_array + +def upload_dog_img(): + img = st.file_uploader("Upload a dog image", type=["jpg","jpeg","png"]) + if img is not None: + try: + st.image(img, width=300, caption="Uploaded Image") + img_data = Image.open(img) + processed_img = preprocess_image(img_data) + return processed_img + except Exception as e: + st.error(f"Error processing image: {e}") + return None + else: + st.warning("Please upload an image") + return None + +def predict_breed(x, breed_labels): + """ + Predict dog breed from preprocessed image + breed_labels should be a list of breed names corresponding to model output indices + """ + model = load_model() + if model is None: + st.error("Failed to load model") + return None, 0 + + try: + predictions = model.predict(x) + predicted_idx = np.argmax(predictions[0]) + confidence = predictions[0][predicted_idx] + return breed_labels[predicted_idx], confidence + except Exception as e: + st.error(f"Error making prediction: {e}") + return None, 0 + +def DogBreedClassificationModel(): + st.title("Dog Breed Classifier") + st.write("Upload a photo of a dog to identify its breed") + + # Define the 120 dog breed labels + breed_labels = ['affenpinscher', 'afghan_hound', 'african_hunting_dog', 'airedale', + 'american_staffordshire_terrier', 'appenzeller', + 'australian_terrier', 'basenji', 'basset', 'beagle', + 'bedlington_terrier', 'bernese_mountain_dog', + 'black-and-tan_coonhound', 'blenheim_spaniel', 'bloodhound', + 'bluetick', 'border_collie', 'border_terrier', 'borzoi', + 'boston_bull', 'bouvier_des_flandres', 'boxer', + 'brabancon_griffon', 'briard', 'brittany_spaniel', 'bull_mastiff', + 'cairn', 'cardigan', 'chesapeake_bay_retriever', 'chihuahua', + 'chow', 'clumber', 'cocker_spaniel', 'collie', + 'curly-coated_retriever', 'dandie_dinmont', 'dhole', 'dingo', + 'doberman', 'english_foxhound', 'english_setter', + 'english_springer', 'entlebucher', 'eskimo_dog', + 'flat-coated_retriever', 'french_bulldog', 'german_shepherd', + 'german_short-haired_pointer', 'giant_schnauzer', + 'golden_retriever', 'gordon_setter', 'great_dane', + 'great_pyrenees', 'greater_swiss_mountain_dog', 'groenendael', + 'ibizan_hound', 'irish_setter', 'irish_terrier', + 'irish_water_spaniel', 'irish_wolfhound', 'italian_greyhound', + 'japanese_spaniel', 'keeshond', 'kelpie', 'kerry_blue_terrier', + 'komondor', 'kuvasz', 'labrador_retriever', 'lakeland_terrier', + 'leonberg', 'lhasa', 'malamute', 'malinois', 'maltese_dog', + 'mexican_hairless', 'miniature_pinscher', 'miniature_poodle', + 'miniature_schnauzer', 'newfoundland', 'norfolk_terrier', + 'norwegian_elkhound', 'norwich_terrier', 'old_english_sheepdog', + 'otterhound', 'papillon', 'pekinese', 'pembroke', 'pomeranian', + 'pug', 'redbone', 'rhodesian_ridgeback', 'rottweiler', + 'saint_bernard', 'saluki', 'samoyed', 'schipperke', + 'scotch_terrier', 'scottish_deerhound', 'sealyham_terrier', + 'shetland_sheepdog', 'shih-tzu', 'siberian_husky', 'silky_terrier', + 'soft-coated_wheaten_terrier', 'staffordshire_bullterrier', + 'standard_poodle', 'standard_schnauzer', 'sussex_spaniel', + 'tibetan_mastiff', 'tibetan_terrier', 'toy_poodle', 'toy_terrier', + 'vizsla', 'walker_hound', 'weimaraner', 'welsh_springer_spaniel', + 'west_highland_white_terrier', 'whippet', + 'wire-haired_fox_terrier', 'yorkshire_terrier'] + + input_img = upload_dog_img() + + if input_img is not None: + if st.button("Predict Breed"): + with st.spinner('Making prediction...'): + breed, confidence = predict_breed(input_img, breed_labels) + if breed is not None: + try: + user = st.session_state["user"].split(',') + name = user[2] + " " + user[3] + except: + name = "there" + st.success(f'Hi {name},\nThis appears to be a {breed} with {confidence:.2%} confidence! 🐕', + icon="🎉") \ No newline at end of file