Skip to content

Commit

Permalink
Passthrough params (#708)
Browse files Browse the repository at this point in the history
Pass objects to local kernels without packing and unpacking.

---------

Co-authored-by: Connor Ward <[email protected]>
  • Loading branch information
pbrubeck and connorjward authored Jan 24, 2024
1 parent ad0c430 commit e0a4d3a
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 17 deletions.
40 changes: 29 additions & 11 deletions pyop2/codegen/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from functools import reduce

import numpy
from loopy.types import OpaqueType
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg)
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg, PassthroughKernelArg)
from pyop2.codegen.representation import (Accumulate, Argument, Comparison, Conditional,
DummyInstruction, Extent, FixedIndex,
FunctionCall, Index, Indexed,
Expand All @@ -16,16 +15,13 @@
PreUnpackInst, Product, RuntimeIndex,
Sum, Symbol, UnpackInst, Variable,
When, Zero)
from pyop2.datatypes import IntType
from pyop2.datatypes import IntType, OpaqueType
from pyop2.op2 import (ALL, INC, MAX, MIN, ON_BOTTOM, ON_INTERIOR_FACETS,
ON_TOP, READ, RW, WRITE)
from pyop2.utils import cached_property


class PetscMat(OpaqueType):

def __init__(self):
super().__init__(name="Mat")
MatType = OpaqueType("Mat")


def _Remainder(a, b):
Expand Down Expand Up @@ -226,6 +222,23 @@ def emit_unpack_instruction(self, *, loop_indices=None):
"""Either yield an instruction, or else return an empty tuple (to indicate no instruction)"""


class PassthroughPack(Pack):
def __init__(self, outer):
self.outer = outer

def kernel_arg(self, loop_indices=None):
return self.outer

def pack(self, loop_indices=None):
pass

def emit_pack_instruction(self, **kwargs):
return ()

def emit_unpack_instruction(self, **kwargs):
return ()


class GlobalPack(Pack):

def __init__(self, outer, access, init_with_zero=False):
Expand Down Expand Up @@ -813,7 +826,12 @@ def add_argument(self, arg):
dtype = local_arg.dtype
interior_horizontal = self.iteration_region == ON_INTERIOR_FACETS

if isinstance(arg, GlobalKernelArg):
if isinstance(arg, PassthroughKernelArg):
argument = Argument((), dtype, pfx="arg")
pack = PassthroughPack(argument)
self.arguments.append(argument)

elif isinstance(arg, GlobalKernelArg):
argument = Argument(arg.dim, dtype, pfx="glob")

pack = GlobalPack(argument, access,
Expand Down Expand Up @@ -856,7 +874,7 @@ def add_argument(self, arg):
pack = MixedDatPack(packs, access, dtype,
interior_horizontal=interior_horizontal)
elif isinstance(arg, MatKernelArg):
argument = Argument((), PetscMat(), pfx="mat")
argument = Argument((), MatType, pfx="mat")
maps = tuple(self._add_map(m, arg.unroll)
for m in arg.maps)
pack = arg.pack(argument, access, maps,
Expand All @@ -866,7 +884,7 @@ def add_argument(self, arg):
elif isinstance(arg, MixedMatKernelArg):
packs = []
for a in arg:
argument = Argument((), PetscMat(), pfx="mat")
argument = Argument((), MatType, pfx="mat")
maps = tuple(self._add_map(m, a.unroll)
for m in a.maps)

Expand Down Expand Up @@ -949,7 +967,7 @@ def kernel_call(self):
args = self.kernel_args
access = tuple(self.loopy_argument_accesses)
# assuming every index is free index
free_indices = set(itertools.chain.from_iterable(arg.multiindex for arg in args))
free_indices = set(itertools.chain.from_iterable(arg.multiindex for arg in args if isinstance(arg, Indexed)))
# remove runtime index
free_indices = tuple(i for i in free_indices if isinstance(i, Index))
if self.pass_layer_to_kernel:
Expand Down
2 changes: 0 additions & 2 deletions pyop2/codegen/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,6 @@ def __new__(cls, aggregate, multiindex):
for index, extent in zip(multiindex, aggregate.shape):
if isinstance(index, Index):
index.set_extent(extent)
if not multiindex:
return aggregate

self = super().__new__(cls)
self.children = (aggregate, multiindex)
Expand Down
8 changes: 8 additions & 0 deletions pyop2/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ def dtype_limits(dtype):
except ValueError as e:
raise ValueError("Unable to determine numeric limits from %s" % dtype) from e
return info.min, info.max


class OpaqueType(lp.types.OpaqueType):
def __init__(self, name):
super().__init__(name=name)

def __repr__(self):
return self.name
10 changes: 10 additions & 0 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ def pack(self):
return DatPack


class PassthroughKernelArg:
@property
def cache_key(self):
return type(self)

@property
def maps(self):
return ()


@dataclass(frozen=True)
class MixedMatKernelArg:
"""Class representing a :class:`pyop2.types.MixedDat` being passed to the kernel.
Expand Down
3 changes: 2 additions & 1 deletion pyop2/op2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import atexit

from pyop2.configuration import configuration
from pyop2.datatypes import OpaqueType # noqa: F401
from pyop2.logger import debug, info, warning, error, critical, set_log_level
from pyop2.mpi import MPI, COMM_WORLD, collective

Expand All @@ -52,7 +53,7 @@
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, # noqa: F401
MatKernelArg, MixedMatKernelArg, MapKernelArg, GlobalKernel)
from pyop2.parloop import (GlobalParloopArg, DatParloopArg, MixedDatParloopArg, # noqa: F401
MatParloopArg, MixedMatParloopArg, Parloop, parloop, par_loop)
MatParloopArg, MixedMatParloopArg, PassthroughArg, Parloop, parloop, par_loop)
from pyop2.parloop import (GlobalLegacyArg, DatLegacyArg, MixedDatLegacyArg, # noqa: F401
MatLegacyArg, MixedMatLegacyArg, LegacyParloop, ParLoop)

Expand Down
88 changes: 85 additions & 3 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyop2.datatypes import as_numpy_dtype
from pyop2.exceptions import KernelTypeError, MapValueError, SetTypeError
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
MatKernelArg, MixedMatKernelArg, GlobalKernel)
MatKernelArg, MixedMatKernelArg, PassthroughKernelArg, GlobalKernel)
from pyop2.local_kernel import LocalKernel, CStringLocalKernel, LoopyLocalKernel
from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set,
MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap)
Expand All @@ -39,6 +39,10 @@ class GlobalParloopArg(ParloopArg):

data: Global

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
return ()
Expand All @@ -59,6 +63,10 @@ def __post_init__(self):
if self.map_ is not None:
self.check_map(self.map_)

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
return self.map_._kernel_args_ if self.map_ else ()
Expand All @@ -81,6 +89,10 @@ class MixedDatParloopArg(ParloopArg):
def __post_init__(self):
self.check_map(self.map_)

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
return self.map_._kernel_args_ if self.map_ else ()
Expand All @@ -102,6 +114,10 @@ def __post_init__(self):
for m in self.maps:
self.check_map(m)

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
rmap, cmap = self.maps
Expand All @@ -120,12 +136,34 @@ def __post_init__(self):
for m in self.maps:
self.check_map(m)

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
rmap, cmap = self.maps
return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_)))


@dataclass
class PassthroughParloopArg(ParloopArg):
# a pointer
data: int

@property
def _kernel_args_(self):
return (self.data,)

@property
def map_kernel_args(self):
return ()

@property
def maps(self):
return ()


class Parloop:
"""A parallel loop invocation.
Expand Down Expand Up @@ -167,7 +205,7 @@ def arglist(self):
"""Prepare the argument list for calling generated code."""
arglist = self.iterset._kernel_args_
for d in self.arguments:
arglist += d.data._kernel_args_
arglist += d._kernel_args_

# Collect an ordered set of maps (ignore duplicates)
maps = {m: None for d in self.arguments for m in d.map_kernel_args}
Expand Down Expand Up @@ -224,6 +262,8 @@ def __call__(self):
def increment_dat_version(self):
"""Increment dat versions of :class:`DataCarrier`s in the arguments."""
for lk_arg, gk_arg, pl_arg in self.zipped_arguments:
if isinstance(pl_arg, PassthroughParloopArg):
continue
assert isinstance(pl_arg.data, DataCarrier)
if lk_arg.access is not Access.READ:
if pl_arg.data in self.reduced_globals:
Expand Down Expand Up @@ -520,6 +560,10 @@ class GlobalLegacyArg(LegacyArg):
data: Global
access: Access

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
return GlobalKernelArg(self.data.dim)
Expand All @@ -537,6 +581,10 @@ class DatLegacyArg(LegacyArg):
map_: Optional[Map]
access: Access

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
map_arg = self.map_._global_kernel_arg if self.map_ is not None else None
Expand All @@ -556,6 +604,10 @@ class MixedDatLegacyArg(LegacyArg):
map_: MixedMap
access: Access

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
args = []
Expand All @@ -579,6 +631,10 @@ class MatLegacyArg(LegacyArg):
lgmaps: Optional[Tuple[Any, Any]] = None
needs_unrolling: Optional[bool] = False

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
map_args = [m._global_kernel_arg for m in self.maps]
Expand All @@ -599,6 +655,10 @@ class MixedMatLegacyArg(LegacyArg):
lgmaps: Tuple[Any] = None
needs_unrolling: Optional[bool] = False

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
nrows, ncols = self.data.sparsity.shape
Expand All @@ -618,6 +678,28 @@ def parloop_arg(self):
return MixedMatParloopArg(self.data, tuple(self.maps), self.lgmaps)


@dataclass
class PassthroughArg(LegacyArg):
"""Argument that is simply passed to the local kernel without packing.
:param dtype: The datatype of the argument. This is needed for code generation.
:param data: A pointer to the data.
"""
# We don't know what the local kernel is doing with this argument
access = Access.RW

dtype: Any
data: int

@property
def global_kernel_arg(self):
return PassthroughKernelArg()

@property
def parloop_arg(self):
return PassthroughParloopArg(self.data)


def ParLoop(*args, **kwargs):
return LegacyParloop(*args, **kwargs)

Expand All @@ -641,7 +723,7 @@ def LegacyParloop(local_knl, iterset, *args, **kwargs):
# finish building the local kernel
local_knl.accesses = tuple(a.access for a in args)
if isinstance(local_knl, CStringLocalKernel):
local_knl.dtypes = tuple(a.data.dtype for a in args)
local_knl.dtypes = tuple(a.dtype for a in args)

global_knl_args = tuple(a.global_kernel_arg for a in args)
extruded = iterset._extruded
Expand Down
36 changes: 36 additions & 0 deletions test/unit/test_direct_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import pytest
import numpy as np
from petsc4py import PETSc

from pyop2 import op2
from pyop2.exceptions import MapValueError
Expand Down Expand Up @@ -249,6 +250,41 @@ def test_kernel_cplusplus(self, delems):

assert (y.data == 10.5).all()

def test_passthrough_mat(self):
niters = 10
iterset = op2.Set(niters)

c_kernel = """
static void mat_inc(Mat mat) {
PetscScalar values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
PetscInt idxs[] = {0, 2, 4};
MatSetValues(mat, 3, idxs, 3, idxs, values, ADD_VALUES);
}
"""
kernel = op2.Kernel(c_kernel, "mat_inc")

# create a tiny 5x5 sparse matrix
petsc_mat = PETSc.Mat().create()
petsc_mat.setSizes(5)
petsc_mat.setUp()
petsc_mat.setValues([0, 2, 4], [0, 2, 4], np.zeros((3, 3), dtype=PETSc.ScalarType))
petsc_mat.assemble()

arg = op2.PassthroughArg(op2.OpaqueType("Mat"), petsc_mat.handle)
op2.par_loop(kernel, iterset, arg)
petsc_mat.assemble()

assert np.allclose(
petsc_mat.getValues(range(5), range(5)),
[
[10, 0, 20, 0, 30],
[0]*5,
[40, 0, 50, 0, 60],
[0]*5,
[70, 0, 80, 0, 90],
]
)


if __name__ == '__main__':
import os
Expand Down

0 comments on commit e0a4d3a

Please sign in to comment.