Skip to content

Commit

Permalink
Merge branch 'public-main' into node-features-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
JoOkuma committed Oct 19, 2024
2 parents b779da5 + 2b73f41 commit 1e8f88e
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 37 deletions.
2 changes: 1 addition & 1 deletion ultrack/core/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def add_new_node(
node = Node.from_mask(
time=time,
mask=mask,
bbox=bbox,
bbox=np.asarray(bbox),
)
if node.area == 0:
raise ValueError("Node area is zero. Something went wrong.")
Expand Down
1 change: 1 addition & 0 deletions ultrack/core/linking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ultrack.core.linking.processing import add_links
178 changes: 143 additions & 35 deletions ultrack/core/linking/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,61 @@ def _compute_features(
]


def color_filtering_mask(
time: int,
current_nodes: List[Node],
next_nodes: List[Node],
images: Sequence[ArrayLike],
neighbors: ArrayLike,
z_score_threshold: float,
) -> ArrayLike:
"""
Filtering by color z-score.
Parameters
----------
time : int
Current time.
current_nodes : List[Node]
List of source nodes.
next_nodes : List[Node]
List of target nodes.
images : Sequence[ArrayLike]
Sequence of images to extract color features for filtering.
neighbors : ArrayLike
Neighbors indices (current/source) for each target (next) node.
z_score_threshold : float
Z-score threshold for color filtering.
Returns
-------
ArrayLike
Boolean mask of neighboring nodes within color z-score threshold.
"""
LOG.info(f"computing filtering by color z-score from t={time}")
(current_features,) = _compute_features(
time, current_nodes, images, [Node.intensity_mean]
)
# inserting dummy value for missing neighbors
current_features = np.append(
current_features,
np.zeros((1, current_features.shape[1])),
axis=0,
)
next_features, next_features_std = _compute_features(
time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std]
)
LOG.info(
f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}"
)
next_features_std[next_features_std <= 1e-6] = 1.0
difference = next_features[:, None, ...] - current_features[neighbors]
difference /= next_features_std[:, None, ...]
filtered_by_color = np.abs(difference).max(axis=-1) <= z_score_threshold
return filtered_by_color


@curry
def _process(
time: int,
Expand Down Expand Up @@ -91,71 +146,86 @@ def _process(
next_nodes = [row[0] for row in query]
next_shift = np.asarray([row[1:] for row in query])

current_pos = np.asarray([n.centroid for n in current_nodes])
next_pos = np.asarray([n.centroid for n in next_nodes], dtype=np.float32)
compute_spatial_neighbors(
time,
config,
current_nodes,
next_nodes,
next_shift,
scale=scale,
table_name=LinkDB.__tablename__,
db_path=db_path,
images=images,
write_lock=write_lock,
)

n_dim = next_pos.shape[1]
next_shift = next_shift[:, -n_dim:] # matching positions dimensions
next_pos += next_shift

def compute_spatial_neighbors(
time: int,
config: LinkingConfig,
source_nodes: List[Node],
target_nodes: List[Node],
target_shift: ArrayLike,
scale: Optional[Sequence[float]],
table_name: str,
db_path: str,
images: Sequence[ArrayLike],
write_lock: Optional[fasteners.InterProcessLock] = None,
) -> pd.DataFrame:

source_pos = np.asarray([n.centroid for n in source_nodes])
target_pos = np.asarray([n.centroid for n in target_nodes], dtype=np.float32)

n_dim = target_pos.shape[1]
target_shift = target_shift[:, -n_dim:] # matching positions dimensions
target_pos += target_shift

if scale is not None:
min_n_dim = min(n_dim, len(scale))
scale = scale[-min_n_dim:]
current_pos = current_pos[..., -min_n_dim:] * scale
next_pos = next_pos[..., -min_n_dim:] * scale
source_pos = source_pos[..., -min_n_dim:] * scale
target_pos = target_pos[..., -min_n_dim:] * scale

# finds neighbors nodes within the radius
# and connect the pairs with highest edge weight
current_kdtree = KDTree(current_pos)
current_kdtree = KDTree(source_pos)

distances, neighbors = current_kdtree.query(
next_pos,
target_pos,
# twice as expected because we select the nearest with highest edge weight
k=2 * config.max_neighbors,
distance_upper_bound=config.max_distance,
)

if len(images) > 0:
LOG.info(f"computing filtering by color z-score from t={time}")
(current_features,) = _compute_features(
time, current_nodes, images, [Node.intensity_mean]
)
# inserting dummy value for missing neighbors
current_features = np.append(
current_features,
np.zeros((1, current_features.shape[1])),
axis=0,
)
next_features, next_features_std = _compute_features(
time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std]
filtered_by_color = color_filtering_mask(
time,
source_nodes,
target_nodes,
images,
neighbors,
config.z_score_threshold,
)
LOG.info(
f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}"
)
next_features_std[next_features_std <= 1e-6] = 1.0
difference = next_features[:, None, ...] - current_features[neighbors]
difference /= next_features_std[:, None, ...]
filtered_by_color = np.abs(difference).max(axis=-1) <= config.z_score_threshold
else:
filtered_by_color = np.ones_like(neighbors, dtype=bool)

int_next_shift = np.round(next_shift).astype(int)
int_next_shift = np.round(target_shift).astype(int)
# NOTE: moving bbox with shift, MUST be after `feature computation`
for node, shift in zip(next_nodes, int_next_shift):
for node, shift in zip(target_nodes, int_next_shift):
node.bbox[:n_dim] += shift
node.bbox[-n_dim:] += shift

distance_w = config.distance_weight
links = []

for i, node in enumerate(next_nodes):
for i, node in enumerate(target_nodes):
valid = (~np.isinf(distances[i])) & filtered_by_color[i]
valid_neighbors = neighbors[i, valid]
neigh_distances = distances[i, valid]

neighborhood = []
for neigh_idx, neigh_dist in zip(valid_neighbors, neigh_distances):
neigh = current_nodes[neigh_idx]
neigh = source_nodes[neigh_idx]
edge_weight = node.IoU(neigh) - distance_w * neigh_dist
# using dist as a tie-breaker
neighborhood.append(
Expand All @@ -176,13 +246,14 @@ def _process(

with write_lock if write_lock is not None else nullcontext():
LOG.info(f"Pushing links from time {time} to {db_path}")
connect_args = {"timeout": 45} if write_lock is not None else {}
engine = sqla.create_engine(
db_path, hide_parameters=True, connect_args=connect_args
)
with engine.begin() as conn:
df.to_sql(
name=LinkDB.__tablename__, con=conn, if_exists="append", index=False
)
df.to_sql(name=table_name, con=conn, if_exists="append", index=False)

return df


def link(
Expand Down Expand Up @@ -230,3 +301,40 @@ def link(
multiprocessing_apply(
process, time_points, config.linking_config.n_workers, desc="Linking nodes."
)


def add_links(
config: MainConfig,
sources: ArrayLike,
targets: ArrayLike,
weights: ArrayLike,
) -> None:
"""
Adds user-defined links to the database.
Parameters
----------
config : MainConfig
Configuration parameters.
sources : ArrayLike
Sources (t) node id.
targets : ArrayLike
Targets (t + 1) node id.
weights : ArrayLike
Link weights, the higher the weight the more likely the link.
"""
df = pd.DataFrame(
{
"source_id": np.asarray(sources, dtype=int),
"target_id": np.asarray(targets, dtype=int),
"weight": weights,
}
)

engine = sqla.create_engine(
config.data_config.database_path,
hide_parameters=True,
)

with engine.begin() as conn:
df.to_sql(name=LinkDB.__tablename__, con=conn, if_exists="append", index=False)
8 changes: 7 additions & 1 deletion ultrack/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
tracks_layer_to_trackmate,
tracks_to_zarr,
)
from ultrack.core.linking.processing import link
from ultrack.core.linking.processing import add_links, link
from ultrack.core.main import track
from ultrack.core.segmentation.processing import get_nodes_features, segment
from ultrack.core.solve.processing import solve
Expand Down Expand Up @@ -165,6 +165,12 @@ def get_nodes_features(self, **kwargs) -> pd.DataFrame:
nodes_features_df = get_nodes_features(self.config, **kwargs)
return nodes_features_df

@functools.wraps(add_links)
def add_links(self, **kwargs) -> None:
self._assert_segmented("add_links")
add_links(config=self.config, **kwargs)
self.status |= TrackerStatus.LINKED

@functools.wraps(add_nodes_prob)
def add_nodes_prob(
self,
Expand Down

0 comments on commit 1e8f88e

Please sign in to comment.