diff --git a/docs/reference.rst b/docs/reference.rst index b9f0a4dad..646918508 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -26,6 +26,7 @@ These operations are *horizontal* in the sense that [..] .. autofunction:: gather .. autofunction:: scatter .. autofunction:: scatter_reduce +.. autofunction:: scatter_inc .. autofunction:: ravel .. autofunction:: unravel .. autofunction:: slice diff --git a/drjit/router.py b/drjit/router.py index 04a552cb7..db73dd2b1 100644 --- a/drjit/router.py +++ b/drjit/router.py @@ -973,6 +973,106 @@ def scatter_reduce(op, target, value, index, active=True): return value.scatter_reduce_(op, target, index, active) +def scatter_inc(target, index, active=True): + ''' + Atomically increment a value within an unsigned 32-bit integer array + and return the value prior to the update. + + This operation works just like the :py:func:`drjit.scatter_reduce()` + operation for 32-bit unsigned integer operands, but with a fixed + ``value=1`` parameter and ``reduce_op=ReduceOp::Add``. + + The main difference is that this variant additionally returns the *old* + value of the target array prior to the atomic update in contrast to the + more general scatter-reduction, which just returns ``None``. The operation + also supports masking---the return value in the unmasked case is undefined. + Both ``target`` and ``index`` parameters must be 1D unsigned 32-bit + arrays. + + This operation is a building block for stream compaction: threads can + scatter-increment a global counter to request a spot in an array and then + write their result there. The recipe for this is look as follows: + + .. code-block:: python + + ctr = UInt32(0) # Counter + active = drjit.ones(Bool, len(data_1)) # .. or a more complex condition + + my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) + + dr.scatter( + target=data_compact_1, + value=data_1, + index=my_index, + mask=active + ) + + dr.scatter( + target=data_compact_2, + value=data_2, + index=my_index, + mask=active + ) + + When following this approach, be sure to provide the same mask value to the + :py:func:`drjit.scatter_inc()` and subsequent :py:func:`dr.scatter()` + operations. + + The function :py:func:`drjit.scatter_inc()` exhibits the following unusual + behavior compared to regular Dr.Jit operations: the return value references + the instantaneous state during a potentially large sequence of atomic + operations. This instantaneous state is not reproducible in later kernel + evaluations, and Dr.Jit will refuse to do so when the computed index is + reused. In essence, the variable is "consumed" by the process of + evaluation. + + .. code-block:: python + + my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) + dr.scatter( + target=data_compact_1, + value=data_1, + index=my_index, + mask=active + ) + + dr.eval(data_compact_1) # Run Kernel #1 + + dr.scatter( + target=data_compact_2, + value=data_2, + index=my_index, # <-- oops, reusing my_index in another kernel. + mask=active # This raises an exception. + ) + + To get the above code to work, you will need to evaluate ``my_index`` at + the same time to materialize it into a stored (and therefore trivially + reproducible) representation. For this, ensure that the size of the + ``active`` mask matches ``len(data_*)`` and that it is not the trivial + ``True`` default mask (otherwise, the evaluated ``my_index`` will be + scalar). + + .. code-block:: python + + dr.eval(data_compact_1, my_index) + + Such multi-stage evaluation is potentially inefficient and may defeat the + purpose of performing stream compaction in the first place. In general, + prefer keeping all scatter operations involving the computed index in the + same kernel, and then this issue does not arise. + + The implementation of :py:func:`drjit.scatter_inc()` performs a local + reduction first, followed by a single atomic write per SIMD packet/warp. + This is done to reduce contention from a potentially very large number of + atomic operations targeting the same memory address. Fully masked updates + do not cause memory traffic. + ''' + + if not _dr.is_jit_v(target) or target.Type != _dr.VarType.UInt32 or \ + type(index) is not type(target): + raise Exception('scatter_inc(): invalid input types!') + return type(target).scatter_inc_(target, index, active) + def ravel(array, order='A'): ''' @@ -3800,7 +3900,7 @@ def hypot(a, b): def prefix_sum(value, exclusive=True): - ''' + r''' Compute an exclusive or inclusive prefix sum of the 1D input array. By default, the function returns an output array :math:`\mathbf{y}` of the diff --git a/ext/drjit-core b/ext/drjit-core index a8f95ab9b..1cb66e8f3 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit a8f95ab9bd48bda6a531e5a133c753497734c0e4 +Subproject commit 1cb66e8f385f2ee1de0c0041de71fa0a22972118 diff --git a/include/drjit/array_router.h b/include/drjit/array_router.h index 35880c973..07c79a6d3 100644 --- a/include/drjit/array_router.h +++ b/include/drjit/array_router.h @@ -1173,20 +1173,31 @@ void scatter_reduce(ReduceOp op, Target &&target, const Value &value, } template -void scatter_reduce_kahan(Target &&target_1, Target &&target_2, +void scatter_reduce_kahan(Target &target_1, Target &target_2, const Value &value, const Index &index, const mask_t &mask = true) { static_assert( is_jit_v && is_jit_v && is_jit_v && - array_depth_v == array_depth_v && - array_depth_v == 1, + array_depth_v == 1 && + array_depth_v == 1 && + array_depth_v == 1, "Only flat JIT arrays are supported at the moment"); value.scatter_reduce_kahan_(target_1, target_2, index, mask); } +template +Index scatter_inc(Index &target, + const Index &index, + const mask_t &mask = true) { + static_assert(is_jit_v && array_depth_v == 1, + "Only flat JIT arrays are supported at the moment"); + + return Index::scatter_inc_(target, index, mask); +} + template decltype(auto) migrate(const T &value, TargetType target) { static_assert(std::is_enum_v); diff --git a/include/drjit/autodiff.h b/include/drjit/autodiff.h index cb028ff16..f41b3a11a 100644 --- a/include/drjit/autodiff.h +++ b/include/drjit/autodiff.h @@ -1413,7 +1413,7 @@ struct DiffArray : ArrayBase, is_mask_v, DiffArray> const MaskType &mask = true) const { if constexpr (std::is_scalar_v) { (void) dst_1; (void) dst_2; (void) offset; (void) mask; - drjit_raise("Array scatter_reduce operation not supported for scalar array type."); + drjit_raise("Array scatter_reduce_kahan operation not supported for scalar array type."); } else { scatter_reduce_kahan(dst_1.m_value, dst_2.m_value, m_value, offset.m_value, mask.m_value); @@ -1430,6 +1430,16 @@ struct DiffArray : ArrayBase, is_mask_v, DiffArray> } } + static DiffArray scatter_inc_(DiffArray &dst, const DiffArray &offset, const MaskType &mask) { + if constexpr (std::is_scalar_v) { + (void) dst; (void) offset; (void) mask; + drjit_raise("Array scatter_inc operation not supported for scalar array type."); + } else { + return Type::scatter_inc_(dst.m_value, offset.m_value, mask.m_value); + } + } + + template static DiffArray gather_(const void *src, const IndexType &offset, const MaskType &mask = true) { diff --git a/include/drjit/jit.h b/include/drjit/jit.h index df67ce53c..4680e69db 100644 --- a/include/drjit/jit.h +++ b/include/drjit/jit.h @@ -510,6 +510,14 @@ struct JitArray : ArrayBase, Derived_> { m_index, index.index(), mask.index()); } + template + static Derived scatter_inc_(Derived &dst, const Derived &index, + const Mask &mask) { + static_assert( + std::is_same_v, detached_t>>); + return steal(jit_var_scatter_inc(dst.index_ptr(), index.index(), mask.index())); + } + //! @} // ----------------------------------------------------------------------- diff --git a/src/python/bind.h b/src/python/bind.h index bf4527694..d7d9174ab 100644 --- a/src/python/bind.h +++ b/src/python/bind.h @@ -346,6 +346,12 @@ auto bind_full(py::class_ &cls, bool /* scalar_mode */ = false) { dr::scatter_reduce(op, target, value, index, mask); }, "op"_a, "target"_a.noconvert(), "index"_a, "mask"_a); } + if constexpr (std::is_same_v && dr::is_jit_v) { + cls.def_static("scatter_inc_", + [](Array& target, const Array & index, const Mask& mask) { + return dr::scatter_inc(target, index, mask); + }); + } } if constexpr (dr::is_jit_v) { diff --git a/tests/python/test_reduce.py b/tests/python/test_reduce.py index 8ce346a1a..24a4fb33e 100644 --- a/tests/python/test_reduce.py +++ b/tests/python/test_reduce.py @@ -232,3 +232,19 @@ def test08_prefix_sum(m): assert dr.prefix_sum(t(1, 2, 3)) == t(0, 1, 3) assert dr.prefix_sum(t(1, 2, 3), exclusive=False) == t(1, 3, 6) assert dr.cumsum(t(1, 2, 3)) == t(1, 3, 6) + +def test09_scatter_inc(m): + try: + import numpy as np + except ImportError: + pytest.skip('NumPy is not installed') + n=10000 + counter = m.UInt32(0) + index = dr.arange(m.UInt32, n) + offset = dr.scatter_inc(counter, m.UInt32(0)) + + out = dr.zeros(m.UInt32, n) + dr.scatter(out, offset, index) + out_np = np.array(out) + out_np.sort() + assert np.all(out_np == np.arange(n))