forked from FEniCS/fiat
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #32 from firedrakeproject/ksagiyam/fix_serendipity
Ksagiyam/fix serendipity
- Loading branch information
Showing
4 changed files
with
47 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
# | ||
# Modified by David A. Ham ([email protected]), 2019 | ||
|
||
import numbers | ||
import sympy | ||
from sympy import symbols, legendre, Array, diff, lambdify | ||
import numpy as np | ||
from FIAT.finite_element import FiniteElement | ||
|
@@ -30,6 +32,24 @@ def tr(n): | |
return int((n-3)*(n-2)/2) | ||
|
||
|
||
def _replace_numbers_with_symbols(polynomials): | ||
# Replace numbers with symbols to work around issue with numpy>=1.24.1; | ||
# see https://github.com/firedrakeproject/fiat/pull/32. | ||
extra_vars = {} # map from numbers to symbols | ||
polynomials_list = [] | ||
for poly in polynomials.tolist(): | ||
if isinstance(poly, numbers.Real): | ||
if poly not in extra_vars: | ||
extra_vars[poly] = symbols('num_' + str(len(extra_vars))) | ||
polynomials_list.append(extra_vars[poly]) | ||
elif isinstance(poly, sympy.core.Expr): | ||
polynomials_list.append(poly) | ||
else: | ||
raise TypeError(f"Unexpected type: {type(poly)}") | ||
polynomials = Array(polynomials_list) | ||
return polynomials, extra_vars | ||
|
||
|
||
class Serendipity(FiniteElement): | ||
|
||
def __new__(cls, ref_el, degree): | ||
|
@@ -104,8 +124,10 @@ def __init__(self, ref_el, degree): | |
super(Serendipity, self).__init__(ref_el=ref_el, dual=None, order=degree, formdegree=formdegree) | ||
|
||
self.basis = {(0,)*dim: Array(s_list)} | ||
self.basis_callable = {(0,)*dim: lambdify(variables[:dim], Array(s_list), | ||
modules="numpy", dummify=True)} | ||
polynomials, extra_vars = _replace_numbers_with_symbols(Array(s_list)) | ||
self.basis_callable = {(0,)*dim: [lambdify(variables[:dim], polynomials, | ||
modules="numpy", dummify=True), | ||
extra_vars]} | ||
topology = ref_el.get_topology() | ||
unflattening_map = compute_unflattening_map(topology) | ||
unflattened_entity_ids = {} | ||
|
@@ -161,16 +183,26 @@ def tabulate(self, order, points, entity=None): | |
alphas = mis(dim, o) | ||
for alpha in alphas: | ||
try: | ||
callable = self.basis_callable[alpha] | ||
callable, extra_vars = self.basis_callable[alpha] | ||
except KeyError: | ||
polynomials = diff(self.basis[(0,)*dim], *zip(variables, alpha)) | ||
callable = lambdify(variables[:dim], polynomials, modules="numpy", dummify=True) | ||
polynomials, extra_vars = _replace_numbers_with_symbols(polynomials) | ||
callable = lambdify(variables[:dim] + tuple(extra_vars.values()), polynomials, modules="numpy", dummify=True) | ||
self.basis[alpha] = polynomials | ||
self.basis_callable[alpha] = callable | ||
tabulation = callable(*(points[:, i] for i in range(pointdim))) | ||
T = np.asarray([np.broadcast_to(tab, (npoints, )) | ||
for tab in tabulation]) | ||
phivals[alpha] = T | ||
self.basis_callable[alpha] = [callable, extra_vars] | ||
# Can no longer make a numpy array from objects of inhomogeneous shape | ||
# (unless we specify `dtype==object`); | ||
# see https://github.com/firedrakeproject/fiat/pull/32. | ||
# | ||
# Casting `key`s to float() is needed, otherwise we somehow get the following error: | ||
# | ||
# E TypeError: unsupported type for persistent hash keying: <class 'complex'> | ||
# | ||
# ../../lib/python3.8/site-packages/pytools/persistent_dict.py:243: TypeError | ||
# | ||
# `key`s have been checked to be numbers.Real. | ||
extra_arrays = [np.ones((npoints, ), dtype=points.dtype) * float(key) for key in extra_vars] | ||
phivals[alpha] = callable(*([points[:, i] for i in range(pointdim)] + extra_arrays)) | ||
return phivals | ||
|
||
def entity_dofs(self): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters