From 39c45b81aff9336278d987c5ddc8735d501bb1f3 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 6 Jul 2023 20:46:51 -0400 Subject: [PATCH] Update type checker in anticipation of decls-in-types. --- src/lib/CheckType.hs | 1471 ++++++++++++++++++--------------------- src/lib/Core.hs | 1 + src/lib/Generalize.hs | 1 - src/lib/Imp.hs | 2 +- src/lib/Inference.hs | 48 +- src/lib/Linearize.hs | 2 +- src/lib/QueryType.hs | 46 +- src/lib/Simplify.hs | 15 +- src/lib/TopLevel.hs | 8 +- src/lib/Types/Core.hs | 8 +- src/lib/Types/Source.hs | 9 + 11 files changed, 790 insertions(+), 821 deletions(-) diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index d25ce524f..a49e5989f 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -6,19 +6,13 @@ {-# LANGUAGE UndecidableInstances #-} -module CheckType ( - CheckableE (..), CheckableB (..), - checkTypes, checkTypesM, checkHasType, - checkExtends, tryGetType, isData, asFFIFunType, checkBlock - ) where +module CheckType (CheckableE (..), CheckableB (..), checkBlock, checkTypes, checkTypeIs) where import Prelude hiding (id) import Control.Category ((>>>)) import Control.Monad import Control.Monad.Reader import Control.Monad.State.Class -import Data.Maybe (isJust) -import Data.Foldable (toList) import Data.Functor import CheapReduction @@ -29,98 +23,70 @@ import MTL1 import Name import Subst import PPrint () -import QueryType hiding (HasType) +import QueryType import Types.Core -import Types.Imp import Types.Primitives import Types.Source -- === top-level API === -checkTypes :: (EnvReader m, CheckableE r e) => e n -> m n (Except ()) -checkTypes e = liftTyperT $ checkE e -{-# SCC checkTypes #-} +checkTypes :: (EnvReader m, Fallible1 m, CheckableE r e) => e n -> m n () +checkTypes e = liftTyperM (void $ checkE e) >>= liftExcept -checkTypesM :: (EnvReader m, Fallible1 m, CheckableE r e) => e n -> m n () -checkTypesM e = liftExcept =<< checkTypes e - -tryGetType :: (EnvReader m, Fallible1 m, HasType r e) => e n -> m n (Type r n) -tryGetType e = liftExcept =<< liftTyperT (getTypeE e) -{-# INLINE tryGetType #-} - -checkHasType :: (EnvReader m, HasType r e) => e n -> Type r n -> m n (Except ()) -checkHasType e ty = liftTyperT $ e |: ty -{-# INLINE checkHasType #-} +checkTypeIs :: (EnvReader m, Fallible1 m, CheckableE r e, IRRep r, HasType r e) => e n -> Type r n -> m n () +checkTypeIs e ty = liftTyperM (void $ e |: ty) >>= liftExcept -- === the type checking/querying monad === --- TODO: not clear why we need the explicit `Monad2` here since it should --- already be a superclass, transitively, through both Fallible2 and --- MonadAtomSubst. -class ( Monad2 m, Fallible2 m, SubstReader Name m - , EnvReader2 m, EnvExtender2 m) - => Typer (m::MonadKind2) (r::IR) | m -> r where - affineUsed :: AtomName r o -> m i o () - parallelAffines_ :: [m i o ()] -> m i o () - -newtype TyperT (m::MonadKind) (r::IR) (i::S) (o::S) (a :: *) = - TyperT { runTyperT' :: SubstReaderT Name (StateT1 (NameMap (AtomNameC r) Int) (EnvReaderT m)) i o a } - deriving ( Functor, Applicative, Monad - , SubstReader Name - , MonadFail - , Fallible - , ScopeReader +newtype TyperM (r::IR) (i::S) (o::S) (a :: *) = + TyperM { runTyperT' :: SubstReaderT Name (StateT1 (NameMap (AtomNameC r) Int) FallibleEnvReaderM) i o a } + deriving ( Functor, Applicative, Monad , SubstReader Name , MonadFail , Fallible , ScopeReader , EnvReader, EnvExtender) -liftTyperT :: (EnvReader m', Fallible m) => TyperT m r n n a -> m' n (m a) -liftTyperT cont = - liftEnvReaderT $ +liftTyperM :: EnvReader m => TyperM r n n a -> m n (Except a) +liftTyperM cont = + liftM runFallibleM $ liftEnvReaderT $ flip evalStateT1 mempty $ runSubstReaderT idSubst $ runTyperT' cont -{-# INLINE liftTyperT #-} - -instance Fallible m => Typer (TyperT m r) r where - -- I can't make up my mind whether a `Seq` loop should be allowed to - -- close over a dest from an enclosing scope. Status quo permits this. - affineUsed name = TyperT $ do - affines <- get - case lookupNameMap name affines of - Just n -> if n > 0 then - throw TypeErr $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times." - else - put $ insertNameMap name (n + 1) affines - Nothing -> put $ insertNameMap name 1 affines - parallelAffines_ actions = TyperT $ do - -- This method permits using an affine variable in each branch of a `case`. - -- We check each `case` branch in isolation, detecting affine overuse within - -- the branch; then we check whether the union of the variables used in the - -- branches reuses a variable from outside that it shouldn't. - -- This has the down-side of localizing such an error to the case rather - -- than to the offending in-branch use, but that can be improved later. - affines <- get - isolateds <- forM actions \act -> do - put mempty - runTyperT' act - get - put affines - forM_ (toListNameMap $ unionsWithNameMap max isolateds) \(name, ct) -> - case ct of - 0 -> return () - 1 -> runTyperT' $ affineUsed name - _ -> error $ "Unexpected multi-used affine name " ++ show name ++ " from case branches." +{-# INLINE liftTyperM #-} + +-- I can't make up my mind whether a `Seq` loop should be allowed to +-- close over a dest from an enclosing scope. Status quo permits this. +affineUsed :: AtomName r o -> TyperM r i o () +affineUsed name = TyperM $ do + affines <- get + case lookupNameMap name affines of + Just n -> if n > 0 then + throw TypeErr $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times." + else + put $ insertNameMap name (n + 1) affines + Nothing -> put $ insertNameMap name 1 affines + +parallelAffines :: [TyperM r i o a] -> TyperM r i o [a] +parallelAffines actions = TyperM $ do + -- This method permits using an affine variable in each branch of a `case`. + -- We check each `case` branch in isolation, detecting affine overuse within + -- the branch; then we check whether the union of the variables used in the + -- branches reuses a variable from outside that it shouldn't. + -- This has the down-side of localizing such an error to the case rather + -- than to the offending in-branch use, but that can be improved later. + affines <- get + (results, isolateds) <- unzip <$> forM actions \act -> do + put mempty + result <- runTyperT' act + (result,) <$> get + put affines + forM_ (toListNameMap $ unionsWithNameMap max isolateds) \(name, ct) -> + case ct of + 0 -> return () + 1 -> runTyperT' $ affineUsed name + _ -> error $ "Unexpected multi-used affine name " ++ show name ++ " from case branches." + return results -- === typeable things === -class (SinkableE e, RenameE e, PrettyE e, IRRep r) => HasType (r::IR) (e::E) | e -> r where - getTypeE :: Typer m r => e i -> m i o (Type r o) - -checkTypeE :: (HasType r e, Typer m r) => Type r o -> e i -> m i o (e o) -checkTypeE reqTy e = do - e |: reqTy - renameM e - -checkTypesEq :: (Typer m r, IRRep r) => Type r o -> Type r o -> m i o () +checkTypesEq :: IRRep r => Type r o -> Type r o -> TyperM r i o () checkTypesEq reqTy ty = alphaEq reqTy ty >>= \case True -> return () False -> {-# SCC typeNormalization #-} do @@ -131,446 +97,482 @@ checkTypesEq reqTy ty = alphaEq reqTy ty >>= \case False -> throw TypeErr $ pprint reqTy' ++ " != " ++ pprint ty' {-# INLINE checkTypesEq #-} -class (SinkableE e) => CheckableE (r::IR) (e::E) | e -> r where - checkE :: Typer m r => e i -> m i o () +class SinkableE e => CheckableE (r::IR) (e::E) | e -> r where + checkE :: e i -> TyperM r i o (e o) class HasNamesB b => CheckableB (r::IR) (b::B) | b -> r where - checkB :: Typer m r - => b i i' - -> (forall o'. Ext o o' => b o o' -> m i' o' ()) - -> m i o () + checkB :: b i i' + -> (forall o'. DExt o o' => b o o' -> TyperM r i' o' a) + -> TyperM r i o a -checkBEvidenced :: (CheckableB r b, Typer m r) +class SinkableE e => CheckableWithEffects (r::IR) (e::E) | e -> r where + checkWithEffects :: EffectRow r o -> e i -> TyperM r i o (e o) + +checkBEvidenced :: CheckableB r b => b i i' - -> (forall o'. ExtEvidence o o' -> b o o' -> m i' o' ()) - -> m i o () + -> (forall o'. Distinct o' => ExtEvidence o o' -> b o o' -> TyperM r i' o' a) + -> TyperM r i o a checkBEvidenced b cont = checkB b \b' -> cont getExtEvidence b' -- === convenience functions === infixr 7 |: -(|:) :: (Typer m r, HasType r e) => e i -> Type r o -> m i o () +(|:) :: (HasType r e, CheckableE r e, IRRep r) => e i -> Type r o -> TyperM r i o (e o) (|:) x reqTy = do - ty <- getTypeE x - -- TODO: Write an alphaEq variant that works modulo an equivalence - -- relation on names. - checkTypesEq reqTy ty + x' <- checkE x + checkTypesEq reqTy (getType x') + return x' + +checkAndGetType :: (HasType r e, CheckableE r e, IRRep r) => e i -> TyperM r i o (e o, Type r o) +checkAndGetType x = do + x' <- checkE x + return (x', getType x') instance CheckableE CoreIR SourceMap where - checkE _ = return () + checkE sm = renameM sm -- TODO? instance (CheckableE r e1, CheckableE r e2) => CheckableE r (PairE e1 e2) where - checkE (PairE e1 e2) = checkE e1 >> checkE e2 + checkE (PairE e1 e2) = PairE <$> checkE e1 <*> checkE e2 instance (CheckableE r e1, CheckableE r e2) => CheckableE r (EitherE e1 e2) where - checkE ( LeftE e) = checkE e - checkE (RightE e) = checkE e + checkE ( LeftE e) = LeftE <$> checkE e + checkE (RightE e) = RightE <$> checkE e instance (CheckableB r b, CheckableE r e) => CheckableE r (Abs b e) where - checkE (Abs b e) = checkB b \_ -> checkE e - -instance (IRRep r) => CheckableE r (LamExpr r) where - checkE (LamExpr bs body) = checkB bs \_ -> void $ checkBlock body + checkE (Abs b e) = checkB b \b' -> Abs b' <$> checkE e -- === type checking core === instance IRRep r => CheckableE r (TopLam r) where - checkE (TopLam _ piTy lam) = do + checkE (TopLam destFlag piTy lam) = do -- TODO: check destination-passing flag - checkE piTy - piTy' <- renameM piTy - piTy'' <- checkLamExpr lam - alphaEq piTy' piTy'' >>= \case - True -> return () - False -> throw TypeErr $ pprint piTy' ++ " != " ++ pprint piTy'' + piTy' <- checkE piTy + lam' <- checkLamExpr piTy' lam + return $ TopLam destFlag piTy' lam' -instance IRRep r => CheckableE r (PiType r) where - checkE piTy = void $ getTypeE piTy +instance IRRep r => CheckableE r (AtomName r) where + checkE = renameM instance IRRep r => CheckableE r (Atom r) where - checkE atom = void $ getTypeE atom - -instance IRRep r => CheckableE r (Type r) where - checkE atom = void $ getTypeE atom - -instance IRRep r => HasType r (AtomName r) where - getTypeE name = do - name' <- renameM name - getType <$> lookupAtomName name' - {-# INLINE getTypeE #-} - -instance IRRep r => HasType r (Atom r) where - getTypeE atom = case atom of + checkE = \case Var name -> do - ty <- getTypeE name - case ty of - RawRefTy _ -> renameM name >>= affineUsed . atomVarName + name' <- checkE name + case getType name' of + RawRefTy _ -> affineUsed $ atomVarName name' _ -> return () - return ty - Lam (CoreLamExpr piTy lam) -> do - Pi piTy' <- checkTypeE TyKind $ Pi piTy - checkCoreLam piTy' lam - return $ Pi piTy' + return $ Var name' + Lam lam -> Lam <$> checkE lam DepPair l r ty -> do - ty' <- checkTypeE TyKind ty - l' <- checkTypeE (depPairLeftTy ty') l - rTy <- instantiate ty' [l'] - r |: rTy - return $ DepPairTy ty' - Con con -> typeCheckPrimCon con - Eff eff -> checkE eff $> EffKind - PtrVar t _ -> return $ PtrTy t -- TODO: check against env - DictCon _ dictExpr -> getTypeE dictExpr -- TODO: check against cached type - RepValAtom (RepVal ty _) -> renameM ty - NewtypeCon con x -> NewtypeTyCon <$> typeCheckNewtypeCon con x - SimpInCore x -> getTypeE x - DictHole _ ty _ -> checkTypeE TyKind ty - ProjectElt _ UnwrapNewtype x -> do - NewtypeTyCon con <- getTypeE x - snd <$> unwrapNewtypeType con - ProjectElt _ (ProjectProduct i) x -> do - ty <- getTypeE x - case ty of + l' <- checkE l + ty' <- checkE ty + rTy <- checkInstantiation ty' [l'] + r' <- r |: rTy + return $ DepPair l' r' ty' + Con con -> Con <$> checkE con + Eff eff -> Eff <$> checkE eff + PtrVar t v -> PtrVar t <$> renameM v + -- TODO: check against cached type + DictCon ty dictExpr -> DictCon <$> checkE ty <*> checkE dictExpr + RepValAtom repVal -> RepValAtom <$> renameM repVal -- TODO: check + NewtypeCon con x -> do + (x', xTy) <- checkAndGetType x + con' <- typeCheckNewtypeCon con xTy + return $ NewtypeCon con' x' + SimpInCore x -> SimpInCore <$> checkE x + DictHole ctx ty access -> do + ty' <- ty |: TyKind + return $ DictHole ctx ty' access + ProjectElt resultTy UnwrapNewtype x -> do + resultTy' <- resultTy |: TyKind + (x', NewtypeTyCon con) <- checkAndGetType x + resultTy'' <- snd <$> unwrapNewtypeType con + checkTypesEq resultTy' resultTy'' + return $ ProjectElt resultTy' UnwrapNewtype x' + ProjectElt resultTy (ProjectProduct i) x -> do + resultTy' <- resultTy |: TyKind + (x', xTy) <- checkAndGetType x + resultTy'' <- case xTy of ProdTy tys -> return $ tys !! i DepPairTy t | i == 0 -> return $ depPairLeftTy t DepPairTy t | i == 1 -> do - x' <- renameM x xFst <- normalizeProj (ProjectProduct 0) x' - instantiate t [xFst] - _ -> throw TypeErr $ "Not a product type:" ++ pprint ty - TypeAsAtom ty -> getTypeE ty - -instance IRRep r => HasType r (AtomVar r) where - getTypeE (AtomVar v t1) = do + checkInstantiation t [xFst] + _ -> throw TypeErr $ "Not a product type:" ++ pprint xTy + checkTypesEq resultTy' resultTy'' + return $ ProjectElt resultTy' (ProjectProduct i) x' + TypeAsAtom ty -> TypeAsAtom <$> checkE ty + +instance IRRep r => CheckableE r (AtomVar r) where + checkE (AtomVar v t1) = do t1' <- renameM t1 v' <- renameM v t2 <- getType <$> lookupAtomName v' checkTypesEq t1' t2 - return t1' - -instance IRRep r => HasType r (Type r) where - getTypeE atom = case atom of - Pi piType -> getTypeE piType - TabPi piType -> getTypeE piType - NewtypeTyCon t -> typeCheckNewtypeTyCon t - TC tyCon -> typeCheckPrimTC tyCon - DepPairTy ty -> getTypeE ty - DictTy (DictType _ className params) -> do - ClassDef _ _ _ _ paramBs _ _ <- renameM className >>= lookupClassDef - params' <- mapM renameM params - checkArgTys paramBs params' - return TyKind - TyVar v -> getTypeE v + return $ AtomVar v' t1' + +instance IRRep r => CheckableE r (Type r) where + checkE = \case + Pi t -> Pi <$> checkE t + TabPi t -> TabPi <$> checkE t + NewtypeTyCon t -> NewtypeTyCon <$> checkE t + TC t -> TC <$> checkE t + DepPairTy t -> DepPairTy <$> checkE t + DictTy (DictType sn className params) -> do + className' <- renameM className + ClassDef _ _ _ _ paramBs _ _ <- lookupClassDef className' + params' <- mapM checkE params + void $ checkInstantiation (Abs paramBs UnitE) params' + return $ DictTy (DictType sn className' params') + TyVar v -> TyVar <$> checkE v ProjectEltTy resultTy UnwrapNewtype x -> do - resultTy' <- renameM resultTy - NewtypeTyCon con <- getTypeE x + resultTy' <- resultTy |: TyKind + x' <- checkE x + NewtypeTyCon con <- return $ getType x' ty <- snd <$> unwrapNewtypeType con checkTypesEq resultTy' ty - return ty - ProjectEltTy _ (ProjectProduct i) x -> do - ty <- getTypeE x - case ty of + return $ ProjectEltTy resultTy' UnwrapNewtype x' + ProjectEltTy resultTy (ProjectProduct i) x -> do + resultTy' <- resultTy |: TyKind + (x', ty) <- checkAndGetType x + resultTy'' <- case ty of ProdTy tys -> return $ tys !! i DepPairTy t | i == 0 -> return $ depPairLeftTy t DepPairTy t | i == 1 -> do - x' <- renameM x xFst <- normalizeProj (ProjectProduct 0) x' instantiate t [xFst] _ -> throw TypeErr $ "Not a product type:" ++ pprint ty + checkTypesEq resultTy' resultTy'' + return $ ProjectEltTy resultTy' (ProjectProduct i) x' -instance HasType CoreIR SimpInCore where - getTypeE = \case - LiftSimp ty _ -> renameM ty -- TODO: check - LiftSimpFun piTy _ -> Pi <$> renameM piTy -- TODO: check - ACase _ _ ty -> renameM ty -- TODO: check - TabLam t _ -> TabPi <$> renameM t -- TODO: check +instance CheckableE CoreIR SimpInCore where + checkE x = renameM x -- TODO: check instance (ToBinding ann c, Color c, CheckableE r ann) => CheckableB r (BinderP c ann) where checkB (b:>ann) cont = do - checkE ann - ann' <- renameM ann + ann' <- checkE ann withFreshBinder (getNameHint b) ann' \b' -> extendRenamer (b@>binderName b') $ cont b' -typeCheckExpr :: (Typer m r, IRRep r) => EffectRow r o -> Expr r i -> m i o (Type r o) -typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of - App (EffTy _ reqTy) f xs -> do - fTy <- getTypeE f - checkApp effs fTy xs >>= checkAgainstGiven reqTy - TabApp reqTy f xs -> do - fTy <- getTypeE f - checkTabApp fTy xs >>= checkAgainstGiven reqTy - TopApp (EffTy _ reqTy) f xs -> do - PiType bs (EffTy _ resultTy) <- getTypeTopFun =<< renameM f - xs' <- mapM renameM xs - checkedApplyNaryAbs (Abs bs resultTy) xs' >>= checkAgainstGiven reqTy - Atom x -> getTypeE x - PrimOp op -> typeCheckPrimOp effs op - Case e alts (EffTy caseEffs resultTy) -> do - caseEffs' <- renameM caseEffs - resultTy' <- renameM resultTy - checkCase e alts resultTy' caseEffs' - checkExtends effs caseEffs' - return resultTy' - ApplyMethod (EffTy _ reqTy) dict i args -> do - DictTy (DictType _ className params) <- getTypeE dict - ClassDef _ _ _ _ paramBs classBs methodTys <- lookupClassDef className - let methodTy = methodTys !! i - superclassDicts <- getSuperclassDicts =<< renameM dict - let subst = ( paramBs @@> map SubstVal params - <.> classBs @@> map SubstVal superclassDicts) - methodTy' <- applySubst subst methodTy - checkApp effs (Pi methodTy') args >>= checkAgainstGiven reqTy - TabCon _ ty xs -> do - ty'@(TabPi (TabPiType _ b restTy)) <- checkTypeE TyKind ty - case fromConstAbs (Abs b restTy) of - HoistSuccess elTy -> forM_ xs (|: elTy) - -- XXX: in the dependent case we don't check that the element types - -- match the annotation because that would require concretely evaluating - -- each index from the ix dict. - HoistFailure _ -> forM_ xs checkE - return ty' +checkBinderType + :: (IRRep r) => Type r o -> Binder r i i' + -> (forall o'. DExt o o' => Binder r o o' -> TyperM r i' o' a) + -> TyperM r i o a +checkBinderType ty b cont = do + checkB b \b' -> do + checkTypesEq (sink $ binderType b') (sink ty) + cont b' + +instance IRRep r => CheckableWithEffects r (Expr r) where + checkWithEffects allowedEffs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of + App effTy f xs -> do + effTy' <- checkEffTy allowedEffs effTy + f' <- checkE f + Pi piTy <- return $ getType f' + xs' <- mapM checkE xs + effTy'' <- checkInstantiation piTy xs' + checkAlphaEq effTy' effTy'' + return $ App effTy' f' xs' + TabApp reqTy f xs -> do + reqTy' <- reqTy |: TyKind + (f', tabTy) <- checkAndGetType f + xs' <- mapM checkE xs + ty' <- checkTabApp tabTy xs' + checkTypesEq reqTy' ty' + return $ TabApp reqTy' f' xs' + TopApp effTy f xs -> do + f' <- renameM f + effTy' <- checkEffTy allowedEffs effTy + piTy <- getTypeTopFun f' + xs' <- mapM checkE xs + effTy'' <- checkInstantiation piTy xs' + checkAlphaEq effTy' effTy'' + return $ TopApp effTy' f' xs' + Atom x -> Atom <$> checkE x + PrimOp op -> PrimOp <$> checkWithEffects allowedEffs op + Case scrut alts effTy -> do + effTy' <- checkEffTy allowedEffs effTy + scrut' <- checkE scrut + altsBinderTys <- checkCaseAltsBinderTys $ getType scrut' + assertEq (length altsBinderTys) (length alts) "" + alts' <- parallelAffines $ (zip alts altsBinderTys) <&> \(Abs b body, reqBinderTy) -> do + checkB b \b' -> do + checkTypesEq (sink reqBinderTy) (sink $ binderType b') + Abs b' <$> checkBlock (sink effTy') body + return $ Case scrut' alts' effTy' + ApplyMethod effTy dict i args -> do + effTy' <- checkEffTy allowedEffs effTy + dict' <- checkE dict + args' <- mapM checkE args + methodTy <- getMethodType dict' i + effTy'' <- checkInstantiation methodTy args' + checkAlphaEq effTy' effTy'' + return $ ApplyMethod effTy' dict' i args' + TabCon maybeD ty xs -> do + ty'@(TabPi (TabPiType _ b restTy)) <- ty |: TyKind + maybeD' <- mapM renameM maybeD -- TODO: check + xs' <- case fromConstAbs (Abs b restTy) of + HoistSuccess elTy -> forM xs (|: elTy) + -- XXX: in the dependent case we don't check that the element types + -- match the annotation because that would require concretely evaluating + -- each index from the ix dict. + HoistFailure _ -> forM xs checkE + return $ TabCon maybeD' ty' xs' instance CheckableE CoreIR TyConParams where - checkE (TyConParams _ params) = mapM_ checkE params - -dictExprType :: Typer m CoreIR => DictExpr i -> m i o (CType o) -dictExprType e = case e of - InstanceDict instanceName args -> do - instanceName' <- renameM instanceName - InstanceDef className _ bs params _ <- lookupInstanceDef instanceName' - ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className - args' <- mapM renameM args - checkArgTys bs args' - ListE params' <- applySubst (bs@@>(SubstVal<$>args')) (ListE params) - return $ DictTy $ DictType sourceName className params' - InstantiatedGiven given args -> do - givenTy <- getTypeE given - checkApp Pure givenTy (toList args) - SuperclassProj d i -> do - DictTy (DictType _ className params) <- getTypeE d - ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className - let scType = getSuperclassType REmpty superclasses i - checkedApplyNaryAbs (Abs bs scType) params - IxFin n -> do - n' <- checkTypeE NatTy n - liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n' - DataData ty -> DictTy <$> (dataDictType =<< checkTypeE TyKind ty) - -instance HasType CoreIR DictExpr where - getTypeE e = dictExprType e - -instance IRRep r => HasType r (DepPairType r) where - getTypeE (DepPairType _ b ty) = do - checkB b \_ -> ty |: TyKind - return TyKind - -instance HasType CoreIR CorePiType where - getTypeE (CorePiType _ _ bs (EffTy eff resultTy)) = do - checkB bs \_ -> do - void $ checkE eff - resultTy|:TyKind - return TyKind - -instance IRRep r => HasType r (PiType r) where - getTypeE (PiType bs (EffTy eff resultTy)) = do - checkB bs \_ -> do - void $ checkE eff - resultTy|:TyKind - return TyKind + checkE (TyConParams expls params) = TyConParams expls <$> mapM checkE params + +instance CheckableE CoreIR DictExpr where + checkE = \case + InstanceDict instanceName args -> do + instanceName' <- renameM instanceName + args' <- mapM checkE args + instanceDef <- lookupInstanceDef instanceName' + void $ checkInstantiation instanceDef args' + return $ InstanceDict instanceName' args' + InstantiatedGiven given args -> do + (given', Pi piTy) <- checkAndGetType given + args' <- mapM checkE args + EffTy Pure _ <- checkInstantiation piTy args' + return $ InstantiatedGiven given' args' + SuperclassProj d i -> SuperclassProj <$> checkE d <*> pure i -- TODO: check index in range + IxFin n -> IxFin <$> n |: NatTy + DataData ty -> DataData <$> ty |: TyKind + +instance IRRep r => CheckableE r (DepPairType r) where + checkE (DepPairType expl b ty) = do + checkB b \b' -> do + ty' <- ty |: TyKind + return $ DepPairType expl b' ty' + +instance CheckableE CoreIR CorePiType where + checkE (CorePiType expl expls bs effTy) = do + checkB bs \bs' -> do + effTy' <- checkE effTy + return $ CorePiType expl expls bs' effTy' + +instance IRRep r => CheckableE r (PiType r) where + checkE (PiType bs effTy) = do + checkB bs \bs' -> do + effTy' <- checkE effTy + return $ PiType bs' effTy' + +instance IRRep r => CheckableE r (IxDict r) where + checkE = renameM -- TODO: check instance IRRep r => CheckableE r (IxType r) where - checkE (IxType t _) = checkE t + checkE (IxType t d) = IxType <$> checkE t <*> checkE d -instance IRRep r => HasType r (TabPiType r) where - getTypeE (TabPiType _ b resultTy) = do - checkB b \_ -> resultTy|:TyKind - return TyKind +instance IRRep r => CheckableE r (TabPiType r) where + checkE (TabPiType d b resultTy) = do + d' <- checkE d + checkB b \b' -> do + resultTy' <- resultTy|:TyKind + return $ TabPiType d' b' resultTy' instance (BindsNames b, CheckableB r b) => CheckableB r (Nest b) where checkB nest cont = case nest of - Empty -> cont Empty + Empty -> getDistinct >>= \Distinct -> cont Empty Nest b rest -> checkBEvidenced b \ext1 b' -> checkBEvidenced rest \ext2 rest' -> withExtEvidence (ext1 >>> ext2) $ cont $ Nest b' rest' -checkAgainstGiven :: (Typer m r, IRRep r) => Type r i -> Type r o -> m i o (Type r o) -checkAgainstGiven givenTy computedTy = do - givenTy' <- renameM givenTy - checkTypesEq givenTy' computedTy - return givenTy' - -checkCoreLam :: Typer m CoreIR => CorePiType o -> LamExpr CoreIR i -> m i o () -checkCoreLam (CorePiType _ _ Empty (EffTy effs resultTy)) (LamExpr Empty body) = do - resultTy' <- checkBlockWithEffs effs body - checkTypesEq resultTy resultTy' -checkCoreLam (CorePiType expl (_:expls) (Nest piB piBs) effTy) (LamExpr (Nest lamB lamBs) body) = do - argTy <- renameM $ binderType lamB - checkTypesEq (binderType piB) argTy - withFreshBinder (getNameHint lamB) argTy \b -> do - piTy <- applyRename (piB@>binderName b) (CorePiType expl expls piBs effTy) - extendRenamer (lamB@>binderName b) do - checkCoreLam piTy (LamExpr lamBs body) -checkCoreLam _ _ = throw TypeErr "zip error" - -typeCheckPrimTC :: (Typer m r, IRRep r) => TC r i -> m i o (Type r o) -typeCheckPrimTC tc = case tc of - BaseType _ -> return TyKind - ProdType tys -> mapM_ (|:TyKind) tys >> return TyKind - SumType cs -> mapM_ (|:TyKind) cs >> return TyKind - RefType r a -> r|:TC HeapType >> a|:TyKind >> return TyKind - TypeKind -> return TyKind - HeapType -> return TyKind - -typeCheckPrimCon :: (Typer m r, IRRep r) => Con r i -> m i o (Type r o) -typeCheckPrimCon con = case con of - Lit l -> return $ BaseTy $ litType l - ProdCon xs -> ProdTy <$> mapM getTypeE xs - SumCon tys tag payload -> do - caseTys <- traverse renameM tys - unless (0 <= tag && tag < length caseTys) $ throw TypeErr "Invalid SumType tag" - payload |: (caseTys !! tag) - return $ SumTy caseTys - HeapVal -> return $ TC HeapType +instance CheckableE CoreIR CoreLamExpr where + checkE (CoreLamExpr piTy lamExpr) = do + CorePiType expl expls bs effTy <- checkE piTy + lamExpr' <- checkLamExpr (PiType bs effTy) lamExpr + return $ CoreLamExpr (CorePiType expl expls bs effTy) lamExpr' + +instance IRRep r => CheckableE r (TC r) where + checkE = \case + BaseType b -> return $ BaseType b + ProdType tys -> ProdType <$> mapM (|:TyKind) tys + SumType cs -> SumType <$> mapM (|:TyKind) cs + RefType r a -> RefType <$> r|:TC HeapType <*> a|:TyKind + TypeKind -> return TypeKind + HeapType -> return HeapType + +instance IRRep r => CheckableE r (Con r) where + checkE = \case + Lit l -> return $ Lit l + ProdCon xs -> ProdCon <$> mapM checkE xs + SumCon tys tag payload -> do + tys' <- mapM (|:TyKind) tys + unless (0 <= tag && tag < length tys') $ throw TypeErr "Invalid SumType tag" + payload' <- payload |: (tys' !! tag) + return $ SumCon tys' tag payload' + HeapVal -> return HeapVal typeCheckNewtypeCon - :: Typer m CoreIR => NewtypeCon i -> CAtom i -> m i o (NewtypeTyCon o) -typeCheckNewtypeCon con x = case con of - NatCon -> x|:IdxRepTy >> return Nat - FinCon n -> n|:NatTy >> x|:NatTy >> renameM (Fin n) - UserADTData _ d params -> do + :: NewtypeCon i -> CType o -> TyperM CoreIR i o (NewtypeCon o) +typeCheckNewtypeCon con xTy = case con of + NatCon -> checkAlphaEq IdxRepTy xTy >> return NatCon + FinCon n -> do + n' <- n|:NatTy + checkAlphaEq xTy NatTy + return $ FinCon n' + UserADTData sn d params -> do d' <- renameM d - def@(TyConDef sn _ _ _) <- lookupTyCon d' - params' <- renameM params - void $ checkedInstantiateTyConDef def params' - return $ UserADTType sn d' params' - -typeCheckNewtypeTyCon :: Typer m CoreIR => NewtypeTyCon i -> m i o (CType o) -typeCheckNewtypeTyCon = \case - Nat -> return TyKind - Fin n -> checkTypeE NatTy n >> return TyKind - EffectRowKind -> return TyKind - UserADTType _ d params -> do - def <- lookupTyCon =<< renameM d - params' <- renameM params - void $ checkedInstantiateTyConDef def params' - return TyKind - -typeCheckPrimOp :: (Typer m r, IRRep r) => EffectRow r o -> PrimOp r i -> m i o (Type r o) -typeCheckPrimOp effs op = case op of - Hof (TypedHof effTy hof) -> do - EffTy effs' resultTy <- renameM effTy - checkExtends effs effs' - resultTy' <- typeCheckPrimHof effs hof - checkTypesEq resultTy resultTy' - return resultTy - VectorOp vOp -> typeCheckVectorOp vOp - BinOp binop x y -> do - xTy <- typeCheckBaseType x - yTy <- typeCheckBaseType y - TC <$> BaseType <$> checkBinOp binop xTy yTy - UnOp unop x -> do - xTy <- typeCheckBaseType x - TC <$> BaseType <$> checkUnOp unop xTy - MiscOp x -> typeCheckMiscOp effs x - MemOp x -> typeCheckMemOp effs x - DAMOp op' -> typeCheckDAMOp effs op' - RefOp ref m -> do - TC (RefType h s) <- getTypeE ref - case m of - MGet -> declareEff effs (RWSEffect State h) $> s - MPut x -> x|:s >> declareEff effs (RWSEffect State h) $> UnitTy - MAsk -> declareEff effs (RWSEffect Reader h) $> s - MExtend _ x -> x|:s >> declareEff effs (RWSEffect Writer h) $> UnitTy - IndexRef givenTy i -> do - TabTy _ (b:>iTy) eltTy <- return s - i' <- checkTypeE iTy i - eltTy' <- applyAbs (Abs b eltTy) (SubstVal i') - checkAgainstGiven givenTy (TC $ RefType h eltTy') - ProjRef givenTy p -> do - resultEltTy <- case p of - ProjectProduct i -> do - ProdTy tys <- return s - return $ tys !! i - UnwrapNewtype -> do - NewtypeTyCon tc <- return s - snd <$> unwrapNewtypeType tc - checkAgainstGiven givenTy (TC $ RefType h resultEltTy) - -typeCheckMemOp :: forall r m i o. (Typer m r, IRRep r) => EffectRow r o -> MemOp r i -> m i o (Type r o) -typeCheckMemOp effs = \case - IOAlloc n -> do - n |: IdxRepTy - declareEff effs IOEffect - return $ PtrTy (CPU, Scalar Word8Type) - IOFree ptr -> do - PtrTy _ <- getTypeE ptr - declareEff effs IOEffect - return UnitTy - PtrOffset arr off -> do - PtrTy (a, b) <- getTypeE arr - off |: IdxRepTy - return $ PtrTy (a, b) - PtrLoad ptr -> do - PtrTy (_, t) <- getTypeE ptr - declareEff effs IOEffect - return $ BaseTy t - PtrStore ptr val -> do - PtrTy (_, t) <- getTypeE ptr - val |: BaseTy t - declareEff effs IOEffect - return $ UnitTy - -typeCheckMiscOp :: forall r m i o. (Typer m r, IRRep r) => EffectRow r o -> MiscOp r i -> m i o (Type r o) -typeCheckMiscOp effs = \case - Select p x y -> do - p |: (BaseTy $ Scalar Word8Type) - ty <- getTypeE x - y |: ty - return ty - CastOp t@(TyVar _) _ -> t |: TyKind >> renameM t - CastOp destTy e -> do - sourceTy' <- getTypeE e - destTy |: TyKind - destTy' <- renameM destTy - checkValidCast sourceTy' destTy' - return $ destTy' - BitcastOp t@(TyVar _) _ -> t |: TyKind >> renameM t - BitcastOp destTy e -> do - sourceTy <- getTypeE e - case (destTy, sourceTy) of - (BaseTy dbt@(Scalar _), BaseTy sbt@(Scalar _)) | sizeOf sbt == sizeOf dbt -> - return $ BaseTy dbt - _ -> throw TypeErr $ "Invalid bitcast: " ++ pprint sourceTy ++ " -> " ++ pprint destTy - UnsafeCoerce t _ -> renameM t - GarbageVal t -> renameM t - SumTag x -> do - void $ getTypeE x >>= checkSomeSumType - return TagRepTy - ToEnum t x -> do - x |: Word8Ty - t' <- checkTypeE TyKind t - cases <- checkSomeSumType t' - forM_ cases \cty -> checkTypesEq cty UnitTy - return t' - OutputStream -> - return $ BaseTy $ hostPtrTy $ Scalar Word8Type - where hostPtrTy ty = PtrType (CPU, ty) - ShowAny x -> - -- TODO: constrain `ShowAny` to have `HasCore r` - checkE x >> return rawStrType - ShowScalar x -> do - BaseTy (Scalar _) <- getTypeE x - return $ PairTy IdxRepTy $ rawFinTabType (IdxRepVal showStringBufferSize) CharRepTy - ThrowError ty -> ty|:TyKind >> renameM ty - ThrowException ty -> do - declareEff effs ExceptionEffect - ty|:TyKind >> renameM ty - -checkSomeSumType :: (Typer m r, IRRep r) => Type r o -> m i o [Type r o] + TyConParams expls params' <- checkE params + def <- lookupTyCon d' + void $ checkInstantiation def params' + return $ UserADTData sn d' (TyConParams expls params') + +instance CheckableE CoreIR NewtypeTyCon where + checkE = \case + Nat -> return Nat + Fin n -> Fin <$> n|:NatTy + EffectRowKind -> return EffectRowKind + UserADTType sn d params -> do + d' <- renameM d + TyConParams expls params' <- checkE params + def <- lookupTyCon d' + void $ checkInstantiation def params' + return $ UserADTType sn d' (TyConParams expls params') + +instance IRRep r => CheckableWithEffects r (PrimOp r) where + checkWithEffects effs = \case + Hof (TypedHof effTy hof) -> do + effTy'@(EffTy effs' resultTy) <- checkE effTy + checkExtends effs effs' + -- TODO: we should be able to use the `effTy` from the `TypedHof`, which + -- might have fewer effects than `effs`. But that exposes an error in + -- which we under-report the `Init` effect in the `TypedHof` effect + -- annotation. We should fix that. + hof' <- checkHof (EffTy effs resultTy) hof + return $ Hof (TypedHof effTy' hof') + VectorOp vOp -> VectorOp <$> checkE vOp + BinOp binop x y -> do + x' <- checkE x + y' <- checkE y + TC (BaseType xTy) <- return $ getType x' + TC (BaseType yTy) <- return $ getType y' + checkBinOp binop xTy yTy + return $ BinOp binop x' y' + UnOp unop x -> do + x' <- checkE x + TC (BaseType xTy) <- return $ getType x' + checkUnOp unop xTy + return $ UnOp unop x' + MiscOp op -> MiscOp <$> checkWithEffects effs op + MemOp op -> MemOp <$> checkWithEffects effs op + DAMOp op -> DAMOp <$> checkWithEffects effs op + RefOp ref m -> do + (ref', TC (RefType h s)) <- checkAndGetType ref + m' <- case m of + MGet -> declareEff effs (RWSEffect State h) $> MGet + MPut x -> do + x' <- x|:s + declareEff effs (RWSEffect State h) + return $ MPut x' + MAsk -> declareEff effs (RWSEffect Reader h) $> MAsk + MExtend b x -> do + b' <- checkE b + x' <- x|:s + declareEff effs (RWSEffect Writer h) + return $ MExtend b' x' + IndexRef givenTy i -> do + givenTy' <- givenTy |: TyKind + TabPi tabTy <- return s + i' <- checkE i + eltTy' <- checkInstantiation tabTy [i'] + checkTypesEq givenTy' (TC $ RefType h eltTy') + return $ IndexRef givenTy' i' + ProjRef givenTy p -> do + givenTy' <- givenTy |: TyKind + resultEltTy <- case p of + ProjectProduct i -> do + ProdTy tys <- return s + return $ tys !! i + UnwrapNewtype -> do + NewtypeTyCon tc <- return s + snd <$> unwrapNewtypeType tc + checkTypesEq givenTy' (TC $ RefType h resultEltTy) + return $ ProjRef givenTy' p + return $ RefOp ref' m' + +instance IRRep r => CheckableE r (EffTy r) where + checkE (EffTy effs ty) = EffTy <$> checkE effs <*> checkE ty + +instance IRRep r => CheckableE r (BaseMonoid r) where + checkE = renameM -- TODO: check + +instance IRRep r => CheckableWithEffects r (MemOp r) where + checkWithEffects effs = \case + IOAlloc n -> do + declareEff effs IOEffect + IOAlloc <$> (n |: IdxRepTy) + IOFree ptr -> do + declareEff effs IOEffect + IOFree <$> checkIsPtr ptr + PtrOffset ptr off -> do + ptr' <- checkIsPtr ptr + off' <- off |: IdxRepTy + return $ PtrOffset ptr' off' + PtrLoad ptr -> do + declareEff effs IOEffect + PtrLoad <$> checkIsPtr ptr + PtrStore ptr val -> do + declareEff effs IOEffect + ptr' <- checkE ptr + PtrTy (_, t) <- return $ getType ptr' + val' <- val |: BaseTy t + return $ PtrStore ptr' val' + +checkIsPtr :: IRRep r => Atom r i -> TyperM r i o (Atom r o) +checkIsPtr ptr = do + ptr' <- checkE ptr + PtrTy _ <- return $ getType ptr' + return ptr' + +instance IRRep r => CheckableWithEffects r (MiscOp r) where + checkWithEffects effs = \case + Select p x y -> do + p' <- p |: (BaseTy $ Scalar Word8Type) + x' <- checkE x + y' <- y |: getType x' + return $ Select p' x' y' + CastOp t@(TyVar _) e -> CastOp <$> (t|:TyKind) <*> renameM e + CastOp destTy e -> do + e' <- checkE e + destTy' <- destTy |: TyKind + checkValidCast (getType e') destTy' + return $ CastOp destTy' e' + BitcastOp t@(TyVar _) e -> BitcastOp <$> (t|:TyKind) <*> renameM e + BitcastOp destTy e -> do + destTy' <- destTy |: TyKind + e' <- checkE e + let sourceTy = getType e' + case (destTy', sourceTy) of + (BaseTy dbt@(Scalar _), BaseTy sbt@(Scalar _)) | sizeOf sbt == sizeOf dbt -> + return $ BitcastOp destTy' e' + _ -> throw TypeErr $ "Invalid bitcast: " ++ pprint sourceTy ++ " -> " ++ pprint destTy + UnsafeCoerce t e -> UnsafeCoerce <$> t|:TyKind <*> renameM e + GarbageVal t -> GarbageVal <$> (t|:TyKind) + SumTag x -> do + x' <- checkE x + void $ checkSomeSumType $ getType x' + return $ SumTag x' + ToEnum t x -> do + t' <- t |: TyKind + x' <- x |: Word8Ty + cases <- checkSomeSumType t' + forM_ cases \cty -> checkTypesEq cty UnitTy + return $ ToEnum t' x' + OutputStream -> return OutputStream + ShowAny x -> ShowAny <$> checkE x + ShowScalar x -> do + x' <- checkE x + BaseTy (Scalar _) <- return $ getType x' + return $ ShowScalar x' + ThrowError ty -> ThrowError <$> (ty|:TyKind) + ThrowException ty -> ThrowException <$> do + declareEff effs ExceptionEffect + ty|:TyKind + +checkSomeSumType :: IRRep r => Type r o -> TyperM r i o [Type r o] checkSomeSumType = \case SumTy cases -> return cases NewtypeTyCon con -> do @@ -578,224 +580,208 @@ checkSomeSumType = \case return cases t -> error $ "not some sum type: " ++ pprint t -typeCheckVectorOp :: (Typer m r, IRRep r) => VectorOp r i -> m i o (Type r o) -typeCheckVectorOp = \case - VectorBroadcast v ty -> do - ty'@(BaseTy (Vector _ sbt)) <- checkTypeE TyKind ty - v |: BaseTy (Scalar sbt) - return ty' - VectorIota ty -> do - ty'@(BaseTy (Vector _ _)) <- checkTypeE TyKind ty - return ty' - VectorIdx tbl i ty -> do - TabTy _ b (BaseTy (Scalar sbt)) <- getTypeE tbl - i |: binderType b - ty'@(BaseTy (Vector _ sbt')) <- checkTypeE TyKind ty - unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" - return ty' - VectorSubref ref i ty -> do - RefTy heap (TabTy _ b (BaseTy (Scalar sbt))) <- getTypeE ref - i |: binderType b - ty'@(BaseTy (Vector _ sbt')) <- checkTypeE TyKind ty - unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" - return $ RefTy heap ty' - -typeCheckPrimHof :: forall r m i o. (Typer m r, IRRep r) => EffectRow r o -> Hof r i -> m i o (Type r o) -typeCheckPrimHof effs hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of - For _ ixTy f -> do - IxType t d <- renameM ixTy - PiType (UnaryNest (b:>argTy)) (EffTy _ eltTy) <- checkLamExpr f - checkTypesEq t argTy - return $ TabTy d (b:>t) eltTy +instance IRRep r => CheckableE r (VectorOp r) where + checkE = \case + VectorBroadcast v ty -> do + ty'@(BaseTy (Vector _ sbt)) <- ty |: TyKind + v' <- v |: BaseTy (Scalar sbt) + return $ VectorBroadcast v' ty' + VectorIota ty -> do + ty'@(BaseTy (Vector _ _)) <- ty |: TyKind + return $ VectorIota ty' + VectorIdx tbl i ty -> do + tbl' <- checkE tbl + TabTy _ b (BaseTy (Scalar sbt)) <- return $ getType tbl' + i' <- i |: binderType b + ty'@(BaseTy (Vector _ sbt')) <- ty |: TyKind + unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" + return $ VectorIdx tbl' i' ty' + VectorSubref ref i ty -> do + ref' <- checkE ref + RefTy _ (TabTy _ b (BaseTy (Scalar sbt))) <- return $ getType ref' + i' <- i |: binderType b + ty'@(BaseTy (Vector _ sbt')) <- ty |: TyKind + unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" + return $ VectorSubref ref' i' ty' + +checkBlock :: IRRep r => EffTy r o -> Block r i -> TyperM r i o (Block r o) +checkBlock (EffTy effs ty) (Abs decls result) = + checkDecls effs decls \decls' -> do + result' <- result |: sink ty + return $ Abs decls' result' + +checkHof :: IRRep r => EffTy r o -> Hof r i -> TyperM r i o (Hof r o) +checkHof (EffTy effs reqTy) = \case + For dir ixTy f -> do + IxType t d <- checkE ixTy + LamExpr (UnaryNest b) body <- return f + TabPi tabTy <- return reqTy + checkBinderType t b \b' -> do + resultTy <- checkInstantiation (sink tabTy) [Var $ binderVar b'] + body' <- checkBlock (EffTy (sink effs) resultTy) body + return $ For dir (IxType t d) (LamExpr (UnaryNest b') body') While body -> do - condTy <- checkBlockWithEffs effs body - checkTypesEq (BaseTy $ Scalar Word8Type) condTy - return UnitTy + let effTy = EffTy effs (BaseTy $ Scalar Word8Type) + checkTypesEq reqTy UnitTy + While <$> checkBlock effTy body Linearize f x -> do - PiType (UnaryNest (binder:>a)) (EffTy Pure b) <- checkLamExpr f - b' <- liftHoistExcept $ hoist binder b - fLinTy <- return $ Pi $ nonDepPiType [a] Pure b' - x |: a - return $ PairTy b' fLinTy + (x', xTy) <- checkAndGetType x + LamExpr (UnaryNest b) body <- return f + checkBinderType xTy b \b' -> do + PairTy resultTy fLinTy <- sinkM reqTy + body' <- checkBlock (EffTy Pure resultTy) body + checkTypesEq fLinTy (Pi $ nonDepPiType [sink xTy] Pure resultTy) + return $ Linearize (LamExpr (UnaryNest b') body') x' Transpose f x -> do - PiType (UnaryNest (binder:>a)) (EffTy Pure b) <- checkLamExpr f - b' <- liftHoistExcept $ hoist binder b - x |: b' - return a + (x', xTy) <- checkAndGetType x + LamExpr (UnaryNest b) body <- return f + checkB b \b' -> do + body' <- checkBlock (EffTy Pure (sink xTy)) body + checkTypesEq (sink $ binderType b') (sink reqTy) + return $ Transpose (LamExpr (UnaryNest b') body') x' RunReader r f -> do - (resultTy, readTy) <- checkRWSAction effs Reader f - r |: readTy - return resultTy - RunWriter d _ f -> do + (r', rTy) <- checkAndGetType r + f' <- checkRWSAction reqTy rTy effs Reader f + return $ RunReader r' f' + RunWriter d bm f -> do -- XXX: We can't verify compatibility between the base monoid and f, because -- the only way in which they are related in the runAccum definition is via -- the AccumMonoid typeclass. The frontend constraints should be sufficient -- to ensure that only well typed programs are accepted, but it is a bit -- disappointing that we cannot verify that internally. We might want to consider -- e.g. only disabling this check for prelude. - (resultTy, accTy) <- checkRWSAction effs Writer f - case d of - Nothing -> return () + bm' <- checkE bm + PairTy resultTy accTy <- return reqTy + f' <- checkRWSAction resultTy accTy effs Writer f + d' <- case d of + Nothing -> return Nothing Just dest -> do - dest |: RawRefTy accTy + dest' <- dest |: RawRefTy accTy declareEff effs InitEffect - return $ PairTy resultTy accTy + return $ Just dest' + return $ RunWriter d' bm' f' RunState d s f -> do - (resultTy, stateTy) <- checkRWSAction effs State f - s |: stateTy - case d of - Nothing -> return () + (s', sTy) <- checkAndGetType s + PairTy resultTy sTy' <- return reqTy + checkTypesEq sTy sTy' + f' <- checkRWSAction resultTy sTy effs State f + d' <- case d of + Nothing -> return Nothing Just dest -> do - dest |: RawRefTy stateTy declareEff effs InitEffect - return $ PairTy resultTy stateTy - RunIO body -> checkBlockWithEffs (extendEffect IOEffect effs) body - RunInit body -> checkBlockWithEffs (extendEffect InitEffect effs) body - CatchException reqTy body -> do - ty <- checkBlockWithEffs (extendEffect ExceptionEffect effs) body - makePreludeMaybeTy ty >>= checkAgainstGiven reqTy - -typeCheckDAMOp :: forall r m i o . (Typer m r, IRRep r) => EffectRow r o -> DAMOp r i -> m i o (Type r o) -typeCheckDAMOp effs op = addContext ("Checking DAMOp:\n" ++ pprint op) case op of - Seq effAnn _ ixTy' carry f -> do - effAnn' <- renameM effAnn - checkExtends effs effAnn' - ixTy <- renameM ixTy' - carryTy' <- getTypeE carry - let badCarry = throw TypeErr $ "Seq carry should be a product of raw references, got: " ++ pprint carryTy' - case carryTy' of - ProdTy refTys -> forM_ refTys \case RawRefTy _ -> return (); _ -> badCarry - _ -> badCarry - PiType (UnaryNest b) _ <- checkLamExprWithEffs effs f - checkTypesEq (PairTy (ixTypeType ixTy) carryTy') (binderType b) - return carryTy' - RememberDest effAnn d body -> do - effAnn' <- renameM effAnn - checkExtends effs effAnn' - dTy@(RawRefTy _) <- getTypeE d - PiType (UnaryNest b) (EffTy _ UnitTy) <- checkLamExpr body - checkTypesEq (binderType b) dTy - return dTy - AllocDest ty -> RawRefTy <$> checkTypeE TyKind ty - Place ref val -> do - ty <- getTypeE val - ref |: RawRefTy ty - declareEff effs InitEffect - return UnitTy - Freeze ref -> do - RawRefTy ty <- getTypeE ref - return ty - -checkLamExpr :: (Typer m r, IRRep r) => LamExpr r i -> m i o (PiType r o) -checkLamExpr (LamExpr bsTop body) = case bsTop of - Empty -> do - EffTy effs resultTy <- checkBlock body - return $ PiType Empty $ EffTy effs resultTy - Nest (b:>ty) bs -> do - ty' <- checkTypeE TyKind ty - withFreshBinder (getNameHint b) ty' \b' -> - extendRenamer (b@>binderName b') do - PiType bs' effTy <- checkLamExpr (LamExpr bs body) - return $ PiType (Nest b' bs') effTy - -checkLamExprWithEffs :: (Typer m r, IRRep r) => EffectRow r o -> LamExpr r i -> m i o (PiType r o) -checkLamExprWithEffs allowedEffs lam = do - piTy@(PiType bs (EffTy effs _)) <- checkLamExpr lam - effs' <- liftHoistExcept $ hoist bs effs - checkExtends allowedEffs effs' - return piTy - -checkBlockWithEffs :: forall i o r m. (Typer m r, IRRep r) => EffectRow r o -> Block r i -> m i o (Type r o) -checkBlockWithEffs allowedEffs (Abs decls result) = do - checkDecls allowedEffs decls \decls' -> do - resultTy <- getTypeE result - liftHoistExcept $ hoist decls' resultTy + Just <$> dest |: RawRefTy sTy + return $ RunState d' s' f' + RunIO body -> RunIO <$> checkBlock (EffTy (extendEffect IOEffect effs) reqTy) body + RunInit body -> RunInit <$> checkBlock (EffTy (extendEffect InitEffect effs) reqTy) body + CatchException reqTy' body -> do + reqTy'' <- checkE reqTy' + checkTypesEq reqTy reqTy'' + TypeCon _ _ (TyConParams _[Type ty]) <- return reqTy'' -- TODO: take more care in unpacking Maybe + body' <- checkBlock (EffTy (extendEffect ExceptionEffect effs) ty) body + return $ CatchException reqTy'' body' + +instance IRRep r => CheckableWithEffects r (DAMOp r) where + checkWithEffects effs = \case + Seq effAnn dir ixTy carry lam -> do + LamExpr (UnaryNest b) body <- return lam + effAnn' <- checkE effAnn + checkExtends effs effAnn' + ixTy' <- checkE ixTy + (carry', carryTy') <- checkAndGetType carry + let badCarry = throw TypeErr $ "Seq carry should be a product of raw references, got: " ++ pprint carryTy' + case carryTy' of + ProdTy refTys -> forM_ refTys \case RawRefTy _ -> return (); _ -> badCarry + _ -> badCarry + let binderReqTy = PairTy (ixTypeType ixTy') carryTy' + checkBinderType binderReqTy b \b' -> do + body' <- checkBlock (EffTy (sink effAnn') UnitTy) body + return $ Seq effAnn' dir ixTy' carry' $ LamExpr (UnaryNest b') body' + RememberDest effAnn d lam -> do + LamExpr (UnaryNest b) body <- return lam + effAnn' <- checkE effAnn + checkExtends effs effAnn' + (d', dTy@(RawRefTy _)) <- checkAndGetType d + checkBinderType dTy b \b' -> do + body' <- checkBlock (EffTy (sink effAnn') UnitTy) body + return $ RememberDest effAnn' d' $ LamExpr (UnaryNest b') body' + AllocDest ty -> AllocDest <$> ty|:TyKind + Place ref val -> do + val' <- checkE val + ref' <- ref |: RawRefTy (getType val') + declareEff effs InitEffect + return $ Place ref' val' + Freeze ref -> do + ref' <- checkE ref + RawRefTy _ <- return $ getType ref' + return $ Freeze ref' + +checkLamExpr :: IRRep r => PiType r o -> LamExpr r i -> TyperM r i o (LamExpr r o) +checkLamExpr piTy (LamExpr bs body) = + checkB bs \bs' -> do + effTy <- checkInstantiation (sink piTy) (Var <$> bindersVars bs') + body' <- checkBlock effTy body + return $ LamExpr bs' body' checkDecls - :: (Typer m r, IRRep r) + :: IRRep r => EffectRow r o -> Decls r i i' - -> (forall o'. DExt o o' => Decls r o o' -> m i' o' a) - -> m i o a + -> (forall o'. DExt o o' => Decls r o o' -> TyperM r i' o' a) + -> TyperM r i o a checkDecls _ Empty cont = getDistinct >>= \Distinct -> cont Empty -checkDecls effs (Nest (Let b rhs@(DeclBinding _ expr)) decls) cont = do - void $ typeCheckExpr effs expr - rhs' <- renameM rhs - withFreshBinder (getNameHint b) rhs' \(b':>_) -> do +checkDecls effs (Nest (Let b (DeclBinding ann expr)) decls) cont = do + rhs <- DeclBinding ann <$> checkWithEffects effs expr + withFreshBinder (getNameHint b) rhs \(b':>_) -> do extendRenamer (b@>binderName b') do - let decl' = Let b' rhs' + let decl' = Let b' rhs checkDecls (sink effs) decls \decls' -> cont $ Nest decl' decls' -checkBlock :: (Typer m r, IRRep r) => Block r i -> m i o (EffTy r o) -checkBlock block = do - EffTy effs _ <- blockEffTy =<< renameM block - ty <- checkBlockWithEffs effs block - return $ EffTy effs ty - -checkRWSAction :: (Typer m r, IRRep r) => EffectRow r o -> RWS -> LamExpr r i -> m i o (Type r o, Type r o) -checkRWSAction effs rws f = do +checkRWSAction + :: IRRep r => Type r o -> Type r o -> EffectRow r o + -> RWS -> LamExpr r i -> TyperM r i o (LamExpr r o) +checkRWSAction resultTy referentTy effs rws f = do BinaryLamExpr bH bR body <- return f - renameBinders bH \bH' -> renameBinders bR \bR' -> do - h <- sinkM $ binderVar bH' - let effs' = extendEffect (RWSEffect rws $ Var h) (sink effs) - RefTy _ referentTy <- sinkM $ binderType bR' - resultTy <- checkBlockWithEffs effs' body - liftM fromPairE $ liftHoistExcept $ hoist (PairB bH' bR') $ PairE resultTy referentTy - -checkCase :: (Typer m r, IRRep r) => Atom r i -> [Alt r i] -> Type r o -> EffectRow r o -> m i o () -checkCase scrut alts resultTy effs = do - scrutTy <- getTypeE scrut - altsBinderTys <- checkCaseAltsBinderTys scrutTy - parallelAffines_ $ zipWith (\alt bs -> - checkAlt resultTy bs effs alt) alts altsBinderTys - -checkCaseAltsBinderTys :: (Fallible1 m, EnvReader m, IRRep r) => Type r n -> m n [Type r n] + checkBinderType (TC HeapType) bH \bH' -> do + let h = Var $ binderVar bH' + let refTy = RefTy h (sink referentTy) + checkBinderType refTy bR \bR' -> do + let effs' = extendEffect (RWSEffect rws $ sink h) (sink effs) + body' <- checkBlock (EffTy effs' (sink resultTy)) body + return $ BinaryLamExpr bH' bR' body' + +checkCaseAltsBinderTys :: IRRep r => Type r n -> TyperM r i n [Type r n] checkCaseAltsBinderTys ty = case ty of SumTy types -> return types NewtypeTyCon t -> case t of - UserADTType _ defName params -> do + UserADTType _ defName (TyConParams _ params) -> do def <- lookupTyCon defName - ADTCons cons <- checkedInstantiateTyConDef def params + ADTCons cons <- checkInstantiation def params return [repTy | DataConDef _ _ repTy _ <- cons] _ -> fail msg _ -> fail msg where msg = "Case analysis only supported on ADTs, not on " ++ pprint ty -checkAlt :: (Typer m r, IRRep r) => Type r o -> Type r o -> EffectRow r o -> Alt r i -> m i o () -checkAlt resultTyReq bTyReq effs (Abs b body) = do - bTy <- renameM $ binderType b - checkTypesEq bTyReq bTy - renameBinders b \_ -> do - resultTy <- checkBlockWithEffs (sink effs) body - checkTypesEq (sink resultTyReq) resultTy - -checkApp :: (Typer m r, IRRep r) => EffectRow r o -> Type r o -> [Atom r i] -> m i o (Type r o) -checkApp allowedEffs fTy xs = case fTy of - Pi (CorePiType _ _ bs effTy) -> do - xs' <- mapM renameM xs - checkArgTys bs xs' - let subst = bs @@> fmap SubstVal xs' - EffTy effs' resultTy' <- applySubst subst effTy - checkExtends allowedEffs effs' - return resultTy' - _ -> throw TypeErr $ - "Not a type: " ++ pprint fTy ++ " (tried to apply it to: " ++ pprint xs ++ ")" - -checkTabApp :: (Typer m r, IRRep r) => Type r o -> [Atom r i] -> m i o (Type r o) +checkTabApp :: (IRRep r) => Type r o -> [Atom r o] -> TyperM r i o (Type r o) checkTabApp ty [] = return ty checkTabApp ty (i:rest) = do - TabTy _ (b :> ixTy) resultTy <- return ty - i' <- checkTypeE ixTy i - resultTy' <- applySubst (b@>SubstVal i') resultTy - checkTabApp resultTy' rest - -checkArgTys :: (Typer m r, IRRep r) => Nest (Binder r) o o' -> [Atom r o] -> m i o () -checkArgTys Empty [] = return () -checkArgTys (Nest b bs) (x:xs) = do - dropSubst $ x |: binderType b - Abs bs' UnitE <- applySubst (b@>SubstVal x) (EmptyAbs bs) - checkArgTys bs' xs -checkArgTys _ _ = throw TypeErr $ "wrong number of args" -{-# INLINE checkArgTys #-} + TabPi tabTy <- return ty + resultTy <- checkInstantiation tabTy [i] + checkTabApp resultTy rest + +checkInstantiation + :: forall r e body i o . + (IRRep r, SinkableE body, SubstE AtomSubstVal body, ToBindersAbs e body r) + => e o -> [Atom r o] -> TyperM r i o (body o) +checkInstantiation abTop xsTop = do + Abs bs body <- return $ toAbs abTop + go (Abs bs body) xsTop + where + go :: Abs (Nest (Binder r)) body o' -> [Atom r o'] -> TyperM r i o' (body o') + go (Abs Empty body) [] = return body + go (Abs (Nest b bs) body) (x:xs) = do + checkTypesEq (getType x) (binderType b) + rest <- applySubst (b@>SubstVal x) (Abs bs body) + go rest xs + go _ _ = throw ZipErr "Wrong number of args" checkIntBaseType :: Fallible m => BaseType -> m () checkIntBaseType t = case t of @@ -841,12 +827,6 @@ checkValidBaseCast sourceTy@(Vector sourceSizes _) destTy@(Vector destSizes _) = checkValidBaseCast sourceTy destTy = throw TypeErr $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy -typeCheckBaseType :: Typer m r => HasType r e => e i -> m i o BaseType -typeCheckBaseType e = - getTypeE e >>= \case - TC (BaseType b) -> return b - ty -> throw TypeErr $ "Expected a base type. Got: " ++ pprint ty - scalarOrVectorLike :: Fallible m => BaseType -> ScalarBaseType -> m BaseType scalarOrVectorLike x sbt = case x of Scalar _ -> return $ Scalar sbt @@ -854,7 +834,6 @@ scalarOrVectorLike x sbt = case x of _ -> throw CompilerErr "only scalar or vector base types should occur here" data ArgumentType = SomeFloatArg | SomeIntArg | SomeUIntArg -data ReturnType = SameReturn | Word8Return checkOpArgType :: Fallible m => ArgumentType -> BaseType -> m () checkOpArgType argTy x = @@ -864,182 +843,76 @@ checkOpArgType argTy x = assertEq x x' "" SomeFloatArg -> checkFloatBaseType x -checkBinOp :: Fallible m => BinOp -> BaseType -> BaseType -> m BaseType +checkBinOp :: Fallible m => BinOp -> BaseType -> BaseType -> m () checkBinOp op x y = do checkOpArgType argTy x assertEq x y "" - case retTy of - SameReturn -> return x - Word8Return -> scalarOrVectorLike x Word8Type where - (argTy, retTy) = case op of - IAdd -> (ia, sr); ISub -> (ia, sr) - IMul -> (ia, sr); IDiv -> (ia, sr) - IRem -> (ia, sr); - ICmp _ -> (ia, br) - FAdd -> (fa, sr); FSub -> (fa, sr) - FMul -> (fa, sr); FDiv -> (fa, sr); - FPow -> (fa, sr) - FCmp _ -> (fa, br) - BAnd -> (ia, sr); BOr -> (ia, sr) - BXor -> (ia, sr) - BShL -> (ia, sr); BShR -> (ia, sr) + argTy = case op of + IAdd -> ia; ISub -> ia + IMul -> ia; IDiv -> ia + IRem -> ia; + ICmp _ -> ia + FAdd -> fa; FSub -> fa + FMul -> fa; FDiv -> fa; + FPow -> fa + FCmp _ -> fa + BAnd -> ia; BOr -> ia + BXor -> ia + BShL -> ia; BShR -> ia where ia = SomeIntArg; fa = SomeFloatArg - br = Word8Return; sr = SameReturn -checkUnOp :: Fallible m => UnOp -> BaseType -> m BaseType -checkUnOp op x = do - checkOpArgType argTy x - case retTy of - SameReturn -> return x - _ -> throw CompilerErr "all supported unary operations have the same argument and return type" +checkUnOp :: Fallible m => UnOp -> BaseType -> m () +checkUnOp op x = checkOpArgType argTy x where - (argTy, retTy) = case op of - Exp -> (f, sr) - Exp2 -> (f, sr) - Log -> (f, sr) - Log2 -> (f, sr) - Log10 -> (f, sr) - Log1p -> (f, sr) - Sin -> (f, sr) - Cos -> (f, sr) - Tan -> (f, sr) - Sqrt -> (f, sr) - Floor -> (f, sr) - Ceil -> (f, sr) - Round -> (f, sr) - LGamma -> (f, sr) - Erf -> (f, sr) - Erfc -> (f, sr) - FNeg -> (f, sr) - BNot -> (u, sr) + argTy = case op of + Exp -> f + Exp2 -> f + Log -> f + Log2 -> f + Log10 -> f + Log1p -> f + Sin -> f + Cos -> f + Tan -> f + Sqrt -> f + Floor -> f + Ceil -> f + Round -> f + LGamma -> f + Erf -> f + Erfc -> f + FNeg -> f + BNot -> u where - u = SomeUIntArg; f = SomeFloatArg; sr = SameReturn - --- === various helpers for querying types === - -checkedInstantiateTyConDef - :: (EnvReader m, Fallible1 m) - => TyConDef n -> TyConParams n -> m n (DataConDefs n) -checkedInstantiateTyConDef (TyConDef _ _ bs cons) (TyConParams _ xs) = do - checkedApplyNaryAbs (Abs bs cons) xs - -checkedApplyNaryAbs - :: forall r e o m - . ( EnvReader m, Fallible1 m, SinkableE e , SubstE AtomSubstVal e, IRRep r) - => Abs (Nest (Binder r)) e o -> [Atom r o] -> m o (e o) -checkedApplyNaryAbs (Abs bsTop e) xsTop = do - go (EmptyAbs bsTop) xsTop - applySubst (bsTop@@>(SubstVal<$>xsTop)) e - where - go :: EmptyAbs (Nest (Binder r)) o -> [Atom r o] -> m o () - go (Abs Empty UnitE) [] = return () - go (Abs (Nest b bs) UnitE) (x:xs) = do - checkAlphaEq (binderType b) (getType x) - bs' <- applySubst (b@>SubstVal x) (Abs bs UnitE) - go bs' xs - go _ _ = throw TypeErr "wrong number of arguments" + u = SomeUIntArg; f = SomeFloatArg; -- === effects === instance IRRep r => CheckableE r (EffectRow r) where checkE (EffectRow effs effTail) = do - forM_ (eSetToList effs) \eff -> case eff of - RWSEffect _ v -> v |: TC HeapType - ExceptionEffect -> return () - IOEffect -> return () - InitEffect -> return () - case effTail of - NoTail -> return () + effs' <- eSetFromList <$> forM (eSetToList effs) \eff -> case eff of + RWSEffect rws v -> do + v' <- v |: TC HeapType + return $ RWSEffect rws v' + ExceptionEffect -> return ExceptionEffect + IOEffect -> return IOEffect + InitEffect -> return InitEffect + effTail' <- case effTail of + NoTail -> return NoTail EffectRowTail v -> do v' <- renameM v ty <- getType <$> lookupAtomName (atomVarName v') checkTypesEq EffKind ty + return $ EffectRowTail v' + return $ EffectRow effs' effTail' -declareEff :: forall r m i o. (IRRep r, Typer m r) => EffectRow r o -> Effect r o -> m i o () +declareEff :: IRRep r => EffectRow r o -> Effect r o -> TyperM r i o () declareEff allowedEffs eff = checkExtends allowedEffs $ OneEffect eff -checkExtends :: (Fallible m, IRRep r) => EffectRow r n -> EffectRow r n -> m () -checkExtends allowed (EffectRow effs effTail) = do - let (EffectRow allowedEffs allowedEffTail) = allowed - case effTail of - EffectRowTail _ -> assertEq allowedEffTail effTail "" - NoTail -> return () - forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $ - throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ - "\nAllowed: " ++ pprint allowed - --- === "Data" type class === - -runCheck - :: (EnvReader m, SinkableE e) - => (forall l. DExt n l => TyperT Maybe r l l (e l)) - -> m n (Maybe (e n)) -runCheck cont = do - Distinct <- getDistinct - liftTyperT $ cont - -asFFIFunType :: EnvReader m => CType n -> m n (Maybe (IFunType, CorePiType n)) -asFFIFunType ty = return do - Pi piTy <- return ty - impTy <- checkFFIFunTypeM piTy - return (impTy, piTy) - -checkFFIFunTypeM :: Fallible m => CorePiType n -> m IFunType -checkFFIFunTypeM (CorePiType appExpl (_:expls) (Nest b bs) effTy) = do - argTy <- checkScalar $ binderType b - case bs of - Empty -> do - resultTys <- checkScalarOrPairType (etTy effTy) - let cc = case length resultTys of - 0 -> error "Not implemented" - 1 -> FFICC - _ -> FFIMultiResultCC - return $ IFunType cc [argTy] resultTys - Nest b' rest -> do - let naryPiRest = CorePiType appExpl expls (Nest b' rest) effTy - IFunType cc argTys resultTys <- checkFFIFunTypeM naryPiRest - return $ IFunType cc (argTy:argTys) resultTys -checkFFIFunTypeM _ = error "expected at least one argument" - -checkScalar :: (IRRep r, Fallible m) => Type r n -> m BaseType -checkScalar (BaseTy ty) = return ty -checkScalar ty = throw TypeErr $ pprint ty - -checkScalarOrPairType :: (IRRep r, Fallible m) => Type r n -> m [BaseType] -checkScalarOrPairType (PairTy a b) = do - tys1 <- checkScalarOrPairType a - tys2 <- checkScalarOrPairType b - return $ tys1 ++ tys2 -checkScalarOrPairType (BaseTy ty) = return [ty] -checkScalarOrPairType ty = throw TypeErr $ pprint ty - -isData :: EnvReader m => Type CoreIR n -> m n Bool -isData ty = liftM isJust $ runCheck do - checkDataLike (sink ty) - return UnitE - -checkDataLike :: Typer m r => Type CoreIR i -> m i o () -checkDataLike ty = case ty of - TyVar _ -> notData - TabPi (TabPiType _ b eltTy) -> do - renameBinders b \_ -> - checkDataLike eltTy - DepPairTy (DepPairType _ b@(_:>l) r) -> do - recur l - renameBinders b \_ -> checkDataLike r - NewtypeTyCon nt -> do - (_, ty') <- unwrapNewtypeType =<< renameM nt - dropSubst $ recur ty' - TC con -> case con of - BaseType _ -> return () - ProdType as -> mapM_ recur as - SumType cs -> mapM_ recur cs - RefType _ _ -> return () - HeapType -> return () - _ -> notData - _ -> notData - where - recur = checkDataLike - notData = throw TypeErr $ pprint ty +checkEffTy :: IRRep r => EffectRow r o -> EffTy r i -> TyperM r i o (EffTy r o) +checkEffTy allowedEffs effTy = do + EffTy declaredEffs resultTy <- checkE effTy + checkExtends allowedEffs declaredEffs + return $ EffTy declaredEffs resultTy diff --git a/src/lib/Core.hs b/src/lib/Core.hs index f6fb57452..a7a107b49 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -82,6 +82,7 @@ newtype EnvReaderT (m::MonadKind) (n::S) (a:: *) = , MonadWriter w, Fallible, Searcher, Alternative) type EnvReaderM = EnvReaderT Identity +type FallibleEnvReaderM = EnvReaderT FallibleM runEnvReaderM :: Distinct n => Env n -> EnvReaderM n a -> a runEnvReaderM bindings m = runIdentity $ runEnvReaderT bindings m diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index dacb584fb..58c0721d4 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -8,7 +8,6 @@ module Generalize (generalizeArgs, generalizeIxDict) where import Control.Monad -import CheckType (isData) import Core import Err import Types.Core diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 32160c868..9a9ab2e71 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -1535,7 +1535,7 @@ impInstrTypes instr = case instr of where hostPtrTy ty = PtrType (CPU, ty) instance CheckableE SimpIR ImpFunction where - checkE _ = return () -- TODO + checkE = renameM -- TODO -- TODO: Don't use Core Envs for Imp! instance BindsEnv ImpDecl where diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 9db292f9e..cd28d7c5e 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -11,7 +11,7 @@ module Inference ( inferTopUDecl, checkTopUType, inferTopUExpr , trySynthTerm, generalizeDict, asTopBlock - , synthTopE, UDeclInferenceResult (..)) where + , synthTopE, UDeclInferenceResult (..), asFFIFunType) where import Prelude hiding ((.), id) import Control.Category @@ -45,6 +45,7 @@ import SourceInfo import Subst import QueryType import Types.Core +import Types.Imp import Types.Primitives import Types.Source import Util hiding (group) @@ -3184,6 +3185,43 @@ withFabricatedEmitsInf cont = fromWrapWithEmitsInf newtype WrapWithEmitsInf n r = WrapWithEmitsInf { fromWrapWithEmitsInf :: EmitsInf n => r } +-- === IFunType === + +asFFIFunType :: EnvReader m => CType n -> m n (Maybe (IFunType, CorePiType n)) +asFFIFunType ty = return do + Pi piTy <- return ty + impTy <- checkFFIFunTypeM piTy + return (impTy, piTy) + +checkFFIFunTypeM :: Fallible m => CorePiType n -> m IFunType +checkFFIFunTypeM (CorePiType appExpl (_:expls) (Nest b bs) effTy) = do + argTy <- checkScalar $ binderType b + case bs of + Empty -> do + resultTys <- checkScalarOrPairType (etTy effTy) + let cc = case length resultTys of + 0 -> error "Not implemented" + 1 -> FFICC + _ -> FFIMultiResultCC + return $ IFunType cc [argTy] resultTys + Nest b' rest -> do + let naryPiRest = CorePiType appExpl expls (Nest b' rest) effTy + IFunType cc argTys resultTys <- checkFFIFunTypeM naryPiRest + return $ IFunType cc (argTy:argTys) resultTys +checkFFIFunTypeM _ = error "expected at least one argument" + +checkScalar :: (IRRep r, Fallible m) => Type r n -> m BaseType +checkScalar (BaseTy ty) = return ty +checkScalar ty = throw TypeErr $ pprint ty + +checkScalarOrPairType :: (IRRep r, Fallible m) => Type r n -> m [BaseType] +checkScalarOrPairType (PairTy a b) = do + tys1 <- checkScalarOrPairType a + tys2 <- checkScalarOrPairType b + return $ tys1 ++ tys2 +checkScalarOrPairType (BaseTy ty) = return [ty] +checkScalarOrPairType ty = throw TypeErr $ pprint ty + -- === instances === instance PrettyE e => Pretty (UDeclInferenceResult e l) where @@ -3197,9 +3235,11 @@ instance SinkableE e => SinkableE (UDeclInferenceResult e) where instance (RenameE e, CheckableE CoreIR e) => CheckableE CoreIR (UDeclInferenceResult e) where checkE = \case - UDeclResultDone _ -> return () - UDeclResultBindName _ block _ -> checkE block - UDeclResultBindPattern _ block _ -> checkE block + UDeclResultDone e -> UDeclResultDone <$> checkE e + UDeclResultBindName ann block ab -> + UDeclResultBindName ann <$> checkE block <*> renameM ab -- TODO: check result + UDeclResultBindPattern hint block recon -> + UDeclResultBindPattern hint <$> checkE block <*> renameM recon -- TODO: check recon instance HasType CoreIR InfEmission where getType = \case diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 3d6b91a3f..d32b5230a 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -291,7 +291,7 @@ linearize f x = runPrimalMInit $ linearizeLambdaApp f x linearizeTopLam :: STopLam n -> [Active] -> DoubleBuilder SimpIR n (STopLam n, STopLam n) linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do (primalFun, tangentFun) <- runPrimalMInit $ refreshBinders bs \bs' frag -> extendSubst frag do - let allPrimals = nestToAtomVars bs' + let allPrimals = bindersVars bs' activeVs <- catMaybes <$> forM (zip actives allPrimals) \(active, v) -> case active of True -> return $ Just v False -> return $ Nothing diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 9bfa3a2c4..50a976816 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -213,7 +213,7 @@ getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case MethodBinding className i -> do ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' absPiTy -> do - let params = Var <$> nestToAtomVars paramBs' + let params = Var <$> bindersVars paramBs' dictTy <- DictTy <$> dictType (sink className) params withFreshBinder noHint dictTy \dictB -> do scDicts <- getSuperclassDicts (Var $ binderVar dictB) @@ -384,3 +384,47 @@ liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where t:ts -> withFreshBinder noHint (BaseTy t) \b -> do PiType bs effTy <- go ts return $ PiType (Nest b bs) effTy + +-- === Data constraints === + +isData :: EnvReader m => Type CoreIR n -> m n Bool +isData ty = do + result <- liftEnvReaderT $ withSubstReaderT $ checkDataLike ty + case runFallibleM result of + Success () -> return True + Failure _ -> return False + +checkDataLike :: Type CoreIR i -> SubstReaderT Name FallibleEnvReaderM i o () +checkDataLike ty = case ty of + TyVar _ -> notData + TabPi (TabPiType _ b eltTy) -> do + renameBinders b \_ -> + checkDataLike eltTy + DepPairTy (DepPairType _ b@(_:>l) r) -> do + recur l + renameBinders b \_ -> checkDataLike r + NewtypeTyCon nt -> do + (_, ty') <- unwrapNewtypeType =<< renameM nt + dropSubst $ recur ty' + TC con -> case con of + BaseType _ -> return () + ProdType as -> mapM_ recur as + SumType cs -> mapM_ recur cs + RefType _ _ -> return () + HeapType -> return () + _ -> notData + _ -> notData + where + recur = checkDataLike + notData = throw TypeErr $ pprint ty + +checkExtends :: (Fallible m, IRRep r) => EffectRow r n -> EffectRow r n -> m () +checkExtends allowed (EffectRow effs effTail) = do + let (EffectRow allowedEffs allowedEffTail) = allowed + case effTail of + EffectRowTail _ -> assertEq allowedEffTail effTail "" + NoTail -> return () + forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $ + throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ + "\nAllowed: " ++ pprint allowed + diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index b3cafc992..028e35981 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -19,7 +19,7 @@ import Data.Text.Prettyprint.Doc (Pretty (..), hardline) import Builder import CheapReduction -import CheckType (CheckableE (..), isData, checkBlock) +import CheckType import Core import Err import Generalize @@ -274,9 +274,12 @@ instance SinkableE SimplifiedBlock instance RenameE SimplifiedBlock instance HoistableE SimplifiedBlock instance CheckableE SimpIR SimplifiedBlock where - checkE (SimplifiedBlock block _) = - -- TODO: CheckableE instance for the recon too - void $ checkBlock block + checkE (SimplifiedBlock block recon) = do + block' <- renameM block + effTy <- blockEffTy block' -- TODO: store this in the simplified block instead + block'' <- dropSubst $ checkBlock effTy block' + recon' <- renameM recon -- TODO: CheckableE instance for the recon too + return $ SimplifiedBlock block'' recon' instance Pretty (SimplifiedBlock n) where pretty (SimplifiedBlock block recon) = @@ -286,9 +289,9 @@ instance SinkableE SimplifiedTopLam where sinkingProofE = todoSinkableProof instance CheckableE SimpIR SimplifiedTopLam where - checkE (SimplifiedTopLam lam _) = do + checkE (SimplifiedTopLam lam recon) = -- TODO: CheckableE instance for the recon too - checkE lam + SimplifiedTopLam <$> checkE lam <*> renameM recon instance Pretty (SimplifiedTopLam n) where pretty (SimplifiedTopLam lam recon) = pretty lam <> hardline <> pretty recon diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 4fdda0f02..caf3d591c 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -43,9 +43,9 @@ import qualified LLVM.AST import AbstractSyntax import Builder -import CheckType ( CheckableE (..), asFFIFunType, checkHasType) +import CheckType ( CheckableE (..), checkTypeIs) #ifdef DEX_DEBUG -import CheckType (checkTypesM) +import CheckType (checkTypes) #endif import Core import ConcreteSyntax @@ -316,7 +316,7 @@ evalSourceBlock' mname block = case sbContents block of _ -> evalUExpr expr fType <- getType <$> toAtomVar fname' (nimplicit, nexplicit, linFunTy) <- liftExceptEnvReaderM $ getLinearizationType zeros fType - impl `checkHasType` linFunTy >>= \case + liftEnvReaderT (impl `checkTypeIs` linFunTy) >>= \case Failure _ -> do let implTy = getType impl throw TypeErr $ unlines @@ -744,7 +744,7 @@ checkPass name cont = do return result #ifdef DEX_DEBUG logTop $ MiscLog $ "Running checks" - checkTypesM result + checkTypes result logTop $ MiscLog $ "Checks passed" #else logTop $ MiscLog $ "Checks skipped (not a debug build)" diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 02ab9c8bf..067af737e 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -943,12 +943,12 @@ binderType (_:>ty) = ty binderVar :: (IRRep r, DExt n l) => Binder r n l -> AtomVar r l binderVar (b:>ty) = AtomVar (binderName b) (sink ty) -nestToAtomVars :: (Distinct l, Ext n l, IRRep r) - => Nest (Binder r) n l -> [AtomVar r l] -nestToAtomVars = \case +bindersVars :: (Distinct l, Ext n l, IRRep r) + => Nest (Binder r) n l -> [AtomVar r l] +bindersVars = \case Empty -> [] Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ - sink (binderVar b) : nestToAtomVars bs + sink (binderVar b) : bindersVars bs -- === ToBinding === diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 0c361236d..21a6974ee 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -701,6 +701,15 @@ instance Color c => BindsAtMostOneName (UBinder c) c where UIgnore -> emptyInFrag UBind _ _ b' -> b' @> x +instance Color c => SinkableB (UBinder c) where + sinkingProofB _ _ _ = todoSinkableProof + +instance Color c => RenameB (UBinder c) where + renameB env ub cont = case ub of + UBindSource pos sn -> cont env $ UBindSource pos sn + UIgnore -> cont env UIgnore + UBind ctx sn b -> renameB env b \env' b' -> cont env' $ UBind ctx sn b' + instance ProvesExt (UAnnBinder req) where instance BindsNames (UAnnBinder req) where toScopeFrag (UAnnBinder b _ _) = toScopeFrag b