diff --git a/changelog/2024-02-23T16_56_56+01_00_fix2628 b/changelog/2024-02-23T16_56_56+01_00_fix2628 new file mode 100644 index 0000000000..7a13e33627 --- /dev/null +++ b/changelog/2024-02-23T16_56_56+01_00_fix2628 @@ -0,0 +1 @@ +FIXED: Clash no longer errors out in the netlist generation stage when a polymorphic function is applied to type X in one alternative of a case-statement and applied to a newtype wrapper of type X in a different alternative. See [#2828](https://github.com/clash-lang/clash-compiler/issues/2628) diff --git a/clash-lib/src/Clash/Normalize/Transformations/DEC.hs b/clash-lib/src/Clash/Normalize/Transformations/DEC.hs index aaa0c30cb3..a19ebce99d 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/DEC.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/DEC.hs @@ -38,6 +38,9 @@ module Clash.Normalize.Transformations.DEC ) where import Control.Concurrent.Supply (splitSupply) +#if !MIN_VERSION_base(4,18,0) +import Control.Applicative (liftA2) +#endif import Control.Lens ((^.), _1) import qualified Control.Lens as Lens import qualified Control.Monad as Monad @@ -56,6 +59,7 @@ import qualified Data.Map.Strict as Map import qualified Data.Maybe as Maybe import Data.Monoid (All(..)) import qualified Data.Text as Text +import Data.Text.Extra (showt) import GHC.Stack (HasCallStack) import qualified Language.Haskell.TH as TH @@ -72,13 +76,14 @@ import Constants (mAX_TUPLE_SIZE) #endif -- internal -import Clash.Core.DataCon (DataCon) +import Clash.Core.DataCon (DataCon) import Clash.Core.Evaluator.Types (whnf') import Clash.Core.FreeVars (termFreeVars', typeFreeVars', localVarsDoNotOccurIn) import Clash.Core.HasType import Clash.Core.Literal (Literal(..)) -import Clash.Core.Name (nameOcc) +import Clash.Core.Name (OccName, nameOcc) +import Clash.Core.Pretty (showPpr) import Clash.Core.Term ( Alt, LetBinding, Pat(..), PrimInfo(..), Term(..), TickInfo(..) , collectArgs, collectArgsTicks, mkApps, mkTicks, patIds, stripTicks) @@ -86,7 +91,7 @@ import Clash.Core.TyCon (TyConMap, TyConName, tyConDataCons) import Clash.Core.Type (Type, TypeView (..), isPolyFunTy, mkTyConApp, splitFunForallTy, tyView) import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings) -import Clash.Core.Var (isGlobalId, isLocalId, varName) +import Clash.Core.Var (Id, isGlobalId, isLocalId, varName) import Clash.Core.VarEnv ( InScopeSet, elemInScopeSet, extendInScopeSet, extendInScopeSetList , notElemInScopeSet, unionInScope) @@ -138,6 +143,24 @@ import qualified GHC.Prim -- B -> f_out -- C -> h x -- @ +-- +-- Though that's a lie. It actually converts it into: +-- +-- @ +-- let f_tupIn = case x of {A -> (3,y); B -> (x,x)} +-- f_arg0 = case f_tupIn of (l,_) -> l +-- f_arg1 = case f_tupIn of (_,r) -> r +-- f_out = f f_arg0 f_arg1 +-- in case x of +-- A -> f_out +-- B -> f_out +-- C -> h x +-- @ +-- +-- In order to share the expression that's in the subject of the case expression, +-- and to share the /decoder/ circuit that logic synthesis will create to map the +-- bits of the subject expression to the bits needed to make the selection in the +-- multiplexer. disjointExpressionConsolidation :: HasCallStack => NormRewrite disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _ty _alts@(_:_:_)) = do -- Collect all (the applications of) global binders (and certain primitives) @@ -150,11 +173,13 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t else do -- For every to-lift expression create (the generalization of): -- - -- let fargs = case x of {A -> (3,y); B -> (x,x)} - -- in f (fst fargs) (snd fargs) + -- let f_tupIn = case x of {A -> (3,y); B -> (x,x)} + -- f_arg0 = case f_tupIn of (l,_) -> l + -- f_arg1 = case f_tupIn of (_,r) -> r + -- in f f_arg0 f_arg0 -- - -- the let-expression is not created when `f` has only one (selectable) - -- argument + -- if an argument is non-representable, the case-expression is inlined, + -- and no let-binding will be created for it. -- -- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine -- whether expressions reference variables from the context, or @@ -190,11 +215,8 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t -- Make the let-binder for the lifted expressions mkFunOut tcm isN ((fun,_),(eLifted,_)) = do let ty = inferCoreTypeOf tcm eLifted - nm = case collectArgs fun of - (Var v,_) -> nameOcc (varName v) - (Prim p,_) -> primName p - _ -> "complex_expression_" - nm1 = last (Text.splitOn "." nm) `Text.append` "Out" + nm = decFunName fun + nm1 = nm `Text.append` "_out" nm2 <- mkInternalVar isN nm1 ty return (extendInScopeSet isN nm2,nm2) @@ -249,12 +271,26 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t disjointExpressionConsolidation _ e = return e {-# SCC disjointExpressionConsolidation #-} +decFunName :: Term -> OccName +decFunName fun = last . Text.splitOn "." $ case collectArgs fun of + (Var v, _) -> nameOcc (varName v) + (Prim p, _) -> primName p + _ -> "complex_expression" + data CaseTree a = Leaf a | LB [LetBinding] (CaseTree a) | Branch Term [(Pat,CaseTree a)] deriving (Eq,Show,Functor,Foldable) +instance Applicative CaseTree where + pure = Leaf + liftA2 f (Leaf a) (Leaf b) = Leaf (f a b) + liftA2 f (LB lb c1) (LB _ c2) = LB lb (liftA2 f c1 c2) + liftA2 f (Branch scrut alts1) (Branch _ alts2) = + Branch scrut (zipWith (\(p1,a1) (_,a2) -> (p1,liftA2 f a1 a2)) alts1 alts2) + liftA2 _ _ _ = error "CaseTree.liftA2: internal error, this should not happen." + -- | Test if a 'CaseTree' collected from an expression indicates that -- application of a global binder is disjoint: occur in separate branches of a -- case-expression. @@ -269,18 +305,6 @@ isDisjoint ct = go ct go (Branch _ [(_,x)]) = go x go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b)) --- Remove empty branches from a 'CaseTree' -removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a] -removeEmpty l@(Leaf _) = l -removeEmpty (LB lb ct) = - case removeEmpty ct of - Leaf [] -> Leaf [] - ct' -> LB lb ct' -removeEmpty (Branch s bs) = - case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of - [] -> Leaf [] - bs' -> Branch s bs' - -- | Test if all elements in a list are equal to each other. allEqual :: Eq a => [a] -> Bool allEqual [] = True @@ -464,90 +488,94 @@ collectGlobalsLbs is0 substitution seen lbs = do -- function-position\", return a let-expression: where the let-binding holds -- a case-expression selecting between the distinct arguments of the case-tree, -- and the body is an application of the term applied to the shared arguments of --- the case tree, and projections of let-binding corresponding to the distinct --- argument positions. +-- the case tree, and variable references to the created let-bindings. +-- +-- case-expressions whose type would be non-representable are not let-bound, +-- but occur directly in the argument position of the application in the body +-- of the let-expression. mkDisjointGroup :: InScopeSet -- ^ Variables in scope at the very top of the case-tree, i.e., the original -- expression - -> (Term,([Term],CaseTree [(Either Term Type)])) + -> (Term,([Term],CaseTree [Either Term Type])) -- ^ Case-tree of arguments belonging to the applied term. -> NormalizeSession (Term,[Term]) mkDisjointGroup inScope (fun,(seen,cs)) = do tcm <- Lens.view tcCache - let argss = Foldable.toList cs - argssT = zip [0..] (List.transpose argss) - (sharedT,distinctT) = List.partition (areShared tcm inScope . fmap (first stripTicks) . snd) argssT - -- TODO: find a better solution than "maybe undefined fst . uncons" - shared = map (second (maybe (error "impossible") fst . List.uncons)) sharedT - distinct = map (Either.lefts) (List.transpose (map snd distinctT)) - cs' = fmap (zip [0..]) cs - cs'' = removeEmpty - $ fmap (Either.lefts . map snd) - (if null shared - then cs' - else fmap (filter (`notElem` shared)) cs') - (distinctCaseM,distinctProjections) <- case distinct of - -- only shared arguments: do nothing. - [] -> return (Nothing,[]) - -- Create selectors and projections - (uc:_) -> do - let argTys = map (inferCoreTypeOf tcm) uc - disJointSelProj inScope argTys cs'' - let newArgs = mkDJArgs 0 shared distinctProjections - case distinctCaseM of - Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen) - Nothing -> return (mkApps fun newArgs, seen) - --- | Create a single selector for all the representable distinct arguments by --- selecting between tuples. This selector is only ('Just') created when the --- number of representable uncommmon arguments is larger than one, otherwise it --- is not ('Nothing'). --- --- It also returns: --- --- * For all the non-representable distinct arguments: a selector --- * For all the representable distinct arguments: a projection out of the tuple --- created by the larger selector. If this larger selector does not exist, a --- single selector is created for the single representable distinct argument. + let funName = decFunName fun + argLen = case Foldable.toList cs of + [] -> error "mkDisjointGroup: no disjoint groups" + l:_ -> length l + csT :: [CaseTree (Either Term Type)] -- "Transposed" 'CaseTree [Either Term Type]' + csT = map (\i -> fmap (!!i) cs) [0..(argLen-1)] -- sequenceA does the wrong thing + (lbs,newArgs) <- List.mapAccumRM (\lbs (c,pos) -> do + let cL = Foldable.toList c + case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of + (Right ty:_, True) -> + return (lbs,Right ty) + (Right _:_, False) -> + error ("mkDisjointGroup: non-equal type arguments: " <> + showPpr (Either.rights cL)) + (Left tm:_, True) -> + return (lbs,Left tm) + (Left tm:_, False) -> do + let ty = inferCoreTypeOf tcm tm + let err = error ("mkDisjointGroup: mixed type and term arguments: " <> show cL) + (lbM,arg) <- disJointSelProj inScope ty (Either.fromLeft err <$> c) funName pos + case lbM of + Just lb -> return (lb:lbs, Left arg) + _ -> return (lbs, Left arg) + ([], _) -> + error "mkDisjointGroup: no arguments" + ) [] (zip csT [0..]) + let funApp = mkApps fun newArgs + tupTcm <- Lens.view tupleTcCache + case lbs of + [] -> + return (funApp, seen) + [(v,(ty,ct))] -> do + let e = genCase tcm tupTcm ty [ty] (fmap (:[]) ct) + return (Letrec [(v,e)] funApp, seen) + _ -> do + let (vs,zs) = unzip lbs + csL :: [CaseTree Term] + (tys,csL) = unzip zs + csLT :: CaseTree [Term] + csLT = fmap ($ []) (foldr1 (liftA2 (.)) (fmap (fmap (:)) csL)) + bigTupTy = mkBigTupTy tcm tupTcm tys + ct = genCase tcm tupTcm bigTupTy tys csLT + tupIn <- mkInternalVar inScope (funName <> "_tupIn") bigTupTy + projections <- + Monad.zipWithM (\v n -> + (v,) <$> mkBigTupSelector inScope tcm tupTcm (Var tupIn) tys n) + vs [0..] + return (Letrec ((tupIn,ct):projections) funApp, seen) + +-- | Create a selector for the case-tree of the argument. If the argument is +-- representable create a let-binding for the created selector, and return +-- a variable reference to this let-binding. If the argument is not representable +-- return the selector directly. disJointSelProj :: InScopeSet - -> [Type] - -- ^ Types of the arguments - -> CaseTree [Term] - -- The case-tree of arguments - -> NormalizeSession (Maybe LetBinding,[Term]) -disJointSelProj _ _ (Leaf []) = return (Nothing,[]) -disJointSelProj inScope argTys cs = do - tcm <- Lens.view tcCache + -> Type + -- ^ Types of the argument + -> CaseTree Term + -- ^ The case-tree of argument + -> OccName + -- ^ Name of the lifted function + -> Word + -- ^ Position of the argument + -> NormalizeSession (Maybe (Id, (Type, CaseTree Term)),Term) +disJointSelProj inScope argTy cs funName argN = do + tcm <- Lens.view tcCache tupTcm <- Lens.view tupleTcCache - let maxIndex = length argTys - 1 - css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex] - (untran,tran) <- List.partitionM (isUntranslatableType False . snd) (zip [0..] argTys) - let untranCs = map (css!!) (map fst untran) - untranSels = zipWith (\(_,ty) cs' -> genCase tcm tupTcm ty [ty] cs') - untran untranCs - (lbM,projs) <- case tran of - [] -> return (Nothing,[]) - [(i,ty)] -> return (Nothing,[genCase tcm tupTcm ty [ty] (css!!i)]) - tys -> do - let m = length tys - (tyIxs,tys') = unzip tys - tupTy = mkBigTupTy tcm tupTcm tys' - cs' = fmap (\es -> map (es !!) tyIxs) cs - djCase = genCase tcm tupTcm tupTy tys' cs' - scrutId <- mkInternalVar inScope "tupIn" tupTy - projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0..m-1] - return (Just (scrutId,djCase),projections) - let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs - - return (lbM,selProjs) - where - tranOrUnTran _ [] projs = projs - tranOrUnTran _ sels [] = map snd sels - tranOrUnTran n ((ut,s):uts) (p:projs) - | n == ut = s : tranOrUnTran (n+1) uts (p:projs) - | otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs + let sel = genCase tcm tupTcm argTy [argTy] (fmap (:[]) cs) + untran <- isUntranslatableType False argTy + case untran of + True -> return (Nothing, sel) + False -> do + argId <- mkInternalVar inScope (funName <> "_arg" <> showt argN) argTy + return (Just (argId,(argTy,cs)), Var argId) -- | Arguments are shared between invocations if: -- @@ -579,18 +607,6 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs) _ -> False isProof _ = False --- | Create a list of arguments given a map of positions to common arguments, --- and a list of arguments -mkDJArgs :: Int -- ^ Current position - -> [(Int,Either Term Type)] -- ^ map from position to common argument - -> [Term] -- ^ (projections for) distinct arguments - -> [Either Term Type] -mkDJArgs _ cms [] = map snd cms -mkDJArgs _ [] uncms = map Left uncms -mkDJArgs n ((m,x):cms) (y:uncms) - | n == m = x : mkDJArgs (n+1) cms (y:uncms) - | otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms - -- | Create a case-expression that selects between the distinct arguments given -- a case-tree genCase :: TyConMap diff --git a/clash-lib/src/Data/List/Extra.hs b/clash-lib/src/Data/List/Extra.hs index 011047005f..ca731578fd 100644 --- a/clash-lib/src/Data/List/Extra.hs +++ b/clash-lib/src/Data/List/Extra.hs @@ -4,6 +4,7 @@ module Data.List.Extra ( partitionM , mapAccumLM + , mapAccumRM , iterateNM , (<:>) , indexMaybe @@ -46,6 +47,19 @@ mapAccumLM f acc (x:xs) = do (acc'',ys) <- mapAccumLM f acc' xs return (acc'',y:ys) +-- | Monadic version of 'Data.List.mapAccumR' +mapAccumRM + :: Monad m + => (acc -> x -> m (acc,y)) + -> acc + -> [x] + -> m (acc,[y]) +mapAccumRM _ acc [] = return (acc,[]) +mapAccumRM f acc (x:xs) = do + (acc1,ys) <- mapAccumRM f acc xs + (acc2,y) <- f acc1 x + return (acc2,y:ys) + -- | Monadic version of 'iterate'. A carbon copy ('iterateM') would not -- terminate, hence the first argument. iterateNM diff --git a/tests/Main.hs b/tests/Main.hs index 85db66e5f4..03c5e15186 100755 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -633,6 +633,7 @@ runClashTest = defaultMain $ clashTestRoot , runTest "T2593" def{hdlSim=[]} , runTest "T2623CaseConFVs" def{hdlLoad=[],hdlSim=[],hdlTargets=[VHDL]} , runTest "T2781" def{hdlLoad=[],hdlSim=[],hdlTargets=[VHDL]} + , runTest "T2628" def{hdlTargets=[VHDL], buildTargets=BuildSpecific ["TACacheServerStep"], hdlSim=[]} ] <> if compiledWith == Cabal then -- This tests fails without environment files present, which are only diff --git a/tests/shouldwork/Issues/T2628.hs b/tests/shouldwork/Issues/T2628.hs new file mode 100644 index 0000000000..2d6a001bb5 --- /dev/null +++ b/tests/shouldwork/Issues/T2628.hs @@ -0,0 +1,156 @@ +module T2628 where + +import Clash.Prelude + +-- idx cacheline entries are Just(tag,Just addr) to translate idx++tag->addr +-- and Just(tag,Nothing) for invalidated idx++tag entry +-- and Nothing for no entry there +type CacheLine m tag addr -- 2^m tags per line, 2^n lines + = Vec (2^m) (Maybe(tag,Maybe addr)) + +{-# ANN tacache_server_step32 + (Synthesize { t_name = "TACacheServerStep" + , t_inputs = [ PortName "dx" -- user B + , PortName "d_x" -- tlb C + , PortName "dw" -- tlb D + , PortName "out2" -- cache B + , PortName "out3" -- cache C + ] + , t_output = PortProduct "" + [ PortName "win1" -- cache A1 + , PortName "win2" -- cache A2 + ] + }) #-} + +{-# NOINLINE tacache_server_step32 #-} +tacache_server_step32 = tacache_server_step' + where + tacache_server_step' + :: forall (m::Nat) (n::Nat) (p::Nat) (q::Nat) + cxdr addr idx tag cacheline + . ( KnownNat q, KnownNat n, KnownNat m, KnownNat p + , n <= p + , cxdr ~ Signed p + , addr ~ Signed q + , idx ~ Signed n + , tag ~ Signed (p-n) + , cacheline ~ CacheLine m tag addr + , p ~ 132 + , q ~ 32 + , n ~ 6 + , m ~ 0 + ) + -- SNat n -- 2^n lines + -- SNat m -- of 2^m entries each + => ( Maybe cxdr -- input frnt invalidate addr req to server + , Maybe cxdr -- input back/weak invalidate req to server + , Maybe (cxdr,addr) -- input back/weak write req to server + , Maybe (idx,cacheline) + , Maybe (idx,cacheline) + ) + -> ( Maybe(idx,cacheline) + , Maybe(idx,cacheline) + ) + tacache_server_step' = tacache_server_step (SNat::SNat n) (SNat::SNat m) + +tacache_server_step + :: forall (m::Nat) (n::Nat) (p::Nat) (q::Nat) + cxdr addr idx tag cacheline + . ( KnownNat q, KnownNat n, KnownNat m, KnownNat p + , n <= p + , cxdr ~ Signed p + , addr ~ Signed q + , idx ~ Signed n + , tag ~ Signed (p-n) + , cacheline ~ CacheLine m tag addr +-- , p ~ 132 +-- , q ~ 32 + ) + => SNat n -- 2^n lines + -> SNat m -- of 2^m entries each + -> ( Maybe cxdr -- input frnt invalidate addr req to server + , Maybe cxdr -- input back/weak invalidate req to server + , Maybe (cxdr,addr) -- input back/weak write req to server + , Maybe (idx,cacheline) + , Maybe (idx,cacheline) + ) + -> ( Maybe(idx,cacheline) + , Maybe(idx,cacheline) + ) +tacache_server_step n m (dx,d_x,dw,out1,out2) = (win1,win2) + + where + -- outs1 and outs2 are prev state + -- (may need to write two lines in one cycle) + win1,win2 :: Maybe(idx,CacheLine m tag addr) + (win1,win2) = + case (dx, d_x, dw, out1, out2) of + + -- !!! FIX for HDL from here on, replace (v,_) = with v = fst $ !!! -- + + (Just x1,Just x2,Nothing,Just (idx1,v1),Just (idx2,v2)) -> + let (idx2',tag2) = tacache_split_cxdr x2 + in + if 1 /= idx2' then + ( Just(1,v1) + , Just(idx2',v2) + ) + else + let (v1',_) = tazcache_line_inval_step v1 2 -- HERE + (v2',_) = tazcache_line_weak_inval_step v1' tag2 -- HERE + in ( Just(idx2',v2') + , Nothing + ) + + -- !!! FIX for HDL from here, as above, and make cases top level fns !!! --- + + (Nothing,Just x,Nothing,_,Just (idx,v)) -> + let (v',_) = tazcache_line_weak_inval_step v 4 -- HERE + in ( Nothing + , Just(3,v') + ) + + _ -> (Nothing,Nothing) + + -------------------- DUMMY NOINLINE support ----------------------- + +-- split incoming addr for translation into a cacheline index and tag +{-# NOINLINE tacache_split_cxdr #-} +tacache_split_cxdr + :: forall (n::Nat) (p::Nat) tag cxdr idx f + . ( KnownNat n, KnownNat p + , Resize f -- might as well be just Signed + , n <= p, (n + (p-n)) ~ p, ((p-n) + n) ~ p + , BitPack cxdr, p ~ BitSize cxdr, cxdr ~ f p + , BitPack idx, n ~ BitSize idx, idx ~ f n + , BitPack tag, (p-n) ~ BitSize tag, tag ~ f (p-n) + ) + => cxdr + -> (idx,tag) +tacache_split_cxdr x = (unpack 5, unpack 6) + + ------------------ DUMMY NOINLINE cacheline ops --------------------- + +-- remove element with matching tag from cacheline, report position +{-# NOINLINE tazcache_line_inval_step #-} +tazcache_line_inval_step :: + ( KnownNat m, KnownNat p_n, KnownNat q + , BitPack tag, p_n ~ BitSize tag, Eq tag + , BitPack addr, q ~ BitSize addr + ) + => CacheLine m tag addr + -> tag + -> (CacheLine m tag addr, Maybe(Index(2^m))) +tazcache_line_inval_step v tag = (v,Nothing) + +-- add placeholder invalidated entry to cacheline, replace entry if was there +{-# NOINLINE tazcache_line_weak_inval_step #-} +tazcache_line_weak_inval_step :: + ( KnownNat m, KnownNat p_n, KnownNat q + , BitPack tag, p_n ~ BitSize tag, Eq tag + , BitPack addr, q ~ BitSize addr + ) + => CacheLine m tag addr + -> tag + -> (CacheLine m tag addr, Maybe(Index(2^m))) +tazcache_line_weak_inval_step v tag = (v,Nothing)