Skip to content

Commit

Permalink
move applyrestriction
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Nov 13, 2024
1 parent f7212dd commit 58a53f1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
15 changes: 13 additions & 2 deletions ufl/algorithms/apply_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.algorithms.formdata import FormData
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.classes import Restricted
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction
from ufl.domain import extract_unique_domain, MixedMesh
from ufl.form import Form
from ufl.measure import integral_type_to_measure_name
from ufl.sobolevspace import H1
from ufl.classes import ReferenceGrad, ReferenceValue
Expand Down Expand Up @@ -190,8 +192,17 @@ def apply_restrictions(expression, assume_single_integral_type=True):
# ``exterior_facet`` and the latter ``interior_facet``.
integral_types = None
rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type)
return map_integrand_dags(rules, expression,
only_integral_type=integral_types)
if isinstance(expression, Form):
return map_integrand_dags(rules, expression, only_integral_type=integral_types)
elif isinstance(expression, FormData):
for integral_data in expression.integral_data:
integral_data.integrals = tuple(
map_integrand_dags(rules, integral, only_integral_type=integral_types)
for integral in integral_data.integrals
)
return expression
else:
raise NotImplementedError(f"Unable to handle {type(expression)}")


class DefaultRestrictionApplier(MultiFunction):
Expand Down
14 changes: 7 additions & 7 deletions ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,6 @@ def compute_form_data(
# Apply '?' restrictions in general multi-domain problems
form = apply_default_restrictions(form, assume_single_integral_type=have_single_domain)

# Propagate restrictions to terminals
if do_apply_restrictions:
if do_assume_single_integral_type:
form = apply_restrictions(form)
else:
form = apply_restrictions(form, assume_single_integral_type=have_single_domain)

# If in real mode, remove any complex nodes introduced during form processing.
if not complex_mode:
form = remove_complex_nodes(form)
Expand All @@ -350,6 +343,13 @@ def compute_form_data(
# Most of the heavy lifting is done above in group_form_integrals.
self.integral_data = build_integral_data(form.integrals())

# Propagate restrictions to terminals
if do_apply_restrictions:
if do_assume_single_integral_type:
apply_restrictions(self)
else:
apply_restrictions(self, assume_single_integral_type=have_single_domain)

# --- Create replacements for arguments and coefficients

# Figure out which form coefficients each integral should enable
Expand Down

0 comments on commit 58a53f1

Please sign in to comment.