Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Jun 22, 2023
1 parent e446273 commit a27faea
Show file tree
Hide file tree
Showing 21 changed files with 586 additions and 695 deletions.
2 changes: 1 addition & 1 deletion src/lib/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ type BlockTraverserM i o a = SubstReaderT PolySubstVal (MaybeT1 (BuilderM SimpIR
blockAsPoly
:: (EnvExtender m, EnvReader m)
=> Block SimpIR n -> m n (Maybe (Polynomial n))
blockAsPoly (Block _ decls result) =
blockAsPoly (Abs decls result) =
liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ blockAsPolyRec decls result

blockAsPolyRec :: Nest (Decl SimpIR) i i' -> Atom SimpIR i' -> BlockTraverserM i o (Polynomial o)
Expand Down
91 changes: 24 additions & 67 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import Control.Monad.Writer.Strict hiding (Alt)
import Control.Monad.State.Strict (MonadState (..), StateT (..), runStateT)
import qualified Data.Map.Strict as M
import Data.Graph (graphFromEdges, topSort)
import Data.Text.Prettyprint.Doc (Pretty (..), group, line, nest)
import Data.Text.Prettyprint.Doc (Pretty (..))
import Foreign.Ptr

import qualified Unsafe.Coerce as TrulyUnsafe
Expand All @@ -28,7 +28,6 @@ import IRVariants
import MTL1
import Subst
import Name
import PPrint (prettyBlock)
import QueryType
import Types.Core
import Types.Imp
Expand Down Expand Up @@ -88,11 +87,11 @@ emitUnOp op x = emitOp $ UnOp op x
{-# INLINE emitUnOp #-}

emitBlock :: (Builder r m, Emits n) => Block r n -> m n (Atom r n)
emitBlock (Block _ decls result) = emitDecls decls result
emitBlock = emitDecls

emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e)
=> Nest (Decl r) n l -> e l -> m n (e n)
emitDecls decls result = runSubstReaderT idSubst $ emitDecls' decls result
=> WithDecls r e n -> m n (e n)
emitDecls (Abs decls result) = runSubstReaderT idSubst $ emitDecls' decls result

emitDecls' :: (Builder r m, Emits o, RenameE e, SinkableE e)
=> Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o)
Expand Down Expand Up @@ -278,10 +277,9 @@ emitTopLet hint letAnn expr = do
v <- emitBinding hint $ AtomNameBinding $ LetBound (DeclBinding letAnn expr)
return $ AtomVar v ty

emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> LamExpr SimpIR n -> m n (TopFunName n)
emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> STopLam n -> m n (TopFunName n)
emitTopFunBinding hint def f = do
ty <- return $ getLamExprType f
emitBinding hint $ TopFunBinding $ DexTopFun def ty f Waiting
emitBinding hint $ TopFunBinding $ DexTopFun def f Waiting

emitSourceMap :: TopBuilder m => SourceMap n -> m n ()
emitSourceMap sm = emitLocalModuleEnv $ mempty {envSourceMap = sm}
Expand Down Expand Up @@ -334,7 +332,7 @@ extendLinearizationCache s fs =

queryObjCache :: EnvReader m => TopFunName n -> m n (Maybe (FunObjCodeName n))
queryObjCache v = lookupEnv v >>= \case
TopFunBinding (DexTopFun _ _ _ (Finished impl)) -> return $ Just $ topFunObjCode impl
TopFunBinding (DexTopFun _ _ (Finished impl)) -> return $ Just $ topFunObjCode impl
_ -> return Nothing

emitObjFile :: (Mut n, TopBuilder m) => CFunction n -> m n (FunObjCodeName n)
Expand Down Expand Up @@ -468,7 +466,7 @@ liftEmitBuilder cont = do
Distinct <- getDistinct
let (result, decls, _) = runHardFail $ unsafeRunInplaceT (runBuilderT' cont) env emptyOutFrag
Emits <- fabricateEmitsEvidenceM
emitDecls (unsafeCoerceB $ unRNest decls) result
emitDecls $ Abs (unsafeCoerceB $ unRNest decls) result

instance (IRRep r, Fallible m) => ScopableBuilder r (BuilderT r m) where
buildScoped cont = BuilderT do
Expand Down Expand Up @@ -601,56 +599,14 @@ buildBlock
:: ScopableBuilder r m
=> (forall l. (Emits l, DExt n l) => m l (Atom r l))
-> m n (Block r n)
buildBlock cont = buildScoped (cont >>= withType) >>= computeAbsEffects >>= absToBlock

withType :: ((EnvReader m, IRRep r), HasType r e) => e l -> m l ((e `PairE` Type r) l)
withType e = do
ty <- {-# SCC blockTypeNormalization #-} cheapNormalize $ getType e
return $ e `PairE` ty
{-# INLINE withType #-}

makeBlock :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> Atom r l -> Type r l -> Block r n
makeBlock decls effs atom ty = Block (BlockAnn (EffTy effs' ty')) decls atom where
ty' = ignoreHoistFailure $ hoist decls ty
effs' = ignoreHoistFailure $ hoist decls effs
{-# INLINE makeBlock #-}

absToBlockInferringTypes :: (EnvReader m, IRRep r) => Abs (Nest (Decl r)) (Atom r) n -> m n (Block r n)
absToBlockInferringTypes ab = liftEnvReaderM do
abWithEffs <- computeAbsEffects ab
refreshAbs abWithEffs \decls (effs `PairE` result) -> do
ty <- cheapNormalize $ getType result
return $ ignoreExcept $
absToBlock $ Abs decls (effs `PairE` (result `PairE` ty))
{-# INLINE absToBlockInferringTypes #-}

absToBlock
:: (Fallible m, IRRep r)
=> Abs (Nest (Decl r)) (EffectRow r `PairE` (Atom r `PairE` Type r)) n -> m (Block r n)
absToBlock (Abs decls (effs `PairE` (result `PairE` ty))) = do
let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line
<> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line
<> group ("With effects:" <> nest 2 (line <> pretty effs))
ty' <- liftHoistExcept' (docAsStr msg) $ hoist decls ty
effs' <- liftHoistExcept' (docAsStr msg) $ hoist decls effs
return $ Block (BlockAnn (EffTy effs' ty')) decls result
{-# INLINE absToBlock #-}

makeBlockFromDecls :: (EnvReader m, IRRep r) => Abs (Nest (Decl r)) (Atom r) n -> m n (Block r n)
makeBlockFromDecls (Abs Empty result) = return $ AtomicBlock result
makeBlockFromDecls ab = liftEnvReaderM $ refreshAbs ab \decls result -> do
ty <- return $ getType result
effs <- declNestEffects decls
PairE ty' effs' <- return $ ignoreHoistFailure $ hoist decls $ PairE ty effs
return $ Block (BlockAnn (EffTy effs' ty')) decls result
{-# INLINE makeBlockFromDecls #-}
buildBlock = buildScoped

coreLamExpr :: EnvReader m => AppExplicitness
-> Abs (Nest (WithExpl CBinder)) (PairE (EffectRow CoreIR) CBlock) n
-> m n (CoreLamExpr n)
coreLamExpr appExpl ab = liftEnvReaderM do
refreshAbs ab \bs' (PairE effs' body') -> do
resultTy <- return $ getType body'
EffTy _ resultTy <- blockEffTy body'
let bs'' = fmapNest withoutExpl bs'
return $ CoreLamExpr (CorePiType appExpl bs' (EffTy effs' resultTy)) (LamExpr bs'' body')

Expand Down Expand Up @@ -789,7 +745,7 @@ buildCase' scrut resultTy indexedAltBody = do
(alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do
(Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do
blk <- buildBlock $ indexedAltBody i $ Var $ sink x
eff <- return $ getEffects blk
EffTy eff _ <- blockEffTy blk
return $ blk `PairE` eff
return (Abs b' body, ignoreHoistFailure $ hoist b' eff')
return $ Case scrut alts $ EffTy (mconcat effs) resultTy
Expand Down Expand Up @@ -1157,9 +1113,15 @@ mkDictAtom d = do
ty <- typeOfDictExpr d
return $ DictCon ty d

mkCase :: EnvReader m => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n)
mkCase = undefined
-- eff' <- foldMapM (pure . getEffects) alts'
-- void $ emitExpr $ Case (sink scrut') alts' (EffTy eff' UnitTy)

mkCatchException :: EnvReader m => CBlock n -> m n (Hof CoreIR n)
mkCatchException body = do
resultTy <- makePreludeMaybeTy $ getType body
EffTy _ bodyTy <- blockEffTy body
resultTy <- makePreludeMaybeTy bodyTy
return $ CatchException resultTy body

app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n)
Expand All @@ -1177,7 +1139,7 @@ naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] ->
naryTopAppInlined f xs = do
TopFunBinding f' <- lookupEnv f
case f' of
DexTopFun _ _ (LamExpr bs body) _ ->
DexTopFun _ (TopLam _ (LamExpr bs body)) _ ->
applySubst (bs@@>(SubstVal<$>xs)) body >>= emitBlock
_ -> naryTopApp f xs
{-# INLINE naryTopAppInlined #-}
Expand Down Expand Up @@ -1551,15 +1513,10 @@ type ExprVisitorNoEmits2 m r = forall i o. ExprVisitorNoEmits (m i o) r i o
visitLamNoEmits
:: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m)
=> LamExpr r i -> m i o (LamExpr r o)
visitLamNoEmits (LamExpr bs body) =
visitBinders bs \bs' -> LamExpr bs' <$> visitBlockNoEmits body

visitBlockNoEmits
:: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m)
=> Block r i -> m i o (Block r o)
visitBlockNoEmits (Block _ decls result) =
absToBlockInferringTypes =<< visitDeclsNoEmits decls \decls' -> do
Abs decls' <$> visitAtom result
visitLamNoEmits (LamExpr bs (Abs decls result)) =
visitBinders bs \bs' -> LamExpr bs' <$>
visitDeclsNoEmits decls \decls' -> Abs decls' <$> do
visitAtom result

visitDeclsNoEmits
:: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m)
Expand Down Expand Up @@ -1602,7 +1559,7 @@ visitLamEmits (LamExpr bs body) = visitBinders bs \bs' -> LamExpr bs' <$>
visitBlockEmits
:: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o)
=> Block r i -> m i o (Atom r o)
visitBlockEmits (Block _ decls result) = visitDeclsEmits decls $ visitAtom result
visitBlockEmits (Abs decls result) = visitDeclsEmits decls $ visitAtom result

visitDeclsEmits
:: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o)
Expand Down
29 changes: 13 additions & 16 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,6 @@ instance CheaplyReducibleE CoreIR TyConParams TyConParams where
instance (CheaplyReducibleE r e e', NiceE r e') => CheaplyReducibleE r (Abs (Nest (Decl r)) e) e' where
cheapReduceE (Abs decls result) = cheapReduceWithDeclsB decls $ cheapReduceE result

instance (CheaplyReducibleE r (Atom r) e', NiceE r e') => CheaplyReducibleE r (Block r) e' where
cheapReduceE (Block _ decls result) = cheapReduceE $ Abs decls result

instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where
cheapReduceE expr = confuseGHC >>= \_ -> case expr of
Atom atom -> cheapReduceE atom
Expand Down Expand Up @@ -517,10 +514,10 @@ instance VisitGeneric (Type r) r where visitGeneric = visitType
instance VisitGeneric (LamExpr r) r where visitGeneric = visitLam
instance VisitGeneric (PiType r) r where visitGeneric = visitPi

instance VisitGeneric (Block r) r where
visitGeneric b = visitGeneric (LamExpr Empty b) >>= \case
LamExpr Empty b' -> return b'
_ -> error "not a block"
visitBlock :: Visitor m r i o => Block r i -> m (Block r o)
visitBlock b = visitGeneric (LamExpr Empty b) >>= \case
LamExpr Empty b' -> return b'
_ -> error "not a block"

visitAlt :: Visitor m r i o => Alt r i -> m (Alt r o)
visitAlt (Abs b body) = do
Expand Down Expand Up @@ -623,9 +620,10 @@ instance IRRep r => VisitGeneric (Expr r) r where
return $ Case x' alts' $ EffTy effs' t'
where
altEffects :: Alt r n -> EffectRow r n
altEffects (Abs bs (Block ann _ _)) = case ann of
NoBlockAnn -> Pure
BlockAnn (EffTy effs _) -> ignoreHoistFailure $ hoist bs effs
altEffects = undefined
-- altEffects (Abs bs (Block ann _ _)) = case ann of
-- NoBlockAnn -> Pure
-- BlockAnn (EffTy effs _) -> ignoreHoistFailure $ hoist bs effs
Atom x -> Atom <$> visitGeneric x
TabCon Nothing t xs -> TabCon Nothing <$> visitGeneric t <*> mapM visitGeneric xs
TabCon (Just (WhenIRE d)) t xs -> TabCon <$> (Just . WhenIRE <$> visitGeneric d) <*> visitGeneric t <*> mapM visitGeneric xs
Expand Down Expand Up @@ -654,10 +652,10 @@ instance IRRep r => VisitGeneric (Hof r) r where
RunReader x body -> RunReader <$> visitGeneric x <*> visitGeneric body
RunWriter dest bm body -> RunWriter <$> mapM visitGeneric dest <*> visitGeneric bm <*> visitGeneric body
RunState dest s body -> RunState <$> mapM visitGeneric dest <*> visitGeneric s <*> visitGeneric body
While b -> While <$> visitGeneric b
RunIO b -> RunIO <$> visitGeneric b
RunInit b -> RunInit <$> visitGeneric b
CatchException t b -> CatchException <$> visitType t <*> visitGeneric b
While b -> While <$> visitBlock b
RunIO b -> RunIO <$> visitBlock b
RunInit b -> RunInit <$> visitBlock b
CatchException t b -> CatchException <$> visitType t <*> visitBlock b
Linearize lam x -> Linearize <$> visitGeneric lam <*> visitGeneric x
Transpose lam x -> Transpose <$> visitGeneric lam <*> visitGeneric x

Expand All @@ -674,7 +672,7 @@ instance IRRep r => VisitGeneric (DAMOp r) r where

instance VisitGeneric UserEffectOp CoreIR where
visitGeneric = \case
Handle name xs body -> Handle <$> renameN name <*> mapM visitGeneric xs <*> visitGeneric body
Handle name xs body -> Handle <$> renameN name <*> mapM visitGeneric xs <*> visitBlock body
Resume t x -> Resume <$> visitGeneric t <*> visitGeneric x
Perform x i -> Perform <$> visitGeneric x <*> pure i

Expand Down Expand Up @@ -889,7 +887,6 @@ instance IRRep r => SubstE AtomSubstVal (PrimOp r)
instance IRRep r => SubstE AtomSubstVal (RefOp r)
instance IRRep r => SubstE AtomSubstVal (EffTy r)
instance IRRep r => SubstE AtomSubstVal (Expr r)
instance IRRep r => SubstE AtomSubstVal (Block r)
instance IRRep r => SubstE AtomSubstVal (GenericOpRep const r)
instance SubstE AtomSubstVal InstanceBody
instance SubstE AtomSubstVal DictType
Expand Down
Loading

0 comments on commit a27faea

Please sign in to comment.