Skip to content

Commit

Permalink
refactored link processing code and added add_links function
Browse files Browse the repository at this point in the history
  • Loading branch information
JoOkuma committed Oct 18, 2024
1 parent 23e41db commit b2c2718
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 45 deletions.
2 changes: 1 addition & 1 deletion ultrack/core/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def add_new_node(
node = Node.from_mask(
time=time,
mask=mask,
bbox=bbox,
bbox=np.asarray(bbox),
)
if node.area == 0:
raise ValueError("Node area is zero. Something went wrong.")
Expand Down
1 change: 1 addition & 0 deletions ultrack/core/linking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ultrack.core.linking.processing import add_links
178 changes: 143 additions & 35 deletions ultrack/core/linking/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,61 @@ def _compute_features(
]


def color_filtering_mask(
time: int,
current_nodes: List[Node],
next_nodes: List[Node],
images: Sequence[ArrayLike],
neighbors: ArrayLike,
z_score_threshold: float,
) -> ArrayLike:
"""
Filtering by color z-score.
Parameters
----------
time : int
Current time.
current_nodes : List[Node]
List of source nodes.
next_nodes : List[Node]
List of target nodes.
images : Sequence[ArrayLike]
Sequence of images to extract color features for filtering.
neighbors : ArrayLike
Neighbors indices (current/source) for each target (next) node.
z_score_threshold : float
Z-score threshold for color filtering.
Returns
-------
ArrayLike
Boolean mask of neighboring nodes within color z-score threshold.
"""
LOG.info(f"computing filtering by color z-score from t={time}")
(current_features,) = _compute_features(
time, current_nodes, images, [Node.intensity_mean]
)
# inserting dummy value for missing neighbors
current_features = np.append(
current_features,
np.zeros((1, current_features.shape[1])),
axis=0,
)
next_features, next_features_std = _compute_features(
time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std]
)
LOG.info(
f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}"
)
next_features_std[next_features_std <= 1e-6] = 1.0
difference = next_features[:, None, ...] - current_features[neighbors]
difference /= next_features_std[:, None, ...]
filtered_by_color = np.abs(difference).max(axis=-1) <= z_score_threshold
return filtered_by_color


@curry
def _process(
time: int,
Expand Down Expand Up @@ -91,71 +146,86 @@ def _process(
next_nodes = [row[0] for row in query]
next_shift = np.asarray([row[1:] for row in query])

current_pos = np.asarray([n.centroid for n in current_nodes])
next_pos = np.asarray([n.centroid for n in next_nodes], dtype=np.float32)
compute_spatial_neighbors(
time,
config,
current_nodes,
next_nodes,
next_shift,
scale=scale,
table_name=LinkDB.__tablename__,
db_path=db_path,
images=images,
write_lock=write_lock,
)

n_dim = next_pos.shape[1]
next_shift = next_shift[:, -n_dim:] # matching positions dimensions
next_pos += next_shift

def compute_spatial_neighbors(
time: int,
config: LinkingConfig,
source_nodes: List[Node],
target_nodes: List[Node],
target_shift: ArrayLike,
scale: Optional[Sequence[float]],
table_name: str,
db_path: str,
images: Sequence[ArrayLike],
write_lock: Optional[fasteners.InterProcessLock] = None,
) -> pd.DataFrame:

source_pos = np.asarray([n.centroid for n in source_nodes])
target_pos = np.asarray([n.centroid for n in target_nodes], dtype=np.float32)

n_dim = target_pos.shape[1]
target_shift = target_shift[:, -n_dim:] # matching positions dimensions
target_pos += target_shift

if scale is not None:
min_n_dim = min(n_dim, len(scale))
scale = scale[-min_n_dim:]
current_pos = current_pos[..., -min_n_dim:] * scale
next_pos = next_pos[..., -min_n_dim:] * scale
source_pos = source_pos[..., -min_n_dim:] * scale
target_pos = target_pos[..., -min_n_dim:] * scale

# finds neighbors nodes within the radius
# and connect the pairs with highest edge weight
current_kdtree = KDTree(current_pos)
current_kdtree = KDTree(source_pos)

distances, neighbors = current_kdtree.query(
next_pos,
target_pos,
# twice as expected because we select the nearest with highest edge weight
k=2 * config.max_neighbors,
distance_upper_bound=config.max_distance,
)

if len(images) > 0:
LOG.info(f"computing filtering by color z-score from t={time}")
(current_features,) = _compute_features(
time, current_nodes, images, [Node.intensity_mean]
)
# inserting dummy value for missing neighbors
current_features = np.append(
current_features,
np.zeros((1, current_features.shape[1])),
axis=0,
)
next_features, next_features_std = _compute_features(
time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std]
filtered_by_color = color_filtering_mask(
time,
source_nodes,
target_nodes,
images,
neighbors,
config.z_score_threshold,
)
LOG.info(
f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}"
)
next_features_std[next_features_std <= 1e-6] = 1.0
difference = next_features[:, None, ...] - current_features[neighbors]
difference /= next_features_std[:, None, ...]
filtered_by_color = np.abs(difference).max(axis=-1) <= config.z_score_threshold
else:
filtered_by_color = np.ones_like(neighbors, dtype=bool)

int_next_shift = np.round(next_shift).astype(int)
int_next_shift = np.round(target_shift).astype(int)
# NOTE: moving bbox with shift, MUST be after `feature computation`
for node, shift in zip(next_nodes, int_next_shift):
for node, shift in zip(target_nodes, int_next_shift):
node.bbox[:n_dim] += shift
node.bbox[-n_dim:] += shift

distance_w = config.distance_weight
links = []

for i, node in enumerate(next_nodes):
for i, node in enumerate(target_nodes):
valid = (~np.isinf(distances[i])) & filtered_by_color[i]
valid_neighbors = neighbors[i, valid]
neigh_distances = distances[i, valid]

neighborhood = []
for neigh_idx, neigh_dist in zip(valid_neighbors, neigh_distances):
neigh = current_nodes[neigh_idx]
neigh = source_nodes[neigh_idx]
edge_weight = node.IoU(neigh) - distance_w * neigh_dist
# using dist as a tie-breaker
neighborhood.append(
Expand All @@ -176,13 +246,14 @@ def _process(

with write_lock if write_lock is not None else nullcontext():
LOG.info(f"Pushing links from time {time} to {db_path}")
connect_args = {"timeout": 45} if write_lock is not None else {}
engine = sqla.create_engine(
db_path, hide_parameters=True, connect_args=connect_args
)
with engine.begin() as conn:
df.to_sql(
name=LinkDB.__tablename__, con=conn, if_exists="append", index=False
)
df.to_sql(name=table_name, con=conn, if_exists="append", index=False)

return df


def link(
Expand Down Expand Up @@ -230,3 +301,40 @@ def link(
multiprocessing_apply(
process, time_points, config.linking_config.n_workers, desc="Linking nodes."
)


def add_links(
config: MainConfig,
sources: ArrayLike,
targets: ArrayLike,
weights: ArrayLike,
) -> None:
"""
Adds user-defined links to the database.
Parameters
----------
config : MainConfig
Configuration parameters.
sources : ArrayLike
Sources (t) node id.
targets : ArrayLike
Targets (t + 1) node id.
weights : ArrayLike
Link weights, the higher the weight the more likely the link.
"""
df = pd.DataFrame(
{
"source_id": np.asarray(sources, dtype=int),
"target_id": np.asarray(targets, dtype=int),
"weight": weights,
}
)

engine = sqla.create_engine(
config.data_config.database_path,
hide_parameters=True,
)

with engine.begin() as conn:
df.to_sql(name=LinkDB.__tablename__, con=conn, if_exists="append", index=False)
28 changes: 19 additions & 9 deletions ultrack/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
tracks_layer_to_trackmate,
tracks_to_zarr,
)
from ultrack.core.linking.processing import link
from ultrack.core.linking.processing import add_links, link
from ultrack.core.main import track
from ultrack.core.segmentation.processing import segment
from ultrack.core.solve.processing import solve
Expand Down Expand Up @@ -71,32 +71,31 @@ def __init__(self, config: MainConfig) -> None:
@rename_argument("edges", "contours")
def segment(self, foreground: ArrayLike, contours: ArrayLike, **kwargs) -> None:
segment(foreground=foreground, contours=contours, config=self.config, **kwargs)
self.status = TrackerStatus.SEGMENTED
self.status &= ~TrackerStatus.NOT_COMPUTED
self.status |= TrackerStatus.SEGMENTED

@functools.wraps(add_flow)
def add_flow(self, vector_field: ArrayLike) -> None:
if TrackerStatus.SEGMENTED not in self.status:
raise ValueError("You must call `segment` before calling `add_flow`.")
self._assert_segmented("add_flow")
add_flow(config=self.config, vector_field=vector_field)

@functools.wraps(link)
def link(self, *args, **kwargs) -> None:
if TrackerStatus.SEGMENTED not in self.status:
raise ValueError("You must call `segment` before calling `link`.")
self._assert_segmented("link")
link(config=self.config, *args, **kwargs)
self.status = TrackerStatus.LINKED
self.status |= TrackerStatus.LINKED

@functools.wraps(solve)
def solve(self, *args, **kwargs) -> None:
if TrackerStatus.LINKED not in self.status:
raise ValueError("You must call `segment` & `link` before calling `solve`.")
solve(config=self.config, *args, **kwargs)
self.status = TrackerStatus.SOLVED
self.status |= TrackerStatus.SOLVED

@functools.wraps(track)
def track(self, *args, **kwargs) -> None:
track(config=self.config, *args, **kwargs)
self.status = TrackerStatus.SOLVED
self.status |= TrackerStatus.SOLVED

def _assert_solved(self) -> None:
"""Raise an error if the tracking is not solved."""
Expand All @@ -106,6 +105,11 @@ def _assert_solved(self) -> None:
"called `segment` &a `link` & `solve` or `track`."
)

def _assert_segmented(self, method_name: str) -> None:
"""Raise an error if segmentation is not done."""
if TrackerStatus.SEGMENTED not in self.status:
raise ValueError(f"You must call `segment` before calling `{method_name}`.")

@functools.wraps(tracks_layer_to_networkx)
def to_networkx(
self, *, tracks_df: Optional[pd.DataFrame] = None, **kwargs
Expand Down Expand Up @@ -155,6 +159,12 @@ def export_by_extension(self, filename: str, overwrite: bool = False) -> None:
self._assert_solved()
export_tracks_by_extension(self.config, filename, overwrite=overwrite)

@functools.wraps(add_links)
def add_links(self, **kwargs) -> None:
self._assert_segmented("add_links")
add_links(config=self.config, **kwargs)
self.status |= TrackerStatus.LINKED

@functools.wraps(add_nodes_prob)
def add_nodes_prob(
self,
Expand Down

0 comments on commit b2c2718

Please sign in to comment.