Skip to content

Commit

Permalink
new: fully support parameterized db object
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Jan 3, 2025
1 parent 9f59085 commit ad38fc6
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 56 deletions.
10 changes: 1 addition & 9 deletions _nx_arangodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,7 @@ def get_info():
for key in info_keys:
del d[key]

d["default_config"] = {
"host": None,
"username": None,
"password": None,
"db_name": None,
"read_parallelism": None,
"read_batch_size": None,
"use_gpu": True,
}
d["default_config"] = {"use_gpu": True}

return d

Expand Down
10 changes: 10 additions & 0 deletions nx_arangodb/classes/dict/adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def adjlist_outer_dict_factory(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
graph_type: str,
Expand All @@ -115,6 +117,8 @@ def adjlist_outer_dict_factory(
db,
graph,
default_node_type,
read_parallelism,
read_batch_size,
edge_type_key,
edge_type_func,
graph_type,
Expand Down Expand Up @@ -1467,6 +1471,8 @@ def __init__(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
graph_type: str,
Expand All @@ -1489,6 +1495,8 @@ def __init__(
self.edge_type_key = edge_type_key
self.edge_type_func = edge_type_func
self.default_node_type = default_node_type
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(
db,
graph,
Expand Down Expand Up @@ -1853,6 +1861,8 @@ def _fetch_all(self) -> None:
is_directed=True,
is_multigraph=self.is_multigraph,
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
read_parallelism=self.read_parallelism,
read_batch_size=self.read_batch_size,
)

# Even if the Graph is undirected,
Expand Down
21 changes: 19 additions & 2 deletions nx_arangodb/classes/dict/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,20 @@


def node_dict_factory(
db: StandardDatabase, graph: Graph, default_node_type: str
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
) -> Callable[..., NodeDict]:
"""Factory function for creating a NodeDict."""
return lambda: NodeDict(db, graph, default_node_type)
return lambda: NodeDict(
db,
graph,
default_node_type,
read_parallelism,
read_batch_size,
)


def node_attr_dict_factory(
Expand Down Expand Up @@ -262,6 +272,8 @@ def __init__(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
*args: Any,
**kwargs: Any,
):
Expand All @@ -271,6 +283,9 @@ def __init__(
self.db = db
self.graph = graph
self.default_node_type = default_node_type
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size

self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph)

self.FETCHED_ALL_DATA = False
Expand Down Expand Up @@ -472,6 +487,8 @@ def _fetch_all(self):
is_directed=False, # not used
is_multigraph=False, # not used
symmetrize_edges_if_directed=False, # not used
read_parallelism=self.read_parallelism,
read_batch_size=self.read_batch_size,
)

for node_id, node_data in node_dict.items():
Expand Down
22 changes: 11 additions & 11 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_arangodb_graph(
is_directed: bool,
is_multigraph: bool,
symmetrize_edges_if_directed: bool,
read_parallelism: int,
read_batch_size: int,
) -> Tuple[
NodeDict,
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict,
Expand Down Expand Up @@ -142,11 +144,9 @@ def get_arangodb_graph(
if not load_adj_dict and not load_coo:
metagraph["edgeCollections"] = {}

config = nx.config.backends.arangodb
assert config.db_name
assert config.host
assert config.username
assert config.password
hosts = adb_graph._conn._hosts
db_name = adb_graph._conn._db_name
username, password = adb_graph._conn._auth

(
node_dict,
Expand All @@ -157,20 +157,20 @@ def get_arangodb_graph(
vertex_ids_to_index,
edge_values,
) = NetworkXLoader.load_into_networkx(
config.db_name,
database=db_name,
metagraph=metagraph,
hosts=[config.host],
username=config.username,
password=config.password,
hosts=hosts,
username=username,
password=password,
load_adj_dict=load_adj_dict,
load_coo=load_coo,
load_all_vertex_attributes=load_all_vertex_attributes,
load_all_edge_attributes=load_all_edge_attributes,
is_directed=is_directed,
is_multigraph=is_multigraph,
symmetrize_edges_if_directed=symmetrize_edges_if_directed,
parallelism=config.read_parallelism,
batch_size=config.read_batch_size,
parallelism=read_parallelism,
batch_size=read_batch_size,
)

return (
Expand Down
59 changes: 27 additions & 32 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,13 @@ def __init__(
self.use_nxcg_cache = True
self.nxcg_graph = None

self.edge_type_key = edge_type_key
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size

# Does not apply to undirected graphs
self.symmetrize_edges = symmetrize_edges

self.edge_type_key = edge_type_key

# TODO: Consider this
# if not self.__graph_name:
# if incoming_graph_data is not None:
Expand All @@ -227,8 +229,8 @@ def __init__(

self._loaded_incoming_graph_data = False
if self.graph_exists_in_db:
self._set_factory_methods()
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
self._set_factory_methods(read_parallelism, read_batch_size)
self.__set_arangodb_backend_config()

if overwrite_graph:
logger.info("Overwriting graph...")
Expand Down Expand Up @@ -284,7 +286,7 @@ def __init__(
# Init helper methods #
#######################

def _set_factory_methods(self) -> None:
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
"""Set the factory methods for the graph, _node, and _adj dictionaries.
The ArangoDB CRUD operations are handled by the modified dictionaries.
Expand All @@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None:
"""

base_args = (self.db, self.adb_graph)

node_args = (*base_args, self.default_node_type)
adj_args = (
*node_args,
self.edge_type_key,
self.edge_type_func,
self.__class__.__name__,
node_args_with_read = (*node_args, read_parallelism, read_batch_size)

adj_args = (self.edge_type_key, self.edge_type_func, self.__class__.__name__)
adj_inner_args = (*node_args, *adj_args)
adj_outer_args = (
*node_args_with_read,
*adj_args,
self.symmetrize_edges,
)

self.graph_attr_dict_factory = graph_dict_factory(*base_args)

self.node_dict_factory = node_dict_factory(*node_args)
self.node_dict_factory = node_dict_factory(*node_args_with_read)
self.node_attr_dict_factory = node_attr_dict_factory(*base_args)

self.edge_attr_dict_factory = edge_attr_dict_factory(*base_args)
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_args)
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(
*adj_args, self.symmetrize_edges
)

def __set_arangodb_backend_config(
self, read_parallelism: int, read_batch_size: int
) -> None:
if not all([self._host, self._username, self._password, self._db_name]):
m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
raise OSError(m)
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_inner_args)
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(*adj_outer_args)

def __set_arangodb_backend_config(self) -> None:
config = nx.config.backends.arangodb
config.host = self._host
config.username = self._username
config.password = self._password
config.db_name = self._db_name
config.read_parallelism = read_parallelism
config.read_batch_size = read_batch_size
config.use_gpu = True # Only used by default if nx-cugraph is available

def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None:
Expand All @@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None
self._edge_collections_attributes.add("_id")

def __set_db(self, db: Any = None) -> None:
self._host = os.getenv("DATABASE_HOST")
self._hosts = os.getenv("DATABASE_HOST", "").split(",")
self._username = os.getenv("DATABASE_USERNAME")
self._password = os.getenv("DATABASE_PASSWORD")
self._db_name = os.getenv("DATABASE_NAME")
Expand All @@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None:
m = "arango.database.StandardDatabase"
raise TypeError(m)

db.version()
db.version() # make sure the connection is valid
self.__db = db
self._db_name = db.name
self._hosts = db._conn._hosts
self._username, self._password = db._conn._auth
return

if not all([self._host, self._username, self._password, self._db_name]):
if not all([self._hosts, self._username, self._password, self._db_name]):
m = "Database environment variables not set. Can't connect to the database"
logger.warning(m)
self.__db = None
return

self.__db = ArangoClient(hosts=self._host, request_timeout=None).db(
self.__db = ArangoClient(hosts=self._hosts, request_timeout=None).db(
self._db_name, self._username, self._password, verify=True
)

Expand Down
4 changes: 2 additions & 2 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def __init__(
# Init helper methods #
#######################

def _set_factory_methods(self) -> None:
super()._set_factory_methods()
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
super()._set_factory_methods(read_parallelism, read_batch_size)
self.edge_key_dict_factory = edge_key_dict_factory(
self.db,
self.adb_graph,
Expand Down
4 changes: 4 additions & 0 deletions nx_arangodb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
read_parallelism=G.read_parallelism,
read_batch_size=G.read_batch_size,
)

logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
Expand Down Expand Up @@ -337,6 +339,8 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
symmetrize_edges_if_directed=(
G.symmetrize_edges if G.is_directed() else False
),
read_parallelism=G.read_parallelism,
read_batch_size=G.read_batch_size,
)

logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
Expand Down

0 comments on commit ad38fc6

Please sign in to comment.