diff --git a/pykokkos/core/fusion/access_modes.py b/pykokkos/core/fusion/access_modes.py index 941ba167..58446b8d 100644 --- a/pykokkos/core/fusion/access_modes.py +++ b/pykokkos/core/fusion/access_modes.py @@ -1,6 +1,6 @@ import ast from enum import auto, Enum -from typing import Dict, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from .util import add_parent_refs @@ -9,6 +9,13 @@ class AccessMode(Enum): Write = auto() ReadWrite = auto() +class AccessIndex(Enum): + Empty = 0 + Constant = 1 + TID = 2 + TIDFunc = 3 + Iter = 4 + All = 5 def get_view_access_modes(AST: ast.FunctionDef, view_args: Set[str]) -> Dict[str, AccessMode]: AST = add_parent_refs(AST) @@ -27,17 +34,36 @@ def get_view_access_modes(AST: ast.FunctionDef, view_args: Set[str]) -> Dict[str if not isinstance(node, ast.Subscript): # We are only interested in view accesses continue - if not isinstance(node.value, ast.Name): # Skip type annotations + # Skip type annotations + if isinstance(node.parent, ast.arg): continue - name: str = node.value.id + # Skip inner subscripts as they will be handled by the below while loop + if isinstance(node.parent, ast.Subscript) and isinstance(node.parent.value, ast.Subscript): + continue + + current_node: ast.Subscript = node + while isinstance(current_node, ast.Subscript): + current_node = current_node.value + + # Go back up one to the parent subscript + if isinstance(current_node, ast.Name): + current_node = current_node.parent + + # The subscript node that holds the load/store context is the + # top level one. + context_node: ast.Subscript = current_node + while isinstance(context_node.parent, ast.Subscript): + context_node = context_node.parent + + name: str = current_node.value.id if name not in view_args: continue existing_mode: Optional[AccessMode] = access_modes.get(name) new_mode: AccessMode - if isinstance(node.ctx, ast.Load): + if isinstance(context_node.ctx, ast.Load): if existing_mode is None: new_mode = AccessMode.Read elif existing_mode is AccessMode.Write: @@ -45,7 +71,7 @@ def get_view_access_modes(AST: ast.FunctionDef, view_args: Set[str]) -> Dict[str else: new_mode = existing_mode - if isinstance(node.ctx, ast.Store): + if isinstance(context_node.ctx, ast.Store): if existing_mode is None: new_mode = AccessMode.Write elif existing_mode is AccessMode.Read: @@ -59,3 +85,128 @@ def get_view_access_modes(AST: ast.FunctionDef, view_args: Set[str]) -> Dict[str access_modes[name] = new_mode return access_modes + +class WriteIndicesVisitor(ast.NodeVisitor): + def __init__(self, tid_name: str, view_args: Dict[str, int]): + self.tid_name = tid_name + self.view_args = view_args + + # Map from each view (str) + dimension (int) to an AccessIndex + self.access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {} + self.current_iters: List[Tuple[str, bool]] = [] + + def visit_For(self, node: ast.For) -> None: + index_node = node.target + + is_tid_iter: bool = False + range_call: ast.Call = node.iter + for arg in range_call.args: + if isinstance(arg, ast.Name) and arg.id == self.tid_name: + is_tid_iter = True + + self.current_iters.append((index_node.id, is_tid_iter)) + for b in node.body: + self.visit(b) + + self.current_iters.pop() + + def visit_Call(self, node: ast.Call) -> None: + # Treat function calls like a black box + for arg in node.args: + if not isinstance(arg, ast.Name): + continue + + if arg.id in self.view_args: + rank: int = self.view_args[arg.id] + for i in range(rank): + self.access_indices[(arg.id, i)] = (AccessIndex.All, AccessMode.ReadWrite) + + def visit_Subscript(self, node: ast.Subscript) -> None: + current_node: ast.Subscript = node + slices: List = [] + + while isinstance(current_node, ast.Subscript): + index = current_node.slice + + slices.insert(0, index) + current_node = current_node.value + + # Avoid type annotations + if isinstance(current_node, ast.Attribute): + return + + assert isinstance(current_node, ast.Name) + view_name: str = current_node.id + + if view_name not in self.view_args: + return + + for i, index_node in enumerate(slices): + index_node_str = ast.unparse(index_node) + + if isinstance(index_node, ast.Constant): + new_index = AccessIndex.Constant + elif isinstance(index_node, ast.Name) and index_node.id == self.tid_name: + new_index = AccessIndex.TID + elif self.tid_name in index_node_str: + new_index = AccessIndex.TIDFunc + elif (index_node_str, True) in self.current_iters: + new_index = AccessIndex.TID + elif (index_node_str, False) in self.current_iters: + new_index = AccessIndex.Iter + else: + new_index = AccessIndex.All + + index_to_set: AccessIndex + mode_to_set: AccessMode + + existing_access: Optional[Tuple[AccessIndex, AccessMode]] = self.access_indices.get((view_name, i)) + if existing_access is None: + index_to_set = new_index + mode_to_set = AccessMode.Read if isinstance(node.ctx, ast.Load) else AccessMode.Write + else: + existing_index: AccessIndex = existing_access[0] + existing_mode: AccessMode = existing_access[1] + + # We will update the existing index if it is None or if + # the new index's value (see enum above) is higher then + # the existing value + if new_index.value > existing_index.value: + index_to_set = new_index + else: + index_to_set = existing_index + + if isinstance(current_node.ctx, ast.Load): + if existing_mode is AccessMode.Write: + mode_to_set = AccessMode.ReadWrite + else: + mode_to_set = existing_mode + + if isinstance(current_node.ctx, ast.Store): + if existing_mode is AccessMode.Read: + mode_to_set = AccessMode.ReadWrite + else: + mode_to_set = existing_mode + + if mode_to_set is AccessMode.Write and isinstance(node.parent, ast.AugAssign): + mode_to_set = AccessMode.ReadWrite + + self.access_indices[(view_name, i)] = (index_to_set, mode_to_set) + + +def get_view_write_indices_and_modes(AST: ast.FunctionDef, view_args: Dict[str, int]) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]]: + """ + Get information from the AST needed for fusion safety + + :param AST: the AST of the workunit + :param view_args: the set of view names and dimensionality + :returns: the safety info + """ + AST = add_parent_refs(AST) + + tid_name: str = AST.args.args[0].arg + visitor = WriteIndicesVisitor(tid_name, view_args) + visitor.visit(AST) + access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = visitor.access_indices + + return access_indices \ No newline at end of file diff --git a/pykokkos/core/fusion/trace.py b/pykokkos/core/fusion/trace.py index 60f499e9..8172503a 100644 --- a/pykokkos/core/fusion/trace.py +++ b/pykokkos/core/fusion/trace.py @@ -6,7 +6,7 @@ from pykokkos.core.parsers import Parser, PyKokkosEntity from pykokkos.interface import ExecutionPolicy, RangePolicy, ViewType -from .access_modes import AccessMode, get_view_access_modes +from .access_modes import AccessIndex, AccessMode, get_view_access_modes, get_view_write_indices_and_modes from .future import Future @@ -46,6 +46,7 @@ class TracerOperation: entity_name: str args: Dict[str, Any] dependencies: Set[DataDependency] + access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] def __hash__(self) -> int: return self.op_id @@ -110,14 +111,46 @@ def log_operation( dependencies: Set[DataDependency] access_modes: Dict[str, AccessMode] dependencies, access_modes = self.get_data_dependencies(kwargs, AST) + access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = self.get_safety_info(kwargs, AST) - tracer_op = TracerOperation(self.op_id, future, name, policy, workunit, operation, parser, entity_name, dict(kwargs), dependencies) + tracer_op = TracerOperation(self.op_id, future, name, policy, workunit, operation, parser, entity_name, dict(kwargs), dependencies, access_indices) self.op_id += 1 self.update_output_data_operations(kwargs, access_modes, tracer_op, future, operation) self.operations[tracer_op] = None + def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]]: + """ + Get the view access indices needed to check for safety + + :param kwargs: the keyword arguments passed to the workunit + :param AST: the AST of the input workunit + :returns: the set of data dependencies and the access modes of the views + """ + + # Map from view name to the object id + view_args: Dict[str, int] = {} + # Map from view name to the rank + view_name_and_rank: Dict[str, int] = {} + + for arg, value in kwargs.items(): + if isinstance(value, ViewType): + view_args[arg] = id(value) + view_name_and_rank[arg] = value.rank() + + # Map from view name (str) + dimension (int) to the type of + # access to that view's dimension + write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = get_view_write_indices_and_modes(AST, view_name_and_rank) + + # Now need to convert view name to view ID + safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {} + for (name, dim), access_index in write_indices.items(): + view_id: int = view_args[name] + safety_info[(view_id, dim)] = access_index + + return safety_info + def get_operations(self, data: Union[Future, ViewType]) -> List[TracerOperation]: """ Get all the operations needed to update the data of a future @@ -189,6 +222,56 @@ def fuse(self, operations: List[TracerOperation], strategy: str) -> List[TracerO raise RuntimeError(f"Unrecognized fusion strategy '{strategy}'") + def is_safe_to_fuse(self, current: List[TracerOperation], current_views: Set[ViewType], current_safety_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]], next: TracerOperation, next_views: Set[ViewType]) -> bool: + """ + Check whether the next operation is safe to fuse with the + current operations + + :param current: the current list of tracer operations + :param current_views: the combined set of views used by each operation to be fused + :param next: the next potential operation to be added + :param next_views: the set of views in the operation to be added + :returns: whether the next operation can be added + """ + + common_views = current_views.intersection(next_views) + next_safety_info = next.access_indices + + for view in common_views: + for dim in range(view.rank()): + key: Tuple[int, int] = (id(view), dim) + + # assert key in current_safety_info and key in next_safety_info + assert key in current_safety_info + assert key in next_safety_info + + current_access_index, current_access_mode = current_safety_info[key] + next_access_index, next_access_mode = next_safety_info[key] + + if current_access_mode == AccessMode.Read and next_access_mode == AccessMode.Read: + continue + + if current_access_index.value > AccessIndex.TID.value or next_access_index.value > AccessIndex.TID.value: + return False + + return True + + def get_operation_views(self, operation: TracerOperation) -> Set[ViewType]: + """ + Get all views from a TracerOperation's arguments + + :param operation: the input tracer operation + :returns: the set of views used in that operation + """ + + views: Set[ViewType] = set() + + for key, value in operation.args.items(): + if isinstance(value, ViewType): + views.add(value) + + return views + def fuse_naive(self, operations: List[TracerOperation]) -> List[TracerOperation]: """ Fuse a list of operations naively: combine all consecutive @@ -200,6 +283,8 @@ def fuse_naive(self, operations: List[TracerOperation]) -> List[TracerOperation] fused_ops: List[TracerOperation] = [] ops_to_fuse: List[TracerOperation] = [] + ops_to_fuse_views: Set[ViewType] = set() + fused_safety_info: Dict[Tuple[int, int], AccessIndex] = {} if len(operations) == 0: return [] @@ -215,11 +300,16 @@ def fuse_naive(self, operations: List[TracerOperation]) -> List[TracerOperation] while len(operations) > 0: op: TracerOperation = operations.pop() + op_views: Set[ViewType] = self.get_operation_views(op) + if not isinstance(op.policy, RangePolicy): if len(ops_to_fuse) > 0: ops_to_fuse.reverse() - fused_ops.append(self.fuse_operations(ops_to_fuse)) + fused_ops.append(self.fuse_operations(ops_to_fuse, fused_safety_info)) ops_to_fuse.clear() + ops_to_fuse_views.clear() + fused_safety_info = {} + fused_range = None # Can't fuse team policies now fused_ops.append(op) @@ -229,47 +319,108 @@ def fuse_naive(self, operations: List[TracerOperation]) -> List[TracerOperation] if fused_range is None: fused_range = current_range - if fused_range != current_range: + # Cannot fuse the incoming op with the current ops. Fuse + # everything in ops_to_fuse. + if fused_range != current_range or not self.is_safe_to_fuse(ops_to_fuse, ops_to_fuse_views, fused_safety_info, op, op_views): ops_to_fuse.reverse() - fused_ops.append(self.fuse_operations(ops_to_fuse)) + fused_ops.append(self.fuse_operations(ops_to_fuse, fused_safety_info)) ops_to_fuse.clear() + ops_to_fuse_views.clear() ops_to_fuse.append(op) + ops_to_fuse_views.update(op_views) + fused_safety_info = op.access_indices fused_range = current_range continue if op.operation == "for": ops_to_fuse.append(op) + ops_to_fuse_views.update(op_views) + fused_safety_info = self.fuse_safety_info(fused_safety_info, op.access_indices) + elif op.operation == "reduce": if len(ops_to_fuse) == 0: ops_to_fuse.append(op) + ops_to_fuse_views.update(op_views) + fused_safety_info = self.fuse_safety_info(fused_safety_info, op.access_indices) + else: ops_to_fuse.reverse() - fused_ops.append(self.fuse_operations(ops_to_fuse)) + fused_ops.append(self.fuse_operations(ops_to_fuse, fused_safety_info)) ops_to_fuse.clear() + ops_to_fuse_views.clear() ops_to_fuse.append(op) + ops_to_fuse_views.update(op_views) + fused_safety_info = op.access_indices + fused_range = current_range else: ops_to_fuse.reverse() - fused_ops.append(self.fuse_operations(ops_to_fuse)) + fused_ops.append(self.fuse_operations(ops_to_fuse, fused_safety_info)) ops_to_fuse.clear() + ops_to_fuse_views.clear() ops_to_fuse.append(op) + ops_to_fuse_views.update(op_views) + fused_safety_info = op.access_indices + fused_range = current_range # Fuse anything left over if len(ops_to_fuse) > 0: ops_to_fuse.reverse() - fused_ops.append(self.fuse_operations(ops_to_fuse)) + fused_ops.append(self.fuse_operations(ops_to_fuse, fused_safety_info)) fused_ops.reverse() return fused_ops - def fuse_operations(self, operations: List[TracerOperation]) -> TracerOperation: + def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]], info_1: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]]) -> Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]]: + """ + Fuse the safety info of two separate operations + + :param info_0: the safety info of the first op + :param info_1: the safety info of the second op + :returns: the fused safety info + """ + + fused_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]] = {} + for key, value in info_0.items(): + if key not in info_1: + fused_info[key] = value + else: + other_index, other_mode = info_1[key] + current_index, current_mode = value + + index_to_set: AccessIndex + mode_to_set: AccessMode + + if other_index.value > current_index.value: + index_to_set = other_index + else: + index_to_set = current_index + + if other_mode == current_mode: + mode_to_set = other_mode + else: + mode_to_set = AccessMode.ReadWrite + + fused_info[key] = (index_to_set, mode_to_set) + + for key, value in info_1.items(): + # Already handled in the previous loop + if key in fused_info: + continue + + fused_info[key] = value + + return fused_info + + def fuse_operations(self, operations: List[TracerOperation], fused_safety_info: Dict[Tuple[int, int], AccessIndex]) -> TracerOperation: """ Fuse a list of TracerOperations into one :param operations: the TracerOperations to be fused + :param fused_safety_info: the fused safety information :returns: the fused operation """ @@ -288,6 +439,7 @@ def fuse_operations(self, operations: List[TracerOperation]) -> TracerOperation: parsers: List[Parser] = [] args: Dict[str, Dict[str, Any]] = {} dependencies: Set[DataDependency] = set() + safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {} for index, op in enumerate(operations): assert isinstance(op.policy, RangePolicy) and policy.begin == op.policy.begin and policy.end == op.policy.end @@ -297,6 +449,7 @@ def fuse_operations(self, operations: List[TracerOperation]) -> TracerOperation: parsers.append(op.parser) args[f"args_{index}"] = op.args dependencies.update(op.dependencies) + safety_info = self.fuse_safety_info(safety_info, op.access_indices) fused_name: str if len(names) < 5: @@ -305,7 +458,7 @@ def fuse_operations(self, operations: List[TracerOperation]) -> TracerOperation: # Avoid long names fused_name = "_".join(names[:5]) + hashlib.md5(("".join(names)).encode()).hexdigest() - return TracerOperation(None, future, fused_name, policy, workunits, operation, parsers, fused_name, args, dependencies) + return TracerOperation(None, future, fused_name, policy, workunits, operation, parsers, fused_name, args, dependencies, fused_safety_info) def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Tuple[Set[DataDependency], Dict[str, AccessMode]]: """