Skip to content

Commit

Permalink
Merge pull request #307 from firedrakeproject/mscroggs/gdim
Browse files Browse the repository at this point in the history
Update UFL element interface
  • Loading branch information
pbrubeck authored Nov 15, 2024
2 parents 8c1c4c0 + 947d74f commit e06308d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/test_interpolation_factorisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_sum_factorisation_scalar_tensor(mesh, element):
source = element(degree - 1)
target = element(degree)
tensor_flops = flop_count(mesh, source, target)
expect = numpy.prod(target.value_shape)
expect = FunctionSpace(mesh, target).value_size
if isinstance(target, FiniteElement):
scalar_flops = tensor_flops
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tsfc_204.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def test_physically_mapped_facet():
V = FiniteElement("P", mesh.ufl_cell(), 1)
R = FiniteElement("P", mesh.ufl_cell(), 1)
Vv = VectorElement(BrokenElement(V))
Qhat = VectorElement(BrokenElement(V[facet]))
Vhat = VectorElement(V[facet])
Qhat = VectorElement(BrokenElement(V[facet]), dim=2)
Vhat = VectorElement(V[facet], dim=2)
Z = FunctionSpace(mesh, MixedElement(U, Vv, Qhat, Vhat, R))

z = Coefficient(Z)
Expand Down
2 changes: 1 addition & 1 deletion tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def prepare_coefficient(coefficient, name, interior_facet=False):

if coefficient.ufl_element().family() == 'Real':
# Constant
value_size = coefficient.ufl_element().value_size
value_size = coefficient.ufl_function_space().value_size
expression = gem.reshape(gem.Variable(name, (value_size,)),
coefficient.ufl_shape)
return expression
Expand Down
9 changes: 5 additions & 4 deletions tsfc/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,9 @@ def apply_mapping(expression, element, domain):
mesh = domain
if domain is not None and mesh != domain:
raise NotImplementedError("Multiple domains not supported")
if expression.ufl_shape != element.value_shape:
raise ValueError(f"Mismatching shapes, got {expression.ufl_shape}, expected {element.value_shape}")
pvs = element.pullback.physical_value_shape(element, mesh)
if expression.ufl_shape != pvs:
raise ValueError(f"Mismatching shapes, got {expression.ufl_shape}, expected {pvs}")
mapping = element.mapping().lower()
if mapping == "identity":
rexpression = expression
Expand Down Expand Up @@ -451,7 +452,7 @@ def apply_mapping(expression, element, domain):
sub_elem = element.sub_elements[0]
shape = expression.ufl_shape
flat = ufl.as_vector([expression[i] for i in numpy.ndindex(shape)])
vs = sub_elem.value_shape
vs = sub_elem.pullback.physical_value_shape(sub_elem, mesh)
rvs = sub_elem.reference_value_shape
seen = set()
rpieces = []
Expand All @@ -472,7 +473,7 @@ def apply_mapping(expression, element, domain):
# And reshape
rexpression = as_tensor(numpy.asarray(rpieces).reshape(element.reference_value_shape))
else:
raise NotImplementedError(f"Don't know how to handle mapping type {mapping} for expression of rank {element.value_shape}")
raise NotImplementedError(f"Don't know how to handle mapping type {mapping} for expression of rank {ufl.FunctionSpace(mesh, element).value_shape}")
if rexpression.ufl_shape != element.reference_value_shape:
raise ValueError(f"Mismatching reference shapes, got {rexpression.ufl_shape} expected {element.reference_value_shape}")
return rexpression
Expand Down

0 comments on commit e06308d

Please sign in to comment.