Skip to content

Commit

Permalink
ruff check .
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Jan 16, 2025
1 parent cce3154 commit 20495df
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 44 deletions.
40 changes: 20 additions & 20 deletions test/test_mixed_function_space_with_mixed_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,30 +152,30 @@ def test_mixed_function_space_with_mixed_mesh_raise():
# Make sure that all mixed functions are split when applying default restrictions.
form = div(g1('+')) * div(f1('-')) * dS1
with pytest.raises(RuntimeError) as e_info:
fd = compute_form_data(form,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(CellVolume, FacetArea),
do_apply_restrictions=True,
do_estimate_degrees=True,
do_split_coefficients=(f,),
do_assume_single_integral_type=False,
complex_mode=False)
_ = compute_form_data(form,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(CellVolume, FacetArea),
do_apply_restrictions=True,
do_estimate_degrees=True,
do_split_coefficients=(f,),
do_assume_single_integral_type=False,
complex_mode=False)
assert e_info.match("Not expecting a terminal object on a mixed mesh at this stage")
# Make sure that g1 is restricted as f1.
form = div(g1) * div(f1('-')) * dS1
with pytest.raises(ValueError) as e_info:
fd = compute_form_data(form,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(CellVolume, FacetArea),
do_apply_restrictions=True,
do_estimate_degrees=True,
do_split_coefficients=(f, g),
do_assume_single_integral_type=False,
complex_mode=False)
_ = compute_form_data(form,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(CellVolume, FacetArea),
do_apply_restrictions=True,
do_estimate_degrees=True,
do_split_coefficients=(f, g),
do_assume_single_integral_type=False,
complex_mode=False)
assert e_info.match("Discontinuous type Coefficient must be restricted.")


Expand Down
51 changes: 42 additions & 9 deletions ufl/algorithms/apply_coefficient_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module contains classes and functions to split coefficients defined on mixed function spaces.
"""

import numpy
import numpy as np

from ufl import indices
from ufl.checks import is_cellwise_constant
Expand All @@ -30,14 +30,22 @@


class CoefficientSplitter(MultiFunction):
"""Split mixed coefficients into the components."""

def __init__(self, coefficient_split):
def __init__(self, coefficient_split: dict):
"""Initialise.
Args:
coefficient_split: map from coefficients to the components.
"""
MultiFunction.__init__(self)
self._coefficient_split = coefficient_split

expr = MultiFunction.reuse_if_untouched

def modified_terminal(self, o):
"""Handle modified terminals."""
restriction = None
local_derivatives = 0
reference_value = False
Expand All @@ -56,7 +64,9 @@ def modified_terminal(self, o):
restriction = t._side
t, = t.ufl_operands
elif t._ufl_terminal_modifiers_:
raise ValueError(f"Missing handler for terminal modifier type {type(t)}, object is {t!r}.")
raise ValueError(
f"Missing handler for terminal modifier type {type(t)}, object is {t!r}."
)
else:
raise ValueError(f"Unexpected type {type(t)} object {t!r}.")
if not isinstance(t, Coefficient):
Expand Down Expand Up @@ -92,7 +102,7 @@ def modified_terminal(self, o):
elif restriction is not None:
raise RuntimeError(f"Got unknown restriction: {restriction}")
# Collect components of the subcoefficient
for alpha in numpy.ndindex(subcoeff.ufl_element().reference_value_shape):
for alpha in np.ndindex(subcoeff.ufl_element().reference_value_shape):
# New modified terminal: component[alpha + beta]
components.append(c[alpha + beta])
# Repack derivative indices to shape
Expand All @@ -107,12 +117,12 @@ def modified_terminal(self, o):


def apply_coefficient_split(expr, coefficient_split):
"""Split mixed coefficients, so mixed elements need not be
implemented.
"""Split mixed coefficients, so mixed elements need not be implemented.
:arg split: A :py:class:`dict` mapping each mixed coefficient to a
sequence of subcoefficients. If None, calling this
function is a no-op.
"""
if coefficient_split is None:
return expr
Expand All @@ -121,8 +131,15 @@ def apply_coefficient_split(expr, coefficient_split):


class FixedIndexRemover(MultiFunction):
"""Handle FixedIndex."""

def __init__(self, fimap: dict):
"""Initialise.
Args:
fimap: map for index replacements.
def __init__(self, fimap):
"""
MultiFunction.__init__(self)
self.fimap = fimap
self._object_cache = {}
Expand All @@ -131,6 +148,7 @@ def __init__(self, fimap):

@memoized_handler
def zero(self, o):
"""Handle Zero."""
free_indices = []
index_dimensions = []
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
Expand All @@ -142,10 +160,15 @@ def zero(self, o):
else:
free_indices.append(i)
index_dimensions.append(d)
return Zero(shape=o.ufl_shape, free_indices=tuple(free_indices), index_dimensions=tuple(index_dimensions))
return Zero(
shape=o.ufl_shape,
free_indices=tuple(free_indices),
index_dimensions=tuple(index_dimensions)
)

@memoized_handler
def list_tensor(self, o):
"""Handle ListTensor."""
cc = []
for o1 in o.ufl_operands:
comp = map_expr_dag(self, o1)
Expand All @@ -154,28 +177,37 @@ def list_tensor(self, o):

@memoized_handler
def multi_index(self, o):
"""Handle MultiIndex."""
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices()))


class IndexRemover(MultiFunction):
"""Remove Indexed."""

def __init__(self):
"""Initialise."""
MultiFunction.__init__(self)
self._object_cache = {}

expr = MultiFunction.reuse_if_untouched

@memoized_handler
def _zero_simplify(self, o):
"""Apply simplification for Zero()."""
operand, = o.ufl_operands
operand = map_expr_dag(self, operand)
if isinstance(operand, Zero):
return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions)
return Zero(
shape=o.ufl_shape,
free_indices=o.ufl_free_indices,
index_dimensions=o.ufl_index_dimensions
)
else:
return o._ufl_expr_reconstruct_(operand)

@memoized_handler
def indexed(self, o):
"""Simplify indexed ComponentTensor and ListTensor."""
o1, i1 = o.ufl_operands
if isinstance(o1, ComponentTensor):
o2, i2 = o1.ufl_operands
Expand Down Expand Up @@ -203,6 +235,7 @@ def indexed(self, o):


def remove_component_and_list_tensors(o):
"""Remove component and list tensors."""
if isinstance(o, Form):
integrals = []
for integral in o.integrals():
Expand Down
58 changes: 45 additions & 13 deletions ufl/algorithms/apply_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def _require_restriction(self, o):
if self.default_restriction is not None:
domain = extract_unique_domain(o, expand_mixed_mesh=False)
if isinstance(domain, MeshSequence):
raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {o!r}")
raise RuntimeError(
f"Not expecting a terminal object on a mixed mesh at this stage: found {o!r}"
)
if isinstance(self.default_restriction, dict):
r = self.default_restriction[domain]
else:
Expand All @@ -90,7 +92,9 @@ def _default_restricted(self, o):
if self.default_restriction is not None:
domain = extract_unique_domain(o, expand_mixed_mesh=False)
if isinstance(domain, MeshSequence):
raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {o!r}")
raise RuntimeError(
f"Not expecting a terminal object on a mixed mesh at this stage: found {o!r}"
)
if isinstance(self.default_restriction, dict):
if domain not in self.default_restriction:
raise RuntimeError(f"Integral type on {domain} not known")
Expand All @@ -115,15 +119,19 @@ def _opposite(self, o):
if isinstance(self.default_restriction, dict):
domain = extract_unique_domain(o, expand_mixed_mesh=False)
if isinstance(domain, MeshSequence):
raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {o!r}")
raise RuntimeError(
f"Not expecting a terminal object on a mixed mesh at this stage: found {o!r}"
)
if domain not in self.default_restriction:
raise RuntimeError(f"Integral type on {domain} not known")
r = self.default_restriction[domain]
else:
r = self.default_restriction
if r is None:
if self.current_restriction is not None:
raise ValueError(f"Expecting current_restriction None: got {self.current_restriction}")
raise ValueError(
f"Expecting current_restriction None: got {self.current_restriction}"
)
return o
else:
if self.current_restriction is None:
Expand Down Expand Up @@ -239,7 +247,11 @@ def facet_normal(self, o):
return self._require_restriction(o)


def apply_restrictions(expression, assume_single_integral_type=True, domain_integral_type_map=None):
def apply_restrictions(
expression,
assume_single_integral_type=True,
domain_integral_type_map=None
):
"""Propagate restriction nodes to wrap differential terminals directly."""
if assume_single_integral_type:
# Hnadle the conventional single-domain case.
Expand Down Expand Up @@ -271,7 +283,10 @@ def apply_restrictions(expression, assume_single_integral_type=True, domain_inte
# the integral type of a given function; e.g., the former can be
# ``exterior_facet`` and the latter ``interior_facet``.
integral_types = None
rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type, default_restriction=default_restriction)
rules = RestrictionPropagator(
assume_single_integral_type=assume_single_integral_type,
default_restriction=default_restriction
)
if isinstance(expression, FormData):
for integral_data in expression.integral_data:
integral_data.integrals = tuple(
Expand All @@ -288,9 +303,15 @@ class DomainRestrictionMapMaker(MultiFunction):
Inspect the DAG and collect domain-restrictions map.
This must be done per integral_data.
"""
def __init__(self, domain_restriction_map: dict):
"""Initialise.
Args:
domain_restriction_map: map from domains to the restrictions.
def __init__(self, domain_restriction_map):
"""
MultiFunction.__init__(self)
self._domain_restriction_map = domain_restriction_map

Expand All @@ -315,12 +336,16 @@ def _modifier(self, o):
restriction = t._side
t, = t.ufl_operands
elif t._ufl_terminal_modifiers_:
raise ValueError(f"Missing handler for terminal modifier type {type(t)}, object is {t!r}.")
raise ValueError(
f"Missing handler for terminal modifier type {type(t)}, object is {t!r}."
)
else:
raise ValueError(f"Unexpected type {type(t)} object {t!r}.")
domain = extract_unique_domain(t, expand_mixed_mesh=False)
if isinstance(domain, MeshSequence):
raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {t!r}")
raise RuntimeError(
f"Not expecting a terminal object on a mixed mesh at this stage: found {t!r}"
)
if domain is not None:
if domain not in self._domain_restriction_map:
self._domain_restriction_map[domain] = set()
Expand All @@ -347,6 +372,7 @@ def make_domain_restriction_map(integral_data):


def make_domain_integral_type_map(integral_data):
"""Make a map from domains to the integral types."""
domain_restriction_map = make_domain_restriction_map(integral_data)
integration_domain = integral_data.domain
integration_type = integral_data.integral_type
Expand All @@ -361,15 +387,21 @@ def make_domain_integral_type_map(integral_data):
elif integration_type in ["exterior_facet", "interior_facet"]:
domain_integral_type_dict[d] = "exterior_facet"
else:
raise NotImplementedError(f"Not implemented for integration type {integration_type}")
raise NotImplementedError(
f"Not implemented for integration type {integration_type}"
)
else:
raise NotImplementedError("Not implemented for meshes of multiple topological dimensions")
raise NotImplementedError(
"Not implemented for meshes of multiple topological dimensions"
)
else:
raise RuntimeError(f"Found inconsistent restrictions {rs} for domain {d}")
if integration_domain in domain_integral_type_dict:
if domain_integral_type_dict[integration_domain] != integration_type:
raise RuntimeError(f"""Found inconsistent integral types for the integration domain ({integration_domain}) :
{domain_integral_type_dict[integration_domain]} != {integration_type}""")
raise RuntimeError(f"""
Found inconsistent integral types for the integration domain ({integration_domain}):
{domain_integral_type_dict[integration_domain]} != {integration_type}
""")
else:
domain_integral_type_dict[integration_domain] = integration_type
return domain_integral_type_dict
Expand Down
6 changes: 4 additions & 2 deletions ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,10 @@ def compute_form_data(
else:
# Inspect the form and apply default restrictions.
if do_split_coefficients is None:
raise ValueError("""Need to pass 'do_split_coefficients=tuple_of_coefficients_to_splilt'
for general multi-domain problems""")
raise ValueError("""
Need to pass 'do_split_coefficients=tuple_of_coefficients_to_splilt'
for general multi-domain problems
""")
for itg_data in self.integral_data:
# Must have split coefficients and removed component/list tensors.
itg_data.domain_integral_type_map = make_domain_integral_type_map(itg_data)
Expand Down

0 comments on commit 20495df

Please sign in to comment.