Skip to content

Commit

Permalink
added mps, added ark-analysis as a dep
Browse files Browse the repository at this point in the history
  • Loading branch information
srivarra committed Mar 4, 2024
1 parent 3d0ae97 commit 6a97c05
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ urls.Documentation = "https://Nimbus-Inference.readthedocs.io/"
urls.Source = "https://github.com/angelolab/Nimbus-Inference"
urls.Home-page = "https://github.com/angelolab/Nimbus-Inference"
dependencies = [
"ark-analysis",
"torch==2.2.0",
"torchvision==0.17.0",
"alpineer",
Expand Down
20 changes: 16 additions & 4 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from alpineer import io_utils
from alpineer import io_utils, misc_utils
from skimage.util.shape import view_as_windows
import nimbus_inference
from nimbus_inference.utils import (
Expand Down Expand Up @@ -81,7 +81,7 @@ class Nimbus(nn.Module):
def __init__(
self, fov_paths, segmentation_naming_convention, output_dir, save_predictions=True,
include_channels=[], half_resolution=True, batch_size=4, test_time_aug=True,
input_shape=[1024, 1024], suffix=".tiff",
input_shape=[1024, 1024], suffix=".tiff", device="auto",
):
"""Initializes a Nimbus Application.
Args:
Expand All @@ -96,6 +96,8 @@ def __init__(
test_time_aug (bool): Whether to use test time augmentation.
input_shape (list): Shape of input images.
suffix (str): Suffix of images to load.
device (str): Device to run model on, either "auto" (either "mps" or "cuda"
, with "cpu" as a fallback), "cpu", "cuda", or "mps". Defaults to "auto".
"""
super(Nimbus, self).__init__()
self.fov_paths = fov_paths
Expand All @@ -111,7 +113,17 @@ def __init__(
self.suffix = suffix
if self.output_dir != "":
os.makedirs(self.output_dir, exist_ok=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device == "auto":
if torch.backends.mps.is_available():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
else:
misc_utils.verify_in_list(device=[device], valid_devices=["cpu", "cuda", "mps"])
self.device = torch.device(device)

def check_inputs(self):
"""check inputs for Nimbus model"""
Expand Down Expand Up @@ -313,7 +325,7 @@ def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"):
image = np.pad(image, ((0, 0), (0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1)), mode=pad_mode)
b, c = image.shape[:2]
# tile image
view = np.squeeze(
view = np.squeeze(
view_as_windows(image, [b, c] + list(tile_size), step=[b, c] + list(output_shape)),
axis=(0,1)
)
Expand Down
2 changes: 1 addition & 1 deletion templates/1_Nimbus_Predict.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "76225704",
"metadata": {
"scrolled": true
Expand Down

0 comments on commit 6a97c05

Please sign in to comment.