@@ -489,34 +489,6 @@ def segment(
489
489
val_dataloader = dm .val_dataloader ()
490
490
test_dataloader = dm .test_dataloader ()
491
491
492
- # # Initialize Dask DataFrame for assignments
493
- # output_ddf = None
494
-
495
- # @dask.delayed
496
- # def process_batch(batch, gpu_id):
497
- # # Assume you're using CuPy, and you need to use a specific GPU
498
- # predict_batch(
499
- # model,
500
- # batch,
501
- # score_cut,
502
- # receptive_field,
503
- # use_cc=use_cc,
504
- # knn_method=knn_method,
505
- # edge_index_save_path=edge_index_save_path,
506
- # output_ddf_save_path=output_ddf_save_path,
507
- # gpu_id=gpu_id
508
- # )
509
-
510
- # delayed_tasks = [process_batch(batch, gpu_ids[i % len(gpu_ids)]) for i, batch in enumerate(dm.train)]
511
- # # pqdm(delayed_tasks, n_jobs=len(gpu_ids), argument_type='delayed', progress_bar=True)
512
- # # dask.compute(*delayed_tasks)
513
- # # delayed_tasks = [process_batch(batch, gpu_ids[i % len(gpu_ids)]) for i, batch in enumerate(batches)]
514
-
515
- # # Use tqdm for progress bar
516
- # with ProgressBar():
517
- # # Execute the delayed tasks with a Dask compute call
518
- # dask.compute(*delayed_tasks)
519
-
520
492
# Loop through the data loaders (train, val, and test)
521
493
for loader_name , loader in zip (
522
494
["Train" , "Validation" , "Test" ], [train_dataloader , val_dataloader , test_dataloader ]
@@ -544,124 +516,122 @@ def segment(
544
516
elapsed_time = time () - step_start_time
545
517
print (f"Batch processing completed in { elapsed_time :.2f} seconds." )
546
518
547
- # Load the full saved segmentation results
548
- seg_final_dd = dd .read_parquet (output_ddf_save_path )
549
- seg_final_dd = seg_final_dd .set_index ("transcript_id" , sorted = False )
550
-
519
+ seg_final_dd = pd .read_parquet (output_ddf_save_path )
520
+ seg_final_dd = seg_final_dd .set_index ("transcript_id" )
521
+
551
522
step_start_time = time ()
552
523
if verbose :
553
524
print (f"Applying max score selection logic..." )
554
-
525
+
555
526
# Step 1: Find max bound indices (bound == 1) and max unbound indices (bound == 0)
556
527
max_bound_idx = seg_final_dd [seg_final_dd ["bound" ] == 1 ].groupby ("transcript_id" )["score" ].idxmax ()
557
528
max_unbound_idx = seg_final_dd [seg_final_dd ["bound" ] == 0 ].groupby ("transcript_id" )["score" ].idxmax ()
558
-
529
+
559
530
# Step 2: Combine indices, prioritizing bound=1 scores
560
- final_idx = max_bound_idx .combine_first (max_unbound_idx ).compute ()
561
- print (final_idx )
562
-
531
+ final_idx = max_bound_idx .combine_first (max_unbound_idx )
532
+
563
533
# Step 3: Use the computed final_idx to select the best assignments
564
534
# Make sure you are using the divisions and set the index correctly before loc
565
- # seg_final_dd = seg_final_dd.set_index('transcript_id', sorted=True)
566
- seg_final_filtered = seg_final_dd .loc [final_idx ].compute ()
567
-
535
+ seg_final_filtered = seg_final_dd .loc [final_idx ]
536
+
568
537
if verbose :
569
538
elapsed_time = time () - step_start_time
570
539
print (f"Max score selection completed in { elapsed_time :.2f} seconds." )
571
-
540
+
572
541
# Step 3: Load the transcripts DataFrame and merge results
573
-
542
+
574
543
if verbose :
575
544
print (f"Loading transcripts from { transcript_file } ..." )
576
-
577
- transcripts_df = dd .read_parquet (transcript_file )
545
+
546
+ transcripts_df = pd .read_parquet (transcript_file )
578
547
transcripts_df ["transcript_id" ] = transcripts_df ["transcript_id" ].astype (str )
579
-
548
+
580
549
step_start_time = time ()
581
550
if verbose :
582
551
print (f"Merging segmentation results with transcripts..." )
583
-
552
+
584
553
# Outer merge to include all transcripts, even those without assigned cell ids
585
554
transcripts_df_filtered = transcripts_df .merge (seg_final_filtered , on = "transcript_id" , how = "outer" )
586
-
555
+
587
556
if verbose :
588
557
elapsed_time = time () - step_start_time
589
- print (f"Merging segmentation results with transcripts completed in { elapsed_time :.2f} seconds." )
590
-
591
- # Step 4: Handle unassigned transcripts using connected components (if use_cc=True)
592
- if use_cc :
593
-
594
- step_start_time = time ()
595
- if verbose :
596
- print (f"Computing connected components for unassigned transcripts..." )
597
- # Load edge indices from saved Parquet
598
- edge_index_dd = dd .read_parquet (edge_index_save_path )
599
-
600
- # Step 2: Get unique transcript_ids from edge_index_dd and their positional indices
601
- transcript_ids_in_edges = dd .concat ([edge_index_dd ["source" ], edge_index_dd ["target" ]]).unique ().compute ()
602
-
603
- # Create a lookup table with unique indices
604
- lookup_table = pd .Series (data = range (len (transcript_ids_in_edges )), index = transcript_ids_in_edges ).to_dict ()
605
-
606
- # Map source and target to positional indices
607
- edge_index_dd ["index_source" ] = edge_index_dd ["source" ].map (lookup_table )
608
- edge_index_dd ["index_target" ] = edge_index_dd ["target" ].map (lookup_table )
609
- # Step 3: Compute connected components for transcripts involved in edges
610
- source_indices = np .asarray (edge_index_dd ["index_source" ].compute ())
611
- target_indices = np .asarray (edge_index_dd ["index_target" ].compute ())
612
- data_cp = np .ones (len (source_indices ), dtype = cp .float32 )
613
-
614
- # Create the sparse COO matrix
615
- coo_cp_matrix = scipy_coo_matrix (
616
- (data_cp , (source_indices , target_indices )),
617
- shape = (len (transcript_ids_in_edges ), len (transcript_ids_in_edges )),
618
- )
619
-
620
- # Use CuPy's connected components algorithm to compute components
621
- n , comps = cc (coo_cp_matrix , directed = True , connection = "weak" )
622
-
623
- # Step 4: Map back the component labels to the original transcript_ids
624
- comp_labels = pd .Series (comps , index = transcript_ids_in_edges )
625
- # Step 5: Handle only unassigned transcripts in transcripts_df_filtered
626
- unassigned_mask = transcripts_df_filtered ["segger_cell_id" ].isna ()
627
-
628
- unassigned_transcripts_df = transcripts_df_filtered .loc [unassigned_mask , ["transcript_id" ]]
629
-
630
- # Step 6: Map component labels only to unassigned transcript_ids
631
- new_segger_cell_ids = unassigned_transcripts_df ["transcript_id" ].map (comp_labels )
632
-
633
- # Step 7: Create a DataFrame with updated 'segger_cell_id' for unassigned transcripts
634
- unassigned_transcripts_df = unassigned_transcripts_df .assign (segger_cell_id = new_segger_cell_ids )
635
-
636
- # Step 8: Merge this DataFrame back into the original to update only the unassigned segger_cell_id
637
- # We perform a left join so that only the rows in unassigned_transcripts_df are updated
638
- # transcripts_df_filtered = transcripts_df_filtered.drop(columns='segger_cell_id')
639
-
640
- # Merging the updates back to the original DataFrame
641
- transcripts_df_filtered = transcripts_df_filtered .merge (
642
- unassigned_transcripts_df [["transcript_id" , "segger_cell_id" ]],
643
- on = "transcript_id" ,
644
- how = "left" , # Perform a left join to only update the unassigned rows
645
- suffixes = ("" , "_new" ), # Suffix for new column to avoid overwriting
646
- )
647
-
648
- # Step 9: Fill missing segger_cell_id values with the updated values from the merge
649
- transcripts_df_filtered ["segger_cell_id" ] = transcripts_df_filtered ["segger_cell_id" ].fillna (
650
- transcripts_df_filtered ["segger_cell_id_new" ]
651
- )
652
-
653
- # Step 10: Clean up by dropping the temporary 'segger_cell_id_new' column
654
- transcripts_df_filtered = transcripts_df_filtered .drop (columns = ["segger_cell_id_new" ])
655
-
656
- # Fill the NaN values in segger_cell_id with the already existing (assigned) values
657
- # transcripts_df_filtered['segger_cell_id'] = transcripts_df_filtered['segger_cell_id'].fillna(transcripts_df_filtered['segger_cell_id_target'])
658
-
659
- # Drop any temporary columns used during the merge
660
- # transcripts_df_filtered = transcripts_df_filtered.drop(columns=['segger_cell_id_target'])
558
+ print (f"Merged segmentation results with transcripts in { elapsed_time :.2f} seconds." )
559
+
560
+ step_start_time = time ()
561
+ if verbose :
562
+ print (f"Computing connected components for unassigned transcripts..." )
563
+ # Load edge indices from saved Parquet
564
+ edge_index_dd = pd .read_parquet (edge_index_save_path )
565
+
566
+ # Step 2: Get unique transcript_ids from edge_index_dd and their positional indices
567
+ transcript_ids_in_edges = pd .concat ([edge_index_dd ["source" ], edge_index_dd ["target" ]]).unique ()
568
+
569
+ # Create a lookup table with unique indices
570
+ lookup_table = pd .Series (data = range (len (transcript_ids_in_edges )), index = transcript_ids_in_edges ).to_dict ()
571
+
572
+ # Map source and target to positional indices
573
+ edge_index_dd ["index_source" ] = edge_index_dd ["source" ].map (lookup_table )
574
+ edge_index_dd ["index_target" ] = edge_index_dd ["target" ].map (lookup_table )
575
+ # Step 3: Compute connected components for transcripts involved in edges
576
+ source_indices = np .asarray (edge_index_dd ["index_source" ])
577
+ target_indices = np .asarray (edge_index_dd ["index_target" ])
578
+ data_cp = np .ones (len (source_indices ), dtype = np .float32 )
579
+
580
+ # Create the sparse COO matrix
581
+ coo_cp_matrix = scipy_coo_matrix (
582
+ (data_cp , (source_indices , target_indices )),
583
+ shape = (len (transcript_ids_in_edges ), len (transcript_ids_in_edges )),
584
+ )
585
+
586
+ # Use CuPy's connected components algorithm to compute components
587
+ n , comps = cc (coo_cp_matrix , directed = True , connection = "strong" )
588
+ if verbose :
589
+ elapsed_time = time () - step_start_time
590
+ print (f"Computed connected components for unassigned transcripts in { elapsed_time :.2f} seconds." )
591
+
592
+ step_start_time = time ()
593
+ if verbose :
594
+ print (f"The rest..." )
595
+ # # Step 4: Map back the component labels to the original transcript_ids
596
+
597
+ def _get_id ():
598
+ """Generate a random Xenium-style ID."""
599
+ return "" .join (np .random .choice (list ("abcdefghijklmnopqrstuvwxyz" ), 8 )) + "-nx"
600
+
601
+ new_ids = np .array ([_get_id () for _ in range (n )])
602
+ comp_labels = new_ids [comps ]
603
+ comp_labels = pd .Series (comp_labels , index = transcript_ids_in_edges )
604
+ # Step 5: Handle only unassigned transcripts in transcripts_df_filtered
605
+ unassigned_mask = transcripts_df_filtered ["segger_cell_id" ].isna ()
606
+
607
+ unassigned_transcripts_df = transcripts_df_filtered .loc [unassigned_mask , ["transcript_id" ]]
608
+
609
+ # Step 6: Map component labels only to unassigned transcript_ids
610
+ new_segger_cell_ids = unassigned_transcripts_df ["transcript_id" ].map (comp_labels )
611
+
612
+ # Step 7: Create a DataFrame with updated 'segger_cell_id' for unassigned transcripts
613
+ unassigned_transcripts_df = unassigned_transcripts_df .assign (segger_cell_id = new_segger_cell_ids )
614
+
615
+ # Step 8: Merge this DataFrame back into the original to update only the unassigned segger_cell_id
616
+
617
+ # Merging the updates back to the original DataFrame
618
+ transcripts_df_filtered = transcripts_df_filtered .merge (
619
+ unassigned_transcripts_df [["transcript_id" , "segger_cell_id" ]],
620
+ on = "transcript_id" ,
621
+ how = "left" , # Perform a left join to only update the unassigned rows
622
+ suffixes = ("" , "_new" ), # Suffix for new column to avoid overwriting
623
+ )
624
+
625
+ # Step 9: Fill missing segger_cell_id values with the updated values from the merge
626
+ transcripts_df_filtered ["segger_cell_id" ] = transcripts_df_filtered ["segger_cell_id" ].fillna (
627
+ transcripts_df_filtered ["segger_cell_id_new" ]
628
+ )
661
629
662
- if verbose :
663
- elapsed_time = time () - step_start_time
664
- print (f"Connected components computed in { elapsed_time :.2f} seconds." )
630
+ transcripts_df_filtered = transcripts_df_filtered .drop (columns = ["segger_cell_id_new" ])
631
+
632
+ if verbose :
633
+ elapsed_time = time () - step_start_time
634
+ print (f"The rest computed in { elapsed_time :.2f} seconds." )
665
635
666
636
# Step 5: Save the merged results based on options
667
637
@@ -670,14 +640,14 @@ def segment(
670
640
step_start_time = time ()
671
641
print (f"Saving transcirpts.parquet..." )
672
642
transcripts_save_path = save_dir / "segger_transcripts.parquet"
673
- transcripts_df_filtered = transcripts_df_filtered .repartition (npartitions = 100 )
643
+ # transcripts_df_filtered = transcripts_df_filtered.repartition(npartitions=100)
674
644
transcripts_df_filtered .to_parquet (
675
645
transcripts_save_path ,
676
646
engine = "pyarrow" , # PyArrow is faster and recommended
677
647
compression = "snappy" , # Use snappy compression for speed
678
- write_index = False , # Skip writing index if not needed
679
- append = False , # Set to True if you're appending to an existing Parquet file
680
- overwrite = True ,
648
+ # write_index=False, # Skip writing index if not needed
649
+ # append=False, # Set to True if you're appending to an existing Parquet file
650
+ # overwrite=True,
681
651
) # Dask handles Parquet well
682
652
if verbose :
683
653
elapsed_time = time () - step_start_time
@@ -688,7 +658,7 @@ def segment(
688
658
step_start_time = time ()
689
659
print (f"Saving anndata object..." )
690
660
anndata_save_path = save_dir / "segger_adata.h5ad"
691
- segger_adata = create_anndata (transcripts_df_filtered . compute () , ** anndata_kwargs ) # Compute for AnnData
661
+ segger_adata = create_anndata (transcripts_df_filtered , ** anndata_kwargs ) # Compute for AnnData
692
662
segger_adata .write (anndata_save_path )
693
663
if verbose :
694
664
elapsed_time = time () - step_start_time
0 commit comments