From 6a97c05b05a26e5e9ab08998c81f2dd4a799d218 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Mon, 4 Mar 2024 15:53:18 -0800 Subject: [PATCH] added mps, added ark-analysis as a dep --- pyproject.toml | 1 + src/nimbus_inference/nimbus.py | 20 ++++++++++++++++---- templates/1_Nimbus_Predict.ipynb | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9642cc8..b5202d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ urls.Documentation = "https://Nimbus-Inference.readthedocs.io/" urls.Source = "https://github.com/angelolab/Nimbus-Inference" urls.Home-page = "https://github.com/angelolab/Nimbus-Inference" dependencies = [ + "ark-analysis", "torch==2.2.0", "torchvision==0.17.0", "alpineer", diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 56616b2..4850e8c 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -1,4 +1,4 @@ -from alpineer import io_utils +from alpineer import io_utils, misc_utils from skimage.util.shape import view_as_windows import nimbus_inference from nimbus_inference.utils import ( @@ -81,7 +81,7 @@ class Nimbus(nn.Module): def __init__( self, fov_paths, segmentation_naming_convention, output_dir, save_predictions=True, include_channels=[], half_resolution=True, batch_size=4, test_time_aug=True, - input_shape=[1024, 1024], suffix=".tiff", + input_shape=[1024, 1024], suffix=".tiff", device="auto", ): """Initializes a Nimbus Application. Args: @@ -96,6 +96,8 @@ def __init__( test_time_aug (bool): Whether to use test time augmentation. input_shape (list): Shape of input images. suffix (str): Suffix of images to load. + device (str): Device to run model on, either "auto" (either "mps" or "cuda" + , with "cpu" as a fallback), "cpu", "cuda", or "mps". Defaults to "auto". """ super(Nimbus, self).__init__() self.fov_paths = fov_paths @@ -111,7 +113,17 @@ def __init__( self.suffix = suffix if self.output_dir != "": os.makedirs(self.output_dir, exist_ok=True) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if device == "auto": + if torch.backends.mps.is_available(): + self.device = torch.device("mps") + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + else: + misc_utils.verify_in_list(device=[device], valid_devices=["cpu", "cuda", "mps"]) + self.device = torch.device(device) def check_inputs(self): """check inputs for Nimbus model""" @@ -313,7 +325,7 @@ def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"): image = np.pad(image, ((0, 0), (0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1)), mode=pad_mode) b, c = image.shape[:2] # tile image - view = np.squeeze( + view = np.squeeze( view_as_windows(image, [b, c] + list(tile_size), step=[b, c] + list(output_shape)), axis=(0,1) ) diff --git a/templates/1_Nimbus_Predict.ipynb b/templates/1_Nimbus_Predict.ipynb index 5c5694d..dbdd8e6 100644 --- a/templates/1_Nimbus_Predict.ipynb +++ b/templates/1_Nimbus_Predict.ipynb @@ -238,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "76225704", "metadata": { "scrolled": true