From 660631b2dc8a20928de50e887ba538d751a6eb61 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 31 Oct 2023 17:08:11 +0100 Subject: [PATCH] Fully parallelise multi-dimensional scatters. (#2037) Closes #2035. --- CHANGELOG.md | 3 ++ src/Futhark/Pass/ExtractKernels.hs | 48 ++++++++++++++++++++++++++++++ tests/distribution/scatter1.fut | 7 +++++ tests/distribution/scatter2.fut | 9 ++++++ 4 files changed, 67 insertions(+) create mode 100644 tests/distribution/scatter1.fut create mode 100644 tests/distribution/scatter2.fut diff --git a/CHANGELOG.md b/CHANGELOG.md index 8779b35205..dff1ff47fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * `futhark autotune` how supports `hip` backend. +* Better parallelisation of `scatter` when the target is + multidimensional (#2035). + ### Removed ### Changed diff --git a/src/Futhark/Pass/ExtractKernels.hs b/src/Futhark/Pass/ExtractKernels.hs index c381605b87..22521512c4 100644 --- a/src/Futhark/Pass/ExtractKernels.hs +++ b/src/Futhark/Pass/ExtractKernels.hs @@ -448,6 +448,54 @@ transformStm path (Let pat _ (Op (Stream w arrs nes fold_fun))) = do types <- asksScope scopeForSOACs transformStms path . stmsToList . snd =<< runBuilderT (sequentialStreamWholeArray pat w nes fold_fun arrs) types +-- +-- When we are scattering into a multidimensional array, we want to +-- fully parallelise, such that we do not have threads writing +-- potentially large rows. We do this by fissioning the scatter into a +-- map part and a scatter part, where the former is flattened as +-- usual, and the latter has a thread per primitive element to be +-- written. +-- +-- TODO: this could be slightly smarter. If we are dealing with a +-- horizontally fused Scatter that targets both single- and +-- multi-dimensional arrays, we could handle the former in the map +-- stage. This would save us from having to store all the intermediate +-- results to memory. Troels suspects such cases are very rare, but +-- they may appear some day. +transformStm path (Let pat aux (Op (Scatter w arrs lam as))) + | not $ all primType $ lambdaReturnType lam = do + -- Produce map stage. + map_pat <- fmap Pat $ forM (lambdaReturnType lam) $ \t -> + PatElem <$> newVName "scatter_tmp" <*> pure (t `arrayOfRow` w) + map_stms <- onMap path $ MapLoop map_pat aux w lam arrs + + -- Now do the scatters. + runBuilder_ $ do + addStms map_stms + zipWithM_ doScatter (patElems pat) $ groupScatterResults as $ patNames map_pat + where + -- Generate code for a scatter where each thread writes only a scalar. + doScatter res_pe (scatter_space, arr, is_vs) = do + kernel_i <- newVName "write_i" + val_t <- stripArray (shapeRank scatter_space) <$> lookupType arr + val_is <- replicateM (arrayRank val_t) (newVName "val_i") + (kret, kstms) <- collectStms $ do + is_vs' <- forM is_vs $ \(is, v) -> do + v' <- letSubExp (baseString v <> "_elem") $ BasicOp $ Index v $ Slice $ map (DimFix . Var) $ kernel_i : val_is + is' <- forM is $ \i' -> + letSubExp (baseString i' <> "_i") $ BasicOp $ Index i' $ Slice [DimFix $ Var kernel_i] + pure (Slice $ map DimFix $ is' <> map Var val_is, v') + pure $ WriteReturns mempty (scatter_space <> arrayShape val_t) arr is_vs' + (kernel, stms) <- + mapKernel + segThreadCapped + ((kernel_i, w) : zip val_is (arrayDims val_t)) + mempty + [Prim $ elemType val_t] + (KernelBody () kstms [kret]) + addStms stms + letBind (Pat [res_pe]) $ Op $ SegOp kernel +-- transformStm _ (Let pat (StmAux cs _ _) (Op (Scatter w ivs lam as))) = runBuilder_ $ do let lam' = soacsLambdaToGPU lam write_i <- newVName "write_i" diff --git a/tests/distribution/scatter1.fut b/tests/distribution/scatter1.fut new file mode 100644 index 0000000000..f51b3f4226 --- /dev/null +++ b/tests/distribution/scatter1.fut @@ -0,0 +1,7 @@ +-- Scattering where elements are themselves arrays - see #2035. +-- == +-- input { [[1,2,3],[4,5,6]] [1i64,0i64,-1i64] [[9,8,7],[6,5,4],[9,8,7]] } +-- output { [[6,5,4],[9,8,7]] } + +entry main (xss: *[][]i32) (is: []i64) (ys: [][]i32) = + scatter xss is ys diff --git a/tests/distribution/scatter2.fut b/tests/distribution/scatter2.fut new file mode 100644 index 0000000000..c1afbb6e8a --- /dev/null +++ b/tests/distribution/scatter2.fut @@ -0,0 +1,9 @@ +-- Scattering where elements are themselves arrays - see #2035. This +-- one also has a map part. +-- == +-- input { [[1,2,3],[4,5,6]] [1i64,0i64,-1i64] [[9,8,7],[6,5,4],[4,5,6]] } +-- output { [[8,7,6],[11,10,9]] } +-- structure gpu { SegMap 2 } + +entry main (xss: *[][]i32) (is: []i64) (ys: [][]i32) = + scatter xss is (map (map (+2)) ys)