Skip to content

Commit

Permalink
Fusion: allow fusion of a single parallel_reduce with a sequence of p…
Browse files Browse the repository at this point in the history
…arallel_fors (#257)

* Fusion: allow fusion of a single parallel_reduce with a sequence of parallel_fors

* Fusion: add rmul operator overload to Future

* Fusion: add list of parsers when fusing kernels in order to properly retrieve entities

* Views: flush data on view write

* Runtime: retrieve different parsers when using manual fusion
  • Loading branch information
NaderAlAwar authored Feb 24, 2024
1 parent 7bfd51b commit 9e98896
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 27 deletions.
27 changes: 25 additions & 2 deletions pykokkos/core/fusion/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,28 @@

def fuse_workunit_kwargs_and_params(
workunit_trees: List[ast.AST],
kwargs: Dict[str, Any]
kwargs: Dict[str, Any],
operation: str
) -> Tuple[Dict[str, Any], List[ast.arg]]:
"""
Fuse the parameters and runtime arguments of a list of workunits and rename them as necessary
:param workunits_trees: the list of workunit trees (ASTs) being merged
:param kwargs: the keyword arguments passed to the call
:param operation: they type of parallel operation ("parallel_for", "parallel_reduce", or "parallel_scan")
:returns: a tuple of the fused kwargs and the combined inspected parameters
"""

if operation == "parallel_scan":
raise RuntimeError("parallel_scan not supported for fusion")

fused_kwargs: Dict[str, Any] = {}
fused_params: List[ast.arg] = []
fused_params.append(ast.arg(arg="fused_tid", annotation=int))

if operation == "parallel_reduce":
fused_params.append(ast.arg(arg="pk_fused_acc"))

view_ids: Set[int] = set()

for workunit_idx, tree in enumerate(workunit_trees):
Expand All @@ -30,7 +38,14 @@ def fuse_workunit_kwargs_and_params(
current_kwargs: Dict[str, Any] = kwargs[key]

current_params: List[ast.arg] = [p for p in tree.args.args]
for p in current_params[1:]: # Skip the thread ID
if operation == "parallel_reduce" and workunit_idx == len(workunit_trees) - 1:
# Skip the thread ID and the accumulator
current_params = current_params[2:]
else:
# Skip the thread ID
current_params = current_params[1:]

for p in current_params:
current_arg = current_kwargs[p.arg]
if "PK_FUSE_ARGS" in os.environ and id(current_arg) in view_ids:
continue
Expand Down Expand Up @@ -78,6 +93,8 @@ def fuse_arguments(all_args: List[ast.arguments], **kwargs) -> Tuple[ast.argumen
new_tid: str = "fused_tid"
fused_args = ast.arguments(args=[ast.arg(arg=new_tid, annotation=ast.Name(id='int', ctx=ast.Load()))])

new_acc: str = "pk_fused_acc"

# Map from view ID to fused name
fused_view_names: Dict[int, str] = {}

Expand All @@ -97,6 +114,12 @@ def fuse_arguments(all_args: List[ast.arguments], **kwargs) -> Tuple[ast.argumen
name_map[key] = new_tid
continue

# Account for accumulator
if old_name not in current_kwargs and arg_idx == 1:
name_map[key] = new_acc
fused_args.args.insert(1, ast.arg(arg=new_acc, annotation=arg.annotation))
continue

current_arg = current_kwargs[old_name]
if "PK_FUSE_ARGS" in os.environ and id(current_arg) in fused_view_names:
name_map[key] = fused_view_names[id(current_arg)]
Expand Down
4 changes: 4 additions & 0 deletions pykokkos/core/fusion/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def __mul__(self, other):
self.flush_trace()
return self.value * other

def __rmul__(self, other):
self.flush_trace()
return other * self.value

def __truediv__(self, other):
self.flush_trace()
return self.value / other
Expand Down
28 changes: 20 additions & 8 deletions pykokkos/core/fusion/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,22 @@ def fuse_naive(self, operations: List[TracerOperation]) -> List[TracerOperation]
while len(operations) > 0:
op: TracerOperation = operations.pop()

if op.operation != "for":
if len(ops_to_fuse) > 0: # Fuse and add any outstanding operations
if op.operation == "for":
ops_to_fuse.append(op)
elif op.operation == "reduce":
if len(ops_to_fuse) == 0:
ops_to_fuse.append(op)
else:
ops_to_fuse.reverse()
fused_ops.append(self.fuse_operations(ops_to_fuse))
ops_to_fuse.clear()

# Add the current operation
fused_ops.append(op)
ops_to_fuse.append(op)
else:
ops_to_fuse.reverse()
fused_ops.append(self.fuse_operations(ops_to_fuse))
ops_to_fuse.clear()

ops_to_fuse.append(op)

# Fuse anything left over
Expand All @@ -238,17 +245,22 @@ def fuse_operations(self, operations: List[TracerOperation]) -> TracerOperation:
names: List[str] = []
policy: RangePolicy = operations[0].policy
workunits: List[Callable[..., None]] = []
operation: str = operations[0].operation
parser: Parser = operations[0].parser

# The last operation determines the type of the fused
# operation since it can be a reduce
operation: str = operations[-1].operation
future: Optional[Future] = operations[-1].future

parsers: List[Parser] = []
args: Dict[str, Dict[str, Any]] = {}
dependencies: Set[DataDependency] = set()

for index, op in enumerate(operations):
assert isinstance(op.policy, RangePolicy) and policy.begin == op.policy.begin and policy.end == op.policy.end
assert operation == op.operation == "for"

names.append(op.name if op.name is not None else op.workunit.__name__)
workunits.append(op.workunit)
parsers.append(op.parser)
args[f"args_{index}"] = op.args
dependencies.update(op.dependencies)

Expand All @@ -259,7 +271,7 @@ def fuse_operations(self, operations: List[TracerOperation]) -> TracerOperation:
# Avoid long names
fused_name = "_".join(names[:5]) + hashlib.md5(("".join(names)).encode()).hexdigest()

return TracerOperation(None, None, fused_name, policy, workunits, operation, parser, fused_name, args, dependencies)
return TracerOperation(None, future, fused_name, policy, workunits, operation, parsers, fused_name, args, dependencies)

def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Tuple[Set[DataDependency], Dict[str, AccessMode]]:
"""
Expand Down
31 changes: 22 additions & 9 deletions pykokkos/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,18 @@ def run_workunit(
raise RuntimeError("ERROR: operation cannot be None for Debug")
return run_workunit_debug(policy, workunit, operation, initial_value, **kwargs)

metadata: EntityMetadata = get_metadata(workunit[0]) if isinstance(workunit, list) else get_metadata(workunit)
parser: Parser = self.compiler.get_parser(metadata.path)
metadata: EntityMetadata
parser: Union[Parser, List[Parser]]

if isinstance(workunit, list):
metadata = get_metadata(workunit[0])
parser = []
for this_workunit in workunit:
this_metadata = get_metadata(this_workunit)
parser.append(self.compiler.get_parser(this_metadata.path))
else:
metadata = get_metadata(workunit)
parser = self.compiler.get_parser(metadata.path)

if self.fusion_strategy is not None:
future = Future()
Expand All @@ -144,7 +154,7 @@ def execute_workunit(
policy: ExecutionPolicy,
workunit: Union[Callable[..., None], List[Callable[..., None]]],
operation: str,
parser: Parser,
parser: Union[Parser, List[Parser]],
**kwargs
) -> Optional[Union[float, int]]:
"""
Expand All @@ -169,7 +179,7 @@ def execute_workunit(
members: PyKokkosMembers = self.precompile_workunit(workunit, execution_space, updated_decorator, updated_types, types_signature, **kwargs)

module_setup: ModuleSetup = self.get_module_setup(workunit, execution_space, types_signature)
return self.execute(workunit, module_setup, members, execution_space, policy=policy, name=name, **kwargs)
return self.execute(workunit, module_setup, members, execution_space, policy=policy, name=name, operation=operation, **kwargs)

def flush_data(self, data: Union[Future, ViewType]) -> None:
"""
Expand Down Expand Up @@ -226,6 +236,7 @@ def execute(
space: ExecutionSpace,
policy: Optional[ExecutionPolicy] = None,
name: Optional[str] = None,
operation: Optional[str] = None,
**kwargs
) -> Optional[Union[float, int]]:
"""
Expand All @@ -237,7 +248,7 @@ def execute(
:param space: the execution space
:param policy: the execution policy for workunits
:param name: the name of the kernel
:param entity_trees: Optional parameter: List of ASTs of entities being fused - only provided when entity is a list
:param operation: the name of the operation "for", "reduce", or "scan"
:param kwargs: the keyword arguments passed to the workunit
:returns: the result of the operation (None for "for" and workloads)
"""
Expand All @@ -251,7 +262,7 @@ def execute(

module = self.import_module(module_setup.name, module_path)

args: Dict[str, Any] = self.get_arguments(entity, members, space, policy, **kwargs)
args: Dict[str, Any] = self.get_arguments(entity, members, space, policy, operation, **kwargs)
if name is None:
args["pk_kernel_name"] = ""
else:
Expand Down Expand Up @@ -293,6 +304,7 @@ def get_arguments(
members: PyKokkosMembers,
space: ExecutionSpace,
policy: Optional[ExecutionPolicy],
operation: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""
Expand All @@ -302,6 +314,7 @@ def get_arguments(
:param members: a collection of PyKokkos related members
:param space: the execution space
:param policy: the execution policy of the operation
:param operation: the name of the operation "for", "reduce", or "scan"
:param kwargs: the keyword arguments passed to a workunit
"""

Expand All @@ -327,10 +340,10 @@ def get_arguments(
else:
is_fused: bool = isinstance(entity, list)
if is_fused:
parser = self.compiler.get_parser(get_metadata(entity[0]).path)
entity_trees = [parser.get_entity(get_metadata(this_entity).name).AST for this_entity in entity]
parsers = [self.compiler.get_parser(get_metadata(e).path) for e in entity]
entity_trees = [this_parser.get_entity(get_metadata(this_entity).name).AST for this_entity, this_parser in zip(entity, parsers)]

kwargs, _ = fuse_workunit_kwargs_and_params(entity_trees, kwargs)
kwargs, _ = fuse_workunit_kwargs_and_params(entity_trees, kwargs, f"parallel_{operation}")
entity_members = kwargs

args.update(self.get_fields(entity_members))
Expand Down
18 changes: 10 additions & 8 deletions pykokkos/core/type_inference/args_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def get_annotations(parallel_type: str, workunit_trees: Union[Tuple[Callable, as
param_list: List[ast.arg]

if isinstance(workunit_trees, list):
if parallel_type != "parallel_for":
raise RuntimeError("Can only do kernel fusion with parallel for")
if parallel_type == "parallel_scan":
raise RuntimeError("Cannot do kernel fusion with parallel scan")
workunit = [w for w, _ in workunit_trees]
trees = [t for _, t in workunit_trees]
passed_kwargs, param_list = fuse_workunit_kwargs_and_params(trees, passed_kwargs)
passed_kwargs, param_list = fuse_workunit_kwargs_and_params(trees, passed_kwargs, parallel_type)
else:
workunit, entity_AST = workunit_trees
param_list = [x for x in entity_AST.args.args]
Expand Down Expand Up @@ -115,11 +115,12 @@ def get_annotations(parallel_type: str, workunit_trees: Union[Tuple[Callable, as
return updated_types


def get_views_decorator(workunit_trees: List[Tuple[Callable, ast.AST]], passed_kwargs) -> UpdatedDecorator:
def get_views_decorator(parallel_type: str, workunit_trees: List[Tuple[Callable, ast.AST]], passed_kwargs) -> UpdatedDecorator:
'''
Extract the layout, space, trait information against view: will be used to construct decorator
specifiers
:param parallel_type: A string identifying the type of parallel dispatch ("parallel_for", "parallel_reduce" ...)
:param handled_args: Processed arguments passed to the dispatch
:param passed_kwargs: Keyword arguments passed to parallel dispatch (has views)
:returns: UpdatedDecorator object
Expand All @@ -128,7 +129,7 @@ def get_views_decorator(workunit_trees: List[Tuple[Callable, ast.AST]], passed_k
param_list: List[ast.arg]
if isinstance(workunit_trees, list):
trees = [t for _, t in workunit_trees]
passed_kwargs, param_list = fuse_workunit_kwargs_and_params(trees, passed_kwargs)
passed_kwargs, param_list = fuse_workunit_kwargs_and_params(trees, passed_kwargs, parallel_type)
param_list = [p.arg for p in param_list]
else:
_, entity_AST = workunit_trees
Expand Down Expand Up @@ -410,11 +411,12 @@ def get_type_info(

if not isinstance(workunit, list):
workunit = [workunit] # for easier transformations
parser = [parser]
list_passed = False

for this_workunit in workunit:
for this_workunit, this_parser in zip(workunit, parser):
this_metadata = get_metadata(this_workunit)
this_tree = parser.get_entity(this_metadata.name).AST
this_tree = this_parser.get_entity(this_metadata.name).AST
workunit_str = str(this_workunit)

if not isinstance(this_tree, ast.FunctionDef):
Expand Down Expand Up @@ -445,7 +447,7 @@ def get_type_info(

if is_standalone_workunit:
updated_types = get_annotations(f"parallel_{operation}", workunit_trees, policy, passed_kwargs)
updated_decorator = get_views_decorator(workunit_trees, passed_kwargs)
updated_decorator = get_views_decorator(f"parallel_{operation}", workunit_trees, passed_kwargs)
types_signature = get_types_signature(updated_types, updated_decorator, execution_space)

return updated_types, updated_decorator, types_signature
Expand Down
3 changes: 3 additions & 0 deletions pykokkos/interface/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def __setitem__(self, key: Union[int, TeamMember], value: Union[int, float]) ->
:param value: the new value at the index.
"""

if "PK_FUSION" in os.environ:
runtime_singleton.runtime.flush_data(self)

self.data[key] = value

def __bool__(self):
Expand Down

0 comments on commit 9e98896

Please sign in to comment.