Skip to content

Commit b516153

Browse files
committed
Merge branch 'main' of https://github.com/EliHei2/segger_dev into main
2 parents 874e970 + fb58fc1 commit b516153

File tree

2 files changed

+107
-135
lines changed

2 files changed

+107
-135
lines changed

src/segger/models/segger_model.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,20 @@ def __init__(
3838

3939
# First GATv2Conv layer
4040
self.conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
41-
# self.lin_first = Linear(-1, hidden_channels * heads)
41+
self.lin_first = Linear(-1, hidden_channels * heads)
4242

4343
# Middle GATv2Conv layers
4444
self.num_mid_layers = num_mid_layers
4545
if num_mid_layers > 0:
4646
self.conv_mid_layers = torch.nn.ModuleList()
47-
# self.lin_mid_layers = torch.nn.ModuleList()
47+
self.lin_mid_layers = torch.nn.ModuleList()
4848
for _ in range(num_mid_layers):
4949
self.conv_mid_layers.append(GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False))
50-
# self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))
50+
self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))
5151

5252
# Last GATv2Conv layer
5353
self.conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False)
54-
# self.lin_last = Linear(-1, out_channels * heads)
54+
self.lin_last = Linear(-1, out_channels * heads)
5555

5656
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
5757
"""
@@ -70,17 +70,19 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
7070
x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
7171
# First layer
7272
x = x.relu()
73-
x = self.conv_first(x, edge_index) # + self.lin_first(x)
73+
x = self.conv_first(x, edge_index) + self.lin_first(x)
7474
x = x.relu()
7575

7676
# Middle layers
7777
if self.num_mid_layers > 0:
78-
for conv_mid in self.conv_mid_layers:
79-
x = conv_mid(x, edge_index) # + lin_mid(x)
78+
for i in range(self.num_mid_layers):
79+
conv_mid = self.conv_mid_layers[i]
80+
lin_mid = self.lin_mid_layers[i]
81+
x = conv_mid(x, edge_index) + lin_mid(x)
8082
x = x.relu()
8183

8284
# Last layer
83-
x = self.conv_last(x, edge_index) # + self.lin_last(x)
85+
x = self.conv_last(x, edge_index) + self.lin_last(x)
8486

8587
return x
8688

src/segger/prediction/predict_parquet.py

+97-127
Original file line numberDiff line numberDiff line change
@@ -489,34 +489,6 @@ def segment(
489489
val_dataloader = dm.val_dataloader()
490490
test_dataloader = dm.test_dataloader()
491491

492-
# # Initialize Dask DataFrame for assignments
493-
# output_ddf = None
494-
495-
# @dask.delayed
496-
# def process_batch(batch, gpu_id):
497-
# # Assume you're using CuPy, and you need to use a specific GPU
498-
# predict_batch(
499-
# model,
500-
# batch,
501-
# score_cut,
502-
# receptive_field,
503-
# use_cc=use_cc,
504-
# knn_method=knn_method,
505-
# edge_index_save_path=edge_index_save_path,
506-
# output_ddf_save_path=output_ddf_save_path,
507-
# gpu_id=gpu_id
508-
# )
509-
510-
# delayed_tasks = [process_batch(batch, gpu_ids[i % len(gpu_ids)]) for i, batch in enumerate(dm.train)]
511-
# # pqdm(delayed_tasks, n_jobs=len(gpu_ids), argument_type='delayed', progress_bar=True)
512-
# # dask.compute(*delayed_tasks)
513-
# # delayed_tasks = [process_batch(batch, gpu_ids[i % len(gpu_ids)]) for i, batch in enumerate(batches)]
514-
515-
# # Use tqdm for progress bar
516-
# with ProgressBar():
517-
# # Execute the delayed tasks with a Dask compute call
518-
# dask.compute(*delayed_tasks)
519-
520492
# Loop through the data loaders (train, val, and test)
521493
for loader_name, loader in zip(
522494
["Train", "Validation", "Test"], [train_dataloader, val_dataloader, test_dataloader]
@@ -544,124 +516,122 @@ def segment(
544516
elapsed_time = time() - step_start_time
545517
print(f"Batch processing completed in {elapsed_time:.2f} seconds.")
546518

547-
# Load the full saved segmentation results
548-
seg_final_dd = dd.read_parquet(output_ddf_save_path)
549-
seg_final_dd = seg_final_dd.set_index("transcript_id", sorted=False)
550-
519+
seg_final_dd = pd.read_parquet(output_ddf_save_path)
520+
seg_final_dd = seg_final_dd.set_index("transcript_id")
521+
551522
step_start_time = time()
552523
if verbose:
553524
print(f"Applying max score selection logic...")
554-
525+
555526
# Step 1: Find max bound indices (bound == 1) and max unbound indices (bound == 0)
556527
max_bound_idx = seg_final_dd[seg_final_dd["bound"] == 1].groupby("transcript_id")["score"].idxmax()
557528
max_unbound_idx = seg_final_dd[seg_final_dd["bound"] == 0].groupby("transcript_id")["score"].idxmax()
558-
529+
559530
# Step 2: Combine indices, prioritizing bound=1 scores
560-
final_idx = max_bound_idx.combine_first(max_unbound_idx).compute()
561-
print(final_idx)
562-
531+
final_idx = max_bound_idx.combine_first(max_unbound_idx)
532+
563533
# Step 3: Use the computed final_idx to select the best assignments
564534
# Make sure you are using the divisions and set the index correctly before loc
565-
# seg_final_dd = seg_final_dd.set_index('transcript_id', sorted=True)
566-
seg_final_filtered = seg_final_dd.loc[final_idx].compute()
567-
535+
seg_final_filtered = seg_final_dd.loc[final_idx]
536+
568537
if verbose:
569538
elapsed_time = time() - step_start_time
570539
print(f"Max score selection completed in {elapsed_time:.2f} seconds.")
571-
540+
572541
# Step 3: Load the transcripts DataFrame and merge results
573-
542+
574543
if verbose:
575544
print(f"Loading transcripts from {transcript_file}...")
576-
577-
transcripts_df = dd.read_parquet(transcript_file)
545+
546+
transcripts_df = pd.read_parquet(transcript_file)
578547
transcripts_df["transcript_id"] = transcripts_df["transcript_id"].astype(str)
579-
548+
580549
step_start_time = time()
581550
if verbose:
582551
print(f"Merging segmentation results with transcripts...")
583-
552+
584553
# Outer merge to include all transcripts, even those without assigned cell ids
585554
transcripts_df_filtered = transcripts_df.merge(seg_final_filtered, on="transcript_id", how="outer")
586-
555+
587556
if verbose:
588557
elapsed_time = time() - step_start_time
589-
print(f"Merging segmentation results with transcripts completed in {elapsed_time:.2f} seconds.")
590-
591-
# Step 4: Handle unassigned transcripts using connected components (if use_cc=True)
592-
if use_cc:
593-
594-
step_start_time = time()
595-
if verbose:
596-
print(f"Computing connected components for unassigned transcripts...")
597-
# Load edge indices from saved Parquet
598-
edge_index_dd = dd.read_parquet(edge_index_save_path)
599-
600-
# Step 2: Get unique transcript_ids from edge_index_dd and their positional indices
601-
transcript_ids_in_edges = dd.concat([edge_index_dd["source"], edge_index_dd["target"]]).unique().compute()
602-
603-
# Create a lookup table with unique indices
604-
lookup_table = pd.Series(data=range(len(transcript_ids_in_edges)), index=transcript_ids_in_edges).to_dict()
605-
606-
# Map source and target to positional indices
607-
edge_index_dd["index_source"] = edge_index_dd["source"].map(lookup_table)
608-
edge_index_dd["index_target"] = edge_index_dd["target"].map(lookup_table)
609-
# Step 3: Compute connected components for transcripts involved in edges
610-
source_indices = np.asarray(edge_index_dd["index_source"].compute())
611-
target_indices = np.asarray(edge_index_dd["index_target"].compute())
612-
data_cp = np.ones(len(source_indices), dtype=cp.float32)
613-
614-
# Create the sparse COO matrix
615-
coo_cp_matrix = scipy_coo_matrix(
616-
(data_cp, (source_indices, target_indices)),
617-
shape=(len(transcript_ids_in_edges), len(transcript_ids_in_edges)),
618-
)
619-
620-
# Use CuPy's connected components algorithm to compute components
621-
n, comps = cc(coo_cp_matrix, directed=True, connection="weak")
622-
623-
# Step 4: Map back the component labels to the original transcript_ids
624-
comp_labels = pd.Series(comps, index=transcript_ids_in_edges)
625-
# Step 5: Handle only unassigned transcripts in transcripts_df_filtered
626-
unassigned_mask = transcripts_df_filtered["segger_cell_id"].isna()
627-
628-
unassigned_transcripts_df = transcripts_df_filtered.loc[unassigned_mask, ["transcript_id"]]
629-
630-
# Step 6: Map component labels only to unassigned transcript_ids
631-
new_segger_cell_ids = unassigned_transcripts_df["transcript_id"].map(comp_labels)
632-
633-
# Step 7: Create a DataFrame with updated 'segger_cell_id' for unassigned transcripts
634-
unassigned_transcripts_df = unassigned_transcripts_df.assign(segger_cell_id=new_segger_cell_ids)
635-
636-
# Step 8: Merge this DataFrame back into the original to update only the unassigned segger_cell_id
637-
# We perform a left join so that only the rows in unassigned_transcripts_df are updated
638-
# transcripts_df_filtered = transcripts_df_filtered.drop(columns='segger_cell_id')
639-
640-
# Merging the updates back to the original DataFrame
641-
transcripts_df_filtered = transcripts_df_filtered.merge(
642-
unassigned_transcripts_df[["transcript_id", "segger_cell_id"]],
643-
on="transcript_id",
644-
how="left", # Perform a left join to only update the unassigned rows
645-
suffixes=("", "_new"), # Suffix for new column to avoid overwriting
646-
)
647-
648-
# Step 9: Fill missing segger_cell_id values with the updated values from the merge
649-
transcripts_df_filtered["segger_cell_id"] = transcripts_df_filtered["segger_cell_id"].fillna(
650-
transcripts_df_filtered["segger_cell_id_new"]
651-
)
652-
653-
# Step 10: Clean up by dropping the temporary 'segger_cell_id_new' column
654-
transcripts_df_filtered = transcripts_df_filtered.drop(columns=["segger_cell_id_new"])
655-
656-
# Fill the NaN values in segger_cell_id with the already existing (assigned) values
657-
# transcripts_df_filtered['segger_cell_id'] = transcripts_df_filtered['segger_cell_id'].fillna(transcripts_df_filtered['segger_cell_id_target'])
658-
659-
# Drop any temporary columns used during the merge
660-
# transcripts_df_filtered = transcripts_df_filtered.drop(columns=['segger_cell_id_target'])
558+
print(f"Merged segmentation results with transcripts in {elapsed_time:.2f} seconds.")
559+
560+
step_start_time = time()
561+
if verbose:
562+
print(f"Computing connected components for unassigned transcripts...")
563+
# Load edge indices from saved Parquet
564+
edge_index_dd = pd.read_parquet(edge_index_save_path)
565+
566+
# Step 2: Get unique transcript_ids from edge_index_dd and their positional indices
567+
transcript_ids_in_edges = pd.concat([edge_index_dd["source"], edge_index_dd["target"]]).unique()
568+
569+
# Create a lookup table with unique indices
570+
lookup_table = pd.Series(data=range(len(transcript_ids_in_edges)), index=transcript_ids_in_edges).to_dict()
571+
572+
# Map source and target to positional indices
573+
edge_index_dd["index_source"] = edge_index_dd["source"].map(lookup_table)
574+
edge_index_dd["index_target"] = edge_index_dd["target"].map(lookup_table)
575+
# Step 3: Compute connected components for transcripts involved in edges
576+
source_indices = np.asarray(edge_index_dd["index_source"])
577+
target_indices = np.asarray(edge_index_dd["index_target"])
578+
data_cp = np.ones(len(source_indices), dtype=np.float32)
579+
580+
# Create the sparse COO matrix
581+
coo_cp_matrix = scipy_coo_matrix(
582+
(data_cp, (source_indices, target_indices)),
583+
shape=(len(transcript_ids_in_edges), len(transcript_ids_in_edges)),
584+
)
585+
586+
# Use CuPy's connected components algorithm to compute components
587+
n, comps = cc(coo_cp_matrix, directed=True, connection="strong")
588+
if verbose:
589+
elapsed_time = time() - step_start_time
590+
print(f"Computed connected components for unassigned transcripts in {elapsed_time:.2f} seconds.")
591+
592+
step_start_time = time()
593+
if verbose:
594+
print(f"The rest...")
595+
# # Step 4: Map back the component labels to the original transcript_ids
596+
597+
def _get_id():
598+
"""Generate a random Xenium-style ID."""
599+
return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx"
600+
601+
new_ids = np.array([_get_id() for _ in range(n)])
602+
comp_labels = new_ids[comps]
603+
comp_labels = pd.Series(comp_labels, index=transcript_ids_in_edges)
604+
# Step 5: Handle only unassigned transcripts in transcripts_df_filtered
605+
unassigned_mask = transcripts_df_filtered["segger_cell_id"].isna()
606+
607+
unassigned_transcripts_df = transcripts_df_filtered.loc[unassigned_mask, ["transcript_id"]]
608+
609+
# Step 6: Map component labels only to unassigned transcript_ids
610+
new_segger_cell_ids = unassigned_transcripts_df["transcript_id"].map(comp_labels)
611+
612+
# Step 7: Create a DataFrame with updated 'segger_cell_id' for unassigned transcripts
613+
unassigned_transcripts_df = unassigned_transcripts_df.assign(segger_cell_id=new_segger_cell_ids)
614+
615+
# Step 8: Merge this DataFrame back into the original to update only the unassigned segger_cell_id
616+
617+
# Merging the updates back to the original DataFrame
618+
transcripts_df_filtered = transcripts_df_filtered.merge(
619+
unassigned_transcripts_df[["transcript_id", "segger_cell_id"]],
620+
on="transcript_id",
621+
how="left", # Perform a left join to only update the unassigned rows
622+
suffixes=("", "_new"), # Suffix for new column to avoid overwriting
623+
)
624+
625+
# Step 9: Fill missing segger_cell_id values with the updated values from the merge
626+
transcripts_df_filtered["segger_cell_id"] = transcripts_df_filtered["segger_cell_id"].fillna(
627+
transcripts_df_filtered["segger_cell_id_new"]
628+
)
661629

662-
if verbose:
663-
elapsed_time = time() - step_start_time
664-
print(f"Connected components computed in {elapsed_time:.2f} seconds.")
630+
transcripts_df_filtered = transcripts_df_filtered.drop(columns=["segger_cell_id_new"])
631+
632+
if verbose:
633+
elapsed_time = time() - step_start_time
634+
print(f"The rest computed in {elapsed_time:.2f} seconds.")
665635

666636
# Step 5: Save the merged results based on options
667637

@@ -670,14 +640,14 @@ def segment(
670640
step_start_time = time()
671641
print(f"Saving transcirpts.parquet...")
672642
transcripts_save_path = save_dir / "segger_transcripts.parquet"
673-
transcripts_df_filtered = transcripts_df_filtered.repartition(npartitions=100)
643+
# transcripts_df_filtered = transcripts_df_filtered.repartition(npartitions=100)
674644
transcripts_df_filtered.to_parquet(
675645
transcripts_save_path,
676646
engine="pyarrow", # PyArrow is faster and recommended
677647
compression="snappy", # Use snappy compression for speed
678-
write_index=False, # Skip writing index if not needed
679-
append=False, # Set to True if you're appending to an existing Parquet file
680-
overwrite=True,
648+
# write_index=False, # Skip writing index if not needed
649+
# append=False, # Set to True if you're appending to an existing Parquet file
650+
# overwrite=True,
681651
) # Dask handles Parquet well
682652
if verbose:
683653
elapsed_time = time() - step_start_time
@@ -688,7 +658,7 @@ def segment(
688658
step_start_time = time()
689659
print(f"Saving anndata object...")
690660
anndata_save_path = save_dir / "segger_adata.h5ad"
691-
segger_adata = create_anndata(transcripts_df_filtered.compute(), **anndata_kwargs) # Compute for AnnData
661+
segger_adata = create_anndata(transcripts_df_filtered, **anndata_kwargs) # Compute for AnnData
692662
segger_adata.write(anndata_save_path)
693663
if verbose:
694664
elapsed_time = time() - step_start_time

0 commit comments

Comments
 (0)