Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make a BinderAndDecls data type, at first just a wrapper around Binder. #1323

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 32 additions & 34 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ buildAbs hint binding cont = do

typesFromNonDepBinderNest
:: (EnvReader m, Fallible1 m, IRRep r)
=> Nest (Binder r) n l -> m n [Type r n]
=> Binders r n l -> m n [Type r n]
typesFromNonDepBinderNest Empty = return []
typesFromNonDepBinderNest (Nest b rest) = do
Abs rest' UnitE <- return $ assumeConst $ Abs (UnaryNest b) $ Abs rest UnitE
Expand All @@ -662,7 +662,7 @@ buildUnaryLamExpr
-> (forall l. (Emits l, Distinct l, DExt n l) => AtomVar r l -> m l (Atom r l))
-> m n (LamExpr r n)
buildUnaryLamExpr hint ty cont = do
bs <- withFreshBinder hint ty \b -> return $ EmptyAbs (UnaryNest b)
bs <- withFreshBinder hint ty \b -> return $ EmptyAbs (UnaryNest (PlainBD b))
buildLamExpr bs \[v] -> cont v

buildBinaryLamExpr
Expand All @@ -672,21 +672,21 @@ buildBinaryLamExpr
-> m n (LamExpr r n)
buildBinaryLamExpr (h1,t1) (h2,t2) cont = do
bs <- withFreshBinder h1 t1 \b1 -> withFreshBinder h2 (sink t2) \b2 ->
return $ EmptyAbs $ BinaryNest b1 b2
return $ EmptyAbs $ BinaryNest (PlainBD b1) (PlainBD b2)
buildLamExpr bs \[v1, v2] -> cont v1 v2

buildLamExpr
:: ScopableBuilder r m
=> (EmptyAbs (Nest (Binder r)) n)
=> (Abs (Binders r) any n)
-> (forall l. (Emits l, Distinct l, DExt n l) => [AtomVar r l] -> m l (Atom r l))
-> m n (LamExpr r n)
buildLamExpr (Abs bs UnitE) cont = case bs of
buildLamExpr (Abs bs _) cont = case bs of
Empty -> LamExpr Empty <$> buildBlock (cont [])
Nest b rest -> do
Abs b' (LamExpr bs' body') <- buildAbs (getNameHint b) (binderType b) \v -> do
rest' <- applySubst (b@>SubstVal (Var v)) $ EmptyAbs rest
rest' <- instantiate (Abs (UnaryNest b) (EmptyAbs rest)) [Var v]
buildLamExpr rest' \vs -> cont $ sink v : vs
return $ LamExpr (Nest b' bs') body'
return $ LamExpr (Nest (PlainBD b') bs') body'

buildTopLamFromPi
:: ScopableBuilder r m
Expand Down Expand Up @@ -765,7 +765,7 @@ buildEffLam hint ty body = do
let ref = binderVar b
hVar <- sinkM $ binderVar h
body' <- buildBlock $ body (sink hVar) $ sink ref
return $ LamExpr (BinaryNest h b) body'
return $ LamExpr (BinaryNest (PlainBD h) (PlainBD b)) body'

buildForAnn
:: (Emits n, ScopableBuilder r m)
Expand All @@ -776,7 +776,7 @@ buildForAnn hint ann (IxType iTy ixDict) body = do
lam <- withFreshBinder hint iTy \b -> do
let v = binderVar b
body' <- buildBlock $ body $ sink v
return $ LamExpr (UnaryNest b) body'
return $ UnaryLamExpr b body'
emitHof $ For ann (IxType iTy ixDict) lam

buildFor :: (Emits n, ScopableBuilder r m)
Expand Down Expand Up @@ -862,7 +862,7 @@ zeroAt ty = liftEmitBuilder $ go ty where
BaseTy bt -> return $ Con $ Lit $ zeroLit bt
ProdTy tys -> ProdVal <$> mapM go tys
TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i ->
go =<< instantiate (sink tabPi) [Var i]
go =<< instantiate tabPi [Var i]
_ -> unreachable
zeroLit bt = case bt of
Scalar Float64Type -> Float64Lit 0.0
Expand Down Expand Up @@ -1336,17 +1336,15 @@ runMaybeWhile body = do

type ReconAbs r e = Abs (ReconBinders r) e

data ReconBinders r n l = ReconBinders
(TelescopeType (AtomNameC r) (Type r) n)
(Nest (NameBinder (AtomNameC r)) n l)
data ReconBinders r n l = ReconBinders (TelescopeType r n) (Nest (NameBinder (AtomNameC r)) n l)

data TelescopeType c e n =
DepTelescope (TelescopeType c e n) (Abs (BinderP c e) (TelescopeType c e) n)
| ProdTelescope [e n]
data TelescopeType r n =
DepTelescope (TelescopeType r n) (Abs (BinderAndDecls r) (TelescopeType r) n)
| ProdTelescope [Type r n]

instance IRRep r => GenericB (ReconBinders r) where
type RepB (ReconBinders r) =
PairB (LiftB (TelescopeType (AtomNameC r) (Type r)))
PairB (LiftB (TelescopeType r))
(Nest (NameBinder (AtomNameC r)))
fromB (ReconBinders x y) = PairB (LiftB x) y
{-# INLINE fromB #-}
Expand All @@ -1365,21 +1363,21 @@ instance IRRep r => ProvesExt (ReconBinders r)
instance IRRep r => BindsNames (ReconBinders r)
instance IRRep r => HoistableB (ReconBinders r)

instance GenericE (TelescopeType c e) where
type RepE (TelescopeType c e) = EitherE
(PairE (TelescopeType c e) (Abs (BinderP c e) (TelescopeType c e)))
(ListE e)
instance GenericE (TelescopeType r) where
type RepE (TelescopeType r) = EitherE
(PairE (TelescopeType r) (Abs (BinderAndDecls r) (TelescopeType r)))
(ListE (Type r))
fromE (DepTelescope lhs ab) = LeftE (PairE lhs ab)
fromE (ProdTelescope tys) = RightE (ListE tys)
{-# INLINE fromE #-}
toE (LeftE (PairE lhs ab)) = DepTelescope lhs ab
toE (RightE (ListE tys)) = ProdTelescope tys
{-# INLINE toE #-}

instance (Color c, SinkableE e) => SinkableE (TelescopeType c e)
instance (Color c, SinkableE e, RenameE e) => RenameE (TelescopeType c e)
instance (Color c, ToBinding e c, SubstE AtomSubstVal e) => SubstE AtomSubstVal (TelescopeType c e)
instance (Color c, HoistableE e) => HoistableE (TelescopeType c e)
instance IRRep r => SinkableE (TelescopeType r)
instance IRRep r => RenameE (TelescopeType r)
instance IRRep r => SubstE AtomSubstVal (TelescopeType r)
instance IRRep r => HoistableE (TelescopeType r)

telescopicCapture
:: (EnvReader m, HoistableE e, HoistableB b, IRRep r)
Expand All @@ -1405,40 +1403,40 @@ applyReconAbs (Abs bs result) x = do
applySubst (bs @@> map SubstVal xs) result

buildTelescopeTy
:: (EnvReader m, EnvExtender m, Color c, HoistableE e)
=> [AnnVar c e n] -> m n (TelescopeType c e n)
:: (EnvReader m, EnvExtender m, IRRep r)
=> [AnnVar (AtomNameC r) (Type r) n] -> m n (TelescopeType r n)
buildTelescopeTy [] = return (ProdTelescope [])
buildTelescopeTy ((v,ty):xs) = do
rhs <- buildTelescopeTy xs
Abs b rhs' <- return $ abstractFreeVar v rhs
case hoist b rhs' of
HoistSuccess rhs'' -> return $ prependTelescopeTy ty rhs''
HoistFailure _ -> return $ DepTelescope (ProdTelescope []) (Abs (b:>ty) rhs')
HoistFailure _ -> return $ DepTelescope (ProdTelescope []) (Abs (BD (b:>ty)) rhs')

prependTelescopeTy :: e n -> TelescopeType c e n -> TelescopeType c e n
prependTelescopeTy :: Type r n -> TelescopeType r n -> TelescopeType r n
prependTelescopeTy x = \case
DepTelescope lhs rhs -> DepTelescope (prependTelescopeTy x lhs) rhs
ProdTelescope xs -> ProdTelescope (x:xs)

buildTelescopeVal
:: (EnvReader m, IRRep r) => [Atom r n]
-> TelescopeType (AtomNameC r) (Type r) n -> m n (Atom r n)
-> TelescopeType r n -> m n (Atom r n)
buildTelescopeVal xsTop tyTop = fst <$> go tyTop xsTop where
go :: (EnvReader m, IRRep r)
=> TelescopeType (AtomNameC r) (Type r) n -> [Atom r n]
=> TelescopeType r n -> [Atom r n]
-> m n (Atom r n, [Atom r n])
go ty rest = case ty of
ProdTelescope tys -> do
(xs, rest') <- return $ splitAt (length tys) rest
return (ProdVal xs, rest')
DepTelescope ty1 (Abs b ty2) -> do
(x1, ~(xDep : rest')) <- go ty1 rest
ty2' <- applySubst (b@>SubstVal xDep) ty2
ty2' <- instantiate (Abs b ty2) [xDep]
(x2, rest'') <- go ty2' rest'
let depPairTy = DepPairType ExplicitDepPair b (telescopeTypeType ty2)
return (PairVal x1 (DepPair xDep x2 depPairTy), rest'')

telescopeTypeType :: TelescopeType (AtomNameC r) (Type r) n -> Type r n
telescopeTypeType :: TelescopeType r n -> Type r n
telescopeTypeType (ProdTelescope tys) = ProdTy tys
telescopeTypeType (DepTelescope lhs (Abs b rhs)) = do
let lhs' = telescopeTypeType lhs
Expand All @@ -1450,7 +1448,7 @@ unpackTelescope
=> ReconBinders r l1 l2 -> Atom r n -> m n [Atom r n]
unpackTelescope (ReconBinders tyTop _) xTop = go tyTop xTop where
go :: (Fallible1 m, EnvReader m, IRRep r)
=> TelescopeType c e l-> Atom r n -> m n [Atom r n]
=> TelescopeType r l-> Atom r n -> m n [Atom r n]
go ty x = case ty of
ProdTelescope _ -> getUnpacked x
DepTelescope ty1 (Abs _ ty2) -> do
Expand Down
112 changes: 72 additions & 40 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ module CheapReduction
, liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..)
, visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated
, bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst)
, instantiateNames, withInstantiatedNames, assumeConst, tryAsConst
, extendSubstBD, arity)
where

import Control.Applicative
Expand Down Expand Up @@ -218,12 +219,9 @@ instance IRRep r => CheaplyReducibleE r (Type r) (Type r) where
-- means that we will follow the full call chain, so it's really expensive!
-- TODO: we don't collect the dict holes here, so there's a danger of
-- dropping them if they turn out to be phantom.
TabPi (TabPiType d (b:>t) resultTy) -> do
t' <- cheapReduceE t
TabPi (TabPiType d b resultTy) -> do
d' <- cheapReduceE d
withFreshBinder (getNameHint b) t' \b' -> do
resultTy' <- extendSubst (b@>Rename (binderName b')) $ cheapReduceE resultTy
return $ TabPi $ TabPiType d' b' resultTy'
cheapReduceBinder b \b' -> TabPi <$> TabPiType d' b' <$> cheapReduceE resultTy
-- We traverse the Atom constructors that might contain lambda expressions
-- explicitly, to make sure that we can skip normalizing free vars inside those.
NewtypeTyCon (Fin n) -> NewtypeTyCon . Fin <$> cheapReduceE n
Expand All @@ -234,6 +232,16 @@ instance IRRep r => CheaplyReducibleE r (Type r) (Type r) where
a' <- substM a
dropSubst $ traverseNames cheapReduceName a'

cheapReduceBinder
:: IRRep r
=> BinderAndDecls r i i'
-> (forall o'. DExt o o' => BinderAndDecls r o o' -> CheapReducerM r i' o' a)
-> CheapReducerM r i o a
cheapReduceBinder (BD (b:>ty)) cont = do
ty' <- cheapReduceE ty
withFreshBinder (getNameHint b) ty' \b' -> do
extendSubst (b@>Rename (binderName b')) $ cont (BD b')

cheapReduceDictExpr :: CType o -> DictExpr i -> CheapReducerM CoreIR i o (CAtom o)
cheapReduceDictExpr resultTy d = case d of
SuperclassProj child superclassIx -> do
Expand Down Expand Up @@ -401,9 +409,9 @@ liftSimpAtom ty simpAtom = case simpAtom of
(BaseTy _ , Con (Lit v)) -> return $ Con $ Lit v
(ProdTy tys, Con (ProdCon xs)) -> Con . ProdCon <$> zipWithM rec tys xs
(SumTy tys, Con (SumCon _ i x)) -> Con . SumCon tys i <$> rec (tys!!i) x
(DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do
x1' <- rec t1 x1
t2' <- applySubst (b@>SubstVal x1') t2
(DepPairTy dpt, DepPair x1 x2 _) -> do
x1' <- rec (depPairLeftTy dpt) x1
t2' <- instantiate dpt [x1']
x2' <- rec t2' x2
return $ DepPair x1' x2' dpt
_ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty'
Expand All @@ -426,8 +434,8 @@ confuseGHC = getDistinct
-- them. Maybe a common set of low-level type-querying utils that both
-- CheapReduction and QueryType import?

depPairLeftTy :: DepPairType r n -> Type r n
depPairLeftTy (DepPairType _ (_:>ty) _) = ty
depPairLeftTy :: IRRep r => DepPairType r n -> Type r n
depPairLeftTy (DepPairType _ b _) = binderType b
{-# INLINE depPairLeftTy #-}

unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n)
Expand Down Expand Up @@ -463,43 +471,75 @@ wrapNewtypesData [] x = x
wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x

instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n)
instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do
applySubst (bs @@> (SubstVal <$> xs)) conDefs
instantiateTyConDef tyConDef (TyConParams _ xs) = instantiate tyConDef xs
{-# INLINE instantiateTyConDef #-}

assumeConst
:: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> body n
assumeConst e = case toAbs e of Abs bs body -> ignoreHoistFailure $ hoist bs body

arity :: (IRRep r, ToBindersAbs e body r) => e n -> Int
arity e = case toAbs e of Abs bs _ -> nestLength bs

tryAsConst
:: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> Maybe (body n)
tryAsConst e =
case toAbs e of
Abs bs body -> case hoist bs body of
HoistFailure _ -> Nothing
HoistSuccess e' -> Just e'

instantiate
:: (EnvReader m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r)
=> e n -> [Atom r n] -> m n (body n)
instantiate e xs = case toAbs e of
Abs bs body -> applySubst (bs @@> (SubstVal <$> xs)) body
:: (EnvReader m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, SinkableE e,
ToBindersAbs e body r, Ext h n)
=> e h -> [Atom r n] -> m n (body n)
instantiate e xs = do
Abs bs body <- sinkM $ toAbs e
let bs' = fmapNest (\(BD b) -> b) bs
applySubst (bs' @@> (SubstVal <$> xs)) body
{-# INLINE instantiate #-}

-- "lazy" subst-extending version of `instantiate`
withInstantiated
:: (SubstReader AtomSubstVal m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r)
=> e i -> [Atom r o]
-> (forall i'. body i' -> m i' o a)
-> m i o a
withInstantiated e xs cont = case toAbs e of
Abs bs body -> extendSubst (bs @@> (SubstVal <$> xs)) $ cont body
withInstantiated e xs cont = do
Abs bs body <- return $ toAbs e
let bs' = fmapNest (\(BD b) -> b) bs
extendSubst (bs' @@> (SubstVal <$> xs)) $ cont body

instantiateNames
:: (EnvReader m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r)
=> e n -> [AtomName r n] -> m n (body n)
instantiateNames e vs = case toAbs e of
Abs bs body -> applyRename (bs @@> vs) body
:: (EnvReader m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r, Ext h n)
=> e h -> [AtomName r n] -> m n (body n)
instantiateNames e vs = do
Abs bs body <- sinkM $ toAbs e
let bs' = fmapNest (\(BD b) -> b) bs
applyRename (bs' @@> vs) body

-- "lazy" subst-extending version of `instantiateNames`
withInstantiatedNames
:: (SubstReader Name m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r)
=> e i -> [AtomName r o]
-> (forall i'. body i' -> m i' o a)
-> m i o a
withInstantiatedNames e vs cont = case toAbs e of
Abs bs body -> extendRenamer (bs @@> vs) $ cont body
withInstantiatedNames e vs cont = do
Abs bs body <- return $ toAbs e
let bs' = fmapNest (\(BD b) -> b) bs
extendRenamer (bs' @@> vs) $ cont body

extendSubstBD
:: forall v m b r i i' o a
. (SubstReader v m, ToBinders b r, IRRep r)
=> b i i' -> [v (AtomNameC r) o] -> m i' o a -> m i o a
extendSubstBD bsTop xsTop contTop = go (toBinders bsTop) xsTop contTop
where
go :: Binders r ii ii' -> [v (AtomNameC r) o] -> m ii' o a -> m ii o a
go Empty [] cont = cont
go (Nest (BD b) bs) (x:xs) cont = extendSubst (b@>x) $ go bs xs cont
go _ _ _ = error "zip error"
{-# INLINE extendSubstBD #-}

-- Returns a representation type (type of an TypeCon-typed Newtype payload)
-- given a list of instantiated DataConDefs.
Expand Down Expand Up @@ -549,8 +589,8 @@ visitBlock b = visitGeneric (LamExpr Empty b) >>= \case

visitAlt :: Visitor m r i o => Alt r i -> m (Alt r o)
visitAlt (Abs b body) = do
visitGeneric (LamExpr (UnaryNest b) body) >>= \case
LamExpr (UnaryNest b') body' -> return $ Abs b' body'
visitGeneric (UnaryLamExpr b body) >>= \case
UnaryLamExpr b' body' -> return $ Abs b' body'
_ -> error "not an alt"

traverseOpTerm
Expand Down Expand Up @@ -585,16 +625,16 @@ visitPiDefault (PiType bs effty) = do

visitBinders
:: (Visitor2 m r, IRRep r, FromName v, AtomSubstReader v m, EnvExtender2 m)
=> Nest (Binder r) i i'
-> (forall o'. DExt o o' => Nest (Binder r) o o' -> m i' o' a)
=> Binders r i i'
-> (forall o'. DExt o o' => Binders r o o' -> m i' o' a)
-> m i o a
visitBinders Empty cont = getDistinct >>= \Distinct -> cont Empty
visitBinders (Nest (b:>ty) bs) cont = do
visitBinders (Nest (BD (b:>ty)) bs) cont = do
ty' <- visitType ty
withFreshBinder (getNameHint b) ty' \b' -> do
extendRenamer (b@>binderName b') do
visitBinders bs \bs' ->
cont $ Nest b' bs'
cont $ Nest (BD b') bs'

-- XXX: This doesn't handle the `Var`, `ProjectElt`, `SimpInCore` cases. These
-- should be handled explicitly beforehand. TODO: split out these cases under a
Expand Down Expand Up @@ -807,15 +847,6 @@ toAtomVar v = do
ty <- getType <$> lookupAtomName v
return $ AtomVar v ty

bindersToVars :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [AtomVar r n]
bindersToVars bs = do
withExtEvidence bs do
Distinct <- getDistinct
mapM toAtomVar $ nestToNames bs

bindersToAtoms :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [Atom r n]
bindersToAtoms bs = liftM (Var <$>) $ bindersToVars bs

newtype SubstVisitor i o a = SubstVisitor { runSubstVisitor :: Reader (Env o, Subst AtomSubstVal i o) a }
deriving (Functor, Applicative, Monad, MonadReader (Env o, Subst AtomSubstVal i o))

Expand Down Expand Up @@ -919,6 +950,7 @@ instance IRRep r => SubstE AtomSubstVal (DepPairType r)
instance SubstE AtomSubstVal SolverBinding
instance IRRep r => SubstE AtomSubstVal (DeclBinding r)
instance IRRep r => SubstB AtomSubstVal (Decl r)
instance IRRep r => SubstB AtomSubstVal (BinderAndDecls r)
instance SubstE AtomSubstVal NewtypeTyCon
instance SubstE AtomSubstVal NewtypeCon
instance IRRep r => SubstE AtomSubstVal (IxDict r)
Expand Down
Loading