Skip to content

Commit

Permalink
[GSProcessing] Fix to make graph_name optional again
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Sep 30, 2024
1 parent 7af14b9 commit 620fde3
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 35 deletions.
31 changes: 19 additions & 12 deletions graphstorm-processing/graphstorm_processing/distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import tempfile
import time
from collections.abc import Mapping
from typing import Any, Dict
from typing import Any, Dict, Optional

import boto3
import botocore
Expand Down Expand Up @@ -106,8 +106,8 @@ class ExecutorConfig:
The filesystem type, can be LOCAL or S3
add_reverse_edges : bool
Whether to create reverse edges for each edge type.
graph_name: str
The name of the graph being processed
graph_name: str, optional
The name of the graph being processed. If not provided we use part of the input_prefix.
do_repartition: bool
Whether to apply repartitioning to the graph on the Spark leader.
"""
Expand All @@ -121,7 +121,7 @@ class ExecutorConfig:
config_filename: str
filesystem_type: FilesystemType
add_reverse_edges: bool
graph_name: str
graph_name: Optional[str]
do_repartition: bool


Expand All @@ -135,7 +135,7 @@ class GSProcessingArguments:
num_output_files: int
add_reverse_edges: bool
log_level: str
graph_name: str
graph_name: Optional[str]
do_repartition: bool


Expand All @@ -162,7 +162,12 @@ def __init__(
self.filesystem_type = executor_config.filesystem_type
self.execution_env = executor_config.execution_env
self.add_reverse_edges = executor_config.add_reverse_edges
self.graph_name = executor_config.graph_name
self.graph_name = (
executor_config.graph_name
if executor_config.graph_name
else s3_utils.s3_path_remove_trailing(executor_config.input_prefix).split("/")[-1]
)
check_graph_name(self.graph_name)
self.repartition_on_leader = executor_config.do_repartition
# Input config dict using GSProcessing schema
self.gsp_config_dict = {}
Expand Down Expand Up @@ -541,11 +546,14 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--graph-name",
type=str,
help="Name for the graph being processed."
"The graph name must adhere to the Python "
"identifier naming rules with the exception "
"that hyphens (-) are permitted and the name "
"can start with numbers",
help=(
"Name for the graph being processed."
"The graph name must adhere to the Python "
"identifier naming rules with the exception "
"that hyphens (-) are permitted and the name "
"can start with numbers. If not provided, we will use the last "
"section of the input prefix path."
),
required=False,
default=None,
)
Expand Down Expand Up @@ -604,7 +612,6 @@ def main():
level=gsprocessing_args.log_level,
format="[GSPROCESSING] %(asctime)s %(levelname)-8s %(message)s",
)
check_graph_name(gsprocessing_args.graph_name)

# Determine execution environment
if os.path.exists("/opt/ml/config/processingjobconfig.json"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,6 @@ def process_and_write_graph_data(
self.timers["process_edge_data"] = perf_counter() - edges_start_time
metadata_dict["edge_data"] = edge_data_dict
metadata_dict["edges"] = edge_structure_dict
# We use the data location as the graph name, can also take from user?
# TODO: Fix this, take from config?
metadata_dict["graph_name"] = (
self.graph_name if self.graph_name else self.input_prefix.split("/")[-1]
)

# Ensure output dict has the correct order of keys
for edge_type in metadata_dict["edge_type"]:
Expand Down Expand Up @@ -447,6 +442,10 @@ def _initialize_metadata_dict(
metadata_dict["edge_type"] = edge_types
metadata_dict["node_type"] = sorted(node_type_set)

# We use the data location as the graph name, can also take from user?
# TODO: Fix this, take from config?
metadata_dict["graph_name"] = self.graph_name

return metadata_dict

def _finalize_graphinfo_dict(self, metadata_dict: Dict) -> Dict:
Expand Down
51 changes: 33 additions & 18 deletions graphstorm-processing/tests/test_dist_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def user_state_categorical_precomp_file_fixture():
os.remove(precomp_file)


def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical_precomp_file):
"""Test run function with local data"""
@pytest.fixture(name="executor_configuration")
def executor_config_fixture(tempdir: str):
"""Create a re-usable ExecutorConfig"""
input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph")
executor_configuration = ExecutorConfig(
local_config_path=input_path,
Expand All @@ -79,6 +80,15 @@ def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical
do_repartition=True,
)

yield executor_configuration


def test_dist_executor_run_with_precomputed(
tempdir: str,
user_state_categorical_precomp_file: str,
executor_configuration: ExecutorConfig,
):
"""Test run function with local data"""
original_precomp_file = user_state_categorical_precomp_file

with open(original_precomp_file, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -106,23 +116,8 @@ def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical
# TODO: Verify other metadata files that verify_integ_test_output doesn't check for


def test_merge_input_and_transform_dicts(tempdir: str):
def test_merge_input_and_transform_dicts(executor_configuration: ExecutorConfig):
"""Test the _merge_config_with_transformations function with hardcoded json data"""
input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph")
executor_configuration = ExecutorConfig(
local_config_path=input_path,
local_metadata_output_path=tempdir,
input_prefix=input_path,
output_prefix=tempdir,
num_output_files=-1,
config_filename="gsprocessing-config.json",
execution_env=ExecutionEnv.LOCAL,
filesystem_type=FilesystemType.LOCAL,
add_reverse_edges=True,
graph_name="small_heterogeneous_graph",
do_repartition=True,
)

dist_executor = DistributedExecutor(executor_configuration)

pre_comp_transormations = {
Expand All @@ -148,3 +143,23 @@ def test_merge_input_and_transform_dicts(tempdir: str):
if "state" == feature["column"]:
transform_for_feature = feature["precomputed_transformation"]
assert transform_for_feature["transformation_name"] == "categorical"


def test_dist_executor_graph_name(executor_configuration: ExecutorConfig):
"""Test cases for graph name"""

# Ensure default value is used when graph_name is not provided
executor_configuration.graph_name = None
dist_executor = DistributedExecutor(executor_configuration)
assert dist_executor.graph_name == "small_heterogeneous_graph"

# Ensure we raise when invalid graph name is provided
with pytest.raises(AssertionError):
executor_configuration.graph_name = "graph.name"
dist_executor = DistributedExecutor(executor_configuration)

# Ensure a valid default graph name is parsed when the input ends in '/'
executor_configuration.graph_name = None
executor_configuration.input_prefix = executor_configuration.input_prefix + "/"
dist_executor = DistributedExecutor(executor_configuration)
assert dist_executor.graph_name == "small_heterogeneous_graph"

0 comments on commit 620fde3

Please sign in to comment.