Skip to content

Commit

Permalink
[CI] fix device configure when run on GPU (dmlc#5154)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jan 13, 2023
1 parent cdfd1e3 commit f65bd2d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
8 changes: 4 additions & 4 deletions benchmarks/benchmarks/api/bench_edge_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .. import utils


@utils.skip_if_gpu()
@utils.benchmark("time")
@utils.parametrize("graph_name", ["livejournal", "reddit"])
@utils.parametrize("format", ["coo"])
Expand All @@ -20,15 +19,16 @@ def track_time(graph_name, format, seed_egdes_num):
graph = graph.to(device)

seed_edges = np.random.randint(0, graph.num_edges(), seed_egdes_num)
seed_edges = torch.from_numpy(seed_edges).to(device)

# dry run
for i in range(3):
dgl.edge_subgraph(graph, seed_edges)

# timing

num_iters = 50
with utils.Timer() as t:
for i in range(3):
for i in range(num_iters):
dgl.edge_subgraph(graph, seed_edges)

return t.elapsed_secs / 3
return t.elapsed_secs / num_iters
6 changes: 4 additions & 2 deletions benchmarks/benchmarks/api/bench_in_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ def track_time(graph_name, format, seed_nodes_num):
graph = graph.to(device)

seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)
seed_nodes = torch.from_numpy(seed_nodes).to(device)

# dry run
for i in range(3):
dgl.in_subgraph(graph, seed_nodes)

# timing
num_iters = 50
with utils.Timer() as t:
for i in range(3):
for i in range(num_iters):
dgl.in_subgraph(graph, seed_nodes)

return t.elapsed_secs / 3
return t.elapsed_secs / num_iters
6 changes: 4 additions & 2 deletions benchmarks/benchmarks/api/bench_node_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ def track_time(graph_name, format, seed_nodes_num):
graph = graph.to(device)

seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)
seed_nodes = torch.from_numpy(seed_nodes).to(device)

# dry run
for i in range(3):
dgl.node_subgraph(graph, seed_nodes)

# timing
num_iters = 50
with utils.Timer() as t:
for i in range(3):
for i in range(num_iters):
dgl.node_subgraph(graph, seed_nodes)

return t.elapsed_secs / 3
return t.elapsed_secs / num_iters
3 changes: 2 additions & 1 deletion benchmarks/benchmarks/api/bench_sample_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
@utils.parametrize("fanout", [5, 20, 40])
def track_time(graph_name, format, seed_nodes_num, fanout):
device = utils.get_bench_device()
graph = utils.get_graph(graph_name, format)
graph = utils.get_graph(graph_name, format).to(device)

edge_dir = "in"
seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)
seed_nodes = torch.from_numpy(seed_nodes).to(device)

# dry run
for i in range(3):
Expand Down

0 comments on commit f65bd2d

Please sign in to comment.