Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 17, 2025
1 parent cc239f8 commit 9e802c2
Showing 1 changed file with 51 additions and 1 deletion.
52 changes: 51 additions & 1 deletion test/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
triangle,
)
from ufl.algorithms import compute_form_data
from ufl.core.multiindex import FixedIndex, MultiIndex
from ufl.core.multiindex import FixedIndex, MultiIndex, indices
from ufl.finiteelement import FiniteElement
from ufl.indexed import Indexed
from ufl.pullback import identity_pullback
Expand Down Expand Up @@ -193,3 +193,53 @@ def test_nested_indexed(self):
multiindex = MultiIndex((FixedIndex(0),))
assert Indexed(expr, multiindex) is expr[0]
assert Indexed(expr, multiindex) is comps[1]


def test_repeated_indexing(self):
# Test that an Indexed with repeated indices does not contract indices
shape = (2, 2)
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
x = Coefficient(space)
C = as_tensor([x, x])

fi = FixedIndex(0)
i, = indices(1)
ii = MultiIndex((fi, i, i))
expr = Indexed(C, ii)
assert i.count() in expr.ufl_free_indices
assert isinstance(expr, Indexed)
B, jj = expr.ufl_operands
assert B is x
assert tuple(jj) == tuple(ii[1:])


def test_untangle_indexed_component_tensor(self):
shape = (2, 2, 2, 2)
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
C = Coefficient(space)

r = len(shape)
kk = indices(r)

# Untangle as_tensor(C[kk], kk) -> C
B = as_tensor(Indexed(C, MultiIndex(kk)), kk)
assert B is C

# Untangle as_tensor(C[kk], jj)[ii] -> C[kk]
jj = kk[2:]
A = as_tensor(Indexed(C, MultiIndex(kk)), jj)
assert A is not C

ii = kk
expr = Indexed(A, MultiIndex(ii))
assert isinstance(expr, Indexed)
B, ll = expr.ufl_operands
assert B is C

rep = dict(zip(jj, ii))
expected = tuple(rep.get(k, k) for k in kk)
assert tuple(ll) == expected

0 comments on commit 9e802c2

Please sign in to comment.