diff --git a/pykokkos/core/fusion/trace.py b/pykokkos/core/fusion/trace.py index 8172503a..7ce8c6ae 100644 --- a/pykokkos/core/fusion/trace.py +++ b/pykokkos/core/fusion/trace.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from pykokkos.core.parsers import Parser, PyKokkosEntity -from pykokkos.interface import ExecutionPolicy, RangePolicy, ViewType +from pykokkos.interface import ExecutionPolicy, RangePolicy, Subview, ViewType from .access_modes import AccessIndex, AccessMode, get_view_access_modes, get_view_write_indices_and_modes from .future import Future @@ -160,6 +160,9 @@ def get_operations(self, data: Union[Future, ViewType]) -> List[TracerOperation] :returns: the list of operations to be executed """ + if isinstance(data, Subview): + data = data.base_view + version: int = self.data_version.get(id(data), 0) dependency = DataDependency(None, id(data), version) @@ -474,6 +477,9 @@ def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> # First pass to get the Future dependencies and record all the views for arg, value in kwargs.items(): + if isinstance(value, Subview): + value = value.base_view + if isinstance(value, Future): version: int = self.data_version.get(id(value), 0) dependency = DataDependency(arg, id(value), version) @@ -487,6 +493,9 @@ def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> # Second pass to check if the views are dependencies for arg, value in kwargs.items(): + if isinstance(value, Subview): + value = value.base_view + if isinstance(value, ViewType) and access_modes[arg] in {AccessMode.Read, AccessMode.ReadWrite}: version: int = self.data_version.get(id(value), 0) dependency = DataDependency(arg, id(value), version) @@ -514,6 +523,9 @@ def update_output_data_operations( """ for arg, value in kwargs.items(): + if isinstance(value, Subview): + value = value.base_view + if isinstance(value, ViewType) and access_modes[arg] in {AccessMode.Write, AccessMode.ReadWrite}: version: int = self.data_version.get(id(value), 0) self.data_version[id(value)] = version + 1