diff --git a/docs/source/scale/sagemaker.rst b/docs/source/scale/sagemaker.rst index 937a26b900..ffc7401156 100644 --- a/docs/source/scale/sagemaker.rst +++ b/docs/source/scale/sagemaker.rst @@ -237,7 +237,8 @@ data for distributed training. Running the above will take the dataset in chunked format from ``${DATASET_S3_PATH}`` as input and create a DistDGL graph with ``${NUM_PARTITIONS}`` under the output path, ``${OUTPUT_PATH}``. -Currently we only support ``random`` as the partitioning algorithm. +Currently we support ``random`` and ``range`` partition +assignment algorithms. Passing additional arguments to the SageMaker ````````````````````````````````````````````` diff --git a/python/graphstorm/gpartition/__init__.py b/python/graphstorm/gpartition/__init__.py index 2006e0fdeb..32e23ae66a 100644 --- a/python/graphstorm/gpartition/__init__.py +++ b/python/graphstorm/gpartition/__init__.py @@ -15,6 +15,8 @@ Modules for local graph partitioning. """ -from .random_partition import (RandomPartitionAlgorithm) -from .metis_partition import (ParMetisPartitionAlgorithm) -from .partition_config import (ParMETISConfig) +from .metis_partition import ParMetisPartitionAlgorithm +from .partition_algo_base import LocalPartitionAlgorithm +from .partition_config import ParMETISConfig +from .random_partition import RandomPartitionAlgorithm +from .range_partition import RangePartitionAlgorithm diff --git a/python/graphstorm/gpartition/dist_partition_graph.py b/python/graphstorm/gpartition/dist_partition_graph.py index a901f93873..cd5c44356c 100644 --- a/python/graphstorm/gpartition/dist_partition_graph.py +++ b/python/graphstorm/gpartition/dist_partition_graph.py @@ -24,13 +24,18 @@ import os import queue import time +import shutil import subprocess import sys from typing import Dict from threading import Thread -from graphstorm.gpartition import (ParMetisPartitionAlgorithm, ParMETISConfig, - RandomPartitionAlgorithm) +from graphstorm.gpartition import ( + ParMetisPartitionAlgorithm, + ParMETISConfig, + RandomPartitionAlgorithm, + RangePartitionAlgorithm, +) from graphstorm.utils import get_log_level @@ -122,6 +127,8 @@ def main(): partition_config = ParMETISConfig(args.ip_config, args.input_path, args.dgl_tool_path, args.metadata_filename) partitioner = ParMetisPartitionAlgorithm(metadata_dict, partition_config) + elif args.partition_algorithm == "range": + partitioner = RangePartitionAlgorithm(metadata_dict) else: raise RuntimeError(f"Unknown partition algorithm {args.part_algorithm}") @@ -133,7 +140,10 @@ def main(): part_assignment_dir) part_end = time.time() - logging.info("Partition assignment took %f sec", part_end - part_start) + logging.info("Partition assignment with algorithm '%s' took %f sec", + args.partition_algorithm, + part_end - part_start, + ) if args.do_dispatch: run_build_dglgraph( @@ -147,6 +157,18 @@ def main(): logging.info("DGL graph building took %f sec", part_end - time.time()) + # Copy raw_id_mappings to dist_graph if they exist in the input + raw_id_mappings_path = os.path.join(args.input_path, "raw_id_mappings") + + if os.path.exists(raw_id_mappings_path): + logging.info("Copying raw_id_mappings to dist_graph") + shutil.copytree( + raw_id_mappings_path, + os.path.join(output_path, 'dist_graph/raw_id_mappings'), + dirs_exist_ok=True, + ) + + logging.info('Partition assignment and DGL graph creation took %f seconds', time.time() - start) @@ -166,7 +188,7 @@ def parse_args() -> argparse.Namespace: argparser.add_argument("--dgl-tool-path", type=str, default="/root/dgl/tools", help="The path to dgl/tools") argparser.add_argument("--partition-algorithm", type=str, default="random", - choices=["random", "parmetis"], help="Partition algorithm to use.") + choices=["random", "parmetis", "range"], help="Partition algorithm to use.") argparser.add_argument("--ip-config", type=str, help=("A file storing a list of IPs, one line for " "each instance of the partition cluster.")) diff --git a/python/graphstorm/gpartition/partition_algo_base.py b/python/graphstorm/gpartition/partition_algo_base.py index 6c95cc5257..1db424996e 100644 --- a/python/graphstorm/gpartition/partition_algo_base.py +++ b/python/graphstorm/gpartition/partition_algo_base.py @@ -71,7 +71,8 @@ def create_partitions(self, num_partitions: int, partition_assignment_dir: str): @abstractmethod def _assign_partitions(self, num_partitions: int, partition_dir: str): """Assigns each node in the data to a partition from 0 to `num_partitions-1`, - and creates one "{ntype}>.txt" partition assignment file per node type. + and creates one "{ntype}.txt" partition assignment file per node type + under the ``partition_dir``. Parameters ---------- diff --git a/python/graphstorm/gpartition/random_partition.py b/python/graphstorm/gpartition/random_partition.py index 8c8e9e86f5..9f992d6f56 100644 --- a/python/graphstorm/gpartition/random_partition.py +++ b/python/graphstorm/gpartition/random_partition.py @@ -50,10 +50,16 @@ def _assign_partitions(self, num_partitions: int, partition_dir: str): logging.info("Generating random partition for node type %s", ntype) ntype_output = os.path.join(partition_dir, f"{ntype}.txt") - partition_assignment = np.random.randint(0, num_partitions, (num_nodes_for_type,)) + partition_dtype = np.uint8 if num_partitions <= 256 else np.uint16 + + partition_assignment = np.random.randint( + 0, + num_partitions, + (num_nodes_for_type,), + dtype=partition_dtype) arrow_partitions = pa.Table.from_arrays( - [pa.array(partition_assignment, type=pa.int64())], + [pa.array(partition_assignment)], names=["partition_id"]) options = pa_csv.WriteOptions(include_header=False, delimiter=' ') pa_csv.write_csv(arrow_partitions, ntype_output, write_options=options) diff --git a/python/graphstorm/gpartition/range_partition.py b/python/graphstorm/gpartition/range_partition.py new file mode 100644 index 0000000000..0443784215 --- /dev/null +++ b/python/graphstorm/gpartition/range_partition.py @@ -0,0 +1,81 @@ +""" + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Single-instance random partition assignment +""" +import os +import logging +import json +from typing import List + +import numpy as np +import pyarrow as pa +import pyarrow.csv as pa_csv + +from .random_partition import LocalPartitionAlgorithm + +class RangePartitionAlgorithm(LocalPartitionAlgorithm): + """ + Single-instance range partitioning algorithm. + + The partition algorithm accepts the intermediate output from GraphStorm + gs-processing which matches the requirements of the DGL distributed + partitioning pipeline. It sequentially assigns nodes to partitions + and outputs the node assignment results and partition + metadata file to the provided output directory. + + + Parameters + ---------- + metadata: dict + DGL "Chunked graph data" JSON, as defined in + https://docs.dgl.ai/guide/distributed-preprocessing.html#specification + """ + def _assign_partitions(self, num_partitions: int, partition_dir: str): + num_nodes_per_type = self.metadata_dict["num_nodes_per_type"] # type: List[int] + ntypes = self.metadata_dict["node_type"] # type: List[str] + + # Note: This assumes that the order of node_type is the same as the order num_nodes_per_type + for ntype, num_nodes_for_type in zip(ntypes, num_nodes_per_type): + logging.debug("Generating range partition for node type %s", ntype) + ntype_output_path = os.path.join(partition_dir, f"{ntype}.txt") + + partition_dtype = np.uint8 if num_partitions <= 256 else np.uint16 + + assigned_parts = np.array_split( + np.empty(num_nodes_for_type, dtype=partition_dtype), + num_partitions) + + for idx, assigned_part in enumerate(assigned_parts): + assigned_part[:] = idx + + arrow_partitions = pa.Table.from_arrays( + [np.concatenate(assigned_parts)], + names=["partition_id"]) + options = pa_csv.WriteOptions(include_header=False, delimiter=' ') + pa_csv.write_csv(arrow_partitions, ntype_output_path, write_options=options) + + + def _create_metadata(self, num_partitions: int, partition_dir: str) -> None: + # TODO: DGL currently restricts the names we can give in the metadata, will + # fix once https://github.com/dmlc/dgl/pull/7361 is merged into a release + partition_meta = { + "algo_name": "random", + "num_parts": num_partitions, + "version": "1.0.0" + } + partition_meta_filepath = os.path.join(partition_dir, "partition_meta.json") + with open(partition_meta_filepath, "w", encoding='utf-8') as metafile: + json.dump(partition_meta, metafile) diff --git a/python/graphstorm/sagemaker/sagemaker_partition.py b/python/graphstorm/sagemaker/sagemaker_partition.py index 9eb1c3a31f..6976342c6a 100644 --- a/python/graphstorm/sagemaker/sagemaker_partition.py +++ b/python/graphstorm/sagemaker/sagemaker_partition.py @@ -24,15 +24,21 @@ import time import subprocess from threading import Thread, Event +from typing import List -import numpy as np import boto3 +import botocore +import numpy as np import sagemaker +from joblib import Parallel, delayed # pylint: disable=wrong-import-order from graphstorm.sagemaker import utils from .s3_utils import download_data_from_s3, upload_file_to_s3 -from .sm_partition_algorithm import (SageMakerRandomPartitioner, - SageMakerPartitionerConfig) +from .sm_partition_algorithm import ( + SageMakerPartitionerConfig, + SageMakerRandomPartitioner, + SageMakerRangePartitioner, + ) DGL_TOOL_PATH = "/root/dgl/tools" @@ -110,8 +116,8 @@ def launch_build_dglgraph( return thread -def download_graph(graph_data_s3, graph_config, world_size, - local_rank, local_path, sagemaker_session): +def parallel_download_graph(graph_data_s3, graph_config, world_size, + local_rank, local_path, region): """ download graph structure data Parameters @@ -126,77 +132,81 @@ def download_graph(graph_data_s3, graph_config, world_size, Path to store graph data local_path: str directory path under which the data will be downloaded - sagemaker_session: sagemaker.session.Session - sagemaker_session to run download + region: str + AWS region Return ------ local_path: str Local path to downloaded graph data """ + s3_client = boto3.client( + "s3", + config=botocore.config.Config(max_pool_connections=150), + region_name=region + ) + def download_file_locally(relative_filepath: str) -> None: + file_s3_path = os.path.join(graph_data_s3, relative_filepath.strip('./')) + logging.debug("Download %s from %s", + relative_filepath, file_s3_path) + local_dir = local_path \ + if len(relative_filepath.rpartition('/')) <= 1 else \ + os.path.join(local_path, relative_filepath.rpartition('/')[0]) + if not os.path.exists(local_dir): + os.makedirs(local_dir, exist_ok=True) + # TODO: If using joblib processes, we'll need to remove the + # SM session from the closure because it can't be pickled + bucket = file_s3_path.split('/')[2] + key = '/'.join(file_s3_path.split('/')[3:]) + s3_client.download_file( + bucket, + key, + os.path.join(local_dir, relative_filepath.rpartition('/')[2]) + ) + + def get_local_file_list(file_list: List[str]) -> List[str]: + """ + Get the list of local files from the list of remote files. + """ + # TODO: Use dgl.tools.distpartitioning.utils.generate_read_list + # once we move the code over from DGL + local_file_idxs = np.array_split(np.arange(len(file_list)), world_size) + local_read_list = [file_list[i] for i in local_file_idxs[local_rank]] + return local_read_list - # download edge info edges = graph_config["edges"] - for etype, edge_data in edges.items(): - if local_rank == 0: - logging.info("Downloading edge structure for edge type '%s'", etype) - edge_file_list = edge_data["data"] - read_list = np.array_split(np.arange(len(edge_file_list)), world_size) - for i, efile in enumerate(edge_file_list): - # TODO: Only download round-robin if ParMETIS will run, skip otherwise - # Download files both in round robin and sequential assignment - if i % world_size == local_rank or i in read_list[local_rank]: - efile = edge_file_list[i] - file_s3_path = os.path.join(graph_data_s3, efile.strip('./')) - logging.debug("Download %s from %s", - efile, file_s3_path) - local_dir = local_path \ - if len(efile.rpartition('/')) <= 1 else \ - os.path.join(local_path, efile.rpartition('/')[0]) - download_data_from_s3(file_s3_path, local_dir, - sagemaker_session=sagemaker_session) - - # download node feature + # We create a pool of workers and re-use it at every step + with Parallel(n_jobs=min(16, os.cpu_count() or 16), prefer='threads') as parallel: + # download edge structures + for etype, edge_data in edges.items(): + if local_rank == 0: + logging.info("Downloading edge structure for edge type '%s'", etype) + # TODO: If using ParMETIS also download edge structures in round-robin assignment + local_efile_list = get_local_file_list(edge_data["data"]) + parallel(delayed( + download_file_locally)(efile) for efile in local_efile_list) + + # download node features node_data = graph_config["node_data"] for ntype, ndata in node_data.items(): - for feat_name, feat_data in ndata.items(): + for nfeat_name, nfeat_data in ndata.items(): if local_rank == 0: logging.info("Downloading node feature '%s' of node type '%s'", - feat_name, ntype) - num_files = len(feat_data["data"]) - # TODO: Use dgl.tools.distpartitioning.utils.generate_read_list - # once we move the code over from DGL - read_list = np.array_split(np.arange(num_files), world_size) - for i in read_list[local_rank].tolist(): - nf_file = feat_data["data"][i] - file_s3_path = os.path.join(graph_data_s3, nf_file.strip('./')) - logging.debug("Download %s from %s", - nf_file, file_s3_path) - local_dir = local_path \ - if len(nf_file.rpartition('/')) <= 1 else \ - os.path.join(local_path, nf_file.rpartition('/')[0]) - download_data_from_s3(file_s3_path, local_dir, - sagemaker_session=sagemaker_session) - - # download edge feature + nfeat_name, ntype) + local_nfeature_list = get_local_file_list(nfeat_data["data"]) + parallel(delayed( + download_file_locally)(nfeature_file) for nfeature_file in local_nfeature_list) + + # download edge features edge_data = graph_config["edge_data"] for e_feat_type, edata in edge_data.items(): - for feat_name, feat_data in edata.items(): + for efeat_name, efeat_data in edata.items(): if local_rank == 0: - logging.info("Downloading edge feature '%s' of '%s'", - feat_name, e_feat_type) - num_files = len(feat_data["data"]) - read_list = np.array_split(np.arange(num_files), world_size) - for i in read_list[local_rank].tolist(): - ef_file = feat_data["data"][i] - file_s3_path = os.path.join(graph_data_s3, ef_file.strip('./')) - logging.debug("Download %s from %s", - ef_file, file_s3_path) - local_dir = local_path \ - if len(ef_file.rpartition('/')) <= 1 else \ - os.path.join(local_path, ef_file.rpartition('/')[0]) - download_data_from_s3(file_s3_path, local_dir, - sagemaker_session=sagemaker_session) + logging.info("Downloading edge feature '%s' of edge type '%s'", + efeat_name, e_feat_type) + local_efeature_list = get_local_file_list(efeat_data["data"]) + parallel(delayed( + download_file_locally)(efeature_file) for efeature_file in local_efeature_list) return local_path @@ -215,8 +225,13 @@ def run_partition(job_config: PartitionJobConfig): metadata_filename = job_config.metadata_filename skip_partitioning = job_config.skip_partitioning == 'true' - with open("/opt/ml/config/resourceconfig.json", "r", encoding="utf-8") as f: - sm_env = json.load(f) + # Get env from either processing job or training job + try: + with open("/opt/ml/config/resourceconfig.json", "r", encoding="utf-8") as f: + sm_env = json.load(f) + except FileNotFoundError: + sm_env = json.loads(os.environ['SM_TRAINING_ENV']) + hosts = sm_env['hosts'] current_host = sm_env['current_host'] world_size = len(hosts) @@ -291,13 +306,13 @@ def run_partition(job_config: PartitionJobConfig): logging.info("Downloading graph data from %s into %s", graph_data_s3, tmp_data_path) - graph_data_path = download_graph( + graph_data_path = parallel_download_graph( graph_data_s3, graph_config, world_size, host_rank, tmp_data_path, - sagemaker_session) + os.environ['AWS_REGION']) partition_config = SageMakerPartitionerConfig( metadata_file=meta_info_file, @@ -307,6 +322,8 @@ def run_partition(job_config: PartitionJobConfig): if job_config.partition_algorithm == 'random': sm_partitioner = SageMakerRandomPartitioner(partition_config) + elif job_config.partition_algorithm == 'range': + sm_partitioner = SageMakerRangePartitioner(partition_config) else: raise RuntimeError(f"Unknown partition algorithm: '{job_config.partition_algorithm}'", ) @@ -362,7 +379,7 @@ def data_dispatch_step(partition_dir): build_dglgraph_task.join() err_code = state_q.get() if err_code != 0: - raise RuntimeError("build dglgrah failed") + raise RuntimeError("build dglgraph failed") task_end = Event() thread = Thread(target=utils.keep_alive, @@ -387,4 +404,22 @@ def data_dispatch_step(partition_dir): upload_file_to_s3(s3_dglgraph_output, dglgraph_output, sagemaker_session) logging.info("Rank %s completed all tasks, exiting...", host_rank) + # Leader instance copies raw_id_mappings from input to dist_graph output on S3 + # using aws cli, usually much faster than using boto3 + if host_rank == 0: + raw_id_mappings_s3_path = os.path.join(graph_data_s3, "raw_id_mappings") + # Copy raw_id_mappings from input to dist_graph output on S3 + subprocess.call([ + "aws", "configure", "set", "default.s3.max_concurrent_requests", "150"]) + subprocess.check_call( + [ + "aws", + "s3", + "sync", + "--only-show-errors", + "--region", + os.environ["AWS_REGION"], + raw_id_mappings_s3_path, + f"{s3_dglgraph_output}/raw_id_mappings"]) + sock.close() diff --git a/python/graphstorm/sagemaker/sm_partition_algorithm.py b/python/graphstorm/sagemaker/sm_partition_algorithm.py index f935d7ae0f..d402452195 100644 --- a/python/graphstorm/sagemaker/sm_partition_algorithm.py +++ b/python/graphstorm/sagemaker/sm_partition_algorithm.py @@ -24,7 +24,7 @@ from typing import Tuple from sagemaker import Session -from graphstorm.gpartition import RandomPartitionAlgorithm +from graphstorm.gpartition import RandomPartitionAlgorithm, RangePartitionAlgorithm, LocalPartitionAlgorithm from .s3_utils import upload_file_to_s3 @@ -179,19 +179,16 @@ def _upload_results_to_s3(self, local_partition_directory: str, output_s3_path: S3 prefix to upload the partitioning results to. """ -class SageMakerRandomPartitioner(SageMakerPartitioner): # pylint: disable=too-few-public-methods - """ - Single-instance random partitioning algorithm running on SageMaker - """ - def _run_partitioning(self, num_partitions: int) -> str: - random_part = RandomPartitionAlgorithm(self.metadata) +class SageMakerSingleInstancePartitioner(SageMakerPartitioner): + local_partitioner: LocalPartitionAlgorithm + def _run_partitioning(self, num_partitions: int) -> str: part_assignment_dir = os.path.join(self.local_output_path, "partition_assignment") os.makedirs(part_assignment_dir, exist_ok=True) # Only the leader creates partition assignments if self.rank == 0: - random_part.create_partitions(num_partitions, part_assignment_dir) + self.local_partitioner.create_partitions(num_partitions, part_assignment_dir) return part_assignment_dir @@ -205,3 +202,21 @@ def _upload_results_to_s3(self, local_partition_directory: str, output_s3_path: else: # Workers do not hold any partitioning information locally pass + +class SageMakerRandomPartitioner(SageMakerSingleInstancePartitioner): # pylint: disable=too-few-public-methods + """ + Single-instance random partitioning algorithm running on SageMaker + """ + def __init__(self, partition_config: SageMakerPartitionerConfig): + super().__init__(partition_config) + + self.local_partitioner = RandomPartitionAlgorithm(self.metadata) + +class SageMakerRangePartitioner(SageMakerSingleInstancePartitioner): # pylint: disable=too-few-public-methods + """ + Single-instance range partitioning algorithm running on SageMaker + """ + def __init__(self, partition_config: SageMakerPartitionerConfig): + super().__init__(partition_config) + + self.local_partitioner = RangePartitionAlgorithm(self.metadata) diff --git a/sagemaker/launch/launch_partition.py b/sagemaker/launch/launch_partition.py index 4cd2f619ba..28491af7ee 100644 --- a/sagemaker/launch/launch_partition.py +++ b/sagemaker/launch/launch_partition.py @@ -105,7 +105,7 @@ def get_partition_parser(): help="File name of metadata config file for chunked format data") partition_args.add_argument("--partition-algorithm", type=str, default='random', - help="Partition algorithm to use.", choices=['random']) + help="Partition algorithm to use.", choices=['random', 'range']) partition_args.add_argument("--skip-partitioning", action='store_true', help="When set, we skip the partitioning step. " diff --git a/sagemaker/local/generate_sagemaker_docker_compose.py b/sagemaker/local/generate_sagemaker_docker_compose.py index f54e40b033..5c7ce49c0c 100644 --- a/sagemaker/local/generate_sagemaker_docker_compose.py +++ b/sagemaker/local/generate_sagemaker_docker_compose.py @@ -94,7 +94,7 @@ def get_parser(): help="Skip partitioning step and only do GSL object creation. Partition assignments " "need to exist under the /partitions location.") partition_parser.add_argument("--partition-algorithm", required=False, - default='random', choices=['random'], + default='random', choices=['random', 'range'], help="Partition algorithm to use.") partition_parser.add_argument("--metadata-filename", required=False, default="metadata.json", help="Metadata file that describes the files " diff --git a/sagemaker/run/partition_entry.py b/sagemaker/run/partition_entry.py index 23e09f6e0b..7f2ed4dc95 100644 --- a/sagemaker/run/partition_entry.py +++ b/sagemaker/run/partition_entry.py @@ -32,7 +32,7 @@ def partition_arg_parser(): parser.add_argument("--metadata-filename", type=str, default="metadata.json", help="file name of metadata config file") parser.add_argument("--partition-algorithm", type=str, default='random', - choices=['random'], + choices=['random', 'range'], help="Partition algorithm to use.") parser.add_argument("--skip-partitioning", type=str, default='false', choices=['true', 'false'], diff --git a/tests/unit-tests/gpartition/conftest.py b/tests/unit-tests/gpartition/conftest.py new file mode 100644 index 0000000000..6e522e3a5e --- /dev/null +++ b/tests/unit-tests/gpartition/conftest.py @@ -0,0 +1,73 @@ +""" + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os +import json +from tempfile import TemporaryDirectory +from typing import Dict + +import pytest + +from graphstorm.gpartition import LocalPartitionAlgorithm + +@pytest.fixture(scope="module", name="chunked_metadata_dict") +def metadata_dict_fixture() -> Dict: + return { + "num_nodes_per_type": [10, 20], + "node_type": ["a", "b"], + } + +def simple_test_partition( + partition_algorithm: LocalPartitionAlgorithm, + algorithm_name: str, + chunked_metadata_dict: Dict): + """Ensures that the provided algorithm will create the correct number of partitions, + and that the partition assignment and metadata files are correctly created. + + Parameters + ---------- + partition_algorithm : LocalPartitionAlgorithm + An implementation of the LocalPartitionAlgorithm base class. + algorithm_name : str + The expected name for the partition algorithm in the metadata + chunked_metadata_dict : Dict + The metadata dictionary that was passed to the algorithm. + """ + + with TemporaryDirectory() as tmpdir: + num_parts = 4 + # This test function is designed to be used with the config + # provided by metadata_dict_fixture() + assert partition_algorithm.metadata_dict == chunked_metadata_dict + partition_algorithm.create_partitions(num_parts, tmpdir) + + assert os.path.exists(os.path.join(tmpdir, "a.txt")) + assert os.path.exists(os.path.join(tmpdir, "b.txt")) + assert os.path.exists(os.path.join(tmpdir, "partition_meta.json")) + + # Ensure contents of partition_meta.json are correct + with open(os.path.join(tmpdir, "partition_meta.json"), 'r', encoding="utf-8") as f: + part_meta = json.load(f) + assert part_meta["num_parts"] == num_parts + assert part_meta["algo_name"] == algorithm_name + + # Ensure contents of partition assignment files are correct + for i, node_type in enumerate(chunked_metadata_dict["node_type"]): + with open(os.path.join(tmpdir, f"{node_type}.txt"), "r", encoding="utf-8") as f: + node_partitions = f.read().splitlines() + assert len(node_partitions) == chunked_metadata_dict["num_nodes_per_type"][i] + for part_id in node_partitions: + assert part_id.isdigit() + assert int(part_id) < num_parts diff --git a/tests/unit-tests/gpartition/test_random_partition.py b/tests/unit-tests/gpartition/test_random_partition.py index d61a9f7251..5104f7337c 100644 --- a/tests/unit-tests/gpartition/test_random_partition.py +++ b/tests/unit-tests/gpartition/test_random_partition.py @@ -13,48 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ -import json -import os -from typing import Dict -from tempfile import TemporaryDirectory - -import pytest - from graphstorm.gpartition import RandomPartitionAlgorithm -@pytest.fixture(scope="module", name="chunked_metadata_dict") -def metadata_dict_fixture() -> Dict: - return { - "num_nodes_per_type": [10, 20], - "node_type": ["a", "b"], - } +from conftest import simple_test_partition def test_create_random_partition(chunked_metadata_dict): - with TemporaryDirectory() as tmpdir: - num_parts = 4 - rand_partitioner = RandomPartitionAlgorithm(chunked_metadata_dict) - rand_partitioner.create_partitions(num_parts, tmpdir) - - assert os.path.exists(os.path.join(tmpdir, "a.txt")) - assert os.path.exists(os.path.join(tmpdir, "b.txt")) - assert os.path.exists(os.path.join(tmpdir, "partition_meta.json")) - - # Ensure contents of partition_meta.json are correct - with open(os.path.join(tmpdir, "partition_meta.json"), 'r', encoding="utf-8") as f: - part_meta = json.load(f) - assert part_meta["num_parts"] == num_parts - assert part_meta["algo_name"] == "random" - - # Ensure contents of partition assignment files are correct - for i, node_type in enumerate(chunked_metadata_dict["node_type"]): - with open(os.path.join(tmpdir, f"{node_type}.txt"), "r", encoding="utf-8") as f: - node_partitions = f.read().splitlines() - assert len(node_partitions) == chunked_metadata_dict["num_nodes_per_type"][i] - for part_id in node_partitions: - assert part_id.isdigit() - assert int(part_id) < num_parts - - -if __name__ == '__main__': - test_create_random_partition(metadata_dict_fixture()) + rand_partitioner = RandomPartitionAlgorithm(chunked_metadata_dict) + simple_test_partition(rand_partitioner, "random", chunked_metadata_dict) diff --git a/tests/unit-tests/gpartition/test_range_partition.py b/tests/unit-tests/gpartition/test_range_partition.py new file mode 100644 index 0000000000..1d7f12108f --- /dev/null +++ b/tests/unit-tests/gpartition/test_range_partition.py @@ -0,0 +1,42 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +from tempfile import TemporaryDirectory + +from graphstorm.gpartition import RangePartitionAlgorithm + +from conftest import simple_test_partition + + +def test_create_range_partition(chunked_metadata_dict): + range_partitioner = RangePartitionAlgorithm(chunked_metadata_dict) + # TODO: DGL only supports random and metis as a name downstream + simple_test_partition(range_partitioner, "random", chunked_metadata_dict) + + +def test_range_partition_ordered(chunked_metadata_dict): + with TemporaryDirectory() as tmpdir: + num_parts = 8 + range_partitioner = RangePartitionAlgorithm(chunked_metadata_dict) + range_partitioner.create_partitions(num_parts, tmpdir) + for _, node_type in enumerate(chunked_metadata_dict["node_type"]): + with open( + os.path.join(tmpdir, f"{node_type}.txt"), "r", encoding="utf-8" + ) as f: + ntype_partitions = [int(x) for x in f.read().splitlines()] + # Ensure the partition assignments are in increasing order + assert sorted(ntype_partitions) == ntype_partitions