Skip to content

Commit

Permalink
composed map: add permute method (#723)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Connor Ward <[email protected]>
  • Loading branch information
ksagiyam and connorjward authored Jun 18, 2024
1 parent af813e9 commit 5f18075
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions pyop2/codegen/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def shape(self):
def dtype(self):
return self.values.dtype

def indexed(self, multiindex, layer=None, permute=lambda x: x):
def _permute(self, x):
return x

def indexed(self, multiindex, layer=None):
n, i, f = multiindex
if layer is not None and self.offset is not None:
# For extruded mesh, prefetch the indirections for each map, so that they don't
Expand All @@ -84,7 +87,7 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
base_key = None
if base_key not in self.prefetch:
j = Index()
base = Indexed(self.values, (n, permute(j)))
base = Indexed(self.values, (n, self._permute(j)))
self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j))

base = self.prefetch[base_key]
Expand Down Expand Up @@ -122,17 +125,17 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
return Indexed(self.prefetch[key], (f, i)), (f, i)
else:
assert f.extent == 1 or f.extent is None
base = Indexed(self.values, (n, permute(i)))
base = Indexed(self.values, (n, self._permute(i)))
return base, (f, i)

def indexed_vector(self, n, shape, layer=None, permute=lambda x: x):
def indexed_vector(self, n, shape, layer=None):
shape = self.shape[1:] + shape
if self.interior_horizontal:
shape = (2, ) + shape
else:
shape = (1, ) + shape
f, i, j = (Index(e) for e in shape)
base, (f, i) = self.indexed((n, i, f), layer=layer, permute=permute)
base, (f, i) = self.indexed((n, i, f), layer=layer)
init = Sum(Product(base, Literal(numpy.int32(j.extent))), j)
pack = Materialise(PackInst(), init, MultiIndex(f, i, j))
multiindex = tuple(Index(e) for e in pack.shape)
Expand Down Expand Up @@ -168,13 +171,8 @@ def __init__(self, map_, permutation):
self.offset_quotient = map_.offset_quotient
self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}")

def indexed(self, multiindex, layer=None):
permute = lambda x: Indexed(self.permutation, (x,))
return super().indexed(multiindex, layer=layer, permute=permute)

def indexed_vector(self, n, shape, layer=None):
permute = lambda x: Indexed(self.permutation, (x,))
return super().indexed_vector(n, shape, layer=layer, permute=permute)
def _permute(self, x):
return Indexed(self.permutation, (x,))


class CMap(Map):
Expand Down

0 comments on commit 5f18075

Please sign in to comment.