Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swap out networkx for rustworkx #147

Open
4 of 14 tasks
bentaculum opened this issue Apr 11, 2024 · 4 comments
Open
4 of 14 tasks

Swap out networkx for rustworkx #147

bentaculum opened this issue Apr 11, 2024 · 4 comments
Labels
enhancement New feature or request help wanted Extra attention is needed performance Changes that impact runtime and performance

Comments

@bentaculum
Copy link
Contributor

Description

Hi all,

I would like to get started with the once-upon-a-time discussed swapping of the backend from networkx to rustworkx https://www.rustworkx.org/index.html.
The reason being that for example for basic 2d datasets from the cell tracking challenge (PhC-C2DL-PSC, ~70k nodes), calculating CTCMetrics takes > 1 minute right now.

This seems to be, amongst other things, due to some basic attribute getting from the graph in the matching of GT and predicted nodes.

I anticipate the following challenges:

  • rustworkx is NOT a drop-in replacement for networkx, and not even a subset of networkx. We currently use networkx-only convenience functions like from_pandas_edgelist() in graph construction. This there requires dealing with the different API.
  • It is not certain that there will be speedups for traccuracy. The rustworkx benchmarks I find are promising but limited to certain high-level algorithms.
  • rustworkx requires node ids to be int. We currently use strings of format segmentation-ID_time.
  • rustworkx is still a 0. release and may introduce breaking changes in the future.

Has anyone else already given this some more thoughts and would like to pitch in, with either comments or coding together?

Cheers

Ben

Topics

What types of changes are you suggesting? Put an x in the boxes that apply.

  • New feature or enhancement
  • Documentation update
  • Tests and benchmarks
  • Maintenance (e.g. dependencies, CI, releases, etc.)

Which topics does your change affect? Put an x in the boxes that apply.

  • Loaders
  • Matchers
  • Track Errors
  • Metrics
  • Core functionality (e.g. TrackingGraph, run_metrics, cli, etc.)

Priority

  • This is an essential feature
  • Nice to have
  • Future idea

Are you interested in contributing?

  • Yes! 🎉
  • No 🙁
@DragaDoncila
Copy link
Collaborator

Thanks for opening the conversation @bentaculum!

rustworkx is NOT a drop-in replacement for networkx, and not even a subset of networkx. We currently use networkx-only convenience functions like from_pandas_edgelist() in graph construction. This there requires dealing with the different API.

I think this is ok. The convenience of networkx is fine but for this package I'm more concerned about speed than I am about the convenenience of the internals. I do think from the perspective of the user we should be taking networkx objects and returning networkx objects, casting/converting to rustworkx internally only (once on load, once on return). And of course, for use with CTC data, it should be completely transparent to the user via our loader.

It is not certain that there will be speedups for traccuracy. The rustworkx benchmarks I find are promising but limited to certain high-level algorithms.

I suppose we will need to check! Have you profiled the metrics computation by the way? I'd be interested to see a profile - maybe there's still simple things we can do on our end.

rustworkx requires node ids to be int. We currently use strings of format segmentation-ID_time.

I have no strong preference for keeping the string IDs (they kinda annoy me actually because all my other graphs have int IDs), but I undersand the initial driver of wanting them to be meaningful for a user.

@bentaculum
Copy link
Contributor Author

I suppose we will need to check! Have you profiled the metrics computation by the way? I'd be interested to see a profile - maybe there's still simple things we can do on our end.

Fresh profiling with ipython's %lprun. Somehow faster than yesterday ...
Most of the time is spent on getting the label_to_id mappings from the graphs.

from traccuracy import run_metrics
from traccuracy.loaders import load_ctc_data
from traccuracy.metrics import CTCMetrics
from traccuracy.matchers import CTCMatcher

gt_data = load_ctc_data(
    '/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA',
    '/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA/man_track.txt'
)
pred_data = load_ctc_data(
    '/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA',
    '/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA/man_track.txt'
)

%time ctc_results = run_metrics(gt_data=gt_data,pred_data=pred_data,matcher=CTCMatcher(),metrics=[CTCMetrics()],)

%lprun -f CTCMatcher._compute_mapping -T profile.txt ctc_results = run_metrics(gt_data=gt_data,pred_data=pred_data,matcher=CTCMatcher(),metrics=[CTCMetrics()],)
Matching frames: 100%|█████████████████████████████████████████████████| 300/300 [00:14<00:00, 20.52it/s]
INFO:traccuracy.matchers._base:Matched 71403 out of 71403 ground truth nodes.
INFO:traccuracy.matchers._base:Matched 71403 out of 71403 predicted nodes.
Evaluating nodes: 100%|████████████████████████████████████████| 71403/71403 [00:00<00:00, 594404.37it/s]
Evaluating FP edges: 100%|█████████████████████████████████████| 71201/71201 [00:00<00:00, 919657.80it/s]
Evaluating FN edges: 100%|████████████████████████████████████| 71201/71201 [00:00<00:00, 1002065.74it/s]
CPU times: user 16.5 s, sys: 895 ms, total: 17.4 s
Wall time: 17.6 s
Timer unit: 1e-09 s

Total time: 31.7674 s
File: /Users/gallusse/code/traccuracy/src/traccuracy/matchers/_ctc.py
Function: _compute_mapping at line 29

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    29                                               def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
    30                                                   """Run ctc matching
    31                                           
    32                                                   Args:
    33                                                       gt_graph (TrackingGraph): Tracking graph object for the gt
    34                                                       pred_graph (TrackingGraph): Tracking graph object for the pred
    35                                           
    36                                                   Returns:
    37                                                       traccuracy.matchers.Matched: Matched data object containing the CTC mapping
    38                                           
    39                                                   Raises:
    40                                                       ValueError: if GT and pred segmentations are None or are not the same shape
    41                                                   """
    42         1       1000.0   1000.0      0.0          gt = gt_graph
    43         1          0.0      0.0      0.0          pred = pred_graph
    44         1       1000.0   1000.0      0.0          gt_label_key = gt_graph.label_key
    45         1          0.0      0.0      0.0          pred_label_key = pred_graph.label_key
    46         1       1000.0   1000.0      0.0          G_gt, mask_gt = gt, gt.segmentation
    47         1          0.0      0.0      0.0          G_pred, mask_pred = pred, pred.segmentation
    48                                           
    49         1          0.0      0.0      0.0          if mask_gt is None or mask_pred is None:
    50                                                       raise ValueError("Segmentation is None, cannot perform matching")
    51                                           
    52         1       4000.0   4000.0      0.0          if mask_gt.shape != mask_pred.shape:
    53                                                       raise ValueError("Segmentation shapes must match between gt and pred")
    54                                           
    55         1          0.0      0.0      0.0          mapping = []
    56                                                   # Get overlaps for each frame
    57       302   41774000.0 138324.5      0.1          for i, t in enumerate(
    58         2     291000.0 145500.0      0.0              tqdm(
    59         1       2000.0   2000.0      0.0                  range(gt.start_frame, gt.end_frame),
    60         1          0.0      0.0      0.0                  desc="Matching frames",
    61                                                       )
    62                                                   ):
    63       300     250000.0    833.3      0.0              gt_frame = mask_gt[i]
    64       300      91000.0    303.3      0.0              pred_frame = mask_pred[i]
    65       300     327000.0   1090.0      0.0              gt_frame_nodes = gt.nodes_by_frame[t]
    66       300     244000.0    813.3      0.0              pred_frame_nodes = pred.nodes_by_frame[t]
    67                                           
    68                                                       # get the labels for this frame
    69       600 2552632000.0    4e+06      8.0              gt_labels = dict(
    70       600     424000.0    706.7      0.0                  filter(
    71       300     147000.0    490.0      0.0                      lambda item: item[0] in gt_frame_nodes,
    72       300        1e+10    4e+07     34.3                      nx.get_node_attributes(G_gt.graph, gt_label_key).items(),
    73                                                           )
    74                                                       )
    75       300    9992000.0  33306.7      0.0              gt_label_to_id = {v: k for k, v in gt_labels.items()}
    76                                           
    77       600 2465360000.0    4e+06      7.8              pred_labels = dict(
    78       600     470000.0    783.3      0.0                  filter(
    79       300     182000.0    606.7      0.0                      lambda item: item[0] in pred_frame_nodes,
    80       300        1e+10    4e+07     34.3                      nx.get_node_attributes(G_pred.graph, pred_label_key).items(),
    81                                                           )
    82                                                       )
    83       300    9263000.0  30876.7      0.0              pred_label_to_id = {v: k for k, v in pred_labels.items()}
    84                                           
    85       300      48000.0    160.0      0.0              (
    86       300      84000.0    280.0      0.0                  overlapping_gt_labels,
    87       300      65000.0    216.7      0.0                  overlapping_pred_labels,
    88       300     586000.0   1953.3      0.0                  intersection,
    89       300 4715872000.0    2e+07     14.8              ) = get_labels_with_overlap(gt_frame, pred_frame)
    90                                           
    91     72093    6224000.0     86.3      0.0              for i in range(len(overlapping_gt_labels)):
    92     71793    9997000.0    139.2      0.0                  gt_label = overlapping_gt_labels[i]
    93     71793    8246000.0    114.9      0.0                  pred_label = overlapping_pred_labels[i]
    94                                                           # CTC metrics only match comp IDs to a single GT ID if there is majority overlap
    95     71793    9167000.0    127.7      0.0                  if intersection[i] > 0.5:
    96    142806   16758000.0    117.3      0.1                      mapping.append(
    97     71403  116304000.0   1628.8      0.4                          (gt_label_to_id[gt_label], pred_label_to_id[pred_label])
    98                                                               )
    99                                           
   100         1       2000.0   2000.0      0.0          return Matched(gt_graph, pred_graph, mapping)

@bentaculum
Copy link
Contributor Author

I coded up a simple speed-up without swapping the backend for now #148.

@bentaculum bentaculum added help wanted Extra attention is needed performance Changes that impact runtime and performance labels Apr 12, 2024
@bentaculum
Copy link
Contributor Author

I do think from the perspective of the user we should be taking networkx objects and returning networkx objects, casting/converting to rustworkx internally only (once on load, once on return). And of course, for use with CTC data, it should be completely transparent to the user via our loader.

I like this, agreed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed performance Changes that impact runtime and performance
Projects
Status: No status
Development

No branches or pull requests

2 participants