Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement loop fusion optimization inside kernels #253

Merged
merged 27 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
76d7fbe
Fusion: remove redundant str() call
NaderAlAwar Jan 25, 2024
6af0b2a
Fusion: move visitors to a separate file to reuse them in loop fusion
NaderAlAwar Jan 25, 2024
8fd27c3
Translators: add index in body to each AST node
NaderAlAwar Jan 25, 2024
a1b3633
Optimizations: add the loop fusion optimization
NaderAlAwar Jan 25, 2024
778dff7
Compiler: add env variable to do loop fusion
NaderAlAwar Jan 25, 2024
c266301
Translators: add idx in parent body attribute to all nodes
NaderAlAwar Jan 27, 2024
74f5cdd
Optimizations: account for fused_* additions from fused kernels when …
NaderAlAwar Jan 27, 2024
dae698b
PyKokkosVisitors: add pk.cpp_auto as an allowed type (maps to auto in…
NaderAlAwar Jan 31, 2024
f8060eb
StaticTranslator: fix idx_in_loop to work for all AST nodes with lists
NaderAlAwar Jan 31, 2024
61ab4c7
Optimizations: move common functionality to util.py in preparation fo…
NaderAlAwar Jan 31, 2024
c10f7f4
Optimizations: add memory_ops_fuse optimization
NaderAlAwar Jan 31, 2024
a8928a8
Compiler: restrict optimizations to workunit and fused styles only
NaderAlAwar Jan 31, 2024
9eb72b4
test_loop_fusion.py: Added basic and manual tests for loop fusion
HannanNaeem Jan 30, 2024
3563f68
run_loop_fusion_test.py: added basic automation and comparison
HannanNaeem Jan 31, 2024
e7295e0
cleaned up and refactored tests to work with pytest
HannanNaeem Jan 31, 2024
5ec2bc0
Added simple neg dist test
HannanNaeem Feb 1, 2024
ee61112
Optimizations: track constant variables used in memory accesses to fa…
NaderAlAwar Feb 2, 2024
6dad848
Fusion: implent rsub() for future and add check if data does not depe…
NaderAlAwar Feb 12, 2024
c103ea5
mypy: add missing ignore
NaderAlAwar Feb 13, 2024
53a2f44
mypy: add missing ignore
NaderAlAwar Feb 13, 2024
edc53f0
StaticTranslator: fix issue with assigning parent_accessor to string …
NaderAlAwar Feb 13, 2024
89186d5
Fusion: overload more operators for Future
NaderAlAwar Feb 16, 2024
3c733be
Fusion: truncate long kernel names to avoid errors
NaderAlAwar Feb 16, 2024
faa9bc0
ModuleSetup: truncate long file names in fused kernels to avoid filen…
NaderAlAwar Feb 16, 2024
95ccc6e
WorkunitVisitor: fix issue with getting accumulator argument in MDRan…
NaderAlAwar Feb 16, 2024
76cc291
Tests: fix formatting and spelling
NaderAlAwar Feb 19, 2024
65cbeeb
Optimizations: don't fuse loops with print statements
NaderAlAwar Feb 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,24 @@ ignore_errors = True
[mypy-pykokkos.interface.views]
ignore_errors = True

[mypy-pykokkos.core.optimizations.loop_fuse]
ignore_errors = True

[mypy-pykokkos.core.optimizations.memory_ops_fuse]
ignore_errors = True

[mypy-pykokkos.core.optimizations.util]
ignore_errors = True

[mypy-pykokkos.core.fusion.fuse]
ignore_errors = True

[mypy-pykokkos.core.fusion.trace]
ignore_errors = True

[mypy-pykokkos.core.fusion.util]
ignore_errors = True

[mypy-pykokkos.core.fusion.access_modes]
ignore_errors = True

Expand Down
6 changes: 6 additions & 0 deletions pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Dict, List, Optional, Tuple

from pykokkos.core.fusion import fuse_workunits
from pykokkos.core.optimizations import loop_fuse, memory_ops_fuse
from pykokkos.core.parsers import Parser, PyKokkosEntity, PyKokkosStyles
from pykokkos.core.translators import PyKokkosMembers, StaticTranslator
from pykokkos.core.type_inference import UpdatedTypes, UpdatedDecorator
Expand Down Expand Up @@ -223,6 +224,11 @@ def compile_entity(
bindings: List[str]
cast: List[str]

if entity.style in {PyKokkosStyles.workunit, PyKokkosStyles.fused}:
if "PK_LOOP_FUSE" in os.environ:
loop_fuse(entity.AST)
if "PK_MEM_FUSE" in os.environ:
memory_ops_fuse(entity.AST, entity.pk_import)
functor, bindings, cast = translator.translate(entity, classtypes)

t_end: float = time.perf_counter() - t_start
Expand Down
47 changes: 1 addition & 46 deletions pykokkos/core/fusion/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,7 @@
import os
from typing import Any, Dict, List, Set, Tuple, Union

def get_node_name(node: Union[ast.Attribute, ast.Name]) -> str:
"""
Copied from visitors_util.py due to circular import
"""

name: str
if isinstance(node, ast.Attribute):
name = node.attr
else:
name = node.id

return name


class DeclarationsVisitor(ast.NodeVisitor):
"""
Get all variable declarations
"""

def __init__(self) -> None:
self.declarations: Set[str] = set()

def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
self.declarations.add(get_node_name(node.target))


class VariableRenamer(ast.NodeTransformer):
"""
Renames variables in a fused ast according to a map
"""

def __init__(self, name_map: Dict[Tuple[str, int], str], workunit_idx: int):
self.name_map = name_map
self.workunit_idx = workunit_idx

def visit_Name(self, node: ast.Name) -> Any:
key = (node.id, self.workunit_idx)
# If the name is not mapped, keep the original name
node.id = self.name_map.get(key, node.id)
return node

def visit_keyword(self, node: ast.keyword) -> Any:
key = (node.id, self.workunit_idx)
# If the name is not mapped, keep the original name
node.arg = self.name_map.get(key, node.arg)
return node
from .util import DeclarationsVisitor, VariableRenamer


def fuse_workunit_kwargs_and_params(
Expand Down
25 changes: 24 additions & 1 deletion pykokkos/core/fusion/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def __sub__(self, other):
self.flush_trace()
return self.value - other

def __rsub__(self, other):
self.flush_trace()
return other - self.value

def __mul__(self, other):
self.flush_trace()
return self.value * other
Expand All @@ -28,6 +32,10 @@ def __truediv__(self, other):
self.flush_trace()
return self.value / other

def __rtruediv__(self, other):
self.flush_trace()
return other / self.value

def __floordiv__(self, other):
self.flush_trace()
return self.value // other
Expand All @@ -36,8 +44,23 @@ def __str__(self):
self.flush_trace()
return str(self.value)

def __eq__(self, other):
self.flush_trace()
if isinstance(other, Future):
return self.value == other.value

return self.value == other

def __lt__(self, other):
self.flush_trace()
return self.value < other

def __gt__(self, other):
self.flush_trace()
return self.value > other

def __repr__(self) -> str:
return str(f"Future(value={self.value})")
return f"Future(value={self.value})"

def flush_trace(self) -> None:
runtime_singleton.runtime.flush_data(self)
Expand Down
12 changes: 11 additions & 1 deletion pykokkos/core/fusion/trace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import hashlib
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -129,6 +130,10 @@ def get_operations(self, data: Union[Future, ViewType]) -> List[TracerOperation]
version: int = self.data_version.get(id(data), 0)
dependency = DataDependency(None, id(data), version)

if dependency not in self.data_operation:
# The data does not depend on any prior operation
return []

operation: TracerOperation = self.data_operation[dependency]
if operation not in self.operations:
# This means that the dependency was already updated
Expand Down Expand Up @@ -247,7 +252,12 @@ def fuse_operations(self, operations: List[TracerOperation]) -> TracerOperation:
args[f"args_{index}"] = op.args
dependencies.update(op.dependencies)

fused_name: str = "_".join(names)
fused_name: str
if len(names) < 5:
fused_name = "_".join(names)
else:
# Avoid long names
fused_name = "_".join(names[:5]) + hashlib.md5(("".join(names)).encode()).hexdigest()

return TracerOperation(None, None, fused_name, policy, workunits, operation, parser, fused_name, args, dependencies)

Expand Down
54 changes: 54 additions & 0 deletions pykokkos/core/fusion/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import ast
from typing import Dict, Set, Tuple, Union


def get_node_name(node: Union[ast.Attribute, ast.Name]) -> str:
"""
Copied from visitors_util.py due to circular import
"""

name: str
if isinstance(node, ast.Attribute):
name = node.attr
else:
name = node.id

return name


class DeclarationsVisitor(ast.NodeVisitor):
"""
Get all variable declarations
"""

def __init__(self) -> None:
self.declarations: Set[str] = set()

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
self.declarations.add(get_node_name(node.target))

def visit_For(self, node: ast.For) -> None:
self.declarations.add(get_node_name(node.target))
for n in node.body:
self.visit(n)

class VariableRenamer(ast.NodeTransformer):
"""
Renames variables in a fused ast according to a map
"""

def __init__(self, name_map: Dict[Tuple[str, int], str], workunit_idx: int):
self.name_map = name_map
self.workunit_idx = workunit_idx

def visit_Name(self, node: ast.Name) -> None:
key = (node.id, self.workunit_idx)
# If the name is not mapped, keep the original name
node.id = self.name_map.get(key, node.id)
return node

def visit_keyword(self, node: ast.keyword) -> None:
key = (node.id, self.workunit_idx)
# If the name is not mapped, keep the original name
node.arg = self.name_map.get(key, node.arg)
return node
10 changes: 9 additions & 1 deletion pykokkos/core/module_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,18 @@ def get_entity_dir(self, main: Path, metadata: List[EntityMetadata]) -> Path:

entity_dir: str = ""

for m in metadata:
for m in metadata[:5]:
filename: str = m.path.split("/")[-1].split(".")[0]
entity_dir += f"{filename}_{m.name}"

remaining: str = ""
for m in metadata[5:]:
filename: str = m.path.split("/")[-1].split(".")[0]
remaining += f"{filename}_{m.name}"

if remaining != "":
entity_dir += hashlib.md5(("".join(remaining)).encode()).hexdigest()

return self.get_main_dir(main) / Path(entity_dir)

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions pykokkos/core/optimizations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .loop_fuse import loop_fuse
from .memory_ops_fuse import memory_ops_fuse
Loading
Loading