Skip to content

Commit ad38fc6

Browse files
committed
new: fully support parameterized db object
1 parent 9f59085 commit ad38fc6

File tree

7 files changed

+74
-56
lines changed

7 files changed

+74
-56
lines changed

_nx_arangodb/__init__.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,7 @@ def get_info():
7474
for key in info_keys:
7575
del d[key]
7676

77-
d["default_config"] = {
78-
"host": None,
79-
"username": None,
80-
"password": None,
81-
"db_name": None,
82-
"read_parallelism": None,
83-
"read_batch_size": None,
84-
"use_gpu": True,
85-
}
77+
d["default_config"] = {"use_gpu": True}
8678

8779
return d
8880

nx_arangodb/classes/dict/adj.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def adjlist_outer_dict_factory(
105105
db: StandardDatabase,
106106
graph: Graph,
107107
default_node_type: str,
108+
read_parallelism: int,
109+
read_batch_size: int,
108110
edge_type_key: str,
109111
edge_type_func: Callable[[str, str], str],
110112
graph_type: str,
@@ -115,6 +117,8 @@ def adjlist_outer_dict_factory(
115117
db,
116118
graph,
117119
default_node_type,
120+
read_parallelism,
121+
read_batch_size,
118122
edge_type_key,
119123
edge_type_func,
120124
graph_type,
@@ -1467,6 +1471,8 @@ def __init__(
14671471
db: StandardDatabase,
14681472
graph: Graph,
14691473
default_node_type: str,
1474+
read_parallelism: int,
1475+
read_batch_size: int,
14701476
edge_type_key: str,
14711477
edge_type_func: Callable[[str, str], str],
14721478
graph_type: str,
@@ -1489,6 +1495,8 @@ def __init__(
14891495
self.edge_type_key = edge_type_key
14901496
self.edge_type_func = edge_type_func
14911497
self.default_node_type = default_node_type
1498+
self.read_parallelism = read_parallelism
1499+
self.read_batch_size = read_batch_size
14921500
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(
14931501
db,
14941502
graph,
@@ -1853,6 +1861,8 @@ def _fetch_all(self) -> None:
18531861
is_directed=True,
18541862
is_multigraph=self.is_multigraph,
18551863
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
1864+
read_parallelism=self.read_parallelism,
1865+
read_batch_size=self.read_batch_size,
18561866
)
18571867

18581868
# Even if the Graph is undirected,

nx_arangodb/classes/dict/node.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,20 @@
4040

4141

4242
def node_dict_factory(
43-
db: StandardDatabase, graph: Graph, default_node_type: str
43+
db: StandardDatabase,
44+
graph: Graph,
45+
default_node_type: str,
46+
read_parallelism: int,
47+
read_batch_size: int,
4448
) -> Callable[..., NodeDict]:
4549
"""Factory function for creating a NodeDict."""
46-
return lambda: NodeDict(db, graph, default_node_type)
50+
return lambda: NodeDict(
51+
db,
52+
graph,
53+
default_node_type,
54+
read_parallelism,
55+
read_batch_size,
56+
)
4757

4858

4959
def node_attr_dict_factory(
@@ -262,6 +272,8 @@ def __init__(
262272
db: StandardDatabase,
263273
graph: Graph,
264274
default_node_type: str,
275+
read_parallelism: int,
276+
read_batch_size: int,
265277
*args: Any,
266278
**kwargs: Any,
267279
):
@@ -271,6 +283,9 @@ def __init__(
271283
self.db = db
272284
self.graph = graph
273285
self.default_node_type = default_node_type
286+
self.read_parallelism = read_parallelism
287+
self.read_batch_size = read_batch_size
288+
274289
self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph)
275290

276291
self.FETCHED_ALL_DATA = False
@@ -472,6 +487,8 @@ def _fetch_all(self):
472487
is_directed=False, # not used
473488
is_multigraph=False, # not used
474489
symmetrize_edges_if_directed=False, # not used
490+
read_parallelism=self.read_parallelism,
491+
read_batch_size=self.read_batch_size,
475492
)
476493

477494
for node_id, node_data in node_dict.items():

nx_arangodb/classes/function.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def get_arangodb_graph(
4747
is_directed: bool,
4848
is_multigraph: bool,
4949
symmetrize_edges_if_directed: bool,
50+
read_parallelism: int,
51+
read_batch_size: int,
5052
) -> Tuple[
5153
NodeDict,
5254
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict,
@@ -142,11 +144,9 @@ def get_arangodb_graph(
142144
if not load_adj_dict and not load_coo:
143145
metagraph["edgeCollections"] = {}
144146

145-
config = nx.config.backends.arangodb
146-
assert config.db_name
147-
assert config.host
148-
assert config.username
149-
assert config.password
147+
hosts = adb_graph._conn._hosts
148+
db_name = adb_graph._conn._db_name
149+
username, password = adb_graph._conn._auth
150150

151151
(
152152
node_dict,
@@ -157,20 +157,20 @@ def get_arangodb_graph(
157157
vertex_ids_to_index,
158158
edge_values,
159159
) = NetworkXLoader.load_into_networkx(
160-
config.db_name,
160+
database=db_name,
161161
metagraph=metagraph,
162-
hosts=[config.host],
163-
username=config.username,
164-
password=config.password,
162+
hosts=hosts,
163+
username=username,
164+
password=password,
165165
load_adj_dict=load_adj_dict,
166166
load_coo=load_coo,
167167
load_all_vertex_attributes=load_all_vertex_attributes,
168168
load_all_edge_attributes=load_all_edge_attributes,
169169
is_directed=is_directed,
170170
is_multigraph=is_multigraph,
171171
symmetrize_edges_if_directed=symmetrize_edges_if_directed,
172-
parallelism=config.read_parallelism,
173-
batch_size=config.read_batch_size,
172+
parallelism=read_parallelism,
173+
batch_size=read_batch_size,
174174
)
175175

176176
return (

nx_arangodb/classes/graph.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,13 @@ def __init__(
214214
self.use_nxcg_cache = True
215215
self.nxcg_graph = None
216216

217+
self.edge_type_key = edge_type_key
218+
self.read_parallelism = read_parallelism
219+
self.read_batch_size = read_batch_size
220+
217221
# Does not apply to undirected graphs
218222
self.symmetrize_edges = symmetrize_edges
219223

220-
self.edge_type_key = edge_type_key
221-
222224
# TODO: Consider this
223225
# if not self.__graph_name:
224226
# if incoming_graph_data is not None:
@@ -227,8 +229,8 @@ def __init__(
227229

228230
self._loaded_incoming_graph_data = False
229231
if self.graph_exists_in_db:
230-
self._set_factory_methods()
231-
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
232+
self._set_factory_methods(read_parallelism, read_batch_size)
233+
self.__set_arangodb_backend_config()
232234

233235
if overwrite_graph:
234236
logger.info("Overwriting graph...")
@@ -284,7 +286,7 @@ def __init__(
284286
# Init helper methods #
285287
#######################
286288

287-
def _set_factory_methods(self) -> None:
289+
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
288290
"""Set the factory methods for the graph, _node, and _adj dictionaries.
289291
290292
The ArangoDB CRUD operations are handled by the modified dictionaries.
@@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None:
299301
"""
300302

301303
base_args = (self.db, self.adb_graph)
304+
302305
node_args = (*base_args, self.default_node_type)
303-
adj_args = (
304-
*node_args,
305-
self.edge_type_key,
306-
self.edge_type_func,
307-
self.__class__.__name__,
306+
node_args_with_read = (*node_args, read_parallelism, read_batch_size)
307+
308+
adj_args = (self.edge_type_key, self.edge_type_func, self.__class__.__name__)
309+
adj_inner_args = (*node_args, *adj_args)
310+
adj_outer_args = (
311+
*node_args_with_read,
312+
*adj_args,
313+
self.symmetrize_edges,
308314
)
309315

310316
self.graph_attr_dict_factory = graph_dict_factory(*base_args)
311317

312-
self.node_dict_factory = node_dict_factory(*node_args)
318+
self.node_dict_factory = node_dict_factory(*node_args_with_read)
313319
self.node_attr_dict_factory = node_attr_dict_factory(*base_args)
314320

315321
self.edge_attr_dict_factory = edge_attr_dict_factory(*base_args)
316-
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_args)
317-
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(
318-
*adj_args, self.symmetrize_edges
319-
)
320-
321-
def __set_arangodb_backend_config(
322-
self, read_parallelism: int, read_batch_size: int
323-
) -> None:
324-
if not all([self._host, self._username, self._password, self._db_name]):
325-
m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
326-
raise OSError(m)
322+
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_inner_args)
323+
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(*adj_outer_args)
327324

325+
def __set_arangodb_backend_config(self) -> None:
328326
config = nx.config.backends.arangodb
329-
config.host = self._host
330-
config.username = self._username
331-
config.password = self._password
332-
config.db_name = self._db_name
333-
config.read_parallelism = read_parallelism
334-
config.read_batch_size = read_batch_size
335327
config.use_gpu = True # Only used by default if nx-cugraph is available
336328

337329
def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None:
@@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None
345337
self._edge_collections_attributes.add("_id")
346338

347339
def __set_db(self, db: Any = None) -> None:
348-
self._host = os.getenv("DATABASE_HOST")
340+
self._hosts = os.getenv("DATABASE_HOST", "").split(",")
349341
self._username = os.getenv("DATABASE_USERNAME")
350342
self._password = os.getenv("DATABASE_PASSWORD")
351343
self._db_name = os.getenv("DATABASE_NAME")
@@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None:
355347
m = "arango.database.StandardDatabase"
356348
raise TypeError(m)
357349

358-
db.version()
350+
db.version() # make sure the connection is valid
359351
self.__db = db
352+
self._db_name = db.name
353+
self._hosts = db._conn._hosts
354+
self._username, self._password = db._conn._auth
360355
return
361356

362-
if not all([self._host, self._username, self._password, self._db_name]):
357+
if not all([self._hosts, self._username, self._password, self._db_name]):
363358
m = "Database environment variables not set. Can't connect to the database"
364359
logger.warning(m)
365360
self.__db = None
366361
return
367362

368-
self.__db = ArangoClient(hosts=self._host, request_timeout=None).db(
363+
self.__db = ArangoClient(hosts=self._hosts, request_timeout=None).db(
369364
self._db_name, self._username, self._password, verify=True
370365
)
371366

nx_arangodb/classes/multigraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def __init__(
229229
# Init helper methods #
230230
#######################
231231

232-
def _set_factory_methods(self) -> None:
233-
super()._set_factory_methods()
232+
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
233+
super()._set_factory_methods(read_parallelism, read_batch_size)
234234
self.edge_key_dict_factory = edge_key_dict_factory(
235235
self.db,
236236
self.adb_graph,

nx_arangodb/convert.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
256256
is_directed=G.is_directed(),
257257
is_multigraph=G.is_multigraph(),
258258
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
259+
read_parallelism=G.read_parallelism,
260+
read_batch_size=G.read_batch_size,
259261
)
260262

261263
logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
@@ -337,6 +339,8 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
337339
symmetrize_edges_if_directed=(
338340
G.symmetrize_edges if G.is_directed() else False
339341
),
342+
read_parallelism=G.read_parallelism,
343+
read_batch_size=G.read_batch_size,
340344
)
341345

342346
logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")

0 commit comments

Comments
 (0)