Skip to content

Aura-python-runtime: Read remote logs to make client responsive #453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
188 changes: 188 additions & 0 deletions examples/python-runtime.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from graphdatascience import GraphDataScience"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"ENVIRONMENT = \"mlruntimedev\"\n",
"DBID = \"e6ba1b5c\"\n",
"PASSWORD = \"l4Co2Qa5GseW0sMropCvJo17laf6ZCq9vuAhiJrVW2c\""
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n",
"gds.set_database(\"neo4j\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "Uploading Nodes: 0%| | 0/2708 [00:00<?, ?Records/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "2402df2e03544cc4aa950b31cdfc0b47"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "Uploading Relationships: 0%| | 0/5429 [00:00<?, ?Records/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "9df2026182824543a38c5d30a04a3ea0"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"try:\n",
" gds.graph.load_cora()\n",
"except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"train_response = gds.gnn.nodeClassification.train(\n",
" \"cora\", \"myModel\", [\"features\"], \"subject\", [\"CITES\"], target_node_label=\"Paper\", node_labels=[\"Paper\"]\n",
")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [
"train_result = gds.run_cypher(\"RETURN gds.remoteml.getTrainResult('myModel')\");"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": " gds.remoteml.getTrainResult('model2')\n0 {'test_acc_mean': 0.8589511513710022, 'test_ac...",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>gds.remoteml.getTrainResult('model2')</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>{'test_acc_mean': 0.8589511513710022, 'test_ac...</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_result"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"predict_result = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [],
"source": [
"cora = gds.graph.get('cora')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"predictions = gds.graph.nodeProperties.stream(cora, node_properties=[\"features\", \"myPredictions\"], separate_property_columns=True)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [
{
"data": {
"text/plain": " nodeId features \\\n0 31336 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n1 1061127 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... \n2 1106406 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n3 13195 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n4 37879 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n... ... ... \n2703 1128975 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2704 1128977 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2705 1128978 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2706 117328 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2707 24043 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n\n model2Predictions \n0 0 \n1 1 \n2 2 \n3 2 \n4 3 \n... ... \n2703 5 \n2704 5 \n2705 5 \n2706 6 \n2707 0 \n\n[2708 rows x 3 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>nodeId</th>\n <th>features</th>\n <th>model2Predictions</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>31336</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1061127</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...</td>\n <td>1</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1106406</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>2</td>\n </tr>\n <tr>\n <th>3</th>\n <td>13195</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>2</td>\n </tr>\n <tr>\n <th>4</th>\n <td>37879</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>3</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>2703</th>\n <td>1128975</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>5</td>\n </tr>\n <tr>\n <th>2704</th>\n <td>1128977</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>5</td>\n </tr>\n <tr>\n <th>2705</th>\n <td>1128978</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>5</td>\n </tr>\n <tr>\n <th>2706</th>\n <td>117328</td>\n <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>6</td>\n </tr>\n <tr>\n <th>2707</th>\n <td>24043</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n <td>0</td>\n </tr>\n </tbody>\n</table>\n<p>2708 rows × 3 columns</p>\n</div>"
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
5 changes: 4 additions & 1 deletion graphdatascience/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints
from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder
from .gnn.gnn_endpoints import GnnEndpoints
from .graph.graph_endpoints import (
GraphAlphaEndpoints,
GraphBetaEndpoints,
Expand Down Expand Up @@ -32,7 +33,9 @@
"""


class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints):
class DirectEndpoints(
DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints
):
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion):
super().__init__(query_runner, namespace, server_version)

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions graphdatascience/gnn/gnn_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ..caller_base import CallerBase
from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace
from .gnn_nc_runner import GNNNodeClassificationRunner


class GNNRunner(UncallableNamespace, IllegalAttrChecker):
@property
def nodeClassification(self) -> GNNNodeClassificationRunner:
return GNNNodeClassificationRunner(
self._query_runner, f"{self._namespace}.nodeClassification", self._server_version
)


class GnnEndpoints(CallerBase):
@property
def gnn(self) -> GNNRunner:
return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version)
124 changes: 124 additions & 0 deletions graphdatascience/gnn/gnn_nc_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import json
from typing import Any, List
import time

from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace


class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
def make_graph_sage_config(self, graph_sage_config):
GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5,
"hidden_channels": 256, "learning_rate": 0.003}
final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG
if graph_sage_config:
bad_keys = []
for key in graph_sage_config:
if key not in GRAPH_SAGE_DEFAULT_CONFIG:
bad_keys.append(key)
if len(bad_keys) > 0:
raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.")

final_sage_config.update(graph_sage_config)
return final_sage_config

def get_logs(self, job_id: str, offset=0) -> "Series[Any]": # noqa: F821
return self._query_runner.run_query(
"RETURN gds.remoteml.getLogs($job_id, $offset)",
params={
"job_id": job_id,
"offset": offset
}).squeeze()

def get_train_result(self, model_name: str) -> "Series[Any]": # noqa: F821
return self._query_runner.run_query(
"RETURN gds.remoteml.getTrainResult($model_name)",
params={
"model_name": model_name
}).squeeze()

def train(
self,
graph_name: str,
model_name: str,
feature_properties: List[str],
target_property: str,
relationship_types: List[str],
target_node_label: str = None,
node_labels: List[str] = None,
graph_sage_config = None,
logging_interval: int = 5
) -> "Series[Any]": # noqa: F821
mlConfigMap = {
"featureProperties": feature_properties,
"targetProperty": target_property,
"job_type": "train",
"nodeProperties": feature_properties + [target_property],
"relationshipTypes": relationship_types,
"graph_sage_config": self.make_graph_sage_config(graph_sage_config)
}

if target_node_label:
mlConfigMap["targetNodeLabel"] = target_node_label
if node_labels:
mlConfigMap["nodeLabels"] = node_labels

mlTrainingConfig = json.dumps(mlConfigMap)

# token and uri will be injected by arrow_query_runner
job_id = self._query_runner.run_query(
"CALL gds.upload.graph($config) YIELD jobId",
params={
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
},
).jobId[0]

received_logs = 0
training_done = False
while not training_done:
for log in self.get_logs(job_id, offset=received_logs):
print(log)
received_logs += 1
try:
self.get_train_result(model_name)
training_done = True
except Exception:
time.sleep(logging_interval)

return job_id



def predict(
self,
graph_name: str,
model_name: str,
mutateProperty: str,
predictedProbabilityProperty: str = None,
logging_interval = 5
) -> "Series[Any]": # noqa: F821
mlConfigMap = {
"job_type": "predict",
"mutateProperty": mutateProperty
}
if predictedProbabilityProperty:
mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty

mlTrainingConfig = json.dumps(mlConfigMap)
job_id = self._query_runner.run_query(
"CALL gds.upload.graph($config) YIELD jobId",
params={
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
},
).jobId[0]

received_logs = 0
prediction_done = False
while not prediction_done:
for log in self.get_logs(job_id, offset=received_logs):
print(log)
received_logs += 1
if log == "Prediction job completed":
prediction_done = True
if not prediction_done:
time.sleep(logging_interval)
1 change: 1 addition & 0 deletions graphdatascience/ignored_server_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"gds.alpha.pipeline.nodeRegression.predict.stream",
"gds.alpha.pipeline.nodeRegression.selectFeatures",
"gds.alpha.pipeline.nodeRegression.train",
"gds.gnn.nc",
"gds.similarity.cosine",
"gds.similarity.euclidean",
"gds.similarity.euclideanDistance",
Expand Down
Loading