Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalise VariableIndex and FlexiblyIndexed #317

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 268 additions & 57 deletions gem/gem.py

Large diffs are not rendered by default.

47 changes: 40 additions & 7 deletions gem/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
expression DAG languages."""

import collections
import gem


class Node(object):
Expand Down Expand Up @@ -99,8 +100,23 @@ def get_hash(self):
return hash((type(self),) + self._cons_args(self.children))


def _make_traversal_children(node):
if isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)):
# Include child nodes hidden in index expressions.
return node.children + node.indirect_children
else:
return node.children


def pre_traversal(expression_dags):
"""Pre-order traversal of the nodes of expression DAGs."""
"""Pre-order traversal of the nodes of expression DAGs.

Notes
-----
This function also walks through nodes in index expressions
(e.g., `VariableIndex`s); see ``_make_traversal_children()``.

"""
seen = set()
lifo = []
# Some roots might be same, but they must be visited only once.
Expand All @@ -114,14 +130,23 @@ def pre_traversal(expression_dags):
while lifo:
node = lifo.pop()
yield node
for child in reversed(node.children):
children = _make_traversal_children(node)
for child in reversed(children):
if child not in seen:
seen.add(child)
lifo.append(child)


def post_traversal(expression_dags):
"""Post-order traversal of the nodes of expression DAGs."""
"""Post-order traversal of the nodes of expression DAGs.

Notes
-----
This function also walks through nodes in index expressions
(e.g., `VariableIndex`s); see ``_make_traversal_children()``.


"""
seen = set()
lifo = []
# Some roots might be same, but they must be visited only once.
Expand All @@ -130,13 +155,13 @@ def post_traversal(expression_dags):
for root in expression_dags:
if root not in seen:
seen.add(root)
lifo.append((root, list(root.children)))
lifo.append((root, list(_make_traversal_children(root))))

while lifo:
node, deps = lifo[-1]
for i, dep in enumerate(deps):
if dep is not None and dep not in seen:
lifo.append((dep, list(dep.children)))
lifo.append((dep, list(_make_traversal_children(dep))))
deps[i] = None
break
else:
Expand All @@ -150,10 +175,18 @@ def post_traversal(expression_dags):


def collect_refcount(expression_dags):
"""Collects reference counts for a multi-root expression DAG."""
"""Collects reference counts for a multi-root expression DAG.

Notes
-----
This function also collects reference counts of nodes
in index expressions (e.g., `VariableIndex`s); see
``_make_traversal_children()``.

"""
result = collections.Counter(expression_dags)
for node in traversal(expression_dags):
result.update(node.children)
result.update(_make_traversal_children(node))
return result


Expand Down
27 changes: 21 additions & 6 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import OrderedDict, defaultdict
from functools import singledispatch, partial, reduce
from itertools import combinations, permutations, zip_longest
from numbers import Integral

import numpy

Expand Down Expand Up @@ -95,11 +96,19 @@ def replace_indices(node, self, subst):
replace_indices.register(Node)(reuse_if_untouched_arg)


def _replace_indices_atomic(i, self, subst):
if isinstance(i, VariableIndex):
new_expr = self(i.expression, subst)
return i if new_expr == i.expression else VariableIndex(new_expr)
else:
substitute = dict(subst)
return substitute.get(i, i)


@replace_indices.register(Delta)
def replace_indices_delta(node, self, subst):
substitute = dict(subst)
i = substitute.get(node.i, node.i)
j = substitute.get(node.j, node.j)
i = _replace_indices_atomic(node.i, self, subst)
j = _replace_indices_atomic(node.j, self, subst)
if i == node.i and j == node.j:
return node
else:
Expand All @@ -110,7 +119,9 @@ def replace_indices_delta(node, self, subst):
def replace_indices_indexed(node, self, subst):
child, = node.children
substitute = dict(subst)
multiindex = tuple(substitute.get(i, i) for i in node.multiindex)
multiindex = []
for i in node.multiindex:
multiindex.append(_replace_indices_atomic(i, self, subst))
if isinstance(child, ComponentTensor):
# Indexing into ComponentTensor
# Inline ComponentTensor and augment the substitution rules
Expand All @@ -130,9 +141,11 @@ def replace_indices_flexiblyindexed(node, self, subst):
child, = node.children
assert not child.free_indices

substitute = dict(subst)
dim2idxs = tuple(
(offset, tuple((substitute.get(i, i), s) for i, s in idxs))
(
offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst),
tuple((_replace_indices_atomic(i, self, subst), s if isinstance(s, Integral) else self(s, subst)) for i, s in idxs)
)
for offset, idxs in node.dim2idxs
)

Expand All @@ -145,6 +158,8 @@ def replace_indices_flexiblyindexed(node, self, subst):
def filtered_replace_indices(node, self, subst):
"""Wrapper for :func:`replace_indices`. At each call removes
substitution rules that do not apply."""
if any(isinstance(k, VariableIndex) for k, _ in subst):
raise NotImplementedError("Can not replace VariableIndex (will need inverse)")
filtered_subst = tuple((k, v) for k, v in subst if k in node.free_indices)
return replace_indices(node, self, filtered_subst)

Expand Down
9 changes: 7 additions & 2 deletions gem/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import collections
import functools
import itertools

from gem import gem, impero
from gem.node import collect_refcount
Expand Down Expand Up @@ -116,8 +117,12 @@ def handle(ops, push, decref, node):
elif isinstance(node, gem.Zero): # should rarely happen
assert not node.shape
elif isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)):
# Indexing always inlined
decref(node.children[0])
if node.indirect_children:
# Do not inline;
# Index expression can be involved if it contains VariableIndex.
ops.append(impero.Evaluate(node))
for child in itertools.chain(node.children, node.indirect_children):
decref(child)
elif isinstance(node, gem.IndexSum):
ops.append(impero.Noop(node))
push(impero.Accumulate(node))
Expand Down
3 changes: 0 additions & 3 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, *,
:arg log: bool if the Kernel should be profiled with Log events
:returns: a kernel constructed by the kernel interface
"""
if integral_data.domain.ufl_cell().cellname() == "hexahedron" and \
integral_data.integral_type == "interior_facet":
raise NotImplementedError("interior facet integration in hex meshes not currently supported")
parameters = preprocess_parameters(parameters)
if interface is None:
interface = firedrake_interface_loopy.KernelBuilder
Expand Down
54 changes: 51 additions & 3 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import gem
import numpy
import ufl
from FIAT.reference_element import UFCSimplex, make_affine_mapping
from FIAT.orientation_utils import Orientation as FIATOrientation
from FIAT.reference_element import UFCHexahedron, UFCSimplex, make_affine_mapping
from FIAT.reference_element import TensorProductCell
from finat.physically_mapped import (NeedsCoordinateMappingElement,
PhysicalGeometry)
from finat.point_set import PointSet, PointSingleton
Expand Down Expand Up @@ -108,6 +110,10 @@ def translator(self):
# NOTE: reference cycle!
return Translator(self)

@cached_property
def use_canonical_quadrature_point_ordering(self):
return isinstance(self.fiat_cell, UFCHexahedron) and self.integral_type in ['exterior_facet', 'interior_facet']


class CoordinateMapping(PhysicalGeometry):
"""Callback class that provides physical geometry to FInAT elements.
Expand Down Expand Up @@ -266,10 +272,13 @@ class PointSetContext(ContextBase):
'weight_expr',
)

@cached_property
def integration_cell(self):
return self.fiat_cell.construct_subelement(self.integration_dim)

@cached_property
def quadrature_rule(self):
integration_cell = self.fiat_cell.construct_subelement(self.integration_dim)
return make_quadrature(integration_cell, self.quadrature_degree)
return make_quadrature(self.integration_cell, self.quadrature_degree)

@cached_property
def point_set(self):
Expand Down Expand Up @@ -629,6 +638,11 @@ def callback(entity_id):
# lives on after ditching FFC and switching to FInAT.
return ffc_rounding(square, ctx.epsilon)
table = ctx.entity_selector(callback, mt.restriction)
if ctx.use_canonical_quadrature_point_ordering:
quad_multiindex = ctx.quadrature_rule.point_set.indices
quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx)
mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices)
table = mapper(table, tuple(zip(quad_multiindex, quad_multiindex_permuted)))
return gem.ComponentTensor(gem.Indexed(table, argument_multiindex + sigma), sigma)


Expand Down Expand Up @@ -698,9 +712,43 @@ def take_singleton(xs):
for node in traversal((result,))
if isinstance(node, gem.Literal)):
result = gem.optimise.aggressive_unroll(result)

if ctx.use_canonical_quadrature_point_ordering:
quad_multiindex = ctx.quadrature_rule.point_set.indices
quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx)
mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices)
result = mapper(result, tuple(zip(quad_multiindex, quad_multiindex_permuted)))
return result


def _make_quad_multiindex_permuted(mt, ctx):
quad_rule = ctx.quadrature_rule
# Note that each quad index here represents quad points on a physical
# cell axis, but the table is indexed by indices representing the points
# on each reference cell axis, so we need to apply permutation based on the orientation.
cell = quad_rule.ref_el
quad_multiindex = quad_rule.point_set.indices
if isinstance(cell, TensorProductCell):
for comp in set(cell.cells):
extents = set(q.extent for c, q in zip(cell.cells, quad_multiindex) if c == comp)
if len(extents) != 1:
raise ValueError("Must have the same number of quadrature points in each symmetric axis")
quad_multiindex_permuted = []
o = ctx.entity_orientation(mt.restriction)
if not isinstance(o, FIATOrientation):
raise ValueError(f"Expecting an instance of FIATOrientation : got {o}")
eo = cell.extract_extrinsic_orientation(o)
eo_perm_map = gem.Literal(quad_rule.extrinsic_orientation_permutation_map, dtype=gem.uint_type)
for ref_axis in range(len(quad_multiindex)):
io = cell.extract_intrinsic_orientation(o, ref_axis)
io_perm_map = gem.Literal(quad_rule.intrinsic_orientation_permutation_map_tuple[ref_axis], dtype=gem.uint_type)
# Effectively swap axes if needed.
ref_index = tuple((phys_index, gem.Indexed(eo_perm_map, (eo, ref_axis, phys_axis))) for phys_axis, phys_index in enumerate(quad_multiindex))
quad_index_permuted = gem.VariableIndex(gem.FlexiblyIndexed(io_perm_map, ((0, ((io, 1), )), (0, ref_index))))
quad_multiindex_permuted.append(quad_index_permuted)
return tuple(quad_multiindex_permuted)


def compile_ufl(expression, context, interior_facet=False, point_sum=False):
"""Translate a UFL expression to GEM.

Expand Down
8 changes: 8 additions & 0 deletions tsfc/kernel_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ class ExteriorFacetKernelArg(KernelArg):

class InteriorFacetKernelArg(KernelArg):
...


class ExteriorFacetOrientationKernelArg(KernelArg):
...


class InteriorFacetOrientationKernelArg(KernelArg):
...
4 changes: 4 additions & 0 deletions tsfc/kernel_interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def cell_size(self, restriction):
def entity_number(self, restriction):
"""Facet or vertex number as a GEM index."""

@abstractmethod
def entity_orientation(self, restriction):
"""Entity orientation as a GEM index."""

@abstractmethod
def create_element(self, element, **kwargs):
"""Create a FInAT element (suitable for tabulating with) given
Expand Down
5 changes: 5 additions & 0 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def entity_number(self, restriction):
# Assume self._entity_number dict is set up at this point.
return self._entity_number[restriction]

def entity_orientation(self, restriction):
"""Facet orientation as a GEM index."""
# Assume self._entity_orientation dict is set up at this point.
return self._entity_orientation[restriction]

def apply_glue(self, prepare=None, finalise=None):
"""Append glue code for operations that are not handled in the
GEM abstraction.
Expand Down
22 changes: 21 additions & 1 deletion tsfc/kernel_interface/firedrake_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import loopy as lp

from tsfc import kernel_args
from tsfc import kernel_args, fem
from tsfc.finatinterface import create_element
from tsfc.kernel_interface.common import KernelBuilderBase as _KernelBuilderBase, KernelBuilderMixin, get_index_names, check_requirements, prepare_coefficient, prepare_arguments, prepare_constant
from tsfc.loopy import generate as generate_loopy
Expand Down Expand Up @@ -259,14 +259,26 @@ def __init__(self, integral_data_info, scalar_type,
if integral_type in ['exterior_facet', 'exterior_facet_vert']:
facet = gem.Variable('facet', (1,))
self._entity_number = {None: gem.VariableIndex(gem.Indexed(facet, (0,)))}
facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type)
self._entity_orientation = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))}
elif integral_type in ['interior_facet', 'interior_facet_vert']:
facet = gem.Variable('facet', (2,))
self._entity_number = {
'+': gem.VariableIndex(gem.Indexed(facet, (0,))),
'-': gem.VariableIndex(gem.Indexed(facet, (1,)))
}
facet_orientation = gem.Variable('facet_orientation', (2,), dtype=gem.uint_type)
self._entity_orientation = {
'+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))),
'-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (1,)))
}
elif integral_type == 'interior_facet_horiz':
self._entity_number = {'+': 1, '-': 0}
facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) # base mesh entity orientation
self._entity_orientation = {
'+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))),
'-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))
}

self.set_arguments(integral_data_info.arguments)
self.integral_data_info = integral_data_info
Expand Down Expand Up @@ -406,6 +418,14 @@ def construct_kernel(self, name, ctx, log=False):
elif info.integral_type in ["interior_facet", "interior_facet_vert"]:
int_loopy_arg = lp.GlobalArg("facet", numpy.uint32, shape=(2,))
args.append(kernel_args.InteriorFacetKernelArg(int_loopy_arg))
# Will generalise this in the submesh PR.
if fem.PointSetContext(**self.fem_config()).use_canonical_quadrature_point_ordering:
if info.integral_type == "exterior_facet":
ext_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(1,))
args.append(kernel_args.ExteriorFacetOrientationKernelArg(ext_ornt_loopy_arg))
elif info.integral_type == "interior_facet":
int_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(2,))
args.append(kernel_args.InteriorFacetOrientationKernelArg(int_ornt_loopy_arg))
for name_, shape in tabulations:
tab_loopy_arg = lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape)
args.append(kernel_args.TabulationKernelArg(tab_loopy_arg))
Expand Down
Loading
Loading