diff --git a/FIAT/restricted.py b/FIAT/restricted.py index 2dd9abf35..b985f14b4 100644 --- a/FIAT/restricted.py +++ b/FIAT/restricted.py @@ -11,14 +11,14 @@ class RestrictedElement(CiarletElement): """Restrict given element to specified list of dofs.""" - def __init__(self, element, indices=None, restriction_domain=None): + def __init__(self, element, indices=None, restriction_domain=None, take_closure=True): '''For sake of argument, indices overrides restriction_domain''' if not (indices or restriction_domain): raise RuntimeError("Either indices or restriction_domain must be passed in") if not indices: - indices = _get_indices(element, restriction_domain) + indices = _get_indices(element, restriction_domain, take_closure) if isinstance(indices, str): raise RuntimeError("variable 'indices' was a string; did you forget to use a keyword?") @@ -70,7 +70,7 @@ def _key(x): return sorted(mapping.items(), key=_key) -def _get_indices(element, restriction_domain): +def _get_indices(element, restriction_domain, take_closure): "Restriction domain can be 'interior', 'vertex', 'edge', 'face' or 'facet'" if restriction_domain == "interior": @@ -91,9 +91,10 @@ def _get_indices(element, restriction_domain): is_prodcell = isinstance(max(element.entity_dofs().keys()), tuple) + ldim = 0 if take_closure else dim entity_dofs = element.entity_dofs() indices = [] - for d in range(dim + 1): + for d in range(ldim, dim + 1): if is_prodcell: for a in range(d + 1): b = d - a