Skip to content

Commit

Permalink
Ran flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
vianamp committed Apr 6, 2024
1 parent 5b45578 commit 3bb46e6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
36 changes: 19 additions & 17 deletions aicscytoparam/alignment/generic_2d_shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy import spatial as spspatial
from skimage import transform as sktrans
from skimage import measure as skmeasure


class Generic2DShape():
"""
Generic class for 2D shapes
Expand All @@ -21,17 +21,18 @@ def _compute_contour(self):
self.cx = cx - cx.mean()
self.cy = cy - cy.mean()
return

def show(self, ax=None):
if ax is None:
fig, ax = plt.subplots()
ax.plot(self.cx, self.cy)
if ax is None:
plt.show()

def find_angle_that_minimizes_countour_distance(self, cx, cy):
"""
Find the angle that minimizes the distance between the shape and the contour (cx, cy)
Find the angle that minimizes the distance between the
shape and the contour (cx, cy)
Parameters
----------
cx: np.ndarray
Expand All @@ -55,7 +56,7 @@ def find_angle_that_minimizes_countour_distance(self, cx, cy):
dist_min = D.min(axis=0).mean() + D.min(axis=1).mean()
dists.append(dist_min)
return np.argmin(dists), np.min(dists)

@staticmethod
def rotate_contour(cx, cy, theta):
"""
Expand All @@ -75,8 +76,8 @@ def rotate_contour(cx, cy, theta):
cyrot: np.ndarray
y coordinates of the rotated contour
"""
cxrot = cx*np.cos(np.deg2rad(theta)) - cy*np.sin(np.deg2rad(theta))
cyrot = cx*np.sin(np.deg2rad(theta)) + cy*np.cos(np.deg2rad(theta))
cxrot = cx * np.cos(np.deg2rad(theta)) - cy * np.sin(np.deg2rad(theta))
cyrot = cx * np.sin(np.deg2rad(theta)) + cy * np.cos(np.deg2rad(theta))
return cxrot, cyrot

@staticmethod
Expand All @@ -99,43 +100,44 @@ def get_contour_from_3d_image(image, pad=5, center=True):
y coordinates of the contour
"""
mip = image.max(axis=0)
y, x = np.where(mip>0)
mip = np.pad(mip, ((pad,pad), (pad,pad)))
cont = skmeasure.find_contours(mip>0)[0]
y, x = np.where(mip > 0)
mip = np.pad(mip, ((pad, pad), (pad, pad)))
cont = skmeasure.find_contours(mip > 0)[0]
cx, cy = cont[:, 1], cont[:, 0]
if center:
cx = cx - cx.mean()
cy = cy - cy.mean()
return (cx, cy)


class ElongatedHexagonalShape(Generic2DShape):
"""
Elongated hexagonal shape
"""
def __init__(self, base, elongation, pad=5):
self._pad = pad
self._base = base
self._height = int(self._base/np.sqrt(2))
self._height = int(self._base / np.sqrt(2))
self._elongation_factor = elongation
self._create()
self._compute_contour()

def _create(self):
"""
Create the elongated hexagonal shape
"""
pad = self._pad
triangle = np.tril(np.ones((self._height, self._base)))
triangle = sktrans.rotate(triangle, angle=-15, center=(0,0), order=0)
triangle = sktrans.rotate(triangle, angle=-15, center=(0, 0), order=0)
rectangle = np.ones((self._height, self._base))
for _ in range(self._elongation_factor):
rectangle = np.concatenate([rectangle, rectangle[:,:1]], axis=1)
upper_half = np.concatenate([triangle[:,::-1],rectangle,triangle], axis=1)
rectangle = np.concatenate([rectangle, rectangle[:, :1]], axis=1)
upper_half = np.concatenate([triangle[:, ::-1], rectangle, triangle], axis=1)
hexagon = np.concatenate([upper_half, upper_half[::-1]], axis=0)
hexagon = np.pad(hexagon, ((pad,pad), (pad,pad)))
hexagon = np.pad(hexagon, ((pad, pad), (pad, pad)))
self._polygon = hexagon
return

@staticmethod
def get_default_parameters_as_dict(elongation=8, base_ini=24, base_end=64):
params = []
Expand Down
23 changes: 12 additions & 11 deletions aicscytoparam/alignment/shape_library_2d.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
import matplotlib.pyplot as plt

from aicscytoparam.alignment.generic_2d_shape import Generic2DShape


class ShapeLibrary2D():
"""
Define a library of 2D shapes
"""
def __init__(self):
pass

def set_base_shape(self, polygon):
"""
Set the base shape for the library
Expand All @@ -19,7 +19,7 @@ def set_base_shape(self, polygon):
base shape for the library
"""
self._polygon = polygon

def set_parameters_range(self, params_dict):
"""
Set the parameters range for the library
Expand Down Expand Up @@ -57,7 +57,7 @@ def find_best_match(self, cx, cy):
idx = np.argmin(dists)
return idx, self._params[idx], angles[idx]

def display(self, xlim=[-150, 150], ylim=[-50,50], contours_to_match=None):
def display(self, xlim=[-150, 150], ylim=[-50, 50], contours_to_match=None):
"""
Display the shapes in the library
Parameters
Expand All @@ -70,18 +70,19 @@ def display(self, xlim=[-150, 150], ylim=[-50,50], contours_to_match=None):
list of tuples with the contours to match
"""
n = int(np.sqrt(len(self._params)))
fig, axs = plt.subplots(n, n, figsize=(3*n,1*n))
fig, axs = plt.subplots(n, n, figsize=(3 * n, 1 * n))
for pid, p in enumerate(self._params):
j, i = pid // n, pid % n
poly = self._polygon(**p)
axs[pid//n, pid%n].plot(poly.cx, poly.cy, lw=7, color="k", alpha=0.2)
axs[pid//n, pid%n].axis("off")
axs[pid//n, pid%n].set_aspect("equal")
axs[pid//n, pid%n].set_xlim(xlim[0], xlim[1])
axs[pid//n, pid%n].set_ylim(ylim[0], ylim[1])
axs[j, i].plot(poly.cx, poly.cy, lw=7, color="k", alpha=0.2)
axs[j, i].axis("off")
axs[j, i].set_aspect("equal")
axs[j, i].set_xlim(xlim[0], xlim[1])
axs[j, i].set_ylim(ylim[0], ylim[1])
if contours_to_match is not None:
for (cx, cy) in contours_to_match:
pid, p, angle = self.find_best_match(cx, cy)
cxrot, cyrot = Generic2DShape.rotate_contour(cx, cy, angle)
axs[pid//n, pid%n].plot(cxrot, cyrot, color="magenta")
axs[j, i].plot(cxrot, cyrot, color="magenta")
plt.tight_layout()
plt.show()

0 comments on commit 3bb46e6

Please sign in to comment.