Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement loop fusion optimization inside kernels #253

Merged
merged 27 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
76d7fbe
Fusion: remove redundant str() call
NaderAlAwar Jan 25, 2024
6af0b2a
Fusion: move visitors to a separate file to reuse them in loop fusion
NaderAlAwar Jan 25, 2024
8fd27c3
Translators: add index in body to each AST node
NaderAlAwar Jan 25, 2024
a1b3633
Optimizations: add the loop fusion optimization
NaderAlAwar Jan 25, 2024
778dff7
Compiler: add env variable to do loop fusion
NaderAlAwar Jan 25, 2024
c266301
Translators: add idx in parent body attribute to all nodes
NaderAlAwar Jan 27, 2024
74f5cdd
Optimizations: account for fused_* additions from fused kernels when …
NaderAlAwar Jan 27, 2024
dae698b
PyKokkosVisitors: add pk.cpp_auto as an allowed type (maps to auto in…
NaderAlAwar Jan 31, 2024
f8060eb
StaticTranslator: fix idx_in_loop to work for all AST nodes with lists
NaderAlAwar Jan 31, 2024
61ab4c7
Optimizations: move common functionality to util.py in preparation fo…
NaderAlAwar Jan 31, 2024
c10f7f4
Optimizations: add memory_ops_fuse optimization
NaderAlAwar Jan 31, 2024
a8928a8
Compiler: restrict optimizations to workunit and fused styles only
NaderAlAwar Jan 31, 2024
9eb72b4
test_loop_fusion.py: Added basic and manual tests for loop fusion
HannanNaeem Jan 30, 2024
3563f68
run_loop_fusion_test.py: added basic automation and comparison
HannanNaeem Jan 31, 2024
e7295e0
cleaned up and refactored tests to work with pytest
HannanNaeem Jan 31, 2024
5ec2bc0
Added simple neg dist test
HannanNaeem Feb 1, 2024
ee61112
Optimizations: track constant variables used in memory accesses to fa…
NaderAlAwar Feb 2, 2024
6dad848
Fusion: implent rsub() for future and add check if data does not depe…
NaderAlAwar Feb 12, 2024
c103ea5
mypy: add missing ignore
NaderAlAwar Feb 13, 2024
53a2f44
mypy: add missing ignore
NaderAlAwar Feb 13, 2024
edc53f0
StaticTranslator: fix issue with assigning parent_accessor to string …
NaderAlAwar Feb 13, 2024
89186d5
Fusion: overload more operators for Future
NaderAlAwar Feb 16, 2024
3c733be
Fusion: truncate long kernel names to avoid errors
NaderAlAwar Feb 16, 2024
faa9bc0
ModuleSetup: truncate long file names in fused kernels to avoid filen…
NaderAlAwar Feb 16, 2024
95ccc6e
WorkunitVisitor: fix issue with getting accumulator argument in MDRan…
NaderAlAwar Feb 16, 2024
76cc291
Tests: fix formatting and spelling
NaderAlAwar Feb 19, 2024
65cbeeb
Optimizations: don't fuse loops with print statements
NaderAlAwar Feb 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Dict, List, Optional, Tuple

from pykokkos.core.fusion import fuse_workunits
from pykokkos.core.optimizations import loop_fuse, memory_ops_fuse
from pykokkos.core.parsers import Parser, PyKokkosEntity, PyKokkosStyles
from pykokkos.core.translators import PyKokkosMembers, StaticTranslator
from pykokkos.core.type_inference import UpdatedTypes, UpdatedDecorator
Expand Down Expand Up @@ -223,6 +224,11 @@ def compile_entity(
bindings: List[str]
cast: List[str]

if entity.style in {PyKokkosStyles.workunit, PyKokkosStyles.fused}:
if "PK_LOOP_FUSE" in os.environ:
loop_fuse(entity.AST)
if "PK_MEM_FUSE" in os.environ:
memory_ops_fuse(entity.AST, entity.pk_import)
functor, bindings, cast = translator.translate(entity, classtypes)

t_end: float = time.perf_counter() - t_start
Expand Down
47 changes: 1 addition & 46 deletions pykokkos/core/fusion/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,7 @@
import os
from typing import Any, Dict, List, Set, Tuple, Union

def get_node_name(node: Union[ast.Attribute, ast.Name]) -> str:
"""
Copied from visitors_util.py due to circular import
"""

name: str
if isinstance(node, ast.Attribute):
name = node.attr
else:
name = node.id

return name


class DeclarationsVisitor(ast.NodeVisitor):
"""
Get all variable declarations
"""

def __init__(self) -> None:
self.declarations: Set[str] = set()

def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
self.declarations.add(get_node_name(node.target))


class VariableRenamer(ast.NodeTransformer):
"""
Renames variables in a fused ast according to a map
"""

def __init__(self, name_map: Dict[Tuple[str, int], str], workunit_idx: int):
self.name_map = name_map
self.workunit_idx = workunit_idx

def visit_Name(self, node: ast.Name) -> Any:
key = (node.id, self.workunit_idx)
# If the name is not mapped, keep the original name
node.id = self.name_map.get(key, node.id)
return node

def visit_keyword(self, node: ast.keyword) -> Any:
key = (node.id, self.workunit_idx)
# If the name is not mapped, keep the original name
node.arg = self.name_map.get(key, node.arg)
return node
from .util import DeclarationsVisitor, VariableRenamer


def fuse_workunit_kwargs_and_params(
Expand Down
2 changes: 1 addition & 1 deletion pykokkos/core/fusion/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __str__(self):
return str(self.value)

def __repr__(self) -> str:
return str(f"Future(value={self.value})")
return f"Future(value={self.value})"

def flush_trace(self) -> None:
runtime_singleton.runtime.flush_data(self)
Expand Down
54 changes: 54 additions & 0 deletions pykokkos/core/fusion/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import ast
from typing import Dict, Set, Tuple, Union


def get_node_name(node: Union[ast.Attribute, ast.Name]) -> str:
"""
Copied from visitors_util.py due to circular import
"""

name: str
if isinstance(node, ast.Attribute):
name = node.attr
else:
name = node.id

return name


class DeclarationsVisitor(ast.NodeVisitor):
"""
Get all variable declarations
"""

def __init__(self) -> None:
self.declarations: Set[str] = set()

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
self.declarations.add(get_node_name(node.target))

def visit_For(self, node: ast.For) -> None:
self.declarations.add(get_node_name(node.target))
for n in node.body:
self.visit(n)

class VariableRenamer(ast.NodeTransformer):
"""
Renames variables in a fused ast according to a map
"""

def __init__(self, name_map: Dict[Tuple[str, int], str], workunit_idx: int):
self.name_map = name_map
self.workunit_idx = workunit_idx

def visit_Name(self, node: ast.Name) -> None:
key = (node.id, self.workunit_idx)
# If the name is not mapped, keep the original name
node.id = self.name_map.get(key, node.id)
return node

def visit_keyword(self, node: ast.keyword) -> None:
key = (node.id, self.workunit_idx)
# If the name is not mapped, keep the original name
node.arg = self.name_map.get(key, node.arg)
return node
2 changes: 2 additions & 0 deletions pykokkos/core/optimizations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .loop_fuse import loop_fuse
from .memory_ops_fuse import memory_ops_fuse
Loading
Loading