-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
164 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
""" | ||
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
Single-instance random partition assignment | ||
""" | ||
import os | ||
import logging | ||
import json | ||
from typing import List | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import pyarrow.csv as pa_csv | ||
|
||
from .random_partition import LocalPartitionAlgorithm | ||
|
||
class RangePartitionAlgorithm(LocalPartitionAlgorithm): | ||
""" | ||
Single-instance range partitioning algorithm. | ||
The partition algorithm accepts the intermediate output from GraphStorm | ||
gs-processing which matches the requirements of the DGL distributed | ||
partitioning pipeline. It sequentially assigns nodes to partitions | ||
and outputs the node assignment results and partition | ||
metadata file to the provided output directory. | ||
Parameters | ||
---------- | ||
metadata: dict | ||
DGL "Chunked graph data" JSON, as defined in | ||
https://docs.dgl.ai/guide/distributed-preprocessing.html#specification | ||
""" | ||
def _assign_partitions(self, num_partitions: int, partition_dir: str): | ||
num_nodes_per_type = self.metadata_dict["num_nodes_per_type"] # type: List[int] | ||
ntypes = self.metadata_dict["node_type"] # type: List[str] | ||
|
||
# Note: This assumes that the order of node_type is the same as the order num_nodes_per_type | ||
for ntype, num_nodes_for_type in zip(ntypes, num_nodes_per_type): | ||
logging.debug("Generating range partition for node type %s", ntype) | ||
ntype_output_path = os.path.join(partition_dir, f"{ntype}.txt") | ||
|
||
partition_dtype = np.uint8 if num_partitions <= 256 else np.uint16 | ||
|
||
assigned_parts = np.array_split( | ||
np.empty(num_nodes_for_type, dtype=partition_dtype), | ||
num_partitions) | ||
|
||
for idx, assigned_part in enumerate(assigned_parts): | ||
assigned_part[:] = idx | ||
|
||
arrow_partitions = pa.Table.from_arrays( | ||
[np.concatenate(assigned_parts)], | ||
names=["partition_id"]) | ||
options = pa_csv.WriteOptions(include_header=False, delimiter=' ') | ||
pa_csv.write_csv(arrow_partitions, ntype_output_path, write_options=options) | ||
|
||
|
||
def _create_metadata(self, num_partitions: int, partition_dir: str) -> None: | ||
# TODO: DGL currently restricts the names we can give in the metadata, will | ||
# fix once https://github.com/dmlc/dgl/pull/7361 is merged into a release | ||
partition_meta = { | ||
"algo_name": "random", | ||
"num_parts": num_partitions, | ||
"version": "1.0.0" | ||
} | ||
partition_meta_filepath = os.path.join(partition_dir, "partition_meta.json") | ||
with open(partition_meta_filepath, "w", encoding='utf-8') as metafile: | ||
json.dump(partition_meta, metafile) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import os | ||
from tempfile import TemporaryDirectory | ||
|
||
from graphstorm.gpartition import RangePartitionAlgorithm | ||
|
||
from conftest import simple_test_partition | ||
|
||
|
||
def test_create_range_partition(chunked_metadata_dict): | ||
range_partitioner = RangePartitionAlgorithm(chunked_metadata_dict) | ||
# TODO: DGL only supports random and metis as a name downstream | ||
simple_test_partition(range_partitioner, "random", chunked_metadata_dict) | ||
|
||
|
||
def test_range_partition_ordered(chunked_metadata_dict): | ||
with TemporaryDirectory() as tmpdir: | ||
num_parts = 8 | ||
range_partitioner = RangePartitionAlgorithm(chunked_metadata_dict) | ||
range_partitioner.create_partitions(num_parts, tmpdir) | ||
for _, node_type in enumerate(chunked_metadata_dict["node_type"]): | ||
with open( | ||
os.path.join(tmpdir, f"{node_type}.txt"), "r", encoding="utf-8" | ||
) as f: | ||
ntype_partitions = [int(x) for x in f.read().splitlines()] | ||
# Ensure the partition assignments are in increasing order | ||
assert sorted(ntype_partitions) == ntype_partitions |