From df1ddf3b30f2a5c80c118058027115dd00b847e0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 15 Jan 2025 12:32:37 +0000 Subject: [PATCH] Indexed: only initialise new instances --- test/test_simplify.py | 19 +++++++++++++++++++ ufl/indexed.py | 13 +++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/test/test_simplify.py b/test/test_simplify.py index 9eff3843b..dcd3e06ba 100755 --- a/test/test_simplify.py +++ b/test/test_simplify.py @@ -32,7 +32,9 @@ triangle, ) from ufl.algorithms import compute_form_data +from ufl.core.multiindex import FixedIndex, MultiIndex from ufl.finiteelement import FiniteElement +from ufl.indexed import Indexed from ufl.pullback import identity_pullback from ufl.sobolevspace import H1 @@ -174,3 +176,20 @@ def test_tensor_from_indexed(self, shape): space = FunctionSpace(domain, element) f = Coefficient(space) assert as_tensor(reshape([f[i] for i in ndindex(f.ufl_shape)], f.ufl_shape).tolist()) is f + + +def test_nested_indexed(self): + # Test that a nested Indexed expression simplifies to the existing Indexed object + shape = (2,) + element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1) + domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1)) + space = FunctionSpace(domain, element) + f = Coefficient(space) + + comps = tuple(f[i] for i in range(2)) + assert all(isinstance(c, Indexed) for c in comps) + expr = as_tensor(list(reversed(comps))) + + multiindex = MultiIndex((FixedIndex(0),)) + assert Indexed(expr, multiindex) is expr[0] + assert Indexed(expr, multiindex) is comps[1] diff --git a/ufl/indexed.py b/ufl/indexed.py index f4433f223..338033413 100644 --- a/ufl/indexed.py +++ b/ufl/indexed.py @@ -20,6 +20,7 @@ class Indexed(Operator): """Indexed expression.""" __slots__ = ( + "_initialised", "ufl_free_indices", "ufl_index_dimensions", ) @@ -46,13 +47,16 @@ def __new__(cls, expression, multiindex): try: # Simplify indexed ListTensor - c = expression[multiindex] - return Indexed(*c.ufl_operands) if isinstance(c, Indexed) else c + return expression[multiindex] except ValueError: - return Operator.__new__(cls) + self = Operator.__new__(cls) + self._initialised = False + return self def __init__(self, expression, multiindex): """Initialise.""" + if self._initialised: + return # Store operands Operator.__init__(self, (expression, multiindex)) @@ -81,7 +85,7 @@ def __init__(self, expression, multiindex): efi = expression.ufl_free_indices efid = expression.ufl_index_dimensions fi = list(zip(efi, efid)) - for pos, ind in enumerate(multiindex._indices): + for pos, ind in enumerate(multiindex): if isinstance(ind, Index): fi.append((ind.count(), shape[pos])) fi = unique_sorted_indices(sorted(fi)) @@ -93,6 +97,7 @@ def __init__(self, expression, multiindex): # Cache free index and dimensions self.ufl_free_indices = fi self.ufl_index_dimensions = fid + self._initialised = True ufl_shape = ()