From 5480a342640c7e310597ad71b3071a92ba0c3248 Mon Sep 17 00:00:00 2001 From: Christiaan Baaij Date: Wed, 14 Feb 2024 16:13:52 +0100 Subject: [PATCH] Refactor DEC transformation Create multiple selectors, one for each non-shared argument. The previous code was a big mess where we partioned arguments into shared and non-shared and then filtered the case-tree depending on whether they were part of the shared arguments or not. But then with the normalisation of type arguments, the second filter did not work properly. This then resulted in shared arguments becoming part of the tuple in the alternatives of the case-expression for the non-shared arguments. The new code is also more robust in the sense that shared and non-shared arguments no longer need to be partioned (shared occur left-most, non-shared occur right-most). They can now be interleaved. The old code would also generate bad Core if ever type and term arguments occured interleaved, this is no longer the case for the new code. Fixes #2628 --- .../Clash/Normalize/Transformations/DEC.hs | 264 +++++------------- tests/Main.hs | 1 + tests/shouldwork/Issues/T2628.hs | 156 +++++++++++ 3 files changed, 229 insertions(+), 192 deletions(-) create mode 100644 tests/shouldwork/Issues/T2628.hs diff --git a/clash-lib/src/Clash/Normalize/Transformations/DEC.hs b/clash-lib/src/Clash/Normalize/Transformations/DEC.hs index 61af7a5c2f..d3025d105c 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/DEC.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/DEC.hs @@ -1,6 +1,6 @@ {-| Copyright : (C) 2015-2016, University of Twente, - 2021-2022, QBayLogic B.V. + 2021-2024, QBayLogic B.V. 2022, LumiGuide Fietsdetectie B.V. License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. @@ -46,8 +46,6 @@ import Data.Coerce (coerce) import qualified Data.Either as Either import qualified Data.Foldable as Foldable import qualified Data.Graph as Graph -import Data.IntMap.Strict (IntMap) -import qualified Data.IntMap.Strict as IntMap import qualified Data.IntSet as IntSet import qualified Data.List as List import qualified Data.List.Extra as List @@ -57,45 +55,32 @@ import Data.Monoid (All(..)) import qualified Data.Text as Text import GHC.Stack (HasCallStack) -#if MIN_VERSION_ghc(9,6,0) -import GHC.Core.Make (chunkify, mkChunkified) -#else -import GHC.Hs.Utils (chunkify, mkChunkified) -#endif - -#if MIN_VERSION_ghc(9,0,0) -import GHC.Settings.Constants (mAX_TUPLE_SIZE) -#else -import Constants (mAX_TUPLE_SIZE) -#endif - -- internal -import Clash.Core.DataCon (DataCon) import Clash.Core.Evaluator.Types (whnf') import Clash.Core.FreeVars (termFreeVars', typeFreeVars', localVarsDoNotOccurIn) import Clash.Core.HasType import Clash.Core.Literal (Literal(..)) import Clash.Core.Name (nameOcc) +import Clash.Core.Pretty (showPpr) import Clash.Core.Term ( Alt, LetBinding, Pat(..), PrimInfo(..), Term(..), TickInfo(..) , collectArgs, collectArgsTicks, mkApps, mkTicks, patIds, stripTicks) -import Clash.Core.TyCon (TyConMap, TyConName, tyConDataCons) +import Clash.Core.TyCon (TyConMap) import Clash.Core.Type - (Type, TypeView (..), isPolyFunTy, mkTyConApp, splitFunForallTy, tyView) -import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings) + (Type, TypeView (..), isPolyFunTy, splitFunForallTy, tyView) +import Clash.Core.Util (mkInternalVar, sccLetBindings) import Clash.Core.Var (isGlobalId, isLocalId, varName) import Clash.Core.VarEnv ( InScopeSet, elemInScopeSet, extendInScopeSet, extendInScopeSetList , notElemInScopeSet, unionInScope) -import qualified Clash.Data.UniqMap as UniqMap import Clash.Normalize.Transformations.Letrec (deadCode) import Clash.Normalize.Types (NormRewrite, NormalizeSession) import Clash.Rewrite.Combinators (bottomupR) import Clash.Rewrite.Types import Clash.Rewrite.Util (changed, isUntranslatableType) import Clash.Rewrite.WorkFree (isConstant) -import Clash.Util (MonadUnique, curLoc) +import Clash.Util (curLoc) -- | This transformation lifts applications of global binders out of -- alternatives of case-statements. @@ -132,11 +117,12 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t else do -- For every to-lift expression create (the generalization of): -- - -- let fargs = case x of {A -> (3,y); B -> (x,x)} - -- in f (fst fargs) (snd fargs) + -- let djArg0 = case x of {A -> 3; B -> x} + -- djArg1 = case x of {A -> y; B -> x} + -- in f djArg0 djArg1 -- - -- the let-expression is not created when `f` has only one (selectable) - -- argument + -- if an argument is non-representable, the case-expression is inlined, + -- and no let-binding will be created for it. -- -- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine -- whether expressions reference variables from the context, or @@ -251,18 +237,6 @@ isDisjoint ct = go ct go (Branch _ [(_,x)]) = go x go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b)) --- Remove empty branches from a 'CaseTree' -removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a] -removeEmpty l@(Leaf _) = l -removeEmpty (LB lb ct) = - case removeEmpty ct of - Leaf [] -> Leaf [] - ct' -> LB lb ct' -removeEmpty (Branch s bs) = - case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of - [] -> Leaf [] - bs' -> Branch s bs' - -- | Test if all elements in a list are equal to each other. allEqual :: Eq a => [a] -> Bool allEqual [] = True @@ -464,8 +438,11 @@ collectGlobalsLbs is0 substitution seen lbs = do -- function-position\", return a let-expression: where the let-binding holds -- a case-expression selecting between the distinct arguments of the case-tree, -- and the body is an application of the term applied to the shared arguments of --- the case tree, and projections of let-binding corresponding to the distinct --- argument positions. +-- the case tree, and variable references to the created let-bindings. +-- +-- case-expressions whose type would be non-representable are not let-bound, +-- but occur directly in the argument position of the application in the body +-- of the let-expression. mkDisjointGroup :: InScopeSet -- ^ Variables in scope at the very top of the case-tree, i.e., the original @@ -475,79 +452,59 @@ mkDisjointGroup -> NormalizeSession (Term,[Term]) mkDisjointGroup inScope (fun,(seen,cs)) = do tcm <- Lens.view tcCache - let argss = Foldable.toList cs - argssT = zip [0..] (List.transpose argss) - (sharedT,distinctT) = List.partition (areShared tcm inScope . fmap (first stripTicks) . snd) argssT - -- TODO: find a better solution than "maybe undefined fst . uncons" - shared = map (second (maybe (error "impossible") fst . List.uncons)) sharedT - distinct = map (Either.lefts) (List.transpose (map snd distinctT)) - cs' = fmap (zip [0..]) cs - cs'' = removeEmpty - $ fmap (Either.lefts . map snd) - (if null shared - then cs' - else fmap (filter (`notElem` shared)) cs') - (distinctCaseM,distinctProjections) <- case distinct of - -- only shared arguments: do nothing. - [] -> return (Nothing,[]) - -- Create selectors and projections - (uc:_) -> do - let argTys = map (inferCoreTypeOf tcm) uc - disJointSelProj inScope argTys cs'' - let newArgs = mkDJArgs 0 shared distinctProjections - case distinctCaseM of - Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen) - Nothing -> return (mkApps fun newArgs, seen) - --- | Create a single selector for all the representable distinct arguments by --- selecting between tuples. This selector is only ('Just') created when the --- number of representable uncommmon arguments is larger than one, otherwise it --- is not ('Nothing'). --- --- It also returns: --- --- * For all the non-representable distinct arguments: a selector --- * For all the representable distinct arguments: a projection out of the tuple --- created by the larger selector. If this larger selector does not exist, a --- single selector is created for the single representable distinct argument. + let argLen = case Foldable.toList cs of + [] -> error ($curLoc <> "mkDisjointGroup: no disjoint groups") + l:_ -> length l + csT :: [CaseTree (Either Term Type)] + csT = map (\i -> fmap (!!i) cs) [0..(argLen-1)] + (lbs,newArgs) <- List.mapAccumLM (\lbs c -> do + let cL :: [Either Term Type] + cL = Foldable.toList c + case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of + (Right ty:_, True) -> + return (lbs,Right ty) + (Right _:_, False) -> + error ($curLoc <> "mkDisjointGroup: non-equal type arguments: " <> + showPpr (Either.rights cL)) + (Left tm:_, True) -> + return (lbs,Left tm) + (Left tm:_, False) -> do + let ty = inferCoreTypeOf tcm tm + let err = error $ + $curLoc <> + "mkDisjointGroup: mixed type and term arguments: " <> + show cL + (lbM,arg) <- disJointSelProj inScope ty (Either.fromLeft err <$> c) + case lbM of + Just lb -> return (lb:lbs,Left arg) + _ -> return (lbs,Left arg) + ([], _) -> + error ($curLoc ++ "mkDisjointGroup: no arguments") + ) [] csT + let funApp = mkApps fun newArgs + case lbs of + [] -> return (funApp, seen) + _ -> return (Letrec lbs funApp, seen) + +-- | Create a selector for the case-tree of the argument. If the argument is +-- representable create a let-binding for the created selector, and return +-- a variable reference to this let-binding. If the argument is not representable +-- return the selector directly. disJointSelProj :: InScopeSet - -> [Type] - -- ^ Types of the arguments - -> CaseTree [Term] - -- The case-tree of arguments - -> NormalizeSession (Maybe LetBinding,[Term]) -disJointSelProj _ _ (Leaf []) = return (Nothing,[]) -disJointSelProj inScope argTys cs = do - tcm <- Lens.view tcCache - tupTcm <- Lens.view tupleTcCache - let maxIndex = length argTys - 1 - css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex] - (untran,tran) <- List.partitionM (isUntranslatableType False . snd) (zip [0..] argTys) - let untranCs = map (css!!) (map fst untran) - untranSels = zipWith (\(_,ty) cs' -> genCase tcm tupTcm ty [ty] cs') - untran untranCs - (lbM,projs) <- case tran of - [] -> return (Nothing,[]) - [(i,ty)] -> return (Nothing,[genCase tcm tupTcm ty [ty] (css!!i)]) - tys -> do - let m = length tys - (tyIxs,tys') = unzip tys - tupTy = mkBigTupTy tcm tupTcm tys' - cs' = fmap (\es -> map (es !!) tyIxs) cs - djCase = genCase tcm tupTcm tupTy tys' cs' - scrutId <- mkInternalVar inScope "tupIn" tupTy - projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0..m-1] - return (Just (scrutId,djCase),projections) - let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs - - return (lbM,selProjs) - where - tranOrUnTran _ [] projs = projs - tranOrUnTran _ sels [] = map snd sels - tranOrUnTran n ((ut,s):uts) (p:projs) - | n == ut = s : tranOrUnTran (n+1) uts (p:projs) - | otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs + -> Type + -- ^ Types of the argument + -> CaseTree Term + -- The case-tree of argument + -> NormalizeSession (Maybe LetBinding,Term) +disJointSelProj inScope argTy cs = do + let sel = genCase argTy cs + untran <- isUntranslatableType False argTy + case untran of + True -> return (Nothing, sel) + False -> do + argId <- mkInternalVar inScope "djArg" argTy + return (Just (argId,sel), Var argId) -- | Arguments are shared between invocations if: -- @@ -579,30 +536,15 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs) _ -> False isProof _ = False --- | Create a list of arguments given a map of positions to common arguments, --- and a list of arguments -mkDJArgs :: Int -- ^ Current position - -> [(Int,Either Term Type)] -- ^ map from position to common argument - -> [Term] -- ^ (projections for) distinct arguments - -> [Either Term Type] -mkDJArgs _ cms [] = map snd cms -mkDJArgs _ [] uncms = map Left uncms -mkDJArgs n ((m,x):cms) (y:uncms) - | n == m = x : mkDJArgs (n+1) cms (y:uncms) - | otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms - -- | Create a case-expression that selects between the distinct arguments given -- a case-tree -genCase :: TyConMap - -> IntMap TyConName - -> Type -- ^ Type of the alternatives - -> [Type] -- ^ Types of the arguments - -> CaseTree [Term] -- ^ CaseTree of arguments +genCase :: Type -- ^ Types of the arguments + -> CaseTree Term -- ^ CaseTree of arguments -> Term -genCase tcm tupTcm ty argTys = go +genCase ty = go where - go (Leaf tms) = - mkBigTupTm tcm tupTcm (List.zipEqual argTys tms) + go (Leaf tm) = + tm go (LB lb ct) = Letrec lb (go ct) @@ -617,68 +559,6 @@ genCase tcm tupTcm ty argTys = go go (Branch scrut pats) = Case scrut ty (map (second go) pats) --- | Lookup the TyConName and DataCon for a tuple of size n -findTup :: TyConMap -> IntMap TyConName -> Int -> (TyConName,DataCon) -findTup tcm tupTcm n = - Maybe.fromMaybe (error ("Cannot build " <> show n <> "-tuble")) $ do - tupTcNm <- IntMap.lookup n tupTcm - tupTc <- UniqMap.lookup tupTcNm tcm - tupDc <- Maybe.listToMaybe (tyConDataCons tupTc) - return (tupTcNm,tupDc) - -mkBigTupTm :: TyConMap -> IntMap TyConName -> [(Type,Term)] -> Term -mkBigTupTm tcm tupTcm args = snd $ mkBigTup tcm tupTcm args - -mkSmallTup,mkBigTup :: TyConMap -> IntMap TyConName -> [(Type,Term)] -> (Type,Term) -mkSmallTup _ _ [] = error $ $curLoc ++ "mkSmallTup: Can't create 0-tuple" -mkSmallTup _ _ [(ty,tm)] = (ty,tm) -mkSmallTup tcm tupTcm args = (ty,tm) - where - (argTys,tms) = unzip args - (tupTcNm,tupDc) = findTup tcm tupTcm (length args) - tm = mkApps (Data tupDc) (map Right argTys ++ map Left tms) - ty = mkTyConApp tupTcNm argTys - -mkBigTup tcm tupTcm = mkChunkified (mkSmallTup tcm tupTcm) - -mkSmallTupTy,mkBigTupTy - :: TyConMap - -> IntMap TyConName - -> [Type] - -> Type -mkSmallTupTy _ _ [] = error $ $curLoc ++ "mkSmallTupTy: Can't create 0-tuple" -mkSmallTupTy _ _ [ty] = ty -mkSmallTupTy tcm tupTcm tys = mkTyConApp tupTcNm tys - where - m = length tys - (tupTcNm,_) = findTup tcm tupTcm m - -mkBigTupTy tcm tupTcm = mkChunkified (mkSmallTupTy tcm tupTcm) - -mkSmallTupSelector,mkBigTupSelector - :: MonadUnique m - => InScopeSet - -> TyConMap - -> IntMap TyConName - -> Term - -> [Type] - -> Int - -> m Term -mkSmallTupSelector _ _ _ scrut [_] 0 = return scrut -mkSmallTupSelector _ _ _ _ [_] n = error $ $curLoc ++ "mkSmallTupSelector called with one type, but to select " ++ show n -mkSmallTupSelector inScope tcm _ scrut _ n = mkSelectorCase ($curLoc ++ "mkSmallTupSelector") inScope tcm scrut 1 n - -mkBigTupSelector inScope tcm tupTcm scrut tys n = go (chunkify tys) - where - go [_] = mkSmallTupSelector inScope tcm tupTcm scrut tys n - go tyss = do - let (nOuter,nInner) = divMod n mAX_TUPLE_SIZE - tyss' = map (mkSmallTupTy tcm tupTcm) tyss - outer <- mkSmallTupSelector inScope tcm tupTcm scrut tyss' nOuter - inner <- mkSmallTupSelector inScope tcm tupTcm outer (tyss List.!! nOuter) nInner - return inner - - -- | Determine if a term in a function position is interesting to lift out of -- of a case-expression. -- diff --git a/tests/Main.hs b/tests/Main.hs index b56019fe9a..66a3f42fb5 100755 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -788,6 +788,7 @@ runClashTest = defaultMain $ clashTestRoot , outputTest "T2510" def{hdlTargets=[VHDL], clashFlags=["-DNOINLINE=OPAQUE"]} #endif , outputTest "T2542" def{hdlTargets=[VHDL]} + , runTest "T2628" def{hdlTargets=[VHDL], buildTargets=BuildSpecific ["TACacheServerStep"], hdlSim=[]} ] <> if compiledWith == Cabal then -- This tests fails without environment files present, which are only diff --git a/tests/shouldwork/Issues/T2628.hs b/tests/shouldwork/Issues/T2628.hs new file mode 100644 index 0000000000..2d6a001bb5 --- /dev/null +++ b/tests/shouldwork/Issues/T2628.hs @@ -0,0 +1,156 @@ +module T2628 where + +import Clash.Prelude + +-- idx cacheline entries are Just(tag,Just addr) to translate idx++tag->addr +-- and Just(tag,Nothing) for invalidated idx++tag entry +-- and Nothing for no entry there +type CacheLine m tag addr -- 2^m tags per line, 2^n lines + = Vec (2^m) (Maybe(tag,Maybe addr)) + +{-# ANN tacache_server_step32 + (Synthesize { t_name = "TACacheServerStep" + , t_inputs = [ PortName "dx" -- user B + , PortName "d_x" -- tlb C + , PortName "dw" -- tlb D + , PortName "out2" -- cache B + , PortName "out3" -- cache C + ] + , t_output = PortProduct "" + [ PortName "win1" -- cache A1 + , PortName "win2" -- cache A2 + ] + }) #-} + +{-# NOINLINE tacache_server_step32 #-} +tacache_server_step32 = tacache_server_step' + where + tacache_server_step' + :: forall (m::Nat) (n::Nat) (p::Nat) (q::Nat) + cxdr addr idx tag cacheline + . ( KnownNat q, KnownNat n, KnownNat m, KnownNat p + , n <= p + , cxdr ~ Signed p + , addr ~ Signed q + , idx ~ Signed n + , tag ~ Signed (p-n) + , cacheline ~ CacheLine m tag addr + , p ~ 132 + , q ~ 32 + , n ~ 6 + , m ~ 0 + ) + -- SNat n -- 2^n lines + -- SNat m -- of 2^m entries each + => ( Maybe cxdr -- input frnt invalidate addr req to server + , Maybe cxdr -- input back/weak invalidate req to server + , Maybe (cxdr,addr) -- input back/weak write req to server + , Maybe (idx,cacheline) + , Maybe (idx,cacheline) + ) + -> ( Maybe(idx,cacheline) + , Maybe(idx,cacheline) + ) + tacache_server_step' = tacache_server_step (SNat::SNat n) (SNat::SNat m) + +tacache_server_step + :: forall (m::Nat) (n::Nat) (p::Nat) (q::Nat) + cxdr addr idx tag cacheline + . ( KnownNat q, KnownNat n, KnownNat m, KnownNat p + , n <= p + , cxdr ~ Signed p + , addr ~ Signed q + , idx ~ Signed n + , tag ~ Signed (p-n) + , cacheline ~ CacheLine m tag addr +-- , p ~ 132 +-- , q ~ 32 + ) + => SNat n -- 2^n lines + -> SNat m -- of 2^m entries each + -> ( Maybe cxdr -- input frnt invalidate addr req to server + , Maybe cxdr -- input back/weak invalidate req to server + , Maybe (cxdr,addr) -- input back/weak write req to server + , Maybe (idx,cacheline) + , Maybe (idx,cacheline) + ) + -> ( Maybe(idx,cacheline) + , Maybe(idx,cacheline) + ) +tacache_server_step n m (dx,d_x,dw,out1,out2) = (win1,win2) + + where + -- outs1 and outs2 are prev state + -- (may need to write two lines in one cycle) + win1,win2 :: Maybe(idx,CacheLine m tag addr) + (win1,win2) = + case (dx, d_x, dw, out1, out2) of + + -- !!! FIX for HDL from here on, replace (v,_) = with v = fst $ !!! -- + + (Just x1,Just x2,Nothing,Just (idx1,v1),Just (idx2,v2)) -> + let (idx2',tag2) = tacache_split_cxdr x2 + in + if 1 /= idx2' then + ( Just(1,v1) + , Just(idx2',v2) + ) + else + let (v1',_) = tazcache_line_inval_step v1 2 -- HERE + (v2',_) = tazcache_line_weak_inval_step v1' tag2 -- HERE + in ( Just(idx2',v2') + , Nothing + ) + + -- !!! FIX for HDL from here, as above, and make cases top level fns !!! --- + + (Nothing,Just x,Nothing,_,Just (idx,v)) -> + let (v',_) = tazcache_line_weak_inval_step v 4 -- HERE + in ( Nothing + , Just(3,v') + ) + + _ -> (Nothing,Nothing) + + -------------------- DUMMY NOINLINE support ----------------------- + +-- split incoming addr for translation into a cacheline index and tag +{-# NOINLINE tacache_split_cxdr #-} +tacache_split_cxdr + :: forall (n::Nat) (p::Nat) tag cxdr idx f + . ( KnownNat n, KnownNat p + , Resize f -- might as well be just Signed + , n <= p, (n + (p-n)) ~ p, ((p-n) + n) ~ p + , BitPack cxdr, p ~ BitSize cxdr, cxdr ~ f p + , BitPack idx, n ~ BitSize idx, idx ~ f n + , BitPack tag, (p-n) ~ BitSize tag, tag ~ f (p-n) + ) + => cxdr + -> (idx,tag) +tacache_split_cxdr x = (unpack 5, unpack 6) + + ------------------ DUMMY NOINLINE cacheline ops --------------------- + +-- remove element with matching tag from cacheline, report position +{-# NOINLINE tazcache_line_inval_step #-} +tazcache_line_inval_step :: + ( KnownNat m, KnownNat p_n, KnownNat q + , BitPack tag, p_n ~ BitSize tag, Eq tag + , BitPack addr, q ~ BitSize addr + ) + => CacheLine m tag addr + -> tag + -> (CacheLine m tag addr, Maybe(Index(2^m))) +tazcache_line_inval_step v tag = (v,Nothing) + +-- add placeholder invalidated entry to cacheline, replace entry if was there +{-# NOINLINE tazcache_line_weak_inval_step #-} +tazcache_line_weak_inval_step :: + ( KnownNat m, KnownNat p_n, KnownNat q + , BitPack tag, p_n ~ BitSize tag, Eq tag + , BitPack addr, q ~ BitSize addr + ) + => CacheLine m tag addr + -> tag + -> (CacheLine m tag addr, Maybe(Index(2^m))) +tazcache_line_weak_inval_step v tag = (v,Nothing)