Skip to content

Commit

Permalink
Tracer: account for subviews when getting dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar committed Jul 28, 2024
1 parent 9e3a3d8 commit 0fce5ae
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion pykokkos/core/fusion/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from pykokkos.core.parsers import Parser, PyKokkosEntity
from pykokkos.interface import ExecutionPolicy, RangePolicy, ViewType
from pykokkos.interface import ExecutionPolicy, RangePolicy, Subview, ViewType

from .access_modes import AccessIndex, AccessMode, get_view_access_modes, get_view_write_indices_and_modes
from .future import Future
Expand Down Expand Up @@ -160,6 +160,9 @@ def get_operations(self, data: Union[Future, ViewType]) -> List[TracerOperation]
:returns: the list of operations to be executed
"""

if isinstance(data, Subview):
data = data.base_view

version: int = self.data_version.get(id(data), 0)
dependency = DataDependency(None, id(data), version)

Expand Down Expand Up @@ -474,6 +477,9 @@ def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) ->

# First pass to get the Future dependencies and record all the views
for arg, value in kwargs.items():
if isinstance(value, Subview):
value = value.base_view

if isinstance(value, Future):
version: int = self.data_version.get(id(value), 0)
dependency = DataDependency(arg, id(value), version)
Expand All @@ -487,6 +493,9 @@ def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) ->

# Second pass to check if the views are dependencies
for arg, value in kwargs.items():
if isinstance(value, Subview):
value = value.base_view

if isinstance(value, ViewType) and access_modes[arg] in {AccessMode.Read, AccessMode.ReadWrite}:
version: int = self.data_version.get(id(value), 0)
dependency = DataDependency(arg, id(value), version)
Expand Down Expand Up @@ -514,6 +523,9 @@ def update_output_data_operations(
"""

for arg, value in kwargs.items():
if isinstance(value, Subview):
value = value.base_view

if isinstance(value, ViewType) and access_modes[arg] in {AccessMode.Write, AccessMode.ReadWrite}:
version: int = self.data_version.get(id(value), 0)
self.data_version[id(value)] = version + 1
Expand Down

0 comments on commit 0fce5ae

Please sign in to comment.