Skip to content

Commit

Permalink
Return predicted spot intensities
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Mar 17, 2024
1 parent 9d9269a commit 35bac9f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ model = Spotiflow.from_pretrained("general")
# model = Spotiflow.from_folder("./mymodel")

# Predict
points, details = model.predict(img) # points contains the coordinates of the detected spots, the attributes 'heatmap' and 'flow' of `details` contains the predicted full resolution heatmap and the prediction of the stereographic flow respectively (access them by `details.heatmap` or `details.flow`).
points, details = model.predict(img) # points contains the coordinates of the detected spots, the attributes 'heatmap' and 'flow' of `details` contain the predicted full resolution heatmap and the prediction of the stereographic flow respectively (access them by `details.heatmap` or `details.flow`). Retrieved spot intensities are found in `details.intens`.
```

### Napari plugin
Expand Down
15 changes: 13 additions & 2 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from ..data import SpotsDataset
from ..utils import (
bilinear_interp_points,
center_crop,
center_pad,
filter_shape,
Expand Down Expand Up @@ -590,7 +591,7 @@ def predict(
device (Optional[Union[torch.device, Literal["auto", "cpu", "cuda", "mps"]]], optional): computing device to use. If None, will infer from model location. If "auto", will infer from available hardware. Defaults to None.
Returns:
Tuple[np.ndarray, SimpleNamespace]: Tuple of (points, details). Points are the coordinates of the spots. Details is a namespace containing the spot-wise probabilities, the heatmap and the 2D flow field.
Tuple[np.ndarray, SimpleNamespace]: Tuple of (points, details). Points are the coordinates of the spots. Details is a namespace containing the spot-wise probabilities (`prob`), the heatmap (`heatmap`), the stereographic flow (`flow`), the 2D local offset vector field (`subpix`) and the spot intensities (`intens`).
"""

if subpix is False:
Expand Down Expand Up @@ -804,7 +805,17 @@ def predict(
_subpix = None
flow = None

details = SimpleNamespace(prob=probs, heatmap=y, subpix=_subpix, flow=flow)
# Retrieve intensity of the spots
if subpix_radius < 0: # no need to interpolate if subpixel precision is not used
intens = img[tuple(pts.astype(int).T)]
else:
try:
intens = bilinear_interp_points(img, pts)
except Exception as _:
log.warn("Bilinear interpolation failed to retrive spot intensities. Will use nearest neighbour interpolation instead.")
intens = img[tuple(pts.round().astype(int).T)]

details = SimpleNamespace(prob=probs, heatmap=y, subpix=_subpix, flow=flow, intens=intens)
return pts, details

def predict_dataset(
Expand Down
41 changes: 41 additions & 0 deletions spotiflow/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,44 @@ def read_npz_dataset(fname: Union[Path, str]) -> Tuple[np.ndarray, ...]:
else:
raise ValueError(f"Unexpected key {key} in .npz file {fname}")
return ret_data

def bilinear_interp_points(img: np.ndarray, pts: np.ndarray, eps: float=1e-9) -> np.ndarray:
""" Return the bilinearly interpolated iamge intensities at each (subpixel) location.
Args:
img (np.ndarray): image in YX or YXC format.
pts (np.ndarray): spot locations to interpolate the intensities from. Array shape should be (N,2).
eps (float, optional): will clip spot locations to SHAPE-eps to avoid numerical issues at image border. Defaults to 1e-9.
Returns:
np.ndarray: array of shape (N,C) containing intensities for each spot
"""
assert img.ndim in (2,3), "Expected YX or YXC image for interpolating intensities."
assert pts.shape[1] == 2, "Point coordinates to be interpolated should be an (N,2) array"

if img.ndim == 2:
img = img[..., None]

if pts.shape[0] == 0:
return np.zeros((0, img.shape[-1]), dtype=img.dtype)
ys, xs = pts[:, 0], pts[:, 1]

# Avoid out of bounds coordinates
ys.clip(0, img.shape[0]-1-eps, out=ys)
xs.clip(0, img.shape[1]-1-eps, out=xs)

pys = np.floor(ys).astype(int)
pxs = np.floor(xs).astype(int)

# Differences to floored coordinates
dys = ys-pys
dxs = xs-pxs
wxs, wys = 1.-dxs, 1.-dys

# Interpolate
weights = np.multiply(img[pys, pxs, :].T , wxs*wys).T
weights += np.multiply(img[pys, pxs+1, :].T , dxs*wys).T
weights += np.multiply(img[pys+1, pxs, :].T , wxs*dys).T
weights += np.multiply(img[pys+1, pxs+1, :].T , dxs*dys).T
return weights

0 comments on commit 35bac9f

Please sign in to comment.