diff --git a/benchmarks/benchmarks/api/bench_edge_subgraph.py b/benchmarks/benchmarks/api/bench_edge_subgraph.py index 0975799ed845..f68395f4cf47 100644 --- a/benchmarks/benchmarks/api/bench_edge_subgraph.py +++ b/benchmarks/benchmarks/api/bench_edge_subgraph.py @@ -9,7 +9,6 @@ from .. import utils -@utils.skip_if_gpu() @utils.benchmark("time") @utils.parametrize("graph_name", ["livejournal", "reddit"]) @utils.parametrize("format", ["coo"]) @@ -20,15 +19,16 @@ def track_time(graph_name, format, seed_egdes_num): graph = graph.to(device) seed_edges = np.random.randint(0, graph.num_edges(), seed_egdes_num) + seed_edges = torch.from_numpy(seed_edges).to(device) # dry run for i in range(3): dgl.edge_subgraph(graph, seed_edges) # timing - + num_iters = 50 with utils.Timer() as t: - for i in range(3): + for i in range(num_iters): dgl.edge_subgraph(graph, seed_edges) - return t.elapsed_secs / 3 + return t.elapsed_secs / num_iters diff --git a/benchmarks/benchmarks/api/bench_in_subgraph.py b/benchmarks/benchmarks/api/bench_in_subgraph.py index cfc3df9bb4da..6e84498886da 100644 --- a/benchmarks/benchmarks/api/bench_in_subgraph.py +++ b/benchmarks/benchmarks/api/bench_in_subgraph.py @@ -19,14 +19,16 @@ def track_time(graph_name, format, seed_nodes_num): graph = graph.to(device) seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num) + seed_nodes = torch.from_numpy(seed_nodes).to(device) # dry run for i in range(3): dgl.in_subgraph(graph, seed_nodes) # timing + num_iters = 50 with utils.Timer() as t: - for i in range(3): + for i in range(num_iters): dgl.in_subgraph(graph, seed_nodes) - return t.elapsed_secs / 3 + return t.elapsed_secs / num_iters diff --git a/benchmarks/benchmarks/api/bench_node_subgraph.py b/benchmarks/benchmarks/api/bench_node_subgraph.py index 9a87d3254ef2..d547c0202052 100644 --- a/benchmarks/benchmarks/api/bench_node_subgraph.py +++ b/benchmarks/benchmarks/api/bench_node_subgraph.py @@ -19,14 +19,16 @@ def track_time(graph_name, format, seed_nodes_num): graph = graph.to(device) seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num) + seed_nodes = torch.from_numpy(seed_nodes).to(device) # dry run for i in range(3): dgl.node_subgraph(graph, seed_nodes) # timing + num_iters = 50 with utils.Timer() as t: - for i in range(3): + for i in range(num_iters): dgl.node_subgraph(graph, seed_nodes) - return t.elapsed_secs / 3 + return t.elapsed_secs / num_iters diff --git a/benchmarks/benchmarks/api/bench_sample_neighbors.py b/benchmarks/benchmarks/api/bench_sample_neighbors.py index 39610df4bfa4..bb3c28afc60e 100644 --- a/benchmarks/benchmarks/api/bench_sample_neighbors.py +++ b/benchmarks/benchmarks/api/bench_sample_neighbors.py @@ -17,10 +17,11 @@ @utils.parametrize("fanout", [5, 20, 40]) def track_time(graph_name, format, seed_nodes_num, fanout): device = utils.get_bench_device() - graph = utils.get_graph(graph_name, format) + graph = utils.get_graph(graph_name, format).to(device) edge_dir = "in" seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num) + seed_nodes = torch.from_numpy(seed_nodes).to(device) # dry run for i in range(3):