From ed3410791890f912cddac475130b05ec346671a9 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Fri, 3 Jan 2025 17:52:18 -0500 Subject: [PATCH] new: support `use_gpu` algorithm parameter --- README.md | 7 +++++-- doc/algorithms/index.rst | 7 +++++-- doc/nx_arangodb.ipynb | 6 +----- nx_arangodb/interface.py | 4 +++- tests/test.py | 11 +++++++++-- 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index b7cc7fd..b0aa5da 100644 --- a/README.md +++ b/README.md @@ -166,13 +166,16 @@ import nx_arangodb as nxadb G = nxadb.Graph(name="MyGraph") +# Option 1: Use Global Config nx.config.backends.arangodb.use_gpu = False - nx.pagerank(G) nx.betweenness_centrality(G) # ... - nx.config.backends.arangodb.use_gpu = True + +# Option 2: Use Local Config +nx.pagerank(G, use_gpu=False) +nx.betweenness_centrality(G, use_gpu=False) ```

diff --git a/doc/algorithms/index.rst b/doc/algorithms/index.rst index 9adf6f7..d041b60 100644 --- a/doc/algorithms/index.rst +++ b/doc/algorithms/index.rst @@ -43,14 +43,17 @@ You can also force-run algorithms on CPU even if ``nx-cugraph`` is installed: G = nxadb.Graph(name="MyGraph") + # Option 1: Use Global Config nx.config.backends.arangodb.use_gpu = False - nx.pagerank(G) nx.betweenness_centrality(G) # ... - nx.config.backends.arangodb.use_gpu = True + # Option 2: Use Local Config + nx.pagerank(G, use_gpu=False) + nx.betweenness_centrality(G, use_gpu=False) + .. image:: ../_static/dispatch.png :align: center diff --git a/doc/nx_arangodb.ipynb b/doc/nx_arangodb.ipynb index 003524e..71249b5 100644 --- a/doc/nx_arangodb.ipynb +++ b/doc/nx_arangodb.ipynb @@ -236,9 +236,7 @@ "outputs": [], "source": [ "# 5. Run an algorithm (CPU)\n", - "nx.config.backends.arangodb.use_gpu = False # Optional\n", - "\n", - "res = nx.pagerank(G)" + "res = nx.pagerank(G, use_gpu=False)" ] }, { @@ -357,8 +355,6 @@ "source": [ "# 4. Run an algorithm (GPU)\n", "# See *Package Installation* to install nx-cugraph ^\n", - "nx.config.backends.arangodb.use_gpu = True\n", - "\n", "res = nx.pagerank(G)" ] }, diff --git a/nx_arangodb/interface.py b/nx_arangodb/interface.py index c756a32..ef2d2a6 100644 --- a/nx_arangodb/interface.py +++ b/nx_arangodb/interface.py @@ -63,7 +63,9 @@ def _auto_func(func_name: str, /, *args: Any, **kwargs: Any) -> Any: dfunc = _registered_algorithms[func_name] backend_priority: list[str] = [] - if nxadb.convert.GPU_AVAILABLE and nx.config.backends.arangodb.use_gpu: + + use_gpu = bool(kwargs.pop("use_gpu", nx.config.backends.arangodb.use_gpu)) + if nxadb.convert.GPU_AVAILABLE and use_gpu: backend_priority.append("cugraph") for backend in backend_priority: diff --git a/tests/test.py b/tests/test.py index db1e7e8..7395771 100644 --- a/tests/test.py +++ b/tests/test.py @@ -447,7 +447,12 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None: assert gpu_cached_time < gpu_no_cache_time assert_pagerank(res_gpu_cached, res_gpu_no_cache, 10) - # 4. CPU + # 4. CPU (with use_gpu=False) + start_cpu_force_no_gpu = time.time() + res_cpu_force_no_gpu = nx.pagerank(graph, use_gpu=False) + cpu_force_no_gpu_time = time.time() - start_cpu_force_no_gpu + + # 5. CPU assert graph.nxcg_graph is not None graph.clear_nxcg_cache() assert graph.nxcg_graph is None @@ -456,12 +461,14 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None: start_cpu = time.time() res_cpu = nx.pagerank(graph) cpu_time = time.time() - start_cpu + assert_pagerank(res_cpu, res_cpu_force_no_gpu, 10) assert graph.nxcg_graph is None - m = "GPU execution should be faster than CPU execution" assert gpu_time < cpu_time, m + assert gpu_time < cpu_force_no_gpu_time, m assert gpu_no_cache_time < cpu_time, m + assert gpu_no_cache_time < cpu_force_no_gpu_time, m assert_pagerank(res_gpu_no_cache, res_cpu, 10)