Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Oct 18, 2023
1 parent 1b60dc0 commit 8fcc86c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
15 changes: 6 additions & 9 deletions benchmark/sampler/hetero_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_hetero_neighbor(dataset, **kwargs):

colptr_dict, row_dict = dataset
num_nodes_dict = {k[-1]: v.size(0) - 1 for k, v in colptr_dict.items()}
num_edges_dict = {k[-1]: v.size(0) for k, v in row_dict.items()}
num_edges_dict = {k: v.size(0) for k, v in row_dict.items()}

if args.temporal:
# generate random timestamps
Expand All @@ -62,15 +62,12 @@ def test_hetero_neighbor(dataset, **kwargs):
else:
node_time_dict = None

edge_weight_dict = None
if args.biased:
ones = torch.ones(num_edges_dict['paper']).view(-1, 1)
zeros = torch.zeros(num_edges_dict['paper']).view(-1, 1)
edge_weights_dict = {
k: torch.cat([ones, zeros], -1).view(-1)
for k in row_dict.keys()
edge_weight_dict = {
edge_type: torch.rand(num_edges)
for edge_type, num_edges in num_edges_dict.items()
}
else:
edge_weights_dict = None

if args.shuffle:
node_perm = torch.randperm(num_nodes_dict['paper'])
Expand Down Expand Up @@ -98,7 +95,7 @@ def test_hetero_neighbor(dataset, **kwargs):
num_neighbors_dict,
node_time_dict,
seed_time_dict=None,
edge_weight_dict=edge_weights_dict,
edge_weight_dict=edge_weight_dict,
csc=True,
replace=False,
directed=True,
Expand Down
15 changes: 6 additions & 9 deletions benchmark/sampler/neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def test_neighbor(dataset, **kwargs):
raise ValueError(
"Temporal sampling needs to create disjoint subgraphs")

(rowptr, col) = dataset
num_nodes = dataset[0].size(0) - 1
num_edges = col.size(0)
rowptr, col = dataset
num_nodes = rowptr.numel() - 1
num_edges = col.numel()

if 'dgl' in args.libraries:
import dgl
Expand All @@ -61,12 +61,9 @@ def test_neighbor(dataset, **kwargs):
else:
node_time = None

edge_weight = None
if args.biased:
ones = torch.ones(num_edges).view(-1, 1)
zeros = torch.zeros(num_edges).view(-1, 1)
edge_weights = torch.cat([ones, zeros], -1).view(-1)
else:
edge_weights = None
edge_weight = torch.rand(num_edges)

if args.shuffle:
node_perm = torch.randperm(num_nodes)
Expand All @@ -91,7 +88,7 @@ def test_neighbor(dataset, **kwargs):
num_neighbors,
time=node_time,
seed_time=None,
edge_weight=edge_weights,
edge_weight=edge_weight,
replace=args.replace,
directed=args.directed,
disjoint=args.disjoint,
Expand Down

0 comments on commit 8fcc86c

Please sign in to comment.