diff --git a/_nx_arangodb/__init__.py b/_nx_arangodb/__init__.py index 616e961..9f6c4d4 100644 --- a/_nx_arangodb/__init__.py +++ b/_nx_arangodb/__init__.py @@ -74,15 +74,7 @@ def get_info(): for key in info_keys: del d[key] - d["default_config"] = { - "host": None, - "username": None, - "password": None, - "db_name": None, - "read_parallelism": None, - "read_batch_size": None, - "use_gpu": True, - } + d["default_config"] = {"use_gpu": True} return d diff --git a/nx_arangodb/classes/dict/adj.py b/nx_arangodb/classes/dict/adj.py index b1b2ddc..2549024 100644 --- a/nx_arangodb/classes/dict/adj.py +++ b/nx_arangodb/classes/dict/adj.py @@ -105,6 +105,8 @@ def adjlist_outer_dict_factory( db: StandardDatabase, graph: Graph, default_node_type: str, + read_parallelism: int, + read_batch_size: int, edge_type_key: str, edge_type_func: Callable[[str, str], str], graph_type: str, @@ -115,6 +117,8 @@ def adjlist_outer_dict_factory( db, graph, default_node_type, + read_parallelism, + read_batch_size, edge_type_key, edge_type_func, graph_type, @@ -1467,6 +1471,8 @@ def __init__( db: StandardDatabase, graph: Graph, default_node_type: str, + read_parallelism: int, + read_batch_size: int, edge_type_key: str, edge_type_func: Callable[[str, str], str], graph_type: str, @@ -1489,6 +1495,8 @@ def __init__( self.edge_type_key = edge_type_key self.edge_type_func = edge_type_func self.default_node_type = default_node_type + self.read_parallelism = read_parallelism + self.read_batch_size = read_batch_size self.adjlist_inner_dict_factory = adjlist_inner_dict_factory( db, graph, @@ -1853,6 +1861,8 @@ def _fetch_all(self) -> None: is_directed=True, is_multigraph=self.is_multigraph, symmetrize_edges_if_directed=self.symmetrize_edges_if_directed, + read_parallelism=self.read_parallelism, + read_batch_size=self.read_batch_size, ) # Even if the Graph is undirected, diff --git a/nx_arangodb/classes/dict/node.py b/nx_arangodb/classes/dict/node.py index 872b158..b68697e 100644 --- a/nx_arangodb/classes/dict/node.py +++ b/nx_arangodb/classes/dict/node.py @@ -40,10 +40,20 @@ def node_dict_factory( - db: StandardDatabase, graph: Graph, default_node_type: str + db: StandardDatabase, + graph: Graph, + default_node_type: str, + read_parallelism: int, + read_batch_size: int, ) -> Callable[..., NodeDict]: """Factory function for creating a NodeDict.""" - return lambda: NodeDict(db, graph, default_node_type) + return lambda: NodeDict( + db, + graph, + default_node_type, + read_parallelism, + read_batch_size, + ) def node_attr_dict_factory( @@ -262,6 +272,8 @@ def __init__( db: StandardDatabase, graph: Graph, default_node_type: str, + read_parallelism: int, + read_batch_size: int, *args: Any, **kwargs: Any, ): @@ -271,6 +283,9 @@ def __init__( self.db = db self.graph = graph self.default_node_type = default_node_type + self.read_parallelism = read_parallelism + self.read_batch_size = read_batch_size + self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) self.FETCHED_ALL_DATA = False @@ -472,6 +487,8 @@ def _fetch_all(self): is_directed=False, # not used is_multigraph=False, # not used symmetrize_edges_if_directed=False, # not used + read_parallelism=self.read_parallelism, + read_batch_size=self.read_batch_size, ) for node_id, node_data in node_dict.items(): diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 491c0cd..8943a73 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -47,6 +47,8 @@ def get_arangodb_graph( is_directed: bool, is_multigraph: bool, symmetrize_edges_if_directed: bool, + read_parallelism: int, + read_batch_size: int, ) -> Tuple[ NodeDict, GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict, @@ -142,11 +144,9 @@ def get_arangodb_graph( if not load_adj_dict and not load_coo: metagraph["edgeCollections"] = {} - config = nx.config.backends.arangodb - assert config.db_name - assert config.host - assert config.username - assert config.password + hosts = adb_graph._conn._hosts + db_name = adb_graph._conn._db_name + username, password = adb_graph._conn._auth ( node_dict, @@ -157,11 +157,11 @@ def get_arangodb_graph( vertex_ids_to_index, edge_values, ) = NetworkXLoader.load_into_networkx( - config.db_name, + database=db_name, metagraph=metagraph, - hosts=[config.host], - username=config.username, - password=config.password, + hosts=hosts, + username=username, + password=password, load_adj_dict=load_adj_dict, load_coo=load_coo, load_all_vertex_attributes=load_all_vertex_attributes, @@ -169,8 +169,8 @@ def get_arangodb_graph( is_directed=is_directed, is_multigraph=is_multigraph, symmetrize_edges_if_directed=symmetrize_edges_if_directed, - parallelism=config.read_parallelism, - batch_size=config.read_batch_size, + parallelism=read_parallelism, + batch_size=read_batch_size, ) return ( diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 336cc3b..ed7cbda 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -214,11 +214,13 @@ def __init__( self.use_nxcg_cache = True self.nxcg_graph = None + self.edge_type_key = edge_type_key + self.read_parallelism = read_parallelism + self.read_batch_size = read_batch_size + # Does not apply to undirected graphs self.symmetrize_edges = symmetrize_edges - self.edge_type_key = edge_type_key - # TODO: Consider this # if not self.__graph_name: # if incoming_graph_data is not None: @@ -227,8 +229,8 @@ def __init__( self._loaded_incoming_graph_data = False if self.graph_exists_in_db: - self._set_factory_methods() - self.__set_arangodb_backend_config(read_parallelism, read_batch_size) + self._set_factory_methods(read_parallelism, read_batch_size) + self.__set_arangodb_backend_config() if overwrite_graph: logger.info("Overwriting graph...") @@ -284,7 +286,7 @@ def __init__( # Init helper methods # ####################### - def _set_factory_methods(self) -> None: + def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None: """Set the factory methods for the graph, _node, and _adj dictionaries. The ArangoDB CRUD operations are handled by the modified dictionaries. @@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None: """ base_args = (self.db, self.adb_graph) + node_args = (*base_args, self.default_node_type) - adj_args = ( - *node_args, - self.edge_type_key, - self.edge_type_func, - self.__class__.__name__, + node_args_with_read = (*node_args, read_parallelism, read_batch_size) + + adj_args = (self.edge_type_key, self.edge_type_func, self.__class__.__name__) + adj_inner_args = (*node_args, *adj_args) + adj_outer_args = ( + *node_args_with_read, + *adj_args, + self.symmetrize_edges, ) self.graph_attr_dict_factory = graph_dict_factory(*base_args) - self.node_dict_factory = node_dict_factory(*node_args) + self.node_dict_factory = node_dict_factory(*node_args_with_read) self.node_attr_dict_factory = node_attr_dict_factory(*base_args) self.edge_attr_dict_factory = edge_attr_dict_factory(*base_args) - self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_args) - self.adjlist_outer_dict_factory = adjlist_outer_dict_factory( - *adj_args, self.symmetrize_edges - ) - - def __set_arangodb_backend_config( - self, read_parallelism: int, read_batch_size: int - ) -> None: - if not all([self._host, self._username, self._password, self._db_name]): - m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501 - raise OSError(m) + self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_inner_args) + self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(*adj_outer_args) + def __set_arangodb_backend_config(self) -> None: config = nx.config.backends.arangodb - config.host = self._host - config.username = self._username - config.password = self._password - config.db_name = self._db_name - config.read_parallelism = read_parallelism - config.read_batch_size = read_batch_size config.use_gpu = True # Only used by default if nx-cugraph is available def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None: @@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None self._edge_collections_attributes.add("_id") def __set_db(self, db: Any = None) -> None: - self._host = os.getenv("DATABASE_HOST") + self._hosts = os.getenv("DATABASE_HOST", "").split(",") self._username = os.getenv("DATABASE_USERNAME") self._password = os.getenv("DATABASE_PASSWORD") self._db_name = os.getenv("DATABASE_NAME") @@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None: m = "arango.database.StandardDatabase" raise TypeError(m) - db.version() + db.version() # make sure the connection is valid self.__db = db + self._db_name = db.name + self._hosts = db._conn._hosts + self._username, self._password = db._conn._auth return - if not all([self._host, self._username, self._password, self._db_name]): + if not all([self._hosts, self._username, self._password, self._db_name]): m = "Database environment variables not set. Can't connect to the database" logger.warning(m) self.__db = None return - self.__db = ArangoClient(hosts=self._host, request_timeout=None).db( + self.__db = ArangoClient(hosts=self._hosts, request_timeout=None).db( self._db_name, self._username, self._password, verify=True ) diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index c494d34..7d7db59 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -229,8 +229,8 @@ def __init__( # Init helper methods # ####################### - def _set_factory_methods(self) -> None: - super()._set_factory_methods() + def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None: + super()._set_factory_methods(read_parallelism, read_batch_size) self.edge_key_dict_factory = edge_key_dict_factory( self.db, self.adb_graph, diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 8eda47b..33df773 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -256,6 +256,8 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph: is_directed=G.is_directed(), is_multigraph=G.is_multigraph(), symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False, + read_parallelism=G.read_parallelism, + read_batch_size=G.read_batch_size, ) logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") @@ -337,6 +339,8 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: symmetrize_edges_if_directed=( G.symmetrize_edges if G.is_directed() else False ), + read_parallelism=G.read_parallelism, + read_batch_size=G.read_batch_size, ) logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")