Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSProcessing] Fix to make graph_name optional again #1050

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions graphstorm-processing/graphstorm_processing/distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import tempfile
import time
from collections.abc import Mapping
from typing import Any, Dict
from typing import Any, Dict, Optional

import boto3
import botocore
Expand Down Expand Up @@ -106,8 +106,8 @@ class ExecutorConfig:
The filesystem type, can be LOCAL or S3
add_reverse_edges : bool
Whether to create reverse edges for each edge type.
graph_name: str
The name of the graph being processed
graph_name: str, optional
The name of the graph being processed. If not provided we use part of the input_prefix.
do_repartition: bool
Whether to apply repartitioning to the graph on the Spark leader.
"""
Expand All @@ -121,7 +121,7 @@ class ExecutorConfig:
config_filename: str
filesystem_type: FilesystemType
add_reverse_edges: bool
graph_name: str
graph_name: Optional[str]
do_repartition: bool


Expand All @@ -135,7 +135,7 @@ class GSProcessingArguments:
num_output_files: int
add_reverse_edges: bool
log_level: str
graph_name: str
graph_name: Optional[str]
do_repartition: bool


Expand All @@ -162,7 +162,14 @@ def __init__(
self.filesystem_type = executor_config.filesystem_type
self.execution_env = executor_config.execution_env
self.add_reverse_edges = executor_config.add_reverse_edges
self.graph_name = executor_config.graph_name
# We use the data location as the graph name if a name is not provided
if executor_config.graph_name:
self.graph_name = executor_config.graph_name
else:
derived_name = s3_utils.s3_path_remove_trailing(self.input_prefix).split("/")[-1]
logging.warning("Setting graph name derived from input path: %s", derived_name)
self.graph_name = derived_name
check_graph_name(self.graph_name)
self.repartition_on_leader = executor_config.do_repartition
# Input config dict using GSProcessing schema
self.gsp_config_dict = {}
Expand Down Expand Up @@ -541,11 +548,14 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--graph-name",
type=str,
help="Name for the graph being processed."
"The graph name must adhere to the Python "
"identifier naming rules with the exception "
"that hyphens (-) are permitted and the name "
"can start with numbers",
help=(
"Name for the graph being processed."
"The graph name must adhere to the Python "
"identifier naming rules with the exception "
"that hyphens (-) are permitted and the name "
"can start with numbers. If not provided, we will use the last "
"section of the input prefix path."
),
required=False,
default=None,
)
Expand Down Expand Up @@ -604,7 +614,6 @@ def main():
level=gsprocessing_args.log_level,
format="[GSPROCESSING] %(asctime)s %(levelname)-8s %(message)s",
)
check_graph_name(gsprocessing_args.graph_name)

# Determine execution environment
if os.path.exists("/opt/ml/config/processingjobconfig.json"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,6 @@ def process_and_write_graph_data(
self.timers["process_edge_data"] = perf_counter() - edges_start_time
metadata_dict["edge_data"] = edge_data_dict
metadata_dict["edges"] = edge_structure_dict
# We use the data location as the graph name, can also take from user?
# TODO: Fix this, take from config?
metadata_dict["graph_name"] = (
self.graph_name if self.graph_name else self.input_prefix.split("/")[-1]
)

# Ensure output dict has the correct order of keys
for edge_type in metadata_dict["edge_type"]:
Expand Down Expand Up @@ -447,6 +442,8 @@ def _initialize_metadata_dict(
metadata_dict["edge_type"] = edge_types
metadata_dict["node_type"] = sorted(node_type_set)

metadata_dict["graph_name"] = self.graph_name

return metadata_dict

def _finalize_graphinfo_dict(self, metadata_dict: Dict) -> Dict:
Expand Down
56 changes: 38 additions & 18 deletions graphstorm-processing/tests/test_dist_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def user_state_categorical_precomp_file_fixture():
os.remove(precomp_file)


def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical_precomp_file):
"""Test run function with local data"""
@pytest.fixture(name="executor_configuration")
def executor_config_fixture(tempdir: str):
"""Create a re-usable ExecutorConfig"""
input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph")
executor_configuration = ExecutorConfig(
local_config_path=input_path,
Expand All @@ -79,6 +80,15 @@ def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical
do_repartition=True,
)

yield executor_configuration


def test_dist_executor_run_with_precomputed(
tempdir: str,
user_state_categorical_precomp_file: str,
executor_configuration: ExecutorConfig,
):
"""Test run function with local data"""
original_precomp_file = user_state_categorical_precomp_file

with open(original_precomp_file, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -106,23 +116,8 @@ def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical
# TODO: Verify other metadata files that verify_integ_test_output doesn't check for


def test_merge_input_and_transform_dicts(tempdir: str):
def test_merge_input_and_transform_dicts(executor_configuration: ExecutorConfig):
"""Test the _merge_config_with_transformations function with hardcoded json data"""
input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph")
executor_configuration = ExecutorConfig(
local_config_path=input_path,
local_metadata_output_path=tempdir,
input_prefix=input_path,
output_prefix=tempdir,
num_output_files=-1,
config_filename="gsprocessing-config.json",
execution_env=ExecutionEnv.LOCAL,
filesystem_type=FilesystemType.LOCAL,
add_reverse_edges=True,
graph_name="small_heterogeneous_graph",
do_repartition=True,
)

dist_executor = DistributedExecutor(executor_configuration)

pre_comp_transormations = {
Expand All @@ -148,3 +143,28 @@ def test_merge_input_and_transform_dicts(tempdir: str):
if "state" == feature["column"]:
transform_for_feature = feature["precomputed_transformation"]
assert transform_for_feature["transformation_name"] == "categorical"


def test_dist_executor_graph_name(executor_configuration: ExecutorConfig):
"""Test cases for graph name"""

# Ensure we can set a valid graph name
executor_configuration.graph_name = "2024-a_valid_name"
dist_executor = DistributedExecutor(executor_configuration)
assert dist_executor.graph_name == "2024-a_valid_name"

# Ensure default value is used when graph_name is not provided
thvasilo marked this conversation as resolved.
Show resolved Hide resolved
executor_configuration.graph_name = None
dist_executor = DistributedExecutor(executor_configuration)
assert dist_executor.graph_name == "small_heterogeneous_graph"

# Ensure we raise when invalid graph name is provided
with pytest.raises(AssertionError):
executor_configuration.graph_name = "graph.name"
dist_executor = DistributedExecutor(executor_configuration)

# Ensure a valid default graph name is parsed when the input ends in '/'
executor_configuration.graph_name = None
executor_configuration.input_prefix = executor_configuration.input_prefix + "/"
dist_executor = DistributedExecutor(executor_configuration)
assert dist_executor.graph_name == "small_heterogeneous_graph"
Loading