Skip to content

Commit

Permalink
Remove effect annotations and the CPS effect runners.
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Apr 22, 2024
1 parent 7a55417 commit a1d4215
Show file tree
Hide file tree
Showing 25 changed files with 394 additions and 1,710 deletions.
42 changes: 10 additions & 32 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,14 @@ aInstanceDef (CInstanceDef (WithSrc clNameId clName) args givens methods instNam
aDef :: CDef -> SyntaxM (SourceNameW, ULamExpr VoidS)
aDef (CDef name params optRhs optGivens body) = do
explicitParams <- explicitBindersOptAnn params
let rhsDefault = (ExplicitApp, Nothing, Nothing)
(expl, effs, resultTy) <- fromMaybeM optRhs rhsDefault \(expl, optEffs, resultTy) -> do
effs <- fromMaybeM optEffs UPure aEffects
let rhsDefault = (ExplicitApp, Nothing)
(expl, resultTy) <- fromMaybeM optRhs rhsDefault \(expl, resultTy) -> do
resultTy' <- expr resultTy
return (expl, Just effs, Just resultTy')
return (expl, Just resultTy')
implicitParams <- aOptGivens optGivens
let allParams = implicitParams >>> explicitParams
body' <- block body
return (name, ULamExpr allParams expl effs resultTy body')
return (name, ULamExpr allParams expl resultTy body')

stripParens :: GroupW -> GroupW
stripParens (WithSrcs _ _ (CParens [g])) = stripParens g
Expand Down Expand Up @@ -356,34 +355,14 @@ identifier ctx (WithSrcs sid _ g) = case g of
CLeaf (CIdentifier name) -> return $ WithSrc sid name
_ -> throw sid $ ExpectedIdentifier ctx

aEffects :: WithSrcs ([GroupW], Maybe GroupW) -> SyntaxM (UEffectRow VoidS)
aEffects (WithSrcs _ _ (effs, optEffTail)) = do
lhs <- mapM effect effs
rhs <- forM optEffTail \effTail ->
fromSourceNameW <$> identifier "effect row remainder variable" effTail
return $ UEffectRow (S.fromList lhs) rhs

effect :: GroupW -> SyntaxM (UEffect VoidS)
effect (WithSrcs grpSid _ grp) = case grp of
CParens [g] -> effect g
CJuxtapose True (Identifier "Read" ) (WithSrcs sid _ (CLeaf (CIdentifier h))) ->
return $ URWSEffect Reader $ fromSourceNameW (WithSrc sid h)
CJuxtapose True (Identifier "Accum") (WithSrcs sid _ (CLeaf (CIdentifier h))) ->
return $ URWSEffect Writer $ fromSourceNameW (WithSrc sid h)
CJuxtapose True (Identifier "State") (WithSrcs sid _ (CLeaf (CIdentifier h))) ->
return $ URWSEffect State $ fromSourceNameW (WithSrc sid h)
CLeaf (CIdentifier "Except") -> return UExceptionEffect
CLeaf (CIdentifier "IO" ) -> return UIOEffect
_ -> throw grpSid UnexpectedEffectForm

aMethod :: CSDeclW -> SyntaxM (Maybe (UMethodDef VoidS))
aMethod (WithSrcs _ _ CPass) = return Nothing
aMethod (WithSrcs sid _ d) = Just . WithSrcE sid <$> case d of
CDefDecl def -> do
(WithSrc nameSid name, lam) <- aDef def
return $ UMethodDef (SourceName nameSid name) lam
CLet (WithSrcs lhsSid _ (CLeaf (CIdentifier name))) rhs -> do
rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs
rhs' <- ULamExpr Empty ImplicitApp Nothing <$> block rhs
return $ UMethodDef (fromSourceNameW (WithSrc lhsSid name)) rhs'
_ -> throw sid UnexpectedMethodDef

Expand All @@ -407,7 +386,7 @@ blockDecls (WithSrcs sid _ (CBind b rhs):ds) = do
b' <- binderOptTy Explicit b
rhs' <- asExpr <$> block rhs
body <- block $ IndentedBlock sid ds -- Not really the right SrcId
let lam = ULam $ ULamExpr (UnaryNest b') ExplicitApp Nothing Nothing body
let lam = ULam $ ULamExpr (UnaryNest b') ExplicitApp Nothing body
return (Empty, WithSrcE sid $ extendAppRight rhs' (WithSrcE sid lam))
blockDecls (d:ds) = do
d' <- decl PlainLet d
Expand All @@ -428,13 +407,12 @@ expr (WithSrcs sid sids grp) = WithSrcE sid <$> case grp of
-- should be detected upstream, before calling expr.
CBrackets gs -> UTabCon <$> mapM expr gs
CGivens _ -> throw sid UnexpectedGivenClause
CArrow lhs effs rhs -> do
CArrow lhs rhs -> do
case lhs of
WithSrcs _ _ (CParens gs) -> do
bs <- aPiBinders gs
effs' <- fromMaybeM effs UPure aEffects
resultTy <- expr rhs
return $ UPi $ UPiExpr bs ExplicitApp effs' resultTy
return $ UPi $ UPiExpr bs ExplicitApp resultTy
WithSrcs lhsSid _ _ -> throw lhsSid ArgsShouldHaveParens
CDo b -> UDo <$> block b
CJuxtapose hasSpace lhs rhs -> case hasSpace of
Expand Down Expand Up @@ -476,7 +454,7 @@ expr (WithSrcs sid sids grp) = WithSrcE sid <$> case grp of
WithSrcs _ _ (CParens gs) -> do
bs <- aPiBinders gs
resultTy <- expr rhs
return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy
return $ UPi $ UPiExpr bs ImplicitApp resultTy
WithSrcs lhsSid _ _ -> throw lhsSid ArgsShouldHaveParens
FatArrow -> do
lhs' <- tyOptPat lhs
Expand All @@ -501,7 +479,7 @@ expr (WithSrcs sid sids grp) = WithSrcE sid <$> case grp of
CLambda params body -> do
params' <- explicitBindersOptAnn $ WithSrcs sid [] $ map stripParens params
body' <- block body
return $ ULam $ ULamExpr params' ExplicitApp Nothing Nothing body'
return $ ULam $ ULamExpr params' ExplicitApp Nothing body'
CFor kind indices body -> do
let (dir, trailingUnit) = case kind of
KFor -> (Fwd, False)
Expand Down
180 changes: 16 additions & 164 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import Control.Monad.Reader
import Control.Monad.Writer.Strict hiding (Alt)
import Control.Monad.State.Strict (MonadState (..), StateT (..), runStateT)
import qualified Data.Map.Strict as M
import Data.Foldable (fold)
import Data.Graph (graphFromEdges, topSort)
import Data.Foldable (fold)
import Foreign.Ptr

import qualified Unsafe.Coerce as TrulyUnsafe
Expand Down Expand Up @@ -709,28 +709,14 @@ buildCase' scrut resultTy indexedAltBody = case scrut of
blk <- buildBlock $ indexedAltBody i $ toAtom $ sink x
return $ blk `PairE` getEffects blk
return (Abs b' body, ignoreHoistFailure $ hoist b' eff')
return $ Case scrut alts $ EffTy (mconcat effs) resultTy
return $ Case scrut alts $ EffTy (fold effs) resultTy

buildCase :: (Emits n, ScopableBuilder r m)
=> Atom r n -> Type r n
-> (forall l. (Emits l, DExt n l) => Int -> Atom r l -> m l (Atom r l))
-> m n (Atom r n)
buildCase s r b = emit =<< buildCase' s r b

buildEffLam
:: ScopableBuilder r m
=> NameHint -> Type r n
-> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l))
-> m n (LamExpr r n)
buildEffLam hint ty body = do
withFreshBinder noHint (TyCon HeapType) \h -> do
let ty' = RefTy (toAtom $ binderVar h) (sink ty)
withFreshBinder hint ty' \b -> do
let ref = binderVar b
hVar <- sinkM $ binderVar h
body' <- buildBlock $ body (sink hVar) $ sink ref
return $ LamExpr (BinaryNest h b) body'

emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n)
emitHof hof = mkTypedHof hof >>= emit

Expand Down Expand Up @@ -766,35 +752,6 @@ buildMap xs f = do
buildFor noHint Fwd (tabIxType t) \i ->
tabApp (sink xs) (toAtom i) >>= f

emitRunWriter
:: (Emits n, ScopableBuilder r m)
=> NameHint -> Type r n -> BaseMonoid r n
-> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l))
-> m n (Atom r n)
emitRunWriter hint accTy bm body = do
lam <- buildEffLam hint accTy \h ref -> body h ref
emitHof $ RunWriter Nothing bm lam

emitRunState
:: (Emits n, ScopableBuilder r m)
=> NameHint -> Atom r n
-> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l))
-> m n (Atom r n)
emitRunState hint initVal body = do
stateTy <- return $ getType initVal
lam <- buildEffLam hint stateTy \h ref -> body h ref
emitHof $ RunState Nothing initVal lam

emitRunReader
:: (Emits n, ScopableBuilder r m)
=> NameHint -> Atom r n
-> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l))
-> m n (Atom r n)
emitRunReader hint r body = do
rTy <- return $ getType r
lam <- buildEffLam hint rTy \h ref -> body h ref
emitHof $ RunReader r lam

emitSeq :: (Emits n, ScopableBuilder SimpIR m)
=> Direction -> IxType SimpIR n -> Atom SimpIR n -> LamExpr SimpIR n
-> m n (Atom SimpIR n)
Expand All @@ -806,8 +763,7 @@ mkSeq :: EnvReader m
=> Direction -> IxType SimpIR n -> Atom SimpIR n -> LamExpr SimpIR n
-> m n (DAMOp SimpIR n)
mkSeq d t x f = do
effTy <- functionEffs f
return $ Seq effTy d t x f
return $ Seq undefined d t x f

buildRememberDest :: (Emits n, ScopableBuilder SimpIR m)
=> NameHint -> SAtom n
Expand All @@ -816,8 +772,7 @@ buildRememberDest :: (Emits n, ScopableBuilder SimpIR m)
buildRememberDest hint dest cont = do
ty <- return $ getType dest
doit <- buildUnaryLamExpr hint ty cont
effs <- functionEffs doit
emit $ PrimOp $ DAMOp $ RememberDest effs dest doit
emit $ PrimOp $ DAMOp $ RememberDest undefined dest doit

-- === vector space (ish) type class ===

Expand Down Expand Up @@ -1040,15 +995,16 @@ mkBlock (Abs decls body) = do
return $ Block effTy block

blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n)
blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do
effs <- declsEffects decls mempty
return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result
where
declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l)
declsEffects Empty !acc = return acc
declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do
expr' <- sinkM expr
declsEffects rest $ acc <> getEffects expr'
blockEffTy _ = undefined
-- blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do
-- effs <- declsEffects decls mempty
-- return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result
-- where
-- declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l)
-- declsEffects Empty !acc = return acc
-- declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do
-- expr' <- sinkM expr
-- declsEffects rest $ acc <> getEffects expr'

mkApp :: EnvReader m => CAtom n -> [CAtom n] -> m n (CExpr n)
mkApp f xs = do
Expand Down Expand Up @@ -1084,15 +1040,9 @@ mkInstanceDict instanceName args = do

mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n)
mkCase scrut resultTy alts = liftEnvReaderM do
eff' <- fold <$> forM alts \alt -> refreshAbs alt \b body -> do
return $ ignoreHoistFailure $ hoist b $ getEffects body
eff' <- undefined
return $ Case scrut alts (EffTy eff' resultTy)

mkCatchException :: EnvReader m => CExpr n -> m n (Hof CoreIR n)
mkCatchException body = do
resultTy <- makePreludeMaybeTy (getType body)
return $ CatchException resultTy body

app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n)
app x i = mkApp x [i] >>= emit

Expand Down Expand Up @@ -1134,9 +1084,7 @@ ptrOffset x i = emit $ MemOp $ PtrOffset x i
{-# INLINE ptrOffset #-}

unsafePtrLoad :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n)
unsafePtrLoad x = do
body <- liftEmitBuilder $ buildBlock $ emit . MemOp . PtrLoad =<< sinkM x
emitHof $ RunIO body
unsafePtrLoad x = emit . MemOp . PtrLoad =<< sinkM x

mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (PrimOp r n)
mkIndexRef ref i = do
Expand Down Expand Up @@ -1198,102 +1146,6 @@ emitIf predicate resultTy trueCase falseCase = do
1 -> trueCase
_ -> error "should only have two cases"

emitMaybeCase :: (Emits n, ScopableBuilder r m)
=> Atom r n -> Type r n
-> (forall l. (Emits l, DExt n l) => m l (Atom r l))
-> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l))
-> m n (Atom r n)
emitMaybeCase scrut resultTy nothingCase justCase = do
buildCase scrut resultTy \i v ->
case i of
0 -> nothingCase
1 -> justCase v
_ -> error "should be a binary scrutinee"

-- Maybe a -> a
fromJustE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n)
fromJustE x = liftEmitBuilder do
MaybeTy a <- return $ getType x
emitMaybeCase x a
(emit $ MiscOp $ ThrowError $ sink a)
(return)

-- Maybe a -> Bool
isJustE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n)
isJustE x = liftEmitBuilder $
emitMaybeCase x BoolTy (return FalseAtom) (\_ -> return TrueAtom)

-- Monoid a -> (n=>a) -> a
reduceE :: (Emits n, SBuilder m) => BaseMonoid SimpIR n -> SAtom n -> m n (SAtom n)
reduceE monoid xs = liftEmitBuilder do
TabPi tabPi <- return $ getTyCon xs
let a = assumeConst tabPi
getSnd =<< emitRunWriter noHint a monoid \_ ref ->
buildFor noHint Fwd (sink $ tabIxType tabPi) \i -> do
x <- tabApp (sink xs) (toAtom i)
emit $ PrimOp $ RefOp (sink $ toAtom ref) $ MExtend (sink monoid) x

andMonoid :: (EnvReader m, IRRep r) => m n (BaseMonoid r n)
andMonoid = liftM (BaseMonoid TrueAtom) $ liftBuilder $
buildBinaryLamExpr (noHint, BoolTy) (noHint, BoolTy) \x y ->
emit $ BinOp BAnd (sink $ toAtom x) (toAtom y)

-- (a-> {|eff} b) -> n=>a -> {|eff} (n=>b)
mapE :: (Emits n, ScopableBuilder SimpIR m)
=> (forall l. (Emits l, DExt n l) => SAtom l -> m l (SAtom l))
-> SAtom n -> m n (SAtom n)
mapE cont xs = do
TabPi tabPi <- return $ getTyCon xs
buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> do
tabApp (sink xs) (toAtom i) >>= cont

-- (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) =
catMaybesE :: (Emits n, SBuilder m) => SAtom n -> m n (SAtom n)
catMaybesE maybes = do
TabTy d n (MaybeTy a) <- return $ getType maybes
justs <- liftEmitBuilder $ mapE isJustE maybes
monoid <- andMonoid
allJust <- reduceE monoid justs
liftEmitBuilder $ emitIf allJust (MaybeTy $ TabTy d n a)
(JustAtom (sink $ TabTy d n a) <$> mapE fromJustE (sink maybes))
(return (NothingAtom $ sink $ TabTy d n a))

emitWhile :: (Emits n, ScopableBuilder r m)
=> (forall l. (Emits l, DExt n l) => m l (Atom r l))
-> m n ()
emitWhile cont = do
body <- buildBlock cont
void $ emitHof $ While body

-- Dex implementation, for reference
-- def whileMaybe (eff:Effects) -> (body: Unit -> {|eff} (Maybe Word8)) : {|eff} Maybe Unit =
-- hadError = yieldState False \ref.
-- while do
-- ans = liftState ref body ()
-- case ans of
-- Nothing ->
-- ref := True
-- False
-- Just cond -> W8ToB cond
-- if hadError
-- then Nothing
-- else Just ()

runMaybeWhile :: (Emits n, ScopableBuilder r m)
=> (forall l. (Emits l, DExt n l) => m l (Atom r l))
-> m n (Atom r n)
runMaybeWhile body = do
hadError <- getSnd =<< emitRunState noHint FalseAtom \_ ref -> do
emitWhile do
ans <- body
emitMaybeCase ans Word8Ty
(emit (RefOp (sink $ toAtom ref) $ MPut TrueAtom) >> return FalseAtom)
(return)
return UnitVal
emitIf hadError (MaybeTy UnitTy)
(return $ NothingAtom UnitTy)
(return $ JustAtom UnitTy UnitVal)

-- === capturing closures with telescopes ===

type ReconAbs r e = Abs (ReconBinders r) e
Expand Down
Loading

0 comments on commit a1d4215

Please sign in to comment.