Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passthrough params #708

Merged
merged 15 commits into from
Jan 24, 2024
41 changes: 28 additions & 13 deletions pyop2/codegen/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from functools import reduce

import numpy
from loopy.types import OpaqueType
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg)
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg, PassthroughKernelArg)
from pyop2.codegen.representation import (Accumulate, Argument, Comparison, Conditional,
DummyInstruction, Extent, FixedIndex,
FunctionCall, Index, Indexed,
Expand All @@ -16,18 +15,12 @@
PreUnpackInst, Product, RuntimeIndex,
Sum, Symbol, UnpackInst, Variable,
When, Zero)
from pyop2.datatypes import IntType
from pyop2.datatypes import IntType, PetscMatType
from pyop2.op2 import (ALL, INC, MAX, MIN, ON_BOTTOM, ON_INTERIOR_FACETS,
ON_TOP, READ, RW, WRITE)
from pyop2.utils import cached_property


class PetscMat(OpaqueType):

def __init__(self):
super().__init__(name="Mat")


def _Remainder(a, b):
# ad hoc replacement of Remainder()
# Replace this with Remainder(a, b) once it gets fixed.
Expand Down Expand Up @@ -226,6 +219,23 @@ def emit_unpack_instruction(self, *, loop_indices=None):
"""Either yield an instruction, or else return an empty tuple (to indicate no instruction)"""


class PassthroughPack(Pack):
def __init__(self, outer):
self.outer = outer

def kernel_arg(self, loop_indices=None):
return self.outer

def pack(self, loop_indices=None):
pass

def emit_pack_instruction(self, **kwargs):
return ()

def emit_unpack_instruction(self, **kwargs):
return ()


class GlobalPack(Pack):

def __init__(self, outer, access, init_with_zero=False):
Expand Down Expand Up @@ -813,7 +823,12 @@ def add_argument(self, arg):
dtype = local_arg.dtype
interior_horizontal = self.iteration_region == ON_INTERIOR_FACETS

if isinstance(arg, GlobalKernelArg):
if isinstance(arg, PassthroughKernelArg):
argument = Argument((), dtype, pfx="arg")
pack = PassthroughPack(argument)
self.arguments.append(argument)

elif isinstance(arg, GlobalKernelArg):
argument = Argument(arg.dim, dtype, pfx="glob")

pack = GlobalPack(argument, access,
Expand Down Expand Up @@ -856,7 +871,7 @@ def add_argument(self, arg):
pack = MixedDatPack(packs, access, dtype,
interior_horizontal=interior_horizontal)
elif isinstance(arg, MatKernelArg):
argument = Argument((), PetscMat(), pfx="mat")
argument = Argument((), PetscMatType(), pfx="mat")
maps = tuple(self._add_map(m, arg.unroll)
for m in arg.maps)
pack = arg.pack(argument, access, maps,
Expand All @@ -866,7 +881,7 @@ def add_argument(self, arg):
elif isinstance(arg, MixedMatKernelArg):
packs = []
for a in arg:
argument = Argument((), PetscMat(), pfx="mat")
argument = Argument((), PetscMatType(), pfx="mat")
maps = tuple(self._add_map(m, a.unroll)
for m in a.maps)

Expand Down Expand Up @@ -949,7 +964,7 @@ def kernel_call(self):
args = self.kernel_args
access = tuple(self.loopy_argument_accesses)
# assuming every index is free index
free_indices = set(itertools.chain.from_iterable(arg.multiindex for arg in args))
free_indices = set(itertools.chain.from_iterable(arg.multiindex for arg in args if isinstance(arg, Indexed)))
# remove runtime index
free_indices = tuple(i for i in free_indices if isinstance(i, Index))
if self.pass_layer_to_kernel:
Expand Down
2 changes: 0 additions & 2 deletions pyop2/codegen/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,6 @@ def __new__(cls, aggregate, multiindex):
for index, extent in zip(multiindex, aggregate.shape):
if isinstance(index, Index):
index.set_extent(extent)
if not multiindex:
return aggregate

self = super().__new__(cls)
self.children = (aggregate, multiindex)
Expand Down
8 changes: 8 additions & 0 deletions pyop2/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ def dtype_limits(dtype):
except ValueError as e:
raise ValueError("Unable to determine numeric limits from %s" % dtype) from e
return info.min, info.max


class PetscMatType(lp.types.OpaqueType):
def __init__(self):
super().__init__(name="Mat")

def __repr__(self):
return type(self).__name__
10 changes: 10 additions & 0 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ def pack(self):
return DatPack


class PassthroughKernelArg:
@property
def cache_key(self):
return type(self)

@property
def maps(self):
return ()


@dataclass(frozen=True)
class MixedMatKernelArg:
"""Class representing a :class:`pyop2.types.MixedDat` being passed to the kernel.
Expand Down
3 changes: 2 additions & 1 deletion pyop2/op2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import atexit

from pyop2.configuration import configuration
from pyop2.datatypes import PetscMatType # noqa: F401
from pyop2.logger import debug, info, warning, error, critical, set_log_level
from pyop2.mpi import MPI, COMM_WORLD, collective

Expand All @@ -52,7 +53,7 @@
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, # noqa: F401
MatKernelArg, MixedMatKernelArg, MapKernelArg, GlobalKernel)
from pyop2.parloop import (GlobalParloopArg, DatParloopArg, MixedDatParloopArg, # noqa: F401
MatParloopArg, MixedMatParloopArg, Parloop, parloop, par_loop)
MatParloopArg, MixedMatParloopArg, PassthroughArg, Parloop, parloop, par_loop)
from pyop2.parloop import (GlobalLegacyArg, DatLegacyArg, MixedDatLegacyArg, # noqa: F401
MatLegacyArg, MixedMatLegacyArg, LegacyParloop, ParLoop)

Expand Down
86 changes: 83 additions & 3 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyop2.datatypes import as_numpy_dtype
from pyop2.exceptions import KernelTypeError, MapValueError, SetTypeError
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
MatKernelArg, MixedMatKernelArg, GlobalKernel)
MatKernelArg, MixedMatKernelArg, PassthroughKernelArg, GlobalKernel)
from pyop2.local_kernel import LocalKernel, CStringLocalKernel, LoopyLocalKernel
from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set,
MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap)
Expand All @@ -38,6 +38,10 @@ class GlobalParloopArg(ParloopArg):

data: Global

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
return ()
Expand All @@ -58,6 +62,10 @@ def __post_init__(self):
if self.map_ is not None:
self.check_map(self.map_)

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
return self.map_._kernel_args_ if self.map_ else ()
Expand All @@ -80,6 +88,10 @@ class MixedDatParloopArg(ParloopArg):
def __post_init__(self):
self.check_map(self.map_)

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
return self.map_._kernel_args_ if self.map_ else ()
Expand All @@ -101,6 +113,10 @@ def __post_init__(self):
for m in self.maps:
self.check_map(m)

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
rmap, cmap = self.maps
Expand All @@ -119,12 +135,34 @@ def __post_init__(self):
for m in self.maps:
self.check_map(m)

@property
def _kernel_args_(self):
return self.data._kernel_args_

@property
def map_kernel_args(self):
rmap, cmap = self.maps
return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_)))


@dataclass
class PassthroughParloopArg(ParloopArg):
# a pointer
data: int

@property
def _kernel_args_(self):
return (self.data,)

@property
def map_kernel_args(self):
return ()

@property
def maps(self):
return ()


class Parloop:
"""A parallel loop invocation.

Expand Down Expand Up @@ -170,7 +208,7 @@ def arglist(self):
"""Prepare the argument list for calling generated code."""
arglist = self.iterset._kernel_args_
for d in self.arguments:
arglist += d.data._kernel_args_
arglist += d._kernel_args_

# Collect an ordered set of maps (ignore duplicates)
maps = {m: None for d in self.arguments for m in d.map_kernel_args}
Expand Down Expand Up @@ -504,6 +542,10 @@ class GlobalLegacyArg(LegacyArg):
data: Global
access: Access

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
return GlobalKernelArg(self.data.dim)
Expand All @@ -521,6 +563,10 @@ class DatLegacyArg(LegacyArg):
map_: Optional[Map]
access: Access

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
map_arg = self.map_._global_kernel_arg if self.map_ is not None else None
Expand All @@ -540,6 +586,10 @@ class MixedDatLegacyArg(LegacyArg):
map_: MixedMap
access: Access

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
args = []
Expand All @@ -563,6 +613,10 @@ class MatLegacyArg(LegacyArg):
lgmaps: Optional[Tuple[Any, Any]] = None
needs_unrolling: Optional[bool] = False

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
map_args = [m._global_kernel_arg for m in self.maps]
Expand All @@ -583,6 +637,10 @@ class MixedMatLegacyArg(LegacyArg):
lgmaps: Tuple[Any] = None
needs_unrolling: Optional[bool] = False

@property
def dtype(self):
return self.data.dtype

@property
def global_kernel_arg(self):
nrows, ncols = self.data.sparsity.shape
Expand All @@ -602,6 +660,28 @@ def parloop_arg(self):
return MixedMatParloopArg(self.data, tuple(self.maps), self.lgmaps)


@dataclass
class PassthroughArg(LegacyArg):
"""Argument that is simply passed to the local kernel without packing.

:param dtype: The datatype of the argument. This is needed for code generation.
:param data: A pointer to the data.
"""
# We don't know what the local kernel is doing with this argument
access = Access.RW

dtype: Any
data: int

@property
def global_kernel_arg(self):
return PassthroughKernelArg()

@property
def parloop_arg(self):
return PassthroughParloopArg(self.data)


def ParLoop(*args, **kwargs):
return LegacyParloop(*args, **kwargs)

Expand All @@ -625,7 +705,7 @@ def LegacyParloop(local_knl, iterset, *args, **kwargs):
# finish building the local kernel
local_knl.accesses = tuple(a.access for a in args)
if isinstance(local_knl, CStringLocalKernel):
local_knl.dtypes = tuple(a.data.dtype for a in args)
local_knl.dtypes = tuple(a.dtype for a in args)

global_knl_args = tuple(a.global_kernel_arg for a in args)
extruded = iterset._extruded
Expand Down
36 changes: 36 additions & 0 deletions test/unit/test_direct_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import pytest
import numpy as np
from petsc4py import PETSc

from pyop2 import op2
from pyop2.exceptions import MapValueError
Expand Down Expand Up @@ -249,6 +250,41 @@ def test_kernel_cplusplus(self, delems):

assert (y.data == 10.5).all()

def test_passthrough_mat(self):
niters = 10
iterset = op2.Set(niters)

c_kernel = """
static void mat_inc(Mat mat) {
PetscScalar values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
PetscInt idxs[] = {0, 2, 4};
MatSetValues(mat, 3, idxs, 3, idxs, values, ADD_VALUES);
}
"""
kernel = op2.Kernel(c_kernel, "mat_inc")

# create a tiny 5x5 sparse matrix
petsc_mat = PETSc.Mat().create()
petsc_mat.setSizes(5)
petsc_mat.setUp()
petsc_mat.setValues([0, 2, 4], [0, 2, 4], np.zeros((3, 3), dtype=PETSc.ScalarType))
petsc_mat.assemble()

arg = op2.PassthroughArg(op2.PetscMatType(), petsc_mat.handle)
op2.par_loop(kernel, iterset, arg)
petsc_mat.assemble()

assert np.allclose(
petsc_mat.getValues(range(5), range(5)),
[
[10, 0, 20, 0, 30],
[0]*5,
[40, 0, 50, 0, 60],
[0]*5,
[70, 0, 80, 0, 90],
]
)


if __name__ == '__main__':
import os
Expand Down
Loading