diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index bf65fda87..e2143fd9a 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -110,23 +110,23 @@ topDecl = dropSrc topDecl' where topDecl' (CSDecl ann d) = ULocalDecl <$> decl ann (WithSrc emptySrcPosCtx d) topDecl' (CData name tyConParams givens constructors) = do tyConParams' <- aExplicitParams tyConParams - givens' <- toNest <$> fromMaybeM givens [] aGivens + givens' <- aOptGivens givens constructors' <- forM constructors \(v, ps) -> do ps' <- toNest <$> mapM tyOptBinder ps return (v, ps') return $ UDataDefDecl - (UDataDef name (givens' >>> tyConParams') $ + (UDataDef name (catUOptAnnExplBinders givens' tyConParams') $ map (\(name', cons) -> (name', UDataDefTrail cons)) constructors') (fromString name) (toNest $ map (fromString . fst) constructors') topDecl' (CStruct name params givens fields defs) = do params' <- aExplicitParams params - givens' <- toNest <$> fromMaybeM givens [] aGivens + givens' <- aOptGivens givens fields' <- forM fields \(v, ty) -> (v,) <$> expr ty methods <- forM defs \(ann, d) -> do (methodName, lam) <- aDef d return (ann, methodName, Abs (UBindSource emptySrcPosCtx "self") lam) - return $ UStructDecl (fromString name) (UStructDef name (givens' >>> params') fields' methods) + return $ UStructDecl (fromString name) (UStructDef name (catUOptAnnExplBinders givens' params') fields' methods) topDecl' (CInterface name params methods) = do params' <- aExplicitParams params (methodNames, methodTys) <- unzip <$> forM methods \(methodName, ty) -> do @@ -153,7 +153,7 @@ aInstanceDef :: CInstanceDef -> SyntaxM (UTopDecl VoidS VoidS) aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do let clName' = fromString clName args' <- mapM expr args - givens' <- toNest <$> fromMaybeM givens [] aGivens + givens' <- aOptGivens givens methods' <- catMaybes <$> mapM aMethod methods case instNameAndParams of Nothing -> return $ UInstance clName' givens' args' methods' NothingB ImplicitApp @@ -162,7 +162,7 @@ aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do case optParams of Just params -> do params' <- aExplicitParams params - return $ UInstance clName' (givens' >>> params') args' methods' instName' ExplicitApp + return $ UInstance clName' (catUOptAnnExplBinders givens' params') args' methods' instName' ExplicitApp Nothing -> return $ UInstance clName' givens' args' methods' instName' ImplicitApp aDef :: CDef -> SyntaxM (SourceName, ULamExpr VoidS) @@ -173,19 +173,27 @@ aDef (CDef name params optRhs optGivens body) = do effs <- fromMaybeM optEffs UPure aEffects resultTy' <- expr resultTy return (expl, Just effs, Just resultTy') - implicitParams <- toNest <$> fromMaybeM optGivens [] aGivens - let allParams = implicitParams >>> explicitParams + implicitParams <- aOptGivens optGivens + let allParams = catUOptAnnExplBinders implicitParams explicitParams body' <- block body return (name, ULamExpr allParams expl effs resultTy body') +catUOptAnnExplBinders :: UOptAnnExplBinders n l -> UOptAnnExplBinders l l' -> UOptAnnExplBinders n l' +catUOptAnnExplBinders (expls, bs) (expls', bs') = (expls <> expls', bs >>> bs') + stripParens :: Group -> Group stripParens (WithSrc _ (CParens [g])) = stripParens g stripParens g = g -aExplicitParams :: ExplicitParams -> SyntaxM (Nest (WithExpl UOptAnnBinder) VoidS VoidS) +aExplicitParams :: ExplicitParams -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS) aExplicitParams gs = generalBinders DataParam Explicit gs -aGivens :: GivenClause -> SyntaxM [WithExpl UOptAnnBinder VoidS VoidS] +aOptGivens :: Maybe GivenClause -> SyntaxM (UOptAnnExplBinders VoidS VoidS) +aOptGivens optGivens = do + (expls, implicitParams) <- unzip <$> fromMaybeM optGivens [] aGivens + return (expls, toNest implicitParams) + +aGivens :: GivenClause -> SyntaxM [(Explicitness, UOptAnnBinder VoidS VoidS)] aGivens (implicits, optConstraints) = do implicits' <- mapM (generalBinder DataParam (Inferred Nothing Unify)) implicits constraints <- fromMaybeM optConstraints [] \gs -> do @@ -194,23 +202,24 @@ aGivens (implicits, optConstraints) = do generalBinders :: ParamStyle -> Explicitness -> [Group] - -> SyntaxM (Nest (WithExpl UOptAnnBinder) VoidS VoidS) -generalBinders paramStyle expl params = toNest . concat <$> - forM params \case + -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS) +generalBinders paramStyle expl params = do + (expls, bs) <- unzip . concat <$> forM params \case WithSrc _ (CGivens gs) -> aGivens gs p -> (:[]) <$> generalBinder paramStyle expl p + return (expls, toNest bs) generalBinder :: ParamStyle -> Explicitness -> Group - -> SyntaxM (WithExpl UOptAnnBinder VoidS VoidS) + -> SyntaxM (Explicitness, UOptAnnBinder VoidS VoidS) generalBinder paramStyle expl g = case expl of - Inferred _ (Synth _) -> WithExpl expl <$> tyOptBinder g + Inferred _ (Synth _) -> (expl,) <$> tyOptBinder g Inferred _ Unify -> do b <- binderOptTy g expl' <- return case b of UAnnBinder (UBindSource _ s) _ _ -> Inferred (Just s) Unify _ -> expl - return $ WithExpl expl' b - Explicit -> WithExpl expl <$> case paramStyle of + return (expl', b) + Explicit -> (expl,) <$> case paramStyle of TypeParam -> tyOptBinder g DataParam -> binderOptTy g @@ -338,7 +347,6 @@ effect (Binary JuxtaposeWithSpace (Identifier "State") (Identifier h)) = return $ URWSEffect State $ fromString h effect (Identifier "Except") = return UExceptionEffect effect (Identifier "IO") = return UIOEffect -effect (Identifier effName) = return $ UUserEffect (fromString effName) effect _ = throw SyntaxErr "Unexpected effect form; expected one of `Read h`, `Accum h`, `State h`, `Except`, `IO`, or the name of a user-defined effect." aMethod :: CSDecl -> SyntaxM (Maybe (UMethodDef VoidS)) @@ -348,7 +356,7 @@ aMethod (WithSrc src d) = Just . WithSrcE src <$> addSrcContext src case d of (name, lam) <- aDef def return $ UMethodDef (fromString name) lam CLet (WithSrc _ (CIdentifier name)) rhs -> do - rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs + rhs' <- ULamExpr ([], Empty) ImplicitApp Nothing Nothing <$> block rhs return $ UMethodDef (fromString name) rhs' _ -> throw SyntaxErr "Unexpected method definition. Expected `def` or `x = ...`." @@ -369,10 +377,10 @@ blockDecls [WithSrc src d] = addSrcContext src case d of CExpr g -> (Empty,) <$> expr g _ -> throw SyntaxErr "Block must end in expression" blockDecls (WithSrc pos (CBind b rhs):ds) = do - WithExpl _ b' <- generalBinder DataParam Explicit b + (_, b') <- generalBinder DataParam Explicit b rhs' <- asExpr <$> block rhs body <- block $ IndentedBlock ds - let lam = ULam $ ULamExpr (UnaryNest (WithExpl Explicit b')) ExplicitApp Nothing Nothing body + let lam = ULam $ ULamExpr ([Explicit], UnaryNest b') ExplicitApp Nothing Nothing body return (Empty, WithSrcE pos $ extendAppRight rhs' (ns lam)) blockDecls (d:ds) = do d' <- decl PlainLet d diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index b526eeace..65491714e 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -137,7 +137,7 @@ type BlockTraverserM i o a = SubstReaderT PolySubstVal (MaybeT1 (BuilderM SimpIR blockAsPoly :: (EnvExtender m, EnvReader m) => Block SimpIR n -> m n (Maybe (Polynomial n)) -blockAsPoly (Block _ decls result) = +blockAsPoly (Abs decls result) = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ blockAsPolyRec decls result blockAsPolyRec :: Nest (Decl SimpIR) i i' -> Atom SimpIR i' -> BlockTraverserM i o (Polynomial o) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index dd4225b1e..5f1c372de 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -15,8 +15,9 @@ import Control.Monad.Reader import Control.Monad.Writer.Strict hiding (Alt) import Control.Monad.State.Strict (MonadState (..), StateT (..), runStateT) import qualified Data.Map.Strict as M +import Data.Foldable (fold) import Data.Graph (graphFromEdges, topSort) -import Data.Text.Prettyprint.Doc (Pretty (..), group, line, nest) +import Data.Text.Prettyprint.Doc (Pretty (..)) import Foreign.Ptr import qualified Unsafe.Coerce as TrulyUnsafe @@ -28,7 +29,6 @@ import IRVariants import MTL1 import Subst import Name -import PPrint (prettyBlock) import QueryType import Types.Core import Types.Imp @@ -88,11 +88,11 @@ emitUnOp op x = emitOp $ UnOp op x {-# INLINE emitUnOp #-} emitBlock :: (Builder r m, Emits n) => Block r n -> m n (Atom r n) -emitBlock (Block _ decls result) = emitDecls decls result +emitBlock = emitDecls emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e) - => Nest (Decl r) n l -> e l -> m n (e n) -emitDecls decls result = runSubstReaderT idSubst $ emitDecls' decls result + => WithDecls r e n -> m n (e n) +emitDecls (Abs decls result) = runSubstReaderT idSubst $ emitDecls' decls result emitDecls' :: (Builder r m, Emits o, RenameE e, SinkableE e) => Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o) @@ -278,10 +278,9 @@ emitTopLet hint letAnn expr = do v <- emitBinding hint $ AtomNameBinding $ LetBound (DeclBinding letAnn expr) return $ AtomVar v ty -emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> LamExpr SimpIR n -> m n (TopFunName n) +emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> STopLam n -> m n (TopFunName n) emitTopFunBinding hint def f = do - ty <- return $ getLamExprType f - emitBinding hint $ TopFunBinding $ DexTopFun def ty f Waiting + emitBinding hint $ TopFunBinding $ DexTopFun def f Waiting emitSourceMap :: TopBuilder m => SourceMap n -> m n () emitSourceMap sm = emitLocalModuleEnv $ mempty {envSourceMap = sm} @@ -334,7 +333,7 @@ extendLinearizationCache s fs = queryObjCache :: EnvReader m => TopFunName n -> m n (Maybe (FunObjCodeName n)) queryObjCache v = lookupEnv v >>= \case - TopFunBinding (DexTopFun _ _ _ (Finished impl)) -> return $ Just $ topFunObjCode impl + TopFunBinding (DexTopFun _ _ (Finished impl)) -> return $ Just $ topFunObjCode impl _ -> return Nothing emitObjFile :: (Mut n, TopBuilder m) => CFunction n -> m n (FunObjCodeName n) @@ -468,7 +467,7 @@ liftEmitBuilder cont = do Distinct <- getDistinct let (result, decls, _) = runHardFail $ unsafeRunInplaceT (runBuilderT' cont) env emptyOutFrag Emits <- fabricateEmitsEvidenceM - emitDecls (unsafeCoerceB $ unRNest decls) result + emitDecls $ Abs (unsafeCoerceB $ unRNest decls) result instance (IRRep r, Fallible m) => ScopableBuilder r (BuilderT r m) where buildScoped cont = BuilderT do @@ -601,66 +600,15 @@ buildBlock :: ScopableBuilder r m => (forall l. (Emits l, DExt n l) => m l (Atom r l)) -> m n (Block r n) -buildBlock cont = buildScoped (cont >>= withType) >>= computeAbsEffects >>= absToBlock - -withType :: ((EnvReader m, IRRep r), HasType r e) => e l -> m l ((e `PairE` Type r) l) -withType e = do - ty <- {-# SCC blockTypeNormalization #-} cheapNormalize $ getType e - return $ e `PairE` ty -{-# INLINE withType #-} - -makeBlock :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> Atom r l -> Type r l -> Block r n -makeBlock decls effs atom ty = Block (BlockAnn (EffTy effs' ty')) decls atom where - ty' = ignoreHoistFailure $ hoist decls ty - effs' = ignoreHoistFailure $ hoist decls effs -{-# INLINE makeBlock #-} - -absToBlockInferringTypes :: (EnvReader m, IRRep r) => Abs (Nest (Decl r)) (Atom r) n -> m n (Block r n) -absToBlockInferringTypes ab = liftEnvReaderM do - abWithEffs <- computeAbsEffects ab - refreshAbs abWithEffs \decls (effs `PairE` result) -> do - ty <- cheapNormalize $ getType result - return $ ignoreExcept $ - absToBlock $ Abs decls (effs `PairE` (result `PairE` ty)) -{-# INLINE absToBlockInferringTypes #-} - -absToBlock - :: (Fallible m, IRRep r) - => Abs (Nest (Decl r)) (EffectRow r `PairE` (Atom r `PairE` Type r)) n -> m (Block r n) -absToBlock (Abs decls (effs `PairE` (result `PairE` ty))) = do - let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line - <> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line - <> group ("With effects:" <> nest 2 (line <> pretty effs)) - ty' <- liftHoistExcept' (docAsStr msg) $ hoist decls ty - effs' <- liftHoistExcept' (docAsStr msg) $ hoist decls effs - return $ Block (BlockAnn (EffTy effs' ty')) decls result -{-# INLINE absToBlock #-} - -makeBlockFromDecls :: (EnvReader m, IRRep r) => Abs (Nest (Decl r)) (Atom r) n -> m n (Block r n) -makeBlockFromDecls (Abs Empty result) = return $ AtomicBlock result -makeBlockFromDecls ab = liftEnvReaderM $ refreshAbs ab \decls result -> do - ty <- return $ getType result - effs <- declNestEffects decls - PairE ty' effs' <- return $ ignoreHoistFailure $ hoist decls $ PairE ty effs - return $ Block (BlockAnn (EffTy effs' ty')) decls result -{-# INLINE makeBlockFromDecls #-} - -coreLamExpr :: EnvReader m => AppExplicitness - -> Abs (Nest (WithExpl CBinder)) (PairE (EffectRow CoreIR) CBlock) n - -> m n (CoreLamExpr n) -coreLamExpr appExpl ab = liftEnvReaderM do - refreshAbs ab \bs' (PairE effs' body') -> do - resultTy <- return $ getType body' - let bs'' = fmapNest withoutExpl bs' - return $ CoreLamExpr (CorePiType appExpl bs' (EffTy effs' resultTy)) (LamExpr bs'' body') +buildBlock = buildScoped buildCoreLam :: ScopableBuilder CoreIR m => CorePiType n -> (forall l. (Emits l, DExt n l) => [CAtomVar l] -> m l (CAtom l)) -> m n (CoreLamExpr n) -buildCoreLam piTy@(CorePiType _ bs _) cont = do - lam <- buildLamExpr (EmptyAbs $ fmapNest withoutExpl bs) cont +buildCoreLam piTy@(CorePiType _ _ bs _) cont = do + lam <- buildLamExpr (EmptyAbs bs) cont return $ CoreLamExpr piTy lam buildAbs @@ -736,12 +684,13 @@ buildLamExpr (Abs bs UnitE) cont = case bs of buildLamExpr rest' \vs -> cont $ sink v : vs return $ LamExpr (Nest b' bs') body' -buildLamExprFromPi +buildTopLamFromPi :: ScopableBuilder r m => PiType r n -> (forall l. (Emits l, Distinct l, DExt n l) => [AtomVar r l] -> m l (Atom r l)) - -> m n (LamExpr r n) -buildLamExprFromPi (PiType bs _) cont = buildLamExpr (EmptyAbs bs) cont + -> m n (TopLam r n) +buildTopLamFromPi piTy@(PiType bs _) cont = + TopLam False piTy <$> buildLamExpr (EmptyAbs bs) cont buildAlt :: ScopableBuilder r m @@ -789,7 +738,7 @@ buildCase' scrut resultTy indexedAltBody = do (alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do (Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do blk <- buildBlock $ indexedAltBody i $ Var $ sink x - eff <- return $ getEffects blk + EffTy eff _ <- blockEffTy blk return $ blk `PairE` eff return (Abs b' body, ignoreHoistFailure $ hoist b' eff') return $ Case scrut alts $ EffTy (mconcat effs) resultTy @@ -837,8 +786,8 @@ buildMap :: (Emits n, ScopableBuilder r m) -> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l)) -> m n (Atom r n) buildMap xs f = do - TabTy d (_:>t) _ <- return $ getType xs - buildFor noHint Fwd (IxType t d) \i -> + TabPi t <- return $ getType xs + buildFor noHint Fwd (tabIxType t) \i -> tabApp (sink xs) (Var i) >>= f unzipTab :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n, Atom r n) @@ -908,8 +857,8 @@ zeroAt ty = liftEmitBuilder $ go ty where go = \case BaseTy bt -> return $ Con $ Lit $ zeroLit bt ProdTy tys -> ProdVal <$> mapM go tys - TabTy d (b:>t) bodyTy -> buildFor (getNameHint b) Fwd (IxType t d) \i -> - go =<< applySubst (b @> SubstVal (Var i)) bodyTy + TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> + go =<< instantiateTabPiTy (sink tabPi) (Var i) _ -> unreachable zeroLit bt = case bt of Scalar Float64Type -> Float64Lit 0.0 @@ -953,8 +902,8 @@ tangentBaseMonoidFor ty = do addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n) addTangent x y = do case getType x of - TabTy d (b:>t) _ -> - liftEmitBuilder $ buildFor (getNameHint b) Fwd (IxType t d) \i -> do + TabPi t -> + liftEmitBuilder $ buildFor (getNameHint t) Fwd (tabIxType t) \i -> do bindM2 addTangent (tabApp (sink x) (Var i)) (tabApp (sink y) (Var i)) TC con -> case con of BaseType (Scalar _) -> emitOp $ BinOp FAdd x y @@ -1125,7 +1074,7 @@ projectStructRef i x = do getStructProjections :: EnvReader m => Int -> CType n -> m n [Projection] getStructProjections i (NewtypeTyCon (UserADTType _ tyConName _)) = do - TyConDef _ _ ~(StructFields fields) <- lookupTyCon tyConName + TyConDef _ _ _ ~(StructFields fields) <- lookupTyCon tyConName return case fields of [_] | i == 0 -> [UnwrapNewtype] | otherwise -> error "bad index" @@ -1157,9 +1106,17 @@ mkDictAtom d = do ty <- typeOfDictExpr d return $ DictCon ty d +mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) +mkCase scrut resultTy alts = liftEnvReaderM do + eff' <- fold <$> forM alts \alt -> refreshAbs alt \b body -> do + EffTy eff _ <- blockEffTy body + return $ ignoreHoistFailure $ hoist b eff + return $ Case scrut alts (EffTy eff' resultTy) + mkCatchException :: EnvReader m => CBlock n -> m n (Hof CoreIR n) mkCatchException body = do - resultTy <- makePreludeMaybeTy $ getType body + EffTy _ bodyTy <- blockEffTy body + resultTy <- makePreludeMaybeTy bodyTy return $ CatchException resultTy body app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n) @@ -1177,7 +1134,7 @@ naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> naryTopAppInlined f xs = do TopFunBinding f' <- lookupEnv f case f' of - DexTopFun _ _ (LamExpr bs body) _ -> + DexTopFun _ (TopLam _ _ (LamExpr bs body)) _ -> applySubst (bs@@>(SubstVal<$>xs)) body >>= emitBlock _ -> naryTopApp f xs {-# INLINE naryTopAppInlined #-} @@ -1237,7 +1194,7 @@ applyIxMethod dict method args = case dict of IxDictSpecialized _ d params -> do SpecializedDict _ maybeFs <- lookupSpecDict d Just fs <- return maybeFs - LamExpr bs body <- return $ fs !! fromEnum method + TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method emitBlock =<< applySubst (bs @@> fmap SubstVal (params ++ args)) body unsafeFromOrdinal :: (SBuilder m, Emits n) => IxType SimpIR n -> Atom SimpIR n -> m n (Atom SimpIR n) @@ -1551,15 +1508,10 @@ type ExprVisitorNoEmits2 m r = forall i o. ExprVisitorNoEmits (m i o) r i o visitLamNoEmits :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) => LamExpr r i -> m i o (LamExpr r o) -visitLamNoEmits (LamExpr bs body) = - visitBinders bs \bs' -> LamExpr bs' <$> visitBlockNoEmits body - -visitBlockNoEmits - :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) - => Block r i -> m i o (Block r o) -visitBlockNoEmits (Block _ decls result) = - absToBlockInferringTypes =<< visitDeclsNoEmits decls \decls' -> do - Abs decls' <$> visitAtom result +visitLamNoEmits (LamExpr bs (Abs decls result)) = + visitBinders bs \bs' -> LamExpr bs' <$> + visitDeclsNoEmits decls \decls' -> Abs decls' <$> do + visitAtom result visitDeclsNoEmits :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) @@ -1602,7 +1554,7 @@ visitLamEmits (LamExpr bs body) = visitBinders bs \bs' -> LamExpr bs' <$> visitBlockEmits :: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o) => Block r i -> m i o (Atom r o) -visitBlockEmits (Block _ decls result) = visitDeclsEmits decls $ visitAtom result +visitBlockEmits (Abs decls result) = visitDeclsEmits decls $ visitAtom result visitDeclsEmits :: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o) diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index fde08b7be..4c42bbda2 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -15,7 +15,8 @@ module CheapReduction , unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 - , visitBinders, visitPiDefault, visitAlt, toAtomVar) + , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy, instantiateTabPiTy + , bindersToVars, bindersToAtoms) where import Control.Applicative @@ -239,7 +240,7 @@ cheapReduceDictExpr resultTy d = case d of cheapReduceE child >>= \case DictCon _ (InstanceDict instanceName args) -> dropSubst do args' <- mapM cheapReduceE args - InstanceDef _ bs _ body <- lookupInstanceDef instanceName + InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName let InstanceBody superclasses _ = body applySubst (bs@@>(SubstVal <$> args')) (superclasses !! superclassIx) child' -> return $ DictCon resultTy $ SuperclassProj child' superclassIx @@ -261,9 +262,6 @@ instance CheaplyReducibleE CoreIR TyConParams TyConParams where instance (CheaplyReducibleE r e e', NiceE r e') => CheaplyReducibleE r (Abs (Nest (Decl r)) e) e' where cheapReduceE (Abs decls result) = cheapReduceWithDeclsB decls $ cheapReduceE result -instance (CheaplyReducibleE r (Atom r) e', NiceE r e') => CheaplyReducibleE r (Block r) e' where - cheapReduceE (Block _ decls result) = cheapReduceE $ Abs decls result - instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where cheapReduceE expr = confuseGHC >>= \_ -> case expr of Atom atom -> cheapReduceE atom @@ -287,7 +285,7 @@ instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where cheapReduceE dict >>= \case DictCon _ (InstanceDict instanceName args) -> dropSubst do args' <- mapM cheapReduceE args - InstanceDef _ bs _ (InstanceBody _ methods) <- lookupInstanceDef instanceName + InstanceDef _ _ bs _ (InstanceBody _ methods) <- lookupInstanceDef instanceName let method = methods !! i extendSubst (bs@@>(SubstVal <$> args')) do method' <- cheapReduceE method @@ -468,10 +466,18 @@ wrapNewtypesData [] x = x wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n) -instantiateTyConDef (TyConDef _ bs conDefs) (TyConParams _ xs) = do +instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do applySubst (bs @@> (SubstVal <$> xs)) conDefs {-# INLINE instantiateTyConDef #-} +instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (EffTy r n) +instantiatePiTy (PiType bs effTy) xs = do + applySubst (bs @@> (SubstVal <$> xs)) effTy + +instantiateTabPiTy :: (EnvReader m, IRRep r) => TabPiType r n -> Atom r n -> m n (Type r n) +instantiateTabPiTy (TabPiType _ b resultTy) x = do + applySubst (b @> SubstVal x) resultTy + -- Returns a representation type (type of an TypeCon-typed Newtype payload) -- given a list of instantiated DataConDefs. dataDefRep :: DataConDefs n -> CType n @@ -485,7 +491,7 @@ dataDefRep (StructFields fields) = case map snd fields of makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n) makeStructRepVal tyConName args = do - TyConDef _ _ (StructFields fields) <- lookupTyCon tyConName + TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName case fields of [_] -> case args of [arg] -> return arg @@ -517,10 +523,10 @@ instance VisitGeneric (Type r) r where visitGeneric = visitType instance VisitGeneric (LamExpr r) r where visitGeneric = visitLam instance VisitGeneric (PiType r) r where visitGeneric = visitPi -instance VisitGeneric (Block r) r where - visitGeneric b = visitGeneric (LamExpr Empty b) >>= \case - LamExpr Empty b' -> return b' - _ -> error "not a block" +visitBlock :: Visitor m r i o => Block r i -> m (Block r o) +visitBlock b = visitGeneric (LamExpr Empty b) >>= \case + LamExpr Empty b' -> return b' + _ -> error "not a block" visitAlt :: Visitor m r i o => Alt r i -> m (Alt r o) visitAlt (Abs b body) = do @@ -615,17 +621,11 @@ instance IRRep r => VisitGeneric (Expr r) r where -- TODO: should we reuse the original effects? Whether it's valid depends on -- the type-preservation requirements for a visitor. We should clarify what -- those are. - Case x alts (EffTy _ t) -> do + Case x alts effTy -> do x' <- visitGeneric x - t' <- visitGeneric t alts' <- mapM visitAlt alts - let effs' = foldMap altEffects alts' - return $ Case x' alts' $ EffTy effs' t' - where - altEffects :: Alt r n -> EffectRow r n - altEffects (Abs bs (Block ann _ _)) = case ann of - NoBlockAnn -> Pure - BlockAnn (EffTy effs _) -> ignoreHoistFailure $ hoist bs effs + effTy' <- visitGeneric effTy + return $ Case x' alts' effTy' Atom x -> Atom <$> visitGeneric x TabCon Nothing t xs -> TabCon Nothing <$> visitGeneric t <*> mapM visitGeneric xs TabCon (Just (WhenIRE d)) t xs -> TabCon <$> (Just . WhenIRE <$> visitGeneric d) <*> visitGeneric t <*> mapM visitGeneric xs @@ -642,7 +642,6 @@ instance IRRep r => VisitGeneric (PrimOp r) r where MiscOp op -> MiscOp <$> visitGeneric op Hof op -> Hof <$> visitGeneric op DAMOp op -> DAMOp <$> visitGeneric op - UserEffectOp op -> UserEffectOp <$> visitGeneric op RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric visitGeneric instance IRRep r => VisitGeneric (TypedHof r) r where @@ -654,10 +653,10 @@ instance IRRep r => VisitGeneric (Hof r) r where RunReader x body -> RunReader <$> visitGeneric x <*> visitGeneric body RunWriter dest bm body -> RunWriter <$> mapM visitGeneric dest <*> visitGeneric bm <*> visitGeneric body RunState dest s body -> RunState <$> mapM visitGeneric dest <*> visitGeneric s <*> visitGeneric body - While b -> While <$> visitGeneric b - RunIO b -> RunIO <$> visitGeneric b - RunInit b -> RunInit <$> visitGeneric b - CatchException t b -> CatchException <$> visitType t <*> visitGeneric b + While b -> While <$> visitBlock b + RunIO b -> RunIO <$> visitBlock b + RunInit b -> RunInit <$> visitBlock b + CatchException t b -> CatchException <$> visitType t <*> visitBlock b Linearize lam x -> Linearize <$> visitGeneric lam <*> visitGeneric x Transpose lam x -> Transpose <$> visitGeneric lam <*> visitGeneric x @@ -672,18 +671,11 @@ instance IRRep r => VisitGeneric (DAMOp r) r where Place x y -> Place <$> visitGeneric x <*> visitGeneric y Freeze x -> Freeze <$> visitGeneric x -instance VisitGeneric UserEffectOp CoreIR where - visitGeneric = \case - Handle name xs body -> Handle <$> renameN name <*> mapM visitGeneric xs <*> visitGeneric body - Resume t x -> Resume <$> visitGeneric t <*> visitGeneric x - Perform x i -> Perform <$> visitGeneric x <*> pure i - instance IRRep r => VisitGeneric (Effect r) r where visitGeneric = \case RWSEffect rws h -> RWSEffect rws <$> visitGeneric h ExceptionEffect -> pure ExceptionEffect IOEffect -> pure IOEffect - UserEffect name -> UserEffect <$> renameN name InitEffect -> pure InitEffect instance IRRep r => VisitGeneric (EffectRow r) r where @@ -737,11 +729,9 @@ instance VisitGeneric CoreLamExpr CoreIR where visitGeneric (CoreLamExpr t lam) = CoreLamExpr <$> visitGeneric t <*> visitGeneric lam instance VisitGeneric CorePiType CoreIR where - visitGeneric (CorePiType app bsExpl effty) = do - let (expls, bs) = unzipExpls bsExpl + visitGeneric (CorePiType app expl bs effty) = do PiType bs' effty' <- visitGeneric $ PiType bs effty - let bsExpl' = zipExpls expls bs' - return $ CorePiType app bsExpl' effty' + return $ CorePiType app expl bs' effty' instance IRRep r => VisitGeneric (TabPiType r) r where visitGeneric (TabPiType d b eltTy) = do @@ -798,6 +788,15 @@ toAtomVar v = do ty <- getType <$> lookupAtomName v return $ AtomVar v ty +bindersToVars :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [AtomVar r n] +bindersToVars bs = do + withExtEvidence bs do + Distinct <- getDistinct + mapM toAtomVar $ nestToNames bs + +bindersToAtoms :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [Atom r n] +bindersToAtoms bs = liftM (Var <$>) $ bindersToVars bs + newtype SubstVisitor i o a = SubstVisitor { runSubstVisitor :: Reader (Env o, Subst AtomSubstVal i o) a } deriving (Functor, Applicative, Monad, MonadReader (Env o, Subst AtomSubstVal i o)) @@ -876,7 +875,6 @@ instance IRRep r => SubstE AtomSubstVal (RepVal r) instance SubstE AtomSubstVal TyConParams instance SubstE AtomSubstVal DataConDef instance IRRep r => SubstE AtomSubstVal (BaseMonoid r) -instance SubstE AtomSubstVal UserEffectOp instance IRRep r => SubstE AtomSubstVal (DAMOp r) instance IRRep r => SubstE AtomSubstVal (TypedHof r) instance IRRep r => SubstE AtomSubstVal (Hof r) @@ -889,7 +887,6 @@ instance IRRep r => SubstE AtomSubstVal (PrimOp r) instance IRRep r => SubstE AtomSubstVal (RefOp r) instance IRRep r => SubstE AtomSubstVal (EffTy r) instance IRRep r => SubstE AtomSubstVal (Expr r) -instance IRRep r => SubstE AtomSubstVal (Block r) instance IRRep r => SubstE AtomSubstVal (GenericOpRep const r) instance SubstE AtomSubstVal InstanceBody instance SubstE AtomSubstVal DictType diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 4aa800f17..47cf2df19 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -9,7 +9,7 @@ module CheckType ( CheckableE (..), CheckableB (..), checkTypes, checkTypesM, checkHasType, - checkExtends, tryGetType, isData, asFFIFunType, checkDestLam + checkExtends, tryGetType, isData, asFFIFunType, checkBlock ) where import Prelude hiding (id) @@ -52,13 +52,6 @@ checkHasType :: (EnvReader m, HasType r e) => e n -> Type r n -> m n (Except ()) checkHasType e ty = liftTyperT $ e |: ty {-# INLINE checkHasType #-} -checkDestLam :: (EnvReader m, Fallible1 m) => LamExpr SimpIR n -> m n () -checkDestLam lam = do - let allowedEffs = OneEffect InitEffect - PiType bs (EffTy effs _) <- return $ getDestLamExprType lam - let effs' = ignoreHoistFailure $ hoist bs effs - checkExtends allowedEffs effs' - -- === the type checking/querying monad === -- TODO: not clear why we need the explicit `Monad2` here since it should @@ -177,10 +170,23 @@ instance (CheckableB r b, CheckableE r e) => CheckableE r (Abs b e) where checkE (Abs b e) = checkB b \_ -> checkE e instance (IRRep r) => CheckableE r (LamExpr r) where - checkE (LamExpr bs body) = checkB bs \_ -> void $ getTypeE body + checkE (LamExpr bs body) = checkB bs \_ -> void $ checkBlock body -- === type checking core === +instance IRRep r => CheckableE r (TopLam r) where + checkE (TopLam _ piTy lam) = do + -- TODO: check destination-passing flag + checkE piTy + piTy' <- renameM piTy + piTy'' <- checkLamExpr lam + alphaEq piTy' piTy'' >>= \case + True -> return () + False -> throw TypeErr $ pprint piTy' ++ " != " ++ pprint piTy'' + +instance IRRep r => CheckableE r (PiType r) where + checkE piTy = void $ getTypeE piTy + instance IRRep r => CheckableE r (Atom r) where checkE atom = void $ getTypeE atom @@ -193,9 +199,6 @@ instance IRRep r => HasType r (AtomName r) where getType <$> lookupAtomName name' {-# INLINE getTypeE #-} -instance IRRep r => CheckableE r (Block r) where - checkE block = void $ getTypeE block - instance IRRep r => HasType r (Atom r) where getTypeE atom = case atom of Var name -> do @@ -253,7 +256,7 @@ instance IRRep r => HasType r (Type r) where TC tyCon -> typeCheckPrimTC tyCon DepPairTy ty -> getTypeE ty DictTy (DictType _ className params) -> do - ClassDef _ _ _ paramBs _ _ <- renameM className >>= lookupClassDef + ClassDef _ _ _ _ paramBs _ _ <- renameM className >>= lookupClassDef params' <- mapM renameM params checkArgTys paramBs params' return TyKind @@ -290,9 +293,6 @@ instance (ToBinding ann c, Color c, CheckableE r ann) => CheckableB r (BinderP c extendRenamer (b@>binderName b') $ cont b' -instance (BindsNames b, CheckableB r b) => CheckableB r (WithExpl b) where - checkB (WithExpl expl b) cont = checkB b \b' -> cont (WithExpl expl b') - typeCheckExpr :: (Typer m r, IRRep r) => EffectRow r o -> Expr r i -> m i o (Type r o) typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of App (EffTy _ reqTy) f xs -> do @@ -315,7 +315,7 @@ typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case ex return resultTy' ApplyMethod (EffTy _ reqTy) dict i args -> do DictTy (DictType _ className params) <- getTypeE dict - ClassDef _ _ _ paramBs classBs methodTys <- lookupClassDef className + ClassDef _ _ _ _ paramBs classBs methodTys <- lookupClassDef className let methodTy = methodTys !! i superclassDicts <- getSuperclassDicts =<< renameM dict let subst = ( paramBs @@> map SubstVal params @@ -332,25 +332,6 @@ typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case ex HoistFailure _ -> forM_ xs checkE return ty' -instance IRRep r => HasType r (Block r) where - getTypeE = \case - Block NoBlockAnn Empty atom -> getTypeE atom - Block (BlockAnn (EffTy effs' reqTy)) decls result -> do - effs <- renameM effs' - reqTy' <- renameM reqTy - go effs reqTy' decls result - return reqTy' - Block _ _ _ -> error "impossible" - where - go :: Typer m r => EffectRow r o -> Type r o -> Nest (Decl r) i i' -> Atom r i' -> m i o () - go _ reqTy Empty result = result |: reqTy - go effs reqTy (Nest (Let b rhs@(DeclBinding _ expr)) decls) result = do - void $ typeCheckExpr effs expr - rhs' <- renameM rhs - withFreshBinder (getNameHint b) rhs' \(b':>_) -> do - extendRenamer (b@>binderName b') do - go (sink effs) (sink reqTy) decls result - instance CheckableE CoreIR TyConParams where checkE (TyConParams _ params) = mapM_ checkE params @@ -358,8 +339,8 @@ dictExprType :: Typer m CoreIR => DictExpr i -> m i o (CType o) dictExprType e = case e of InstanceDict instanceName args -> do instanceName' <- renameM instanceName - InstanceDef className bs params _ <- lookupInstanceDef instanceName' - ClassDef sourceName _ _ _ _ _ <- lookupClassDef className + InstanceDef className _ bs params _ <- lookupInstanceDef instanceName' + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className args' <- mapM renameM args checkArgTys bs args' ListE params' <- applySubst (bs@@>(SubstVal<$>args')) (ListE params) @@ -369,7 +350,7 @@ dictExprType e = case e of checkApp Pure givenTy (toList args) SuperclassProj d i -> do DictTy (DictType _ className params) <- getTypeE d - ClassDef _ _ _ bs superclasses _ <- lookupClassDef className + ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className let scType = getSuperclassType REmpty superclasses i checkedApplyNaryAbs (Abs bs scType) params IxFin n -> do @@ -386,7 +367,14 @@ instance IRRep r => HasType r (DepPairType r) where return TyKind instance HasType CoreIR CorePiType where - getTypeE (CorePiType _ bs (EffTy eff resultTy)) = do + getTypeE (CorePiType _ _ bs (EffTy eff resultTy)) = do + checkB bs \_ -> do + void $ checkE eff + resultTy|:TyKind + return TyKind + +instance IRRep r => HasType r (PiType r) where + getTypeE (PiType bs (EffTy eff resultTy)) = do checkB bs \_ -> do void $ checkE eff resultTy|:TyKind @@ -416,14 +404,14 @@ checkAgainstGiven givenTy computedTy = do return givenTy' checkCoreLam :: Typer m CoreIR => CorePiType o -> LamExpr CoreIR i -> m i o () -checkCoreLam (CorePiType _ Empty (EffTy effs resultTy)) (LamExpr Empty body) = do +checkCoreLam (CorePiType _ _ Empty (EffTy effs resultTy)) (LamExpr Empty body) = do resultTy' <- checkBlockWithEffs effs body checkTypesEq resultTy resultTy' -checkCoreLam (CorePiType expl (Nest piB piBs) effTy) (LamExpr (Nest lamB lamBs) body) = do +checkCoreLam (CorePiType expl (_:expls) (Nest piB piBs) effTy) (LamExpr (Nest lamB lamBs) body) = do argTy <- renameM $ binderType lamB checkTypesEq (binderType piB) argTy withFreshBinder (getNameHint lamB) argTy \b -> do - piTy <- applyRename (piB@>binderName b) (CorePiType expl piBs effTy) + piTy <- applyRename (piB@>binderName b) (CorePiType expl expls piBs effTy) extendRenamer (lamB@>binderName b) do checkCoreLam piTy (LamExpr lamBs body) checkCoreLam _ _ = throw TypeErr "zip error" @@ -455,7 +443,7 @@ typeCheckNewtypeCon con x = case con of FinCon n -> n|:NatTy >> x|:NatTy >> renameM (Fin n) UserADTData _ d params -> do d' <- renameM d - def@(TyConDef sn _ _) <- lookupTyCon d' + def@(TyConDef sn _ _ _) <- lookupTyCon d' params' <- renameM params void $ checkedInstantiateTyConDef def params' return $ UserADTType sn d' params' @@ -490,7 +478,6 @@ typeCheckPrimOp effs op = case op of MiscOp x -> typeCheckMiscOp effs x MemOp x -> typeCheckMemOp effs x DAMOp op' -> typeCheckDAMOp effs op' - UserEffectOp op' -> typeCheckUserEffect op' RefOp ref m -> do TC (RefType h s) <- getTypeE ref case m of @@ -613,24 +600,6 @@ typeCheckVectorOp = \case unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" return $ RefTy heap ty' -typeCheckUserEffect :: Typer m CoreIR => UserEffectOp i -> m i o (CType o) -typeCheckUserEffect = \case - -- TODO(alex): check the argument - Resume retTy _argTy -> do - checkTypeE TyKind retTy - -- TODO(alex): actually check something here? this is a QueryType copy/paste - Handle hndName [] body -> do - hndName' <- renameM hndName - r <- getTypeE body - instantiateHandlerType hndName' r [] - -- TODO(alex): implement - Handle _ _ _ -> error "not implemented" - Perform eff i -> do - Eff (OneEffect (UserEffect effName)) <- return eff - EffectDef _ ops <- renameM effName >>= lookupEffectDef - let (_, EffectOpType _pol lamTy) = ops !! i - return lamTy - typeCheckPrimHof :: forall r m i o. (Typer m r, IRRep r) => EffectRow r o -> Hof r i -> m i o (Type r o) typeCheckPrimHof effs hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of For _ ixTy f -> do @@ -639,7 +608,7 @@ typeCheckPrimHof effs hof = addContext ("Checking HOF:\n" ++ pprint hof) case ho checkTypesEq t argTy return $ TabTy d (b:>t) eltTy While body -> do - condTy <- getTypeE body + condTy <- checkBlockWithEffs effs body checkTypesEq (BaseTy $ Scalar Word8Type) condTy return UnitTy Linearize f x -> do @@ -720,8 +689,7 @@ typeCheckDAMOp effs op = addContext ("Checking DAMOp:\n" ++ pprint op) case op o checkLamExpr :: (Typer m r, IRRep r) => LamExpr r i -> m i o (PiType r o) checkLamExpr (LamExpr bsTop body) = case bsTop of Empty -> do - resultTy <- getTypeE body - effs <- renameM $ getEffects body + EffTy effs resultTy <- checkBlock body return $ PiType Empty $ EffTy effs resultTy Nest (b:>ty) bs -> do ty' <- checkTypeE TyKind ty @@ -737,12 +705,31 @@ checkLamExprWithEffs allowedEffs lam = do checkExtends allowedEffs effs' return piTy -checkBlockWithEffs :: (Typer m r, IRRep r) => EffectRow r o -> Block r i -> m i o (Type r o) -checkBlockWithEffs allowedEffs block = do - ty <- getTypeE block - effs <- renameM $ getEffects block - checkExtends allowedEffs effs - return ty +checkBlockWithEffs :: forall i o r m. (Typer m r, IRRep r) => EffectRow r o -> Block r i -> m i o (Type r o) +checkBlockWithEffs allowedEffs (Abs decls result) = do + checkDecls allowedEffs decls \decls' -> do + resultTy <- getTypeE result + liftHoistExcept $ hoist decls' resultTy + +checkDecls + :: (Typer m r, IRRep r) + => EffectRow r o -> Decls r i i' + -> (forall o'. DExt o o' => Decls r o o' -> m i' o' a) + -> m i o a +checkDecls _ Empty cont = getDistinct >>= \Distinct -> cont Empty +checkDecls effs (Nest (Let b rhs@(DeclBinding _ expr)) decls) cont = do + void $ typeCheckExpr effs expr + rhs' <- renameM rhs + withFreshBinder (getNameHint b) rhs' \(b':>_) -> do + extendRenamer (b@>binderName b') do + let decl' = Let b' rhs' + checkDecls (sink effs) decls \decls' -> cont $ Nest decl' decls' + +checkBlock :: (Typer m r, IRRep r) => Block r i -> m i o (EffTy r o) +checkBlock block = do + EffTy effs _ <- blockEffTy =<< renameM block + ty <- checkBlockWithEffs effs block + return $ EffTy effs ty checkRWSAction :: (Typer m r, IRRep r) => EffectRow r o -> RWS -> LamExpr r i -> m i o (Type r o, Type r o) checkRWSAction effs rws f = do @@ -783,7 +770,7 @@ checkAlt resultTyReq bTyReq effs (Abs b body) = do checkApp :: (Typer m r, IRRep r) => EffectRow r o -> Type r o -> [Atom r i] -> m i o (Type r o) checkApp allowedEffs fTy xs = case fTy of - Pi (CorePiType _ bs effTy) -> do + Pi (CorePiType _ _ bs effTy) -> do xs' <- mapM renameM xs checkArgTys bs xs' let subst = bs @@> fmap SubstVal xs' @@ -801,11 +788,7 @@ checkTabApp ty (i:rest) = do resultTy' <- applySubst (b@>SubstVal i') resultTy checkTabApp resultTy' rest -checkArgTys - :: (Typer m r, SubstB AtomSubstVal b, BindsNames b, BindsOneAtomName r b, IRRep r) - => Nest b o o' - -> [Atom r o] - -> m i o () +checkArgTys :: (Typer m r, IRRep r) => Nest (Binder r) o o' -> [Atom r o] -> m i o () checkArgTys Empty [] = return () checkArgTys (Nest b bs) (x:xs) = do dropSubst $ x |: binderType b @@ -939,19 +922,18 @@ checkUnOp op x = do checkedInstantiateTyConDef :: (EnvReader m, Fallible1 m) => TyConDef n -> TyConParams n -> m n (DataConDefs n) -checkedInstantiateTyConDef (TyConDef _ bs cons) (TyConParams _ xs) = do +checkedInstantiateTyConDef (TyConDef _ _ bs cons) (TyConParams _ xs) = do checkedApplyNaryAbs (Abs bs cons) xs checkedApplyNaryAbs - :: forall b r e o m - . ( BindsOneAtomName r b, EnvReader m, Fallible1 m, SinkableE e - , SubstE AtomSubstVal e, IRRep r, SubstB AtomSubstVal b) - => Abs (Nest b) e o -> [Atom r o] -> m o (e o) + :: forall r e o m + . ( EnvReader m, Fallible1 m, SinkableE e , SubstE AtomSubstVal e, IRRep r) + => Abs (Nest (Binder r)) e o -> [Atom r o] -> m o (e o) checkedApplyNaryAbs (Abs bsTop e) xsTop = do go (EmptyAbs bsTop) xsTop applySubst (bsTop@@>(SubstVal<$>xsTop)) e where - go :: EmptyAbs (Nest b) o -> [Atom r o] -> m o () + go :: EmptyAbs (Nest (Binder r)) o -> [Atom r o] -> m o () go (Abs Empty UnitE) [] = return () go (Abs (Nest b bs) UnitE) (x:xs) = do checkAlphaEq (binderType b) (getType x) @@ -967,7 +949,6 @@ instance IRRep r => CheckableE r (EffectRow r) where RWSEffect _ v -> v |: TC HeapType ExceptionEffect -> return () IOEffect -> return () - UserEffect _ -> return () InitEffect -> return () case effTail of NoTail -> return () @@ -1006,7 +987,7 @@ asFFIFunType ty = return do return (impTy, piTy) checkFFIFunTypeM :: Fallible m => CorePiType n -> m IFunType -checkFFIFunTypeM (CorePiType appExpl (Nest b bs) effTy) = do +checkFFIFunTypeM (CorePiType appExpl (_:expls) (Nest b bs) effTy) = do argTy <- checkScalar $ binderType b case bs of Empty -> do @@ -1017,7 +998,7 @@ checkFFIFunTypeM (CorePiType appExpl (Nest b bs) effTy) = do _ -> FFIMultiResultCC return $ IFunType cc [argTy] resultTys Nest b' rest -> do - let naryPiRest = CorePiType appExpl (Nest b' rest) effTy + let naryPiRest = CorePiType appExpl expls (Nest b' rest) effTy IFunType cc argTys resultTys <- checkFFIFunTypeM naryPiRest return $ IFunType cc (argTy:argTys) resultTys checkFFIFunTypeM _ = error "expected at least one argument" diff --git a/src/lib/Core.hs b/src/lib/Core.hs index c63021fe3..f6fb57452 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -218,17 +218,13 @@ instance BindsEnv EnvFrag where toEnvFrag frag = frag {-# INLINE toEnvFrag #-} -instance BindsEnv b => BindsEnv (WithExpl b) where - toEnvFrag (WithExpl _ b) = toEnvFrag b - {-# INLINE toEnvFrag #-} - -instance BindsEnv RolePiBinder where - toEnvFrag (RolePiBinder _ b) = toEnvFrag b - {-# INLINE toEnvFrag #-} - instance BindsEnv (RecSubstFrag Binding) where toEnvFrag frag = EnvFrag frag +instance BindsEnv b => BindsEnv (WithAttrB a b) where + toEnvFrag (WithAttrB _ b) = toEnvFrag b + {-# INLINE toEnvFrag #-} + instance (BindsEnv b1, BindsEnv b2) => (BindsEnv (PairB b1 b2)) where toEnvFrag (PairB b1 b2) = do @@ -347,18 +343,6 @@ lookupInstanceTy :: EnvReader m => InstanceName n -> m n (CorePiType n) lookupInstanceTy name = lookupEnv name >>= \case InstanceBinding _ ty -> return ty {-# INLINE lookupInstanceTy #-} -lookupEffectDef :: EnvReader m => EffectName n -> m n (EffectDef n) -lookupEffectDef name = lookupEnv name >>= \case EffectBinding x -> return x -{-# INLINE lookupEffectDef #-} - -lookupEffectOpDef :: EnvReader m => EffectOpName n -> m n (EffectOpDef n) -lookupEffectOpDef name = lookupEnv name >>= \case EffectOpBinding x -> return x -{-# INLINE lookupEffectOpDef #-} - -lookupHandlerDef :: EnvReader m => HandlerName n -> m n (HandlerDef n) -lookupHandlerDef name = lookupEnv name >>= \case HandlerBinding x -> return x -{-# INLINE lookupHandlerDef #-} - lookupSourceMapPure :: SourceMap n -> SourceName -> [SourceNameDef n] lookupSourceMapPure (SourceMap m) v = M.findWithDefault [] v m {-# INLINE lookupSourceMapPure #-} @@ -423,8 +407,8 @@ getInstanceDicts name = do liftLamExpr :: (IRRep r, EnvReader m) => (forall l m2. EnvReader m2 => Block r l -> m2 l (Block r l)) - -> LamExpr r n -> m n (LamExpr r n) -liftLamExpr f (LamExpr bs body) = liftEnvReaderM $ + -> TopLam r n -> m n (TopLam r n) +liftLamExpr f (TopLam d ty (LamExpr bs body)) = liftM (TopLam d ty) $ liftEnvReaderM $ refreshAbs (Abs bs body) \bs' body' -> LamExpr bs' <$> f body' fromNaryForExpr :: IRRep r => Int -> Expr r n -> Maybe (Int, LamExpr r n) diff --git a/src/lib/Export.hs b/src/lib/Export.hs index c164d7bc6..f7ab3184d 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -52,8 +52,8 @@ prepareFunctionForExport cc f = do HoistFailure _ -> throw TypeErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s - CoreLamExpr _ f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (Var <$> xs) - fSimp <- simplifyTopFunction f' + f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (Var <$> xs) + fSimp <- simplifyTopFunction $ coreLamToTopLam f' fImp <- compileTopLevelFun cc fSimp nativeFun <- toCFunction "userFunc" fImp >>= emitObjFile >>= loadObject return $ ExportNativeFunction nativeFun closedSig @@ -61,9 +61,8 @@ prepareFunctionForExport cc f = do {-# SCC prepareFunctionForExport #-} prepareSLamForExport :: (Mut n, Topper m) - => CallingConvention -> SLam n -> m n ExportNativeFunction -prepareSLamForExport cc f = do - let naryPi = getLamExprType f + => CallingConvention -> STopLam n -> m n ExportNativeFunction +prepareSLamForExport cc f@(TopLam _ naryPi _) = do sig <- liftExportSigM $ simpPiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> @@ -101,11 +100,11 @@ liftExportSigM cont = do corePiToExportSig :: CallingConvention -> CorePiType i -> ExportSigM CoreIR i o (ExportedSignature o) -corePiToExportSig cc (CorePiType _ tbs (EffTy effs resultTy)) = do +corePiToExportSig cc (CorePiType _ expls tbs (EffTy effs resultTy)) = do case effs of Pure -> return () _ -> throw TypeErr "Only pure functions can be exported" - goArgs cc Empty [] tbs resultTy + goArgs cc Empty [] (zipAttrs expls tbs) resultTy simpPiToExportSig :: CallingConvention -> PiType SimpIR i -> ExportSigM SimpIR i o (ExportedSignature o) @@ -113,14 +112,14 @@ simpPiToExportSig cc (PiType bs (EffTy effs resultTy)) = do case effs of Pure -> return () _ -> throw TypeErr "Only pure functions can be exported" - bs' <- return $ fmapNest (\b -> WithExpl Explicit b) bs + bs' <- return $ fmapNest (\b -> WithAttrB Explicit b) bs goArgs cc Empty [] bs' resultTy goArgs :: (IRRep r) => CallingConvention -> Nest ExportArg o o' -> [CAtomName o'] - -> Nest (WithExpl (Binder r)) i i' + -> Nest (WithAttrB Explicitness (Binder r)) i i' -> Type r i' -> ExportSigM r i o' (ExportedSignature o) goArgs cc argSig argVs piBs piRes = case piBs of @@ -129,7 +128,7 @@ goArgs cc argSig argVs piBs piRes = case piBs of StandardCC -> (fromListE $ sink $ ListE argVs) ++ nestToList (sink . binderName) resSig XLACC -> [] _ -> error $ "calling convention not supported: " ++ show cc - Nest (WithExpl expl (b:>ty)) bs -> do + Nest (WithAttrB expl (b:>ty)) bs -> do ety <- toExportType ty withFreshBinder (getNameHint b) ety \(v:>_) -> extendSubst (b @> Rename (binderName v)) $ do @@ -176,8 +175,8 @@ parseTabTy = go [] NewtypeTyCon Nat -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape TabTy d (b:>ixty) a -> do maybeN <- case IxType ixty d of - (IxType (NewtypeTyCon (Fin n)) _) -> return $ Just n - (IxType _ (IxDictRawFin n)) -> return $ Just n + IxType (NewtypeTyCon (Fin n)) _ -> return $ Just n + IxType _ (IxDictRawFin n) -> return $ Just n _ -> return Nothing maybeDim <- case maybeN of Just (Var v) -> do diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index 78037c742..dacb584fb 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -20,6 +20,9 @@ import Subst import MTL1 import Types.Primitives +type RolePiBinder = WithAttrB RoleExpl CBinder +type RolePiBinders = Nest RolePiBinder + generalizeIxDict :: EnvReader m => Atom CoreIR n -> m n (Generalized CoreIR CAtom n) generalizeIxDict dict = liftGeneralizerM do dict' <- sinkM dict @@ -31,12 +34,12 @@ generalizeIxDict dict = liftGeneralizerM do generalizeArgs ::EnvReader m => CorePiType n -> [Atom CoreIR n] -> m n (Generalized CoreIR (ListE CAtom) n) generalizeArgs fTy argsTop = liftGeneralizerM $ runSubstReaderT idSubst do - PairE (CorePiType _ bs _) (ListE argsTop') <- sinkM $ PairE fTy (ListE argsTop) - ListE <$> go bs argsTop' + PairE (CorePiType _ expls bs _) (ListE argsTop') <- sinkM $ PairE fTy (ListE argsTop) + ListE <$> go (zipAttrs expls bs) argsTop' where - go :: Nest (WithExpl CBinder) i i' -> [Atom CoreIR n] + go :: Nest (WithAttrB Explicitness CBinder) i i' -> [Atom CoreIR n] -> SubstReaderT AtomSubstVal GeneralizerM i n [Atom CoreIR n] - go (Nest (WithExpl expl b) bs) (arg:args) = do + go (Nest (WithAttrB expl b) bs) (arg:args) = do ty' <- substM $ binderType b arg' <- case (ty', expl) of (TyKind, _) -> liftSubstReaderT case arg of @@ -172,7 +175,7 @@ traverseRoleBinders f allBinders allParams = go :: forall i i'. RolePiBinders i i' -> [Atom CoreIR n] -> SubstReaderT AtomSubstVal m i n [Atom CoreIR n] go Empty [] = return [] - go (Nest (RolePiBinder role b) bs) (param:params) = do + go (Nest (WithAttrB (role, _) b) bs) (param:params) = do ty' <- substM $ binderType b Distinct <- getDistinct param' <- liftSubstReaderT $ f role ty' param @@ -183,14 +186,14 @@ traverseRoleBinders f allBinders allParams = getDataDefRoleBinders :: EnvReader m => TyConName n -> m n (Abs RolePiBinders UnitE n) getDataDefRoleBinders def = do - TyConDef _ bs _ <- lookupTyCon def - return $ Abs bs UnitE + TyConDef _ attrs bs _ <- lookupTyCon def + return $ Abs (zipAttrs attrs bs) UnitE {-# INLINE getDataDefRoleBinders #-} getClassRoleBinders :: EnvReader m => ClassName n -> m n (Abs RolePiBinders UnitE n) getClassRoleBinders def = do - ClassDef _ _ _ bs _ _ <- lookupClassDef def - return $ Abs bs UnitE + ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef def + return $ Abs (zipAttrs roleExpls bs) UnitE {-# INLINE getClassRoleBinders #-} -- === instances === diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index b6d1d3958..bfb73537c 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -46,16 +46,14 @@ import Types.Imp import Types.Primitives import Util (forMFilter, Tree (..), zipTrees, enumerate) --- XXX: The LamExpr should be in destination-passing style, with its last --- argument a reference to the result. toImpFunction :: EnvReader m - => CallingConvention -> LamExpr SimpIR n -> m n (ImpFunction n) -toImpFunction cc lam = do - (LamExpr bsAndRefB body) <- return lam + => CallingConvention -> STopLam n -> m n (ImpFunction n) +toImpFunction cc (TopLam True destTy lam) = do + LamExpr bsAndRefB body <- return lam PairB bs destB <- case popNest bsAndRefB of Just bsAndRefB' -> return bsAndRefB' Nothing -> error "expected a trailing reference binder" - ty <- return $ getDestLamExprType lam + let ty = piTypeWithoutDest destTy impArgTys <- getNaryLamImpArgTypesWithCC cc ty liftImpM $ buildImpFunction cc (zip (repeat noHint) impArgTys) \vs -> do case cc of @@ -75,6 +73,7 @@ toImpFunction cc lam = do extendSubst (destB @> SubstVal (destToAtom (sink resultDest))) do void $ translateBlock body return [] +toImpFunction _ (TopLam False _ _) = error "expected a lambda in destination-passing form" getNaryLamImpArgTypesWithCC :: EnvReader m => CallingConvention -> PiType SimpIR n -> m n [BaseType] @@ -270,7 +269,7 @@ liftImpM cont = do translateBlock :: forall i o. Emits o => SBlock i -> SubstImpM i o (SAtom o) -translateBlock (Block _ decls result) = translateDeclNest decls $ substM result +translateBlock (Abs decls result) = translateDeclNest decls $ substM result translateDeclNestSubst :: Emits o => Subst AtomSubstVal l o @@ -296,7 +295,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of f <- substM f' xs <- mapM substM xs' lookupTopFun f >>= \case - DexTopFun _ piTy _ _ -> emitCall piTy f $ toList xs + DexTopFun _ (TopLam _ piTy _) _ -> emitCall piTy f $ toList xs FFITopFun _ _ -> do scalarArgs <- liftM toList $ mapM fromScalarAtom xs results <- impCall f scalarArgs @@ -367,16 +366,16 @@ toImpRefOp refDest' m = do ans <- liftBuilderImp $ emitBlock (sink body') storeAtom accDest ans False -> case accTy of - TabTy d (b:>t) eltTy -> do - let ixTy = IxType t d + TabPi t -> do + let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i xElt <- liftBuilderImp $ tabApp (sink x) (sink idx) yElt <- liftBuilderImp $ tabApp (sink y) (sink idx) - eltTy' <- applySubst (b@>SubstVal idx) eltTy + eltTy <- instantiateTabPiTy (sink t) idx ithDest <- indexDest (sink accDest) idx - liftMonoidCombine ithDest eltTy' (sink bc) xElt yElt + liftMonoidCombine ithDest eltTy (sink bc) xElt yElt _ -> error $ "Base monoid type mismatch: can't lift " ++ pprint baseTy ++ " to " ++ pprint accTy @@ -579,15 +578,15 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do alphaEq xTy accTy >>= \case True -> storeAtom accDest x False -> case accTy of - TabTy d (b:>t) eltTy -> do - let ixTy = IxType t d + TabPi t -> do + let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i x' <- sinkM x - eltTy' <- applySubst (b@>SubstVal idx) eltTy + eltTy <- instantiateTabPiTy (sink t) idx ithDest <- indexDest (sink accDest) idx - liftMonoidEmpty ithDest eltTy' x' + liftMonoidEmpty ithDest eltTy x' _ -> error $ "Base monoid type mismatch: can't lift " ++ pprint xTy ++ " to " ++ pprint accTy @@ -1003,11 +1002,11 @@ buildGarbageVal ty = -- === Operations on dests === indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n) -indexDest (Dest destValTy@(TabTy d (b:>t) eltTy) tree) i = do - eltTy' <- applySubst (b@>SubstVal i) eltTy - ord <- ordinalImp (IxType t d) i - leafTys <- typeToTree destValTy - Dest eltTy' <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do +indexDest (Dest (TabPi tabTy) tree) i = do + eltTy <- instantiateTabPiTy tabTy i + ord <- ordinalImp (tabIxType tabTy) i + leafTys <- typeToTree $ TabPi tabTy + Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do BufferType ixStruct _ <- return $ getRefBufferType leafTy offset <- computeOffsetImp ixStruct ord impOffset ptr offset @@ -1027,10 +1026,10 @@ indexRepValParam :: Emits n => SRepVal n -> SAtom n -> (SType n -> SType n) -> (IExpr n -> SubstImpM i n (IExpr n)) -> SubstImpM i n (SRepVal n) -indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc func = do - eltTy' <- applySubst (b@>SubstVal i) eltTy - ord <- ordinalImp (IxType t d) i - leafTys <- typeToTree tabTy +indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do + eltTy <- instantiateTabPiTy tabTy i + ord <- ordinalImp (tabIxType tabTy) i + leafTys <- typeToTree (TabPi tabTy) vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do BufferPtr (BufferType ixStruct _) <- return $ getIExprInterpretation leafTy offset <- computeOffsetImp ixStruct ord @@ -1042,7 +1041,7 @@ indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc _ -> func ptr' -- `func` may have changed the types of the `vals'`. The caller must also -- supply `tyFunc` to reflect that change in the SType. - return $ RepVal (tyFunc eltTy') vals' + return $ RepVal (tyFunc eltTy) vals' indexRepValParam _ _ _ _ = error "expected table type" {-# INLINE indexRepValParam #-} @@ -1161,12 +1160,9 @@ hoistDecls , BindsNames b, BindsEnv b, RenameB b, SinkableB b) => b n l -> SBlock l -> m n (Abs b SBlock n) hoistDecls b block = do - Abs hoistedDecls rest <- liftEnvReaderM $ - refreshAbs (Abs b block) \b' (Block _ decls result) -> + emitDecls =<< liftEnvReaderM do + refreshAbs (Abs b block) \b' (Abs decls result) -> hoistDeclsRec b' Empty decls result - ab <- emitDecls hoistedDecls rest - refreshAbs ab \b'' blockAbs' -> - Abs b'' <$> absToBlockInferringTypes blockAbs' {-# INLINE hoistDecls #-} hoistDeclsRec @@ -1409,30 +1405,27 @@ ordinalImp :: Emits n => IxType SimpIR n -> SAtom n -> SubstImpM i n (IExpr n) ordinalImp (IxType _ dict) i = fromScalarAtom =<< case dict of IxDictRawFin _ -> return i IxDictSpecialized _ d params -> do - SpecializedDict _ (Just fs) <- lookupSpecDict d - appSpecializedIxMethod (fs !! fromEnum Ordinal) (params ++ [i]) + appSpecializedIxMethod d Ordinal (params ++ [i]) unsafeFromOrdinalImp :: Emits n => IxType SimpIR n -> IExpr n -> SubstImpM i n (SAtom n) unsafeFromOrdinalImp (IxType _ dict) i = do let i' = toScalarAtom i case dict of IxDictRawFin _ -> return i' - IxDictSpecialized _ d params -> do - SpecializedDict _ (Just fs) <- lookupSpecDict d - appSpecializedIxMethod (fs !! fromEnum UnsafeFromOrdinal) (params ++ [i']) + IxDictSpecialized _ d params -> + appSpecializedIxMethod d UnsafeFromOrdinal (params ++ [i']) indexSetSizeImp :: Emits n => IxType SimpIR n -> SubstImpM i n (IExpr n) indexSetSizeImp (IxType _ dict) = do - ans <- case dict of + fromScalarAtom =<< case dict of IxDictRawFin n -> return n - IxDictSpecialized _ d params -> do - SpecializedDict _ (Just fs) <- lookupSpecDict d - appSpecializedIxMethod (fs !! fromEnum Size) (params ++ []) - fromScalarAtom ans - -appSpecializedIxMethod :: Emits n => LamExpr SimpIR n -> [SAtom n] -> SubstImpM i n (SAtom n) -appSpecializedIxMethod simpLam args = do - LamExpr bs body <- return simpLam + IxDictSpecialized _ d params -> + appSpecializedIxMethod d Size (params ++ []) + +appSpecializedIxMethod :: Emits n => SpecDictName n -> IxMethod -> [SAtom n] -> SubstImpM i n (SAtom n) +appSpecializedIxMethod d method args = do + SpecializedDict _ (Just fs) <- lookupSpecDict d + TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock body -- === Abstracting link-time objects === @@ -1444,7 +1437,7 @@ abstractLinktimeObjects f = do let allVars = freeVarsE f (funVars, funTys) <- unzip <$> forMFilter (nameSetToList @TopFunNameC allVars) \v -> lookupTopFun v >>= \case - DexTopFun _ piTy _ _ -> do + DexTopFun _ (TopLam _ piTy _) _ -> do ty' <- getImpFunType StandardCC piTy return $ Just (v, ty') FFITopFun _ _ -> return Nothing @@ -1529,7 +1522,7 @@ impInstrTypes instr = case instr of DebugPrint _ _ -> return [] IQueryParallelism _ _ -> return [IIdxRepTy, IIdxRepTy] ICall f _ -> lookupTopFun f >>= \case - DexTopFun _ piTy _ _ -> do + DexTopFun _ (TopLam _ piTy _) _ -> do IFunType _ _ resultTys <- getImpFunType StandardCC piTy return resultTys FFITopFun _ (IFunType _ _ resultTys) -> return resultTys diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index 15474cd33..19905424b 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -501,7 +501,7 @@ compileInstr instr = case instr of return [] RenameOperandSubstVal v -> do lookupTopFun v >>= \case - DexTopFun _ _ _ _ -> error "Imp functions should be abstracted at this point" + DexTopFun _ _ _ -> error "Imp functions should be abstracted at this point" FFITopFun fname ty@(IFunType cc _ impResultTys) -> do let resultTys = map scalarTy impResultTys case cc of diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index de0d43854..907f375a9 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -10,7 +10,7 @@ module Inference ( inferTopUDecl, checkTopUType, inferTopUExpr - , trySynthTerm, generalizeDict + , trySynthTerm, generalizeDict, asTopBlock , synthTopE, UDeclInferenceResult (..)) where import Prelude hiding ((.), id) @@ -25,7 +25,7 @@ import Data.Foldable (toList, asum) import Data.Functor ((<&>)) import Data.List (sortOn) import Data.Maybe (fromJust, fromMaybe, catMaybes) -import Data.Text.Prettyprint.Doc (Pretty (..), (<+>), vcat) +import Data.Text.Prettyprint.Doc (Pretty (..), (<+>), vcat, group, line, nest) import Data.Word import qualified Data.HashMap.Strict as HM import qualified Data.Map.Strict as M @@ -47,7 +47,8 @@ import QueryType import Types.Core import Types.Primitives import Types.Source -import Util +import Util hiding (group) +import PPrint (prettyBlock) -- === Top-level interface === @@ -55,15 +56,15 @@ checkTopUType :: (Fallible1 m, EnvReader m) => UType n -> m n (CType n) checkTopUType ty = liftInfererM $ solveLocal $ withApplyDefaults $ checkUType ty {-# SCC checkTopUType #-} -inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (CBlock n) -inferTopUExpr e = liftInfererM do +inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) +inferTopUExpr e = asTopBlock =<< liftInfererM do solveLocal $ buildBlockInf $ withApplyDefaults $ inferSigma noHint e {-# SCC inferTopUExpr #-} data UDeclInferenceResult e n = UDeclResultDone (e n) -- used for UDataDefDecl, UInterface and UInstance - | UDeclResultBindName LetAnn (CBlock n) (Abs (UBinder (AtomNameC CoreIR)) e n) - | UDeclResultBindPattern NameHint (CBlock n) (ReconAbs CoreIR e n) + | UDeclResultBindName LetAnn (TopBlock CoreIR n) (Abs (UBinder (AtomNameC CoreIR)) e n) + | UDeclResultBindPattern NameHint (TopBlock CoreIR n) (ReconAbs CoreIR e n) inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, SinkableE e, HoistableE e, RenameE e) => UTopDecl n l -> e l -> m n (UDeclInferenceResult e n) @@ -73,7 +74,7 @@ inferTopUDecl (UStructDecl tc def) result = do extendRenamer (tc@>sink tc') $ inferStructDef def def'' <- synthTyConDef def' updateTopEnv $ UpdateTyConDef tc' def'' - UStructDef _ paramBs _ methods <- return def + UStructDef _ (_, paramBs) _ methods <- return def forM_ methods \(letAnn, methodName, methodDef) -> do method <- liftInfererM $ solveLocal $ extendRenamer (tc@>sink tc') $ @@ -84,7 +85,7 @@ inferTopUDecl (UStructDecl tc def) result = do UDeclResultDone <$> applyRename (tc @> tc') result inferTopUDecl (UDataDefDecl def tc dcs) result = do tcDef <- liftInfererM $ solveLocal $ inferTyConDef def - tcDef'@(TyConDef _ _ (ADTCons dataCons)) <- synthTyConDef tcDef + tcDef'@(TyConDef _ _ _ (ADTCons dataCons)) <- synthTyConDef tcDef tc' <- emitBinding (getNameHint tcDef') $ TyConBinding (Just tcDef') (DotMethods mempty) dcs' <- forM (enumerate dataCons) \(i, dcDef) -> emitBinding (getNameHint dcDef) $ DataConBinding tc' i @@ -103,14 +104,16 @@ inferTopUDecl (UInterface paramBs methodTys className methodNames) result = do inferTopUDecl (UInstance className instanceBs params methods maybeName expl) result = do let (InternalName _ _ className') = className ab <- liftInfererM $ solveLocal do - withRoleUBinders instanceBs \_ -> do - ClassDef _ _ _ paramBinders _ _ <- lookupClassDef (sink className') - params' <- checkInstanceParams paramBinders params + withRoleUBinders instanceBs do + ClassDef _ _ _ roleExpls paramBinders _ _ <- lookupClassDef (sink className') + let expls = snd <$> roleExpls + params' <- checkInstanceParams expls paramBinders params className'' <- sinkM className' body <- checkInstanceBody className'' params' methods return (ListE params' `PairE` body) Abs bs' (ListE params' `PairE` body) <- return ab - let def = InstanceDef className' bs' params' body + let (roleExpls, bs'') = unzipAttrs bs' + let def = InstanceDef className' roleExpls bs'' params' body UDeclResultDone <$> case maybeName of RightB UnitB -> do void $ synthInstanceDefAndAddSynthCandidate def @@ -129,7 +132,8 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d WithSrcB _ (UPatBinder b) -> do block <- liftInfererM $ solveLocal $ buildBlockInf do checkMaybeAnnExpr (getNameHint b) tyAnn rhs <* applyDefaults - return $ UDeclResultBindName letAnn block (Abs b result) + topBlock <- asTopBlock block + return $ UDeclResultBindName letAnn topBlock (Abs b result) _ -> do PairE block recon <- liftInfererM $ solveLocal $ buildBlockInfWithRecon do val <- checkMaybeAnnExpr (getNameHint p) tyAnn rhs @@ -137,19 +141,24 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d bindLetPat p v do applyDefaults renameM result - return $ UDeclResultBindPattern (getNameHint p) block recon + topBlock <- asTopBlock block + return $ UDeclResultBindPattern (getNameHint p) topBlock recon inferTopUDecl (UEffectDecl _ _ _) _ = error "not implemented" inferTopUDecl (UHandlerDecl _ _ _ _ _ _ _) _ = error "not implemented" {-# SCC inferTopUDecl #-} +asTopBlock :: EnvReader m => CBlock n -> m n (TopBlock CoreIR n) +asTopBlock block = do + effTy <- blockEffTy block + return $ TopLam False (PiType Empty effTy) (LamExpr Empty block) + getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n) -getInstanceType (InstanceDef className bs params _) = liftEnvReaderM do +getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM do refreshAbs (Abs bs (ListE params)) \bs' (ListE params') -> do className' <- sinkM className - ClassDef classSourceName _ _ _ _ _ <- lookupClassDef className' + ClassDef classSourceName _ _ _ _ _ _ <- lookupClassDef className' let dTy = DictTy $ DictType classSourceName className' params' - let bs'' = fmapNest (\(RolePiBinder _ b) -> b) bs' - return $ CorePiType ImplicitApp bs'' $ EffTy Pure dTy + return $ CorePiType ImplicitApp (snd <$> roleExpls) bs' $ EffTy Pure dTy -- === Inferer interface === @@ -170,19 +179,40 @@ class ( MonadFail1 m, Fallible1 m, Catchable1 m, CtxReader1 m, Builder CoreIR m => EmitsInf n => NameHint -> Explicitness -> CType n -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) - -> m n (Abs (WithExpl CBinder) e n) + -> m n (Abs CBinder e n) + +buildAbsInfWithExpl + :: (InfBuilder m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) + => EmitsInf n + => NameHint -> Explicitness -> CType n + -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) + -> m n (Abs (WithExpl CBinder) e n) +buildAbsInfWithExpl hint expl ty cont = do + Abs b e <- buildAbsInf hint expl ty cont + return $ Abs (WithAttrB expl b) e + +buildNaryAbsInfWithExpl + :: (Inferer m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) + => EmitsInf n + => [Explicitness] -> EmptyAbs (Nest CBinder) n + -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) + -> m i n (Abs (Nest (WithExpl CBinder)) e n) +buildNaryAbsInfWithExpl expls bs cont = do + Abs bs' e <- buildNaryAbsInf expls bs cont + return $ Abs (zipAttrs expls bs') e buildNaryAbsInf :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) => EmitsInf n - => EmptyAbs (Nest (WithExpl CBinder)) n + => [Explicitness] -> EmptyAbs (Nest CBinder) n -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) - -> m i n (Abs (Nest (WithExpl CBinder)) e n) -buildNaryAbsInf (Abs Empty UnitE) cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -buildNaryAbsInf (Abs (Nest (WithExpl expl (b:>ty)) bs) UnitE) cont = + -> m i n (Abs (Nest CBinder) e n) +buildNaryAbsInf [] (Abs Empty UnitE) cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] +buildNaryAbsInf (expl:expls) (Abs (Nest (b:>ty) bs) UnitE) cont = prependAbs <$> buildAbsInf (getNameHint b) expl ty \v -> do bs' <- applyRename (b@>atomVarName v) (Abs bs UnitE) - buildNaryAbsInf bs' \vs -> cont (sink v:vs) + buildNaryAbsInf expls bs' \vs -> cont (sink v:vs) +buildNaryAbsInf _ _ _ = error "zip error" buildDeclsInf :: (SubstE AtomSubstVal e, RenameE e, Solver m, InfBuilder m) @@ -514,7 +544,7 @@ instance InfBuilder (InfererM i) where ++ "\n" ++ pprint infFrag Abs b e <- return ab ty' <- zonk ty - return $ Abs (WithExpl expl (b:>ty')) e + return $ Abs (b:>ty') e dceInfFrag :: (EnvReader m, EnvExtender m, Fallible1 m, RenameE e, HoistableE e) @@ -823,11 +853,12 @@ extendSynthCandidates (Inferred _ (Synth _)) v (Env topEnv (ModuleEnv a b scs)) extendSynthCandidates _ _ env = env {-# INLINE extendSynthCandidates #-} -extendSynthCandidatess :: Distinct n => RolePiBinders n' n -> Env n -> Env n -extendSynthCandidatess (Nest (RolePiBinder _ (WithExpl expl b)) rest) env = - extendSynthCandidatess rest env' - where env' = extendSynthCandidates expl (withExtEvidence rest $ sink $ binderName b) env -extendSynthCandidatess Empty env = env +extendSynthCandidatess :: Distinct n => [Explicitness] -> Nest CBinder n' n -> Env n -> Env n +extendSynthCandidatess (expl:expls) (Nest b bs) env = + extendSynthCandidatess expls bs env' + where env' = extendSynthCandidates expl (withExtEvidence bs $ sink $ binderName b) env +extendSynthCandidatess [] Empty env = env +extendSynthCandidatess _ _ _ = error "zip error" {-# INLINE extendSynthCandidatess #-} -- === actual inference pass === @@ -840,8 +871,8 @@ data RequiredTy (e::E) (n::S) = checkSigma :: EmitsBoth o => NameHint -> UExpr i -> CType o -> InfererM i o (CAtom o) checkSigma hint expr sTy = confuseGHC >>= \_ -> case sTy of - Pi piTy@(CorePiType _ bs _) -> do - if all (== Explicit) (nestToList getExpl bs) + Pi piTy@(CorePiType _ expls _ _) -> do + if all (== Explicit) expls then fallback else case expr of WithSrcE src (ULam lam) -> addSrcContext src $ Lam <$> checkULam lam piTy @@ -941,7 +972,8 @@ checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do -- TODO: check explicitness constraints ab <- withUBinders bs \_ -> EffTy <$> checkUEffRow effs <*> checkUType ty Abs bs' effTy' <- return ab - matchRequirement $ Type $ Pi $ CorePiType appExpl bs' effTy' + let (expls, bs'') = unzipAttrs bs' + matchRequirement $ Type $ Pi $ CorePiType appExpl expls bs'' effTy' UTabPi (UTabPiExpr (UAnnBinder b ann cs) ty) -> do unless (null cs) $ throw TypeErr "`=>` shouldn't have constraints" ann' <- asIxType =<< checkAnn (getSourceName b) ann @@ -1149,11 +1181,11 @@ getFieldDefs ty = case ty of instantiateSigma :: forall i o. EmitsBoth o => SigmaAtom o -> InfererM i o (CAtom o) instantiateSigma sigmaAtom = case getType sigmaAtom of - Pi piTy@(CorePiType ExplicitApp _ _) -> do + Pi piTy@(CorePiType ExplicitApp _ _ _) -> do Lam <$> etaExpandExplicits fDesc piTy \args -> applySigmaAtom (sink sigmaAtom) args - Pi (CorePiType ImplicitApp bs (EffTy _ resultTy)) -> do - args <- inferMixedArgs @UExpr fDesc (Abs bs resultTy) [] [] + Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy)) -> do + args <- inferMixedArgs @UExpr fDesc expls (Abs bs resultTy) [] [] applySigmaAtom sigmaAtom args DepPairTy (DepPairType ImplicitDepPair _ _) -> -- TODO: we should probably call instantiateSigma again here in case @@ -1190,53 +1222,55 @@ etaExpandExplicits :: EmitsInf o => SourceName -> CorePiType o -> (forall o'. (EmitsBoth o', DExt o o') => [CAtom o'] -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -etaExpandExplicits fSourceName (CorePiType _ bsTop (EffTy effs _)) contTop = do - ab <- go bsTop \xs -> do +etaExpandExplicits fSourceName (CorePiType _ explsTop bsTop (EffTy effs _)) contTop = do + Abs bs body <- go explsTop bsTop \xs -> do effs' <- applySubst (bsTop@@>(SubstVal<$>xs)) effs withAllowedEffects effs' do body <- buildBlockInf $ contTop $ sinkList xs return $ PairE effs' body - coreLamExpr ExplicitApp ab + let (expls, bs') = unzipAttrs bs + coreLamExpr ExplicitApp expls $ Abs bs' body where go :: (EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e, HoistableE e ) - => Nest (WithExpl CBinder) o any + => [Explicitness] -> Nest CBinder o any -> (forall o'. (EmitsInf o', DExt o o') => [CAtom o'] -> InfererM i o' (e o')) -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) - go Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (Nest (WithExpl expl (b:>ty)) rest) cont = case expl of + go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] + go (expl:expls) (Nest (b:>ty) rest) cont = case expl of Explicit -> do - prependAbs <$> buildAbsInf (getNameHint b) expl ty \v -> do + prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ty \v -> do Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE - go rest' \args -> cont (sink (Var v) : args) + go expls rest' \args -> cont (sink (Var v) : args) Inferred argSourceName infMech -> do arg <- getImplicitArg (fSourceName, fromMaybe "_" argSourceName) infMech ty Abs rest' UnitE <- applySubst (b@>SubstVal arg) $ Abs rest UnitE - go rest' \args -> cont (sink arg : args) + go expls rest' \args -> cont (sink arg : args) + go _ _ _ = error "zip error" buildLamInf :: EmitsInf o => CorePiType o -> (forall o' . (EmitsBoth o', DExt o o') => [(Explicitness, CAtom o')] -> CType o' -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -buildLamInf (CorePiType appExpl bsTop effTy) contTop = do - ab <- go bsTop \xs -> do +buildLamInf (CorePiType appExpl explsTop bsTop effTy) contTop = do + ab <- go explsTop bsTop \xs -> do let (expls, xs') = unzip xs EffTy effs' resultTy' <- applySubst (bsTop@@>(SubstVal<$>xs')) effTy withAllowedEffects effs' do body <- buildBlockInf $ contTop (zip expls $ sinkList xs') (sink resultTy') return $ PairE effs' body - coreLamExpr appExpl ab + coreLamExpr appExpl explsTop ab where go :: (EmitsInf o, HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e) - => Nest (WithExpl CBinder) o any - -> (forall o'. (EmitsInf o', DExt o o') - => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) - go Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (Nest (WithExpl expl b) rest) cont = do + => [Explicitness] -> Nest CBinder o any + -> (forall o'. (EmitsInf o', DExt o o') => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) + -> InfererM i o (Abs (Nest CBinder) e o) + go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] + go (expl:expls) (Nest b rest) cont = do prependAbs <$> buildAbsInf (getNameHint b) expl (binderType b) \v -> do Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE - go rest' \args -> cont ((expl, sink (Var v)) : args) + go expls rest' \args -> cont $ (expl, sink $ Var v) : args + go _ _ _ = error "zip error" class ExplicitArg (e::E) where checkExplicitArg :: EmitsBoth o => IsDependent -> e i -> CType o -> InfererM i o (CAtom o) @@ -1266,14 +1300,14 @@ checkOrInferApp checkOrInferApp f' posArgs namedArgs reqTy = do f <- maybeInterpretPunsAsTyCons reqTy f' case getType f of - Pi (CorePiType appExpl bs effTy) -> case appExpl of + Pi (CorePiType appExpl expls bs effTy) -> case appExpl of ExplicitApp -> do - checkArity bs posArgs - args' <- inferMixedArgs fDesc (Abs bs effTy) posArgs namedArgs + checkArity expls posArgs + args' <- inferMixedArgs fDesc expls (Abs bs effTy) posArgs namedArgs applySigmaAtom f args' >>= matchRequirement ImplicitApp -> do -- TODO: should this already have been done by the time we get `f`? - implicitArgs <- inferMixedArgs @UExpr fDesc (Abs bs effTy) [] [] + implicitArgs <- inferMixedArgs @UExpr fDesc expls (Abs bs effTy) [] [] f'' <- SigmaAtom (Just fDesc) <$> applySigmaAtom f implicitArgs checkOrInferApp f'' posArgs namedArgs Infer >>= matchRequirement -- TODO: special-case error for when `fTy` can't possibly be a function @@ -1320,24 +1354,24 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of f'' <- toAtomVar f' emitExprWithEffects =<< mkApp (Var f'') args UTyConVar f' -> do - TyConDef sn bs _ <- lookupTyCon f' - let expls = nestToList (\(RolePiBinder _ (WithExpl expl _)) -> expl) bs + TyConDef sn roleExpls _ _ <- lookupTyCon f' + let expls = snd <$> roleExpls return $ Type $ NewtypeTyCon $ UserADTType sn f' (TyConParams expls args) UDataConVar v -> do (tyCon, i) <- lookupDataCon v applyDataCon tyCon i args UPunVar tc -> do - TyConDef sn _ _ <- lookupTyCon tc + TyConDef sn _ _ _ <- lookupTyCon tc -- interpret as a data constructor by default (params, dataArgs) <- splitParamPrefix tc args repVal <- makeStructRepVal tc dataArgs return $ NewtypeCon (UserADTData sn tc params) repVal UClassVar f' -> do - ClassDef sourceName _ _ _ _ _ <- lookupClassDef f' + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef f' return $ Type $ DictTy $ DictType sourceName f' args UMethodVar f' -> do MethodBinding className methodIdx <- lookupEnv f' - ClassDef _ _ _ paramBs _ _ <- lookupClassDef className + ClassDef _ _ _ _ paramBs _ _ <- lookupClassDef className let numParams = nestLength paramBs -- params aren't needed because they're already implied by the dict argument let (dictArg:args') = drop numParams args @@ -1349,14 +1383,14 @@ applySigmaAtom (SigmaPartialApp _ f prevArgs) args = splitParamPrefix :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n, [CAtom n]) splitParamPrefix tc args = do - TyConDef _ paramBs _ <- lookupTyCon tc + TyConDef _ _ paramBs _ <- lookupTyCon tc let (paramArgs, dataArgs) = splitAt (nestLength paramBs) args params <- makeTyConParams tc paramArgs return (params, dataArgs) applyDataCon :: Emits o => TyConName o -> Int -> [CAtom o] -> InfererM i o (CAtom o) applyDataCon tc conIx topArgs = do - tyDef@(TyConDef sn _ _) <- lookupTyCon tc + tyDef@(TyConDef sn _ _ _) <- lookupTyCon tc (params, dataArgs) <- splitParamPrefix tc topArgs ADTCons conDefs <- instantiateTyConDef tyDef params DataConDef _ _ repTy _ <- return $ conDefs !! conIx @@ -1390,9 +1424,9 @@ emitExprWithEffects expr = do addEffects $ getEffects expr emitExpr expr -checkArity :: BindsNames b => Nest (WithExpl b) n l -> [a] -> InfererM i o () -checkArity bs args = do - let arity = length [() | Explicit <- nestToList (\(WithExpl expl _) -> expl) bs] +checkArity :: [Explicitness] -> [a] -> InfererM i o () +checkArity expls args = do + let arity = length [() | Explicit <- expls] let numArgs = length args when (numArgs /= arity) do throw TypeErr $ "Wrong number of positional arguments provided. Expected " ++ @@ -1402,24 +1436,25 @@ checkArity bs args = do inferMixedArgs :: forall arg i o e . (ExplicitArg arg, EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => SourceName - -> Abs (Nest (WithExpl CBinder)) e o -> [arg i] -> [(SourceName, arg i)] + => SourceName -> [Explicitness] + -> Abs (Nest CBinder) e o -> [arg i] -> [(SourceName, arg i)] -> InfererM i o [CAtom o] -inferMixedArgs fSourceName bsAbs posArgs namedArgs = do - checkNamedArgValidity bsAbs (map fst namedArgs) - liftM fst $ runStreamReaderT1 posArgs $ go bsAbs +inferMixedArgs fSourceName explsTop bsAbs posArgs namedArgs = do + checkNamedArgValidity explsTop (map fst namedArgs) + liftM fst $ runStreamReaderT1 posArgs $ go explsTop bsAbs where go :: (EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => Abs (Nest (WithExpl CBinder)) e o + => [Explicitness] -> Abs (Nest CBinder) e o -> StreamReaderT1 (arg i) (InfererM i) o [CAtom o] - go (Abs Empty _) = return [] - go (Abs (Nest (WithExpl expl b) bs) result) = do + go [] (Abs Empty _) = return [] + go (expl:expls) (Abs (Nest b bs) result) = do let rest = Abs bs result let isDependent = binderName b `isFreeIn` rest arg <- inferMixedArg isDependent (binderType b) expl arg' <- lift11 $ zonk arg rest' <- applySubst (b @> SubstVal arg') rest - (arg:) <$> go rest' + (arg:) <$> go expls rest' + go _ _ = error "zip error" inferMixedArg :: EmitsBoth o => IsDependent -> CType o -> Explicitness -> StreamReaderT1 (arg i) (InfererM i) o (CAtom o) @@ -1437,12 +1472,12 @@ inferMixedArgs fSourceName bsAbs posArgs namedArgs = do lookupNamedArg Nothing = Nothing lookupNamedArg (Just v) = lookup v namedArgs -checkNamedArgValidity :: (BindsNames b, Fallible m) => Abs (Nest (WithExpl b)) e any -> [SourceName] -> m () -checkNamedArgValidity (Abs bs _) offeredNames = do +checkNamedArgValidity :: Fallible m => [Explicitness] -> [SourceName] -> m () +checkNamedArgValidity expls offeredNames = do let explToMaybeName = \case Explicit -> Nothing Inferred v _ -> v - let acceptedNames = catMaybes $ nestToList (explToMaybeName . getExpl) bs + let acceptedNames = catMaybes $ map explToMaybeName expls let duplicates = repeated offeredNames when (not $ null duplicates) do throw TypeErr $ "Repeated names offered" ++ pprint duplicates @@ -1454,7 +1489,8 @@ checkNamedArgValidity (Abs bs _) offeredNames = do inferPrimArg :: EmitsBoth o => UExpr i -> InfererM i o (CAtom o) inferPrimArg x = do xBlock <- buildBlockInf $ inferRho noHint x - case getType xBlock of + EffTy _ ty <- blockEffTy xBlock + case ty of TyKind -> cheapReduce xBlock >>= \case Just reduced -> return reduced _ -> throw CompilerErr "Type args to primops must be reducible" @@ -1609,7 +1645,7 @@ buildSortedCase scrut alts resultTy = do scrutTy <- return $ getType scrut case scrutTy of TypeCon _ defName _ -> do - TyConDef _ _ (ADTCons cons) <- lookupTyCon defName + TyConDef _ _ _ (ADTCons cons) <- lookupTyCon defName case cons of [] -> error "case of void?" -- Single constructor ADTs are not sum types, so elide the case. @@ -1622,14 +1658,13 @@ buildSortedCase scrut alts resultTy = do -- TODO: cache this with the instance def (requires a recursive binding) instanceFun :: EnvReader m => InstanceName n -> AppExplicitness -> m n (CAtom n) -instanceFun instanceName expl = do - InstanceDef _ bs _ _ <- lookupInstanceDef instanceName +instanceFun instanceName appExpl = do + InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName ab <- liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do args <- mapM toAtomVar $ nestToNames bs' - let bs'' = fmapNest (\(RolePiBinder _ b) -> b) bs' result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> args) - return $ Abs bs'' (PairE Pure (AtomicBlock result)) - Lam <$> coreLamExpr expl ab + return $ Abs bs' (PairE Pure (WithoutDecls result)) + Lam <$> coreLamExpr appExpl (snd<$>expls) ab checkMaybeAnnExpr :: EmitsBoth o => NameHint -> Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) @@ -1655,53 +1690,55 @@ inferRole ty = \case inferTyConDef :: EmitsInf o => UDataDef i -> InfererM i o (TyConDef o) inferTyConDef (UDataDef tyConName paramBs dataCons) = do Abs paramBs' dataCons' <- - withRoleUBinders paramBs \_ -> do + withRoleUBinders paramBs do ADTCons <$> mapM inferDataCon dataCons - return (TyConDef tyConName paramBs' dataCons') + let (roleExpls, paramBs'') = unzipAttrs paramBs' + return (TyConDef tyConName roleExpls paramBs'' dataCons') inferStructDef :: EmitsInf o => UStructDef i -> InfererM i o (TyConDef o) inferStructDef (UStructDef tyConName paramBs fields _) = do let (fieldNames, fieldTys) = unzip fields - Abs paramBs' dataConDefs <- withRoleUBinders paramBs \_ -> do + Abs paramBs' dataConDefs <- withRoleUBinders paramBs do tys <- mapM checkUType fieldTys return $ StructFields $ zip fieldNames tys - return $ TyConDef tyConName paramBs' dataConDefs + let (roleExpls, paramBs'') = unzipAttrs paramBs' + return $ TyConDef tyConName roleExpls paramBs'' dataConDefs inferDotMethod :: EmitsInf o => TyConName o - -> Abs (Nest (WithExpl UOptAnnBinder)) (Abs UAtomBinder ULamExpr) i + -> Abs (Nest UOptAnnBinder) (Abs UAtomBinder ULamExpr) i -> InfererM i o (CoreLamExpr o) inferDotMethod tc (Abs uparamBs (Abs selfB lam)) = do - TyConDef sn paramBs _ <- lookupTyCon tc - let paramBs' = fmapNest (\(RolePiBinder _ b) -> b) paramBs - ab <- buildNaryAbsInf (Abs paramBs' UnitE) \paramVs -> do - let expls = nestToList (\(WithExpl expl _) -> expl) paramBs' + TyConDef sn roleExpls paramBs _ <- lookupTyCon tc + let expls = snd <$> roleExpls + ab <- buildNaryAbsInfWithExpl expls (Abs paramBs UnitE) \paramVs -> do let paramVs' = catMaybes $ zip expls paramVs <&> \(expl, v) -> case expl of Inferred _ (Synth _) -> Nothing _ -> Just v extendRenamer (uparamBs @@> (atomVarName <$> paramVs')) do let selfTy = NewtypeTyCon $ UserADTType sn (sink tc) (TyConParams expls (Var <$> paramVs)) - buildAbsInf "self" Explicit selfTy \vSelf -> + buildAbsInfWithExpl "self" Explicit selfTy \vSelf -> extendRenamer (selfB @> atomVarName vSelf) $ inferULam lam Abs paramBs'' (Abs selfB' lam') <- return ab return $ prependCoreLamExpr (paramBs'' >>> UnaryNest selfB') lam' prependCoreLamExpr :: Nest (WithExpl CBinder) n l -> CoreLamExpr l -> CoreLamExpr n prependCoreLamExpr bs e = case e of - CoreLamExpr (CorePiType appExpl piBs effTy) (LamExpr lamBs body) -> do - let piType = CorePiType appExpl (bs >>> piBs) effTy - let lamExpr = LamExpr (fmapNest withoutExpl bs >>> lamBs) body + CoreLamExpr (CorePiType appExpl piExpls piBs effTy) (LamExpr lamBs body) -> do + let (expls, bs') = unzipAttrs bs + let piType = CorePiType appExpl (expls <> piExpls) (bs' >>> piBs) effTy + let lamExpr = LamExpr (fmapNest withoutAttr bs >>> lamBs) body CoreLamExpr piType lamExpr inferDataCon :: EmitsInf o => (SourceName, UDataDefTrail i) -> InfererM i o (DataConDef o) inferDataCon (sourceName, UDataDefTrail argBs) = do - let argBsExpls = addExpls Explicit argBs - Abs argBs' UnitE <- withUBinders argBsExpls \_ -> return UnitE - let argBs'' = Abs (fmapNest withoutExpl argBs') UnitE + let expls = nestToList (const Explicit) argBs + Abs argBs' UnitE <- withUBinders (expls, argBs) \_ -> return UnitE + let argBs'' = Abs (fmapNest withoutAttr argBs') UnitE let (repTy, projIdxs) = dataConRepTy argBs'' return $ DataConDef sourceName argBs'' repTy projIdxs -dataConRepTy :: EmptyAbs (Nest (Binder CoreIR)) n -> (CType n, [[Projection]]) +dataConRepTy :: EmptyAbs (Nest CBinder) n -> (CType n, [[Projection]]) dataConRepTy (Abs topBs UnitE) = case topBs of Empty -> (UnitTy, []) _ -> go [] [UnwrapNewtype] topBs @@ -1729,47 +1766,49 @@ dataConRepTy (Abs topBs UnitE) = case topBs of inferClassDef :: EmitsInf o => SourceName -> [SourceName] - -> Nest (WithExpl UOptAnnBinder) i i' + -> UOptAnnExplBinders i i' -> [UType i'] -> InfererM i o (ClassDef o) -inferClassDef className methodNames paramBs methods = do +inferClassDef className methodNames paramBs@(expls, paramBs') methods = do + let paramBsWithAttrBs = zipWithNest paramBs' expls \b expl -> WithAttrB expl b let paramNames = catMaybes $ nestToList - (\(WithExpl expl (UAnnBinder b _ _)) -> case expl of + (\(WithAttrB expl (UAnnBinder b _ _)) -> case expl of Inferred _ (Synth _) -> Nothing - _ -> Just $ Just $ getSourceName b) paramBs - ab <- withRoleUBinders paramBs \_ -> do + _ -> Just $ Just $ getSourceName b) paramBsWithAttrBs + ab <- withRoleUBinders paramBs do ListE <$> forM methods \m -> do checkUType m >>= \case Pi t -> return t - t -> return $ CorePiType ImplicitApp Empty (EffTy Pure t) + t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) Abs (PairB bs scs) (ListE mtys) <- identifySuperclasses ab - return $ ClassDef className methodNames paramNames bs scs mtys + let (roleExpls, bs') = unzipAttrs bs + return $ ClassDef className methodNames paramNames roleExpls bs' scs mtys --- TODO: this is just partitioning the binders. We could write a more general function like this: --- partitionBinders :: Nest b n l -> (forall n l. b i i' -> EitherB b1 b2 i i') --- -> Except (PairB (Nest b1) (Nest b2)) n l identifySuperclasses - :: RenameE e => Abs RolePiBinders e n - -> InfererM i n (Abs (PairB RolePiBinders (Nest CBinder)) e n) -identifySuperclasses ab = refreshAbs ab \bs e -> do - bs' <- partitionBinders bs \b@(RolePiBinder _ (WithExpl expl b')) -> case expl of - Explicit -> return $ LeftB b - Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" - Inferred _ (Synth _) -> return $ RightB b' - return $ Abs bs' e + :: RenameE e => Abs (Nest (WithRoleExpl CBinder)) e n + -> InfererM i n (Abs (PairB (Nest (WithRoleExpl CBinder)) (Nest CBinder)) e n) +identifySuperclasses ab = do + refreshAbs ab \bs e -> do + bs' <- partitionBinders bs \b@(WithAttrB (_, expl) b') -> case expl of + Explicit -> return $ LeftB b + Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" + Inferred _ (Synth _) -> return $ RightB b' + return $ Abs bs' e withUBinders :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) - => Nest (WithExpl (UAnnBinder req)) i i' + => UAnnExplBinders req i i' -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) withUBinders bs cont = case bs of - Empty -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] - Nest (WithExpl expl (UAnnBinder b ann cs)) rest -> do + ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] + (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do ann' <- checkAnn (getSourceName b) ann - prependAbs <$> buildAbsInf (getNameHint b) expl ann' \v -> + prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ann' \v -> concatAbs <$> withConstraintBinders cs v do - extendSubst (b@>sink (atomVarName v)) $ withUBinders rest \vs -> cont (sink v : vs) + extendSubst (b@>sink (atomVarName v)) $ withUBinders (expls, rest) \vs -> + cont (sink v : vs) + _ -> error "zip error" withConstraintBinders :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, RenameE e, SinkableE e) @@ -1782,24 +1821,26 @@ withConstraintBinders (c:cs) v cont = do Type dictTy <- withReducibleEmissions "Can't reduce interface constraint" do c' <- inferWithoutInstantiation c >>= zonk dropSubst $ checkOrInferApp c' [Var $ sink v] [] (Check TyKind) - prependAbs <$> buildAbsInf "d" (Inferred Nothing (Synth Full)) dictTy \_ -> + prependAbs <$> buildAbsInfWithExpl "d" (Inferred Nothing (Synth Full)) dictTy \_ -> withConstraintBinders cs (sink v) cont withRoleUBinders :: forall i i' o e req. (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) - => Nest (WithExpl (UAnnBinder req)) i i' - -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) - -> InfererM i o (Abs RolePiBinders e o) -withRoleUBinders bs cont = case bs of - Empty -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] - Nest (WithExpl expl (UAnnBinder b ann cs)) rest -> do + => UAnnExplBinders req i i' + -> (forall o'. (EmitsInf o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (Abs (Nest (WithRoleExpl CBinder)) e o) +withRoleUBinders roleBs cont = case roleBs of + ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont + (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do ann' <- checkAnn (getSourceName b) ann Abs b' (Abs bs' e) <- buildAbsInf (getNameHint b) expl ann' \v -> do Abs ds (Abs bs' e) <- withConstraintBinders cs v $ - extendSubst (b@>sink (atomVarName v)) $ withRoleUBinders rest \vs -> cont (sink v : vs) - return $ Abs (fmapNest (RolePiBinder DictParam) ds >>> bs') e + extendSubst (b@>sink (atomVarName v)) $ withRoleUBinders (expls, rest) cont + let ds' = fmapNest (\(WithAttrB expl' b') -> WithAttrB (DictParam, expl') b') ds + return $ Abs (ds' >>> bs') e role <- inferRole (binderType b') expl - return $ Abs (Nest (RolePiBinder role b') bs') e + return $ Abs (Nest (WithAttrB (role,expl) b') bs') e + _ -> error "zip error" inferULam :: EmitsInf o => ULamExpr i -> InfererM i o (CoreLamExpr o) inferULam (ULamExpr bs appExpl effs resultTy body) = do @@ -1814,12 +1855,13 @@ inferULam (ULamExpr bs appExpl effs resultTy body) = do checkSigma noHint result (sink resultTy'') return (PairE effs' body') Abs bs' (PairE effs' body') <- return ab + let (expls, bs'') = unzipAttrs bs' case appExpl of - ImplicitApp -> checkImplicitLamRestrictions bs' effs' + ImplicitApp -> checkImplicitLamRestrictions bs'' effs' ExplicitApp -> return () - coreLamExpr appExpl $ Abs bs' $ PairE effs' body' + coreLamExpr appExpl expls $ Abs bs'' $ PairE effs' body' -checkImplicitLamRestrictions :: Nest (WithExpl CBinder) o o' -> EffectRow CoreIR o' -> InfererM i o () +checkImplicitLamRestrictions :: Nest CBinder o o' -> EffectRow CoreIR o' -> InfererM i o () checkImplicitLamRestrictions _ _ = return () -- TODO checkUForExpr :: EmitsBoth o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) @@ -1836,7 +1878,7 @@ checkUForExpr (UForExpr (UAnnBinder bFor ann cs) body) tabPi@(TabPiType _ bPi _) buildBlockInf do withBlockDecls body \result -> checkSigma noHint result $ sink resultTy' - return $ LamExpr (UnaryNest $ withoutExpl b) body' + return $ LamExpr (UnaryNest b) body' inferUForExpr :: EmitsBoth o => UForExpr i -> InfererM i o (LamExpr CoreIR o) inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do @@ -1846,15 +1888,15 @@ inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do extendRenamer (bFor@>atomVarName i) $ buildBlockInf $ withBlockDecls body \result -> checkOrInferRho noHint result Infer - return $ LamExpr (UnaryNest $ withoutExpl b) body' + return $ LamExpr (UnaryNest b) body' checkULam :: EmitsInf o => ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) -checkULam (ULamExpr lamBs lamAppExpl lamEffs lamResultTy body) - (CorePiType piAppExpl piBs effTy) = do - checkArity piBs (nestToList (const ()) lamBs) +checkULam (ULamExpr (_, lamBs) lamAppExpl lamEffs lamResultTy body) + (CorePiType piAppExpl expls piBs effTy) = do + checkArity expls (nestToList (const ()) lamBs) when (piAppExpl /= lamAppExpl) $ throw TypeErr $ "Wrong arrow. Expected " ++ pprint piAppExpl ++ " got " ++ pprint lamAppExpl - ab <- checkLamBinders piBs lamBs \vs -> do + Abs explBs body' <- checkLamBinders expls piBs lamBs \vs -> do EffTy piEffs' piResultTy' <- applyRename (piBs@@>map atomVarName vs) effTy case lamResultTy of Nothing -> return () @@ -1868,47 +1910,44 @@ checkULam (ULamExpr lamBs lamAppExpl lamEffs lamResultTy body) withBlockDecls body \result -> checkSigma noHint result piResultTy'' return $ PairE piEffs' body' - coreLamExpr piAppExpl ab + let (expls', bs') = unzipAttrs explBs + coreLamExpr piAppExpl expls' $ Abs bs' body' checkLamBinders :: (EmitsInf o, SinkableE e, HoistableE e, SubstE AtomSubstVal e, RenameE e) - => Nest (WithExpl CBinder) o any - -> Nest (WithExpl UOptAnnBinder) i i' + => [Explicitness] -> Nest CBinder o any + -> Nest UOptAnnBinder i i' -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) -checkLamBinders Empty Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -checkLamBinders (Nest (WithExpl piExpl (piB:>piAnn)) piBs) lamBs cont = do +checkLamBinders [] Empty Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] +checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do prependAbs <$> case piExpl of Inferred _ _ -> - buildAbsInf (getNameHint piB) piExpl piAnn \v -> do + buildAbsInfWithExpl (getNameHint piB) piExpl piAnn \v -> do Abs piBs' UnitE <- applyRename (piB@>atomVarName v) $ Abs piBs UnitE - checkLamBinders piBs' lamBs \vs -> + checkLamBinders piExpls piBs' lamBs \vs -> cont (sink v:vs) Explicit -> case lamBs of - Nest (WithExpl Explicit (UAnnBinder lamB ann cs)) lamBsRest -> do + Nest (UAnnBinder lamB ann cs) lamBsRest -> do case ann of UAnn lamAnn -> checkUType lamAnn >>= constrainTypesEq piAnn UNoAnn -> return () - buildAbsInf (getNameHint lamB) Explicit piAnn \v -> do + buildAbsInfWithExpl (getNameHint lamB) Explicit piAnn \v -> do concatAbs <$> withConstraintBinders cs v do Abs piBs' UnitE <- applyRename (piB@>sink (atomVarName v)) $ Abs piBs UnitE - extendRenamer (lamB@>sink (atomVarName v)) $ checkLamBinders piBs' lamBsRest \vs -> + extendRenamer (lamB@>sink (atomVarName v)) $ checkLamBinders piExpls piBs' lamBsRest \vs -> cont (sink v:vs) - Nest (WithExpl (Inferred _ _) _) _ -> - -- TODO(dougalm): I don't think this case is reachable, but if it is - -- then we can check for it in `checkULam` and fall back to `inferULam`. - error "shouldn't be able to check lambda terms with implicit binders" Empty -> error "zip error" -checkLamBinders _ _ _ = error "zip error" +checkLamBinders _ _ _ _ = error "zip error" -checkInstanceParams :: EmitsInf o => RolePiBinders o any -> [UExpr i] -> InfererM i o [CAtom o] -checkInstanceParams bsTop paramsTop = do - checkArity (fmapNest (\(RolePiBinder _ b) -> b) bsTop) paramsTop +checkInstanceParams :: EmitsInf o => [Explicitness] -> Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] +checkInstanceParams expls bsTop paramsTop = do + checkArity expls paramsTop go bsTop paramsTop where - go :: EmitsInf o => Nest RolePiBinder o any -> [UExpr i] -> InfererM i o [CAtom o] + go :: EmitsInf o => Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] go Empty [] = return [] - go (Nest (RolePiBinder _ (WithExpl _ (b:>ty))) bs) (x:xs) = do + go (Nest (b:>ty) bs) (x:xs) = do x' <- checkUParam ty x Abs bs' UnitE <- applySubst (b@>SubstVal x') $ Abs bs UnitE (x':) <$> go bs' xs @@ -1918,7 +1957,7 @@ checkInstanceBody :: EmitsInf o => ClassName o -> [CAtom o] -> [UMethodDef i] -> InfererM i o (InstanceBody o) checkInstanceBody className params methods = do - ClassDef _ methodNames _ paramBs scBs methodTys <- lookupClassDef className + ClassDef _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className Abs scBs' methodTys' <- applySubst (paramBs @@> (SubstVal <$> params)) $ Abs scBs $ ListE methodTys superclassTys <- superclassDictTys scBs' superclassDicts <- mapM (flip trySynthTerm Full) superclassTys @@ -1943,7 +1982,7 @@ checkMethodDef className methodTys (WithSrcE src m) = addSrcContext src do UMethodDef ~(InternalName _ sourceName v) rhs <- return m MethodBinding className' i <- renameM v >>= lookupEnv when (className /= className') do - ClassBinding (ClassDef classSourceName _ _ _ _ _) <- lookupEnv className + ClassBinding (ClassDef classSourceName _ _ _ _ _ _) <- lookupEnv className throw TypeErr $ pprint sourceName ++ " is not a method of " ++ pprint classSourceName (i,) <$> Lam <$> checkULam rhs (methodTys !! i) @@ -1966,7 +2005,6 @@ checkUEff eff = case eff of return $ RWSEffect rws (Var region') UExceptionEffect -> return ExceptionEffect UIOEffect -> return IOEffect - UUserEffect ~(SIInternalName _ name _ _) -> UserEffect <$> renameM name constrainVarTy :: EmitsInf o => CAtomVar o -> CType o -> InfererM i o () constrainVarTy v tyReq = do @@ -2000,12 +2038,12 @@ checkCasePat :: EmitsBoth o checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat of UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, con) <- renameM conName >>= lookupDataCon - TyConDef sourceName paramBs (ADTCons cons) <- lookupTyCon dataDefName + TyConDef sourceName roleExpls paramBs (ADTCons cons) <- lookupTyCon dataDefName DataConDef _ _ repTy idxs <- return $ cons !! con when (length idxs /= nestLength ps) $ throw TypeErr $ "Unexpected number of pattern binders. Expected " ++ show (length idxs) ++ " got " ++ show (nestLength ps) - (params, repTy') <- inferParams sourceName (Abs paramBs repTy) + (params, repTy') <- inferParams sourceName roleExpls (Abs paramBs repTy) constrainTypesEq scrutineeTy $ TypeCon sourceName dataDefName params buildAltInf repTy' \arg -> do args <- forM idxs \projs -> do @@ -2015,22 +2053,23 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" inferParams :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) - => SourceName -> Abs RolePiBinders e o -> InfererM i o (TyConParams o, e o) -inferParams sourceName (Abs paramBs bodyTop) = do - (params, e') <- go (Abs paramBs bodyTop) - let expls = nestToList (\(RolePiBinder _ (WithExpl expl _)) -> expl) paramBs + => SourceName -> [RoleExpl] -> Abs (Nest CBinder) e o -> InfererM i o (TyConParams o, e o) +inferParams sourceName roleExpls (Abs paramBs bodyTop) = do + let expls = snd <$> roleExpls + (params, e') <- go expls (Abs paramBs bodyTop) return (TyConParams expls params, e') where go :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) - => Abs (Nest RolePiBinder) e o -> InfererM i o ([CAtom o], e o) - go (Abs Empty body) = return ([], body) - go (Abs (Nest (RolePiBinder _ (WithExpl expl (b:>ty))) bs) body) = do + => [Explicitness] -> Abs (Nest CBinder) e o -> InfererM i o ([CAtom o], e o) + go [] (Abs Empty body) = return ([], body) + go (expl:expls) (Abs (Nest (b:>ty) bs) body) = do x <- case expl of Explicit -> Var <$> freshInferenceName (TypeInstantiationInfVar sourceName) ty Inferred argName infMech -> getImplicitArg (sourceName, fromMaybe "_" argName) infMech ty rest <- applySubst (b@>SubstVal x) $ Abs bs body - (params, body') <- go rest + (params, body') <- go expls rest return (x:params, body') + go _ _ = error "zip error" bindLetPats :: EmitsBoth o => Nest UPat i i' -> [CAtomVar o] -> InfererM i' o a -> InfererM i o a @@ -2061,13 +2100,13 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of cont UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, _) <- lookupDataCon =<< renameM conName - TyConDef sourceName paramBs cons <- lookupTyCon dataDefName + TyConDef sourceName roleExpls paramBs cons <- lookupTyCon dataDefName case cons of ADTCons [DataConDef _ _ _ idxss] -> do when (length idxss /= nestLength ps) $ throw TypeErr $ "Unexpected number of pattern binders. Expected " ++ show (length idxss) ++ " got " ++ show (nestLength ps) - (params, UnitE) <- inferParams sourceName (Abs paramBs UnitE) + (params, UnitE) <- inferParams sourceName roleExpls (Abs paramBs UnitE) constrainVarTy v $ TypeCon sourceName dataDefName params x <- cheapNormalize =<< zonk (Var v) xs <- forM idxss \idxs -> normalizeNaryProj idxs x >>= emit . Atom @@ -2132,7 +2171,7 @@ inferTabCon hint xs reqTy = do withFreshBinder noHint finTy \b' -> do elemTy' <- applyRename (b@>binderName b') elemTy dTy <- DictTy <$> dataDictType elemTy' - return $ Pi $ CorePiType ImplicitApp (UnaryNest (WithExpl (Inferred Nothing Unify) b')) (EffTy Pure dTy) + return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' -- Bool flag is just to tweak the reported error message @@ -2486,19 +2525,19 @@ unifyEq e1 e2 = guard =<< alphaEq e1 e2 {-# INLINE unifyEq #-} instance Unifiable CorePiType where - unifyZonked (CorePiType appExpl1 bsTop1 effTy1) - (CorePiType appExpl2 bsTop2 effTy2) = do + unifyZonked (CorePiType appExpl1 expls1 bsTop1 effTy1) + (CorePiType appExpl2 expls2 bsTop2 effTy2) = do unless (appExpl1 == appExpl2) empty + unless (expls1 == expls2) empty go (Abs bsTop1 effTy1) (Abs bsTop2 effTy2) where go :: EmitsInf n - => Abs (Nest (WithExpl CBinder)) (EffTy CoreIR) n - -> Abs (Nest (WithExpl CBinder)) (EffTy CoreIR) n + => Abs (Nest CBinder) (EffTy CoreIR) n + -> Abs (Nest CBinder) (EffTy CoreIR) n -> SolverM n () go (Abs Empty (EffTy e1 t1)) (Abs Empty (EffTy e2 t2)) = unify t1 t2 >> unify e1 e2 - go (Abs (Nest (WithExpl expl1 (b1:>t1)) bs1) rest1) - (Abs (Nest (WithExpl expl2 (b2:>t2)) bs2) rest2) = do - unless (expl1 == expl2) empty + go (Abs (Nest (b1:>t1) bs1) rest1) + (Abs (Nest (b2:>t2) bs2) rest2) = do unify t1 t2 v <- freshSkolemName t1 ab1 <- zonk =<< applySubst (b1@>SubstVal (Var v)) (Abs bs1 rest1) @@ -2577,13 +2616,9 @@ synthTopE block = do {-# SCC synthTopE #-} synthTyConDef :: (EnvReader m, Fallible1 m) => TyConDef n -> m n (TyConDef n) -synthTyConDef (TyConDef sn rbs body) = (liftExcept =<<) $ liftDictSynthTraverserM do - let bs = fmapNest (\(RolePiBinder _ b) -> b) rbs - let roles = nestToList (\(RolePiBinder role _) -> role) rbs - dsTraverseExplBinders bs \bs' -> do - body' <- dsTraverse body - let rbs' = zipWithNest bs' roles \b role -> RolePiBinder role b - return $ TyConDef sn rbs' body' +synthTyConDef (TyConDef sn roleExpls bs body) = (liftExcept =<<) $ liftDictSynthTraverserM do + dsTraverseExplBinders (snd <$> roleExpls) bs \bs' -> + TyConDef sn roleExpls bs' <$> dsTraverse body {-# SCC synthTyConDef #-} -- Given a simplified dict (an Atom of type `DictTy _` in the @@ -2616,8 +2651,8 @@ generalizeDictRec dict = do DictCon _ dict' <- cheapNormalize dict mkDictAtom =<< case dict' of InstanceDict instanceName args -> do - InstanceDef _ bs _ _ <- lookupInstanceDef instanceName - args' <- generalizeInstanceArgs bs args + InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName + args' <- generalizeInstanceArgs roleExpls bs args return $ InstanceDict instanceName args' IxFin _ -> IxFin <$> Var <$> freshInferenceName MiscInfVar NatTy InstantiatedGiven _ _ -> notSimplifiedDict @@ -2625,9 +2660,9 @@ generalizeDictRec dict = do DataData ty -> DataData <$> TyVar <$> freshInferenceName MiscInfVar ty where notSimplifiedDict = error $ "Not a simplified dict: " ++ pprint dict -generalizeInstanceArgs :: EmitsInf n => RolePiBinders n l -> [CAtom n] -> SolverM n [CAtom n] -generalizeInstanceArgs Empty [] = return [] -generalizeInstanceArgs (Nest (RolePiBinder role (WithExpl _ (b:>ty))) bs) (arg:args) = do +generalizeInstanceArgs :: EmitsInf n => [RoleExpl] -> Nest CBinder n l -> [CAtom n] -> SolverM n [CAtom n] +generalizeInstanceArgs [] Empty [] = return [] +generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) = do arg' <- case role of -- XXX: for `TypeParam` we can just emit a fresh inference name rather than -- traversing the whole type like we do in `Generalize.hs`. The reason is @@ -2638,21 +2673,21 @@ generalizeInstanceArgs (Nest (RolePiBinder role (WithExpl _ (b:>ty))) bs) (arg:a DictParam -> generalizeDictAndUnify ty arg DataParam -> Var <$> freshInferenceName MiscInfVar ty Abs bs' UnitE <- applySubst (b@>SubstVal arg') (Abs bs UnitE) - args' <- generalizeInstanceArgs bs' args + args' <- generalizeInstanceArgs expls bs' args return $ arg':args' -generalizeInstanceArgs _ _ = error "zip error" +generalizeInstanceArgs _ _ _ = error "zip error" synthInstanceDefAndAddSynthCandidate :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceName n) -synthInstanceDefAndAddSynthCandidate def@(InstanceDef className bs params (InstanceBody superclasses _)) = do - let emptyDef = InstanceDef className bs params $ InstanceBody superclasses [] +synthInstanceDefAndAddSynthCandidate def@(InstanceDef className expls bs params (InstanceBody superclasses _)) = do + let emptyDef = InstanceDef className expls bs params $ InstanceBody superclasses [] instanceName <- emitInstanceDef emptyDef addInstanceSynthCandidate className instanceName synthInstanceDefRec instanceName def return instanceName emitInstanceDef :: (Mut n, TopBuilder m) => InstanceDef n -> m n (Name InstanceNameC n) -emitInstanceDef instanceDef@(InstanceDef className _ _ _) = do +emitInstanceDef instanceDef@(InstanceDef className _ _ _ _) = do ty <- getInstanceType instanceDef emitBinding (getNameHint className) $ InstanceBinding instanceDef ty @@ -2664,46 +2699,47 @@ pattern InstanceDefAbsBody :: [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] pattern InstanceDefAbsBody params superclasses doneMethods todoMethods = ListE params `PairE` (ListE superclasses) `PairE` (ListE doneMethods) `PairE` (ListE todoMethods) -type InstanceDefAbsT = Abs (Nest RolePiBinder) InstanceDefAbsBodyT +type InstanceDefAbsT n = ([RoleExpl], Abs (Nest CBinder) InstanceDefAbsBodyT n) -pattern InstanceDefAbs :: Nest RolePiBinder h n -> [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] +pattern InstanceDefAbs :: [RoleExpl] -> Nest CBinder h n -> [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] -> InstanceDefAbsT h -pattern InstanceDefAbs bs params superclasses doneMethods todoMethods = - Abs bs (InstanceDefAbsBody params superclasses doneMethods todoMethods) +pattern InstanceDefAbs expls bs params superclasses doneMethods todoMethods = + (expls, Abs bs (InstanceDefAbsBody params superclasses doneMethods todoMethods)) synthInstanceDefRec :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceName n -> InstanceDef n -> m n () -synthInstanceDefRec instanceName (InstanceDef className bs params (InstanceBody superclasses methods)) = do - let ab = InstanceDefAbs bs params superclasses [] methods +synthInstanceDefRec instanceName def = do + InstanceDef className roleExplsTop bs params (InstanceBody superclasses methods) <- return def + let ab = InstanceDefAbs roleExplsTop bs params superclasses [] methods recur ab className instanceName where recur :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceDefAbsT n -> ClassName n -> InstanceName n -> m n () - recur (InstanceDefAbs _ _ _ _ []) _ _ = return () - recur ab cname iname = do - (def, ab') <- liftExceptEnvReaderM $ refreshAbs ab + recur (InstanceDefAbs _ _ _ _ _ []) _ _ = return () + recur (roleExpls, ab) cname iname = do + (def', ab') <- liftExceptEnvReaderM $ refreshAbs ab \bs' (InstanceDefAbsBody ps scs doneMethods (m:ms)) -> do EnvReaderT $ ReaderT \(Distinct, env) -> do - let env' = extendSynthCandidatess bs' env + let env' = extendSynthCandidatess (snd<$>roleExpls) bs' env flip runReaderT (Distinct, env') $ runEnvReaderT' do m' <- synthTopE m let doneMethods' = doneMethods ++ [m'] - let ab' = InstanceDefAbs bs' ps scs doneMethods' ms - let def = InstanceDef cname bs' ps $ InstanceBody scs doneMethods' - return (def, ab') - updateTopEnv $ UpdateInstanceDef iname def + let ab' = InstanceDefAbs roleExpls bs' ps scs doneMethods' ms + let def' = InstanceDef cname roleExpls bs' ps $ InstanceBody scs doneMethods' + return (def', ab') + updateTopEnv $ UpdateInstanceDef iname def' recur ab' cname iname synthInstanceDef :: (EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceDef n) -synthInstanceDef (InstanceDef className bs params body) = do +synthInstanceDef (InstanceDef className expls bs params body) = do liftExceptEnvReaderM $ refreshAbs (Abs bs (ListE params `PairE` body)) \bs' (ListE params' `PairE` InstanceBody superclasses methods) -> do EnvReaderT $ ReaderT \(Distinct, env) -> do - let env' = extendSynthCandidatess bs' env + let env' = extendSynthCandidatess (snd<$>expls) bs' env flip runReaderT (Distinct, env') $ runEnvReaderT' do methods' <- mapM synthTopE methods - return $ InstanceDef className bs' params' $ InstanceBody superclasses methods' + return $ InstanceDef className expls bs' params' $ InstanceBody superclasses methods' -- main entrypoint to dictionary synthesizer trySynthTerm :: (Fallible1 m, EnvReader m) => CType n -> RequiredMethodAccess -> m n (SynthAtom n) @@ -2720,7 +2756,7 @@ trySynthTerm ty reqMethodAccess = do {-# SCC trySynthTerm #-} type SynthAtom = CAtom -type SynthPiType = Abs (Nest (WithExpl CBinder)) DictType +type SynthPiType n = ([Explicitness], Abs (Nest CBinder) DictType n) data SynthType n = SynthDictType (DictType n) | SynthPiType (SynthPiType n) @@ -2773,7 +2809,7 @@ getSynthType x = ignoreExcept $ typeAsSynthType (getType x) typeAsSynthType :: CType n -> Except (SynthType n) typeAsSynthType = \case DictTy dictTy -> return $ SynthDictType dictTy - Pi (CorePiType ImplicitApp bs (EffTy Pure (DictTy d))) -> return $ SynthPiType (Abs bs d) + Pi (CorePiType ImplicitApp expls bs (EffTy Pure (DictTy d))) -> return $ SynthPiType (expls, Abs bs d) ty -> Failure $ Errs [Err TypeErr mempty $ "Can't synthesize terms of type: " ++ pprint ty] {-# SCC typeAsSynthType #-} @@ -2819,11 +2855,11 @@ getSuperclassClosurePure env givens newGivens = synthTerm :: SynthType n -> RequiredMethodAccess -> SyntherM n (SynthAtom n) synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of - SynthPiType ab -> do - ab' <- withGivenBinders ab \bs targetTy' -> do + SynthPiType (expls, ab) -> do + ab' <- withGivenBinders expls ab \bs targetTy' -> do Abs bs <$> synthTerm (SynthDictType targetTy') reqMethodAccess Abs bs synthExpr <- return ab' - liftM Lam $ coreLamExpr ImplicitApp $ Abs bs $ PairE Pure (AtomicBlock synthExpr) + liftM Lam $ coreLamExpr ImplicitApp expls $ Abs bs $ PairE Pure (WithoutDecls synthExpr) SynthDictType dictTy -> case dictTy of DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon (DictTy dictTy) $ IxFin n DictType "Data" _ [Type t] -> do @@ -2840,21 +2876,29 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of _ -> return dict {-# SCC synthTerm #-} +coreLamExpr :: EnvReader m => AppExplicitness + -> [Explicitness] -> Abs (Nest CBinder) (PairE (EffectRow CoreIR) CBlock) n + -> m n (CoreLamExpr n) +coreLamExpr appExpl expls ab = liftEnvReaderM do + refreshAbs ab \bs' (PairE effs' body') -> do + EffTy _ resultTy <- blockEffTy body' + return $ CoreLamExpr (CorePiType appExpl expls bs' (EffTy effs' resultTy)) (LamExpr bs' body') + withGivenBinders - :: (SinkableE e, RenameE e) => Abs (Nest (WithExpl CBinder)) e n - -> (forall l. DExt n l => Nest (WithExpl CBinder) n l -> e l -> SyntherM l a) + :: (SinkableE e, RenameE e) => [Explicitness] -> Abs (Nest CBinder) e n + -> (forall l. DExt n l => Nest CBinder n l -> e l -> SyntherM l a) -> SyntherM n a -withGivenBinders (Abs bsTop e) contTop = - runSubstReaderT idSubst $ go bsTop \bsTop' -> do +withGivenBinders explsTop (Abs bsTop e) contTop = + runSubstReaderT idSubst $ go explsTop bsTop \bsTop' -> do e' <- renameM e liftSubstReaderT $ contTop bsTop' e' where - go :: Nest (WithExpl CBinder) i i' - -> (forall o'. DExt o o' => Nest (WithExpl CBinder) o o' -> SubstReaderT Name SyntherM i' o' a) + go :: [Explicitness] -> Nest CBinder i i' + -> (forall o'. DExt o o' => Nest CBinder o o' -> SubstReaderT Name SyntherM i' o' a) -> SubstReaderT Name SyntherM i o a - go bs cont = case bs of - Empty -> getDistinct >>= \Distinct -> cont Empty - Nest (WithExpl expl b) rest -> do + go expls bs cont = case (expls, bs) of + ([], Empty) -> getDistinct >>= \Distinct -> cont Empty + (expl:explsRest, Nest b rest) -> do argTy <- renameM $ binderType b withFreshBinder (getNameHint b) argTy \b' -> do givens <- case expl of @@ -2863,13 +2907,14 @@ withGivenBinders (Abs bsTop e) contTop = s <- getSubst liftSubstReaderT $ extendGivens givens $ runSubstReaderT (s <>> b@>binderName b') $ - go rest \rest' -> cont (Nest (WithExpl expl b') rest') + go explsRest rest \rest' -> cont (Nest b' rest') + _ -> error "zip error" isMethodAccessAllowedBy :: EnvReader m => RequiredMethodAccess -> InstanceName n -> m n Bool isMethodAccessAllowedBy access instanceName = do - InstanceDef className _ _ (InstanceBody _ methods) <- lookupInstanceDef instanceName + InstanceDef className _ _ _ (InstanceBody _ methods) <- lookupInstanceDef instanceName let numInstanceMethods = length methods - ClassDef _ _ _ _ _ methodTys <- lookupClassDef className + ClassDef _ _ _ _ _ _ methodTys <- lookupClassDef className let numClassMethods = length methodTys case access of Full -> return $ numClassMethods == numInstanceMethods @@ -2891,34 +2936,35 @@ synthDictFromInstance :: DictType n -> SyntherM n (SynthAtom n) synthDictFromInstance targetTy@(DictType _ targetClass _) = do instances <- getInstanceDicts targetClass asum $ instances <&> \candidate -> do - CorePiType _ bs (EffTy _ (DictTy candidateTy)) <- lookupInstanceTy candidate - args <- instantiateSynthArgs targetTy $ Abs bs candidateTy + CorePiType _ expls bs (EffTy _ (DictTy candidateTy)) <- lookupInstanceTy candidate + args <- instantiateSynthArgs targetTy (expls, Abs bs candidateTy) return $ DictCon (DictTy targetTy) $ InstanceDict candidate args instantiateSynthArgs :: DictType n -> SynthPiType n -> SyntherM n [CAtom n] -instantiateSynthArgs targetTop (Abs bsTop resultTyTop) = do +instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = do ListE args <- (liftExceptAlt =<<) $ liftSolverM $ solveLocal do - args <- runSubstReaderT idSubst $ go (sink targetTop) (sink $ Abs bsTop resultTyTop) + args <- runSubstReaderT idSubst $ go (sink targetTop) explsTop (sink $ Abs bsTop resultTyTop) zonk $ ListE args forM args \case DictHole _ argTy req -> liftExceptAlt (typeAsSynthType argTy) >>= flip synthTerm req arg -> return arg where go :: EmitsInf o - => DictType o -> Abs (Nest (WithExpl CBinder)) DictType i + => DictType o -> [Explicitness] -> Abs (Nest CBinder) DictType i -> SubstReaderT AtomSubstVal SolverM i o [CAtom o] - go target (Abs bs proposed) = case bs of - Empty -> do + go target allExpls (Abs bs proposed) = case (allExpls, bs) of + ([], Empty) -> do proposed' <- substM proposed liftSubstReaderT $ unify target proposed' return [] - Nest (WithExpl expl b) rest -> do + (expl:expls, Nest b rest) -> do argTy <- substM $ binderType b arg <- liftSubstReaderT case expl of Explicit -> error "instances shouldn't have explicit args" Inferred _ Unify -> Var <$> freshInferenceName MiscInfVar argTy Inferred _ (Synth req) -> return $ DictHole (AlwaysEqual emptySrcPosCtx) argTy req - liftM (arg:) $ extendSubst (b@>SubstVal arg) $ go target (Abs rest proposed) + liftM (arg:) $ extendSubst (b@>SubstVal arg) $ go target expls (Abs rest proposed) + _ -> error "zip error" synthDictForData :: forall n. DictType n -> SyntherM n (SynthAtom n) synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of @@ -2991,6 +3037,9 @@ instance ExprVisitorNoEmits (DictSynthTraverserM i o) CoreIR i o where class DictSynthTraversable (e::E) where dsTraverse :: e i -> DictSynthTraverserM i o (e o) +instance DictSynthTraversable (TopLam CoreIR) where + dsTraverse (TopLam d ty lam) = TopLam d <$> visitPiDefault ty <*> visitLamNoEmits lam + instance DictSynthTraversable CAtom where dsTraverse atom = case atom of DictHole (AlwaysEqual ctx) ty access -> do @@ -2999,12 +3048,11 @@ instance DictSynthTraversable CAtom where case ans of Failure errs -> put (LiftE errs) >> renameM atom Success d -> return d - Lam (CoreLamExpr piTy@(CorePiType _ bsPi _) (LamExpr bsLam body)) -> do + Lam (CoreLamExpr piTy@(CorePiType _ expls _ _) (LamExpr bsLam (Abs decls result))) -> do Pi piTy' <- dsTraverse $ Pi piTy - let (expls, _) = unzipExpls bsPi - lam' <- dsTraverseExplBinders (zipExpls expls bsLam) \bsLamExpl' -> do - let (_, bsLam') = unzipExpls bsLamExpl' - LamExpr bsLam' <$> dsTraverse body + lam' <- dsTraverseExplBinders expls bsLam \bsLam' -> do + visitDeclsNoEmits decls \decls' -> do + LamExpr bsLam' <$> Abs decls' <$> dsTraverse result return $ Lam $ CoreLamExpr piTy' lam' Var _ -> renameM atom SimpInCore _ -> renameM atom @@ -3013,28 +3061,27 @@ instance DictSynthTraversable CAtom where instance DictSynthTraversable CType where dsTraverse ty = case ty of - Pi (CorePiType appExpl bs (EffTy effs resultTy)) -> Pi <$> - dsTraverseExplBinders bs \bs' -> do - CorePiType appExpl bs' <$> (EffTy <$> renameM effs <*> dsTraverse resultTy) + Pi (CorePiType appExpl expls bs (EffTy effs resultTy)) -> Pi <$> + dsTraverseExplBinders expls bs \bs' -> do + CorePiType appExpl expls bs' <$> (EffTy <$> renameM effs <*> dsTraverse resultTy) TyVar _ -> renameM ty ProjectEltTy _ _ _ -> renameM ty _ -> visitTypePartial ty instance DictSynthTraversable DataConDefs where dsTraverse = visitGeneric -instance DictSynthTraversable (Block CoreIR) where - dsTraverse = visitBlockNoEmits dsTraverseExplBinders - :: Nest (WithExpl CBinder) i i' - -> (forall o'. DExt o o' => Nest (WithExpl CBinder) o o' -> DictSynthTraverserM i' o' a) + :: [Explicitness] -> Nest CBinder i i' + -> (forall o'. DExt o o' => Nest CBinder o o' -> DictSynthTraverserM i' o' a) -> DictSynthTraverserM i o a -dsTraverseExplBinders Empty cont = getDistinct >>= \Distinct -> cont Empty -dsTraverseExplBinders (Nest (WithExpl expl b) bs) cont = do +dsTraverseExplBinders [] Empty cont = getDistinct >>= \Distinct -> cont Empty +dsTraverseExplBinders (expl:expls) (Nest b bs) cont = do ty <- dsTraverse $ binderType b withFreshBinder (getNameHint b) ty \b' -> do let v = binderName b' extendSynthCandidatesDict expl v $ extendRenamer (b@>v) do - dsTraverseExplBinders bs \bs' -> cont $ Nest (WithExpl expl b') bs' + dsTraverseExplBinders expls bs \bs' -> cont $ Nest b' bs' +dsTraverseExplBinders _ _ _ = error "zip error" extendSynthCandidatesDict :: Explicitness -> CAtomName n -> DictSynthTraverserM i n a -> DictSynthTraverserM i n a extendSynthCandidatesDict c v cont = DictSynthTraverserM do @@ -3053,11 +3100,22 @@ extendSynthCandidatesDict c v cont = DictSynthTraverserM do -- the needs of inference, like adding `SubstE AtomSubstVal e` constraints in -- various places. +type WithExpl = WithAttrB Explicitness +type WithRoleExpl = WithAttrB RoleExpl + buildBlockInf :: EmitsInf n => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (CAtom l)) -> InfererM i n (CBlock n) -buildBlockInf cont = buildDeclsInf (cont >>= withType) >>= computeAbsEffects >>= absToBlock +buildBlockInf cont = do + Abs decls (PairE result ty) <- buildDeclsInf do + ans <- cont + ty <- cheapNormalize $ getType ans + return $ PairE ans ty + let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line + <> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line + void $ liftHoistExcept' (docAsStr msg) $ hoist decls ty + return $ Abs decls result {-# INLINE buildBlockInf #-} buildBlockInfWithRecon @@ -3066,10 +3124,9 @@ buildBlockInfWithRecon -> InfererM i n (PairE CBlock (ReconAbs CoreIR e) n) buildBlockInfWithRecon cont = do ab <- buildDeclsInfUnzonked cont - (declsResult, recon) <- refreshAbs ab \decls result -> do + (block, recon) <- refreshAbs ab \decls result -> do (newResult, recon) <- telescopicCapture decls result return (Abs decls newResult, recon) - block <- makeBlockFromDecls declsResult return $ PairE block recon {-# INLINE buildBlockInfWithRecon #-} @@ -3079,10 +3136,8 @@ buildTabPiInf -> (forall l. (EmitsInf l, Ext n l) => CAtomVar l -> InfererM i l (CType l)) -> InfererM i n (TabPiType CoreIR n) buildTabPiInf hint (IxType t d) body = do - Abs (WithExpl _ (b:>_)) resultTy <- - buildAbsInf hint Explicit t \v -> - withoutEffects $ body v - return $ TabPiType d (b:>t) resultTy + Abs b resultTy <- buildAbsInf hint Explicit t \v -> withoutEffects $ body v + return $ TabPiType d b resultTy buildDepPairTyInf :: EmitsInf n @@ -3091,7 +3146,7 @@ buildDepPairTyInf -> InfererM i n (DepPairType CoreIR n) buildDepPairTyInf hint expl ty body = do Abs b resultTy <- buildAbsInf hint Explicit ty body - return $ DepPairType expl (withoutExpl b) resultTy + return $ DepPairType expl b resultTy buildAltInf :: EmitsInf n @@ -3099,11 +3154,10 @@ buildAltInf -> (forall l. (EmitsBoth l, Ext n l) => CAtomVar l -> InfererM i l (CAtom l)) -> InfererM i n (Alt CoreIR n) buildAltInf ty body = do - Abs b body' <- buildAbsInf noHint Explicit ty \v -> + buildAbsInf noHint Explicit ty \v -> buildBlockInf do Distinct <- getDistinct body $ sink v - return $ Abs (withoutExpl b) body' -- === EmitsInf predicate === @@ -3173,11 +3227,11 @@ instance BindsEnv InfOutFrag where toEnvFrag (InfOutFrag frag _ _) = toEnvFrag frag instance GenericE SynthType where - type RepE SynthType = EitherE2 DictType (Abs (Nest (WithExpl CBinder)) DictType) + type RepE SynthType = EitherE2 DictType (PairE (LiftE [Explicitness]) (Abs (Nest CBinder) DictType)) fromE (SynthDictType d) = Case0 d - fromE (SynthPiType t) = Case1 t + fromE (SynthPiType (expl, t)) = Case1 (PairE (LiftE expl) t) toE (Case0 d) = SynthDictType d - toE (Case1 t) = SynthPiType t + toE (Case1 (PairE (LiftE expl) t)) = SynthPiType (expl, t) toE _ = error "impossible" instance AlphaEqE SynthType diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index d250a8d4d..bcf5bdc64 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -22,14 +22,11 @@ import Types.Primitives -- === External API === -inlineBindings :: (EnvReader m) => SLam n -> m n (SLam n) -inlineBindings = liftLamExpr inlineBindingsBlock +inlineBindings :: (EnvReader m) => STopLam n -> m n (STopLam n) +inlineBindings = liftLamExpr \(Abs decls ans) -> liftInlineM $ + buildScoped $ inlineDecls decls $ inline Stop ans {-# INLINE inlineBindings #-} -inlineBindingsBlock :: (EnvReader m) => SBlock n -> m n (SBlock n) -inlineBindingsBlock blk = liftInlineM $ buildScopedAssumeNoDecls $ inline Stop blk -{-# SCC inlineBindingsBlock #-} - -- === Data Structure === data InlineExpr (r::IR) (o::S) where @@ -220,7 +217,7 @@ inlineDeclsSubst = \case ixDepthExpr _ = 0 ixDepthBlock :: Block SimpIR n -> Int ixDepthBlock (exprBlock -> (Just expr)) = ixDepthExpr expr - ixDepthBlock (AtomicBlock result) = ixDepthExpr $ Atom result + ixDepthBlock (Abs Empty result) = ixDepthExpr $ Atom result ixDepthBlock _ = 0 -- Should we decide to inline this binding wherever it appears, before we even @@ -316,9 +313,10 @@ instance Inlinable SType where inline ctx ty = visitTypePartial ty >>= reconstruct ctx instance Inlinable SLam where - inline ctx (LamExpr bs body) = do + inline ctx (LamExpr bs (Abs decls ans)) = do reconstruct ctx =<< withBinders bs \bs' -> do - LamExpr bs' <$> (buildScopedAssumeNoDecls $ inline Stop body) + (LamExpr bs' <$>) $ buildScoped $ + inlineDecls decls $ inline Stop ans withBinders :: Nest SBinder i i' @@ -337,18 +335,8 @@ instance Inlinable (PiType SimpIR) where effTy' <- buildScopedAssumeNoDecls $ inline Stop effTy return $ PiType bs' effTy' -instance Inlinable SBlock where - inline ctx (Block ann decls ans) = case (ann, decls) of - (NoBlockAnn, Empty) -> - (Block NoBlockAnn Empty <$> inline Stop ans) >>= reconstruct ctx - (NoBlockAnn, _) -> error "should be unreachable" - (BlockAnn effTy, _) -> do - (Abs decls' ans') <- buildScoped $ inlineDecls decls $ inline Stop ans - effTy' <- inline Stop effTy - reconstruct ctx $ Block (BlockAnn effTy') decls' ans' - inlineBlockEmits :: Emits o => Context SExpr e2 o -> SBlock i -> InlineM i o (e2 o) -inlineBlockEmits ctx (Block _ decls ans) = do +inlineBlockEmits ctx (Abs decls ans) = do inlineDecls decls $ inlineAtom ctx ans -- Still using InlineM because we may call back into inlining, and we wish to @@ -369,7 +357,7 @@ reconstructTabApp ctx expr [] = do reconstruct ctx expr reconstructTabApp ctx expr ixs = case fromNaryForExpr (length ixs) expr of - Just (bsCount, LamExpr bs (Block _ decls result)) -> do + Just (bsCount, LamExpr bs (Abs decls result)) -> do let (ixsPref, ixsRest) = splitAt bsCount ixs -- Note: There's a decision here. Is it ok to inline the atoms in -- `ixsPref` into the body `decls`? If so, should we pre-process them and diff --git a/src/lib/JAX/ToSimp.hs b/src/lib/JAX/ToSimp.hs index a3b012ea6..e2e183955 100644 --- a/src/lib/JAX/ToSimp.hs +++ b/src/lib/JAX/ToSimp.hs @@ -30,8 +30,8 @@ liftJaxSimpM :: (EnvReader m) => JaxSimpM n n (e n) -> m n (e n) liftJaxSimpM act = liftBuilder $ runSubstReaderT idSubst $ runJaxSimpM act {-# INLINE liftJaxSimpM #-} -simplifyClosedJaxpr :: ClosedJaxpr i -> JaxSimpM i o (LamExpr SimpIR o) -simplifyClosedJaxpr ClosedJaxpr{jaxpr, consts=[]} = simplifyJaxpr jaxpr +simplifyClosedJaxpr :: ClosedJaxpr i -> JaxSimpM i o (TopLam SimpIR o) +simplifyClosedJaxpr ClosedJaxpr{jaxpr, consts=[]} = asTopLam =<< simplifyJaxpr jaxpr simplifyClosedJaxpr _ = error "TODO Support consts" simplifyJaxpr :: Jaxpr i -> JaxSimpM i o (LamExpr SimpIR o) diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 32944c5f7..3d6b91a3f 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -4,7 +4,7 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Linearize (linearize, linearizeLam) where +module Linearize (linearize, linearizeTopLam) where import Control.Category ((>>>)) import Control.Monad.Reader @@ -269,10 +269,9 @@ linearizeBlockDefuncGeneral locals block = do WithTangent primalResult tangentFun <- linearizeBlock block lam <- tangentFunAsLambda tangentFun return $ PairE primalResult lam - (blockAbs, recon) <- refreshAbs (Abs decls result) \decls' (PairE primal lam) -> do + (block', recon) <- refreshAbs (Abs decls result) \decls' (PairE primal lam) -> do (primal', recon) <- capture (locals >>> toScopeFrag decls') primal lam return (Abs decls' primal', recon) - block' <- makeBlockFromDecls blockAbs return (block', recon) -- Inverse of tangentFunAsLambda. Should be used inside a returned tangent action. @@ -289,9 +288,9 @@ linearize :: Emits n => SLam n -> SAtom n -> DoubleBuilder SimpIR n (SAtom n, SL linearize f x = runPrimalMInit $ linearizeLambdaApp f x {-# SCC linearize #-} -linearizeLam :: SLam n -> [Active] -> DoubleBuilder SimpIR n (SLam n, SLam n) -linearizeLam (LamExpr bs body) actives = runPrimalMInit do - refreshBinders bs \bs' frag -> extendSubst frag do +linearizeTopLam :: STopLam n -> [Active] -> DoubleBuilder SimpIR n (STopLam n, STopLam n) +linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do + (primalFun, tangentFun) <- runPrimalMInit $ refreshBinders bs \bs' frag -> extendSubst frag do let allPrimals = nestToAtomVars bs' activeVs <- catMaybes <$> forM (zip actives allPrimals) \(active, v) -> case active of True -> return $ Just v @@ -312,6 +311,8 @@ linearizeLam (LamExpr bs body) actives = runPrimalMInit do emitBlock =<< applySubst substFrag tangentBody return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody' return (primalFun, tangentFun) + (,) <$> asTopLam primalFun <*> asTopLam tangentFun +linearizeTopLam (TopLam True _ _) _ = error "expected a non-destination-passing function" -- reify the tangent builder as a lambda linearizeLambdaApp :: Emits o => SLam i -> SAtom o -> PrimalM i o (SAtom o, SLam o) @@ -343,7 +344,7 @@ linearizeAtom atom = case atom of where emitZeroT = withZeroT $ renameM atom linearizeBlock :: Emits o => SBlock i -> LinM i o SAtom SAtom -linearizeBlock (Block _ decls result) = +linearizeBlock (Abs decls result) = linearizeDecls decls $ linearizeAtom result linearizeDecls :: Emits o => Nest SDecl i i' -> LinM i' o e1 e2 -> LinM i o e1 e2 @@ -624,7 +625,7 @@ linearizeHof hof = case hof of WithTangent sInit' sLin <- linearizeAtom sInit (lam', recon) <- linearizeEffectFun State lam (primalAux, sFinal) <- fromPair =<< emitHof (RunState Nothing sInit' lam') - referentTy <- return $ snd $ getTypeRWSAction lam' + referentTy <- snd <$> getTypeRWSAction lam' (primal, linLam) <- reconstruct primalAux recon return $ WithTangent (PairVal primal sFinal) do sLin' <- sLin @@ -639,7 +640,7 @@ linearizeHof hof = case hof of (lam', recon) <- linearizeEffectFun Writer lam (primalAux, wFinal) <- fromPair =<< emitHof (RunWriter Nothing bm' lam') (primal, linLam) <- reconstruct primalAux recon - referentTy <- return $ snd $ getTypeRWSAction lam' + referentTy <- snd <$> getTypeRWSAction lam' return $ WithTangent (PairVal primal wFinal) do bm'' <- sinkM bm' tt <- tangentType $ sink referentTy diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 6c8f30df2..bce5b8050 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -7,8 +7,7 @@ {-# LANGUAGE UndecidableInstances #-} module Lower - ( lowerFullySequential, lowerFullySequentialNoDest - , DestLamExpr, DestBlock + ( lowerFullySequential, DestBlock ) where import Prelude hiding ((.)) @@ -29,7 +28,7 @@ import Subst import QueryType import Types.Core import Types.Primitives -import Util (enumerate, foldMapM) +import Util (enumerate) -- === For loop resolution === @@ -60,30 +59,34 @@ import Util (enumerate, foldMapM) -- destination to a sub-block or sub-expression, hence "desintation -- passing style"). -type DestLamExpr = SLam type DestBlock = Abs (SBinder) SBlock -lowerFullySequential :: EnvReader m => SLam n -> m n (DestLamExpr n) -lowerFullySequential (LamExpr bs body) = liftEnvReaderM $ do - refreshAbs (Abs bs body) \bs' body' -> do - Abs b body'' <- lowerFullySequentialBlock body' - return $ LamExpr (bs' >>> UnaryNest b) body'' - -lowerFullySequentialBlock :: EnvReader m => SBlock n -> m n (DestBlock n) -lowerFullySequentialBlock b = liftAtomSubstBuilder do - resultDestTy <- RawRefTy <$> substM (getType b) +lowerFullySequential :: EnvReader m => Bool -> STopLam n -> m n (STopLam n) +lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftEnvReaderM $ do + lam <- case wantDestStyle of + True -> do + refreshAbs (Abs bs body) \bs' body' -> do + xs <- bindersToAtoms bs' + EffTy _ resultTy <- instantiatePiTy (sink piTy) xs + Abs b body'' <- lowerFullySequentialBlock resultTy body' + return $ LamExpr (bs' >>> UnaryNest b) body'' + False -> do + refreshAbs (Abs bs body) \bs' body' -> do + body'' <- lowerFullySequentialBlockNoDest body' + return $ LamExpr bs' body'' + piTy' <- getLamExprType lam + return $ TopLam wantDestStyle piTy' lam +lowerFullySequential _ (TopLam True _ _) = error "already in destination style" + +lowerFullySequentialBlock :: EnvReader m => SType n -> SBlock n -> m n (DestBlock n) +lowerFullySequentialBlock resultTy b = liftAtomSubstBuilder do + let resultDestTy = RawRefTy resultTy withFreshBinder (getNameHint @String "ans") resultDestTy \destBinder -> do Abs destBinder <$> buildBlock do let dest = Var $ sink $ binderVar destBinder lowerBlockWithDest dest b $> UnitVal {-# SCC lowerFullySequentialBlock #-} -lowerFullySequentialNoDest :: EnvReader m => SLam n -> m n (SLam n) -lowerFullySequentialNoDest (LamExpr bs body) = liftEnvReaderM $ do - refreshAbs (Abs bs body) \bs' body' -> do - body'' <- lowerFullySequentialBlockNoDest body' - return $ LamExpr bs' body'' - lowerFullySequentialBlockNoDest :: EnvReader m => SBlock n -> m n (SBlock n) lowerFullySequentialBlockNoDest b = liftAtomSubstBuilder $ buildBlock $ lowerBlock b {-# SCC lowerFullySequentialBlockNoDest #-} @@ -150,12 +153,12 @@ lowerFor _ _ _ _ _ = error "expected a unary lambda expression" lowerTabCon :: forall i o. Emits o => Maybe (Dest SimpIR o) -> SType i -> [SAtom i] -> LowerM i o (SExpr o) lowerTabCon maybeDest tabTy elems = do - tabTy'@(TabPi (TabPiType dict (_:>t) _)) <- substM tabTy + TabPi tabTy' <- substM tabTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest tabTy' + Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest $ TabPi tabTy' Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do - buildBlock $ unsafeFromOrdinal (sink $ IxType t dict) $ Var $ sink ord + buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ Var $ sink ord -- This is emitting a chain of RememberDest ops to force `dest` to be used -- linearly, and to force reads of the `Freeze dest'` result not to be -- reordered in front of the writes. @@ -190,8 +193,7 @@ lowerCase maybeDest scrut alts resultTy = do extendSubst (b @> Rename (atomVarName b')) $ buildBlock do lowerBlockWithDest (Var $ sink $ local_dest) body $> UnitVal - eff' <- foldMapM (pure . getEffects) alts' - void $ emitExpr $ Case (sink scrut') alts' (EffTy eff' UnitTy) + void $ mkCase (sink scrut') UnitTy alts' >>= emitExpr return UnitVal return $ PrimOp $ DAMOp $ Freeze dest' @@ -243,7 +245,7 @@ decomposeDest dest = \case _ -> return Nothing lowerBlockWithDest :: Emits o => Dest SimpIR o -> SBlock i -> LowerM i o (SAtom o) -lowerBlockWithDest dest (Block _ decls ans) = do +lowerBlockWithDest dest (Abs decls ans) = do decomposeDest dest ans >>= \case Nothing -> do ans' <- visitDeclsEmits decls $ visitAtom ans diff --git a/src/lib/Name.hs b/src/lib/Name.hs index ddb01ab8d..68d1ad2f8 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -521,6 +521,20 @@ data PairB (b1::B) (b2::B) (n::S) (l::S) where PairB :: b1 n l' -> b2 l' l -> PairB b1 b2 n l deriving instance (ShowB b1, ShowB b2) => Show (PairB b1 b2 n l) +data WithAttrB (a:: *) (b::B) (n::S) (l::S) = + WithAttrB {getAttr :: a , withoutAttr :: b n l } + deriving (Show, Generic) + +unzipAttrs :: Nest (WithAttrB a b) n l -> ([a], Nest b n l) +unzipAttrs Empty = ([], Empty) +unzipAttrs (Nest (WithAttrB a b) rest) = (a:as, Nest b bs) + where (as, bs) = unzipAttrs rest + +zipAttrs :: [a] -> Nest b n l -> Nest (WithAttrB a b) n l +zipAttrs [] Empty = Empty +zipAttrs (a:as) (Nest b bs) = Nest (WithAttrB a b) (zipAttrs as bs) +zipAttrs _ _ = error "zip error" + data EitherB (b1::B) (b2::B) (n::S) (l::S) = LeftB (b1 n l) | RightB (b2 n l) @@ -655,7 +669,7 @@ forNest :: Nest b i i' -> Nest b' i i' forNest n f = fmapNest f n -zipWithNest :: Nest b n l -> [a] +zipWithNest :: Nest b n l -> [a] -> (forall n1 n2. b n1 n2 -> a -> b' n1 n2) -> Nest b' n l zipWithNest Empty [] _ = Empty @@ -3195,6 +3209,35 @@ instance Monad HoistExcept where HoistSuccess x >>= f = f x {-# INLINE (>>=) #-} +instance (Store a, Store (b n l)) => Store (WithAttrB a b n l) + +instance (Eq a, AlphaEqB b) => AlphaEqB (WithAttrB a b) where + withAlphaEqB (WithAttrB a1 b1) (WithAttrB a2 b2) cont = do + unless (a1 == a2) zipErr + withAlphaEqB b1 b2 cont + +instance (Hashable a, AlphaHashableB b) => AlphaHashableB (WithAttrB a b) where + hashWithSaltB env salt (WithAttrB expl b) = do + let h = hashWithSalt salt expl + hashWithSaltB env h b + +instance BindsNames b => ProvesExt (WithAttrB a b) where +instance BindsNames b => BindsNames (WithAttrB a b) where + toScopeFrag (WithAttrB _ b) = toScopeFrag b + +instance (SinkableB b) => SinkableB (WithAttrB a b) where + sinkingProofB fresh (WithAttrB a b) cont = + sinkingProofB fresh b \fresh' b' -> + cont fresh' (WithAttrB a b') + +instance (BindsNames b, RenameB b) => RenameB (WithAttrB a b) where + renameB env (WithAttrB a b) cont = + renameB env b \env' b' -> + cont env' $ WithAttrB a b' + +instance HoistableB b => HoistableB (WithAttrB a b) where + freeVarsB (WithAttrB _ b) = freeVarsB b + -- === extra data structures === -- A map from names in some scope to values that do not contain names. This is diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 11e1be2c3..bff364f47 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -28,12 +28,12 @@ import QueryType -- annotation holding a summary of how that binding is used. It also eliminates -- unused pure bindings as it goes, since it has all the needed information. -analyzeOccurrences :: EnvReader m => SLam n -> m n (SLam n) +analyzeOccurrences :: EnvReader m => STopLam n -> m n (STopLam n) analyzeOccurrences = liftLamExpr analyzeOccurrencesBlock {-# INLINE analyzeOccurrences #-} analyzeOccurrencesBlock :: EnvReader m => SBlock n -> m n (SBlock n) -analyzeOccurrencesBlock = liftOCCM . occ accessOnce +analyzeOccurrencesBlock = liftOCCM . occNest accessOnce {-# SCC analyzeOccurrencesBlock #-} -- === Overview === @@ -254,7 +254,7 @@ occTy ty = occ accessOnce ty instance HasOCC SLam where occ a (LamExpr bs body) = do lam@(LamExpr bs' _) <- refreshAbs (Abs bs body) \bs' body' -> - LamExpr bs' <$> occ (sink a) body' + LamExpr bs' <$> occNest (sink a) body' countFreeVarsAsOccurrencesB bs' return lam @@ -269,15 +269,6 @@ instance HasOCC (PiType SimpIR) where countFreeVarsAsOccurrencesB bs' return piTy -instance HasOCC SBlock where - occ a (Block ann decls ans) = case (ann, decls) of - (NoBlockAnn , Empty) -> Block NoBlockAnn Empty <$> occ a ans - (NoBlockAnn , _ ) -> error "should be unreachable" - (BlockAnn effTy, _ ) -> do - Abs decls' ans' <- occNest a decls ans - effTy' <- occ a effTy - return $ Block (BlockAnn effTy') decls' ans' - instance HasOCC (EffTy SimpIR) where occ _ (EffTy effs ty) = do ty' <- occTy ty @@ -288,17 +279,17 @@ data ElimResult (n::S) where ElimSuccess :: Abs (Nest SDecl) SAtom n -> ElimResult n ElimFailure :: SDecl n l -> UsageInfo -> Abs (Nest SDecl) SAtom l -> ElimResult n -occNest :: Access n -> Nest SDecl n l -> SAtom l +occNest :: Access n -> Abs (Nest SDecl) SAtom n -> OCCM n (Abs (Nest SDecl) SAtom n) -occNest a decls ans = case decls of +occNest a (Abs decls ans) = case decls of Empty -> Abs Empty <$> occ a ans Nest d@(Let _ binding) ds -> do isPureDecl <- return $ isPure binding dceAttempt <- refreshAbs (Abs d (Abs ds ans)) - \d'@(Let b' (DeclBinding _ expr')) (Abs ds' ans') -> do + \d'@(Let b' (DeclBinding _ expr')) rest -> do exprIx <- summaryExpr $ sink expr' extend b' exprIx do - below <- occNest (sink a) ds' ans' + below <- occNest (sink a) rest checkAllFreeVariablesMentioned below accessInfo <- getAccessInfo $ binderName d' let usage = usageInfo accessInfo @@ -387,7 +378,7 @@ occAlt acc scrut alt = do -- case statement in that event. scrutIx <- unknown $ sink scrut extend nb scrutIx do - body' <- occ (sink acc) body + body' <- occNest (sink acc) body return $ Abs b body' ty' <- occTy ty return $ Abs (b':>ty') body' @@ -407,12 +398,12 @@ instance HasOCC (Hof SimpIR) where ixDict' <- inlinedLater ixDict occWithBinder (Abs b body) \b' body' -> do extend b' (Occ.Var $ binderName b') do - (body'', bodyFV) <- isolated (occ accessOnce body') + (body'', bodyFV) <- isolated (occNest accessOnce body') modify (<> abstractFor b' bodyFV) return $ For ann ixDict' (UnaryLamExpr b' body'') For _ _ _ -> error "For body should be a unary lambda expression" While body -> While <$> do - (body', bodyFV) <- isolated (occ accessOnce body) + (body', bodyFV) <- isolated $ occNest accessOnce body modify (<> useManyTimes bodyFV) return body' RunReader ini bd -> do @@ -451,14 +442,15 @@ instance HasOCC (Hof SimpIR) where return $ RunState Nothing ini' bd' RunState (Just _) _ _ -> error "Expecting to do occurrence analysis before destination passing." - RunIO bd -> RunIO <$> occ a bd + RunIO bd -> RunIO <$> occNest a bd RunInit _ -> -- Though this is probably not too hard to implement. Presumably -- the lambda is one-shot. error "Expecting to do occurrence analysis before lowering." oneShot :: Access n -> [IxExpr n] -> LamExpr SimpIR n -> OCCM n (LamExpr SimpIR n) -oneShot acc [] (LamExpr Empty body) = LamExpr Empty <$> occ acc body +oneShot acc [] (LamExpr Empty body) = + LamExpr Empty <$> occNest acc body oneShot acc (ix:ixs) (LamExpr (Nest b bs) body) = do occWithBinder (Abs b (LamExpr bs body)) \b' restLam -> extend b' (sink ix) do diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 262ef3ae0..e4331e484 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -31,7 +31,7 @@ import QueryType import Util (iota) import Err -optimize :: EnvReader m => SLam n -> m n (SLam n) +optimize :: EnvReader m => STopLam n -> m n (STopLam n) optimize = dceTop -- Clean up user code >=> unrollLoops >=> dceTop -- Clean up peephole-optimized code after unrolling @@ -208,7 +208,7 @@ peepholeExpr expr = case expr of -- === Loop unrolling === -unrollLoops :: EnvReader m => SLam n -> m n (SLam n) +unrollLoops :: EnvReader m => STopLam n -> m n (STopLam n) unrollLoops = liftLamExpr unrollLoopsBlock unrollLoopsBlock :: EnvReader m => SBlock n -> m n (SBlock n) @@ -240,7 +240,7 @@ ulBlock :: SBlock i -> ULM i o (SBlock o) ulBlock b = buildBlock $ visitBlockEmits b emitSubstBlock :: Emits o => SBlock i -> ULM i o (SAtom o) -emitSubstBlock (Block _ decls ans) = visitDeclsEmits decls $ visitAtom ans +emitSubstBlock (Abs decls ans) = visitDeclsEmits decls $ visitAtom ans -- TODO: Refine the cost accounting so that operations that will become -- constant-foldable after inlining don't count towards it. @@ -257,7 +257,7 @@ ulExpr expr = case expr of vals <- dropSubst $ forM (iota n) \i -> do extendSubst (b' @> SubstVal (IdxRepVal i)) $ emitSubstBlock block' inc $ fromIntegral n -- To account for the TabCon we emit below - case getLamExprType body' of + getLamExprType body' >>= \case PiType (UnaryNest (tb:>_)) (EffTy _ valTy) -> do let tabTy = TabPi $ TabPiType (IxDictRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy emitExpr $ TabCon Nothing tabTy vals @@ -305,7 +305,7 @@ hoistLoopInvariantBlock :: EnvReader m => SBlock n -> m n (SBlock n) hoistLoopInvariantBlock body = liftLICMM $ buildBlock $ visitBlockEmits body {-# SCC hoistLoopInvariantBlock #-} -hoistLoopInvariant :: EnvReader m => SLam n -> m n (SLam n) +hoistLoopInvariant :: EnvReader m => STopLam n -> m n (STopLam n) hoistLoopInvariant = liftLamExpr hoistLoopInvariantBlock {-# INLINE hoistLoopInvariant #-} @@ -317,11 +317,11 @@ licmExpr = \case let numCarriesOriginal = length dests' Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do -- First, traverse the block, to allow any Hofs inside it to hoist their own decls. - Block _ decls ans <- buildBlock $ visitBlockEmits body + Abs decls ans <- buildBlock $ visitBlockEmits body -- Now, we process the decls and decide which ones to hoist. liftEnvReaderM $ runSubstReaderT idSubst $ seqLICM REmpty mempty (asNameBinder b') REmpty decls ans - PairE (ListE extraDests) ab <- emitDecls hdecls destsAndBody + PairE (ListE extraDests) ab <- emitDecls $ Abs hdecls destsAndBody extraDests' <- mapM toAtomVar extraDests -- Append the destinations of hoisted Allocs as loop carried values. let dests'' = ProdVal $ dests' ++ (Var <$> extraDests') @@ -334,19 +334,19 @@ licmExpr = \case (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpacked allCarries let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) let s = extraDestBs @@> map SubstVal newCarries <.> lb @> SubstVal oldLoopBinderVal - block <- applySubst s bodyAbs >>= makeBlockFromDecls + block <- applySubst s bodyAbs return $ UnaryLamExpr lb' block emitSeq dir ix' dests'' body' PrimOp (Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body)))) -> do ix' <- substM ix Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do - Block _ decls ans <- buildBlock $ visitBlockEmits body + Abs decls ans <- buildBlock $ visitBlockEmits body liftEnvReaderM $ runSubstReaderT idSubst $ seqLICM REmpty mempty (asNameBinder b') REmpty decls ans - PairE (ListE []) (Abs lnb bodyAbs) <- emitDecls hdecls destsAndBody + PairE (ListE []) (Abs lnb bodyAbs) <- emitDecls $ Abs hdecls destsAndBody ixTy <- substM $ binderType b body' <- withFreshBinder noHint ixTy \i -> do - block <- applyRename (lnb@>binderName i) bodyAbs >>= makeBlockFromDecls + block <- applyRename (lnb@>binderName i) bodyAbs return $ UnaryLamExpr i block emitHof $ For dir ix' body' expr -> visitGeneric expr >>= emitExpr @@ -400,12 +400,12 @@ newtype DCEM n a = DCEM { runDCEM :: StateT1 FV EnvReaderM n a } deriving ( Functor, Applicative, Monad, EnvReader, ScopeReader , MonadState (FV n), EnvExtender) -dceTop :: EnvReader m => SLam n -> m n (SLam n) +dceTop :: EnvReader m => STopLam n -> m n (STopLam n) dceTop = liftLamExpr dceBlock {-# INLINE dceTop #-} dceBlock :: EnvReader m => SBlock n -> m n (SBlock n) -dceBlock b = liftEnvReaderM $ evalStateT1 (runDCEM $ dce b) mempty +dceBlock b = liftEnvReaderM $ evalStateT1 (runDCEM $ dceBlock' b) mempty {-# SCC dceBlock #-} class HasDCE (e::E) where @@ -432,32 +432,36 @@ instance HasDCE SAtom where instance HasDCE SType where dce = visitTypePartial instance HasDCE (PiType SimpIR) where dce (PiType bs effTy) = do - Abs bs' effTy' <- dce (Abs bs effTy) - return $ PiType bs' effTy' + dceBinders bs effTy \bs' effTy' -> PiType bs' <$> dce effTy' instance HasDCE (LamExpr SimpIR) where - dce (LamExpr bs e) = do - Abs bs' e' <- dce (Abs bs e) - return $ LamExpr bs' e' - -instance HasDCE SBlock where - dce (Block ann decls ans) = case (ann, decls) of - (NoBlockAnn , Empty) -> Block NoBlockAnn Empty <$> dce ans - (NoBlockAnn , _ ) -> error "should be unreachable" - (BlockAnn effTy, _ ) -> do - -- The free vars accumulated in the state of DCEM should correspond to - -- the free vars of the Abs of the block answer, by the decls traversed - -- so far. dceNest takes care to uphold this invariant, but we temporarily - -- reset the state to an empty map, just so that names from the surrounding - -- block don't end up influencing elimination decisions here. Note that we - -- restore the state (and accumulate free vars of the DCE'd block into it) - -- right after dceNest. - old <- get - put mempty - Abs decls' ans' <- dceNest decls ans - modify (<> old) - effTy' <- dce effTy - return $ Block (BlockAnn effTy') decls' ans' + dce (LamExpr bs e) = dceBinders bs e \bs' e' -> LamExpr bs' <$> dceBlock' e' + +dceBinders + :: (HoistableB b, BindsEnv b, RenameB b, RenameE e) + => b n l -> e l + -> (forall l'. b n l' -> e l' -> DCEM l' a) + -> DCEM n a +dceBinders b e cont = do + ans <- refreshAbs (Abs b e) \b' e' -> cont b' e' + modify (<>FV (freeVarsB b)) + return ans +{-# INLINE dceBinders #-} + +dceBlock' :: SBlock n -> DCEM n (SBlock n) +dceBlock' (Abs decls ans) = do + -- The free vars accumulated in the state of DCEM should correspond to + -- the free vars of the Abs of the block answer, by the decls traversed + -- so far. dceNest takes care to uphold this invariant, but we temporarily + -- reset the state to an empty map, just so that names from the surrounding + -- block don't end up influencing elimination decisions here. Note that we + -- restore the state (and accumulate free vars of the DCE'd block into it) + -- right after dceNest. + old <- get + put mempty + block <- dceNest decls ans + modify (<> old) + return block data CachedFVs e n = UnsafeCachedFVs { _cachedFVs :: (NameSet n), fromCachedFVs :: (e n) } instance HoistableE (CachedFVs e) where @@ -512,13 +516,6 @@ dceNest decls ans = case decls of modify (<>FV (freeVarsB b')) return $ Abs (Nest (Let b' decl'') bs'') ans'' -instance (BindsEnv b, RenameB b, HoistableB b, RenameE e, HasDCE e) => HasDCE (Abs b e) where - dce a = do - a'@(Abs b' _) <- refreshAbs a \b e -> Abs b <$> dce e - modify (<>FV (freeVarsB b')) - return a' - {-# INLINE dce #-} - instance HasDCE (EffectRow SimpIR) instance HasDCE (DeclBinding SimpIR) instance HasDCE (EffTy SimpIR) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 2121eff5d..23bc7ea60 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -131,9 +131,9 @@ pArg :: PrettyPrec a => a -> Doc ann pArg a = prettyPrec a ArgPrec instance IRRep r => Pretty (Block r n) where - pretty (Block _ decls expr) = prettyBlock decls expr + pretty (Abs decls expr) = prettyBlock decls expr instance IRRep r => PrettyPrec (Block r n) where - prettyPrec (Block _ decls expr) = atPrec LowestPrec $ prettyBlock decls expr + prettyPrec (Abs decls expr) = atPrec LowestPrec $ prettyBlock decls expr prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann prettyBlock Empty expr = group $ line <> pLowest expr @@ -170,11 +170,6 @@ instance IRRep r => PrettyPrec (Expr r n) where prettyPrec (PrimOp op) = prettyPrec op prettyPrec (ApplyMethod _ d i xs) = atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs -instance Pretty (UserEffectOp n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UserEffectOp n) where - prettyPrec (Handle v args body) = atPrec LowestPrec $ p v <+> p args <+> prettyLam "\\_." body - prettyPrec _ = error "not implemented" - prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann prettyPrecCase name e alts effs = atPrec LowestPrec $ name <+> pApp e <+> "of" <> @@ -301,26 +296,26 @@ forStr Fwd = "for" forStr Rev = "rof" instance Pretty (CorePiType n) where - pretty (CorePiType appExpl bs (EffTy eff resultTy)) = - prettyBindersWithExpl bs <+> p appExpl <> prettyEff <> p resultTy + pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = + prettyBindersWithExpl expls bs <+> p appExpl <> prettyEff <> p resultTy where prettyEff = case eff of Pure -> space _ -> space <> pretty eff <> space prettyBindersWithExpl :: forall b n l ann. PrettyB b - => Nest (WithExpl b) n l -> Doc ann -prettyBindersWithExpl bs = do - let groups = groupByExpl $ fromNest bs + => [Explicitness] -> Nest b n l -> Doc ann +prettyBindersWithExpl expls bs = do + let groups = groupByExpl $ zip expls (fromNest bs) let groups' = case groups of [] -> [(Explicit, [])] _ -> groups mconcat [withExplParens expl $ commaSep bsGroup | (expl, bsGroup) <- groups'] -groupByExpl :: [WithExpl b UnsafeS UnsafeS] -> [(Explicitness, [b UnsafeS UnsafeS])] +groupByExpl :: [(Explicitness, b UnsafeS UnsafeS)] -> [(Explicitness, [b UnsafeS UnsafeS])] groupByExpl [] = [] -groupByExpl (WithExpl expl b:bs) = do - let (matches, rest) = span (\(WithExpl expl' _) -> expl == expl') bs - let matches' = map withoutExpl matches +groupByExpl ((expl, b):bs) = do + let (matches, rest) = span (\(expl', _) -> expl == expl') bs + let matches' = map snd matches (expl, b:matches') : groupByExpl rest withExplParens :: Explicitness -> Doc ann -> Doc ann @@ -357,18 +352,6 @@ prettyLam :: Pretty a => Doc ann -> a -> Doc ann prettyLam binders body = group $ group (nest 4 $ binders) <> group (nest 2 $ p body) -_inlineLastDeclBlock :: IRRep r => Block r n -> Abs (Nest (Decl r)) (Expr r) n -_inlineLastDeclBlock (Block _ decls expr) = inlineLastDecl decls expr - -inlineLastDecl :: IRRep r => Nest (Decl r) n l -> Atom r l -> Abs (Nest (Decl r)) (Expr r) n -inlineLastDecl Empty result = Abs Empty $ Atom result -inlineLastDecl (Nest (Let b (DeclBinding _ expr)) Empty) (Var (AtomVar v _)) - | v == binderName b = Abs Empty expr -inlineLastDecl (Nest decl rest) result = - case inlineLastDecl rest result of - Abs decls' result' -> - Abs (Nest decl decls') result' - instance IRRep r => Pretty (EffectRow r n) where pretty (EffectRow effs t) = braces $ hsep (punctuate "," (map p (eSetToList effs))) <> p t @@ -383,7 +366,6 @@ instance IRRep r => Pretty (Effect r n) where RWSEffect rws h -> p rws <+> p h ExceptionEffect -> "Except" IOEffect -> "IO" - UserEffect name -> p name InitEffect -> "Init" instance Pretty (UEffect n) where @@ -391,7 +373,6 @@ instance Pretty (UEffect n) where URWSEffect rws h -> p rws <+> p h UExceptionEffect -> "Except" UIOEffect -> "IO" - UUserEffect name -> p name instance PrettyPrec (Name s n) where prettyPrec = atPrec ArgPrec . pretty @@ -435,10 +416,6 @@ instance Pretty (Binding c n) where FunObjCodeBinding _ -> "" ModuleBinding _ -> "" PtrBinding _ _ -> "" - -- TODO(alex): do something actually useful here - EffectBinding _ -> "" - HandlerBinding _ -> "" - EffectOpBinding _ -> "" SpecializedDictBinding _ -> "" ImpNameBinding ty -> "Imp name of type: " <+> p ty @@ -454,36 +431,29 @@ instance Pretty (TyConParams n) where pretty (TyConParams _ _) = undefined instance Pretty (TyConDef n) where - pretty (TyConDef name bs cons) = - "data" <+> p name <+> (p $ map (\(RolePiBinder _ b) -> b) $ fromNest bs) <> pretty cons + pretty (TyConDef name _ bs cons) = "data" <+> p name <+> p bs <> pretty cons instance Pretty (DataConDefs n) where pretty = undefined -instance Pretty (RolePiBinder n l) where - pretty (RolePiBinder _ b) = pretty b - instance Pretty (DataConDef n) where pretty (DataConDef name _ repTy _) = p name <+> ":" <+> p repTy instance Pretty (ClassDef n) where - pretty (ClassDef classSourceName methodNames _ params superclasses methodTys) = + pretty (ClassDef classSourceName methodNames _ _ params superclasses methodTys) = "Class:" <+> pretty classSourceName <+> pretty methodNames <> indented ( - line <> "parameter binders:" <+> prettyRolePiBinders params <> + line <> "parameter binders:" <+> pretty params <> line <> "superclasses:" <+> pretty superclasses <> line <> "methods:" <+> pretty methodTys) instance Pretty ParamRole where pretty r = p (show r) -prettyRolePiBinders :: RolePiBinders n l -> Doc ann -prettyRolePiBinders = undefined - instance Pretty (InstanceDef n) where - pretty (InstanceDef className bs params _) = - "Instance" <+> p className <+> prettyRolePiBinders bs <+> p params + pretty (InstanceDef className _ bs params _) = + "Instance" <+> p className <+> pretty bs <+> p params deriving instance (forall c n. Pretty (v c n)) => Pretty (RecSubst v o) @@ -652,14 +622,11 @@ instance Pretty FieldName' where instance Pretty (UAlt n) where pretty (UAlt pat body) = p pat <+> "->" <+> p body -instance PrettyB b => Pretty (WithExpl b n l) where - pretty (WithExpl _ b) = pretty b - instance Pretty (UTopDecl n l) where - pretty (UDataDefDecl (UDataDef nm bs dataCons) bTyCon bDataCons) = + pretty (UDataDefDecl (UDataDef nm (_, bs) dataCons) bTyCon bDataCons) = "data" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 (prettyLines (zip (toList $ fromNest bDataCons) dataCons)) - pretty (UStructDecl bTyCon (UStructDef nm bs fields defs)) = + pretty (UStructDecl bTyCon (UStructDef nm (_, bs) fields defs)) = "struct" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 (prettyLines fields <> prettyLines defs) pretty (UInterface params methodTys interfaceName methodNames) = @@ -785,14 +752,16 @@ instance Pretty (TopFunDef n) where instance Pretty (TopFun n) where pretty = \case - DexTopFun def ty simp lowering -> + DexTopFun def lam lowering -> "Top-level Function" <> hardline <+> "definition:" <+> pretty def - <> hardline <+> "type:" <+> pretty ty - <> hardline <+> "simplified:" <+> pretty simp + <> hardline <+> "lambda:" <+> pretty lam <> hardline <+> "lowering:" <+> pretty lowering FFITopFun f _ -> p f +instance IRRep r => Pretty (TopLam r n) where + pretty (TopLam _ _ lam) = pretty lam + instance Pretty a => Pretty (EvalStatus a) where pretty = \case Waiting -> "" @@ -927,7 +896,6 @@ instance IRRep r => PrettyPrec (PrimOp r n) where MemOp op -> prettyPrec op VectorOp op -> prettyPrec op DAMOp op -> prettyPrec op - UserEffectOp op -> prettyPrec op Hof (TypedHof _ hof) -> prettyPrec hof RefOp ref eff -> atPrec LowestPrec case eff of MAsk -> "ask" <+> pApp ref diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index cf6f98a26..f5952402b 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -9,6 +9,7 @@ module QueryType (module QueryType, module QueryTypePure, toAtomVar) where import Control.Category ((>>>)) import Control.Monad import Data.List (elemIndex) +import Data.Functor ((<&>)) import Types.Primitives import Types.Core @@ -47,16 +48,32 @@ caseAltsBinderTys ty = case ty of extendEffect :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n extendEffect eff (EffectRow effs t) = EffectRow (effs <> eSetSingleton eff) t -getDestLamExprType :: LamExpr SimpIR n -> PiType SimpIR n -getDestLamExprType (LamExpr bsRefB body) = +blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) +blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do + effs <- declsEffects decls mempty + return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result + where + declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) + declsEffects Empty !acc = return acc + declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do + expr' <- sinkM expr + declsEffects rest $ acc <> getEffects expr' + +blockTy :: (EnvReader m, IRRep r) => Block r n -> m n (Type r n) +blockTy b = blockEffTy b <&> \(EffTy _ t) -> t + +piTypeWithoutDest :: PiType SimpIR n -> PiType SimpIR n +piTypeWithoutDest (PiType bsRefB _) = case popNest bsRefB of - Just (PairB bs (bDest:>RawRefTy ansTy)) -> do - let resultEffs = ignoreHoistFailure $ hoist bDest $ getEffects body - PiType bs $ EffTy resultEffs ansTy + Just (PairB bs (_:>RawRefTy ansTy)) -> do + PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here _ -> error "expected trailing dest binder" +blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) +blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff + typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfApp (Pi (CorePiType _ bs (EffTy _ resultTy))) xs = do +typeOfApp (Pi (CorePiType _ _ bs (EffTy _ resultTy))) xs = do let subst = bs @@> fmap SubstVal xs applySubst subst resultTy typeOfApp _ _ = error "expected a pi type" @@ -76,14 +93,14 @@ typeOfApplyMethod d i args = do typeOfDictExpr :: EnvReader m => DictExpr n -> m n (CType n) typeOfDictExpr e = liftM ignoreExcept $ liftEnvReaderT $ case e of InstanceDict instanceName args -> do - InstanceDef className bs params _ <- lookupInstanceDef instanceName - ClassDef sourceName _ _ _ _ _ <- lookupClassDef className + InstanceDef className _ bs params _ <- lookupInstanceDef instanceName + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className ListE params' <- applySubst (bs @@> map SubstVal args) $ ListE params return $ DictTy $ DictType sourceName className params' InstantiatedGiven given args -> typeOfApp (getType given) args SuperclassProj d i -> do DictTy (DictType _ className params) <- return $ getType d - ClassDef _ _ _ bs superclasses _ <- lookupClassDef className + ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className applySubst (bs @@> map SubstVal params) $ getSuperclassType REmpty superclasses i IxFin n -> liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n @@ -114,20 +131,21 @@ typeOfProjRef (TC (RefType h s)) p = do typeOfProjRef _ _ = error "expected a reference" appEffTy :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (EffTy r n) -appEffTy (Pi (CorePiType _ bs effTy)) xs = do +appEffTy (Pi (CorePiType _ _ bs effTy)) xs = do let subst = bs @@> fmap SubstVal xs applySubst subst effTy appEffTy t _ = error $ "expected a pi type, got: " ++ pprint t partialAppType :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -partialAppType (Pi (CorePiType expl bs effTy)) xs = do +partialAppType (Pi (CorePiType appExpl expls bs effTy)) xs = do + (_, expls2) <- return $ splitAt (length xs) expls PairB bs1 bs2 <- return $ splitNestAt (length xs) bs let subst = bs1 @@> fmap SubstVal xs - applySubst subst $ Pi $ CorePiType expl bs2 effTy + applySubst subst $ Pi $ CorePiType appExpl expls2 bs2 effTy partialAppType _ _ = error "expected a pi type" appEffects :: (EnvReader m, IRRep r) => Type r n -> [Atom r n] -> m n (EffectRow r n) -appEffects (Pi (CorePiType _ bs (EffTy effs _))) xs = do +appEffects (Pi (CorePiType _ _ bs (EffTy effs _))) xs = do let subst = bs @@> fmap SubstVal xs applySubst subst effs appEffects _ _ = error "expected a pi type" @@ -137,40 +155,42 @@ effTyOfHof hof = EffTy <$> hofEffects hof <*> typeOfHof hof typeOfHof :: (EnvReader m, IRRep r) => Hof r n -> m n (Type r n) typeOfHof = \case - For _ ixTy f -> case getLamExprType f of + For _ ixTy f -> getLamExprType f >>= \case PiType (UnaryNest b) (EffTy _ eltTy) -> return $ TabTy (ixTypeDict ixTy) b eltTy _ -> error "expected a unary pi type" While _ -> return UnitTy - Linearize f _ -> case getLamExprType f of + Linearize f _ -> getLamExprType f >>= \case PiType (UnaryNest (binder:>a)) (EffTy Pure b) -> do let b' = ignoreHoistFailure $ hoist binder b let fLinTy = Pi $ nonDepPiType [a] Pure b' return $ PairTy b' fLinTy _ -> error "expected a unary pi type" - Transpose f _ -> case getLamExprType f of + Transpose f _ -> getLamExprType f >>= \case PiType (UnaryNest (_:>a)) _ -> return a _ -> error "expected a unary pi type" - RunReader _ f -> return resultTy - where (resultTy, _) = getTypeRWSAction f - RunWriter _ _ f -> return $ uncurry PairTy $ getTypeRWSAction f - RunState _ _ f -> return $ PairTy resultTy stateTy - where (resultTy, stateTy) = getTypeRWSAction f - RunIO f -> return $ getType f - RunInit f -> return $ getType f + RunReader _ f -> do + (resultTy, _) <- getTypeRWSAction f + return resultTy + RunWriter _ _ f -> uncurry PairTy <$> getTypeRWSAction f + RunState _ _ f -> do + (resultTy, stateTy) <- getTypeRWSAction f + return $ PairTy resultTy stateTy + RunIO f -> blockTy f + RunInit f -> blockTy f CatchException ty _ -> return ty hofEffects :: (EnvReader m, IRRep r) => Hof r n -> m n (EffectRow r n) hofEffects = \case For _ _ f -> functionEffs f - While body -> return $ getEffects body + While body -> blockEff body Linearize _ _ -> return Pure -- Body has to be a pure function Transpose _ _ -> return Pure -- Body has to be a pure function RunReader _ f -> rwsFunEffects Reader f RunWriter d _ f -> maybeInit d <$> rwsFunEffects Writer f RunState d _ f -> maybeInit d <$> rwsFunEffects State f - RunIO f -> return $ deleteEff IOEffect $ getEffects f - RunInit f -> return $ deleteEff InitEffect $ getEffects f - CatchException _ f -> return $ deleteEff ExceptionEffect $ getEffects f + RunIO f -> deleteEff IOEffect <$> blockEff f + RunInit f -> deleteEff InitEffect <$> blockEff f + CatchException _ f -> deleteEff ExceptionEffect <$> blockEff f where maybeInit :: IRRep r => Maybe (Atom r i) -> (EffectRow r o -> EffectRow r o) maybeInit d = case d of Just _ -> (<>OneEffect InitEffect); Nothing -> id @@ -179,7 +199,7 @@ deleteEff eff (EffectRow effs t) = EffectRow (effs `eSetDifference` eSetSingleto getMethodIndex :: EnvReader m => ClassName n -> SourceName -> m n Int getMethodIndex className methodSourceName = do - ClassDef _ methodNames _ _ _ _ <- lookupClassDef className + ClassDef _ methodNames _ _ _ _ _ <- lookupClassDef className case elemIndex methodSourceName methodNames of Nothing -> error $ methodSourceName ++ " is not a method of " ++ pprint className Just i -> return i @@ -192,9 +212,8 @@ getUVarType = \case UDataConVar v -> getDataConNameType v UPunVar v -> getStructDataConType v UClassVar v -> do - ClassDef _ _ _ bs _ _ <- lookupClassDef v - let bs' = fmapNest (\(RolePiBinder _ b) -> b) bs - return $ Pi $ CorePiType ExplicitApp bs' $ EffTy Pure TyKind + ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef v + return $ Pi $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind UMethodVar v -> getMethodNameType v UEffectVar _ -> error "not implemented" UEffectOpVar _ -> error "not implemented" @@ -202,23 +221,22 @@ getUVarType = \case getMethodNameType :: EnvReader m => MethodName n -> m n (CType n) getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case MethodBinding className i -> do - ClassDef _ _ paramNames paramBs scBinders methodTys <- lookupClassDef className - let paramBs' = zipWithNest paramBs paramNames \(RolePiBinder _ (WithExpl _ b)) paramName -> - WithExpl (Inferred paramName Unify) b - refreshAbs (Abs paramBs' $ Abs scBinders (methodTys !! i)) \paramBs'' (Abs scBinders' piTy) -> do - let params = Var <$> nestToAtomVars (fmapNest withoutExpl paramBs'') + ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className + refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' (Abs scBinders' piTy) -> do + let params = Var <$> nestToAtomVars paramBs' dictTy <- DictTy <$> dictType (sink className) params withFreshBinder noHint dictTy \dictB -> do scDicts <- getSuperclassDicts (Var $ binderVar dictB) piTy' <- applySubst (scBinders'@@>(SubstVal<$>scDicts)) piTy - CorePiType appExpl methodBs effTy <- return piTy' - let dictBs = UnaryNest $ WithExpl (Inferred Nothing (Synth $ Partial $ succ i)) dictB - return $ Pi $ CorePiType appExpl (paramBs'' >>> dictBs >>> methodBs) effTy + CorePiType appExpl methodExpls methodBs effTy <- return piTy' + let paramExpls = paramNames <&> \name -> Inferred name Unify + let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls + return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy getMethodType :: EnvReader m => Dict n -> Int -> m n (CorePiType n) getMethodType dict i = do ~(DictTy (DictType _ className params)) <- return $ getType dict - ClassDef _ _ _ paramBs classBs methodTys <- lookupClassDef className + ClassDef _ _ _ _ paramBs classBs methodTys <- lookupClassDef className let methodTy = methodTys !! i superclassDicts <- getSuperclassDicts dict let subst = ( paramBs @@> map SubstVal params @@ -226,60 +244,56 @@ getMethodType dict i = do applySubst subst methodTy {-# INLINE getMethodType #-} - getTyConNameType :: EnvReader m => TyConName n -> m n (Type CoreIR n) getTyConNameType v = do - TyConDef _ bs _ <- lookupTyCon v + TyConDef _ expls bs _ <- lookupTyCon v case bs of Empty -> return TyKind - _ -> do - let bs' = fmapNest (\(RolePiBinder _ b) -> b) bs - return $ Pi $ CorePiType ExplicitApp bs' $ EffTy Pure TyKind + _ -> return $ Pi $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n) getDataConNameType dataCon = liftEnvReaderM do (tyCon, i) <- lookupDataCon dataCon lookupTyCon tyCon >>= \case - tyConDef@(TyConDef tcSn paramBs ~(ADTCons dataCons)) -> - buildDataConType tyConDef \paramBs' paramVs params -> do + tyConDef@(TyConDef tcSn _ paramBs ~(ADTCons dataCons)) -> + buildDataConType tyConDef \expls paramBs' paramVs params -> do DataConDef _ ab _ _ <- applyRename (paramBs @@> paramVs) (dataCons !! i) refreshAbs ab \dataBs UnitE -> do let appExpl = case dataBs of Empty -> ImplicitApp _ -> ExplicitApp let resultTy = NewtypeTyCon $ UserADTType tcSn (sink tyCon) (sink params) - let dataBs' = fmapNest (WithExpl Explicit) dataBs - return $ Pi $ CorePiType appExpl (paramBs' >>> dataBs') (EffTy Pure resultTy) + let dataExpls = nestToList (const $ Explicit) dataBs + return $ Pi $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) getStructDataConType :: EnvReader m => TyConName n -> m n (CType n) getStructDataConType tyCon = liftEnvReaderM do - tyConDef@(TyConDef tcSn paramBs ~(StructFields fields)) <- lookupTyCon tyCon - buildDataConType tyConDef \paramBs' paramVs params -> do + tyConDef@(TyConDef tcSn _ paramBs ~(StructFields fields)) <- lookupTyCon tyCon + buildDataConType tyConDef \expls paramBs' paramVs params -> do fieldTys <- forM fields \(_, t) -> applyRename (paramBs @@> paramVs) t let resultTy = NewtypeTyCon $ UserADTType tcSn (sink tyCon) params Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy - let dataBs' = fmapNest (WithExpl Explicit) dataBs - return $ Pi $ CorePiType ExplicitApp (paramBs' >>> dataBs') (EffTy Pure resultTy') + let dataExpls = nestToList (const Explicit) dataBs + return $ Pi $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') buildDataConType :: (EnvReader m, EnvExtender m) => TyConDef n - -> (forall l. DExt n l => Nest (WithExpl CBinder) n l -> [CAtomName l] -> TyConParams l -> m l a) + -> (forall l. DExt n l => [Explicitness] -> Nest CBinder n l -> [CAtomName l] -> TyConParams l -> m l a) -> m n a -buildDataConType (TyConDef _ bs _) cont = do - bs' <- return $ forNest bs \(RolePiBinder _ (WithExpl expl b)) -> case expl of - Explicit -> WithExpl (Inferred Nothing Unify) b - _ -> WithExpl expl b - refreshAbs (Abs bs' UnitE) \bs'' UnitE -> do - let expls = nestToList (\(RolePiBinder _ b) -> getExpl b) bs - let vs = nestToNames bs'' +buildDataConType (TyConDef _ roleExpls bs _) cont = do + let expls = snd <$> roleExpls + expls' <- forM expls \case + Explicit -> return $ Inferred Nothing Unify + expl -> return $ expl + refreshAbs (Abs bs UnitE) \bs' UnitE -> do + let vs = nestToNames bs' vs' <- mapM toAtomVar vs - cont bs'' vs $ TyConParams expls (Var <$> vs') + cont expls' bs' vs $ TyConParams expls (Var <$> vs') makeTyConParams :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n) makeTyConParams tc params = do - TyConDef _ paramBs _ <- lookupTyCon tc - let expls = nestToList (\(RolePiBinder _ b) -> getExpl b) paramBs - return $ TyConParams expls params + TyConDef _ expls _ _ <- lookupTyCon tc + return $ TyConParams (map snd expls) params getDataClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) getDataClassName = lookupSourceMap "Data" >>= \case @@ -300,7 +314,7 @@ getIxClassName = lookupSourceMap "Ix" >>= \case dictType :: EnvReader m => ClassName n -> [CAtom n] -> m n (DictType n) dictType className params = do - ClassDef sourceName _ _ _ _ _ <- lookupClassDef className + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className return $ DictType sourceName className params ixDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) @@ -316,58 +330,35 @@ makePreludeMaybeTy ty = do -- === computing effects === functionEffs :: (IRRep r, EnvReader m) => LamExpr r n -> m n (EffectRow r n) -functionEffs f = case getLamExprType f of +functionEffs f = getLamExprType f >>= \case PiType b (EffTy effs _) -> return $ ignoreHoistFailure $ hoist b effs rwsFunEffects :: (IRRep r, EnvReader m) => RWS -> LamExpr r n -> m n (EffectRow r n) -rwsFunEffects rws f = return case getLamExprType f of +rwsFunEffects rws f = getLamExprType f >>= \case PiType (BinaryNest h ref) et -> do let effs' = ignoreHoistFailure $ hoist ref (etEff et) let hVal = Var $ AtomVar (binderName h) (TC HeapType) let effs'' = deleteEff (RWSEffect rws hVal) effs' - ignoreHoistFailure $ hoist h effs'' + return $ ignoreHoistFailure $ hoist h effs'' _ -> error "Expected a binary function type" -getLamExprType :: IRRep r => LamExpr r n -> PiType r n -getLamExprType (LamExpr bs body) = PiType bs (EffTy (getEffects body) (getType body)) +getLamExprType :: (IRRep r, EnvReader m) => LamExpr r n -> m n (PiType r n) +getLamExprType (LamExpr bs body) = liftEnvReaderM $ + refreshAbs (Abs bs body) \bs' body' -> do + effTy <- blockEffTy body' + return $ PiType bs' effTy -getTypeRWSAction :: IRRep r => LamExpr r n -> (Type r n, Type r n) -getTypeRWSAction f = case getLamExprType f of +getTypeRWSAction :: (IRRep r, EnvReader m) => LamExpr r n -> m n (Type r n, Type r n) +getTypeRWSAction f = getLamExprType f >>= \case PiType (BinaryNest regionBinder refBinder) (EffTy _ resultTy) -> do case binderType refBinder of RefTy _ referentTy -> do let referentTy' = ignoreHoistFailure $ hoist regionBinder referentTy let resultTy' = ignoreHoistFailure $ hoist (PairB regionBinder refBinder) resultTy - (resultTy', referentTy') + return (resultTy', referentTy') _ -> error "expected a ref" _ -> error "expected a pi type" -computeAbsEffects :: (IRRep r, EnvExtender m, RenameE e) - => Abs (Nest (Decl r)) e n -> m n (Abs (Nest (Decl r)) (EffectRow r `PairE` e) n) -computeAbsEffects it = refreshAbs it \decls result -> do - effs <- declNestEffects decls - return $ Abs decls (effs `PairE` result) -{-# INLINE computeAbsEffects #-} - -declNestEffects :: (IRRep r, EnvReader m) => Nest (Decl r) n l -> m l (EffectRow r l) -declNestEffects decls = liftEnvReaderM $ declNestEffectsRec decls mempty -{-# INLINE declNestEffects #-} - -declNestEffectsRec :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) -declNestEffectsRec Empty !acc = return acc -declNestEffectsRec n@(Nest decl rest) !acc = withExtEvidence n do - expr <- sinkM $ declExpr decl - acc' <- sinkM $ acc <> (getEffects expr) - declNestEffectsRec rest acc' - where - declExpr :: Decl r n l -> Expr r n - declExpr (Let _ (DeclBinding _ expr)) = expr - -instantiateHandlerType :: EnvReader m => HandlerName n -> CType n -> [CAtom n] -> m n (CType n) -instantiateHandlerType hndName r args = do - HandlerDef _ rb bs _effs retTy _ _ <- lookupHandlerDef hndName - applySubst (rb @> (SubstVal (Type r)) <.> bs @@> (map SubstVal args)) retTy - getSuperclassDicts :: EnvReader m => CAtom n -> m n ([CAtom n]) getSuperclassDicts dict = do case getType dict of @@ -378,16 +369,21 @@ getSuperclassDicts dict = do getSuperclassTys :: EnvReader m => DictType n -> m n [CType n] getSuperclassTys (DictType _ className params) = do - ClassDef _ _ _ bs superclasses _ <- lookupClassDef className + ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className forM [0 .. nestLength superclasses - 1] \i -> do applySubst (bs @@> map SubstVal params) $ getSuperclassType REmpty superclasses i getTypeTopFun :: EnvReader m => TopFunName n -> m n (PiType SimpIR n) getTypeTopFun f = lookupTopFun f >>= \case - DexTopFun _ piTy _ _ -> return piTy + DexTopFun _ (TopLam _ piTy _) _ -> return piTy FFITopFun _ iTy -> liftIFunType iTy +asTopLam :: (EnvReader m, IRRep r) => LamExpr r n -> m n (TopLam r n) +asTopLam lam = do + piTy <- getLamExprType lam + return $ TopLam False piTy lam + liftIFunType :: (IRRep r, EnvReader m) => IFunType -> m n (PiType r n) liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where go :: IRRep r => [BaseType] -> EnvReaderM n (PiType r n) diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index e71769003..9be267241 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -119,8 +119,8 @@ instance IRRep r => HasType r (Con r) where getSuperclassType :: RNest CBinder n l -> Nest CBinder l l' -> Int -> CType n getSuperclassType _ Empty = error "bad index" -getSuperclassType bsAbove (Nest b bs) = \case - 0 -> ignoreHoistFailure $ hoist bsAbove $ binderType b +getSuperclassType bsAbove (Nest b@(_:>t) bs) = \case + 0 -> ignoreHoistFailure $ hoist bsAbove t i -> getSuperclassType (RNest bsAbove b) bs (i-1) instance IRRep r => HasType r (Expr r) where @@ -144,12 +144,6 @@ instance IRRep r => HasType r (DAMOp r) where Seq _ _ _ cinit _ -> getType cinit RememberDest _ d _ -> getType d -instance HasType CoreIR UserEffectOp where - getType = \case - Handle _ _ _ -> undefined - Perform _ _ -> undefined - Resume retTy _ -> retTy - instance IRRep r => HasType r (PrimOp r) where getType primOp = case primOp of BinOp op x _ -> TC $ BaseType $ typeBinOp op $ getTypeBaseType x @@ -159,7 +153,6 @@ instance IRRep r => HasType r (PrimOp r) where MiscOp op -> getType op VectorOp op -> getType op DAMOp op -> getType op - UserEffectOp op -> getType op RefOp ref m -> case getType ref of TC (RefType _ s) -> case m of MGet -> s @@ -220,6 +213,9 @@ rawStrType = case newName "n" of rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy +tabIxType :: TabPiType r n -> IxType r n +tabIxType (TabPiType d (_:>t) _) = IxType t d + typesAsBinderNest :: (SinkableE e, HoistableE e, IRRep r) => [Type r n] -> e n -> Abs (Nest (Binder r)) e n @@ -228,14 +224,20 @@ typesAsBinderNest types body = toConstBinderNest types body nonDepPiType :: [CType n] -> EffectRow CoreIR n -> CType n -> CorePiType n nonDepPiType argTys eff resultTy = case typesAsBinderNest argTys (PairE eff resultTy) of Abs bs (PairE eff' resultTy') -> do - let bs' = fmapNest (WithExpl Explicit) bs - CorePiType ExplicitApp bs' $ EffTy eff' resultTy' + let expls = nestToList (const Explicit) bs + CorePiType ExplicitApp expls bs $ EffTy eff' resultTy' nonDepTabPiType :: IRRep r => IxType r n -> Type r n -> TabPiType r n nonDepTabPiType (IxType t d) resultTy = case toConstAbsPure resultTy of Abs b resultTy' -> TabPiType d (b:>t) resultTy' +corePiTypeToPiType :: CorePiType n -> PiType CoreIR n +corePiTypeToPiType (CorePiType _ _ bs effTy) = PiType bs effTy + +coreLamToTopLam :: CoreLamExpr n -> TopLam CoreIR n +coreLamToTopLam (CoreLamExpr ty f) = TopLam False (corePiTypeToPiType ty) f + (==>) :: IRRep r => IxType r n -> Type r n -> Type r n a ==> b = TabPi $ nonDepTabPiType a b @@ -253,11 +255,6 @@ ixTyFromDict ixDict = flip IxType ixDict $ case ixDict of IxDictRawFin _ -> IdxRepTy IxDictSpecialized n _ _ -> n -instance IRRep r => HasType r (Block r) where - getType (Block NoBlockAnn Empty result) = getType result - getType (Block (BlockAnn (EffTy _ ty)) _ _) = ty - getType _ = error "impossible" - -- === querying effects implementation === instance IRRep r => HasEffects (Expr r) r where @@ -309,7 +306,6 @@ instance IRRep r => HasEffects (PrimOp r) r where IndexRef _ _ -> Pure ProjRef _ _ -> Pure _ -> error "not a ref" - UserEffectOp _ -> undefined DAMOp op -> case op of Place _ _ -> OneEffect InitEffect Seq eff _ _ _ _ -> eff @@ -318,13 +314,3 @@ instance IRRep r => HasEffects (PrimOp r) r where Freeze _ -> Pure -- is this correct? Hof (TypedHof (EffTy eff _) _) -> eff {-# INLINE getEffects #-} - - -instance IRRep r => HasEffects (Block r) r where - getEffects (Block (BlockAnn (EffTy effs _)) _ _) = effs - getEffects (Block NoBlockAnn _ _) = Pure - {-# INLINE getEffects #-} - -instance IRRep r => HasEffects (Alt r) r where - getEffects (Abs bs body) = ignoreHoistFailure $ hoist bs (getEffects body) - {-# INLINE getEffects #-} diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index d1b17c792..4a4c2c6a5 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -169,8 +169,9 @@ withBuffer cont = do body <- buildBlock do cont $ sink $ Var $ binderVar b return UnitVal - let piBinders = BinaryNest (WithExpl (Inferred Nothing Unify) h) (WithExpl Explicit b) - let piTy = CorePiType ExplicitApp piBinders $ EffTy eff UnitTy + let binders = BinaryNest h b + let expls = [Inferred Nothing Unify, Explicit] + let piTy = CorePiType ExplicitApp expls binders $ EffTy eff UnitTy let lam = LamExpr (BinaryNest h b) body return $ Lam $ CoreLamExpr piTy lam applyPreludeFunction "with_stack_internal" [lam] @@ -184,8 +185,8 @@ bufferTy h = do extendBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n () extendBuffer buf tab = do RefTy h _ <- return $ getType buf - TabTy d (_:>t) _ <- return $ getType tab - n <- applyIxMethodCore Size (IxType t d) [] + TabPi t <- return $ getType tab + n <- applyIxMethodCore Size (tabIxType t) [] void $ applyPreludeFunction "stack_extend_internal" [n, h, buf, tab] -- argument has type `Word8` @@ -236,8 +237,8 @@ forEachTabElt -> (forall l. (Emits l, DExt n l) => CAtom l -> CAtom l -> m l ()) -> m n () forEachTabElt tab cont = do - TabTy d (_:>t) _ <- return $ getType tab - let ixTy = IxType t d + TabPi t <- return $ getType tab + let ixTy = tabIxType t void $ buildFor "i" Fwd ixTy \i -> do x <- tabApp (sink tab) (Var i) i' <- applyIxMethodCore Ordinal (sink ixTy) [Var i] diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index b3167ef0b..a40b01fa5 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -7,8 +7,8 @@ {-# LANGUAGE UndecidableInstances #-} module Simplify - ( simplifyTopBlock, simplifyTopFunction, SimplifiedBlock (..), ReconstructAtom (..), applyReconTop, - linearizeTopFun) where + ( simplifyTopBlock, simplifyTopFunction, ReconstructAtom (..), applyReconTop, + linearizeTopFun, SimplifiedTopLam (..)) where import Control.Applicative import Control.Category ((>>>)) @@ -19,7 +19,7 @@ import Data.Text.Prettyprint.Doc (Pretty (..), hardline) import Builder import CheapReduction -import CheckType (CheckableE (..), isData) +import CheckType (CheckableE (..), isData, checkBlock) import Core import Err import Generalize @@ -106,9 +106,8 @@ tryAsDataAtom atom = do forceTabLam :: Emits n => TabLamExpr n -> SimplifyM i n (SAtom n) forceTabLam (PairE ixTy (Abs b ab)) = buildFor (getNameHint b) Fwd ixTy \v -> do - Abs decls result <- applyRename (b@>(atomVarName v)) ab - result' <- emitDecls decls result - toDataAtomIgnoreRecon result' + result <- applyRename (b@>(atomVarName v)) ab >>= emitDecls + toDataAtomIgnoreRecon result type NaryTabLamExpr = Abs (Nest SBinder) (Abs (Nest SDecl) CAtom) @@ -154,12 +153,12 @@ getRepType ty = go ty where x <- liftSimpAtom (sink l) (Var $ binderVar b') r' <- go =<< applySubst (b@>SubstVal x) r return $ DepPairTy $ DepPairType expl b' r' - TabPi (TabPiType d (b:>t) bodyTy) -> do - let ixTy = IxType t d + TabPi tabTy -> do + let ixTy = tabIxType tabTy IxType t' d' <- simplifyIxType ixTy - withFreshBinder (getNameHint b) t' \b' -> do + withFreshBinder (getNameHint tabTy) t' \b' -> do x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b') - bodyTy' <- go =<< applySubst (b@>SubstVal x) bodyTy + bodyTy' <- go =<< instantiateTabPiTy (sink tabTy) x return $ TabPi $ TabPiType d' b' bodyTy' NewtypeTyCon con -> do (_, ty') <- unwrapNewtypeType con @@ -246,18 +245,24 @@ instance ScopableBuilder SimpIR (SimplifyM i) where -- === Top-level API === +data SimplifiedTopLam n = SimplifiedTopLam (STopLam n) (ReconstructAtom n) data SimplifiedBlock n = SimplifiedBlock (SBlock n) (ReconstructAtom n) --- TODO: extend this to work on functions instead of blocks (with blocks still --- accessible as nullary functions) -simplifyTopBlock :: (TopBuilder m, Mut n) => Block CoreIR n -> m n (SimplifiedBlock n) -simplifyTopBlock block = liftSimplifyM $ buildSimplifiedBlock $ simplifyBlock block +simplifyTopBlock + :: (TopBuilder m, Mut n) => TopBlock CoreIR n -> m n (SimplifiedTopLam n) +simplifyTopBlock (TopLam _ _ (LamExpr Empty body)) = do + SimplifiedBlock block recon <- liftSimplifyM $ buildSimplifiedBlock $ simplifyBlock body + topLam <- asTopLam $ LamExpr Empty block + return $ SimplifiedTopLam topLam recon +simplifyTopBlock _ = error "not a block (nullary lambda)" {-# SCC simplifyTopBlock #-} -simplifyTopFunction :: (TopBuilder m, Mut n) => LamExpr CoreIR n -> m n (LamExpr SimpIR n) -simplifyTopFunction f = liftSimplifyM do - (lam, CoerceReconAbs) <- simplifyLam f - return lam +simplifyTopFunction :: (TopBuilder m, Mut n) => CTopLam n -> m n (STopLam n) +simplifyTopFunction (TopLam False _ f) = do + asTopLam =<< liftSimplifyM do + (lam, CoerceReconAbs) <- simplifyLam f + return lam +simplifyTopFunction _ = error "shouldn't be in destination-passing style already" {-# SCC simplifyTopFunction #-} applyReconTop :: (EnvReader m, Fallible1 m) => ReconstructAtom n -> SAtom n -> m n (CAtom n) @@ -276,12 +281,23 @@ instance HoistableE SimplifiedBlock instance CheckableE SimpIR SimplifiedBlock where checkE (SimplifiedBlock block _) = -- TODO: CheckableE instance for the recon too - checkE block + void $ checkBlock block instance Pretty (SimplifiedBlock n) where pretty (SimplifiedBlock block recon) = pretty block <> hardline <> pretty recon +instance SinkableE SimplifiedTopLam where + sinkingProofE = todoSinkableProof + +instance CheckableE SimpIR SimplifiedTopLam where + checkE (SimplifiedTopLam lam _) = do + -- TODO: CheckableE instance for the recon too + checkE lam + +instance Pretty (SimplifiedTopLam n) where + pretty (SimplifiedTopLam lam recon) = pretty lam <> hardline <> pretty recon + -- === All the bits of IR === simplifyDecls :: Emits o => Nest (Decl CoreIR) i i' -> SimplifyM i' o a -> SimplifyM i o a @@ -349,13 +365,6 @@ simplifyRefOp op ref = case op of ProjRef _ UnwrapNewtype -> return ref where emitRefOp op' = emitOp $ RefOp ref op' -caseComputingEffs - :: forall m n r. (MonadFail1 m, EnvReader m, IRRep r) - => Atom r n -> [Alt r n] -> Type r n -> m n (Expr r n) -caseComputingEffs scrut alts resultTy = do - return $ Case scrut alts (EffTy (foldMap getEffects alts) resultTy) -{-# INLINE caseComputingEffs #-} - defuncCaseCore :: Emits o => Atom CoreIR o -> Type CoreIR o -> (forall o'. (Emits o', DExt o o') => Int -> CAtom o' -> SimplifyM i o' (CAtom o')) @@ -396,17 +405,16 @@ defuncCase scrut resultTy cont = do alts' <- forM (enumerate altBinderTys) \(i, bTy) -> do buildAbs noHint bTy \x -> do buildBlock $ cont i (sink $ Var x) >>= toDataAtomIgnoreRecon - caseExpr <- caseComputingEffs scrut alts' resultTyData + caseExpr <- mkCase scrut resultTyData alts' emitExpr caseExpr >>= liftSimpAtom resultTy Nothing -> do split <- splitDataComponents resultTy - (alts', recons) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do + (alts', closureTys, recons) <- unzip3 <$> forM (enumerate altBinderTys) \(i, bTy) -> do simplifyAlt split bTy $ cont i - closureTys <- mapM getAltNonDataTy alts' let closureSumTy = SumTy closureTys let newNonDataTy = nonDataTy split alts'' <- forM (enumerate alts') \(i, alt) -> injectAltResult closureTys i alt - caseExpr <- caseComputingEffs scrut alts'' (PairTy (dataTy split) closureSumTy) + caseExpr <- mkCase scrut (PairTy (dataTy split) closureSumTy) alts'' caseResult <- emitExpr $ caseExpr (dataVal, sumVal) <- fromPair caseResult reconAlts <- forM (zip closureTys recons) \(ty, recon) -> @@ -419,11 +427,11 @@ simplifyAlt :: SplitDataNonData n -> SType o -> (forall o'. (Emits o', DExt o o') => SAtom o' -> SimplifyM i o' (CAtom o')) - -> SimplifyM i o (Alt SimpIR o, ReconstructAtom o) -simplifyAlt split ty cont = fromPairE <$> do + -> SimplifyM i o (Alt SimpIR o, SType o, ReconstructAtom o) +simplifyAlt split ty cont = do withFreshBinder noHint ty \b -> do ab <- buildScoped $ cont $ sink $ Var $ binderVar b - (resultWithDecls, recon) <- refreshAbs ab \decls result -> do + (body, recon) <- refreshAbs ab \decls result -> do let locals = toScopeFrag b >>> toScopeFrag decls -- TODO: this might be too cautious. The type only needs to -- be hoistable above the decls. In principle it can still @@ -432,15 +440,9 @@ simplifyAlt split ty cont = fromPairE <$> do (resultData, resultNonData) <- toSplit split result (newResult, reconAbs) <- telescopicCapture locals resultNonData return (Abs decls (PairVal resultData newResult), LamRecon reconAbs) - block <- makeBlockFromDecls resultWithDecls - return $ PairE (Abs b block) recon - -getAltNonDataTy :: EnvReader m => Alt SimpIR n -> m n (SType n) -getAltNonDataTy (Abs bs body) = liftSubstEnvReaderM @AtomSubstVal do - substBinders bs \bs' -> do - ~(PairTy _ ty) <- substM $ getType body - -- Result types of simplified abs should be hoistable past binder - return $ ignoreHoistFailure $ hoist bs' ty + EffTy _ (PairTy _ nonDataType) <- blockEffTy body + let nonDataType' = ignoreHoistFailure $ hoist b nonDataType + return (Abs b body, nonDataType', recon) simplifyApp :: forall i o. Emits o => NameHint -> CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) @@ -531,9 +533,9 @@ emitSpecialization s = do extendSpecializationCache s name return name -specializedFunCoreDefinition :: (Mut n, TopBuilder m) => SpecializationSpec n -> m n (LamExpr CoreIR n) +specializedFunCoreDefinition :: (Mut n, TopBuilder m) => SpecializationSpec n -> m n (TopLam CoreIR n) specializedFunCoreDefinition (AppSpecialization f (Abs bs staticArgs)) = do - liftBuilder $ buildLamExpr (EmptyAbs bs) \runtimeArgs -> do + (asTopLam =<<) $ liftBuilder $ buildLamExpr (EmptyAbs bs) \runtimeArgs -> do -- This avoids an infinite loop. Otherwise, in simplifyTopFunction, -- where we eta-expand and try to simplify `App f args`, we would see `f` as a -- "noinline" function and defer its simplification. @@ -547,12 +549,12 @@ simplifyTabApp f [] = return f simplifyTabApp f@(SimpInCore sic) xs = case sic of TabLam _ _ -> do case fromNaryTabLam (length xs) f of - Just (bsCount, Abs bs declsAtom) -> do + Just (bsCount, Abs bs block) -> do let (xsPref, xsRest) = splitAt bsCount xs xsPref' <- mapM toDataAtomIgnoreRecon xsPref - Abs decls atom <- applySubst (bs@@>(SubstVal <$> xsPref')) declsAtom - atom' <- emitDecls decls atom - simplifyTabApp atom' xsRest + block' <- applySubst (bs@@>(SubstVal <$> xsPref')) block + atom <- emitDecls block' + simplifyTabApp atom xsRest Nothing -> error "should never happen" ACase e alts ty -> dropSubst do resultTy <- typeOfTabApp ty xs @@ -604,10 +606,10 @@ requireIxDictCache dictAbs = do Nothing -> error "Couldn't hoist specialized dictionary" {-# INLINE requireIxDictCache #-} -simplifyDictMethod :: Mut n => AbsDict n -> IxMethod -> TopBuilderM n (SLam n) +simplifyDictMethod :: Mut n => AbsDict n -> IxMethod -> TopBuilderM n (TopLam SimpIR n) simplifyDictMethod absDict@(Abs bs dict) method = do ty <- liftEnvReaderM $ ixMethodType method absDict - lamExpr <- liftBuilder $ buildLamExprFromPi ty \allArgs -> do + lamExpr <- liftBuilder $ buildTopLamFromPi ty \allArgs -> do let (extraArgs, methodArgs) = splitAt (nestLength bs) allArgs dict' <- applyRename (bs @@> (atomVarName <$> extraArgs)) dict emitExpr =<< mkApplyMethod dict' (fromEnum method) (Var <$> methodArgs) @@ -616,8 +618,8 @@ simplifyDictMethod absDict@(Abs bs dict) method = do ixMethodType :: IxMethod -> AbsDict n -> EnvReaderM n (PiType CoreIR n) ixMethodType method absDict = do refreshAbs absDict \extraArgBs dict -> do - CorePiType _ methodArgs (EffTy _ resultTy) <- getMethodType dict (fromEnum method) - let allBs = extraArgBs >>> fmapNest withoutExpl methodArgs + CorePiType _ _ methodArgs (EffTy _ resultTy) <- getMethodType dict (fromEnum method) + let allBs = extraArgBs >>> methodArgs return $ PiType allBs (EffTy Pure resultTy) -- TODO: do we even need this, or is it just a glorified `SubstM`? @@ -716,19 +718,16 @@ buildSimplifiedBlock cont = do return $ RightE (dataResult `PairE` ansTy) case eitherResult of LeftE ans -> do - (declsResult, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do + (block, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do (newResult, reconAbs) <- telescopicCapture (toScopeFrag decls') ans' return (Abs decls' newResult, LamRecon reconAbs) - block <- makeBlockFromDecls declsResult return $ SimplifiedBlock block recon RightE (ans `PairE` ty) -> do - block <- makeBlockFromDecls $ Abs decls ans let ty' = ignoreHoistFailure $ hoist (toScopeFrag decls) ty - return $ SimplifiedBlock block (CoerceRecon ty') + return $ SimplifiedBlock (Abs decls ans) (CoerceRecon ty') simplifyOp :: Emits o => NameHint -> PrimOp CoreIR i -> SimplifyM i o (CAtom o) simplifyOp hint op = case op of - UserEffectOp _ -> error "not implemented" Hof (TypedHof (EffTy _ ty) hof) -> do ty' <- substM ty simplifyHof hint ty' hof @@ -789,7 +788,7 @@ applyDictMethod resultTy d i methodArgs = do cheapNormalize d >>= \case DictCon _ (InstanceDict instanceName instanceArgs) -> dropSubst do instanceArgs' <- mapM simplifyAtom instanceArgs - InstanceDef _ bsInstance _ body <- lookupInstanceDef instanceName + InstanceDef _ _ bsInstance _ body <- lookupInstanceDef instanceName let InstanceBody _ methods = body let method = methods !! i extendSubst (bsInstance @@> (SubstVal <$> instanceArgs')) do @@ -879,7 +878,9 @@ simplifyHof _hint resultTy = \case liftSimpAtom resultTy result CatchException _ body-> do SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body - block <- liftBuilder $ runSubstReaderT idSubst $ buildBlock $ exceptToMaybeBlock body' + simplifiedResultTy <- blockTy body' + block <- liftBuilder $ runSubstReaderT idSubst $ buildBlock $ + exceptToMaybeBlock (sink simplifiedResultTy) body' result <- emitBlock block case recon of CoerceRecon ty -> do @@ -920,12 +921,12 @@ preludeNothingVal ty = do preludeMaybeNewtypeCon :: EnvReader m => CType n -> m n (NewtypeCon n) preludeMaybeNewtypeCon ty = do ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" - TyConDef sn _ _ <- lookupTyCon tyConName + TyConDef sn _ _ _ <- lookupTyCon tyConName let params = TyConParams [Explicit] [Type ty] return $ UserADTData sn tyConName params simplifyBlock :: Emits o => Block CoreIR i -> SimplifyM i o (CAtom o) -simplifyBlock (Block _ decls result) = simplifyDecls decls $ simplifyAtom result +simplifyBlock (Abs decls result) = simplifyDecls decls $ simplifyAtom result -- === simplifying custom linearizations === @@ -941,10 +942,10 @@ linearizeTopFun spec = do linearizeTopFunNoCache :: (Mut n, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n) linearizeTopFunNoCache spec@(LinearizationSpec f actives) = do - TopFunBinding ~(DexTopFun _ _ lam _) <- lookupEnv f + TopFunBinding ~(DexTopFun _ lam _) <- lookupEnv f PairE fPrimal fTangent <- liftSimplifyM $ tryGetCustomRule (sink f) >>= \case Just (absParams, rule) -> simplifyCustomLinearization (sink absParams) actives (sink rule) - Nothing -> liftM toPairE $ liftDoubleBuilderToSimplifyM $ linearizeLam (sink lam) actives + Nothing -> liftM toPairE $ liftDoubleBuilderToSimplifyM $ linearizeTopLam (sink lam) actives fTangentT <- transposeTopFun fTangent fPrimal' <- emitTopFunBinding "primal" (LinearizationPrimal spec) fPrimal fTangent' <- emitTopFunBinding "tangent" (LinearizationTangent spec) fTangent @@ -956,7 +957,7 @@ tryGetCustomRule :: EnvReader m => TopFunName n -> m n (Maybe (Abstracted CoreIR tryGetCustomRule f' = do ~(TopFunBinding f) <- lookupEnv f' case f of - DexTopFun def _ _ _ -> case def of + DexTopFun def _ _ -> case def of Specialization (AppSpecialization fCore absParams) -> fmap (absParams,) <$> lookupCustomRules (atomVarName fCore) _ -> return Nothing @@ -969,10 +970,10 @@ type Linearized = Abs (Nest SBinder) -- primal args simplifyCustomLinearization :: Abstracted CoreIR (ListE CAtom) n -> [Active] -> AtomRules n - -> SimplifyM i n (PairE SLam SLam n) + -> SimplifyM i n (PairE STopLam STopLam n) simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do CustomLinearize nImplicit nExplicit zeros fCustom <- return rule - defuncLinearized =<< withSimplifiedBinders runtimeBs \runtimeBs' runtimeArgs -> do + linearized <- withSimplifiedBinders runtimeBs \runtimeBs' runtimeArgs -> do Abs runtimeBs' <$> buildScoped do ListE staticArgs' <- applySubst (runtimeBs @@> (SubstVal . sink <$> runtimeArgs)) staticArgs fCustom' <- sinkM fCustom @@ -996,14 +997,18 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do -- a custom linearization defined for a function on ADTs will -- not work. fLin' <- sinkM fLin - Pi (CorePiType _ bs _) <- return $ getType fLin' + Pi (CorePiType _ _ bs _) <- return $ getType fLin' let tangentCoreTys = fromNonDepNest bs tangentArgs' <- zipWithM liftSimpAtom tangentCoreTys tangentArgs resultTyTangent <- typeOfApp (getType fLin') tangentArgs' tangentResult <- dropSubst $ simplifyApp noHint resultTyTangent fLin' tangentArgs' toDataAtomIgnoreRecon tangentResult return $ PairE primalResult' fLin' - where + PairE primalFun tangentFun <- defuncLinearized linearized + primalFun' <- asTopLam primalFun + tangentFun' <- asTopLam tangentFun + return $ PairE primalFun' tangentFun' + where buildTangentArgs :: Emits n => SymbolicZeros -> [(SType n, Active)] -> [SAtom n] -> SimplifyM i n [SAtom n] buildTangentArgs _ [] [] = return [] buildTangentArgs zeros ((t, False):tys) activeArgs = do @@ -1020,7 +1025,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do return $ activeArg':rest buildTangentArgs _ _ _ = error "zip error" - fromNonDepNest :: (HoistableB b, BindsOneAtomName CoreIR b) => Nest b n l -> [CType n] + fromNonDepNest :: Nest CBinder n l -> [CType n] fromNonDepNest Empty = [] fromNonDepNest (Nest b bs) = case ignoreHoistFailure $ hoist b (Abs bs UnitE) of @@ -1038,7 +1043,7 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do return $ Abs (Nest rB tBs') UnitE residualsTangentsBs' <- return $ ignoreHoistFailure $ hoist decls residualsTangentsBs return (Abs decls (PairVal primalResult residuals), reconAbs, residualsTangentsBs') - primalFun <- LamExpr bs <$> makeBlockFromDecls declsAndResult + let primalFun = LamExpr bs declsAndResult LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (Var residuals) applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emitBlock @@ -1049,25 +1054,20 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do type HandlerM = SubstReaderT AtomSubstVal (BuilderM SimpIR) -exceptToMaybeBlock :: Emits o => SBlock i -> HandlerM i o (SAtom o) -exceptToMaybeBlock (Block (BlockAnn (EffTy _ ty)) decls result) = do - ty' <- substM ty - exceptToMaybeDecls ty' decls $ Atom result -exceptToMaybeBlock (Block NoBlockAnn Empty result) = exceptToMaybeExpr $ Atom result -exceptToMaybeBlock _ = error "impossible" - -exceptToMaybeDecls :: Emits o => SType o -> Nest SDecl i i' -> SExpr i' -> HandlerM i o (SAtom o) -exceptToMaybeDecls _ Empty result = exceptToMaybeExpr result -exceptToMaybeDecls resultTy (Nest (Let b (DeclBinding _ rhs)) decls) finalResult = do +exceptToMaybeBlock :: Emits o => SType o -> SBlock i -> HandlerM i o (SAtom o) +exceptToMaybeBlock ty (Abs Empty result) = do + result' <- substM result + return $ JustAtom ty result' +exceptToMaybeBlock resultTy (Abs (Nest (Let b (DeclBinding _ rhs)) decls) finalResult) = do maybeResult <- exceptToMaybeExpr rhs case maybeResult of -- This case is just an optimization (but an important one!) JustAtom _ x -> - extendSubst (b@> SubstVal x) $ exceptToMaybeDecls resultTy decls finalResult + extendSubst (b@> SubstVal x) $ exceptToMaybeBlock resultTy (Abs decls finalResult) _ -> emitMaybeCase maybeResult (MaybeTy resultTy) (return $ NothingAtom $ sink resultTy) (\v -> extendSubst (b@> SubstVal v) $ - exceptToMaybeDecls (sink resultTy) decls finalResult) + exceptToMaybeBlock (sink resultTy) (Abs decls finalResult)) exceptToMaybeExpr :: Emits o => SExpr i -> HandlerM i o (SAtom o) exceptToMaybeExpr expr = case expr of @@ -1076,15 +1076,19 @@ exceptToMaybeExpr expr = case expr of resultTy' <- substM $ MaybeTy resultTy buildCase e' resultTy' \i v -> do Abs b body <- return $ alts !! i - extendSubst (b @> SubstVal v) $ exceptToMaybeBlock body + extendSubst (b @> SubstVal v) do + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + exceptToMaybeBlock blockResultTy body Atom x -> do x' <- substM x let ty = getType x' return $ JustAtom ty x' PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do ixTy <- substM ixTy' - maybes <- buildForAnn (getNameHint b) ann ixTy \i -> - extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeBlock body + maybes <- buildForAnn (getNameHint b) ann ixTy \i -> do + extendSubst (b@>Rename (atomVarName i)) do + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + exceptToMaybeBlock blockResultTy body catMaybesE maybes PrimOp (MiscOp (ThrowException _)) -> do ty <- substM $ getType expr @@ -1094,7 +1098,8 @@ exceptToMaybeExpr expr = case expr of BinaryLamExpr h ref body <- return lam result <- emitRunState noHint s' \h' ref' -> extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do - exceptToMaybeBlock body + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + exceptToMaybeBlock blockResultTy body (maybeAns, newState) <- fromPair result a <- substM $ getType expr emitMaybeCase maybeAns (MaybeTy a) @@ -1104,14 +1109,17 @@ exceptToMaybeExpr expr = case expr of monoid' <- substM monoid PairTy _ accumTy <- substM resultTy result <- emitRunWriter noHint accumTy monoid' \h' ref' -> - extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) $ - exceptToMaybeBlock body + extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + exceptToMaybeBlock blockResultTy body (maybeAns, accumResult) <- fromPair result a <- substM $ getType expr emitMaybeCase maybeAns (MaybeTy a) (return $ NothingAtom $ sink a) (\ans -> return $ JustAtom (sink a) $ PairVal ans (sink accumResult)) - PrimOp (Hof (TypedHof _ (While body))) -> runMaybeWhile $ exceptToMaybeBlock body + PrimOp (Hof (TypedHof _ (While body))) -> do + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + runMaybeWhile $ exceptToMaybeBlock (sink blockResultTy) body _ -> do expr' <- substM expr case hasExceptions expr' of diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index 9ea111ae3..3ee3b13b1 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -180,9 +180,9 @@ instance SourceRenamableE UExpr' where UVar v -> UVar <$> sourceRenameE v ULit l -> return $ ULit l ULam lam -> ULam <$> sourceRenameE lam - UPi (UPiExpr pats appExpl eff body) -> + UPi (UPiExpr (attrs, pats) appExpl eff body) -> sourceRenameB pats \pats' -> - UPi <$> (UPiExpr pats' <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) + UPi <$> (UPiExpr (attrs, pats') <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) UApp f xs ys -> UApp <$> sourceRenameE f <*> forM xs sourceRenameE <*> forM ys (\(name, y) -> (name,) <$> sourceRenameE y) @@ -224,7 +224,6 @@ instance SourceRenamableE UEffect where sourceRenameE (URWSEffect rws name) = URWSEffect rws <$> sourceRenameE name sourceRenameE UExceptionEffect = return UExceptionEffect sourceRenameE UIOEffect = return UIOEffect - sourceRenameE (UUserEffect name) = UUserEffect <$> sourceRenameE name instance SourceRenamableE a => SourceRenamableE (WithSrcE a) where sourceRenameE (WithSrcE pos e) = addSrcContext pos $ @@ -246,20 +245,20 @@ instance SourceRenamableB UTopDecl where sourceRenameUBinder UPunVar tyConName \tyConName' -> do structDef' <- sourceRenameE structDef cont $ UStructDecl tyConName' structDef' - UInterface paramBs methodTys className methodNames -> do + UInterface (attrs, paramBs) methodTys className methodNames -> do Abs paramBs' (ListE methodTys') <- sourceRenameB paramBs \paramBs' -> do methodTys' <- mapM sourceRenameE methodTys return $ Abs paramBs' $ ListE methodTys' sourceRenameUBinder UClassVar className \className' -> sourceRenameUBinderNest UMethodVar methodNames \methodNames' -> - cont $ UInterface paramBs' methodTys' className' methodNames' - UInstance className conditions params methodDefs instanceName expl -> do + cont $ UInterface (attrs, paramBs') methodTys' className' methodNames' + UInstance className (roleExpls, conditions) params methodDefs instanceName expl -> do className' <- sourceRenameE className Abs conditions' (PairE (ListE params') (ListE methodDefs')) <- sourceRenameE $ Abs conditions (PairE (ListE params) $ ListE methodDefs) sourceRenameB instanceName \instanceName' -> - cont $ UInstance className' conditions' params' methodDefs' instanceName' expl + cont $ UInstance className' (roleExpls, conditions') params' methodDefs' instanceName' expl UEffectDecl opTypes effName opNames -> do opTypes' <- mapM (\(UEffectOpType p ty) -> (UEffectOpType p) <$> sourceRenameE ty) opTypes sourceRenameUBinder UEffectVar effName \effName' -> @@ -278,8 +277,8 @@ instance SourceRenamableB UDecl' where UPass -> cont UPass instance SourceRenamableE ULamExpr where - sourceRenameE (ULamExpr args expl effs resultTy body) = - sourceRenameB args \args' -> ULamExpr args' + sourceRenameE (ULamExpr (expls, args) expl effs resultTy body) = + sourceRenameB args \args' -> ULamExpr (expls, args') <$> pure expl <*> mapM sourceRenameE effs <*> mapM sourceRenameE resultTy @@ -305,9 +304,6 @@ instance (SourceRenamableB b1, SourceRenamableB b2) => SourceRenamableB (PairB b sourceRenameB b2 \b2' -> cont $ PairB b1' b2' -instance SourceRenamableB b => SourceRenamableB (WithExpl b) where - sourceRenameB (WithExpl x b) cont = sourceRenameB b \b' -> cont $ WithExpl x b' - sourceRenameUBinderNest :: (Color c, Renamer m, Distinct o) => (forall l. Name c l -> UVar l) @@ -340,15 +336,15 @@ sourceRenameUBinder asUVar ubinder cont = case ubinder of UIgnore -> cont UIgnore instance SourceRenamableE UDataDef where - sourceRenameE (UDataDef tyConName paramBs dataCons) = do + sourceRenameE (UDataDef tyConName (expls, paramBs) dataCons) = do sourceRenameB paramBs \paramBs' -> do dataCons' <- forM dataCons \(dataConName, argBs) -> do argBs' <- sourceRenameE argBs return (dataConName, argBs') - return $ UDataDef tyConName paramBs' dataCons' + return $ UDataDef tyConName (expls, paramBs') dataCons' instance SourceRenamableE UStructDef where - sourceRenameE (UStructDef tyConName paramBs fields methods) = do + sourceRenameE (UStructDef tyConName (expls, paramBs) fields methods) = do sourceRenameB paramBs \paramBs' -> do fields' <- forM fields \(fieldName, ty) -> do ty' <- sourceRenameE ty @@ -356,7 +352,7 @@ instance SourceRenamableE UStructDef where methods' <- forM methods \(ann, methodName, lam) -> do lam' <- sourceRenameE lam return (ann, methodName, lam') - return $ UStructDef tyConName paramBs' fields' methods' + return $ UStructDef tyConName (expls, paramBs') fields' methods' instance SourceRenamableE UDataDefTrail where sourceRenameE (UDataDefTrail args) = sourceRenameB args \args' -> diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 5e29db46a..5b13ef624 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -18,7 +18,6 @@ import Control.Monad.State.Strict import Name import IRVariants import Types.Core -import Types.Primitives import Core import qualified RawName as R import Err @@ -444,20 +443,16 @@ instance (BindsNames b, SubstB v b, SinkableV v) instance FromName v => SubstE v UnitE where substE _ UnitE = UnitE +instance SubstB v b => SubstB v (WithAttrB a b) where + substB env (WithAttrB x b) cont = + substB env b \env' b' -> cont env' $ WithAttrB x b' + instance (Traversable f, SubstE v e) => SubstE v (ComposeE f e) where substE env (ComposeE xs) = ComposeE $ fmap (substE env) xs instance (SubstE v e1, SubstE v e2) => SubstE v (PairE e1 e2) where substE env (PairE x y) = PairE (substE env x) (substE env y) -instance SubstB v b => SubstB v (WithExpl b) where - substB env (WithExpl x b) cont = - substB env b \env' b' -> cont env' $ WithExpl x b' - -instance (FromName v, SubstB v CBinder) => SubstB v RolePiBinder where - substB env (RolePiBinder role b) cont = - substB env b \env' b' -> cont env' $ RolePiBinder role b' - instance (SubstE v e1, SubstE v e2) => SubstE v (EitherE e1 e2) where substE env (LeftE x) = LeftE $ substE env x substE env (RightE x) = RightE $ substE env x diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index a07a9da47..fdcafcc51 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -43,7 +43,7 @@ import qualified LLVM.AST import AbstractSyntax import Builder -import CheckType ( CheckableE (..), asFFIFunType, checkHasType, checkExtends, checkDestLam) +import CheckType ( CheckableE (..), asFFIFunType, checkHasType) #ifdef DEX_DEBUG import CheckType (checkTypesM) #endif @@ -536,37 +536,32 @@ whenOpt x act = getConfig <&> optLevel >>= \case NoOptimize -> return x Optimize -> act x -evalBlock :: (Topper m, Mut n) => CBlock n -> m n (CAtom n) +evalBlock :: (Topper m, Mut n) => TopBlock CoreIR n -> m n (CAtom n) evalBlock typed = do -- Be careful when adding new compilation passes here. If you do, be sure to -- also check compileTopLevelFun, below, and Export.prepareFunctionForExport. -- In most cases it should be easiest to add new passes to simpOptimizations or -- loweredOptimizations, below, because those are reused in all three places. - checkEffects Pure typed synthed <- checkPass SynthPass $ synthTopE typed - simplifiedBlock <- checkPass SimpPass $ simplifyTopBlock synthed - SimplifiedBlock simp recon <- return simplifiedBlock - checkEffects Pure simp - NullaryLamExpr opt <- simpOptimizations $ NullaryLamExpr simp - checkEffects Pure opt + SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock synthed + opt <- simpOptimizations simp simpResult <- case opt of - AtomicBlock result -> return result + TopLam _ _ (LamExpr Empty (WithoutDecls result)) -> return result _ -> do - lowered <- checkPass LowerPass $ lowerFullySequential $ NullaryLamExpr opt - checkDestLam lowered - lOpt <- loweredOptimizations lowered - checkDestLam lOpt + lowered <- checkPass LowerPass $ lowerFullySequential True opt + lOpt <- checkPass OptPass $ loweredOptimizations lowered cc <- getEntryFunCC impOpt <- checkPass ImpPass $ toImpFunction cc lOpt llvmOpt <- packageLLVMCallable impOpt resultVals <- liftIO $ callEntryFun llvmOpt [] - PiType bs (EffTy _ resultTy') <- return $ getDestLamExprType lOpt + TopLam _ destTy _ <- return lOpt + PiType bs (EffTy _ resultTy') <- return $ piTypeWithoutDest destTy let resultTy = ignoreHoistFailure $ hoist bs resultTy' repValAtom =<< repValFromFlatList resultTy resultVals applyReconTop recon simpResult {-# SCC evalBlock #-} -simpOptimizations :: Topper m => SLam n -> m n (SLam n) +simpOptimizations :: Topper m => STopLam n -> m n (STopLam n) simpOptimizations simp = do analyzed <- whenOpt simp $ checkPass OccAnalysisPass . analyzeOccurrences inlined <- whenOpt analyzed $ checkPass InlinePass . inlineBindings @@ -574,7 +569,7 @@ simpOptimizations simp = do inlined2 <- whenOpt analyzed2 $ checkPass InlinePass . inlineBindings whenOpt inlined2 $ checkPass OptPass . optimize -loweredOptimizations :: Topper m => DestLamExpr n -> m n (DestLamExpr n) +loweredOptimizations :: Topper m => STopLam n -> m n (STopLam n) loweredOptimizations lowered = do lopt <- whenOpt lowered $ checkPass LowerOptPass . (dceTop >=> hoistLoopInvariant) @@ -584,7 +579,7 @@ loweredOptimizations lowered = do logFiltered l VectPass $ return [TextOut $ pprint errs] checkPass VectPass $ return vo -loweredOptimizationsNoDest :: Topper m => SLam n -> m n (SLam n) +loweredOptimizationsNoDest :: Topper m => STopLam n -> m n (STopLam n) loweredOptimizationsNoDest lowered = do lopt <- whenOpt lowered $ checkPass LowerOptPass . (dceTop >=> hoistLoopInvariant) @@ -594,7 +589,7 @@ loweredOptimizationsNoDest lowered = do evalSpecializations :: (Topper m, Mut n) => [TopFunName n] -> m n () evalSpecializations fs = do fSimps <- toposortAnnVars <$> catMaybes <$> forM fs \f -> lookupTopFun f >>= \case - DexTopFun _ _ simp Waiting -> return $ Just (f, simp) + DexTopFun _ simp Waiting -> return $ Just (f, simp) _ -> return Nothing forM_ fSimps \(f, simp) -> do -- Prevents infinite loop in case compiling `v` ends up requiring `v` @@ -608,14 +603,14 @@ evalSpecializations fs = do evalDictSpecializations :: (Topper m, Mut n) => [SpecDictName n] -> m n () evalDictSpecializations ds = do - -- TODO Do we have to do these in order, like evalSpecializations, or are they - -- independent enough not to need it? - -- TODO Do we need to gate the status of these, too? + -- -- TODO Do we have to do these in order, like evalSpecializations, or are they + -- -- independent enough not to need it? + -- -- TODO Do we need to gate the status of these, too? forM_ ds \dName -> do SpecializedDict _ (Just fs) <- lookupSpecDict dName fs' <- forM fs \lam -> do opt <- simpOptimizations lam - lowered <- checkPass LowerPass $ lowerFullySequentialNoDest opt + lowered <- checkPass LowerPass $ lowerFullySequential False opt loweredOptimizationsNoDest lowered updateTopEnv $ LowerDictSpecialization dName fs' return () @@ -647,10 +642,10 @@ execUDecl mname decl = do {-# SCC execUDecl #-} compileTopLevelFun :: (Topper m, Mut n) - => CallingConvention -> SLam n -> m n (ImpFunction n) + => CallingConvention -> STopLam n -> m n (ImpFunction n) compileTopLevelFun cc fSimp = do fOpt <- simpOptimizations fSimp - fLower <- checkPass LowerPass $ lowerFullySequential fOpt + fLower <- checkPass LowerPass $ lowerFullySequential True fOpt flOpt <- loweredOptimizations fLower checkPass ImpPass $ toImpFunction cc flOpt {-# SCC compileTopLevelFun #-} @@ -659,7 +654,8 @@ printCodegen :: (Topper m, Mut n) => CAtom n -> m n String printCodegen x = do block <- liftBuilder $ buildBlock do emitExpr $ PrimOp $ MiscOp $ ShowAny $ sink x - getDexString =<< evalBlock block + topBlock <- asTopBlock block + getDexString =<< evalBlock topBlock loadObject :: (Topper m, Mut n) => FunObjCodeName n -> m n NativeFunction loadObject fname = @@ -733,7 +729,7 @@ funNameToObj :: (EnvReader m, Fallible1 m) => ImpFunName n -> m n (FunObjCodeName n) funNameToObj v = do lookupEnv v >>= \case - TopFunBinding (DexTopFun _ _ _ (Finished impl)) -> return $ topFunObjCode impl + TopFunBinding (DexTopFun _ _ (Finished impl)) -> return $ topFunObjCode impl b -> error $ "couldn't find object cache entry for " ++ pprint v ++ "\ngot:\n" ++ pprint b withCompileTime :: MonadIO m => m Result -> m Result @@ -756,11 +752,6 @@ checkPass name cont = do #endif return result -checkEffects :: (Topper m, HasEffects e r, IRRep r) => EffectRow r n -> e n -> m n () -checkEffects allowedEffs e = do - let actualEffs = getEffects e - checkExtends allowedEffs actualEffs - addResultCtx :: SourceBlock -> Result -> Result addResultCtx block (Result outs errs) = Result outs (addSrcTextContext (sbOffset block) (sbText block) errs) @@ -946,8 +937,8 @@ instance Generic TopStateEx where getLinearizationType :: SymbolicZeros -> CType n -> EnvReaderT Except n (Int, Int, CType n) getLinearizationType zeros = \case - Pi (CorePiType ExplicitApp bs (EffTy Pure resultTy)) -> do - (numIs, numEs) <- getNumImplicits $ fst $ unzipExpls bs + Pi (CorePiType ExplicitApp expls bs (EffTy Pure resultTy)) -> do + (numIs, numEs) <- getNumImplicits expls refreshAbs (Abs bs resultTy) \bs' resultTy' -> do PairB _ bsE <- return $ splitNestAt numIs bs' let explicitArgTys = nestToList (\b -> sink $ binderType b) bsE @@ -960,7 +951,7 @@ getLinearizationType zeros = \case Just rtt -> return rtt Nothing -> throw TypeErr $ "No tangent type for: " ++ pprint resultTy' let tanFunTy = Pi $ nonDepPiType argTanTys Pure resultTanTy - let fullTy = CorePiType ExplicitApp bs' $ EffTy Pure (PairTy resultTy' tanFunTy) + let fullTy = CorePiType ExplicitApp expls bs' $ EffTy Pure (PairTy resultTy' tanFunTy) return (numIs, numEs, Pi fullTy) _ -> throw TypeErr $ "Can't define a custom linearization for implicit or impure functions" where diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 29032ff49..904e608d1 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -42,8 +42,8 @@ runTransposeM cont = runReaderT1 (ListE []) $ runSubstReaderT idSubst $ cont transposeTopFun :: (MonadFail1 m, EnvReader m) - => LamExpr SimpIR n -> m n (LamExpr SimpIR n) -transposeTopFun lam = liftBuilder $ runTransposeM do + => STopLam n -> m n (STopLam n) +transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do (Abs bsNonlin (Abs bLin body), Abs bsNonlin'' outTy) <- unpackLinearLamExpr lam refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do outTy' <- applyRename (bsNonlin''@@> nestToNames bsNonlin') outTy @@ -54,7 +54,11 @@ transposeTopFun lam = liftBuilder $ runTransposeM do withAccumulator inTy \refSubstVal -> extendSubst (bLin @> refSubstVal) $ transposeBlock body (sink ct) - return $ LamExpr (bsNonlin' >>> UnaryNest bCT) body' + EffTy _ bodyTy <- blockEffTy body' + let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (EffTy Pure bodyTy) + let lamT = LamExpr (bsNonlin' >>> UnaryNest bCT) body' + return $ TopLam False piTy lamT +transposeTopFun (TopLam True _ _) = error "shouldn't be transposing in destination passing style" unpackLinearLamExpr :: (MonadFail1 m, EnvReader m) => LamExpr SimpIR n @@ -63,7 +67,7 @@ unpackLinearLamExpr unpackLinearLamExpr lam@(LamExpr bs body) = do let numNonlin = nestLength bs - 1 PairB bsNonlin (UnaryNest bLin) <- return $ splitNestAt numNonlin bs - PiType bsTy (EffTy _ resultTy) <- return $ getLamExprType lam + PiType bsTy (EffTy _ resultTy) <- getLamExprType lam PairB bsNonlinTy (UnaryNest bLinTy) <- return $ splitNestAt numNonlin bsTy let resultTy' = ignoreHoistFailure $ hoist bLinTy resultTy return ( Abs bsNonlin $ Abs bLin body @@ -154,7 +158,7 @@ extendLinRegions v cont = local (\(ListE vs) -> ListE (v:vs)) cont -- === actual pass === transposeBlock :: Emits o => SBlock i -> SAtom o -> TransposeM i o () -transposeBlock (Block _ decls result) ct = transposeWithDecls decls result ct +transposeBlock (Abs decls result) ct = transposeWithDecls decls result ct transposeWithDecls :: Emits o => Nest SDecl i i' -> SAtom i' -> SAtom o -> TransposeM i o () transposeWithDecls Empty atom ct = transposeAtom atom ct diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 8953b1944..d09f7c69f 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -131,8 +131,6 @@ type ClassName = Name ClassNameC type TyConName = Name TyConNameC type DataConName = Name DataConNameC type EffectName = Name EffectNameC -type EffectOpName = Name EffectOpNameC -type HandlerName = Name HandlerNameC type InstanceName = Name InstanceNameC type MethodName = Name MethodNameC type ModuleName = Name ModuleNameC @@ -153,7 +151,8 @@ data TyConDef n where -- binder name is in UExpr and Env TyConDef :: SourceName - -> RolePiBinders n l + -> [RoleExpl] + -> Nest CBinder n l -> DataConDefs l -> TyConDef n @@ -175,19 +174,13 @@ data ParamRole = TypeParam | DictParam | DataParam deriving (Show, Generic, Eq) data TyConParams n = TyConParams [Explicitness] [Atom CoreIR n] deriving (Show, Generic) --- The Type is the type of the result expression (and thus the type of the --- block). It's given by querying the result expression's type, and checking --- that it doesn't have any free names bound by the decls in the block. We store --- it separately as an optimization, to avoid having to traverse the block. --- If the decls are empty we can skip the type annotation, because then we can --- cheaply query the result, and, more importantly, there's no risk of having a --- type that mentions local variables. -data Block (r::IR) (n::S) where - Block :: BlockAnn r n l -> Nest (Decl r) n l -> Atom r l -> Block r n +type WithDecls (r::IR) = Abs (Decls r) :: E -> E +type Block (r::IR) = WithDecls r (Atom r) :: E -data BlockAnn r n l where - BlockAnn :: EffTy r n -> BlockAnn r n l - NoBlockAnn :: BlockAnn r n n +type TopBlock = TopLam -- used for nullary lambda +type IsDestLam = Bool +data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) + deriving (Show, Generic) data LamExpr (r::IR) (n::S) where LamExpr :: Nest (Binder r) n l -> Block r l -> LamExpr r n @@ -221,10 +214,8 @@ data TabPiType (r::IR) (n::S) where data PiType (r::IR) (n::S) where PiType :: Nest (Binder r) n l -> EffTy r l -> PiType r n -type CoreBinders = Nest (WithExpl CBinder) - data CorePiType (n::S) where - CorePiType :: AppExplicitness -> CoreBinders n l -> EffTy CoreIR l -> CorePiType n + CorePiType :: AppExplicitness -> [Explicitness] -> Nest CBinder n l -> EffTy CoreIR l -> CorePiType n data DepPairType (r::IR) (n::S) where DepPairType :: DepPairExplicitness -> Binder r n l -> Type r l -> DepPairType r n @@ -318,8 +309,7 @@ data PrimOp (r::IR) (n::S) where MiscOp :: MiscOp r n -> PrimOp r n Hof :: TypedHof r n -> PrimOp r n RefOp :: Atom r n -> RefOp r n -> PrimOp r n - DAMOp :: DAMOp SimpIR n -> PrimOp SimpIR n - UserEffectOp :: UserEffectOp n -> PrimOp CoreIR n + DAMOp :: DAMOp SimpIR n -> PrimOp SimpIR n deriving instance IRRep r => Show (PrimOp r n) deriving via WrapE (PrimOp r) n instance IRRep r => Generic (PrimOp r n) @@ -399,12 +389,6 @@ data RefOp r n = | ProjRef (Type r n) Projection deriving (Show, Generic) -data UserEffectOp n = - Handle (HandlerName n) [CAtom n] (CBlock n) - | Resume (CType n) (CAtom n) -- Resume from effect handler (type, arg) - | Perform (CAtom n) Int -- Call an effect operation (effect name) (op #) - deriving (Show, Generic) - -- === IR variants === type CAtom = Atom CoreIR @@ -416,6 +400,7 @@ type CDecl = Decl CoreIR type CDecls = Decls CoreIR type CAtomName = AtomName CoreIR type CAtomVar = AtomVar CoreIR +type CTopLam = TopLam CoreIR type SAtom = Atom SimpIR type SType = Type SimpIR @@ -429,6 +414,7 @@ type SAtomVar = AtomVar SimpIR type SBinder = Binder SimpIR type SRepVal = RepVal SimpIR type SLam = LamExpr SimpIR +type STopLam = TopLam SimpIR -- === newtypes === @@ -456,16 +442,15 @@ isSumCon = \case -- === type classes === -data RolePiBinder (n::S) (l::S) = RolePiBinder ParamRole (WithExpl CBinder n l) - deriving (Show, Generic) -type RolePiBinders = Nest RolePiBinder +type RoleExpl = (ParamRole, Explicitness) data ClassDef (n::S) where ClassDef :: SourceName -- name of class -> [SourceName] -- method source names -> [Maybe SourceName] -- parameter source names - -> RolePiBinders n1 n2 -- parameters + -> [RoleExpl] -- parameter info + -> Nest CBinder n1 n2 -- parameters -> Nest CBinder n2 n3 -- superclasses -> [CorePiType n3] -- method types -> ClassDef n1 @@ -473,7 +458,8 @@ data ClassDef (n::S) where data InstanceDef (n::S) where InstanceDef :: ClassName n1 - -> RolePiBinders n1 n2 -- parameters (types and dictionaries) + -> [RoleExpl] -- parameter info + -> Nest CBinder n1 n2 -- parameters (types and dictionaries) -> [CAtom n2] -- class parameters -> InstanceBody n2 -> InstanceDef n1 @@ -591,8 +577,8 @@ data TopEnvUpdate n = | AddCustomRule (CAtomName n) (AtomRules n) | UpdateLoadedModules ModuleSourceName (ModuleName n) | UpdateLoadedObjects (FunObjCodeName n) NativeFunction - | FinishDictSpecialization (SpecDictName n) [LamExpr SimpIR n] - | LowerDictSpecialization (SpecDictName n) [LamExpr SimpIR n] + | FinishDictSpecialization (SpecDictName n) [TopLam SimpIR n] + | LowerDictSpecialization (SpecDictName n) [TopLam SimpIR n] | UpdateTopFunEvalStatus (TopFunName n) (TopFunEvalStatus n) | UpdateInstanceDef (InstanceName n) (InstanceDef n) | UpdateTyConDef (TyConName n) (TyConDef n) @@ -660,9 +646,6 @@ data Binding (c::C) (n::S) where ClassBinding :: ClassDef n -> Binding ClassNameC n InstanceBinding :: InstanceDef n -> CorePiType n -> Binding InstanceNameC n MethodBinding :: ClassName n -> Int -> Binding MethodNameC n - EffectBinding :: EffectDef n -> Binding EffectNameC n - HandlerBinding :: HandlerDef n -> Binding HandlerNameC n - EffectOpBinding :: EffectOpDef n -> Binding EffectOpNameC n TopFunBinding :: TopFun n -> Binding TopFunNameC n FunObjCodeBinding :: CFunction n -> Binding FunObjCodeNameC n ModuleBinding :: Module n -> Binding ModuleNameC n @@ -711,33 +694,6 @@ instance RenameE EffectDef deriving instance Show (EffectDef n) deriving via WrapE EffectDef n instance Generic (EffectDef n) -data HandlerDef (n::S) where - HandlerDef :: EffectName n - -> CBinder n r -- body type arg - -> RolePiBinders r l - -> EffectRow CoreIR l - -> CType l -- return type - -> [Block CoreIR l] -- effect operations - -> Block CoreIR l -- return body - -> HandlerDef n - -instance GenericE HandlerDef where - type RepE HandlerDef = - EffectName `PairE` Abs (CBinder `PairB` RolePiBinders) - (EffectRow CoreIR `PairE` CType `PairE` ListE (Block CoreIR) `PairE` Block CoreIR) - fromE (HandlerDef name bodyTyArg bs effs ty ops ret) = - name `PairE` Abs (bodyTyArg `PairB` bs) (effs `PairE` ty `PairE` ListE ops `PairE` ret) - toE (name `PairE` Abs (bodyTyArg `PairB` bs) (effs `PairE` ty `PairE` ListE ops `PairE` ret)) = - HandlerDef name bodyTyArg bs effs ty ops ret - -instance SinkableE HandlerDef -instance HoistableE HandlerDef -instance AlphaEqE HandlerDef -instance AlphaHashableE HandlerDef -instance RenameE HandlerDef -deriving instance Show (HandlerDef n) -deriving via WrapE HandlerDef n instance Generic (HandlerDef n) - data EffectOpType (n::S) where EffectOpType :: UResumePolicy -> CType n -> EffectOpType n @@ -756,7 +712,7 @@ deriving instance Show (EffectOpType n) deriving via WrapE EffectOpType n instance Generic (EffectOpType n) instance GenericE SpecializedDictDef where - type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (LamExpr SimpIR)) + type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (TopLam SimpIR)) fromE (SpecializedDict ab methods) = ab `PairE` methods' where methods' = case methods of Just xs -> LeftE (ListE xs) Nothing -> RightE UnitE @@ -777,7 +733,7 @@ data EvalStatus a = Waiting | Running | Finished a type TopFunEvalStatus n = EvalStatus (TopFunLowerings n) data TopFun (n::S) = - DexTopFun (TopFunDef n) (PiType SimpIR n) (LamExpr SimpIR n) (TopFunEvalStatus n) + DexTopFun (TopFunDef n) (TopLam SimpIR n) (TopFunEvalStatus n) | FFITopFun String IFunType deriving (Show, Generic) @@ -866,7 +822,6 @@ data Effect (r::IR) (n::S) = RWSEffect RWS (Atom r n) | ExceptionEffect | IOEffect - | UserEffect (Name EffectNameC n) | InitEffect -- Internal effect modeling writing to a destination. deriving (Generic, Show) @@ -929,7 +884,7 @@ data SpecializedDictDef n = -- Methods (thunked if nullary), if they're available. -- We create specialized dict names during simplification, but we don't -- actually simplify/lower them until we return to TopLevel - (Maybe [LamExpr SimpIR n]) + (Maybe [TopLam SimpIR n]) deriving (Show, Generic) -- TODO: extend with AD-oriented specializations, backend-specific specializations etc. @@ -942,17 +897,13 @@ data LinearizationSpec (n::S) = LinearizationSpec (TopFunName n) [Active] deriving (Show, Generic) --- === BindsOneAtomName === +-- === Binder utils === -class BindsOneName b (AtomNameC r) => BindsOneAtomName (r::IR) (b::B) | b -> r where - binderType :: b n l -> Type r n - binderVar :: DExt n l => b n l -> AtomVar r l +binderType :: Binder r n l -> Type r n +binderType (_:>ty) = ty -bindersTypes :: (IRRep r, Distinct l, ProvesExt b, BindsNames b, BindsOneAtomName r b) - => Nest b n l -> [Type r l] -bindersTypes Empty = [] -bindersTypes n@(Nest b bs) = ty : bindersTypes bs - where ty = withExtEvidence n $ sink (binderType b) +binderVar :: (IRRep r, DExt n l) => Binder r n l -> AtomVar r l +binderVar (b:>ty) = AtomVar (binderName b) (sink ty) nestToAtomVars :: (Distinct l, Ext n l, IRRep r) => Nest (Binder r) n l -> [AtomVar r l] @@ -961,18 +912,6 @@ nestToAtomVars = \case Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ sink (binderVar b) : nestToAtomVars bs -instance IRRep r => BindsOneAtomName r (BinderP (AtomNameC r) (Type r)) where - binderType (_ :> ty) = ty - binderVar (b:>t) = AtomVar (binderName b) (sink t) - -instance BindsOneAtomName CoreIR b => BindsOneAtomName CoreIR (WithExpl b) where - binderType (WithExpl _ b) = binderType b - binderVar (WithExpl _ b) = binderVar b - -toBinderNest :: BindsOneAtomName r b => Nest b n l -> Nest (Binder r) n l -toBinderNest Empty = Empty -toBinderNest (Nest b bs) = Nest (asNameBinder b :> binderType b) (toBinderNest bs) - -- === ToBinding === atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n @@ -1006,14 +945,6 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where toBinding (LeftE e) = toBinding e toBinding (RightE e) = toBinding e --- === HasArgType === - -class HasArgType (e::E) (r::IR) | e -> r where - argType :: e n -> Type r n - -instance HasArgType (TabPiType r) r where - argType (TabPiType _ (_:>ty) _) = ty - -- === Pattern synonyms === -- XXX: only use this pattern when you're actually expecting a type. If it's @@ -1125,12 +1056,11 @@ pattern UnaryLamExpr b body = LamExpr (UnaryNest b) body pattern BinaryLamExpr :: Binder r n l1 -> Binder r l1 l2 -> Block r l2 -> LamExpr r n pattern BinaryLamExpr b1 b2 body = LamExpr (BinaryNest b1 b2) body -pattern AtomicBlock :: Atom r n -> Block r n -pattern AtomicBlock atom <- Block _ Empty atom - where AtomicBlock atom = Block NoBlockAnn Empty atom +pattern WithoutDecls :: e n -> WithDecls r e n +pattern WithoutDecls x = Abs Empty x exprBlock :: IRRep r => Block r n -> Maybe (Expr r n) -exprBlock (Block _ (Nest (Let b (DeclBinding _ expr)) Empty) (Var (AtomVar n _))) +exprBlock (Abs (Nest (Let b (DeclBinding _ expr)) Empty) (Var (AtomVar n _))) | n == binderName b = Just expr exprBlock _ = Nothing {-# INLINE exprBlock #-} @@ -1228,10 +1158,10 @@ instance AlphaEqE DataConDefs instance AlphaHashableE DataConDefs instance GenericE TyConDef where - type RepE TyConDef = PairE (LiftE SourceName) (Abs RolePiBinders DataConDefs) - fromE (TyConDef sourceName bs cons) = PairE (LiftE sourceName) (Abs bs cons) + type RepE TyConDef = PairE (LiftE (SourceName, [RoleExpl])) (Abs (Nest CBinder) DataConDefs) + fromE (TyConDef sourceName expls bs cons) = PairE (LiftE (sourceName, expls)) (Abs bs cons) {-# INLINE fromE #-} - toE (PairE (LiftE sourceName) (Abs bs cons)) = TyConDef sourceName bs cons + toE (PairE (LiftE (sourceName, expls)) (Abs bs cons)) = TyConDef sourceName expls bs cons {-# INLINE toE #-} deriving instance Show (TyConDef n) @@ -1243,7 +1173,7 @@ instance AlphaEqE TyConDef instance AlphaHashableE TyConDef instance HasNameHint (TyConDef n) where - getNameHint (TyConDef v _ _) = getNameHint v + getNameHint (TyConDef v _ _ _) = getNameHint v instance GenericE DataConDef where type RepE DataConDef = (LiftE (SourceName, [[Projection]])) @@ -1324,29 +1254,6 @@ instance IRRep r => RenameE (BaseMonoid r) instance IRRep r => AlphaEqE (BaseMonoid r) instance IRRep r => AlphaHashableE (BaseMonoid r) -instance GenericE UserEffectOp where - type RepE UserEffectOp = EitherE3 - {- Handle -} (HandlerName `PairE` ListE CAtom `PairE` CBlock) - {- Resume -} (CType `PairE` CAtom) - {- Perform -} (CAtom `PairE` LiftE Int) - fromE = \case - Handle name args body -> Case0 $ name `PairE` ListE args `PairE` body - Resume x y -> Case1 $ x `PairE` y - Perform x i -> Case2 $ x `PairE` LiftE i - {-# INLINE fromE #-} - toE = \case - Case0 (name `PairE` ListE args `PairE` body) -> Handle name args body - Case1 (x `PairE` y) -> Resume x y - Case2 (x `PairE` LiftE i) -> Perform x i - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE UserEffectOp -instance HoistableE UserEffectOp -instance RenameE UserEffectOp -instance AlphaEqE UserEffectOp -instance AlphaHashableE UserEffectOp - instance IRRep r => GenericE (DAMOp r) where type RepE (DAMOp r) = EitherE5 {- Seq -} (EffectRow r `PairE` LiftE Direction `PairE` IxType r `PairE` Atom r `PairE` LamExpr r) @@ -1690,11 +1597,10 @@ instance IRRep r => GenericE (PrimOp r) where {- MemOp -} (MemOp r) {- VectorOp -} (VectorOp r) {- MiscOp -} (MiscOp r) - ) (EitherE4 + ) (EitherE3 {- Hof -} (TypedHof r) {- RefOp -} (Atom r `PairE` RefOp r) {- DAMOp -} (WhenSimp r (DAMOp SimpIR)) - {- UserEffectOp -} (WhenCore r UserEffectOp) ) fromE = \case UnOp op x -> Case0 $ Case0 $ LiftE op `PairE` x @@ -1705,7 +1611,6 @@ instance IRRep r => GenericE (PrimOp r) where Hof op -> Case1 $ Case0 op RefOp r op -> Case1 $ Case1 $ r `PairE` op DAMOp op -> Case1 $ Case2 $ WhenIRE op - UserEffectOp op -> Case1 $ Case3 $ WhenIRE op {-# INLINE fromE #-} toE = \case @@ -1720,7 +1625,6 @@ instance IRRep r => GenericE (PrimOp r) where Case0 op -> Hof op Case1 (r `PairE` op) -> RefOp r op Case2 (WhenIRE op) -> DAMOp op - Case3 (WhenIRE op) -> UserEffectOp op _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} @@ -1890,27 +1794,6 @@ instance IRRep r => AlphaEqE (TC r) instance IRRep r => AlphaHashableE (TC r) instance IRRep r => RenameE (TC r) -instance IRRep r => GenericE (Block r) where - type RepE (Block r) = PairE (MaybeE (EffTy r)) (Abs (Nest (Decl r)) (Atom r)) - fromE (Block (BlockAnn effTy) decls result) = PairE (JustE effTy) (Abs decls result) - fromE (Block NoBlockAnn Empty result) = PairE NothingE (Abs Empty result) - fromE _ = error "impossible" - {-# INLINE fromE #-} - toE (PairE (JustE effTy) (Abs decls result)) = Block (BlockAnn effTy) decls result - toE (PairE NothingE (Abs Empty result)) = Block NoBlockAnn Empty result - toE _ = error "impossible" - {-# INLINE toE #-} - -deriving instance IRRep r => Show (BlockAnn r n l) - -instance IRRep r => SinkableE (Block r) -instance IRRep r => HoistableE (Block r) -instance IRRep r => AlphaEqE (Block r) -instance IRRep r => AlphaHashableE (Block r) -instance IRRep r => RenameE (Block r) -deriving instance IRRep r => Show (Block r n) -deriving via WrapE (Block r) n instance IRRep r => Generic (Block r n) - instance IRRep r => GenericB (NonDepNest r ann) where type RepB (NonDepNest r ann) = (LiftB (ListE ann)) `PairB` Nest (AtomNameBinder r) fromB (NonDepNest bs anns) = LiftB (ListE anns) `PairB` bs @@ -1926,39 +1809,15 @@ instance (IRRep r, AlphaEqE ann) => AlphaEqB (NonDepNest r ann) instance (IRRep r, AlphaHashableE ann) => AlphaHashableB (NonDepNest r ann) deriving instance (Show (ann n)) => IRRep r => Show (NonDepNest r ann n l) -instance GenericB RolePiBinder where - type RepB RolePiBinder = PairB (LiftB (LiftE ParamRole)) (WithExpl CBinder) - fromB (RolePiBinder role b) = PairB (LiftB (LiftE role)) b - toB (PairB (LiftB (LiftE role)) b) = RolePiBinder role b - -instance BindsAtMostOneName RolePiBinder (AtomNameC CoreIR) where - RolePiBinder _ b @> x = b @> x - {-# INLINE (@>) #-} - -instance BindsOneName RolePiBinder (AtomNameC CoreIR) where - binderName (RolePiBinder _ b) = binderName b - -instance BindsOneAtomName CoreIR RolePiBinder where - binderType (RolePiBinder _ b) = binderType b - binderVar (RolePiBinder _ b) = binderVar b - -instance ProvesExt RolePiBinder -instance BindsNames RolePiBinder -instance SinkableB RolePiBinder -instance HoistableB RolePiBinder -instance RenameB RolePiBinder -instance AlphaEqB RolePiBinder -instance AlphaHashableB RolePiBinder - instance GenericE ClassDef where type RepE ClassDef = - LiftE (SourceName, [SourceName], [Maybe SourceName]) - `PairE` Abs RolePiBinders (Abs (Nest CBinder) (ListE CorePiType)) - fromE (ClassDef name names paramNames b scs tys) = - LiftE (name, names, paramNames) `PairE` Abs b (Abs scs (ListE tys)) + LiftE (SourceName, [SourceName], [Maybe SourceName], [RoleExpl]) + `PairE` Abs (Nest CBinder) (Abs (Nest CBinder) (ListE CorePiType)) + fromE (ClassDef name names paramNames roleExpls b scs tys) = + LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys)) {-# INLINE fromE #-} - toE (LiftE (name, names, paramNames) `PairE` Abs b (Abs scs (ListE tys))) = - ClassDef name names paramNames b scs tys + toE (LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys))) = + ClassDef name names paramNames roleExpls b scs tys {-# INLINE toE #-} instance SinkableE ClassDef @@ -1971,11 +1830,11 @@ deriving via WrapE ClassDef n instance Generic (ClassDef n) instance GenericE InstanceDef where type RepE InstanceDef = - ClassName `PairE` Abs RolePiBinders (ListE CAtom `PairE` InstanceBody) - fromE (InstanceDef name bs params body) = - name `PairE` Abs bs (ListE params `PairE` body) - toE (name `PairE` Abs bs (ListE params `PairE` body)) = - InstanceDef name bs params body + ClassName `PairE` LiftE [RoleExpl] `PairE` Abs (Nest CBinder) (ListE CAtom `PairE` InstanceBody) + fromE (InstanceDef name expls bs params body) = + name `PairE` LiftE expls `PairE` Abs bs (ListE params `PairE` body) + toE (name `PairE` LiftE expls `PairE` Abs bs (ListE params `PairE` body)) = + InstanceDef name expls bs params body instance SinkableE InstanceDef instance HoistableE InstanceDef @@ -2107,10 +1966,10 @@ deriving instance Show (CoreLamExpr n) deriving via WrapE CoreLamExpr n instance Generic (CoreLamExpr n) instance GenericE CorePiType where - type RepE CorePiType = LiftE AppExplicitness `PairE` Abs CoreBinders (EffTy CoreIR) - fromE (CorePiType ex b effTy) = LiftE ex `PairE` Abs b effTy + type RepE CorePiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) (EffTy CoreIR) + fromE (CorePiType ex exs b effTy) = LiftE (ex, exs) `PairE` Abs b effTy {-# INLINE fromE #-} - toE (LiftE ex `PairE` Abs b effTy) = CorePiType ex b effTy + toE (LiftE (ex, exs) `PairE` Abs b effTy) = CorePiType ex exs b effTy {-# INLINE toE #-} instance SinkableE CorePiType @@ -2176,6 +2035,9 @@ instance IRRep r => AlphaEqE (TabPiType r) where instance IRRep r => AlphaHashableE (TabPiType r) where hashWithSaltE env salt (TabPiType _ b t) = hashWithSaltE env salt $ Abs b t +instance HasNameHint (TabPiType r n) where + getNameHint (TabPiType _ b _) = getNameHint b + instance IRRep r => SinkableE (TabPiType r) instance IRRep r => HoistableE (TabPiType r) instance IRRep r => RenameE (TabPiType r) @@ -2291,16 +2153,29 @@ instance RenameE TopFunDef instance AlphaEqE TopFunDef instance AlphaHashableE TopFunDef +instance IRRep r => GenericE (TopLam r) where + type RepE (TopLam r) = LiftE Bool `PairE` PiType r `PairE` LamExpr r + fromE (TopLam d x y) = LiftE d `PairE` x `PairE` y + {-# INLINE fromE #-} + toE (LiftE d `PairE` x `PairE` y) = TopLam d x y + {-# INLINE toE #-} + +instance IRRep r => SinkableE (TopLam r) +instance IRRep r => HoistableE (TopLam r) +instance IRRep r => RenameE (TopLam r) +instance IRRep r => AlphaEqE (TopLam r) +instance IRRep r => AlphaHashableE (TopLam r) + instance GenericE TopFun where type RepE TopFun = EitherE - (TopFunDef `PairE` PiType SimpIR `PairE` LamExpr SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) + (TopFunDef `PairE` TopLam SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) (LiftE (String, IFunType)) fromE = \case - DexTopFun def ty simp status -> LeftE (def `PairE` ty `PairE` simp `PairE` ComposeE status) + DexTopFun def lam status -> LeftE (def `PairE` lam `PairE` ComposeE status) FFITopFun name ty -> RightE (LiftE (name, ty)) {-# INLINE fromE #-} toE = \case - LeftE (def `PairE` ty `PairE` simp `PairE` ComposeE status) -> DexTopFun def ty simp status + LeftE (def `PairE` lam `PairE` ComposeE status) -> DexTopFun def lam status RightE (LiftE (name, ty)) -> FFITopFun name ty {-# INLINE toE #-} @@ -2371,14 +2246,11 @@ instance GenericE (Binding c) where (WhenC ClassNameC c (ClassDef)) (WhenC InstanceNameC c (InstanceDef `PairE` CorePiType)) (WhenC MethodNameC c (ClassName `PairE` LiftE Int))) - (EitherE7 + (EitherE4 (WhenC TopFunNameC c (TopFun)) (WhenC FunObjCodeNameC c (CFunction)) (WhenC ModuleNameC c (Module)) - (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal))) - (WhenC EffectNameC c (EffectDef)) - (WhenC HandlerNameC c (HandlerDef)) - (WhenC EffectOpNameC c (EffectOpDef))) + (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal)))) (EitherE2 (WhenC SpecializedDictNameC c (SpecializedDictDef)) (WhenC ImpNameC c (LiftE BaseType))) @@ -2394,9 +2266,6 @@ instance GenericE (Binding c) where FunObjCodeBinding cFun -> Case1 $ Case1 $ WhenC $ cFun ModuleBinding m -> Case1 $ Case2 $ WhenC $ m PtrBinding ty p -> Case1 $ Case3 $ WhenC $ LiftE (ty,p) - EffectBinding effDef -> Case1 $ Case4 $ WhenC $ effDef - HandlerBinding hDef -> Case1 $ Case5 $ WhenC $ hDef - EffectOpBinding opDef -> Case1 $ Case6 $ WhenC $ opDef SpecializedDictBinding def -> Case2 $ Case0 $ WhenC $ def ImpNameBinding ty -> Case2 $ Case1 $ WhenC $ LiftE ty {-# INLINE fromE #-} @@ -2412,9 +2281,6 @@ instance GenericE (Binding c) where Case1 (Case1 (WhenC (f))) -> FunObjCodeBinding f Case1 (Case2 (WhenC (m))) -> ModuleBinding m Case1 (Case3 (WhenC ((LiftE (ty,p))))) -> PtrBinding ty p - Case1 (Case4 (WhenC (effDef))) -> EffectBinding effDef - Case1 (Case5 (WhenC (hDef))) -> HandlerBinding hDef - Case1 (Case6 (WhenC (opDef))) -> EffectOpBinding opDef Case2 (Case0 (WhenC (def))) -> SpecializedDictBinding def Case2 (Case1 (WhenC ((LiftE ty)))) -> ImpNameBinding ty _ -> error "impossible" @@ -2471,23 +2337,20 @@ instance IRRep r => BindsNames (Decl r) instance IRRep r => GenericE (Effect r) where type RepE (Effect r) = - EitherE4 (PairE (LiftE RWS) (Atom r)) + EitherE3 (PairE (LiftE RWS) (Atom r)) (LiftE (Either () ())) - (Name EffectNameC) UnitE fromE = \case RWSEffect rws h -> Case0 (PairE (LiftE rws) h) ExceptionEffect -> Case1 (LiftE (Left ())) IOEffect -> Case1 (LiftE (Right ())) - UserEffect name -> Case2 name - InitEffect -> Case3 UnitE + InitEffect -> Case2 UnitE {-# INLINE fromE #-} toE = \case Case0 (PairE (LiftE rws) h) -> RWSEffect rws h Case1 (LiftE (Left ())) -> ExceptionEffect Case1 (LiftE (Right ())) -> IOEffect - Case2 name -> UserEffect name - Case3 UnitE -> InitEffect + Case2 UnitE -> InitEffect _ -> error "unreachable" {-# INLINE toE #-} @@ -2574,8 +2437,8 @@ instance GenericE TopEnvUpdate where {- UpdateLoadedModules -} (LiftE ModuleSourceName `PairE` ModuleName) {- UpdateLoadedObjects -} (FunObjCodeName `PairE` LiftE NativeFunction) ) ( EitherE6 - {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (LamExpr SimpIR)) - {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (LamExpr SimpIR)) + {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) + {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) {- UpdateTopFunEvalStatus -} (TopFunName `PairE` ComposeE EvalStatus TopFunLowerings) {- UpdateInstanceDef -} (InstanceName `PairE` InstanceDef) {- UpdateTyConDef -} (TyConName `PairE` TyConDef) @@ -2685,8 +2548,8 @@ applyUpdate e = \case updateEnv dName newBinding e UpdateTopFunEvalStatus f s -> do case lookupEnvPure e f of - TopFunBinding (DexTopFun def ty simp _) -> - updateEnv f (TopFunBinding $ DexTopFun def ty simp s) e + TopFunBinding (DexTopFun def lam _) -> + updateEnv f (TopFunBinding $ DexTopFun def lam s) e _ -> error "can't update ffi function impl" UpdateInstanceDef name def -> do case lookupEnvPure e name of @@ -2842,7 +2705,7 @@ instance Store (TyConParams n) instance Store (DataConDefs n) instance Store (TyConDef n) instance Store (DataConDef n) -instance IRRep r => Store (Block r n) +instance IRRep r => Store (TopLam r n) instance IRRep r => Store (LamExpr r n) instance IRRep r => Store (IxType r n) instance Store (CorePiType n) @@ -2857,8 +2720,6 @@ instance Store (DictType n) instance Store (DictExpr n) instance Store (EffectDef n) instance Store (EffectOpDef n) -instance Store (RolePiBinder n l) -instance Store (HandlerDef n) instance Store (EffectOpType n) instance Store (EffectOpIdx) instance Store (SynthCandidates n) @@ -2882,7 +2743,6 @@ instance IRRep r => Store (RefOp r n) instance IRRep r => Store (BaseMonoid r n) instance IRRep r => Store (DAMOp r n) instance IRRep r => Store (IxDict r n) -instance Store (UserEffectOp n) instance Store (NewtypeCon n) instance Store (NewtypeTyCon n) instance Store (DotMethods n) diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index a9230a3ed..002a6d09a 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -22,9 +22,7 @@ module Types.Primitives ( module Types.Primitives, UnOp (..), BinOp (..), CmpOp (..), Projection (..)) where -import Name import qualified Data.ByteString as BS -import Control.Monad import Data.Int import Data.Word import Data.Hashable @@ -35,7 +33,6 @@ import Foreign.Ptr import GHC.Generics (Generic (..)) import Occurrence -import Util (zipErr) import Types.OpNames (UnOp (..), BinOp (..), CmpOp (..), Projection (..)) type SourceName = String @@ -58,23 +55,6 @@ data Explicitness = data AppExplicitness = ExplicitApp | ImplicitApp deriving (Show, Generic, Eq) data DepPairExplicitness = ExplicitDepPair | ImplicitDepPair deriving (Show, Generic, Eq) -data WithExpl (b::B) (n::S) (l::S) = - WithExpl { getExpl :: Explicitness , withoutExpl :: b n l } - deriving (Show, Generic) - -unzipExpls :: Nest (WithExpl b) n l -> ([Explicitness], Nest b n l) -unzipExpls Empty = ([], Empty) -unzipExpls (Nest (WithExpl expl b) rest) = (expl:expls, Nest b bs) - where (expls, bs) = unzipExpls rest - -zipExpls :: [Explicitness] -> Nest b n l -> Nest (WithExpl b) n l -zipExpls [] Empty = Empty -zipExpls (expl:expls) (Nest b bs) = Nest (WithExpl expl b) (zipExpls expls bs) -zipExpls _ _ = error "zip error" - -addExpls :: Explicitness -> Nest b n l -> Nest (WithExpl b) n l -addExpls expl bs = fmapNest (\b -> WithExpl expl b) bs - data RequiredMethodAccess = Full | Partial Int deriving (Show, Eq, Ord, Generic) data LetAnn = @@ -225,40 +205,3 @@ instance Hashable AppExplicitness instance Hashable DepPairExplicitness instance Hashable InferenceMechanism instance Hashable RequiredMethodAccess - -instance Store (b n l) => Store (WithExpl b n l) - -instance (Color c, BindsOneName b c) => BindsOneName (WithExpl b) c where - binderName (WithExpl _ b) = binderName b - asNameBinder (WithExpl _ b) = asNameBinder b - -instance (Color c, BindsAtMostOneName b c) => BindsAtMostOneName (WithExpl b) c where - WithExpl _ b @> x = b @> x - {-# INLINE (@>) #-} - -instance AlphaEqB b => AlphaEqB (WithExpl b) where - withAlphaEqB (WithExpl a1 b1) (WithExpl a2 b2) cont = do - unless (a1 == a2) zipErr - withAlphaEqB b1 b2 cont - -instance AlphaHashableB b => AlphaHashableB (WithExpl b) where - hashWithSaltB env salt (WithExpl expl b) = do - let h = hashWithSalt salt expl - hashWithSaltB env h b - -instance BindsNames b => ProvesExt (WithExpl b) where -instance BindsNames b => BindsNames (WithExpl b) where - toScopeFrag (WithExpl _ b) = toScopeFrag b - -instance (SinkableB b) => SinkableB (WithExpl b) where - sinkingProofB fresh (WithExpl a b) cont = - sinkingProofB fresh b \fresh' b' -> - cont fresh' (WithExpl a b') - -instance (BindsNames b, RenameB b) => RenameB (WithExpl b) where - renameB env (WithExpl a b) cont = - renameB env b \env' b' -> - cont env' $ WithExpl a b' - -instance HoistableB b => HoistableB (WithExpl b) where - freeVarsB (WithExpl _ b) = freeVarsB b diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 78bac3ddc..0c361236d 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -201,7 +201,6 @@ data UEffect (n::S) = URWSEffect RWS (SourceOrInternalName (AtomNameC CoreIR) n) | UExceptionEffect | UIOEffect - | UUserEffect (SourceOrInternalName EffectNameC n) deriving (Generic) data UEffectRow (n::S) = @@ -275,9 +274,12 @@ data FieldName' = | FieldNum Int deriving (Show, Eq, Ord) +type UAnnExplBinders req n l = ([Explicitness], Nest (UAnnBinder req) n l) +type UOptAnnExplBinders n l = UAnnExplBinders AnnOptional n l + data ULamExpr (n::S) where ULamExpr - :: Nest (WithExpl UOptAnnBinder) n l -- args + :: UOptAnnExplBinders n l -- args -> AppExplicitness -> Maybe (UEffectRow l) -- optional effect -> Maybe (UType l) -- optional result type @@ -285,7 +287,7 @@ data ULamExpr (n::S) where -> ULamExpr n data UPiExpr (n::S) where - UPiExpr :: Nest (WithExpl UOptAnnBinder) n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n + UPiExpr :: UOptAnnExplBinders n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n data UTabPiExpr (n::S) where UTabPiExpr :: UOptAnnBinder n l -> UType l -> UTabPiExpr n @@ -298,14 +300,14 @@ type UConDef (n::S) (l::S) = (SourceName, Nest UReqAnnBinder n l) data UDataDef (n::S) where UDataDef :: SourceName -- source name for pretty printing - -> Nest (WithExpl UOptAnnBinder) n l + -> UOptAnnExplBinders n l -> [(SourceName, UDataDefTrail l)] -- data constructor types -> UDataDef n data UStructDef (n::S) where UStructDef :: SourceName -- source name for pretty printing - -> Nest (WithExpl UOptAnnBinder) n l + -> UOptAnnExplBinders n l -> [(SourceName, UType l)] -- named payloads -> [(LetAnn, SourceName, Abs UAtomBinder ULamExpr l)] -- named methods (initial binder is for `self`) -> UStructDef n @@ -325,14 +327,14 @@ data UTopDecl (n::S) (l::S) where -> UStructDef l -- actual definition -> UTopDecl n l UInterface - :: Nest (WithExpl UOptAnnBinder) n p -- parameter binders + :: UOptAnnExplBinders n p -- parameter binders -> [UType p] -- method types -> UBinder ClassNameC n l' -- class name -> Nest (UBinder MethodNameC) l' l -- method names -> UTopDecl n l UInstance :: SourceNameOr (Name ClassNameC) n -- class name - -> Nest (WithExpl UOptAnnBinder) n l' + -> UOptAnnExplBinders n l' -> [UExpr l'] -- class parameters -> [UMethodDef l'] -- method definitions -- Maybe we should make a separate color (namespace) for instance names? @@ -347,7 +349,7 @@ data UTopDecl (n::S) (l::S) where UHandlerDecl :: SourceNameOr (Name EffectNameC) n -- effect name -> UAtomBinder n b -- body type argument - -> Nest (WithExpl UOptAnnBinder) b l' -- type args + -> UOptAnnExplBinders b l' -- type args -> UEffectRow l' -- returning effect -> UType l' -- returning type -> [UEffectOpDef l'] -- operation definitions diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 69b7404e5..112060821 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -19,7 +19,7 @@ import Core import Err import CheapReduction import IRVariants -import Lower (DestBlock, DestLamExpr) +import Lower (DestBlock) import MTL1 import Name import Subst @@ -87,13 +87,13 @@ newtype TopVectorizeM (i::S) (o::S) (a:: *) = TopVectorizeM , EnvExtender, Builder SimpIR, ScopableBuilder SimpIR, Catchable , SubstReader Name) -vectorizeLoops :: EnvReader m => Word32 -> DestLamExpr n -> m n (DestLamExpr n, Errs) -vectorizeLoops width (LamExpr bsDestB body) = liftEnvReaderM do +vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, Errs) +vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do case popNest bsDestB of Just (PairB bs b) -> refreshAbs (Abs bs (Abs b body)) \bs' body' -> do (Abs b'' body'', errs) <- liftTopVectorizeM width $ vectorizeLoopsDestBlock body' - return $ (LamExpr (bs' >>> UnaryNest b'') body'', errs) + return $ (TopLam d ty (LamExpr (bs' >>> UnaryNest b'') body''), errs) Nothing -> error "expected a trailing dest binder" liftTopVectorizeM :: (EnvReader m) @@ -139,7 +139,7 @@ vectorizeLoopsDestBlock (Abs (destb:>destTy) body) = do vectorizeLoopsBlock :: (Emits o) => Block SimpIR i -> TopVectorizeM i o (SAtom o) -vectorizeLoopsBlock (Block _ decls ans) = +vectorizeLoopsBlock (Abs decls ans) = vectorizeLoopsDecls decls $ renameM ans vectorizeLoopsDecls :: (Emits o) @@ -360,7 +360,7 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of _ -> error "Zip error" vectorizeBlock :: Emits o => SBlock i -> VectorizeM i o (VAtom o) -vectorizeBlock block@(Block _ decls (ans :: SAtom i')) = +vectorizeBlock block@(Abs decls (ans :: SAtom i')) = addVectErrCtx "vectorizeBlock" ("Block:\n" ++ pprint block) $ go decls where @@ -497,11 +497,10 @@ vectorizePrimOp op = case op of -- complain about FFI calls and the like. Hof (TypedHof _ (RunIO body)) -> do -- TODO: buildBlockAux? - Abs decls (LiftE vy `PairE` yWithTy) <- buildScoped do + Abs decls (LiftE vy `PairE` y) <- buildScoped do VVal vy y <- vectorizeBlock body - PairE (LiftE vy) <$> withType y - body' <- absToBlock =<< computeAbsEffects (Abs decls yWithTy) - VVal vy <$> emitHof (RunIO body') + return $ PairE (LiftE vy) y + VVal vy <$> emitHof (RunIO $ Abs decls y) _ -> throwVectErr $ "Can't vectorize op: " ++ pprint op vectorizeType :: SType i -> VectorizeM i o (SType o) diff --git a/tests/uexpr-tests.dx b/tests/uexpr-tests.dx index c32378219..6583ce5bd 100644 --- a/tests/uexpr-tests.dx +++ b/tests/uexpr-tests.dx @@ -321,7 +321,7 @@ def bug(n|Data) -> () = > v#1:((RangeFrom n v#0) => ()) = for i:(RangeFrom n v#0). () > v#1 > Of type: ((RangeFrom n v#0) => ()) -> With effects: {} +> > > for w':n. > ^^^^^^^^^^ diff --git a/tests/unit/ConstantCastingSpec.hs b/tests/unit/ConstantCastingSpec.hs index 499946ef8..fe9abab12 100644 --- a/tests/unit/ConstantCastingSpec.hs +++ b/tests/unit/ConstantCastingSpec.hs @@ -24,6 +24,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import QueryType castOp :: ScalarBaseType -> (SAtom n) -> PrimOp SimpIR n castOp ty x = MiscOp $ CastOp (BaseTy (Scalar ty)) x @@ -43,7 +44,7 @@ exprToBlock expr = do compile :: (Topper m, Mut n) => ScalarBaseType -> ScalarBaseType -> m n LLVMCallable compile fromTy toTy = do - sLam <- liftEnvReaderM $ castLam fromTy toTy + sLam <- liftEnvReaderM (castLam fromTy toTy) >>= asTopLam compileTopLevelFun (EntryFunCC CUDANotRequired) sLam >>= packageLLVMCallable arbLitVal :: ScalarBaseType -> Gen LitVal diff --git a/tests/unit/JaxADTSpec.hs b/tests/unit/JaxADTSpec.hs index ee2725788..40c0ff785 100644 --- a/tests/unit/JaxADTSpec.hs +++ b/tests/unit/JaxADTSpec.hs @@ -18,6 +18,7 @@ import TopLevel import Types.Imp import Types.Primitives hiding (Sin) import Types.Source hiding (SourceName) +import QueryType x_nm, y_nm :: JSourceName x_nm = JSourceName 0 0 "x" @@ -48,7 +49,7 @@ compile jaxpr = do -- the jaxpr instead of just coercing it. Distinct <- getDistinct jRename <- liftRenameM $ renameJaxpr (unsafeCoerceE jaxpr) - jSimp <- liftJaxSimpM $ simplifyJaxpr jRename + jSimp <- liftJaxSimpM (simplifyJaxpr jRename) >>= asTopLam compileTopLevelFun (EntryFunCC CUDANotRequired) jSimp >>= packageLLVMCallable spec :: Spec diff --git a/tests/unit/OccAnalysisSpec.hs b/tests/unit/OccAnalysisSpec.hs index 06007fb1c..20cd6dc88 100644 --- a/tests/unit/OccAnalysisSpec.hs +++ b/tests/unit/OccAnalysisSpec.hs @@ -25,6 +25,7 @@ import Types.Imp (Backend (..)) import Types.Primitives import Types.Source import TopLevel +import QueryType sourceTextToBlocks :: (Topper m, Mut n) => Text -> m n [SBlock n] sourceTextToBlocks source = do @@ -44,11 +45,11 @@ uExprToBlock expr = do renamed <- renameSourceNamesUExpr expr typed <- inferTopUExpr renamed synthed <- synthTopE typed - (SimplifiedBlock block (CoerceRecon _)) <- simplifyTopBlock synthed + SimplifiedTopLam (TopLam _ _ (LamExpr Empty block)) (CoerceRecon _) <- simplifyTopBlock synthed return block findRunIOAnnotation :: SBlock n -> LetAnn -findRunIOAnnotation (Block _ decls _) = go decls where +findRunIOAnnotation (Abs decls _) = go decls where go :: Nest SDecl n l -> LetAnn go (Nest (Let _ (DeclBinding ann (PrimOp (Hof (TypedHof _ (RunIO _)))))) _) = ann go (Nest _ rest) = go rest @@ -57,7 +58,8 @@ findRunIOAnnotation (Block _ decls _) = go decls where analyze :: EvalConfig -> TopStateEx -> [Text] -> IO LetAnn analyze cfg env code = fst <$> runTopperM cfg env do [block] <- sourceTextToBlocks $ unlines code - NullaryLamExpr block' <- analyzeOccurrences $ NullaryLamExpr block + lam <- asTopLam $ LamExpr Empty block + TopLam _ _ (LamExpr Empty block') <- analyzeOccurrences lam -- The RunIO is generated by simplifying `unreachable()` in the examples -- below. If we want compound examples that have more than one RunIO block, -- we will need better pattern-matching.