Skip to content

Commit

Permalink
[GPartition/SageMaker] Improvements to partition for gpartition and S…
Browse files Browse the repository at this point in the history
…ageMaker.

We use this commit as a base for range partitioning.
  • Loading branch information
thvasilo committed Jun 17, 2024
1 parent 5199149 commit ee7cc86
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 111 deletions.
1 change: 1 addition & 0 deletions python/graphstorm/gpartition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .random_partition import (RandomPartitionAlgorithm)
from .metis_partition import (ParMetisPartitionAlgorithm)
from .partition_config import (ParMETISConfig)
from .partition_algo_base import LocalPartitionAlgorithm
25 changes: 22 additions & 3 deletions python/graphstorm/gpartition/dist_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
import os
import queue
import time
import shutil
import subprocess
import sys
from typing import Dict
from threading import Thread

from graphstorm.gpartition import (ParMetisPartitionAlgorithm, ParMETISConfig,
RandomPartitionAlgorithm)
from graphstorm.gpartition import (
ParMetisPartitionAlgorithm,
ParMETISConfig,
RandomPartitionAlgorithm,
)
from graphstorm.utils import get_log_level


Expand Down Expand Up @@ -133,7 +137,10 @@ def main():
part_assignment_dir)

part_end = time.time()
logging.info("Partition assignment took %f sec", part_end - part_start)
logging.info("Partition assignment with algorithm '%s' took %f sec",
args.partition_algorithm,
part_end - part_start,
)

if args.do_dispatch:
run_build_dglgraph(
Expand All @@ -147,6 +154,18 @@ def main():

logging.info("DGL graph building took %f sec", part_end - time.time())

# Copy raw_id_mappings to dist_graph if they exist in the input
raw_id_mappings_path = os.path.join(args.input_path, "raw_id_mappings")

if os.path.exists(raw_id_mappings_path):
logging.info("Copying raw_id_mappings to dist_graph")
shutil.copytree(
raw_id_mappings_path,
os.path.join(output_path, 'dist_graph/raw_id_mappings'),
dirs_exist_ok=True,
)


logging.info('Partition assignment and DGL graph creation took %f seconds',
time.time() - start)

Expand Down
10 changes: 8 additions & 2 deletions python/graphstorm/gpartition/random_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,16 @@ def _assign_partitions(self, num_partitions: int, partition_dir: str):
logging.info("Generating random partition for node type %s", ntype)
ntype_output = os.path.join(partition_dir, f"{ntype}.txt")

partition_assignment = np.random.randint(0, num_partitions, (num_nodes_for_type,))
partition_dtype = np.uint8 if num_partitions <= 256 else np.uint16

partition_assignment = np.random.randint(
0,
num_partitions,
(num_nodes_for_type,),
dtype=partition_dtype)

arrow_partitions = pa.Table.from_arrays(
[pa.array(partition_assignment, type=pa.int64())],
[pa.array(partition_assignment)],
names=["partition_id"])
options = pa_csv.WriteOptions(include_header=False, delimiter=' ')
pa_csv.write_csv(arrow_partitions, ntype_output, write_options=options)
Expand Down
164 changes: 98 additions & 66 deletions python/graphstorm/sagemaker/sagemaker_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,20 @@
import time
import subprocess
from threading import Thread, Event
from typing import List

import numpy as np
import boto3
import botocore
import numpy as np
import sagemaker
from joblib import Parallel, delayed

from graphstorm.sagemaker import utils
from .s3_utils import download_data_from_s3, upload_file_to_s3
from .sm_partition_algorithm import (SageMakerRandomPartitioner,
SageMakerPartitionerConfig)
from .sm_partition_algorithm import (
SageMakerPartitionerConfig,
SageMakerRandomPartitioner,
)

DGL_TOOL_PATH = "/root/dgl/tools"

Expand Down Expand Up @@ -110,8 +115,8 @@ def launch_build_dglgraph(
return thread


def download_graph(graph_data_s3, graph_config, world_size,
local_rank, local_path, sagemaker_session):
def parallel_download_graph(graph_data_s3, graph_config, world_size,
local_rank, local_path, region):
""" download graph structure data
Parameters
Expand All @@ -126,77 +131,81 @@ def download_graph(graph_data_s3, graph_config, world_size,
Path to store graph data
local_path: str
directory path under which the data will be downloaded
sagemaker_session: sagemaker.session.Session
sagemaker_session to run download
region: str
AWS region
Return
------
local_path: str
Local path to downloaded graph data
"""
s3_client = boto3.client(
"s3",
config=botocore.config.Config(max_pool_connections=150),
region_name=region
)
def download_file_locally(relative_filepath: str) -> None:
file_s3_path = os.path.join(graph_data_s3, relative_filepath.strip('./'))
logging.debug("Download %s from %s",
relative_filepath, file_s3_path)
local_dir = local_path \
if len(relative_filepath.rpartition('/')) <= 1 else \
os.path.join(local_path, relative_filepath.rpartition('/')[0])
if not os.path.exists(local_dir):
os.makedirs(local_dir, exist_ok=True)
# TODO: If using joblib processes, we'll need to remove the
# SM session from the closure because it can't be pickled
bucket = file_s3_path.split('/')[2]
key = '/'.join(file_s3_path.split('/')[3:])
s3_client.download_file(
bucket,
key,
os.path.join(local_dir, relative_filepath.rpartition('/')[2])
)

def get_local_file_list(file_list: List[str]) -> List[str]:
"""
Get the list of local files from the list of remote files.
"""
# TODO: Use dgl.tools.distpartitioning.utils.generate_read_list
# once we move the code over from DGL
local_file_idxs = np.array_split(np.arange(len(file_list)), world_size)
local_read_list = [file_list[i] for i in local_file_idxs[local_rank]]
return local_read_list

# download edge info
edges = graph_config["edges"]
for etype, edge_data in edges.items():
if local_rank == 0:
logging.info("Downloading edge structure for edge type '%s'", etype)
edge_file_list = edge_data["data"]
read_list = np.array_split(np.arange(len(edge_file_list)), world_size)
for i, efile in enumerate(edge_file_list):
# TODO: Only download round-robin if ParMETIS will run, skip otherwise
# Download files both in round robin and sequential assignment
if i % world_size == local_rank or i in read_list[local_rank]:
efile = edge_file_list[i]
file_s3_path = os.path.join(graph_data_s3, efile.strip('./'))
logging.debug("Download %s from %s",
efile, file_s3_path)
local_dir = local_path \
if len(efile.rpartition('/')) <= 1 else \
os.path.join(local_path, efile.rpartition('/')[0])
download_data_from_s3(file_s3_path, local_dir,
sagemaker_session=sagemaker_session)

# download node feature
# We create a pool of workers and re-use it at every step
with Parallel(n_jobs=min(16, os.cpu_count() or 16), prefer='threads') as parallel:
# download edge structures
for etype, edge_data in edges.items():
if local_rank == 0:
logging.info("Downloading edge structure for edge type '%s'", etype)
# TODO: If using ParMETIS also download edge structures in round-robin assignment
local_efile_list = get_local_file_list(edge_data["data"])
parallel(delayed(
download_file_locally)(efile) for efile in local_efile_list)

# download node features
node_data = graph_config["node_data"]
for ntype, ndata in node_data.items():
for feat_name, feat_data in ndata.items():
for nfeat_name, nfeat_data in ndata.items():
if local_rank == 0:
logging.info("Downloading node feature '%s' of node type '%s'",
feat_name, ntype)
num_files = len(feat_data["data"])
# TODO: Use dgl.tools.distpartitioning.utils.generate_read_list
# once we move the code over from DGL
read_list = np.array_split(np.arange(num_files), world_size)
for i in read_list[local_rank].tolist():
nf_file = feat_data["data"][i]
file_s3_path = os.path.join(graph_data_s3, nf_file.strip('./'))
logging.debug("Download %s from %s",
nf_file, file_s3_path)
local_dir = local_path \
if len(nf_file.rpartition('/')) <= 1 else \
os.path.join(local_path, nf_file.rpartition('/')[0])
download_data_from_s3(file_s3_path, local_dir,
sagemaker_session=sagemaker_session)

# download edge feature
nfeat_name, ntype)
local_nfeature_list = get_local_file_list(nfeat_data["data"])
parallel(delayed(
download_file_locally)(nfeature_file) for nfeature_file in local_nfeature_list)

# download edge features
edge_data = graph_config["edge_data"]
for e_feat_type, edata in edge_data.items():
for feat_name, feat_data in edata.items():
for efeat_name, efeat_data in edata.items():
if local_rank == 0:
logging.info("Downloading edge feature '%s' of '%s'",
feat_name, e_feat_type)
num_files = len(feat_data["data"])
read_list = np.array_split(np.arange(num_files), world_size)
for i in read_list[local_rank].tolist():
ef_file = feat_data["data"][i]
file_s3_path = os.path.join(graph_data_s3, ef_file.strip('./'))
logging.debug("Download %s from %s",
ef_file, file_s3_path)
local_dir = local_path \
if len(ef_file.rpartition('/')) <= 1 else \
os.path.join(local_path, ef_file.rpartition('/')[0])
download_data_from_s3(file_s3_path, local_dir,
sagemaker_session=sagemaker_session)
logging.info("Downloading edge feature '%s' of edge type '%s'",
efeat_name, e_feat_type)
local_efeature_list = get_local_file_list(efeat_data["data"])
parallel(delayed(
download_file_locally)(efeature_file) for efeature_file in local_efeature_list)

return local_path

Expand All @@ -215,8 +224,13 @@ def run_partition(job_config: PartitionJobConfig):
metadata_filename = job_config.metadata_filename
skip_partitioning = job_config.skip_partitioning == 'true'

with open("/opt/ml/config/resourceconfig.json", "r", encoding="utf-8") as f:
sm_env = json.load(f)
# Get env from either processing job or training job
try:
with open("/opt/ml/config/resourceconfig.json", "r", encoding="utf-8") as f:
sm_env = json.load(f)
except FileNotFoundError:
sm_env = json.loads(os.environ['SM_TRAINING_ENV'])

hosts = sm_env['hosts']
current_host = sm_env['current_host']
world_size = len(hosts)
Expand Down Expand Up @@ -291,13 +305,13 @@ def run_partition(job_config: PartitionJobConfig):

logging.info("Downloading graph data from %s into %s",
graph_data_s3, tmp_data_path)
graph_data_path = download_graph(
graph_data_path = parallel_download_graph(
graph_data_s3,
graph_config,
world_size,
host_rank,
tmp_data_path,
sagemaker_session)
os.environ['AWS_REGION'])

partition_config = SageMakerPartitionerConfig(
metadata_file=meta_info_file,
Expand Down Expand Up @@ -362,7 +376,7 @@ def data_dispatch_step(partition_dir):
build_dglgraph_task.join()
err_code = state_q.get()
if err_code != 0:
raise RuntimeError("build dglgrah failed")
raise RuntimeError("build dglgraph failed")

task_end = Event()
thread = Thread(target=utils.keep_alive,
Expand All @@ -387,4 +401,22 @@ def data_dispatch_step(partition_dir):
upload_file_to_s3(s3_dglgraph_output, dglgraph_output, sagemaker_session)
logging.info("Rank %s completed all tasks, exiting...", host_rank)

# Leader instance copies raw_id_mappings from input to dist_graph output on S3
# using aws cli, usually much faster than using boto3
if host_rank == 0:
raw_id_mappings_s3_path = os.path.join(graph_data_s3, "raw_id_mappings")
# Copy raw_id_mappings from input to dist_graph output on S3
subprocess.call([
"aws", "configure", "set", "default.s3.max_concurrent_requests", "150"])
subprocess.check_call(
[
"aws",
"s3",
"sync",
"--only-show-errors",
"--region",
os.environ["AWS_REGION"],
raw_id_mappings_s3_path,
f"{s3_dglgraph_output}/raw_id_mappings"])

sock.close()
73 changes: 73 additions & 0 deletions tests/unit-tests/gpartition/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import json
from tempfile import TemporaryDirectory
from typing import Dict

import pytest

from graphstorm.gpartition import LocalPartitionAlgorithm

@pytest.fixture(scope="module", name="chunked_metadata_dict")
def metadata_dict_fixture() -> Dict:
return {
"num_nodes_per_type": [10, 20],
"node_type": ["a", "b"],
}

def simple_test_partition(
partition_algorithm: LocalPartitionAlgorithm,
algorithm_name: str,
chunked_metadata_dict: Dict):
"""Ensures that the provided algorithm will create the correct number of partitions,
and that the partition assignment and metadata files are correctly created.
Parameters
----------
partition_algorithm : LocalPartitionAlgorithm
An implementation of the LocalPartitionAlgorithm base class.
algorithm_name : str
The expected name for the partition algorithm in the metadata
chunked_metadata_dict : Dict
The metadata dictionary that was passed to the algorithm.
"""

with TemporaryDirectory() as tmpdir:
num_parts = 4
# This test function is designed to be used with the config
# provided by metadata_dict_fixture()
assert partition_algorithm.metadata_dict == chunked_metadata_dict
partition_algorithm.create_partitions(num_parts, tmpdir)

assert os.path.exists(os.path.join(tmpdir, "a.txt"))
assert os.path.exists(os.path.join(tmpdir, "b.txt"))
assert os.path.exists(os.path.join(tmpdir, "partition_meta.json"))

# Ensure contents of partition_meta.json are correct
with open(os.path.join(tmpdir, "partition_meta.json"), 'r', encoding="utf-8") as f:
part_meta = json.load(f)
assert part_meta["num_parts"] == num_parts
assert part_meta["algo_name"] == algorithm_name

# Ensure contents of partition assignment files are correct
for i, node_type in enumerate(chunked_metadata_dict["node_type"]):
with open(os.path.join(tmpdir, f"{node_type}.txt"), "r", encoding="utf-8") as f:
node_partitions = f.read().splitlines()
assert len(node_partitions) == chunked_metadata_dict["num_nodes_per_type"][i]
for part_id in node_partitions:
assert part_id.isdigit()
assert int(part_id) < num_parts
Loading

0 comments on commit ee7cc86

Please sign in to comment.