Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RestrictedFunctionSpace: Added in changes to Dataset / Set for use in RestrictedFunctionSpace in firedrake #716

Merged
merged 11 commits into from
Apr 26, 2024
12 changes: 6 additions & 6 deletions pyop2/types/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def lgmap(self):
indices for this :class:`DataSet`.
"""
lgmap = PETSc.LGMap()
if self.comm.size == 1:
if self.comm.size == 1 and self.halo is None:
lgmap.create(indices=np.arange(self.size, dtype=dtypes.IntType),
bsize=self.cdim, comm=self.comm)
else:
Expand Down Expand Up @@ -183,7 +183,7 @@ def local_ises(self):
def layout_vec(self):
"""A PETSc Vec compatible with the dof layout of this DataSet."""
vec = PETSc.Vec().create(comm=self.comm)
size = (self.size * self.cdim, None)
size = ((self.size - self.set.constrained_size) * self.cdim, None)
vec.setSizes(size, bsize=self.cdim)
vec.setUp()
return vec
Expand Down Expand Up @@ -449,8 +449,8 @@ def lgmap(self):
indices for this :class:`MixedDataSet`.
"""
lgmap = PETSc.LGMap()
if self.comm.size == 1:
size = sum(s.size * s.cdim for s in self)
if self.comm.size == 1 and self.halo is None:
size = sum((s.size - s.constrained_size) * s.cdim for s in self)
lgmap.create(indices=np.arange(size, dtype=dtypes.IntType),
bsize=1, comm=self.comm)
return lgmap
Expand Down Expand Up @@ -479,7 +479,7 @@ def lgmap(self):
# current field offset.
idx_size = sum(s.total_size*s.cdim for s in self)
indices = np.full(idx_size, -1, dtype=dtypes.IntType)
owned_sz = np.array([sum(s.size * s.cdim for s in self)],
owned_sz = np.array([sum((s.size - s.constrained_size) * s.cdim for s in self)],
dtype=dtypes.IntType)
field_offset = np.empty_like(owned_sz)
self.comm.Scan(owned_sz, field_offset)
Expand All @@ -493,7 +493,7 @@ def lgmap(self):
current_offsets = np.zeros(self.comm.size + 1, dtype=dtypes.IntType)
for s in self:
idx = indices[start:start + s.total_size * s.cdim]
owned_sz[0] = s.size * s.cdim
owned_sz[0] = (s.size - s.set.constrained_size) * s.cdim
self.comm.Scan(owned_sz, field_offset)
self.comm.Allgather(field_offset, current_offsets[1:])
# Find the ranks each entry in the l2g belongs to
Expand Down
13 changes: 12 additions & 1 deletion pyop2/types/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _wrapper_cache_key_(self):

@utils.validate_type(('size', (numbers.Integral, tuple, list, np.ndarray), ex.SizeTypeError),
('name', str, ex.NameTypeError))
def __init__(self, size, name=None, halo=None, comm=None):
def __init__(self, size, name=None, halo=None, comm=None, constrained_size=0):
self.comm = mpi.internal_comm(comm, self)
if isinstance(size, numbers.Integral):
size = [size] * 3
Expand All @@ -75,6 +75,8 @@ def __init__(self, size, name=None, halo=None, comm=None):
self._name = name or "set_#x%x" % id(self)
self._halo = halo
self._partition_size = 1024
self._constrained_size = constrained_size

# A cache of objects built on top of this set
self._cache = {}

Expand All @@ -88,6 +90,10 @@ def core_size(self):
"""Core set size. Owned elements not touching halo elements."""
return self._sizes[Set._CORE_SIZE]

@utils.cached_property
def constrained_size(self):
return self._constrained_size

@utils.cached_property
def size(self):
"""Set size, owned elements."""
Expand Down Expand Up @@ -588,6 +594,11 @@ def core_size(self):
"""Core set size. Owned elements not touching halo elements."""
return sum(s.core_size for s in self._sets)

@utils.cached_property
def constrained_size(self):
"""Set size, owned constrained elements."""
return sum(s.constrained_size for s in self._sets)

@utils.cached_property
def size(self):
"""Set size, owned elements."""
Expand Down
Loading