diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 8d3ba0472..3b4f4bfd8 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -113,7 +113,7 @@ def lgmap(self): indices for this :class:`DataSet`. """ lgmap = PETSc.LGMap() - if self.comm.size == 1: + if self.comm.size == 1 and self.halo is None: lgmap.create(indices=np.arange(self.size, dtype=dtypes.IntType), bsize=self.cdim, comm=self.comm) else: @@ -183,7 +183,7 @@ def local_ises(self): def layout_vec(self): """A PETSc Vec compatible with the dof layout of this DataSet.""" vec = PETSc.Vec().create(comm=self.comm) - size = (self.size * self.cdim, None) + size = ((self.size - self.set.constrained_size) * self.cdim, None) vec.setSizes(size, bsize=self.cdim) vec.setUp() return vec @@ -449,8 +449,8 @@ def lgmap(self): indices for this :class:`MixedDataSet`. """ lgmap = PETSc.LGMap() - if self.comm.size == 1: - size = sum(s.size * s.cdim for s in self) + if self.comm.size == 1 and self.halo is None: + size = sum((s.size - s.constrained_size) * s.cdim for s in self) lgmap.create(indices=np.arange(size, dtype=dtypes.IntType), bsize=1, comm=self.comm) return lgmap @@ -479,7 +479,7 @@ def lgmap(self): # current field offset. idx_size = sum(s.total_size*s.cdim for s in self) indices = np.full(idx_size, -1, dtype=dtypes.IntType) - owned_sz = np.array([sum(s.size * s.cdim for s in self)], + owned_sz = np.array([sum((s.size - s.constrained_size) * s.cdim for s in self)], dtype=dtypes.IntType) field_offset = np.empty_like(owned_sz) self.comm.Scan(owned_sz, field_offset) @@ -493,7 +493,7 @@ def lgmap(self): current_offsets = np.zeros(self.comm.size + 1, dtype=dtypes.IntType) for s in self: idx = indices[start:start + s.total_size * s.cdim] - owned_sz[0] = s.size * s.cdim + owned_sz[0] = (s.size - s.set.constrained_size) * s.cdim self.comm.Scan(owned_sz, field_offset) self.comm.Allgather(field_offset, current_offsets[1:]) # Find the ranks each entry in the l2g belongs to diff --git a/pyop2/types/set.py b/pyop2/types/set.py index 32fb01844..f10c93404 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -64,7 +64,7 @@ def _wrapper_cache_key_(self): @utils.validate_type(('size', (numbers.Integral, tuple, list, np.ndarray), ex.SizeTypeError), ('name', str, ex.NameTypeError)) - def __init__(self, size, name=None, halo=None, comm=None): + def __init__(self, size, name=None, halo=None, comm=None, constrained_size=0): self.comm = mpi.internal_comm(comm, self) if isinstance(size, numbers.Integral): size = [size] * 3 @@ -75,6 +75,8 @@ def __init__(self, size, name=None, halo=None, comm=None): self._name = name or "set_#x%x" % id(self) self._halo = halo self._partition_size = 1024 + self._constrained_size = constrained_size + # A cache of objects built on top of this set self._cache = {} @@ -88,6 +90,10 @@ def core_size(self): """Core set size. Owned elements not touching halo elements.""" return self._sizes[Set._CORE_SIZE] + @utils.cached_property + def constrained_size(self): + return self._constrained_size + @utils.cached_property def size(self): """Set size, owned elements.""" @@ -588,6 +594,11 @@ def core_size(self): """Core set size. Owned elements not touching halo elements.""" return sum(s.core_size for s in self._sets) + @utils.cached_property + def constrained_size(self): + """Set size, owned constrained elements.""" + return sum(s.constrained_size for s in self._sets) + @utils.cached_property def size(self): """Set size, owned elements."""