Skip to content

Commit

Permalink
merged with main
Browse files Browse the repository at this point in the history
  • Loading branch information
anushka255 committed Nov 14, 2024
2 parents cb6bafd + 2794160 commit 5eeb82d
Show file tree
Hide file tree
Showing 20 changed files with 589 additions and 214 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/test_pinned_deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ jobs:
shell: bash -l {0}

steps:
- name: Set environment variables
# needed for testing the napari plugin
if: matrix.os == 'ubuntu-latest'
run: echo "QT_QPA_PLATFORM=offscreen" >> $GITHUB_ENV

- uses: actions/checkout@main

- uses: conda-incubator/setup-miniconda@v3
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/test_unpinned_deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ jobs:
runs-on: ${{ matrix.os }}

steps:
- name: Set environment variables
# needed for testing the napari plugin
if: matrix.os == 'ubuntu-latest'
run: echo "QT_QPA_PLATFORM=offscreen" >> $GITHUB_ENV

- uses: actions/checkout@v4

- uses: actions/setup-python@v5
Expand Down
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0", "setuptools_scm>=6.2"]
build-backend = "setuptools.build_meta"

[project]
name = "paste"
name = "paste3"
authors = [
{name="Max Land", email="[email protected]"}
]
Expand All @@ -14,6 +14,7 @@ classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Framework :: Napari",
]

dependencies = [
Expand All @@ -26,7 +27,8 @@ dependencies = [
"IPython",
"statsmodels",
"torch",
"torchnmf"
"torchnmf",
"pooch"
]
dynamic = ["version"]

Expand All @@ -42,7 +44,9 @@ dev = [
"coveralls",
"ruff",
"pre-commit",
"napari"
"napari",
"pytest-qt",
"PyQt5"
]

docs = [
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ MarkupSafe==2.1.5
matplotlib==3.9.2
matplotlib-inline==0.1.7
mpmath==1.3.0
napari==0.5.4
natsort==8.4.0
networkx==3.3
nodeenv==1.9.1
Expand All @@ -45,6 +46,7 @@ pexpect==4.9.0
pillow==10.4.0
platformdirs==4.3.6
pluggy==1.5.0
pooch==1.8.2
POT==0.9.5
pre-commit==3.8.0
prompt_toolkit==3.0.48
Expand All @@ -55,9 +57,11 @@ pynndescent==0.5.13
pyparsing==3.1.4
pyproject_hooks==1.2.0
pytest==8.3.3
pytest-qt==4.4.0
pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
pytz==2024.2
PyQt5==5.15.11
PyYAML==6.0.2
requests==2.32.3
ruff==0.6.8
Expand Down
34 changes: 0 additions & 34 deletions scripts/workflow.py

This file was deleted.

89 changes: 62 additions & 27 deletions src/paste3/experimental.py → src/paste3/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os.path
from functools import cached_property
from glob import glob
from pathlib import Path
Expand Down Expand Up @@ -33,11 +34,11 @@ def __init__(self, filepath: Path | None = None, adata: AnnData | None = None):
def __str__(self):
if self.filepath is not None:
return Path(self.filepath).stem
return "Slice with adata: " + str(self.adata)
return "Slice with adata: " + str(self.adata).split("\n")[0]

@cached_property
def adata(self):
return self._adata or sc.read_h5ad(self.filepath)
return self._adata or sc.read_h5ad(str(self.filepath))

@cached_property
def obs(self):
Expand All @@ -53,12 +54,8 @@ def obs(self):
def get_obs_values(self, which, coordinates=None):
assert which in self.obs.columns, f"Unknown column: {which}"
if coordinates is None:
assert (
self.has_spatial_data
), "Slice has no `spatial` obsm. Don't know which coordinates to query"
coordinates = self.adata.obsm["spatial"]

return self.obs.loc[coordinates.tolist()][which].tolist()
coordinates = self.obs.index.values
return self.obs.loc[coordinates][which].tolist()

def set_obs_values(self, which, values):
self.obs[which] = values
Expand Down Expand Up @@ -86,20 +83,33 @@ def cluster(
class AlignmentDataset:
def __init__(
self,
file_paths: list[Path] | None = None,
glob_pattern: str | None = None,
slices: list[Slice] | None = None,
max_slices: int | None = None,
name: str | None = None,
):
if slices is not None:
self.slices = slices[:max_slices]
else:
elif glob_pattern is not None:
self.slices = [
Slice(filepath)
for filepath in sorted(glob(glob_pattern))[:max_slices] # noqa: PTH207
]
else:
self.slices = [Slice(filepath) for filepath in file_paths[:max_slices]]

if name is not None:
self.name = name
else:
# Take common prefix of slice names, but remove the word "slice"
# and any trailing underscores
name = os.path.commonprefix([str(slice_) for slice_ in self])
name = name.replace("slice", "").rstrip("_")
self.name = name

def __str__(self):
return f"Data with {len(self.slices)} slices"
return self.name

def __iter__(self):
return iter(self.slices)
Expand All @@ -118,7 +128,7 @@ def align(
self,
center_align: bool = False,
pis: np.ndarray | None = None,
overlap_fraction: float | None = None,
overlap_fraction: float | list[float] | None = None,
max_iters: int = 1000,
):
if center_align:
Expand All @@ -136,31 +146,47 @@ def align(
overlap_fraction=overlap_fraction, pis=pis, max_iters=max_iters
)

def find_pis(self, overlap_fraction: float, max_iters: int = 1000):
def find_pis(self, overlap_fraction: float | list[float], max_iters: int = 1000):
# If multiple overlap_fraction values are specified
# ensure that they are |slices| - 1 in length
try:
iter(overlap_fraction)
except TypeError:
overlap_fraction = [overlap_fraction] * (len(self) - 1)
assert (
len(overlap_fraction) == len(self) - 1
), "Either specify a single overlap_fraction or one for each pair of slices"

pis = []
for i in range(len(self) - 1):
logger.info(f"Finding Pi for slices {i} and {i+1}")
pis.append(
pairwise_align(
self.slices[i].adata,
self.slices[i + 1].adata,
overlap_fraction=overlap_fraction,
numItermax=max_iters,
maxIter=max_iters,
)
pi, _ = pairwise_align(
self.slices[i].adata,
self.slices[i + 1].adata,
overlap_fraction=overlap_fraction[i],
numItermax=max_iters,
maxIter=max_iters,
)
pis.append(pi)
return pis

def pairwise_align(
self,
overlap_fraction: float,
overlap_fraction: float | list[float],
pis: list[np.ndarray] | None = None,
max_iters: int = 1000,
):
if pis is None:
pis = self.find_pis(overlap_fraction=overlap_fraction, max_iters=max_iters)
new_slices = stack_slices_pairwise(self.slices_adata, pis)
return AlignmentDataset(slices=[Slice(adata=s) for s in new_slices])
new_slices, rotation_angles, translations = stack_slices_pairwise(
self.slices_adata, pis
)
aligned_dataset = AlignmentDataset(
slices=[Slice(adata=s) for s in new_slices],
name=self.name + "_pairwise_aligned",
)

return aligned_dataset, rotation_angles, translations

def find_center_slice(
self,
Expand Down Expand Up @@ -210,16 +236,25 @@ def center_align(
center_slice, pis = self.find_center_slice(initial_slice=initial_slice)

logger.info("Stacking slices around center slice")
_, new_slices = stack_slices_center(
center_slice=center_slice.adata, slices=self.slices_adata, pis=pis
new_center, new_slices, rotation_angles, translations = stack_slices_center(
center_slice=center_slice.adata,
slices=self.slices_adata,
pis=pis,
)
aligned_dataset = AlignmentDataset(
slices=[Slice(adata=s) for s in new_slices],
name=self.name + "_center_aligned",
)
return AlignmentDataset(slices=[Slice(adata=s) for s in new_slices])

def all_points(self) -> np.ndarray:
return aligned_dataset, rotation_angles, translations

def all_points(self, translation: np.ndarray | None = None) -> np.ndarray:
layers = []
for i, slice in enumerate(self.slices):
adata = slice.adata
points = adata.obsm["spatial"]
if translation is not None:
points = points + translation
layer_data = np.pad(
points, pad_width=((0, 0), (1, 0)), mode="constant", constant_values=i
)
Expand Down
8 changes: 7 additions & 1 deletion src/paste3/napari/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
__version__ = "0.0.1"

from ._reader import napari_get_reader
from ._sample_data import make_sample_data
from ._widget import CenterAlignContainer, PairwiseAlignContainer

__all__ = ("napari_get_reader", "CenterAlignContainer", "PairwiseAlignContainer")
__all__ = (
"make_sample_data",
"napari_get_reader",
"CenterAlignContainer",
"PairwiseAlignContainer",
)
5 changes: 0 additions & 5 deletions src/paste3/napari/_commands.py

This file was deleted.

61 changes: 9 additions & 52 deletions src/paste3/napari/_reader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
"""
This module is an example of a barebones numpy reader plugin for napari.
from pathlib import Path

It implements the Reader specification, but your plugin may choose to
implement multiple readers or even other plugin contributions. see:
https://napari.org/stable/plugins/guides.html?#readers
"""

import seaborn as sns

from paste3.experimental import AlignmentDataset, Slice
from paste3.dataset import AlignmentDataset, Slice
from paste3.napari._widget import init_widget


def napari_get_reader(path):
Expand All @@ -32,7 +25,7 @@ def napari_get_reader(path):
path = path[0]

# if we know we cannot read the file, we immediately return None.
if not path.endswith(".h5ad"):
if not str(path).endswith(".h5ad"):
return None

# otherwise we return the *function* that can read ``path``.
Expand Down Expand Up @@ -62,48 +55,12 @@ def reader_function(path):
default to layer_type=="image" if not provided
"""
# handle both a string and a list of strings
paths = [path] if isinstance(path, str) else path
paths = [path] if isinstance(path, str | Path) else path
slices = [Slice(filepath) for filepath in paths]
dataset = AlignmentDataset(slices=slices)

face_color_cycle = sns.color_palette("Paired", 20)

layer_data = []
all_clusters = []
for slice in dataset.slices:
points = slice.adata.obsm["spatial"]
clusters = slice.get_obs_values("original_clusters")
all_clusters.extend(clusters)

layer_data.append(
(
points,
{
"features": {"cluster": clusters},
"face_color": "cluster",
"face_color_cycle": face_color_cycle,
"size": 1,
"metadata": {"slice": slice},
"name": f"{slice}",
},
"points",
)
)

layer_data.append(
(
dataset.all_points(),
{
"features": {"cluster": all_clusters},
"face_color": "cluster",
"face_color_cycle": face_color_cycle,
"ndim": 3,
"size": 1,
"scale": [3, 1, 1],
"name": "paste3_volume",
},
"points",
)
)
init_widget(alignment_dataset=dataset)

return layer_data
# We let the initialized widget handle the rest of the logic
# and add the layers to the viewer
return [(None,)]
Loading

0 comments on commit 5eeb82d

Please sign in to comment.