Skip to content

Commit 6adf12e

Browse files
committed
Revert model changes and remove print statements outside of verbose
1 parent b516153 commit 6adf12e

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

src/segger/models/segger_model.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,20 @@ def __init__(
3838

3939
# First GATv2Conv layer
4040
self.conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
41-
self.lin_first = Linear(-1, hidden_channels * heads)
41+
#self.lin_first = Linear(-1, hidden_channels * heads)
4242

4343
# Middle GATv2Conv layers
4444
self.num_mid_layers = num_mid_layers
4545
if num_mid_layers > 0:
4646
self.conv_mid_layers = torch.nn.ModuleList()
47-
self.lin_mid_layers = torch.nn.ModuleList()
47+
#self.lin_mid_layers = torch.nn.ModuleList()
4848
for _ in range(num_mid_layers):
4949
self.conv_mid_layers.append(GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False))
50-
self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))
50+
#self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))
5151

5252
# Last GATv2Conv layer
5353
self.conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False)
54-
self.lin_last = Linear(-1, out_channels * heads)
54+
#self.lin_last = Linear(-1, out_channels * heads)
5555

5656
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
5757
"""
@@ -70,19 +70,19 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
7070
x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
7171
# First layer
7272
x = x.relu()
73-
x = self.conv_first(x, edge_index) + self.lin_first(x)
73+
x = self.conv_first(x, edge_index) # + self.lin_first(x)
7474
x = x.relu()
7575

7676
# Middle layers
7777
if self.num_mid_layers > 0:
7878
for i in range(self.num_mid_layers):
7979
conv_mid = self.conv_mid_layers[i]
80-
lin_mid = self.lin_mid_layers[i]
81-
x = conv_mid(x, edge_index) + lin_mid(x)
80+
#lin_mid = self.lin_mid_layers[i]
81+
x = conv_mid(x, edge_index) # + lin_mid(x)
8282
x = x.relu()
8383

8484
# Last layer
85-
x = self.conv_last(x, edge_index) + self.lin_last(x)
85+
x = self.conv_last(x, edge_index) # + self.lin_last(x)
8686

8787
return x
8888

src/segger/prediction/predict_parquet.py

-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,6 @@ def _get_id():
304304
"""Generate a random Xenium-style ID."""
305305
return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx"
306306

307-
print(gpu_id)
308307
with cp.cuda.Device(gpu_id):
309308
# Move the batch to the specified GPU
310309
batch = batch.to(f"cuda:{gpu_id}")

0 commit comments

Comments
 (0)