Skip to content

Commit

Permalink
Fully parallelise multi-dimensional scatters. (#2037)
Browse files Browse the repository at this point in the history
Closes #2035.
  • Loading branch information
athas authored Oct 31, 2023
1 parent 22d7dc2 commit 660631b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

* `futhark autotune` how supports `hip` backend.

* Better parallelisation of `scatter` when the target is
multidimensional (#2035).

### Removed

### Changed
Expand Down
48 changes: 48 additions & 0 deletions src/Futhark/Pass/ExtractKernels.hs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,54 @@ transformStm path (Let pat _ (Op (Stream w arrs nes fold_fun))) = do
types <- asksScope scopeForSOACs
transformStms path . stmsToList . snd
=<< runBuilderT (sequentialStreamWholeArray pat w nes fold_fun arrs) types
--
-- When we are scattering into a multidimensional array, we want to
-- fully parallelise, such that we do not have threads writing
-- potentially large rows. We do this by fissioning the scatter into a
-- map part and a scatter part, where the former is flattened as
-- usual, and the latter has a thread per primitive element to be
-- written.
--
-- TODO: this could be slightly smarter. If we are dealing with a
-- horizontally fused Scatter that targets both single- and
-- multi-dimensional arrays, we could handle the former in the map
-- stage. This would save us from having to store all the intermediate
-- results to memory. Troels suspects such cases are very rare, but
-- they may appear some day.
transformStm path (Let pat aux (Op (Scatter w arrs lam as)))
| not $ all primType $ lambdaReturnType lam = do
-- Produce map stage.
map_pat <- fmap Pat $ forM (lambdaReturnType lam) $ \t ->
PatElem <$> newVName "scatter_tmp" <*> pure (t `arrayOfRow` w)
map_stms <- onMap path $ MapLoop map_pat aux w lam arrs

-- Now do the scatters.
runBuilder_ $ do
addStms map_stms
zipWithM_ doScatter (patElems pat) $ groupScatterResults as $ patNames map_pat
where
-- Generate code for a scatter where each thread writes only a scalar.
doScatter res_pe (scatter_space, arr, is_vs) = do
kernel_i <- newVName "write_i"
val_t <- stripArray (shapeRank scatter_space) <$> lookupType arr
val_is <- replicateM (arrayRank val_t) (newVName "val_i")
(kret, kstms) <- collectStms $ do
is_vs' <- forM is_vs $ \(is, v) -> do
v' <- letSubExp (baseString v <> "_elem") $ BasicOp $ Index v $ Slice $ map (DimFix . Var) $ kernel_i : val_is
is' <- forM is $ \i' ->
letSubExp (baseString i' <> "_i") $ BasicOp $ Index i' $ Slice [DimFix $ Var kernel_i]
pure (Slice $ map DimFix $ is' <> map Var val_is, v')
pure $ WriteReturns mempty (scatter_space <> arrayShape val_t) arr is_vs'
(kernel, stms) <-
mapKernel
segThreadCapped
((kernel_i, w) : zip val_is (arrayDims val_t))
mempty
[Prim $ elemType val_t]
(KernelBody () kstms [kret])
addStms stms
letBind (Pat [res_pe]) $ Op $ SegOp kernel
--
transformStm _ (Let pat (StmAux cs _ _) (Op (Scatter w ivs lam as))) = runBuilder_ $ do
let lam' = soacsLambdaToGPU lam
write_i <- newVName "write_i"
Expand Down
7 changes: 7 additions & 0 deletions tests/distribution/scatter1.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Scattering where elements are themselves arrays - see #2035.
-- ==
-- input { [[1,2,3],[4,5,6]] [1i64,0i64,-1i64] [[9,8,7],[6,5,4],[9,8,7]] }
-- output { [[6,5,4],[9,8,7]] }

entry main (xss: *[][]i32) (is: []i64) (ys: [][]i32) =
scatter xss is ys
9 changes: 9 additions & 0 deletions tests/distribution/scatter2.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- Scattering where elements are themselves arrays - see #2035. This
-- one also has a map part.
-- ==
-- input { [[1,2,3],[4,5,6]] [1i64,0i64,-1i64] [[9,8,7],[6,5,4],[4,5,6]] }
-- output { [[8,7,6],[11,10,9]] }
-- structure gpu { SegMap 2 }

entry main (xss: *[][]i32) (is: []i64) (ys: [][]i32) =
scatter xss is (map (map (+2)) ys)

0 comments on commit 660631b

Please sign in to comment.