diff --git a/test/test_simplify.py b/test/test_simplify.py index dcd3e06ba..d398681e9 100755 --- a/test/test_simplify.py +++ b/test/test_simplify.py @@ -32,7 +32,7 @@ triangle, ) from ufl.algorithms import compute_form_data -from ufl.core.multiindex import FixedIndex, MultiIndex +from ufl.core.multiindex import FixedIndex, MultiIndex, indices from ufl.finiteelement import FiniteElement from ufl.indexed import Indexed from ufl.pullback import identity_pullback @@ -193,3 +193,53 @@ def test_nested_indexed(self): multiindex = MultiIndex((FixedIndex(0),)) assert Indexed(expr, multiindex) is expr[0] assert Indexed(expr, multiindex) is comps[1] + + +def test_repeated_indexing(self): + # Test that an Indexed with repeated indices does not contract indices + shape = (2, 2) + element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1) + domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1)) + space = FunctionSpace(domain, element) + x = Coefficient(space) + C = as_tensor([x, x]) + + fi = FixedIndex(0) + i, = indices(1) + ii = MultiIndex((fi, i, i)) + expr = Indexed(C, ii) + assert i.count() in expr.ufl_free_indices + assert isinstance(expr, Indexed) + B, jj = expr.ufl_operands + assert B is x + assert tuple(jj) == tuple(ii[1:]) + + +def test_untangle_indexed_component_tensor(self): + shape = (2, 2, 2, 2) + element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1) + domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1)) + space = FunctionSpace(domain, element) + C = Coefficient(space) + + r = len(shape) + kk = indices(r) + + # Untangle as_tensor(C[kk], kk) -> C + B = as_tensor(Indexed(C, MultiIndex(kk)), kk) + assert B is C + + # Untangle as_tensor(C[kk], jj)[ii] -> C[kk] + jj = kk[2:] + A = as_tensor(Indexed(C, MultiIndex(kk)), jj) + assert A is not C + + ii = kk + expr = Indexed(A, MultiIndex(ii)) + assert isinstance(expr, Indexed) + B, ll = expr.ufl_operands + assert B is C + + rep = dict(zip(jj, ii)) + expected = tuple(rep.get(k, k) for k in kk) + assert tuple(ll) == expected