diff --git a/ultrack/core/interactive.py b/ultrack/core/interactive.py index 18a8ad8..cabd657 100644 --- a/ultrack/core/interactive.py +++ b/ultrack/core/interactive.py @@ -264,7 +264,7 @@ def add_new_node( node = Node.from_mask( time=time, mask=mask, - bbox=bbox, + bbox=np.asarray(bbox), ) if node.area == 0: raise ValueError("Node area is zero. Something went wrong.") diff --git a/ultrack/core/linking/__init__.py b/ultrack/core/linking/__init__.py index e69de29..6623968 100644 --- a/ultrack/core/linking/__init__.py +++ b/ultrack/core/linking/__init__.py @@ -0,0 +1 @@ +from ultrack.core.linking.processing import add_links diff --git a/ultrack/core/linking/processing.py b/ultrack/core/linking/processing.py index f02a944..1941d8e 100644 --- a/ultrack/core/linking/processing.py +++ b/ultrack/core/linking/processing.py @@ -48,6 +48,61 @@ def _compute_features( ] +def color_filtering_mask( + time: int, + current_nodes: List[Node], + next_nodes: List[Node], + images: Sequence[ArrayLike], + neighbors: ArrayLike, + z_score_threshold: float, +) -> ArrayLike: + """ + Filtering by color z-score. + + Parameters + ---------- + time : int + Current time. + current_nodes : List[Node] + List of source nodes. + next_nodes : List[Node] + List of target nodes. + images : Sequence[ArrayLike] + Sequence of images to extract color features for filtering. + neighbors : ArrayLike + Neighbors indices (current/source) for each target (next) node. + z_score_threshold : float + Z-score threshold for color filtering. + + Returns + ------- + ArrayLike + Boolean mask of neighboring nodes within color z-score threshold. + + """ + LOG.info(f"computing filtering by color z-score from t={time}") + (current_features,) = _compute_features( + time, current_nodes, images, [Node.intensity_mean] + ) + # inserting dummy value for missing neighbors + current_features = np.append( + current_features, + np.zeros((1, current_features.shape[1])), + axis=0, + ) + next_features, next_features_std = _compute_features( + time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std] + ) + LOG.info( + f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}" + ) + next_features_std[next_features_std <= 1e-6] = 1.0 + difference = next_features[:, None, ...] - current_features[neighbors] + difference /= next_features_std[:, None, ...] + filtered_by_color = np.abs(difference).max(axis=-1) <= z_score_threshold + return filtered_by_color + + @curry def _process( time: int, @@ -91,71 +146,86 @@ def _process( next_nodes = [row[0] for row in query] next_shift = np.asarray([row[1:] for row in query]) - current_pos = np.asarray([n.centroid for n in current_nodes]) - next_pos = np.asarray([n.centroid for n in next_nodes], dtype=np.float32) + compute_spatial_neighbors( + time, + config, + current_nodes, + next_nodes, + next_shift, + scale=scale, + table_name=LinkDB.__tablename__, + db_path=db_path, + images=images, + write_lock=write_lock, + ) - n_dim = next_pos.shape[1] - next_shift = next_shift[:, -n_dim:] # matching positions dimensions - next_pos += next_shift + +def compute_spatial_neighbors( + time: int, + config: LinkingConfig, + source_nodes: List[Node], + target_nodes: List[Node], + target_shift: ArrayLike, + scale: Optional[Sequence[float]], + table_name: str, + db_path: str, + images: Sequence[ArrayLike], + write_lock: Optional[fasteners.InterProcessLock] = None, +) -> pd.DataFrame: + + source_pos = np.asarray([n.centroid for n in source_nodes]) + target_pos = np.asarray([n.centroid for n in target_nodes], dtype=np.float32) + + n_dim = target_pos.shape[1] + target_shift = target_shift[:, -n_dim:] # matching positions dimensions + target_pos += target_shift if scale is not None: min_n_dim = min(n_dim, len(scale)) scale = scale[-min_n_dim:] - current_pos = current_pos[..., -min_n_dim:] * scale - next_pos = next_pos[..., -min_n_dim:] * scale + source_pos = source_pos[..., -min_n_dim:] * scale + target_pos = target_pos[..., -min_n_dim:] * scale # finds neighbors nodes within the radius # and connect the pairs with highest edge weight - current_kdtree = KDTree(current_pos) + current_kdtree = KDTree(source_pos) distances, neighbors = current_kdtree.query( - next_pos, + target_pos, # twice as expected because we select the nearest with highest edge weight k=2 * config.max_neighbors, distance_upper_bound=config.max_distance, ) if len(images) > 0: - LOG.info(f"computing filtering by color z-score from t={time}") - (current_features,) = _compute_features( - time, current_nodes, images, [Node.intensity_mean] - ) - # inserting dummy value for missing neighbors - current_features = np.append( - current_features, - np.zeros((1, current_features.shape[1])), - axis=0, - ) - next_features, next_features_std = _compute_features( - time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std] + filtered_by_color = color_filtering_mask( + time, + source_nodes, + target_nodes, + images, + neighbors, + config.z_score_threshold, ) - LOG.info( - f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}" - ) - next_features_std[next_features_std <= 1e-6] = 1.0 - difference = next_features[:, None, ...] - current_features[neighbors] - difference /= next_features_std[:, None, ...] - filtered_by_color = np.abs(difference).max(axis=-1) <= config.z_score_threshold else: filtered_by_color = np.ones_like(neighbors, dtype=bool) - int_next_shift = np.round(next_shift).astype(int) + int_next_shift = np.round(target_shift).astype(int) # NOTE: moving bbox with shift, MUST be after `feature computation` - for node, shift in zip(next_nodes, int_next_shift): + for node, shift in zip(target_nodes, int_next_shift): node.bbox[:n_dim] += shift node.bbox[-n_dim:] += shift distance_w = config.distance_weight links = [] - for i, node in enumerate(next_nodes): + for i, node in enumerate(target_nodes): valid = (~np.isinf(distances[i])) & filtered_by_color[i] valid_neighbors = neighbors[i, valid] neigh_distances = distances[i, valid] neighborhood = [] for neigh_idx, neigh_dist in zip(valid_neighbors, neigh_distances): - neigh = current_nodes[neigh_idx] + neigh = source_nodes[neigh_idx] edge_weight = node.IoU(neigh) - distance_w * neigh_dist # using dist as a tie-breaker neighborhood.append( @@ -176,13 +246,14 @@ def _process( with write_lock if write_lock is not None else nullcontext(): LOG.info(f"Pushing links from time {time} to {db_path}") + connect_args = {"timeout": 45} if write_lock is not None else {} engine = sqla.create_engine( db_path, hide_parameters=True, connect_args=connect_args ) with engine.begin() as conn: - df.to_sql( - name=LinkDB.__tablename__, con=conn, if_exists="append", index=False - ) + df.to_sql(name=table_name, con=conn, if_exists="append", index=False) + + return df def link( @@ -230,3 +301,40 @@ def link( multiprocessing_apply( process, time_points, config.linking_config.n_workers, desc="Linking nodes." ) + + +def add_links( + config: MainConfig, + sources: ArrayLike, + targets: ArrayLike, + weights: ArrayLike, +) -> None: + """ + Adds user-defined links to the database. + + Parameters + ---------- + config : MainConfig + Configuration parameters. + sources : ArrayLike + Sources (t) node id. + targets : ArrayLike + Targets (t + 1) node id. + weights : ArrayLike + Link weights, the higher the weight the more likely the link. + """ + df = pd.DataFrame( + { + "source_id": np.asarray(sources, dtype=int), + "target_id": np.asarray(targets, dtype=int), + "weight": weights, + } + ) + + engine = sqla.create_engine( + config.data_config.database_path, + hide_parameters=True, + ) + + with engine.begin() as conn: + df.to_sql(name=LinkDB.__tablename__, con=conn, if_exists="append", index=False) diff --git a/ultrack/core/tracker.py b/ultrack/core/tracker.py index c7cdf77..5a3a3df 100644 --- a/ultrack/core/tracker.py +++ b/ultrack/core/tracker.py @@ -16,7 +16,7 @@ tracks_layer_to_trackmate, tracks_to_zarr, ) -from ultrack.core.linking.processing import link +from ultrack.core.linking.processing import add_links, link from ultrack.core.main import track from ultrack.core.segmentation.processing import get_nodes_features, segment from ultrack.core.solve.processing import solve @@ -165,6 +165,12 @@ def get_nodes_features(self, **kwargs) -> pd.DataFrame: nodes_features_df = get_nodes_features(self.config, **kwargs) return nodes_features_df + @functools.wraps(add_links) + def add_links(self, **kwargs) -> None: + self._assert_segmented("add_links") + add_links(config=self.config, **kwargs) + self.status |= TrackerStatus.LINKED + @functools.wraps(add_nodes_prob) def add_nodes_prob( self,