Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(learning): feature_store & graph_store V1 #4237

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion learning_engine/graphlearn-for-pytorch
3 changes: 3 additions & 0 deletions python/graphscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
from graphscope.framework.graph_builder import load_from
from graphscope.version import __version__


from graphscope.client.session import PyG_remote_backend

__doc__ = """
GraphScope - A unified distributed graph computing platform
=====================================================================
Expand Down
108 changes: 108 additions & 0 deletions python/graphscope/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,83 @@ def graphlearn_torch(
self._learning_instance_dict[graph.vineyard_id] = g
graph._attach_learning_instance(g)
return g

def PyG_remote_backend(
self,
graph,
edges,
edge_weights=None,
node_features=None,
edge_features=None,
node_labels=None,
edge_dir="out",
random_node_split=None,
num_clients=1,
manifest_path=None,
client_folder_path="./",
):
from graphscope.learning.gl_torch_graph import GLTorchGraph
from graphscope.learning.utils import fill_params_in_yaml
from graphscope.learning.utils import read_folder_files_content
from graphscope.learning.GSFeatureStore import GSFeatureStore
from graphscope.learning.GSGraphStore import GSGraphStore

handle = {
"vineyard_socket": self._engine_config["vineyard_socket"],
"vineyard_id": graph.vineyard_id,
"fragments": graph.fragments,
"num_servers": len(graph.fragments),
"num_clients": num_clients,
}
manifest_params = {
"NUM_CLIENT_NODES": handle["num_clients"],
"NUM_SERVER_NODES": handle["num_servers"],
"NUM_WORKER_REPLICAS": handle["num_clients"] - 1,
}
if manifest_path is not None:
handle["manifest"] = fill_params_in_yaml(manifest_path, manifest_params)
if client_folder_path is not None:
handle["client_content"] = read_folder_files_content(client_folder_path)

handle = base64.b64encode(
json.dumps(handle).encode("utf-8", errors="ignore")
).decode("utf-8", errors="ignore")
config = {
"edges": edges,
"edge_weights": edge_weights,
"node_features": node_features,
"edge_features": edge_features,
"node_labels": node_labels,
"edge_dir": edge_dir,
"random_node_split": random_node_split,
}
GLTorchGraph.check_params(graph.schema, config)
config = GLTorchGraph.transform_config(config)
config = base64.b64encode(
json.dumps(config).encode("utf-8", errors="ignore")
).decode("utf-8", errors="ignore")
handle, config, endpoints = self._grpc_client.create_learning_instance(
graph.vineyard_id,
handle,
config,
message_pb2.LearningBackend.GRAPHLEARN_TORCH,
)

feature_store = GSFeatureStore(
handle=handle,
config=config,
endpoints=endpoints,
graph=graph)
graph_store = GSGraphStore(
handle=handle,
config=config,
endpoints=endpoints,
graph=graph)

learning_instance = tuple([feature_store, graph_store])
self._learning_instance_dict[graph.vineyard_id] = learning_instance
graph._attach_learning_instance(learning_instance)
return feature_store, graph_store

def nx(self):
if not self.eager():
Expand Down Expand Up @@ -1700,3 +1777,34 @@ def graphlearn_torch(
manifest_path,
client_folder_path,
) # pylint: disable=protected-access

def PyG_remote_backend(
graph,
edges,
edge_weights=None,
node_features=None,
edge_features=None,
node_labels=None,
edge_dir="out",
random_node_split=None,
num_clients=1,
manifest_path=None,
client_folder_path="./",
):
assert graph is not None, "graph cannot be None"
assert (
graph._session is not None
), "The graph object is invalid"
return graph._session.PyG_remote_backend(
graph,
edges,
edge_weights,
node_features,
edge_features,
node_labels,
edge_dir,
random_node_split,
num_clients,
manifest_path,
client_folder_path,
)
148 changes: 148 additions & 0 deletions python/graphscope/learning/GSFeatureStore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import base64
import json
from multiprocessing.reduction import ForkingPickler
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch
from torch import Tensor
from torch_geometric.data import FeatureStore
from torch_geometric.data import TensorAttr
from torch_geometric.typing import FeatureTensorType

from graphscope.learning.graphlearn_torch.data import DeviceGroup
from graphscope.learning.graphlearn_torch.data import Feature

KeyType = Tuple[Optional[str], Optional[str]]


class GSFeatureStore(FeatureStore):
def __init__(self, endpoints, handle=None, config=None, graph=None) -> None:
super().__init__()
# self.store: Dict[KeyType, Tuple[Tensor, Tensor]] = {}
self.handle = handle
self.config = config

if config is not None:
config = json.loads(
base64.b64decode(config.encode("utf-8", errors="ignore")).decode(
"utf-8", errors="ignore"
)
)
self.edge_features = config["edge_features"]
self.node_features = config["node_features"]
self.node_labels = config["node_labels"]
self.edges = config["edges"]

self.endpoints = endpoints

@staticmethod
def key(attr: TensorAttr) -> KeyType:
return (attr.group_name, attr.attr_name)

def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
r"""To be implemented by :class:`GSFeatureStore`."""
raise NotImplementedError

def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]:
r"""To be implemented by :class:`GSFeatureStore`."""
raise NotImplementedError

def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""To be implemented by :class:`GSFeatureStore`."""
raise NotImplementedError

def _get_tensor_size(self, attr: TensorAttr) -> Optional[Tuple[int, ...]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do these methods _get_tensor_size get_all_tensor_attrs _build_features work?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorAttr stores various attributes that uniquely represent the vertex/edge feature.
_get_tensor_size gets the length of the vertex feature corresponding to TensorAttr.
get_all_tensor_attrs retrieves all TensorAttrs that exist in the FeatureStore.
_build_features is what I referred to in dist_data earlier, it doesn't work here, I forgot to delete it.

if self.node_features is not None:
node_tensor = self.node_features.get(attr.group_name)
if node_tensor is not None:
return [len(node_tensor)]
if self.edge_features is not None:
edge_tensor = self.edge_features.get(attr.group_name)
if edge_tensor is not None:
return [len(edge_tensor)]
return None

def get_all_tensor_attrs(self) -> List[TensorAttr]:
TensorAttrList = []
if self.node_features is not None:
for node_type, node_features in self.node_features.items():
for idx, node_feature in enumerate(node_features):
TensorAttrList.append(TensorAttr(node_type, node_feature, torch.tensor([idx])))
if self.edge_features is not None:
for edge_type, edge_features in self.edge_features.items():
for idx, edge_feature in enumerate(edge_features):
TensorAttrList.append(TensorAttr(edge_type, edge_feature, torch.tensor([idx])))
return TensorAttrList

def _build_features(
self,
feature_data,
id2idx,
split_ratio: Union[float, Dict[str, float]] = 0.0,
device_group_list: Optional[List[DeviceGroup]] = None,
device: Optional[int] = None,
with_gpu: bool = True,
dtype: Optional[torch.dtype] = None
):
r""" Build `Feature`s for node/edge feature data.
"""
if feature_data is not None:
if isinstance(feature_data, dict):
# heterogeneous.
if not isinstance(split_ratio, dict):
split_ratio = {
graph_type: float(split_ratio)
for graph_type in feature_data.keys()
}

if id2idx is not None:
assert isinstance(id2idx, dict)
else:
id2idx = {}

features = {}
for graph_type, feat in feature_data.items():
features[graph_type] = Feature(
feat, id2idx.get(graph_type, None),
split_ratio.get(graph_type, 0.0),
device_group_list, device, with_gpu,
dtype if dtype is not None else feat.dtype
)
else:
# homogeneous.
features = Feature(
feature_data, id2idx, float(split_ratio),
device_group_list, device, with_gpu,
dtype if dtype is not None else feature_data.dtype
)
else:
features = None

return features

@classmethod
def from_ipc_handle(cls, ipc_handle):
return cls(*ipc_handle)

def share_ipc(self):
ipc_hanlde = (
list(self.endpoints), self.handle, self.config
)
return ipc_hanlde


## Pickling Registration

def rebuild_featurestore(ipc_handle):
fs = GSFeatureStore.from_ipc_handle(ipc_handle)
return fs

def reduce_featurestore(FeatureStore: GSFeatureStore):
ipc_handle = FeatureStore.share_ipc()
return (rebuild_featurestore, (ipc_handle, ))

ForkingPickler.register(GSFeatureStore, reduce_featurestore)
Loading
Loading