From 7cbb64bfd1d2d99ec606aee008250486f7c74567 Mon Sep 17 00:00:00 2001 From: Jacob Sznajdman Date: Thu, 22 Jun 2023 12:11:14 +0200 Subject: [PATCH 01/14] Implement client endpoints for gnn/graph sage training --- graphdatascience/endpoints.py | 3 ++- graphdatascience/gnn/__init__.py | 0 graphdatascience/gnn/gnn_endpoints.py | 17 ++++++++++++++++ graphdatascience/gnn/gnn_nc_runner.py | 21 ++++++++++++++++++++ graphdatascience/ignored_server_endpoints.py | 1 + 5 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 graphdatascience/gnn/__init__.py create mode 100644 graphdatascience/gnn/gnn_endpoints.py create mode 100644 graphdatascience/gnn/gnn_nc_runner.py diff --git a/graphdatascience/endpoints.py b/graphdatascience/endpoints.py index 4abd44247..8df5e5e30 100644 --- a/graphdatascience/endpoints.py +++ b/graphdatascience/endpoints.py @@ -1,5 +1,6 @@ from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder +from .gnn.gnn_endpoints import GnnEndpoints from .graph.graph_endpoints import ( GraphAlphaEndpoints, GraphBetaEndpoints, @@ -32,7 +33,7 @@ """ -class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints): +class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints): def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion): super().__init__(query_runner, namespace, server_version) diff --git a/graphdatascience/gnn/__init__.py b/graphdatascience/gnn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/gnn/gnn_endpoints.py b/graphdatascience/gnn/gnn_endpoints.py new file mode 100644 index 000000000..e140eb8fc --- /dev/null +++ b/graphdatascience/gnn/gnn_endpoints.py @@ -0,0 +1,17 @@ +from .gnn_nc_runner import GNNNodeClassificationRunner +from ..caller_base import CallerBase +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace + +class GNNRunner(UncallableNamespace, IllegalAttrChecker): + @property + def nodeClassification(self) -> GNNNodeClassificationRunner: + return GNNNodeClassificationRunner(self._query_runner, f"{self._namespace}.nodeClassification", self._server_version) + +class GnnEndpoints(CallerBase): + @property + def gnn(self) -> GNNRunner: + return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version) + + + diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py new file mode 100644 index 000000000..d898176f0 --- /dev/null +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -0,0 +1,21 @@ +from typing import Any, List + +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace +import json + + +class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): + def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str, + target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": + configMap = { + "featureProperties": feature_properties, + "targetProperty": target_property, + } + node_properties = feature_properties + [target_property] + if target_node_label: + configMap["targetNodeLabel"] = target_node_label + mlTrainingConfig = json.dumps(configMap) + # TODO query avaiable node labels + node_labels = ["Paper"] if not node_labels else node_labels + self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})") diff --git a/graphdatascience/ignored_server_endpoints.py b/graphdatascience/ignored_server_endpoints.py index 89ad9f0b2..d103a90c4 100644 --- a/graphdatascience/ignored_server_endpoints.py +++ b/graphdatascience/ignored_server_endpoints.py @@ -47,6 +47,7 @@ "gds.alpha.pipeline.nodeRegression.predict.stream", "gds.alpha.pipeline.nodeRegression.selectFeatures", "gds.alpha.pipeline.nodeRegression.train", + "gds.gnn.nc", "gds.similarity.cosine", "gds.similarity.euclidean", "gds.similarity.euclideanDistance", From 5117eea82fcaebc89605740a8d23dc2241b076a2 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 6 Jul 2023 17:41:52 +0200 Subject: [PATCH 02/14] Add predict endpoint to GNN NC runner Co-authored-by: Jacob Sznajdman Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index d898176f0..64dbb44b1 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -11,11 +11,26 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str], configMap = { "featureProperties": feature_properties, "targetProperty": target_property, + "job_type": "train", } node_properties = feature_properties + [target_property] if target_node_label: configMap["targetNodeLabel"] = target_node_label mlTrainingConfig = json.dumps(configMap) - # TODO query avaiable node labels + # TODO query available node labels node_labels = ["Paper"] if not node_labels else node_labels self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})") + + + def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": + configMap = { + "featureProperties": feature_properties, + "job_type": "predict", + } + if target_node_label: + configMap["targetNodeLabel"] = target_node_label + mlTrainingConfig = json.dumps(configMap) + # TODO query available node labels + node_labels = ["Paper"] if not node_labels else node_labels + self._query_runner.run_query( + f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})") From 1e8237e19d4adb0d54b514b6e841fedea13627f6 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 6 Jul 2023 18:00:12 +0200 Subject: [PATCH 03/14] Add notebook illustrating usage of new GNN stuff --- examples/python-runtime.ipynb | 83 +++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 examples/python-runtime.ipynb diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb new file mode 100644 index 000000000..d0a44265f --- /dev/null +++ b/examples/python-runtime.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "DBID = \"beefbeef\"\n", + "ENVIRONMENT = \"\"\n", + "PASSWORD = \"\"\n", + "\n", + "from graphdatascience import GraphDataScience\n", + "\n", + "gds = GraphDataScience(\n", + " f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD)\n", + ")\n", + "gds.set_database(\"neo4j\")\n", + "\n", + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "try:\n", + " gds.graph.load_cora()\n", + "except:\n", + " pass\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 95ecaa5ffa5aa3c1eba05570bd384ad2165e473b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Tue, 11 Jul 2023 11:10:07 +0200 Subject: [PATCH 04/14] Inject Arrow credentials to UploadGraph endpoint Co-authored-by: Brian Shi Co-authored-by: Olga Razvenskaia --- examples/python-runtime.ipynb | 48 +++++-------------- graphdatascience/endpoints.py | 4 +- graphdatascience/gnn/gnn_endpoints.py | 11 +++-- graphdatascience/gnn/gnn_nc_runner.py | 13 ++++- .../query_runner/arrow_query_runner.py | 31 +++++++++++- 5 files changed, 62 insertions(+), 45 deletions(-) diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb index d0a44265f..ca7d4c2f5 100644 --- a/examples/python-runtime.ipynb +++ b/examples/python-runtime.ipynb @@ -3,9 +3,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "DBID = \"beefbeef\"\n", @@ -14,68 +12,46 @@ "\n", "from graphdatascience import GraphDataScience\n", "\n", - "gds = GraphDataScience(\n", - " f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD)\n", - ")\n", + "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", "gds.set_database(\"neo4j\")\n", "\n", - "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n" + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "try:\n", " gds.graph.load_cora()\n", "except:\n", - " pass\n" - ], - "metadata": { - "collapsed": false - } + " pass" + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n" - ], - "metadata": { - "collapsed": false - } + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" - ], - "metadata": { - "collapsed": false - } + ] } ], "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "name": "python" } }, "nbformat": 4, diff --git a/graphdatascience/endpoints.py b/graphdatascience/endpoints.py index 8df5e5e30..e91c1702b 100644 --- a/graphdatascience/endpoints.py +++ b/graphdatascience/endpoints.py @@ -33,7 +33,9 @@ """ -class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints): +class DirectEndpoints( + DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints +): def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion): super().__init__(query_runner, namespace, server_version) diff --git a/graphdatascience/gnn/gnn_endpoints.py b/graphdatascience/gnn/gnn_endpoints.py index e140eb8fc..ba1b7b2b7 100644 --- a/graphdatascience/gnn/gnn_endpoints.py +++ b/graphdatascience/gnn/gnn_endpoints.py @@ -1,17 +1,18 @@ -from .gnn_nc_runner import GNNNodeClassificationRunner from ..caller_base import CallerBase from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace +from .gnn_nc_runner import GNNNodeClassificationRunner + class GNNRunner(UncallableNamespace, IllegalAttrChecker): @property def nodeClassification(self) -> GNNNodeClassificationRunner: - return GNNNodeClassificationRunner(self._query_runner, f"{self._namespace}.nodeClassification", self._server_version) + return GNNNodeClassificationRunner( + self._query_runner, f"{self._namespace}.nodeClassification", self._server_version + ) + class GnnEndpoints(CallerBase): @property def gnn(self) -> GNNRunner: return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version) - - - diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 64dbb44b1..54545c020 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -13,13 +13,24 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str], "targetProperty": target_property, "job_type": "train", } + node_properties = feature_properties + [target_property] if target_node_label: configMap["targetNodeLabel"] = target_node_label mlTrainingConfig = json.dumps(configMap) # TODO query available node labels node_labels = ["Paper"] if not node_labels else node_labels - self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})") + + # token and uri will be injected by arrow_query_runner + self._query_runner.run_query( + f"CALL gds.upload.graph($graph_name, $config)", + params={"graph_name": graph_name, "config": { + "mlTrainingConfig": mlTrainingConfig, + "modelName": model_name, + "nodeLabels": node_labels, + "nodeProperties": node_properties + }} + ) def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index cf648879a..acb1b1325 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -29,6 +29,9 @@ def __init__( ): self._fallback_query_runner = fallback_query_runner self._server_version = server_version + # FIXME handle version were tls cert is given + self._auth = auth + self._uri = uri host, port_string = uri.split(":") @@ -39,8 +42,9 @@ def __init__( ) client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} + self._auth_factory = AuthFactory(auth) if auth: - client_options["middleware"] = [AuthFactory(auth)] + client_options["middleware"] = [self._auth_factory] if tls_root_certs: client_options["tls_root_certs"] = tls_root_certs @@ -129,6 +133,11 @@ def run_query( endpoint = "gds.beta.graph.relationships.stream" return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types}) + elif "gds.upload.graph" in query: + # inject parameters + params["config"]["token"] = self._get_or_request_token() + params["config"]["arrowEndpoint"] = self._uri + print(params) return self._fallback_query_runner.run_query(query, params, database, custom_error) @@ -183,11 +192,19 @@ def create_graph_constructor( return ArrowGraphConstructor( database, graph_name, self._flight_client, concurrency, undirected_relationship_types ) + + def _get_or_request_token(self) -> str: + print("get or request token") + self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + return self._auth_factory.token() class AuthFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) + print("init auth factory") + + self._auth = auth self._token: Optional[str] = None self._token_timestamp = 0 @@ -196,6 +213,7 @@ def start_call(self, info: Any) -> "AuthMiddleware": return AuthMiddleware(self) def token(self) -> Optional[str]: + print(f"current token {self._token} at {self._token_timestamp}") # check whether the token is older than 10 minutes. If so, reset it. if self._token and int(time.time()) - self._token_timestamp > 600: self._token = None @@ -206,6 +224,8 @@ def set_token(self, token: str) -> None: self._token = token self._token_timestamp = int(time.time()) + print(f"set token {self._token} time_stamp: {self._token_timestamp}") + @property def auth(self) -> Tuple[str, str]: return self._auth @@ -217,14 +237,21 @@ def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None: self._factory = factory def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header: str = headers.get("Authorization", None) + auth_header: str = headers.get("authorization", None) if not auth_header: return + # authenticate_basic_token() returns a list. + # TODO We should take the first Bearer element here + if isinstance(auth_header, list): + auth_header = auth_header[0] + [auth_type, token] = auth_header.split(" ", 1) if auth_type == "Bearer": self._factory.set_token(token) def sending_headers(self) -> Dict[str, str]: + print("sending headers") + token = self._factory.token() if not token: username, password = self._factory.auth From f1977a30b37f992748010a8ea2e86d3b499d0063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Tue, 11 Jul 2023 17:28:49 +0200 Subject: [PATCH 05/14] Remove prints --- graphdatascience/query_runner/arrow_query_runner.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index acb1b1325..b13565858 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -137,7 +137,6 @@ def run_query( # inject parameters params["config"]["token"] = self._get_or_request_token() params["config"]["arrowEndpoint"] = self._uri - print(params) return self._fallback_query_runner.run_query(query, params, database, custom_error) @@ -194,7 +193,6 @@ def create_graph_constructor( ) def _get_or_request_token(self) -> str: - print("get or request token") self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) return self._auth_factory.token() @@ -202,9 +200,6 @@ def _get_or_request_token(self) -> str: class AuthFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - print("init auth factory") - - self._auth = auth self._token: Optional[str] = None self._token_timestamp = 0 @@ -213,7 +208,6 @@ def start_call(self, info: Any) -> "AuthMiddleware": return AuthMiddleware(self) def token(self) -> Optional[str]: - print(f"current token {self._token} at {self._token_timestamp}") # check whether the token is older than 10 minutes. If so, reset it. if self._token and int(time.time()) - self._token_timestamp > 600: self._token = None @@ -224,8 +218,6 @@ def set_token(self, token: str) -> None: self._token = token self._token_timestamp = int(time.time()) - print(f"set token {self._token} time_stamp: {self._token_timestamp}") - @property def auth(self) -> Tuple[str, str]: return self._auth @@ -250,8 +242,6 @@ def received_headers(self, headers: Dict[str, Any]) -> None: self._factory.set_token(token) def sending_headers(self) -> Dict[str, str]: - print("sending headers") - token = self._factory.token() if not token: username, password = self._factory.auth From ae2af59363f15071be143fd0185f17c3da964564 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Wed, 12 Jul 2023 17:10:01 +0100 Subject: [PATCH 06/14] Add all configs to CRD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Florentin Dörre Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 54545c020..8eccd6c88 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -12,23 +12,22 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str], "featureProperties": feature_properties, "targetProperty": target_property, "job_type": "train", + "nodeProperties": feature_properties + [target_property] } - node_properties = feature_properties + [target_property] if target_node_label: configMap["targetNodeLabel"] = target_node_label + if node_labels: + configMap["nodeLabels"] = node_labels + mlTrainingConfig = json.dumps(configMap) - # TODO query available node labels - node_labels = ["Paper"] if not node_labels else node_labels # token and uri will be injected by arrow_query_runner self._query_runner.run_query( f"CALL gds.upload.graph($graph_name, $config)", params={"graph_name": graph_name, "config": { "mlTrainingConfig": mlTrainingConfig, - "modelName": model_name, - "nodeLabels": node_labels, - "nodeProperties": node_properties + "modelName": model_name }} ) From 8a681b422d5124b614c06931af0a7e352ad91261 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Mon, 17 Jul 2023 11:35:25 +0100 Subject: [PATCH 07/14] Cleanup nc_runner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Florentin Dörre Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 64 ++++++++++++------- .../query_runner/arrow_query_runner.py | 4 +- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 8eccd6c88..adb49fb07 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -1,46 +1,66 @@ +import json from typing import Any, List from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace -import json class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): - def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str, - target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": - configMap = { + def train( + self, + graph_name: str, + model_name: str, + feature_properties: List[str], + target_property: str, + target_node_label: str = None, + node_labels: List[str] = None, + ) -> "Series[Any]": # noqa: F821 + mlConfigMap = { "featureProperties": feature_properties, "targetProperty": target_property, "job_type": "train", - "nodeProperties": feature_properties + [target_property] + "nodeProperties": feature_properties + [target_property], } if target_node_label: - configMap["targetNodeLabel"] = target_node_label + mlConfigMap["targetNodeLabel"] = target_node_label if node_labels: - configMap["nodeLabels"] = node_labels + mlConfigMap["nodeLabels"] = node_labels - mlTrainingConfig = json.dumps(configMap) + mlTrainingConfig = json.dumps(mlConfigMap) # token and uri will be injected by arrow_query_runner self._query_runner.run_query( - f"CALL gds.upload.graph($graph_name, $config)", - params={"graph_name": graph_name, "config": { - "mlTrainingConfig": mlTrainingConfig, - "modelName": model_name - }} - ) - + "CALL gds.upload.graph($graph_name, $config)", + params={ + "graph_name": graph_name, + "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, + }, + ) - def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": - configMap = { + def predict( + self, + graph_name: str, + model_name: str, + feature_properties: List[str], + target_node_label: str = None, + node_labels: List[str] = None, + ) -> "Series[Any]": # noqa: F821 + mlConfigMap = { "featureProperties": feature_properties, "job_type": "predict", + "nodeProperties": feature_properties, } if target_node_label: - configMap["targetNodeLabel"] = target_node_label - mlTrainingConfig = json.dumps(configMap) - # TODO query available node labels - node_labels = ["Paper"] if not node_labels else node_labels + mlConfigMap["targetNodeLabel"] = target_node_label + if node_labels: + mlConfigMap["nodeLabels"] = node_labels + + mlTrainingConfig = json.dumps(mlConfigMap) self._query_runner.run_query( - f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})") + "CALL gds.upload.graph($graph_name, $config)", + params={ + "graph_name": graph_name, + "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, + }, + ) # type: ignore diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index b13565858..eab64398c 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -191,7 +191,7 @@ def create_graph_constructor( return ArrowGraphConstructor( database, graph_name, self._flight_client, concurrency, undirected_relationship_types ) - + def _get_or_request_token(self) -> str: self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) return self._auth_factory.token() @@ -232,7 +232,7 @@ def received_headers(self, headers: Dict[str, Any]) -> None: auth_header: str = headers.get("authorization", None) if not auth_header: return - # authenticate_basic_token() returns a list. + # authenticate_basic_token() returns a list. # TODO We should take the first Bearer element here if isinstance(auth_header, list): auth_header = auth_header[0] From 84b8cef6677fae3b7b6f8236e50ce82e4b30f19e Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Tue, 18 Jul 2023 16:03:44 +0100 Subject: [PATCH 08/14] Parse remote ML configs --- graphdatascience/gnn/gnn_nc_runner.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index adb49fb07..a18f4499d 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -12,6 +12,7 @@ def train( model_name: str, feature_properties: List[str], target_property: str, + relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, ) -> "Series[Any]": # noqa: F821 @@ -20,6 +21,7 @@ def train( "targetProperty": target_property, "job_type": "train", "nodeProperties": feature_properties + [target_property], + "relationshipTypes": relationship_types } if target_node_label: @@ -31,10 +33,9 @@ def train( # token and uri will be injected by arrow_query_runner self._query_runner.run_query( - "CALL gds.upload.graph($graph_name, $config)", + "CALL gds.upload.graph($config)", params={ - "graph_name": graph_name, - "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, }, ) @@ -43,6 +44,7 @@ def predict( graph_name: str, model_name: str, feature_properties: List[str], + relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, ) -> "Series[Any]": # noqa: F821 @@ -50,6 +52,7 @@ def predict( "featureProperties": feature_properties, "job_type": "predict", "nodeProperties": feature_properties, + "relationshipTypes": relationship_types } if target_node_label: mlConfigMap["targetNodeLabel"] = target_node_label @@ -58,9 +61,8 @@ def predict( mlTrainingConfig = json.dumps(mlConfigMap) self._query_runner.run_query( - "CALL gds.upload.graph($graph_name, $config)", + "CALL gds.upload.graph($config)", params={ - "graph_name": graph_name, - "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, }, ) # type: ignore From 82087b7829b5dfad122108784ccc90cb7c102742 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Thu, 20 Jul 2023 15:30:11 +0100 Subject: [PATCH 09/14] Add mutateProperty MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Florentin Dörre Co-authored-by: Jacob Sznajdman Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index a18f4499d..261443a87 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -45,6 +45,7 @@ def predict( model_name: str, feature_properties: List[str], relationship_types: List[str], + mutateProperty: str, target_node_label: str = None, node_labels: List[str] = None, ) -> "Series[Any]": # noqa: F821 @@ -52,7 +53,8 @@ def predict( "featureProperties": feature_properties, "job_type": "predict", "nodeProperties": feature_properties, - "relationshipTypes": relationship_types + "relationshipTypes": relationship_types, + "mutateProperty": mutateProperty } if target_node_label: mlConfigMap["targetNodeLabel"] = target_node_label From 8e822bed26c57e4ecbb0b27d50418bef49b2c358 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 31 Jul 2023 16:12:38 +0100 Subject: [PATCH 10/14] Short form of predict call Co-authored-by: Jacob Sznajdman --- graphdatascience/gnn/gnn_nc_runner.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 261443a87..90ee4b56b 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -43,23 +43,15 @@ def predict( self, graph_name: str, model_name: str, - feature_properties: List[str], - relationship_types: List[str], mutateProperty: str, - target_node_label: str = None, - node_labels: List[str] = None, + predictedProbabilityProperty: str = None, ) -> "Series[Any]": # noqa: F821 mlConfigMap = { - "featureProperties": feature_properties, "job_type": "predict", - "nodeProperties": feature_properties, - "relationshipTypes": relationship_types, "mutateProperty": mutateProperty } - if target_node_label: - mlConfigMap["targetNodeLabel"] = target_node_label - if node_labels: - mlConfigMap["nodeLabels"] = node_labels + if predictedProbabilityProperty: + mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty mlTrainingConfig = json.dumps(mlConfigMap) self._query_runner.run_query( From 74d563975a55a85252b7be3ee6a07887bd171846 Mon Sep 17 00:00:00 2001 From: Jacob Sznajdman Date: Thu, 27 Jul 2023 17:56:08 +0200 Subject: [PATCH 11/14] Expose graphsage training configuration Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 90ee4b56b..37f7125d9 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -6,6 +6,21 @@ class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): + def make_graph_sage_config(self, graph_sage_config): + GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, + "hidden_channels": 256} + final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG + if graph_sage_config: + bad_keys = [] + for key in graph_sage_config: + if key not in GRAPH_SAGE_DEFAULT_CONFIG: + bad_keys.append(key) + if len(bad_keys) > 0: + raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.") + + final_sage_config.update(graph_sage_config) + return final_sage_config + def train( self, graph_name: str, @@ -15,13 +30,15 @@ def train( relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, + graph_sage_config = None ) -> "Series[Any]": # noqa: F821 mlConfigMap = { "featureProperties": feature_properties, "targetProperty": target_property, "job_type": "train", "nodeProperties": feature_properties + [target_property], - "relationshipTypes": relationship_types + "relationshipTypes": relationship_types, + "graph_sage_config": self.make_graph_sage_config(graph_sage_config) } if target_node_label: From fa544887ea83a68bb5005411e2b1dc4d43983be3 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Tue, 1 Aug 2023 11:19:15 +0100 Subject: [PATCH 12/14] Add learning rate Co-authored-by: Jacob Sznajdman --- graphdatascience/gnn/gnn_nc_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 37f7125d9..27aec8d63 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -8,7 +8,7 @@ class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): def make_graph_sage_config(self, graph_sage_config): GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, - "hidden_channels": 256} + "hidden_channels": 256, "learning_rate": 0.003} final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG if graph_sage_config: bad_keys = [] From e6849590b7b3100585bc1262a93b98bd178569e0 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Tue, 1 Aug 2023 16:15:09 +0100 Subject: [PATCH 13/14] Update python-runtime notebook --- examples/python-runtime.ipynb | 167 ++++++++++++++++++++++++++++++---- 1 file changed, 148 insertions(+), 19 deletions(-) diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb index ca7d4c2f5..da2118da3 100644 --- a/examples/python-runtime.ipynb +++ b/examples/python-runtime.ipynb @@ -2,27 +2,68 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "DBID = \"beefbeef\"\n", - "ENVIRONMENT = \"\"\n", - "PASSWORD = \"\"\n", - "\n", - "from graphdatascience import GraphDataScience\n", - "\n", - "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", - "gds.set_database(\"neo4j\")\n", - "\n", - "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + "from graphdatascience import GraphDataScience" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 2, + "outputs": [], + "source": [ + "ENVIRONMENT = \"mlruntimedev\"\n", + "DBID = \"e6ba1b5c\"\n", + "PASSWORD = \"l4Co2Qa5GseW0sMropCvJo17laf6ZCq9vuAhiJrVW2c\"" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, "outputs": [], + "source": [ + "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", + "gds.set_database(\"neo4j\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": "Uploading Nodes: 0%| | 0/2708 [00:00\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n
gds.remoteml.getTrainResult('model2')
0{'test_acc_mean': 0.8589511513710022, 'test_ac...
\n" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_result" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + "predict_result = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" ] + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "cora = gds.graph.get('cora')" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "predictions = gds.graph.nodeProperties.stream(cora, node_properties=[\"features\", \"myPredictions\"], separate_property_columns=True)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "data": { + "text/plain": " nodeId features \\\n0 31336 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n1 1061127 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... \n2 1106406 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n3 13195 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n4 37879 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n... ... ... \n2703 1128975 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2704 1128977 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2705 1128978 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2706 117328 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2707 24043 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n\n model2Predictions \n0 0 \n1 1 \n2 2 \n3 2 \n4 3 \n... ... \n2703 5 \n2704 5 \n2705 5 \n2706 6 \n2707 0 \n\n[2708 rows x 3 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
nodeIdfeaturesmodel2Predictions
031336[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0
11061127[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...1
21106406[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...2
313195[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...2
437879[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...3
............
27031128975[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
27041128977[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
27051128978[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
2706117328[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...6
270724043[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0
\n

2708 rows × 3 columns

\n
" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictions" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } } ], "metadata": { From 8e83471b5ac599fd7af854c44b0c02259256e3aa Mon Sep 17 00:00:00 2001 From: Jacob Sznajdman Date: Fri, 4 Aug 2023 12:52:07 +0200 Subject: [PATCH 14/14] WIP --- graphdatascience/gnn/gnn_nc_runner.py | 59 +++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 27aec8d63..6e65e2337 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -1,5 +1,6 @@ import json from typing import Any, List +import time from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace @@ -21,6 +22,21 @@ def make_graph_sage_config(self, graph_sage_config): final_sage_config.update(graph_sage_config) return final_sage_config + def get_logs(self, job_id: str, offset=0) -> "Series[Any]": # noqa: F821 + return self._query_runner.run_query( + "RETURN gds.remoteml.getLogs($job_id, $offset)", + params={ + "job_id": job_id, + "offset": offset + }).squeeze() + + def get_train_result(self, model_name: str) -> "Series[Any]": # noqa: F821 + return self._query_runner.run_query( + "RETURN gds.remoteml.getTrainResult($model_name)", + params={ + "model_name": model_name + }).squeeze() + def train( self, graph_name: str, @@ -30,7 +46,8 @@ def train( relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, - graph_sage_config = None + graph_sage_config = None, + logging_interval: int = 5 ) -> "Series[Any]": # noqa: F821 mlConfigMap = { "featureProperties": feature_properties, @@ -49,12 +66,28 @@ def train( mlTrainingConfig = json.dumps(mlConfigMap) # token and uri will be injected by arrow_query_runner - self._query_runner.run_query( - "CALL gds.upload.graph($config)", + job_id = self._query_runner.run_query( + "CALL gds.upload.graph($config) YIELD jobId", params={ "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, }, - ) + ).jobId[0] + + received_logs = 0 + training_done = False + while not training_done: + for log in self.get_logs(job_id, offset=received_logs): + print(log) + received_logs += 1 + try: + self.get_train_result(model_name) + training_done = True + except Exception: + time.sleep(logging_interval) + + return job_id + + def predict( self, @@ -62,6 +95,7 @@ def predict( model_name: str, mutateProperty: str, predictedProbabilityProperty: str = None, + logging_interval = 5 ) -> "Series[Any]": # noqa: F821 mlConfigMap = { "job_type": "predict", @@ -71,9 +105,20 @@ def predict( mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty mlTrainingConfig = json.dumps(mlConfigMap) - self._query_runner.run_query( - "CALL gds.upload.graph($config)", + job_id = self._query_runner.run_query( + "CALL gds.upload.graph($config) YIELD jobId", params={ "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, }, - ) # type: ignore + ).jobId[0] + + received_logs = 0 + prediction_done = False + while not prediction_done: + for log in self.get_logs(job_id, offset=received_logs): + print(log) + received_logs += 1 + if log == "Prediction job completed": + prediction_done = True + if not prediction_done: + time.sleep(logging_interval)