Skip to content

Commit

Permalink
Refactor DEC transformation (#2668)
Browse files Browse the repository at this point in the history
The previous code was a big mess where we partioned arguments
into shared and non-shared and then filtered the case-tree
depending on whether they were part of the shared arguments
or not. But then with the normalisation of type arguments,
the second filter did not work properly. This then resulted in
shared arguments becoming part of the tuple in the alternatives
of the case-expression for the non-shared arguments.

The new code is also more robust in the sense that shared and
non-shared arguments no longer need to be partioned (shared
occur left-most, non-shared occur right-most). They can now
be interleaved. The old code would also generate bad Core if
ever type and term arguments occured interleaved, this is no
longer the case for the new code.

Fixes #2628
  • Loading branch information
christiaanb authored Aug 28, 2024
1 parent f946617 commit 5927123
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 109 deletions.
1 change: 1 addition & 0 deletions changelog/2024-02-23T16_56_56+01_00_fix2628
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: Clash no longer errors out in the netlist generation stage when a polymorphic function is applied to type X in one alternative of a case-statement and applied to a newtype wrapper of type X in a different alternative. See [#2828](https://github.com/clash-lang/clash-compiler/issues/2628)
234 changes: 125 additions & 109 deletions clash-lib/src/Clash/Normalize/Transformations/DEC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ module Clash.Normalize.Transformations.DEC
) where

import Control.Concurrent.Supply (splitSupply)
#if !MIN_VERSION_base(4,18,0)
import Control.Applicative (liftA2)
#endif
import Control.Lens ((^.), _1)
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
Expand All @@ -56,6 +59,7 @@ import qualified Data.Map.Strict as Map
import qualified Data.Maybe as Maybe
import Data.Monoid (All(..))
import qualified Data.Text as Text
import Data.Text.Extra (showt)
import GHC.Stack (HasCallStack)
import qualified Language.Haskell.TH as TH

Expand All @@ -72,21 +76,22 @@ import Constants (mAX_TUPLE_SIZE)
#endif

-- internal
import Clash.Core.DataCon (DataCon)
import Clash.Core.DataCon (DataCon)
import Clash.Core.Evaluator.Types (whnf')
import Clash.Core.FreeVars
(termFreeVars', typeFreeVars', localVarsDoNotOccurIn)
import Clash.Core.HasType
import Clash.Core.Literal (Literal(..))
import Clash.Core.Name (nameOcc)
import Clash.Core.Name (OccName, nameOcc)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Term
( Alt, LetBinding, Pat(..), PrimInfo(..), Term(..), TickInfo(..)
, collectArgs, collectArgsTicks, mkApps, mkTicks, patIds, stripTicks)
import Clash.Core.TyCon (TyConMap, TyConName, tyConDataCons)
import Clash.Core.Type
(Type, TypeView (..), isPolyFunTy, mkTyConApp, splitFunForallTy, tyView)
import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings)
import Clash.Core.Var (isGlobalId, isLocalId, varName)
import Clash.Core.Var (Id, isGlobalId, isLocalId, varName)
import Clash.Core.VarEnv
( InScopeSet, elemInScopeSet, extendInScopeSet, extendInScopeSetList
, notElemInScopeSet, unionInScope)
Expand Down Expand Up @@ -138,6 +143,24 @@ import qualified GHC.Prim
-- B -> f_out
-- C -> h x
-- @
--
-- Though that's a lie. It actually converts it into:
--
-- @
-- let f_tupIn = case x of {A -> (3,y); B -> (x,x)}
-- f_arg0 = case f_tupIn of (l,_) -> l
-- f_arg1 = case f_tupIn of (_,r) -> r
-- f_out = f f_arg0 f_arg1
-- in case x of
-- A -> f_out
-- B -> f_out
-- C -> h x
-- @
--
-- In order to share the expression that's in the subject of the case expression,
-- and to share the /decoder/ circuit that logic synthesis will create to map the
-- bits of the subject expression to the bits needed to make the selection in the
-- multiplexer.
disjointExpressionConsolidation :: HasCallStack => NormRewrite
disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _ty _alts@(_:_:_)) = do
-- Collect all (the applications of) global binders (and certain primitives)
Expand All @@ -150,11 +173,13 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
else do
-- For every to-lift expression create (the generalization of):
--
-- let fargs = case x of {A -> (3,y); B -> (x,x)}
-- in f (fst fargs) (snd fargs)
-- let f_tupIn = case x of {A -> (3,y); B -> (x,x)}
-- f_arg0 = case f_tupIn of (l,_) -> l
-- f_arg1 = case f_tupIn of (_,r) -> r
-- in f f_arg0 f_arg0
--
-- the let-expression is not created when `f` has only one (selectable)
-- argument
-- if an argument is non-representable, the case-expression is inlined,
-- and no let-binding will be created for it.
--
-- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine
-- whether expressions reference variables from the context, or
Expand Down Expand Up @@ -190,11 +215,8 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
-- Make the let-binder for the lifted expressions
mkFunOut tcm isN ((fun,_),(eLifted,_)) = do
let ty = inferCoreTypeOf tcm eLifted
nm = case collectArgs fun of
(Var v,_) -> nameOcc (varName v)
(Prim p,_) -> primName p
_ -> "complex_expression_"
nm1 = last (Text.splitOn "." nm) `Text.append` "Out"
nm = decFunName fun
nm1 = nm `Text.append` "_out"
nm2 <- mkInternalVar isN nm1 ty
return (extendInScopeSet isN nm2,nm2)

Expand Down Expand Up @@ -249,12 +271,26 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
disjointExpressionConsolidation _ e = return e
{-# SCC disjointExpressionConsolidation #-}

decFunName :: Term -> OccName
decFunName fun = last . Text.splitOn "." $ case collectArgs fun of
(Var v, _) -> nameOcc (varName v)
(Prim p, _) -> primName p
_ -> "complex_expression"

data CaseTree a
= Leaf a
| LB [LetBinding] (CaseTree a)
| Branch Term [(Pat,CaseTree a)]
deriving (Eq,Show,Functor,Foldable)

instance Applicative CaseTree where
pure = Leaf
liftA2 f (Leaf a) (Leaf b) = Leaf (f a b)
liftA2 f (LB lb c1) (LB _ c2) = LB lb (liftA2 f c1 c2)
liftA2 f (Branch scrut alts1) (Branch _ alts2) =
Branch scrut (zipWith (\(p1,a1) (_,a2) -> (p1,liftA2 f a1 a2)) alts1 alts2)
liftA2 _ _ _ = error "CaseTree.liftA2: internal error, this should not happen."

-- | Test if a 'CaseTree' collected from an expression indicates that
-- application of a global binder is disjoint: occur in separate branches of a
-- case-expression.
Expand All @@ -269,18 +305,6 @@ isDisjoint ct = go ct
go (Branch _ [(_,x)]) = go x
go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b))

-- Remove empty branches from a 'CaseTree'
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a]
removeEmpty l@(Leaf _) = l
removeEmpty (LB lb ct) =
case removeEmpty ct of
Leaf [] -> Leaf []
ct' -> LB lb ct'
removeEmpty (Branch s bs) =
case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of
[] -> Leaf []
bs' -> Branch s bs'

-- | Test if all elements in a list are equal to each other.
allEqual :: Eq a => [a] -> Bool
allEqual [] = True
Expand Down Expand Up @@ -464,90 +488,94 @@ collectGlobalsLbs is0 substitution seen lbs = do
-- function-position\", return a let-expression: where the let-binding holds
-- a case-expression selecting between the distinct arguments of the case-tree,
-- and the body is an application of the term applied to the shared arguments of
-- the case tree, and projections of let-binding corresponding to the distinct
-- argument positions.
-- the case tree, and variable references to the created let-bindings.
--
-- case-expressions whose type would be non-representable are not let-bound,
-- but occur directly in the argument position of the application in the body
-- of the let-expression.
mkDisjointGroup
:: InScopeSet
-- ^ Variables in scope at the very top of the case-tree, i.e., the original
-- expression
-> (Term,([Term],CaseTree [(Either Term Type)]))
-> (Term,([Term],CaseTree [Either Term Type]))
-- ^ Case-tree of arguments belonging to the applied term.
-> NormalizeSession (Term,[Term])
mkDisjointGroup inScope (fun,(seen,cs)) = do
tcm <- Lens.view tcCache
let argss = Foldable.toList cs
argssT = zip [0..] (List.transpose argss)
(sharedT,distinctT) = List.partition (areShared tcm inScope . fmap (first stripTicks) . snd) argssT
-- TODO: find a better solution than "maybe undefined fst . uncons"
shared = map (second (maybe (error "impossible") fst . List.uncons)) sharedT
distinct = map (Either.lefts) (List.transpose (map snd distinctT))
cs' = fmap (zip [0..]) cs
cs'' = removeEmpty
$ fmap (Either.lefts . map snd)
(if null shared
then cs'
else fmap (filter (`notElem` shared)) cs')
(distinctCaseM,distinctProjections) <- case distinct of
-- only shared arguments: do nothing.
[] -> return (Nothing,[])
-- Create selectors and projections
(uc:_) -> do
let argTys = map (inferCoreTypeOf tcm) uc
disJointSelProj inScope argTys cs''
let newArgs = mkDJArgs 0 shared distinctProjections
case distinctCaseM of
Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen)
Nothing -> return (mkApps fun newArgs, seen)

-- | Create a single selector for all the representable distinct arguments by
-- selecting between tuples. This selector is only ('Just') created when the
-- number of representable uncommmon arguments is larger than one, otherwise it
-- is not ('Nothing').
--
-- It also returns:
--
-- * For all the non-representable distinct arguments: a selector
-- * For all the representable distinct arguments: a projection out of the tuple
-- created by the larger selector. If this larger selector does not exist, a
-- single selector is created for the single representable distinct argument.
let funName = decFunName fun
argLen = case Foldable.toList cs of
[] -> error "mkDisjointGroup: no disjoint groups"
l:_ -> length l
csT :: [CaseTree (Either Term Type)] -- "Transposed" 'CaseTree [Either Term Type]'
csT = map (\i -> fmap (!!i) cs) [0..(argLen-1)] -- sequenceA does the wrong thing
(lbs,newArgs) <- List.mapAccumRM (\lbs (c,pos) -> do
let cL = Foldable.toList c
case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of
(Right ty:_, True) ->
return (lbs,Right ty)
(Right _:_, False) ->
error ("mkDisjointGroup: non-equal type arguments: " <>
showPpr (Either.rights cL))
(Left tm:_, True) ->
return (lbs,Left tm)
(Left tm:_, False) -> do
let ty = inferCoreTypeOf tcm tm
let err = error ("mkDisjointGroup: mixed type and term arguments: " <> show cL)
(lbM,arg) <- disJointSelProj inScope ty (Either.fromLeft err <$> c) funName pos
case lbM of
Just lb -> return (lb:lbs, Left arg)
_ -> return (lbs, Left arg)
([], _) ->
error "mkDisjointGroup: no arguments"
) [] (zip csT [0..])
let funApp = mkApps fun newArgs
tupTcm <- Lens.view tupleTcCache
case lbs of
[] ->
return (funApp, seen)
[(v,(ty,ct))] -> do
let e = genCase tcm tupTcm ty [ty] (fmap (:[]) ct)
return (Letrec [(v,e)] funApp, seen)
_ -> do
let (vs,zs) = unzip lbs
csL :: [CaseTree Term]
(tys,csL) = unzip zs
csLT :: CaseTree [Term]
csLT = fmap ($ []) (foldr1 (liftA2 (.)) (fmap (fmap (:)) csL))
bigTupTy = mkBigTupTy tcm tupTcm tys
ct = genCase tcm tupTcm bigTupTy tys csLT
tupIn <- mkInternalVar inScope (funName <> "_tupIn") bigTupTy
projections <-
Monad.zipWithM (\v n ->
(v,) <$> mkBigTupSelector inScope tcm tupTcm (Var tupIn) tys n)
vs [0..]
return (Letrec ((tupIn,ct):projections) funApp, seen)

-- | Create a selector for the case-tree of the argument. If the argument is
-- representable create a let-binding for the created selector, and return
-- a variable reference to this let-binding. If the argument is not representable
-- return the selector directly.
disJointSelProj
:: InScopeSet
-> [Type]
-- ^ Types of the arguments
-> CaseTree [Term]
-- The case-tree of arguments
-> NormalizeSession (Maybe LetBinding,[Term])
disJointSelProj _ _ (Leaf []) = return (Nothing,[])
disJointSelProj inScope argTys cs = do
tcm <- Lens.view tcCache
-> Type
-- ^ Types of the argument
-> CaseTree Term
-- ^ The case-tree of argument
-> OccName
-- ^ Name of the lifted function
-> Word
-- ^ Position of the argument
-> NormalizeSession (Maybe (Id, (Type, CaseTree Term)),Term)
disJointSelProj inScope argTy cs funName argN = do
tcm <- Lens.view tcCache
tupTcm <- Lens.view tupleTcCache
let maxIndex = length argTys - 1
css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex]
(untran,tran) <- List.partitionM (isUntranslatableType False . snd) (zip [0..] argTys)
let untranCs = map (css!!) (map fst untran)
untranSels = zipWith (\(_,ty) cs' -> genCase tcm tupTcm ty [ty] cs')
untran untranCs
(lbM,projs) <- case tran of
[] -> return (Nothing,[])
[(i,ty)] -> return (Nothing,[genCase tcm tupTcm ty [ty] (css!!i)])
tys -> do
let m = length tys
(tyIxs,tys') = unzip tys
tupTy = mkBigTupTy tcm tupTcm tys'
cs' = fmap (\es -> map (es !!) tyIxs) cs
djCase = genCase tcm tupTcm tupTy tys' cs'
scrutId <- mkInternalVar inScope "tupIn" tupTy
projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0..m-1]
return (Just (scrutId,djCase),projections)
let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs

return (lbM,selProjs)
where
tranOrUnTran _ [] projs = projs
tranOrUnTran _ sels [] = map snd sels
tranOrUnTran n ((ut,s):uts) (p:projs)
| n == ut = s : tranOrUnTran (n+1) uts (p:projs)
| otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs
let sel = genCase tcm tupTcm argTy [argTy] (fmap (:[]) cs)
untran <- isUntranslatableType False argTy
case untran of
True -> return (Nothing, sel)
False -> do
argId <- mkInternalVar inScope (funName <> "_arg" <> showt argN) argTy
return (Just (argId,(argTy,cs)), Var argId)

-- | Arguments are shared between invocations if:
--
Expand Down Expand Up @@ -579,18 +607,6 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs)
_ -> False
isProof _ = False

-- | Create a list of arguments given a map of positions to common arguments,
-- and a list of arguments
mkDJArgs :: Int -- ^ Current position
-> [(Int,Either Term Type)] -- ^ map from position to common argument
-> [Term] -- ^ (projections for) distinct arguments
-> [Either Term Type]
mkDJArgs _ cms [] = map snd cms
mkDJArgs _ [] uncms = map Left uncms
mkDJArgs n ((m,x):cms) (y:uncms)
| n == m = x : mkDJArgs (n+1) cms (y:uncms)
| otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms

-- | Create a case-expression that selects between the distinct arguments given
-- a case-tree
genCase :: TyConMap
Expand Down
14 changes: 14 additions & 0 deletions clash-lib/src/Data/List/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
module Data.List.Extra
( partitionM
, mapAccumLM
, mapAccumRM
, iterateNM
, (<:>)
, indexMaybe
Expand Down Expand Up @@ -46,6 +47,19 @@ mapAccumLM f acc (x:xs) = do
(acc'',ys) <- mapAccumLM f acc' xs
return (acc'',y:ys)

-- | Monadic version of 'Data.List.mapAccumR'
mapAccumRM
:: Monad m
=> (acc -> x -> m (acc,y))
-> acc
-> [x]
-> m (acc,[y])
mapAccumRM _ acc [] = return (acc,[])
mapAccumRM f acc (x:xs) = do
(acc1,ys) <- mapAccumRM f acc xs
(acc2,y) <- f acc1 x
return (acc2,y:ys)

-- | Monadic version of 'iterate'. A carbon copy ('iterateM') would not
-- terminate, hence the first argument.
iterateNM
Expand Down
1 change: 1 addition & 0 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ runClashTest = defaultMain $ clashTestRoot
, runTest "T2593" def{hdlSim=[]}
, runTest "T2623CaseConFVs" def{hdlLoad=[],hdlSim=[],hdlTargets=[VHDL]}
, runTest "T2781" def{hdlLoad=[],hdlSim=[],hdlTargets=[VHDL]}
, runTest "T2628" def{hdlTargets=[VHDL], buildTargets=BuildSpecific ["TACacheServerStep"], hdlSim=[]}
] <>
if compiledWith == Cabal then
-- This tests fails without environment files present, which are only
Expand Down
Loading

0 comments on commit 5927123

Please sign in to comment.