Skip to content

Commit

Permalink
Indexed: only initialise new instances
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 15, 2025
1 parent cd1d21b commit df1ddf3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
19 changes: 19 additions & 0 deletions test/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
triangle,
)
from ufl.algorithms import compute_form_data
from ufl.core.multiindex import FixedIndex, MultiIndex
from ufl.finiteelement import FiniteElement
from ufl.indexed import Indexed
from ufl.pullback import identity_pullback
from ufl.sobolevspace import H1

Expand Down Expand Up @@ -174,3 +176,20 @@ def test_tensor_from_indexed(self, shape):
space = FunctionSpace(domain, element)
f = Coefficient(space)
assert as_tensor(reshape([f[i] for i in ndindex(f.ufl_shape)], f.ufl_shape).tolist()) is f


def test_nested_indexed(self):
# Test that a nested Indexed expression simplifies to the existing Indexed object
shape = (2,)
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
f = Coefficient(space)

comps = tuple(f[i] for i in range(2))
assert all(isinstance(c, Indexed) for c in comps)
expr = as_tensor(list(reversed(comps)))

multiindex = MultiIndex((FixedIndex(0),))
assert Indexed(expr, multiindex) is expr[0]
assert Indexed(expr, multiindex) is comps[1]
13 changes: 9 additions & 4 deletions ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Indexed(Operator):
"""Indexed expression."""

__slots__ = (
"_initialised",
"ufl_free_indices",
"ufl_index_dimensions",
)
Expand All @@ -46,13 +47,16 @@ def __new__(cls, expression, multiindex):

try:
# Simplify indexed ListTensor
c = expression[multiindex]
return Indexed(*c.ufl_operands) if isinstance(c, Indexed) else c
return expression[multiindex]
except ValueError:
return Operator.__new__(cls)
self = Operator.__new__(cls)
self._initialised = False
return self

def __init__(self, expression, multiindex):
"""Initialise."""
if self._initialised:
return
# Store operands
Operator.__init__(self, (expression, multiindex))

Expand Down Expand Up @@ -81,7 +85,7 @@ def __init__(self, expression, multiindex):
efi = expression.ufl_free_indices
efid = expression.ufl_index_dimensions
fi = list(zip(efi, efid))
for pos, ind in enumerate(multiindex._indices):
for pos, ind in enumerate(multiindex):
if isinstance(ind, Index):
fi.append((ind.count(), shape[pos]))
fi = unique_sorted_indices(sorted(fi))
Expand All @@ -93,6 +97,7 @@ def __init__(self, expression, multiindex):
# Cache free index and dimensions
self.ufl_free_indices = fi
self.ufl_index_dimensions = fid
self._initialised = True

ufl_shape = ()

Expand Down

0 comments on commit df1ddf3

Please sign in to comment.