From 43d7b98633996cbe0ac39ff9d817d159d8fd0f26 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 11 Jul 2023 08:38:46 -0400 Subject: [PATCH 1/2] Make a `BinderAndDecls` data type, at first just a wrapper around `Binder`. Next we'll actually make decls part of it. But that will involve substantive changes. This first step just makes the type distinct from `Binder` so we can make all the administrative changes while still passing tests. --- src/lib/Builder.hs | 66 ++++--- src/lib/CheapReduction.hs | 112 +++++++----- src/lib/CheckType.hs | 25 ++- src/lib/Core.hs | 6 +- src/lib/Export.hs | 13 +- src/lib/Generalize.hs | 25 +-- src/lib/Imp.hs | 86 ++++----- src/lib/Inference.hs | 281 ++++++++++++++++-------------- src/lib/Inline.hs | 11 +- src/lib/JAX/ToSimp.hs | 4 +- src/lib/Linearize.hs | 29 +-- src/lib/Lower.hs | 10 +- src/lib/OccAnalysis.hs | 4 +- src/lib/Optimize.hs | 19 +- src/lib/PPrint.hs | 5 +- src/lib/QueryType.hs | 47 +++-- src/lib/QueryTypePure.hs | 20 +-- src/lib/RuntimePrint.hs | 4 +- src/lib/Simplify.hs | 67 +++---- src/lib/Transpose.hs | 22 +-- src/lib/Types/Core.hs | 127 ++++++++++---- src/lib/Vectorize.hs | 45 ++--- tests/unit/ConstantCastingSpec.hs | 6 +- 23 files changed, 576 insertions(+), 458 deletions(-) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index c539d01b4..b00baa6d1 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -649,7 +649,7 @@ buildAbs hint binding cont = do typesFromNonDepBinderNest :: (EnvReader m, Fallible1 m, IRRep r) - => Nest (Binder r) n l -> m n [Type r n] + => Binders r n l -> m n [Type r n] typesFromNonDepBinderNest Empty = return [] typesFromNonDepBinderNest (Nest b rest) = do Abs rest' UnitE <- return $ assumeConst $ Abs (UnaryNest b) $ Abs rest UnitE @@ -662,7 +662,7 @@ buildUnaryLamExpr -> (forall l. (Emits l, Distinct l, DExt n l) => AtomVar r l -> m l (Atom r l)) -> m n (LamExpr r n) buildUnaryLamExpr hint ty cont = do - bs <- withFreshBinder hint ty \b -> return $ EmptyAbs (UnaryNest b) + bs <- withFreshBinder hint ty \b -> return $ EmptyAbs (UnaryNest (PlainBD b)) buildLamExpr bs \[v] -> cont v buildBinaryLamExpr @@ -672,21 +672,21 @@ buildBinaryLamExpr -> m n (LamExpr r n) buildBinaryLamExpr (h1,t1) (h2,t2) cont = do bs <- withFreshBinder h1 t1 \b1 -> withFreshBinder h2 (sink t2) \b2 -> - return $ EmptyAbs $ BinaryNest b1 b2 + return $ EmptyAbs $ BinaryNest (PlainBD b1) (PlainBD b2) buildLamExpr bs \[v1, v2] -> cont v1 v2 buildLamExpr :: ScopableBuilder r m - => (EmptyAbs (Nest (Binder r)) n) + => (Abs (Binders r) any n) -> (forall l. (Emits l, Distinct l, DExt n l) => [AtomVar r l] -> m l (Atom r l)) -> m n (LamExpr r n) -buildLamExpr (Abs bs UnitE) cont = case bs of +buildLamExpr (Abs bs _) cont = case bs of Empty -> LamExpr Empty <$> buildBlock (cont []) Nest b rest -> do Abs b' (LamExpr bs' body') <- buildAbs (getNameHint b) (binderType b) \v -> do - rest' <- applySubst (b@>SubstVal (Var v)) $ EmptyAbs rest + rest' <- instantiate (Abs (UnaryNest b) (EmptyAbs rest)) [Var v] buildLamExpr rest' \vs -> cont $ sink v : vs - return $ LamExpr (Nest b' bs') body' + return $ LamExpr (Nest (PlainBD b') bs') body' buildTopLamFromPi :: ScopableBuilder r m @@ -765,7 +765,7 @@ buildEffLam hint ty body = do let ref = binderVar b hVar <- sinkM $ binderVar h body' <- buildBlock $ body (sink hVar) $ sink ref - return $ LamExpr (BinaryNest h b) body' + return $ LamExpr (BinaryNest (PlainBD h) (PlainBD b)) body' buildForAnn :: (Emits n, ScopableBuilder r m) @@ -776,7 +776,7 @@ buildForAnn hint ann (IxType iTy ixDict) body = do lam <- withFreshBinder hint iTy \b -> do let v = binderVar b body' <- buildBlock $ body $ sink v - return $ LamExpr (UnaryNest b) body' + return $ UnaryLamExpr b body' emitHof $ For ann (IxType iTy ixDict) lam buildFor :: (Emits n, ScopableBuilder r m) @@ -862,7 +862,7 @@ zeroAt ty = liftEmitBuilder $ go ty where BaseTy bt -> return $ Con $ Lit $ zeroLit bt ProdTy tys -> ProdVal <$> mapM go tys TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> - go =<< instantiate (sink tabPi) [Var i] + go =<< instantiate tabPi [Var i] _ -> unreachable zeroLit bt = case bt of Scalar Float64Type -> Float64Lit 0.0 @@ -1336,17 +1336,15 @@ runMaybeWhile body = do type ReconAbs r e = Abs (ReconBinders r) e -data ReconBinders r n l = ReconBinders - (TelescopeType (AtomNameC r) (Type r) n) - (Nest (NameBinder (AtomNameC r)) n l) +data ReconBinders r n l = ReconBinders (TelescopeType r n) (Nest (NameBinder (AtomNameC r)) n l) -data TelescopeType c e n = - DepTelescope (TelescopeType c e n) (Abs (BinderP c e) (TelescopeType c e) n) - | ProdTelescope [e n] +data TelescopeType r n = + DepTelescope (TelescopeType r n) (Abs (BinderAndDecls r) (TelescopeType r) n) + | ProdTelescope [Type r n] instance IRRep r => GenericB (ReconBinders r) where type RepB (ReconBinders r) = - PairB (LiftB (TelescopeType (AtomNameC r) (Type r))) + PairB (LiftB (TelescopeType r)) (Nest (NameBinder (AtomNameC r))) fromB (ReconBinders x y) = PairB (LiftB x) y {-# INLINE fromB #-} @@ -1365,10 +1363,10 @@ instance IRRep r => ProvesExt (ReconBinders r) instance IRRep r => BindsNames (ReconBinders r) instance IRRep r => HoistableB (ReconBinders r) -instance GenericE (TelescopeType c e) where - type RepE (TelescopeType c e) = EitherE - (PairE (TelescopeType c e) (Abs (BinderP c e) (TelescopeType c e))) - (ListE e) +instance GenericE (TelescopeType r) where + type RepE (TelescopeType r) = EitherE + (PairE (TelescopeType r) (Abs (BinderAndDecls r) (TelescopeType r))) + (ListE (Type r)) fromE (DepTelescope lhs ab) = LeftE (PairE lhs ab) fromE (ProdTelescope tys) = RightE (ListE tys) {-# INLINE fromE #-} @@ -1376,10 +1374,10 @@ instance GenericE (TelescopeType c e) where toE (RightE (ListE tys)) = ProdTelescope tys {-# INLINE toE #-} -instance (Color c, SinkableE e) => SinkableE (TelescopeType c e) -instance (Color c, SinkableE e, RenameE e) => RenameE (TelescopeType c e) -instance (Color c, ToBinding e c, SubstE AtomSubstVal e) => SubstE AtomSubstVal (TelescopeType c e) -instance (Color c, HoistableE e) => HoistableE (TelescopeType c e) +instance IRRep r => SinkableE (TelescopeType r) +instance IRRep r => RenameE (TelescopeType r) +instance IRRep r => SubstE AtomSubstVal (TelescopeType r) +instance IRRep r => HoistableE (TelescopeType r) telescopicCapture :: (EnvReader m, HoistableE e, HoistableB b, IRRep r) @@ -1405,27 +1403,27 @@ applyReconAbs (Abs bs result) x = do applySubst (bs @@> map SubstVal xs) result buildTelescopeTy - :: (EnvReader m, EnvExtender m, Color c, HoistableE e) - => [AnnVar c e n] -> m n (TelescopeType c e n) + :: (EnvReader m, EnvExtender m, IRRep r) + => [AnnVar (AtomNameC r) (Type r) n] -> m n (TelescopeType r n) buildTelescopeTy [] = return (ProdTelescope []) buildTelescopeTy ((v,ty):xs) = do rhs <- buildTelescopeTy xs Abs b rhs' <- return $ abstractFreeVar v rhs case hoist b rhs' of HoistSuccess rhs'' -> return $ prependTelescopeTy ty rhs'' - HoistFailure _ -> return $ DepTelescope (ProdTelescope []) (Abs (b:>ty) rhs') + HoistFailure _ -> return $ DepTelescope (ProdTelescope []) (Abs (BD (b:>ty)) rhs') -prependTelescopeTy :: e n -> TelescopeType c e n -> TelescopeType c e n +prependTelescopeTy :: Type r n -> TelescopeType r n -> TelescopeType r n prependTelescopeTy x = \case DepTelescope lhs rhs -> DepTelescope (prependTelescopeTy x lhs) rhs ProdTelescope xs -> ProdTelescope (x:xs) buildTelescopeVal :: (EnvReader m, IRRep r) => [Atom r n] - -> TelescopeType (AtomNameC r) (Type r) n -> m n (Atom r n) + -> TelescopeType r n -> m n (Atom r n) buildTelescopeVal xsTop tyTop = fst <$> go tyTop xsTop where go :: (EnvReader m, IRRep r) - => TelescopeType (AtomNameC r) (Type r) n -> [Atom r n] + => TelescopeType r n -> [Atom r n] -> m n (Atom r n, [Atom r n]) go ty rest = case ty of ProdTelescope tys -> do @@ -1433,12 +1431,12 @@ buildTelescopeVal xsTop tyTop = fst <$> go tyTop xsTop where return (ProdVal xs, rest') DepTelescope ty1 (Abs b ty2) -> do (x1, ~(xDep : rest')) <- go ty1 rest - ty2' <- applySubst (b@>SubstVal xDep) ty2 + ty2' <- instantiate (Abs b ty2) [xDep] (x2, rest'') <- go ty2' rest' let depPairTy = DepPairType ExplicitDepPair b (telescopeTypeType ty2) return (PairVal x1 (DepPair xDep x2 depPairTy), rest'') -telescopeTypeType :: TelescopeType (AtomNameC r) (Type r) n -> Type r n +telescopeTypeType :: TelescopeType r n -> Type r n telescopeTypeType (ProdTelescope tys) = ProdTy tys telescopeTypeType (DepTelescope lhs (Abs b rhs)) = do let lhs' = telescopeTypeType lhs @@ -1450,7 +1448,7 @@ unpackTelescope => ReconBinders r l1 l2 -> Atom r n -> m n [Atom r n] unpackTelescope (ReconBinders tyTop _) xTop = go tyTop xTop where go :: (Fallible1 m, EnvReader m, IRRep r) - => TelescopeType c e l-> Atom r n -> m n [Atom r n] + => TelescopeType r l-> Atom r n -> m n [Atom r n] go ty x = case ty of ProdTelescope _ -> getUnpacked x DepTelescope ty1 (Abs _ ty2) -> do diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index c4cc41bb1..ff55abdf7 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -16,7 +16,8 @@ module CheapReduction , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated - , bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst) + , instantiateNames, withInstantiatedNames, assumeConst, tryAsConst + , extendSubstBD, arity) where import Control.Applicative @@ -218,12 +219,9 @@ instance IRRep r => CheaplyReducibleE r (Type r) (Type r) where -- means that we will follow the full call chain, so it's really expensive! -- TODO: we don't collect the dict holes here, so there's a danger of -- dropping them if they turn out to be phantom. - TabPi (TabPiType d (b:>t) resultTy) -> do - t' <- cheapReduceE t + TabPi (TabPiType d b resultTy) -> do d' <- cheapReduceE d - withFreshBinder (getNameHint b) t' \b' -> do - resultTy' <- extendSubst (b@>Rename (binderName b')) $ cheapReduceE resultTy - return $ TabPi $ TabPiType d' b' resultTy' + cheapReduceBinder b \b' -> TabPi <$> TabPiType d' b' <$> cheapReduceE resultTy -- We traverse the Atom constructors that might contain lambda expressions -- explicitly, to make sure that we can skip normalizing free vars inside those. NewtypeTyCon (Fin n) -> NewtypeTyCon . Fin <$> cheapReduceE n @@ -234,6 +232,16 @@ instance IRRep r => CheaplyReducibleE r (Type r) (Type r) where a' <- substM a dropSubst $ traverseNames cheapReduceName a' +cheapReduceBinder + :: IRRep r + => BinderAndDecls r i i' + -> (forall o'. DExt o o' => BinderAndDecls r o o' -> CheapReducerM r i' o' a) + -> CheapReducerM r i o a +cheapReduceBinder (BD (b:>ty)) cont = do + ty' <- cheapReduceE ty + withFreshBinder (getNameHint b) ty' \b' -> do + extendSubst (b@>Rename (binderName b')) $ cont (BD b') + cheapReduceDictExpr :: CType o -> DictExpr i -> CheapReducerM CoreIR i o (CAtom o) cheapReduceDictExpr resultTy d = case d of SuperclassProj child superclassIx -> do @@ -401,9 +409,9 @@ liftSimpAtom ty simpAtom = case simpAtom of (BaseTy _ , Con (Lit v)) -> return $ Con $ Lit v (ProdTy tys, Con (ProdCon xs)) -> Con . ProdCon <$> zipWithM rec tys xs (SumTy tys, Con (SumCon _ i x)) -> Con . SumCon tys i <$> rec (tys!!i) x - (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do - x1' <- rec t1 x1 - t2' <- applySubst (b@>SubstVal x1') t2 + (DepPairTy dpt, DepPair x1 x2 _) -> do + x1' <- rec (depPairLeftTy dpt) x1 + t2' <- instantiate dpt [x1'] x2' <- rec t2' x2 return $ DepPair x1' x2' dpt _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty' @@ -426,8 +434,8 @@ confuseGHC = getDistinct -- them. Maybe a common set of low-level type-querying utils that both -- CheapReduction and QueryType import? -depPairLeftTy :: DepPairType r n -> Type r n -depPairLeftTy (DepPairType _ (_:>ty) _) = ty +depPairLeftTy :: IRRep r => DepPairType r n -> Type r n +depPairLeftTy (DepPairType _ b _) = binderType b {-# INLINE depPairLeftTy #-} unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n) @@ -463,19 +471,33 @@ wrapNewtypesData [] x = x wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n) -instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do - applySubst (bs @@> (SubstVal <$> xs)) conDefs +instantiateTyConDef tyConDef (TyConParams _ xs) = instantiate tyConDef xs {-# INLINE instantiateTyConDef #-} assumeConst :: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> body n assumeConst e = case toAbs e of Abs bs body -> ignoreHoistFailure $ hoist bs body +arity :: (IRRep r, ToBindersAbs e body r) => e n -> Int +arity e = case toAbs e of Abs bs _ -> nestLength bs + +tryAsConst + :: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> Maybe (body n) +tryAsConst e = + case toAbs e of + Abs bs body -> case hoist bs body of + HoistFailure _ -> Nothing + HoistSuccess e' -> Just e' + instantiate - :: (EnvReader m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r) - => e n -> [Atom r n] -> m n (body n) -instantiate e xs = case toAbs e of - Abs bs body -> applySubst (bs @@> (SubstVal <$> xs)) body + :: (EnvReader m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, SinkableE e, + ToBindersAbs e body r, Ext h n) + => e h -> [Atom r n] -> m n (body n) +instantiate e xs = do + Abs bs body <- sinkM $ toAbs e + let bs' = fmapNest (\(BD b) -> b) bs + applySubst (bs' @@> (SubstVal <$> xs)) body +{-# INLINE instantiate #-} -- "lazy" subst-extending version of `instantiate` withInstantiated @@ -483,14 +505,18 @@ withInstantiated => e i -> [Atom r o] -> (forall i'. body i' -> m i' o a) -> m i o a -withInstantiated e xs cont = case toAbs e of - Abs bs body -> extendSubst (bs @@> (SubstVal <$> xs)) $ cont body +withInstantiated e xs cont = do + Abs bs body <- return $ toAbs e + let bs' = fmapNest (\(BD b) -> b) bs + extendSubst (bs' @@> (SubstVal <$> xs)) $ cont body instantiateNames - :: (EnvReader m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r) - => e n -> [AtomName r n] -> m n (body n) -instantiateNames e vs = case toAbs e of - Abs bs body -> applyRename (bs @@> vs) body + :: (EnvReader m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r, Ext h n) + => e h -> [AtomName r n] -> m n (body n) +instantiateNames e vs = do + Abs bs body <- sinkM $ toAbs e + let bs' = fmapNest (\(BD b) -> b) bs + applyRename (bs' @@> vs) body -- "lazy" subst-extending version of `instantiateNames` withInstantiatedNames @@ -498,8 +524,22 @@ withInstantiatedNames => e i -> [AtomName r o] -> (forall i'. body i' -> m i' o a) -> m i o a -withInstantiatedNames e vs cont = case toAbs e of - Abs bs body -> extendRenamer (bs @@> vs) $ cont body +withInstantiatedNames e vs cont = do + Abs bs body <- return $ toAbs e + let bs' = fmapNest (\(BD b) -> b) bs + extendRenamer (bs' @@> vs) $ cont body + +extendSubstBD + :: forall v m b r i i' o a + . (SubstReader v m, ToBinders b r, IRRep r) + => b i i' -> [v (AtomNameC r) o] -> m i' o a -> m i o a +extendSubstBD bsTop xsTop contTop = go (toBinders bsTop) xsTop contTop + where + go :: Binders r ii ii' -> [v (AtomNameC r) o] -> m ii' o a -> m ii o a + go Empty [] cont = cont + go (Nest (BD b) bs) (x:xs) cont = extendSubst (b@>x) $ go bs xs cont + go _ _ _ = error "zip error" +{-# INLINE extendSubstBD #-} -- Returns a representation type (type of an TypeCon-typed Newtype payload) -- given a list of instantiated DataConDefs. @@ -549,8 +589,8 @@ visitBlock b = visitGeneric (LamExpr Empty b) >>= \case visitAlt :: Visitor m r i o => Alt r i -> m (Alt r o) visitAlt (Abs b body) = do - visitGeneric (LamExpr (UnaryNest b) body) >>= \case - LamExpr (UnaryNest b') body' -> return $ Abs b' body' + visitGeneric (UnaryLamExpr b body) >>= \case + UnaryLamExpr b' body' -> return $ Abs b' body' _ -> error "not an alt" traverseOpTerm @@ -585,16 +625,16 @@ visitPiDefault (PiType bs effty) = do visitBinders :: (Visitor2 m r, IRRep r, FromName v, AtomSubstReader v m, EnvExtender2 m) - => Nest (Binder r) i i' - -> (forall o'. DExt o o' => Nest (Binder r) o o' -> m i' o' a) + => Binders r i i' + -> (forall o'. DExt o o' => Binders r o o' -> m i' o' a) -> m i o a visitBinders Empty cont = getDistinct >>= \Distinct -> cont Empty -visitBinders (Nest (b:>ty) bs) cont = do +visitBinders (Nest (BD (b:>ty)) bs) cont = do ty' <- visitType ty withFreshBinder (getNameHint b) ty' \b' -> do extendRenamer (b@>binderName b') do visitBinders bs \bs' -> - cont $ Nest b' bs' + cont $ Nest (BD b') bs' -- XXX: This doesn't handle the `Var`, `ProjectElt`, `SimpInCore` cases. These -- should be handled explicitly beforehand. TODO: split out these cases under a @@ -807,15 +847,6 @@ toAtomVar v = do ty <- getType <$> lookupAtomName v return $ AtomVar v ty -bindersToVars :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [AtomVar r n] -bindersToVars bs = do - withExtEvidence bs do - Distinct <- getDistinct - mapM toAtomVar $ nestToNames bs - -bindersToAtoms :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [Atom r n] -bindersToAtoms bs = liftM (Var <$>) $ bindersToVars bs - newtype SubstVisitor i o a = SubstVisitor { runSubstVisitor :: Reader (Env o, Subst AtomSubstVal i o) a } deriving (Functor, Applicative, Monad, MonadReader (Env o, Subst AtomSubstVal i o)) @@ -919,6 +950,7 @@ instance IRRep r => SubstE AtomSubstVal (DepPairType r) instance SubstE AtomSubstVal SolverBinding instance IRRep r => SubstE AtomSubstVal (DeclBinding r) instance IRRep r => SubstB AtomSubstVal (Decl r) +instance IRRep r => SubstB AtomSubstVal (BinderAndDecls r) instance SubstE AtomSubstVal NewtypeTyCon instance SubstE AtomSubstVal NewtypeCon instance IRRep r => SubstE AtomSubstVal (IxDict r) diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index a49e5989f..2f6ac6e66 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -254,6 +254,9 @@ instance (ToBinding ann c, Color c, CheckableE r ann) => CheckableB r (BinderP c extendRenamer (b@>binderName b') $ cont b' +instance IRRep r => CheckableB r (BinderAndDecls r) where + checkB (BD b) cont = checkB b \b' -> cont $ BD b' + checkBinderType :: (IRRep r) => Type r o -> Binder r i i' -> (forall o'. DExt o o' => Binder r o o' -> TyperM r i' o' a) @@ -263,6 +266,12 @@ checkBinderType ty b cont = do checkTypesEq (sink $ binderType b') (sink ty) cont b' +checkBinderAndDecls + :: (IRRep r) => Type r o -> BinderAndDecls r i i' + -> (forall o'. DExt o o' => BinderAndDecls r o o' -> TyperM r i' o' a) + -> TyperM r i o a +checkBinderAndDecls ty (BD b) cont = checkBinderType ty b \b' -> cont (BD b') + instance IRRep r => CheckableWithEffects r (Expr r) where checkWithEffects allowedEffs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of App effTy f xs -> do @@ -616,7 +625,7 @@ checkHof (EffTy effs reqTy) = \case IxType t d <- checkE ixTy LamExpr (UnaryNest b) body <- return f TabPi tabTy <- return reqTy - checkBinderType t b \b' -> do + checkBinderAndDecls t b \b' -> do resultTy <- checkInstantiation (sink tabTy) [Var $ binderVar b'] body' <- checkBlock (EffTy (sink effs) resultTy) body return $ For dir (IxType t d) (LamExpr (UnaryNest b') body') @@ -627,7 +636,7 @@ checkHof (EffTy effs reqTy) = \case Linearize f x -> do (x', xTy) <- checkAndGetType x LamExpr (UnaryNest b) body <- return f - checkBinderType xTy b \b' -> do + checkBinderAndDecls xTy b \b' -> do PairTy resultTy fLinTy <- sinkM reqTy body' <- checkBlock (EffTy Pure resultTy) body checkTypesEq fLinTy (Pi $ nonDepPiType [sink xTy] Pure resultTy) @@ -693,7 +702,7 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where ProdTy refTys -> forM_ refTys \case RawRefTy _ -> return (); _ -> badCarry _ -> badCarry let binderReqTy = PairTy (ixTypeType ixTy') carryTy' - checkBinderType binderReqTy b \b' -> do + checkBinderAndDecls binderReqTy b \b' -> do body' <- checkBlock (EffTy (sink effAnn') UnitTy) body return $ Seq effAnn' dir ixTy' carry' $ LamExpr (UnaryNest b') body' RememberDest effAnn d lam -> do @@ -701,7 +710,7 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where effAnn' <- checkE effAnn checkExtends effs effAnn' (d', dTy@(RawRefTy _)) <- checkAndGetType d - checkBinderType dTy b \b' -> do + checkBinderAndDecls dTy b \b' -> do body' <- checkBlock (EffTy (sink effAnn') UnitTy) body return $ RememberDest effAnn' d' $ LamExpr (UnaryNest b') body' AllocDest ty -> AllocDest <$> ty|:TyKind @@ -740,10 +749,10 @@ checkRWSAction -> RWS -> LamExpr r i -> TyperM r i o (LamExpr r o) checkRWSAction resultTy referentTy effs rws f = do BinaryLamExpr bH bR body <- return f - checkBinderType (TC HeapType) bH \bH' -> do + checkBinderAndDecls (TC HeapType) bH \bH' -> do let h = Var $ binderVar bH' let refTy = RefTy h (sink referentTy) - checkBinderType refTy bR \bR' -> do + checkBinderAndDecls refTy bR \bR' -> do let effs' = extendEffect (RWSEffect rws $ sink h) (sink effs) body' <- checkBlock (EffTy effs' (sink resultTy)) body return $ BinaryLamExpr bH' bR' body' @@ -775,9 +784,9 @@ checkInstantiation abTop xsTop = do Abs bs body <- return $ toAbs abTop go (Abs bs body) xsTop where - go :: Abs (Nest (Binder r)) body o' -> [Atom r o'] -> TyperM r i o' (body o') + go :: Abs (Binders r) body o' -> [Atom r o'] -> TyperM r i o' (body o') go (Abs Empty body) [] = return body - go (Abs (Nest b bs) body) (x:xs) = do + go (Abs (Nest (BD b) bs) body) (x:xs) = do checkTypesEq (getType x) (binderType b) rest <- applySubst (b@>SubstVal x) (Abs bs body) go rest xs diff --git a/src/lib/Core.hs b/src/lib/Core.hs index a7a107b49..f6d53574e 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -215,6 +215,10 @@ instance IRRep r => BindsEnv (Decl r) where toEnvFrag (Let b binding) = toEnvFrag $ b :> binding {-# INLINE toEnvFrag #-} +instance IRRep r => BindsEnv (BinderAndDecls r) where + toEnvFrag (BD b) = toEnvFrag b + {-# INLINE toEnvFrag #-} + instance BindsEnv EnvFrag where toEnvFrag frag = frag {-# INLINE toEnvFrag #-} @@ -415,7 +419,7 @@ liftLamExpr f (TopLam d ty (LamExpr bs body)) = liftM (TopLam d ty) $ liftEnvRea fromNaryForExpr :: IRRep r => Int -> Expr r n -> Maybe (Int, LamExpr r n) fromNaryForExpr maxDepth | maxDepth <= 0 = error "expected non-negative number of args" fromNaryForExpr maxDepth = \case - PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr b body)))) -> + PrimOp (Hof (TypedHof _ (For _ _ (LamExpr (UnaryNest b) body)))) -> extend <|> (Just $ (1, LamExpr (Nest b Empty) body)) where extend = do diff --git a/src/lib/Export.hs b/src/lib/Export.hs index f7ab3184d..090b090cf 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -19,6 +19,7 @@ import Foreign.C.String import Foreign.Ptr import Builder +import CheapReduction import Core import Err import IRVariants @@ -119,7 +120,7 @@ goArgs :: (IRRep r) => CallingConvention -> Nest ExportArg o o' -> [CAtomName o'] - -> Nest (WithAttrB Explicitness (Binder r)) i i' + -> Nest (WithAttrB Explicitness (BinderAndDecls r)) i i' -> Type r i' -> ExportSigM r i o' (ExportedSignature o) goArgs cc argSig argVs piBs piRes = case piBs of @@ -128,10 +129,10 @@ goArgs cc argSig argVs piBs piRes = case piBs of StandardCC -> (fromListE $ sink $ ListE argVs) ++ nestToList (sink . binderName) resSig XLACC -> [] _ -> error $ "calling convention not supported: " ++ show cc - Nest (WithAttrB expl (b:>ty)) bs -> do - ety <- toExportType ty + Nest (WithAttrB expl b) bs -> do + ety <- toExportType $ binderType b withFreshBinder (getNameHint b) ety \(v:>_) -> - extendSubst (b @> Rename (binderName v)) $ do + extendSubstBD b [Rename (binderName v)] do vis <- case expl of Explicit -> return ExplicitArg Inferred _ _ -> return ImplicitArg @@ -173,8 +174,8 @@ parseTabTy = go [] go shape = \case BaseTy (Scalar sbt) -> return $ Just $ RectContArrayPtr sbt shape NewtypeTyCon Nat -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape - TabTy d (b:>ixty) a -> do - maybeN <- case IxType ixty d of + TabTy d b a -> do + maybeN <- case IxType (binderType b) d of IxType (NewtypeTyCon (Fin n)) _ -> return $ Just n IxType _ (IxDictRawFin n) -> return $ Just n _ -> return Nothing diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index 58c0721d4..cebb3a690 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -14,12 +14,13 @@ import Types.Core import Inference import IRVariants import QueryType +import CheapReduction import Name import Subst import MTL1 import Types.Primitives -type RolePiBinder = WithAttrB RoleExpl CBinder +type RolePiBinder = WithAttrB RoleExpl CBinderAndDecls type RolePiBinders = Nest RolePiBinder generalizeIxDict :: EnvReader m => Atom CoreIR n -> m n (Generalized CoreIR CAtom n) @@ -36,7 +37,7 @@ generalizeArgs fTy argsTop = liftGeneralizerM $ runSubstReaderT idSubst do PairE (CorePiType _ expls bs _) (ListE argsTop') <- sinkM $ PairE fTy (ListE argsTop) ListE <$> go (zipAttrs expls bs) argsTop' where - go :: Nest (WithAttrB Explicitness CBinder) i i' -> [Atom CoreIR n] + go :: Nest (WithAttrB Explicitness CBinderAndDecls) i i' -> [Atom CoreIR n] -> SubstReaderT AtomSubstVal GeneralizerM i n [Atom CoreIR n] go (Nest (WithAttrB expl b) bs) (arg:args) = do ty' <- substM $ binderType b @@ -52,7 +53,7 @@ generalizeArgs fTy argsTop = liftGeneralizerM $ runSubstReaderT idSubst do -- non-type, non-dict arguments (e.g. a function). We just don't -- generalize in that case. return arg - args'' <- extendSubst (b@>SubstVal arg') $ go bs args + args'' <- extendSubstBD b [SubstVal arg'] $ go bs args return $ arg' : args'' go Empty [] = return [] go _ _ = error "zip error" @@ -80,12 +81,12 @@ liftGeneralizerM cont = do return (Abs bs e, vals) where -- OPTIMIZE: something not O(N^2) - hoistGeneralizationVals :: Nest GeneralizationEmission n l -> (Nest (Binder CoreIR) n l, [Atom CoreIR n]) + hoistGeneralizationVals :: Nest GeneralizationEmission n l -> (Binders CoreIR n l, [Atom CoreIR n]) hoistGeneralizationVals Empty = (Empty, []) hoistGeneralizationVals (Nest (GeneralizationEmission b val) bs) = do let (bs', vals) = hoistGeneralizationVals bs case hoist b (ListE vals) of - HoistSuccess (ListE vals') -> (Nest b bs', val:vals') + HoistSuccess (ListE vals') -> (Nest (PlainBD b) bs', val:vals') HoistFailure _ -> error "should't happen" -- when we do the generalization, -- the "local" values we emit never mention the new generalization binders. -- TODO: consider trying to encode this constraint using scope parameters. @@ -130,13 +131,13 @@ traverseTyParams ty f = getDistinct >>= \Distinct -> case ty of Abs paramRoles UnitE <- getClassRoleBinders name params' <- traverseRoleBinders f paramRoles params return $ DictTy $ DictType sn name params' - TabPi (TabPiType (IxDictAtom d) (b:>iTy) resultTy) -> do - iTy' <- f' TypeParam TyKind iTy - dictTy <- liftM ignoreExcept $ runFallibleT1 $ DictTy <$> ixDictType iTy' + TabPi tabTy@(TabPiType (IxDictAtom d) b _) -> do + iTy <- f' TypeParam TyKind $ binderType b + dictTy <- liftM ignoreExcept $ runFallibleT1 $ DictTy <$> ixDictType iTy d' <- f DictParam dictTy d - withFreshBinder (getNameHint b) iTy' \(b':>_) -> do - resultTy' <- applyRename (b@>binderName b') resultTy >>= (f' TypeParam TyKind) - return $ TabTy (IxDictAtom d') (b':>iTy') resultTy' + withFreshBinder (getNameHint b) iTy \b' -> do + resultTy' <- instantiate tabTy [Var $ binderVar b'] >>= (f' TypeParam TyKind) + return $ TabTy (IxDictAtom d') (PlainBD b') resultTy' -- shouldn't need this once we can exclude IxDictFin and IxDictSpecialized from CoreI TabPi t -> return $ TabPi t TC tc -> TC <$> case tc of @@ -178,7 +179,7 @@ traverseRoleBinders f allBinders allParams = ty' <- substM $ binderType b Distinct <- getDistinct param' <- liftSubstReaderT $ f role ty' param - params'' <- extendSubst (b@>SubstVal param') $ go bs params + params'' <- extendSubstBD b [SubstVal param'] $ go bs params return $ param' : params'' go _ _ = error "zip error" {-# INLINE traverseRoleBinders #-} diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 9a9ab2e71..86c6fb2db 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -59,18 +59,18 @@ toImpFunction cc (TopLam True destTy lam) = do case cc of EntryFunCC _ -> do argAtoms <- interpretImpArgs (sink $ EmptyAbs bs) vs - extendSubst (bs @@> (SubstVal <$> argAtoms)) do + extendSubstBD bs (SubstVal <$> argAtoms) do dest <- case binderType destB of RefTy _ ansTy -> allocDestUnmanaged =<< substM ansTy _ -> error "Expected a reference type for body destination" - extendSubst (destB @> SubstVal (destToAtom dest)) do + extendSubstBD destB [SubstVal $ destToAtom dest] do void $ translateBlock body resultAtom <- loadAtom dest repValToList <$> atomToRepVal resultAtom _ -> do (argAtoms, resultDest) <- interpretImpArgsWithCC cc (sink ty) vs - extendSubst (bs @@> (SubstVal <$> argAtoms)) do - extendSubst (destB @> SubstVal (destToAtom (sink resultDest))) do + extendSubstBD bs (SubstVal <$> argAtoms) do + extendSubstBD destB [SubstVal $ destToAtom (sink resultDest)] do void $ translateBlock body return [] toImpFunction _ (TopLam False _ _) = error "expected a lambda in destination-passing form" @@ -129,9 +129,9 @@ getNaryLamImpArgTypes t = liftEnvReaderM $ go t where interpretImpArgsWithDest :: EnvReader m => PiType SimpIR n -> [IExpr n] -> m n ([SAtom n], Dest n) interpretImpArgsWithDest t xs = do - (PiType bs (EffTy _ resultTy)) <- return t + piTy@(PiType bs _) <- return t (args, xsLeft) <- _interpretImpArgs (EmptyAbs bs) xs - resultTy' <- applySubst (bs @@> (SubstVal <$> args)) resultTy + EffTy _ resultTy' <- instantiate piTy args (destTree, xsRest) <- listToTree resultTy' xsLeft case xsRest of [] -> return () @@ -139,7 +139,7 @@ interpretImpArgsWithDest t xs = do return (args, Dest resultTy' destTree) interpretImpArgs :: EnvReader m - => EmptyAbs (Nest SBinder) n -> [IExpr n] -> m n [SAtom n] + => EmptyAbs SBinders n -> [IExpr n] -> m n [SAtom n] interpretImpArgs t args = do (args', xsLeft) <- _interpretImpArgs t args case xsLeft of @@ -147,17 +147,17 @@ interpretImpArgs t args = do _ -> error "Shouldn't have any Imp arguments left" _interpretImpArgs :: EnvReader m - => EmptyAbs (Nest SBinder) n -> [IExpr n] -> m n ([SAtom n], [IExpr n]) + => EmptyAbs SBinders n -> [IExpr n] -> m n ([SAtom n], [IExpr n]) _interpretImpArgs t args = liftEnvReaderM $ runSubstReaderT idSubst $ go t args where - go :: EmptyAbs (Nest SBinder) i -> [IExpr o] + go :: EmptyAbs SBinders i -> [IExpr o] -> SubstReaderT AtomSubstVal EnvReaderM i o ([SAtom o], [IExpr o]) go (Abs bs UnitE) xs = case bs of - Nest (b:>argTy) rest -> do - argTy' <- substM argTy - (argTree, xsRest) <- listToTree argTy' xs - arg <- repValAtom $ RepVal argTy' argTree - (args', xsLeft) <- extendSubst (b @> SubstVal arg) $ go (EmptyAbs rest) xsRest + Nest b rest -> do + argTy <- substM $ binderType b + (argTree, xsRest) <- listToTree argTy xs + arg <- repValAtom $ RepVal argTy argTree + (args', xsLeft) <- extendSubstBD b [SubstVal arg] $ go (EmptyAbs rest) xsRest return (arg:args', xsLeft) Empty -> return ([], xs) @@ -358,13 +358,13 @@ toImpRefOp refDest' m = do => (Dest o) -> SType o -> LamExpr SimpIR o -> SAtom o -> SAtom o -> SubstImpM n o () liftMonoidCombine accDest accTy bc x y = do - LamExpr (Nest (_:>baseTy) _) _ <- return bc + LamExpr (Nest b _) _ <- return bc + let baseTy = binderType b alphaEq accTy baseTy >>= \case -- Immediately beta-reduce, beacuse Imp doesn't reduce non-table applications. True -> do - BinaryLamExpr xb yb body <- return bc - body' <- applySubst (xb @> SubstVal x <.> yb @> SubstVal y) body - ans <- liftBuilderImp $ emitBlock (sink body') + body <- instantiate bc [x, y] + ans <- liftBuilderImp $ emitBlock $ sink body storeAtom accDest ans False -> case accTy of TabPi t -> do @@ -374,7 +374,7 @@ toImpRefOp refDest' m = do idx <- unsafeFromOrdinalImp (sink ixTy) i xElt <- liftBuilderImp $ tabApp (sink x) (sink idx) yElt <- liftBuilderImp $ tabApp (sink y) (sink idx) - eltTy <- instantiate (sink t) [idx] + eltTy <- instantiate t [idx] ithDest <- indexDest (sink accDest) idx liftMonoidCombine ithDest eltTy (sink bc) xElt yElt _ -> error $ "Base monoid type mismatch: can't lift " ++ @@ -536,14 +536,12 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do emitStatement $ IWhile body' return UnitVal RunReader r f -> do - BinaryLamExpr h ref body <- return f r' <- substM r rDest <- allocDest $ getType r' storeAtom rDest r' - extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom rDest)) $ + withInstantiated f [Con HeapVal, destToAtom rDest] \body -> translateBlock body RunWriter d (BaseMonoid e _) f -> do - BinaryLamExpr h ref body <- return f let PairTy ansTy accTy = resultTy (aDest, wDest) <- case d of Nothing -> destPairUnpack <$> allocDest resultTy @@ -554,11 +552,10 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do e' <- substM e PairE accTy' e'' <- sinkM $ PairE accTy e' liftMonoidEmpty wDest accTy' e'' - extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom wDest)) $ + withInstantiated f [Con HeapVal, destToAtom wDest] \body -> translateBlock body >>= storeAtom aDest PairVal <$> loadAtom aDest <*> loadAtom wDest RunState d s f -> do - BinaryLamExpr h ref body <- return f let PairTy ansTy _ = resultTy (aDest, sDest) <- case d of Nothing -> destPairUnpack <$> allocDest resultTy @@ -567,7 +564,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do sDest <- atomToDest =<< substM d' return (aDest, sDest) storeAtom sDest =<< substM s - extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom sDest)) $ + withInstantiated f [Con HeapVal, destToAtom sDest] \body -> do translateBlock body >>= storeAtom aDest PairVal <$> loadAtom aDest <*> loadAtom sDest RunIO body -> translateBlock body @@ -585,7 +582,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i x' <- sinkM x - eltTy <- instantiate (sink t) [idx] + eltTy <- instantiate t [idx] ithDest <- indexDest (sink accDest) idx liftMonoidEmpty ithDest eltTy x' _ -> error $ "Base monoid type mismatch: can't lift " ++ @@ -610,13 +607,13 @@ type IndexStructure r = EmptyAbs (IdxNest r) :: E pattern Singleton :: IndexStructure r n pattern Singleton = EmptyAbs Empty -type IxBinder r = PairB (LiftB (IxDict r)) (Binder r) +type IxBinder r = PairB (LiftB (IxDict r)) (BinderAndDecls r) type IdxNest r = Nest (IxBinder r) data TypeCtxLayer (r::IR) (n::S) (l::S) where - TabCtx :: IxBinder r n l -> TypeCtxLayer r n l - DepPairCtx :: MaybeB (Binder r) n l -> TypeCtxLayer r n l - RefCtx :: TypeCtxLayer r n n + TabCtx :: IxBinder r n l -> TypeCtxLayer r n l + DepPairCtx :: MaybeB (BinderAndDecls r) n l -> TypeCtxLayer r n l + RefCtx :: TypeCtxLayer r n n instance SinkableE Dest where sinkingProofE = undefined @@ -663,7 +660,7 @@ allNothingBs Empty = Just UnitB allNothingBs (Nest (LeftB _) _) = Nothing allNothingBs (Nest (RightB UnitB) rest) = allNothingBs rest -splitLeadingDepPairs :: TypeCtx SimpIR n l -> PairB (Nest (MaybeB SBinder)) (TypeCtx SimpIR) n l +splitLeadingDepPairs :: TypeCtx SimpIR n l -> PairB (Nest (MaybeB SBinderAndDecls)) (TypeCtx SimpIR) n l splitLeadingDepPairs = \case Empty -> PairB Empty Empty Nest (DepPairCtx b) rest -> case splitLeadingDepPairs rest of @@ -699,9 +696,10 @@ typeToTree tyTop = return $ go REmpty tyTop BaseTy b -> Leaf $ LeafType (unRNest ctx) b TabTy d b bodyTy -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy RefTy _ t -> go (RNest ctx RefCtx) t - DepPairTy (DepPairType _ (b:>t1) (t2)) -> do + DepPairTy (DepPairType _ b (t2)) -> do + let t1 = binderType b let tree1 = rec t1 - let tree2 = go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2 + let tree2 = go (RNest ctx (DepPairCtx (JustB b))) t2 Branch [tree1, tree2] ProdTy ts -> Branch $ map rec ts SumTy ts -> do @@ -735,18 +733,20 @@ valueToTree (RepVal tyTop valTop) = do BaseTy b -> return $ Leaf $ LeafType (unRNest ctx) b TabTy d b bodyTy -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val RefTy _ t -> go (RNest ctx RefCtx) t val - DepPairTy (DepPairType _ (b:>t1) (t2)) -> case val of + DepPairTy dpTy@(DepPairType _ b t2) -> case val of Branch [v1, v2] -> do case allDepPairCtxs (unRNest ctx) of Just UnitB -> do + let t1 = depPairLeftTy dpTy tree1 <- rec t1 v1 x <- repValAtom $ RepVal t1 v1 - t2' <- applySubst (b@>SubstVal x) t2 + t2' <- instantiate dpTy [x] tree2 <- go (RNest ctx (DepPairCtx NothingB )) t2' v2 return $ Branch [tree1, tree2] Nothing -> do + let t1 = depPairLeftTy dpTy tree1 <- rec t1 v1 - tree2 <- go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2 v2 + tree2 <- go (RNest ctx (DepPairCtx (JustB b))) t2 v2 return $ Branch [tree1, tree2] _ -> error "expected a branch" ProdTy ts -> case val of @@ -1110,8 +1110,8 @@ computeElemCount idxNest' = do elemCountPoly :: Emits n => IndexStructure SimpIR n -> SBuilderM n (Atom SimpIR n) elemCountPoly (Abs bs UnitE) = case bs of Empty -> return $ IdxRepVal 1 - Nest b@(PairB (LiftB d) (_:>t)) rest -> do - curSize <- indexSetSize $ IxType t d + Nest b@(PairB (LiftB d) b') rest -> do + curSize <- indexSetSize $ IxType (binderType b') d restSizes <- computeSizeGivenOrdinal b $ EmptyAbs rest sumUsingPolysImp curSize restSizes @@ -1119,7 +1119,7 @@ computeSizeGivenOrdinal :: EnvReader m => IxBinder SimpIR n l -> IndexStructure SimpIR l -> m n (Abs (Binder SimpIR) (Block SimpIR) n) -computeSizeGivenOrdinal (PairB (LiftB d) (b:>t)) idxStruct = liftBuilder do +computeSizeGivenOrdinal (PairB (LiftB d) (BD (b:>t))) idxStruct = liftBuilder do withFreshBinder noHint IdxRepTy \bOrdinal -> Abs bOrdinal <$> buildBlock do i <- unsafeFromOrdinal (sink $ IxType t d) $ Var $ sink $ binderVar bOrdinal @@ -1212,8 +1212,8 @@ withFreshIBinder hint ty cont = do emitCall :: Emits n => PiType SimpIR n -> ImpFunName n -> [SAtom n] -> SubstImpM i n (SAtom n) -emitCall (PiType bs (EffTy _ resultTy)) f xs = do - resultTy' <- applySubst (bs @@> map SubstVal xs) resultTy +emitCall piTy f xs = do + EffTy _ resultTy' <- instantiate piTy xs dest <- allocDest resultTy' argsImp <- forM xs \x -> repValToList <$> atomToRepVal x destImp <- repValToList <$> atomToRepVal (destToAtom dest) @@ -1426,8 +1426,8 @@ indexSetSizeImp (IxType _ dict) = do appSpecializedIxMethod :: Emits n => SpecDictName n -> IxMethod -> [SAtom n] -> SubstImpM i n (SAtom n) appSpecializedIxMethod d method args = do SpecializedDict _ (Just fs) <- lookupSpecDict d - TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method - dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock body + TopLam _ _ lam <- return $ fs !! fromEnum method + dropSubst $ withInstantiated lam args \body -> translateBlock body -- === Abstracting link-time objects === diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index cd28d7c5e..b04d8029a 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -182,22 +182,32 @@ class ( MonadFail1 m, Fallible1 m, Catchable1 m, CtxReader1 m, Builder CoreIR m -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) -> m n (Abs CBinder e n) +buildAbsInfWithDecls + :: (InfBuilder m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) + => EmitsInf n + => NameHint -> Explicitness -> CType n + -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) + -> m n (Abs CBinderAndDecls e n) +buildAbsInfWithDecls hint expl ty cont = do + Abs b e <- buildAbsInf hint expl ty cont + return $ Abs (PlainBD b) e + buildAbsInfWithExpl :: (InfBuilder m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) => EmitsInf n => NameHint -> Explicitness -> CType n -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) - -> m n (Abs (WithExpl CBinder) e n) + -> m n (Abs (WithExpl CBinderAndDecls) e n) buildAbsInfWithExpl hint expl ty cont = do Abs b e <- buildAbsInf hint expl ty cont - return $ Abs (WithAttrB expl b) e + return $ Abs (WithAttrB expl (PlainBD b)) e buildNaryAbsInfWithExpl :: (Inferer m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) => EmitsInf n - => [Explicitness] -> EmptyAbs (Nest CBinder) n + => [Explicitness] -> EmptyAbs CBinders n -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) - -> m i n (Abs (Nest (WithExpl CBinder)) e n) + -> m i n (Abs (Nest (WithExpl CBinderAndDecls)) e n) buildNaryAbsInfWithExpl expls bs cont = do Abs bs' e <- buildNaryAbsInf expls bs cont return $ Abs (zipAttrs expls bs') e @@ -205,13 +215,13 @@ buildNaryAbsInfWithExpl expls bs cont = do buildNaryAbsInf :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) => EmitsInf n - => [Explicitness] -> EmptyAbs (Nest CBinder) n + => [Explicitness] -> EmptyAbs CBinders n -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) - -> m i n (Abs (Nest CBinder) e n) + -> m i n (Abs CBinders e n) buildNaryAbsInf [] (Abs Empty UnitE) cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -buildNaryAbsInf (expl:expls) (Abs (Nest (b:>ty) bs) UnitE) cont = - prependAbs <$> buildAbsInf (getNameHint b) expl ty \v -> do - bs' <- applyRename (b@>atomVarName v) (Abs bs UnitE) +buildNaryAbsInf (expl:expls) (Abs (Nest b bs) UnitE) cont = + prependAbs <$> buildAbsInfWithDecls (getNameHint b) expl (binderType b) \v -> do + bs' <- instantiateNames (Abs b $ Abs bs UnitE) [atomVarName v] buildNaryAbsInf expls bs' \vs -> cont (sink v:vs) buildNaryAbsInf _ _ _ = error "zip error" @@ -857,8 +867,8 @@ extendSynthCandidates (Inferred _ (Synth _)) v (Env topEnv (ModuleEnv a b scs)) extendSynthCandidates _ _ env = env {-# INLINE extendSynthCandidates #-} -extendSynthCandidatess :: Distinct n => [Explicitness] -> Nest CBinder n' n -> Env n -> Env n -extendSynthCandidatess (expl:expls) (Nest b bs) env = +extendSynthCandidatess :: Distinct n => [Explicitness] -> CBinders n' n -> Env n -> Env n +extendSynthCandidatess (expl:expls) (Nest (BD b) bs) env = extendSynthCandidatess expls bs env' where env' = extendSynthCandidates expl (withExtEvidence bs $ sink $ binderName b) env extendSynthCandidatess [] Empty env = env @@ -887,9 +897,9 @@ checkSigma hint expr sTy = confuseGHC >>= \_ -> case sTy of expr' <- inferWithoutInstantiation expr >>= zonk dropSubst $ checkOrInferApp expr' explicits [] (Check resultTy) DepPairTy depPairTy -> case depPairTy of - DepPairType ImplicitDepPair (_ :> lhsTy) _ -> do + DepPairType ImplicitDepPair b _ -> do -- TODO: check for the case that we're given some of the implicit dependent pair args explicitly - lhsVal <- Var <$> freshInferenceName MiscInfVar lhsTy + lhsVal <- Var <$> freshInferenceName MiscInfVar (binderType b) -- TODO: make an InfVarDesc case for dep pair instantiation rhsTy <- instantiate depPairTy [lhsVal] rhsVal <- checkSigma noHint expr rhsTy @@ -998,8 +1008,8 @@ checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do withReducibleEmissions msg $ checkUType rhs UDepPair lhs rhs -> do case reqTy of - Check (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do - lhs' <- checkSigmaDependent noHint lhs lhsTy + Check (DepPairTy ty) -> do + lhs' <- checkSigmaDependent noHint lhs (depPairLeftTy ty) rhsTy <- instantiate ty [lhs'] rhs' <- checkSigma noHint rhs rhsTy return $ DepPair lhs' rhs' ty @@ -1226,28 +1236,28 @@ etaExpandExplicits :: EmitsInf o => SourceName -> CorePiType o -> (forall o'. (EmitsBoth o', DExt o o') => [CAtom o'] -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -etaExpandExplicits fSourceName (CorePiType _ explsTop bsTop (EffTy effs _)) contTop = do +etaExpandExplicits fSourceName piTy@(CorePiType _ explsTop bsTop _) contTop = do Abs bs body <- go explsTop bsTop \xs -> do - effs' <- applySubst (bsTop@@>(SubstVal<$>xs)) effs - withAllowedEffects effs' do + EffTy effs _ <- instantiate piTy xs + withAllowedEffects effs do body <- buildBlockInf $ contTop $ sinkList xs - return $ PairE effs' body + return $ PairE effs body let (expls, bs') = unzipAttrs bs coreLamExpr ExplicitApp expls $ Abs bs' body where go :: (EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e, HoistableE e ) - => [Explicitness] -> Nest CBinder o any + => [Explicitness] -> CBinders o any -> (forall o'. (EmitsInf o', DExt o o') => [CAtom o'] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) + -> InfererM i o (Abs (Nest (WithExpl CBinderAndDecls)) e o) go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (expl:expls) (Nest (b:>ty) rest) cont = case expl of + go (expl:expls) (Nest b rest) cont = case expl of Explicit -> do - prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ty \v -> do - Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE + prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl (binderType b) \v -> do + Abs rest' UnitE <- instantiateNames (Abs b $ Abs rest UnitE) [atomVarName v] go expls rest' \args -> cont (sink (Var v) : args) Inferred argSourceName infMech -> do - arg <- getImplicitArg (fSourceName, fromMaybe "_" argSourceName) infMech ty - Abs rest' UnitE <- applySubst (b@>SubstVal arg) $ Abs rest UnitE + arg <- getImplicitArg (fSourceName, fromMaybe "_" argSourceName) infMech (binderType b) + Abs rest' UnitE <- instantiate (Abs b $ Abs rest UnitE) [arg] go expls rest' \args -> cont (sink arg : args) go _ _ _ = error "zip error" @@ -1256,23 +1266,23 @@ buildLamInf -> (forall o' . (EmitsBoth o', DExt o o') => [(Explicitness, CAtom o')] -> CType o' -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -buildLamInf (CorePiType appExpl explsTop bsTop effTy) contTop = do +buildLamInf piTy@(CorePiType appExpl explsTop bsTop _) contTop = do ab <- go explsTop bsTop \xs -> do let (expls, xs') = unzip xs - EffTy effs' resultTy' <- applySubst (bsTop@@>(SubstVal<$>xs')) effTy + EffTy effs' resultTy' <- instantiate piTy xs' withAllowedEffects effs' do body <- buildBlockInf $ contTop (zip expls $ sinkList xs') (sink resultTy') return $ PairE effs' body coreLamExpr appExpl explsTop ab where go :: (EmitsInf o, HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e) - => [Explicitness] -> Nest CBinder o any + => [Explicitness] -> CBinders o any -> (forall o'. (EmitsInf o', DExt o o') => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest CBinder) e o) + -> InfererM i o (Abs CBinders e o) go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] go (expl:expls) (Nest b rest) cont = do - prependAbs <$> buildAbsInf (getNameHint b) expl (binderType b) \v -> do - Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE + prependAbs <$> buildAbsInfWithDecls (getNameHint b) expl (binderType b) \v -> do + Abs rest' UnitE <- instantiateNames (Abs b $ Abs rest UnitE) [atomVarName v] go expls rest' \args -> cont $ (expl, sink $ Var v) : args go _ _ _ = error "zip error" @@ -1416,9 +1426,9 @@ applyDataCon tc conIx topArgs = do where nargs = length args; ntys = length tys (curArgs, remArgs) = splitAt (ntys - 1) args - DepPairTy dpt@(DepPairType _ b rty') -> do - rty'' <- applySubst (b@>SubstVal h) rty' - ans <- wrap rty'' t + DepPairTy dpt -> do + rty' <- instantiate dpt [h] + ans <- wrap rty' t return $ DepPair h ans dpt where h:t = args _ -> error $ "Unexpected data con representation type: " ++ pprint rty @@ -1430,33 +1440,35 @@ emitExprWithEffects expr = do checkArity :: [Explicitness] -> [a] -> InfererM i o () checkArity expls args = do - let arity = length [() | Explicit <- expls] + let explArity = length [() | Explicit <- expls] let numArgs = length args - when (numArgs /= arity) do + when (numArgs /= explArity) do throw TypeErr $ "Wrong number of positional arguments provided. Expected " ++ - pprint arity ++ " but got " ++ pprint numArgs + pprint explArity ++ " but got " ++ pprint numArgs -- TODO: check that there are no extra named args provided inferMixedArgs :: forall arg i o e . (ExplicitArg arg, EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) => SourceName -> [Explicitness] - -> Abs (Nest CBinder) e o -> [arg i] -> [(SourceName, arg i)] + -> Abs CBinders e o -> [arg i] -> [(SourceName, arg i)] -> InfererM i o [CAtom o] inferMixedArgs fSourceName explsTop bsAbs posArgs namedArgs = do checkNamedArgValidity explsTop (map fst namedArgs) liftM fst $ runStreamReaderT1 posArgs $ go explsTop bsAbs where go :: (EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => [Explicitness] -> Abs (Nest CBinder) e o + => [Explicitness] -> Abs CBinders e o -> StreamReaderT1 (arg i) (InfererM i) o [CAtom o] go [] (Abs Empty _) = return [] go (expl:expls) (Abs (Nest b bs) result) = do let rest = Abs bs result - let isDependent = binderName b `isFreeIn` rest + isDependent <- return case tryAsConst (Abs b rest) of + Nothing -> True + Just _ -> False arg <- inferMixedArg isDependent (binderType b) expl arg' <- lift11 $ zonk arg - rest' <- applySubst (b @> SubstVal arg') rest + rest' <- instantiate (Abs b rest) [arg'] (arg:) <$> go expls rest' go _ _ = error "zip error" @@ -1541,7 +1553,7 @@ matchPrimApp = \case lam1 :: Fallible m => CAtom n -> m (LamExpr CoreIR n) lam1 x = do ExplicitCoreLam (UnaryNest b) body <- return x - return $ UnaryLamExpr b body + return $ LamExpr (UnaryNest b) body lam0 :: Fallible m => CAtom n -> m (CBlock n) lam0 x = do @@ -1558,7 +1570,7 @@ matchPrimApp = \case _ -> return $ Right x return $ fromJust $ toOp $ GenericOpRep op tyArgs dataArgs [] -pattern ExplicitCoreLam :: Nest CBinder n l -> CBlock l -> CAtom n +pattern ExplicitCoreLam :: CBinders n l -> CBlock l -> CAtom n pattern ExplicitCoreLam bs body <- Lam (CoreLamExpr _ (LamExpr bs body)) -- === n-ary applications === @@ -1575,14 +1587,16 @@ inferNaryTabAppArgs => CType o -> [UExpr i] -> InfererM i o [CAtom o] inferNaryTabAppArgs _ [] = return [] inferNaryTabAppArgs tabTy (arg:rest) = do - TabPiType _ b resultTy <- fromTabPiType True tabTy + tabPi@(TabPiType _ b _) <- fromTabPiType True tabTy let ixTy = binderType b - let isDependent = binderName b `isFreeIn` resultTy + isDependent <- return case tryAsConst tabPi of + Nothing -> True + Just _ -> False arg' <- if isDependent - then checkSigmaDependent (getNameHint b) arg ixTy - else checkSigma (getNameHint b) arg ixTy + then checkSigmaDependent (getNameHint tabPi) arg ixTy + else checkSigma (getNameHint tabPi) arg ixTy arg'' <- zonk arg' - resultTy' <- applySubst (b @> SubstVal arg'') resultTy + resultTy' <- instantiate tabPi [arg''] rest' <- inferNaryTabAppArgs resultTy' rest return $ arg'':rest' @@ -1665,8 +1679,7 @@ instanceFun :: EnvReader m => InstanceName n -> AppExplicitness -> m n (CAtom n) instanceFun instanceName appExpl = do InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName ab <- liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do - args <- mapM toAtomVar $ nestToNames bs' - result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> args) + result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> bindersVars bs') return $ Abs bs' (PairE Pure (WithoutDecls result)) Lam <$> coreLamExpr appExpl (snd<$>expls) ab @@ -1726,7 +1739,7 @@ inferDotMethod tc (Abs uparamBs (Abs selfB lam)) = do Abs paramBs'' (Abs selfB' lam') <- return ab return $ prependCoreLamExpr (paramBs'' >>> UnaryNest selfB') lam' -prependCoreLamExpr :: Nest (WithExpl CBinder) n l -> CoreLamExpr l -> CoreLamExpr n +prependCoreLamExpr :: Nest (WithExpl CBinderAndDecls) n l -> CoreLamExpr l -> CoreLamExpr n prependCoreLamExpr bs e = case e of CoreLamExpr (CorePiType appExpl piExpls piBs effTy) (LamExpr lamBs body) -> do let (expls, bs') = unzipAttrs bs @@ -1742,12 +1755,12 @@ inferDataCon (sourceName, UDataDefTrail argBs) = do let (repTy, projIdxs) = dataConRepTy argBs'' return $ DataConDef sourceName argBs'' repTy projIdxs -dataConRepTy :: EmptyAbs (Nest CBinder) n -> (CType n, [[Projection]]) +dataConRepTy :: EmptyAbs CBinders n -> (CType n, [[Projection]]) dataConRepTy (Abs topBs UnitE) = case topBs of Empty -> (UnitTy, []) _ -> go [] [UnwrapNewtype] topBs where - go :: [CType l] -> [Projection] -> Nest (Binder CoreIR) l p -> (CType l, [[Projection]]) + go :: [CType l] -> [Projection] -> CBinders l p -> (CType l, [[Projection]]) go revAcc projIdxs = \case Empty -> case revAcc of [] -> error "should never happen" @@ -1789,8 +1802,8 @@ inferClassDef className methodNames paramBs@(expls, paramBs') methods = do return $ ClassDef className methodNames paramNames roleExpls bs' scs mtys identifySuperclasses - :: RenameE e => Abs (Nest (WithRoleExpl CBinder)) e n - -> InfererM i n (Abs (PairB (Nest (WithRoleExpl CBinder)) (Nest CBinder)) e n) + :: RenameE e => Abs (Nest (WithRoleExpl CBinderAndDecls)) e n + -> InfererM i n (Abs (PairB (Nest (WithRoleExpl CBinderAndDecls)) CBinders) e n) identifySuperclasses ab = do refreshAbs ab \bs e -> do bs' <- partitionBinders bs \b@(WithAttrB (_, expl) b') -> case expl of @@ -1803,7 +1816,7 @@ withUBinders :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) => UAnnExplBinders req i i' -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) + -> InfererM i o (Abs (Nest (WithExpl CBinderAndDecls)) e o) withUBinders bs cont = case bs of ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do @@ -1819,7 +1832,7 @@ withConstraintBinders => [UConstraint i] -> CAtomVar o -> (forall o'. (EmitsInf o', DExt o o') => InfererM i o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) + -> InfererM i o (Abs (Nest (WithExpl CBinderAndDecls)) e o) withConstraintBinders [] _ cont = getDistinct >>= \Distinct -> Abs Empty <$> cont withConstraintBinders (c:cs) v cont = do Type dictTy <- withReducibleEmissions "Can't reduce interface constraint" do @@ -1832,12 +1845,12 @@ withRoleUBinders :: forall i i' o e req. (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) => UAnnExplBinders req i i' -> (forall o'. (EmitsInf o', DExt o o') => InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithRoleExpl CBinder)) e o) + -> InfererM i o (Abs (Nest (WithRoleExpl CBinderAndDecls)) e o) withRoleUBinders roleBs cont = case roleBs of ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do ann' <- checkAnn (getSourceName b) ann - Abs b' (Abs bs' e) <- buildAbsInf (getNameHint b) expl ann' \v -> do + Abs b' (Abs bs' e) <- buildAbsInfWithDecls (getNameHint b) expl ann' \v -> do Abs ds (Abs bs' e) <- withConstraintBinders cs v $ extendSubst (b@>sink (atomVarName v)) $ withRoleUBinders (expls, rest) cont let ds' = fmapNest (\(WithAttrB expl' b') -> WithAttrB (DictParam, expl') b') ds @@ -1865,30 +1878,29 @@ inferULam (ULamExpr bs appExpl effs resultTy body) = do ExplicitApp -> return () coreLamExpr appExpl expls $ Abs bs'' $ PairE effs' body' -checkImplicitLamRestrictions :: Nest CBinder o o' -> EffectRow CoreIR o' -> InfererM i o () +checkImplicitLamRestrictions :: CBinders o o' -> EffectRow CoreIR o' -> InfererM i o () checkImplicitLamRestrictions _ _ = return () -- TODO checkUForExpr :: EmitsBoth o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) checkUForExpr (UForExpr (UAnnBinder bFor ann cs) body) tabPi@(TabPiType _ bPi _) = do unless (null cs) $ throw TypeErr "`for` binders shouldn't have constraints" - let iTy = binderAnn bPi + let iTy = binderType bPi case ann of UNoAnn -> return () UAnn forAnn -> checkUType forAnn >>= constrainTypesEq iTy - Abs b body' <- buildAbsInf (getNameHint bFor) Explicit iTy \i -> do + Abs b body' <- buildAbsInfWithDecls (getNameHint bFor) Explicit iTy \i -> do extendRenamer (bFor@>atomVarName i) do - TabPiType _ bPi' resultTy <- sinkM tabPi - resultTy' <- applyRename (bPi'@>atomVarName i) resultTy + resultTy <- instantiate tabPi [Var i] buildBlockInf do withBlockDecls body \result -> - checkSigma noHint result $ sink resultTy' + checkSigma noHint result $ sink resultTy return $ LamExpr (UnaryNest b) body' inferUForExpr :: EmitsBoth o => UForExpr i -> InfererM i o (LamExpr CoreIR o) inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do unless (null cs) $ throw TypeErr "`for` binders shouldn't have constraints" iTy <- checkAnn (getSourceName bFor) ann - Abs b body' <- buildAbsInf (getNameHint bFor) Explicit iTy \i -> + Abs b body' <- buildAbsInfWithDecls (getNameHint bFor) Explicit iTy \i -> extendRenamer (bFor@>atomVarName i) $ buildBlockInf $ withBlockDecls body \result -> checkOrInferRho noHint result Infer @@ -1896,12 +1908,12 @@ inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do checkULam :: EmitsInf o => ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) checkULam (ULamExpr (_, lamBs) lamAppExpl lamEffs lamResultTy body) - (CorePiType piAppExpl expls piBs effTy) = do + piTy@(CorePiType piAppExpl expls piBs _) = do checkArity expls (nestToList (const ()) lamBs) when (piAppExpl /= lamAppExpl) $ throw TypeErr $ "Wrong arrow. Expected " ++ pprint piAppExpl ++ " got " ++ pprint lamAppExpl Abs explBs body' <- checkLamBinders expls piBs lamBs \vs -> do - EffTy piEffs' piResultTy' <- applyRename (piBs@@>map atomVarName vs) effTy + EffTy piEffs' piResultTy' <- instantiateNames piTy (atomVarName <$> vs) case lamResultTy of Nothing -> return () Just t -> checkUType t >>= constrainTypesEq piResultTy' @@ -1919,16 +1931,17 @@ checkULam (ULamExpr (_, lamBs) lamAppExpl lamEffs lamResultTy body) checkLamBinders :: (EmitsInf o, SinkableE e, HoistableE e, SubstE AtomSubstVal e, RenameE e) - => [Explicitness] -> Nest CBinder o any + => [Explicitness] -> CBinders o any -> Nest UOptAnnBinder i i' -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) + -> InfererM i o (Abs (Nest (WithExpl CBinderAndDecls)) e o) checkLamBinders [] Empty Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do +checkLamBinders (piExpl:piExpls) (Nest piB piBs) lamBs cont = do + let piAnn = binderType piB prependAbs <$> case piExpl of Inferred _ _ -> buildAbsInfWithExpl (getNameHint piB) piExpl piAnn \v -> do - Abs piBs' UnitE <- applyRename (piB@>atomVarName v) $ Abs piBs UnitE + Abs piBs' UnitE <- instantiateNames (Abs piB $ Abs piBs UnitE) [atomVarName v] checkLamBinders piExpls piBs' lamBs \vs -> cont (sink v:vs) Explicit -> case lamBs of @@ -1938,22 +1951,22 @@ checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do UNoAnn -> return () buildAbsInfWithExpl (getNameHint lamB) Explicit piAnn \v -> do concatAbs <$> withConstraintBinders cs v do - Abs piBs' UnitE <- applyRename (piB@>sink (atomVarName v)) $ Abs piBs UnitE + Abs piBs' UnitE <- instantiateNames (Abs piB $ Abs piBs UnitE) [sink $ atomVarName v] extendRenamer (lamB@>sink (atomVarName v)) $ checkLamBinders piExpls piBs' lamBsRest \vs -> cont (sink v:vs) Empty -> error "zip error" checkLamBinders _ _ _ _ = error "zip error" -checkInstanceParams :: EmitsInf o => [Explicitness] -> Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] +checkInstanceParams :: EmitsInf o => [Explicitness] -> CBinders o any -> [UExpr i] -> InfererM i o [CAtom o] checkInstanceParams expls bsTop paramsTop = do checkArity expls paramsTop go bsTop paramsTop where - go :: EmitsInf o => Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] + go :: EmitsInf o => CBinders o any -> [UExpr i] -> InfererM i o [CAtom o] go Empty [] = return [] - go (Nest (b:>ty) bs) (x:xs) = do - x' <- checkUParam ty x - Abs bs' UnitE <- applySubst (b@>SubstVal x') $ Abs bs UnitE + go (Nest b bs) (x:xs) = do + x' <- checkUParam (binderType b) x + Abs bs' UnitE <- instantiate (Abs b $ Abs bs UnitE) [x'] (x':) <$> go bs' xs go _ _ = error "zip error" @@ -1961,11 +1974,11 @@ checkInstanceBody :: EmitsInf o => ClassName o -> [CAtom o] -> [UMethodDef i] -> InfererM i o (InstanceBody o) checkInstanceBody className params methods = do - ClassDef _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className - Abs scBs' methodTys' <- applySubst (paramBs @@> (SubstVal <$> params)) $ Abs scBs $ ListE methodTys + classDef@(ClassDef _ methodNames _ _ _ _ _) <- lookupClassDef className + superclassAbs@(Abs scBs' _) <- instantiate classDef params superclassTys <- superclassDictTys scBs' superclassDicts <- mapM (flip trySynthTerm Full) superclassTys - ListE methodTys'' <- applySubst (scBs'@@>(SubstVal<$>superclassDicts)) methodTys' + ListE methodTys'' <- instantiate superclassAbs superclassDicts methodsChecked <- mapM (checkMethodDef className methodTys'') methods let (idxs, methods') = unzip $ sortOn fst $ methodsChecked forM_ (repeated idxs) \i -> @@ -1974,7 +1987,7 @@ checkInstanceBody className params methods = do throw TypeErr $ "Missing method: " ++ pprint (methodNames!!i) return $ InstanceBody superclassDicts methods' -superclassDictTys :: Nest CBinder o o' -> InfererM i o [CType o] +superclassDictTys :: CBinders o o' -> InfererM i o [CType o] superclassDictTys Empty = return [] superclassDictTys (Nest b bs) = do Abs bs' UnitE <- liftHoistExcept $ hoist b $ Abs bs UnitE @@ -2057,20 +2070,20 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" inferParams :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) - => SourceName -> [RoleExpl] -> Abs (Nest CBinder) e o -> InfererM i o (TyConParams o, e o) + => SourceName -> [RoleExpl] -> Abs CBinders e o -> InfererM i o (TyConParams o, e o) inferParams sourceName roleExpls (Abs paramBs bodyTop) = do let expls = snd <$> roleExpls (params, e') <- go expls (Abs paramBs bodyTop) return (TyConParams expls params, e') where go :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) - => [Explicitness] -> Abs (Nest CBinder) e o -> InfererM i o ([CAtom o], e o) + => [Explicitness] -> Abs CBinders e o -> InfererM i o ([CAtom o], e o) go [] (Abs Empty body) = return ([], body) - go (expl:expls) (Abs (Nest (b:>ty) bs) body) = do + go (expl:expls) (Abs (Nest b bs) body) = do x <- case expl of - Explicit -> Var <$> freshInferenceName (TypeInstantiationInfVar sourceName) ty - Inferred argName infMech -> getImplicitArg (sourceName, fromMaybe "_" argName) infMech ty - rest <- applySubst (b@>SubstVal x) $ Abs bs body + Explicit -> Var <$> freshInferenceName (TypeInstantiationInfVar sourceName) (binderType b) + Inferred argName infMech -> getImplicitArg (sourceName, fromMaybe "_" argName) infMech (binderType b) + rest <- instantiate (Abs b $ Abs bs body) [x] (params, body') <- go expls rest return (x:params, body') go _ _ = error "zip error" @@ -2162,21 +2175,21 @@ inferTabCon hint xs reqTy = do xs' <- forM xs \x -> checkRho noHint x elemTy dTy <- DictTy <$> dataDictType elemTy liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' - Check tabTy -> do - TabPiType _ b elemTy <- fromTabPiType True tabTy + Check ty -> do + tabTy@(TabPiType _ b _) <- fromTabPiType True ty constrainTypesEq (binderType b) finTy xs' <- forM (enumerate xs) \(i, x) -> do let i' = NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i) :: CAtom o - elemTy' <- applySubst (b@>SubstVal i') elemTy + elemTy' <- instantiate tabTy [i'] checkRho noHint x elemTy' - dTy <- case hoist b elemTy of - HoistSuccess elemTy' -> DictTy <$> dataDictType elemTy' - HoistFailure _ -> ignoreExcept <$> liftEnvReaderT do + dTy <- case tryAsConst tabTy of + Nothing -> ignoreExcept <$> liftEnvReaderT do withFreshBinder noHint finTy \b' -> do - elemTy' <- applyRename (b@>binderName b') elemTy + elemTy' <- instantiate tabTy [Var $ binderVar b'] dTy <- DictTy <$> dataDictType elemTy' - return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) - liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' + return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest (BD b')) (EffTy Pure dTy) + Just elemTy' -> DictTy <$> dataDictType elemTy' + liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) ty xs' -- Bool flag is just to tweak the reported error message fromTabPiType :: EmitsBoth o => Bool -> CType o -> InfererM i o (TabPiType CoreIR o) @@ -2536,16 +2549,16 @@ instance Unifiable CorePiType where go (Abs bsTop1 effTy1) (Abs bsTop2 effTy2) where go :: EmitsInf n - => Abs (Nest CBinder) (EffTy CoreIR) n - -> Abs (Nest CBinder) (EffTy CoreIR) n + => Abs CBinders (EffTy CoreIR) n + -> Abs CBinders (EffTy CoreIR) n -> SolverM n () go (Abs Empty (EffTy e1 t1)) (Abs Empty (EffTy e2 t2)) = unify t1 t2 >> unify e1 e2 - go (Abs (Nest (b1:>t1) bs1) rest1) - (Abs (Nest (b2:>t2) bs2) rest2) = do - unify t1 t2 - v <- freshSkolemName t1 - ab1 <- zonk =<< applySubst (b1@>SubstVal (Var v)) (Abs bs1 rest1) - ab2 <- zonk =<< applySubst (b2@>SubstVal (Var v)) (Abs bs2 rest2) + go (Abs (Nest b1 bs1) rest1) + (Abs (Nest b2 bs2) rest2) = do + unify (binderType b1) (binderType b2) + v <- freshSkolemName (binderType b1) + ab1 <- zonk =<< instantiate (Abs b1 (Abs bs1 rest1)) [Var v] + ab2 <- zonk =<< instantiate (Abs b2 (Abs bs2 rest2)) [Var v] go ab1 ab2 go _ _ = empty @@ -2555,8 +2568,8 @@ unifyTabPiType (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = do let ann2 = binderType b2 unify ann1 ann2 v <- freshSkolemName ann1 - ty1' <- applySubst (b1@>SubstVal (Var v)) ty1 - ty2' <- applySubst (b2@>SubstVal (Var v)) ty2 + ty1' <- instantiate (Abs b1 ty1) [Var v] + ty2' <- instantiate (Abs b2 ty2) [Var v] unify ty1' ty2' extendSolution :: CAtomName n -> CAtom n -> SolverM n () @@ -2664,9 +2677,9 @@ generalizeDictRec dict = do DataData ty -> DataData <$> TyVar <$> freshInferenceName MiscInfVar ty where notSimplifiedDict = error $ "Not a simplified dict: " ++ pprint dict -generalizeInstanceArgs :: EmitsInf n => [RoleExpl] -> Nest CBinder n l -> [CAtom n] -> SolverM n [CAtom n] +generalizeInstanceArgs :: EmitsInf n => [RoleExpl] -> CBinders n l -> [CAtom n] -> SolverM n [CAtom n] generalizeInstanceArgs [] Empty [] = return [] -generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) = do +generalizeInstanceArgs ((role,_):expls) (Nest b bs) (arg:args) = do arg' <- case role of -- XXX: for `TypeParam` we can just emit a fresh inference name rather than -- traversing the whole type like we do in `Generalize.hs`. The reason is @@ -2674,9 +2687,9 @@ generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) = do -- fresh dictionary, and if we were to do that, we would infer this type -- parameter exactly as we do here, using inference. TypeParam -> Var <$> freshInferenceName MiscInfVar TyKind - DictParam -> generalizeDictAndUnify ty arg - DataParam -> Var <$> freshInferenceName MiscInfVar ty - Abs bs' UnitE <- applySubst (b@>SubstVal arg') (Abs bs UnitE) + DictParam -> generalizeDictAndUnify (binderType b) arg + DataParam -> Var <$> freshInferenceName MiscInfVar (binderType b) + Abs bs' UnitE <- instantiate (Abs b (Abs bs UnitE)) [arg'] args' <- generalizeInstanceArgs expls bs' args return $ arg':args' generalizeInstanceArgs _ _ _ = error "zip error" @@ -2703,9 +2716,9 @@ pattern InstanceDefAbsBody :: [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] pattern InstanceDefAbsBody params superclasses doneMethods todoMethods = ListE params `PairE` (ListE superclasses) `PairE` (ListE doneMethods) `PairE` (ListE todoMethods) -type InstanceDefAbsT n = ([RoleExpl], Abs (Nest CBinder) InstanceDefAbsBodyT n) +type InstanceDefAbsT n = ([RoleExpl], Abs CBinders InstanceDefAbsBodyT n) -pattern InstanceDefAbs :: [RoleExpl] -> Nest CBinder h n -> [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] +pattern InstanceDefAbs :: [RoleExpl] -> CBinders h n -> [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] -> InstanceDefAbsT h pattern InstanceDefAbs expls bs params superclasses doneMethods todoMethods = (expls, Abs bs (InstanceDefAbsBody params superclasses doneMethods todoMethods)) @@ -2760,7 +2773,7 @@ trySynthTerm ty reqMethodAccess = do {-# SCC trySynthTerm #-} type SynthAtom = CAtom -type SynthPiType n = ([Explicitness], Abs (Nest CBinder) DictType n) +type SynthPiType n = ([Explicitness], Abs CBinders DictType n) data SynthType n = SynthDictType (DictType n) | SynthPiType (SynthPiType n) @@ -2881,7 +2894,7 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of {-# SCC synthTerm #-} coreLamExpr :: EnvReader m => AppExplicitness - -> [Explicitness] -> Abs (Nest CBinder) (PairE (EffectRow CoreIR) CBlock) n + -> [Explicitness] -> Abs CBinders (PairE (EffectRow CoreIR) CBlock) n -> m n (CoreLamExpr n) coreLamExpr appExpl expls ab = liftEnvReaderM do refreshAbs ab \bs' (PairE effs' body') -> do @@ -2889,20 +2902,20 @@ coreLamExpr appExpl expls ab = liftEnvReaderM do return $ CoreLamExpr (CorePiType appExpl expls bs' (EffTy effs' resultTy)) (LamExpr bs' body') withGivenBinders - :: (SinkableE e, RenameE e) => [Explicitness] -> Abs (Nest CBinder) e n - -> (forall l. DExt n l => Nest CBinder n l -> e l -> SyntherM l a) + :: (SinkableE e, RenameE e) => [Explicitness] -> Abs CBinders e n + -> (forall l. DExt n l => CBinders n l -> e l -> SyntherM l a) -> SyntherM n a withGivenBinders explsTop (Abs bsTop e) contTop = runSubstReaderT idSubst $ go explsTop bsTop \bsTop' -> do e' <- renameM e liftSubstReaderT $ contTop bsTop' e' where - go :: [Explicitness] -> Nest CBinder i i' - -> (forall o'. DExt o o' => Nest CBinder o o' -> SubstReaderT Name SyntherM i' o' a) + go :: [Explicitness] -> CBinders i i' + -> (forall o'. DExt o o' => CBinders o o' -> SubstReaderT Name SyntherM i' o' a) -> SubstReaderT Name SyntherM i o a go expls bs cont = case (expls, bs) of ([], Empty) -> getDistinct >>= \Distinct -> cont Empty - (expl:explsRest, Nest b rest) -> do + (expl:explsRest, Nest (BD b) rest) -> do argTy <- renameM $ binderType b withFreshBinder (getNameHint b) argTy \b' -> do givens <- case expl of @@ -2911,7 +2924,7 @@ withGivenBinders explsTop (Abs bsTop e) contTop = s <- getSubst liftSubstReaderT $ extendGivens givens $ runSubstReaderT (s <>> b@>binderName b') $ - go explsRest rest \rest' -> cont (Nest b' rest') + go explsRest rest \rest' -> cont (Nest (BD b') rest') _ -> error "zip error" isMethodAccessAllowedBy :: EnvReader m => RequiredMethodAccess -> InstanceName n -> m n Bool @@ -2954,7 +2967,7 @@ instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = do arg -> return arg where go :: EmitsInf o - => DictType o -> [Explicitness] -> Abs (Nest CBinder) DictType i + => DictType o -> [Explicitness] -> Abs CBinders DictType i -> SubstReaderT AtomSubstVal SolverM i o [CAtom o] go target allExpls (Abs bs proposed) = case (allExpls, bs) of ([], Empty) -> do @@ -2967,7 +2980,7 @@ instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = do Explicit -> error "instances shouldn't have explicit args" Inferred _ Unify -> Var <$> freshInferenceName MiscInfVar argTy Inferred _ (Synth req) -> return $ DictHole (AlwaysEqual emptySrcPosCtx) argTy req - liftM (arg:) $ extendSubst (b@>SubstVal arg) $ go target expls (Abs rest proposed) + liftM (arg:) $ extendSubstBD b [SubstVal arg] $ go target expls (Abs rest proposed) _ -> error "zip error" synthDictForData :: forall n. DictType n -> SyntherM n (SynthAtom n) @@ -2976,8 +2989,8 @@ synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of -- The "Var" case is different TyVar _ -> synthDictFromGiven dictTy TabPi (TabPiType _ b eltTy) -> recurBinder (Abs b eltTy) >> success - DepPairTy (DepPairType _ b@(_:>l) r) -> do - recur l >> recurBinder (Abs b r) >> success + DepPairTy (DepPairType _ b r) -> do + recur (binderType b) >> recurBinder (Abs b r) >> success NewtypeTyCon nt -> do (_, ty') <- unwrapNewtypeType nt recur ty' >> success @@ -3075,16 +3088,16 @@ instance DictSynthTraversable CType where instance DictSynthTraversable DataConDefs where dsTraverse = visitGeneric dsTraverseExplBinders - :: [Explicitness] -> Nest CBinder i i' - -> (forall o'. DExt o o' => Nest CBinder o o' -> DictSynthTraverserM i' o' a) + :: [Explicitness] -> CBinders i i' + -> (forall o'. DExt o o' => CBinders o o' -> DictSynthTraverserM i' o' a) -> DictSynthTraverserM i o a dsTraverseExplBinders [] Empty cont = getDistinct >>= \Distinct -> cont Empty -dsTraverseExplBinders (expl:expls) (Nest b bs) cont = do +dsTraverseExplBinders (expl:expls) (Nest (BD b) bs) cont = do ty <- dsTraverse $ binderType b withFreshBinder (getNameHint b) ty \b' -> do let v = binderName b' extendSynthCandidatesDict expl v $ extendRenamer (b@>v) do - dsTraverseExplBinders expls bs \bs' -> cont $ Nest b' bs' + dsTraverseExplBinders expls bs \bs' -> cont $ Nest (BD b') bs' dsTraverseExplBinders _ _ _ = error "zip error" extendSynthCandidatesDict :: Explicitness -> CAtomName n -> DictSynthTraverserM i n a -> DictSynthTraverserM i n a @@ -3141,7 +3154,7 @@ buildTabPiInf -> InfererM i n (TabPiType CoreIR n) buildTabPiInf hint (IxType t d) body = do Abs b resultTy <- buildAbsInf hint Explicit t \v -> withoutEffects $ body v - return $ TabPiType d b resultTy + return $ TabPiType d (PlainBD b) resultTy buildDepPairTyInf :: EmitsInf n @@ -3150,7 +3163,7 @@ buildDepPairTyInf -> InfererM i n (DepPairType CoreIR n) buildDepPairTyInf hint expl ty body = do Abs b resultTy <- buildAbsInf hint Explicit ty body - return $ DepPairType expl b resultTy + return $ DepPairType expl (PlainBD b) resultTy buildAltInf :: EmitsInf n @@ -3270,7 +3283,7 @@ instance BindsEnv InfOutFrag where toEnvFrag (InfOutFrag frag _ _) = toEnvFrag frag instance GenericE SynthType where - type RepE SynthType = EitherE2 DictType (PairE (LiftE [Explicitness]) (Abs (Nest CBinder) DictType)) + type RepE SynthType = EitherE2 DictType (PairE (LiftE [Explicitness]) (Abs CBinders DictType)) fromE (SynthDictType d) = Case0 d fromE (SynthPiType (expl, t)) = Case1 (PairE (LiftE expl) t) toE (Case0 d) = SynthDictType d diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index bcf5bdc64..75039399b 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -319,15 +319,15 @@ instance Inlinable SLam where inlineDecls decls $ inline Stop ans withBinders - :: Nest SBinder i i' - -> (forall o'. DExt o o' => Nest SBinder o o' -> InlineM i' o' a) + :: SBinders i i' + -> (forall o'. DExt o o' => SBinders o o' -> InlineM i' o' a) -> InlineM i o a withBinders Empty cont = getDistinct >>= \Distinct -> cont Empty -withBinders (Nest (b:>ty) bs) cont = do +withBinders (Nest (BD (b:>ty)) bs) cont = do ty' <- buildScopedAssumeNoDecls $ inline Stop ty withFreshBinder (getNameHint b) ty' \b' -> extendRenamer (b@>binderName b') do - withBinders bs \bs' -> cont $ Nest b' bs' + withBinders bs \bs' -> cont $ Nest (BD b') bs' instance Inlinable (PiType SimpIR) where inline ctx (PiType bs effTy) = @@ -381,8 +381,7 @@ reconstructTabApp ctx expr ixs = ixsPref' <- mapM (inline $ EmitToNameCtx Stop) ixsPref let ixsPref'' = [v | AtomVar v _ <- ixsPref'] s <- getSubst - let moreSubst = bs @@> map Rename ixsPref'' - dropSubst $ extendSubst moreSubst do + dropSubst $ extendSubstBD bs (Rename <$> ixsPref'') do -- Decision here. These decls have already been processed by the -- inliner once, so their occurrence information is stale (and should -- have been erased). Do we rerun occurrence analysis, or just complete diff --git a/src/lib/JAX/ToSimp.hs b/src/lib/JAX/ToSimp.hs index e2e183955..dc9d3810e 100644 --- a/src/lib/JAX/ToSimp.hs +++ b/src/lib/JAX/ToSimp.hs @@ -46,7 +46,7 @@ simplifyJaxpr (Jaxpr invars constvars eqns outvars) = do simplifyJBinders :: Nest JBinder i i' - -> (forall o'. DExt o o' => Nest SBinder o o' -> JaxSimpM i' o' a) + -> (forall o'. DExt o o' => SBinders o o' -> JaxSimpM i' o' a) -> JaxSimpM i o a simplifyJBinders Empty cont = getDistinct >>= \Distinct -> cont Empty simplifyJBinders (Nest jb jbs) cont = case jb of @@ -55,7 +55,7 @@ simplifyJBinders (Nest jb jbs) cont = case jb of ty <- simplifyJTy jTy withFreshBinder (getNameHint suffix) ty \b' -> do extendSubst (b @> Rename (binderName b')) do - simplifyJBinders jbs \bs' -> cont (Nest b' bs') + simplifyJBinders jbs \bs' -> cont (Nest (PlainBD b') bs') simplifyJTy :: JVarType -> JaxSimpM i o (SType o) simplifyJTy JArrayName{shape, dtype} = go shape $ simplifyDType dtype where diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index d32b5230a..2ec562777 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -72,6 +72,9 @@ extendActiveSubst => b i i' -> SAtomVar o -> PrimalM i' o a -> PrimalM i o a extendActiveSubst b v cont = extendSubst (b@>atomVarName v) $ extendActivePrimals v cont +extendActiveSubstBD :: BinderAndDecls SimpIR i i' -> SAtomVar o -> PrimalM i' o a -> PrimalM i o a +extendActiveSubstBD (BD b) v cont = extendActiveSubst b v cont + extendActiveEffs :: Effect SimpIR o -> PrimalM i o a -> PrimalM i o a extendActiveEffs eff = local \primals -> primals { activeEffs = extendEffRow (eSetSingleton eff) (activeEffs primals)} @@ -169,7 +172,7 @@ isActive e = do vs <- (S.fromList . map atomVarName . activeVars) <$> getActivePrimals return $ any (`S.member` vs) (freeAtomVarsList e) --- === converision between monadic and reified version of functions === +-- === conversion between monadic and reified version of functions === tangentFunAsLambda :: Emits o @@ -181,10 +184,10 @@ tangentFunAsLambda cont = do buildLamExpr tangentTys \tangentVars -> do liftTangentM (TangentArgs $ map sink tangentVars) cont -getTangentArgTys :: (Fallible1 m, EnvExtender m) => [SAtomVar n] -> m n (EmptyAbs (Nest SBinder) n) +getTangentArgTys :: (Fallible1 m, EnvExtender m) => [SAtomVar n] -> m n (EmptyAbs SBinders n) getTangentArgTys topVs = go mempty topVs where go :: (Fallible1 m, EnvExtender m) - => EMap SAtomName SAtomVar n -> [SAtomVar n] -> m n (EmptyAbs (Nest SBinder) n) + => EMap SAtomName SAtomVar n -> [SAtomVar n] -> m n (EmptyAbs SBinders n) go _ [] = return $ EmptyAbs Empty go heapMap (v:vs) = case getType v of -- This is a hack to handle heaps/references. They normally come in pairs @@ -195,7 +198,7 @@ getTangentArgTys topVs = go mempty topVs where withFreshBinder (getNameHint v) (TC HeapType) \hb -> do let newHeapMap = sink heapMap <> eMapSingleton (sink (atomVarName v)) (binderVar hb) Abs bs UnitE <- go newHeapMap $ sinkList vs - return $ EmptyAbs $ Nest hb bs + return $ EmptyAbs $ Nest (PlainBD hb) bs RefTy (Var h) referentTy -> do case lookupEMap heapMap (atomVarName h) of Nothing -> error "shouldn't happen?" @@ -204,12 +207,12 @@ getTangentArgTys topVs = go mempty topVs where let refTy = RefTy (Var h') tt withFreshBinder (getNameHint v) refTy \refb -> do Abs bs UnitE <- go (sink heapMap) $ sinkList vs - return $ EmptyAbs $ Nest refb bs + return $ EmptyAbs $ Nest (PlainBD refb) bs ty -> do tt <- tangentType ty withFreshBinder (getNameHint v) tt \b -> do Abs bs UnitE <- go (sink heapMap) $ sinkList vs - return $ EmptyAbs $ Nest b bs + return $ EmptyAbs $ Nest (PlainBD b) bs class ReconFunctor (f :: E -> E) where capture @@ -276,9 +279,9 @@ linearizeBlockDefuncGeneral locals block = do -- Inverse of tangentFunAsLambda. Should be used inside a returned tangent action. applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom SimpIR o) -applyLinLam (LamExpr bs body) = do +applyLinLam lam = do TangentArgs args <- liftSubstReaderT $ getTangentArgs - extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do + withInstantiated lam (Var <$> args) \body -> substM body >>= emitBlock -- === actual linearization passs === @@ -307,9 +310,9 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do tangentBody' <- buildBlock do ts <- getUnpacked $ Var $ sink $ binderVar bTangent let substFrag = bsRecon @@> map (SubstVal . sink) xs - <.> bsTangent @@> map (SubstVal . sink) ts + <.> (fmapNest (\(BD b) -> b) bsTangent) @@> map (SubstVal . sink) ts emitBlock =<< applySubst substFrag tangentBody - return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody' + return $ LamExpr (bs' >>> BinaryNest (PlainBD bResidual) (PlainBD bTangent)) tangentBody' return (primalFun, tangentFun) (,) <$> asTopLam primalFun <*> asTopLam tangentFun linearizeTopLam (TopLam True _ _) _ = error "expected a non-destination-passing function" @@ -659,11 +662,11 @@ linearizeHof hof = case hof of linearizeEffectFun :: RWS -> SLam i -> PrimalM i o (SLam o, LinLamAbs o) linearizeEffectFun rws (BinaryLamExpr hB refB body) = do withFreshBinder noHint (TC HeapType) \h -> do - bTy <- extendSubst (hB@>binderName h) $ renameM $ binderType refB + bTy <- extendSubstBD hB [binderName h] $ renameM $ binderType refB withFreshBinder noHint bTy \b -> do let ref = binderVar b hVar <- sinkM $ binderVar h - (body', linLam) <- extendActiveSubst hB hVar $ extendActiveSubst refB ref $ + (body', linLam) <- extendActiveSubstBD hB hVar $ extendActiveSubstBD refB ref $ -- TODO: maybe we should check whether we need to extend the active effects extendActiveEffs (RWSEffect rws (Var hVar)) do linearizeBlockDefunc body @@ -671,7 +674,7 @@ linearizeEffectFun rws (BinaryLamExpr hB refB body) = do -- ensures that such references can never be *used* once the effect runner -- returns, but technically it's legal to return them. let linLam' = ignoreHoistFailure $ hoist (PairB h b) linLam - return (BinaryLamExpr h b body', linLam') + return (BinaryLamExpr (BD h) (BD b) body', linLam') linearizeEffectFun _ _ = error "expect effect function to be a binary lambda" withT :: PrimalM i o (e1 o) diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 0c441b298..3362b539a 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -66,10 +66,10 @@ lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftE lam <- case wantDestStyle of True -> do refreshAbs (Abs bs body) \bs' body' -> do - xs <- bindersToAtoms bs' - EffTy _ resultTy <- instantiate (sink piTy) xs + let xs = Var <$> bindersVars bs' + EffTy _ resultTy <- instantiate piTy xs Abs b body'' <- lowerFullySequentialBlock resultTy body' - return $ LamExpr (bs' >>> UnaryNest b) body'' + return $ LamExpr (bs' >>> UnaryNest (PlainBD b)) body'' False -> do refreshAbs (Abs bs body) \bs' body' -> do body'' <- lowerFullySequentialBlockNoDest body' @@ -323,13 +323,13 @@ lowerExprWithDest dest expr = case expr of :: SType o -> LamExpr SimpIR i -> (Maybe (Dest SimpIR o) -> LamExpr SimpIR o -> LowerM i o (Hof SimpIR o)) -> LowerM i o (SExpr o) - traverseRWS referentTy (LamExpr (BinaryNest hb rb) body) cont = do + traverseRWS referentTy lam@(LamExpr (BinaryNest _ rb) _) cont = do unpackRWSDest dest >>= \case Nothing -> generic Just (bodyDest, refDest) -> do hof <- cont refDest =<< buildEffLam (getNameHint rb) referentTy \hb' rb' -> - extendRenamer (hb@>atomVarName hb' <.> rb@>atomVarName rb') do + withInstantiated lam [Var hb', Var rb'] \body -> case bodyDest of Nothing -> lowerBlock body Just bd -> lowerBlockWithDest (sink bd) body diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index bff364f47..fddcb5ab6 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -451,11 +451,11 @@ instance HasOCC (Hof SimpIR) where oneShot :: Access n -> [IxExpr n] -> LamExpr SimpIR n -> OCCM n (LamExpr SimpIR n) oneShot acc [] (LamExpr Empty body) = LamExpr Empty <$> occNest acc body -oneShot acc (ix:ixs) (LamExpr (Nest b bs) body) = do +oneShot acc (ix:ixs) (LamExpr (Nest (BD b) bs) body) = do occWithBinder (Abs b (LamExpr bs body)) \b' restLam -> extend b' (sink ix) do LamExpr bs' body' <- oneShot (sink acc) (map sink ixs) restLam - return $ LamExpr (Nest b' bs') body' + return $ LamExpr (Nest (BD b') bs') body' oneShot _ _ _ = error "zip error" -- Going under a lambda binder. diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index e4331e484..2aceee2fc 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -14,6 +14,7 @@ import Data.Word import Data.Bits import Data.Bits.Floating import Data.List +import Control.Category ((>>>)) import Control.Monad import Control.Monad.State.Strict import GHC.Float @@ -258,8 +259,8 @@ ulExpr expr = case expr of extendSubst (b' @> SubstVal (IdxRepVal i)) $ emitSubstBlock block' inc $ fromIntegral n -- To account for the TabCon we emit below getLamExprType body' >>= \case - PiType (UnaryNest (tb:>_)) (EffTy _ valTy) -> do - let tabTy = TabPi $ TabPiType (IxDictRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy + PiType (UnaryNest tb) (EffTy _ valTy) -> do + let tabTy = TabPi $ TabPiType (IxDictRawFin (IdxRepVal n)) tb valTy emitExpr $ TabCon Nothing tabTy vals _ -> error "Expected `for` body to have a Pi type" _ -> error "Expected `for` body to be a lambda expression" @@ -320,7 +321,7 @@ licmExpr = \case Abs decls ans <- buildBlock $ visitBlockEmits body -- Now, we process the decls and decide which ones to hoist. liftEnvReaderM $ runSubstReaderT idSubst $ - seqLICM REmpty mempty (asNameBinder b') REmpty decls ans + seqLICM REmpty mempty b' REmpty decls ans PairE (ListE extraDests) ab <- emitDecls $ Abs hdecls destsAndBody extraDests' <- mapM toAtomVar extraDests -- Append the destinations of hoisted Allocs as loop carried values. @@ -329,12 +330,12 @@ licmExpr = \case let lbTy = case ix' of IxType ixTy _ -> PairTy ixTy carryTy extraDestsTyped <- forM extraDests' \(AtomVar d t) -> return (d, t) Abs extraDestBs (Abs lb bodyAbs) <- return $ abstractFreeVars extraDestsTyped ab + let extraDestBs' = fmapNest PlainBD extraDestBs body' <- withFreshBinder noHint lbTy \lb' -> do (oldIx, allCarries) <- fromPair $ Var $ binderVar lb' (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpacked allCarries let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) - let s = extraDestBs @@> map SubstVal newCarries <.> lb @> SubstVal oldLoopBinderVal - block <- applySubst s bodyAbs + block <- instantiate (Abs (extraDestBs' >>> UnaryNest lb) bodyAbs) (newCarries <> [oldLoopBinderVal]) return $ UnaryLamExpr lb' block emitSeq dir ix' dests'' body' PrimOp (Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body)))) -> do @@ -342,25 +343,25 @@ licmExpr = \case Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do Abs decls ans <- buildBlock $ visitBlockEmits body liftEnvReaderM $ runSubstReaderT idSubst $ - seqLICM REmpty mempty (asNameBinder b') REmpty decls ans + seqLICM REmpty mempty b' REmpty decls ans PairE (ListE []) (Abs lnb bodyAbs) <- emitDecls $ Abs hdecls destsAndBody ixTy <- substM $ binderType b body' <- withFreshBinder noHint ixTy \i -> do - block <- applyRename (lnb@>binderName i) bodyAbs + block <- instantiateNames (Abs lnb bodyAbs) [binderName i] return $ UnaryLamExpr i block emitHof $ For dir ix' body' expr -> visitGeneric expr >>= emitExpr seqLICM :: RNest SDecl n1 n2 -- hoisted decls -> [SAtomName n2] -- hoisted dests - -> AtomNameBinder SimpIR n2 n3 -- loop binder + -> BinderAndDecls SimpIR n2 n3 -- loop binder -> RNest SDecl n3 n4 -- loop-dependent decls -> Nest SDecl m1 m2 -- decls remaining to process -> SAtom m2 -- loop result -> SubstReaderT AtomSubstVal EnvReaderM m1 n4 (Abs (Nest SDecl) -- hoisted decls (PairE (ListE SAtomName) -- hoisted allocs (these should go in the loop carry) - (Abs (AtomNameBinder SimpIR) -- loop binder + (Abs (BinderAndDecls SimpIR) -- loop binder (Abs (Nest SDecl) -- non-hoisted decls SAtom))) n1) -- final result seqLICM !top !topDestNames !lb !reg decls ans = case decls of diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 23bc7ea60..b3ab47ecc 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -159,6 +159,9 @@ instance PrettyPrec a => PrettyPrec [a] where instance PrettyE ann => Pretty (BinderP c ann n l) where pretty (b:>ty) = p b <> ":" <> p ty +instance IRRep r => Pretty (BinderAndDecls r n l) where + pretty (BD b) = pretty b + instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (Expr r n) where prettyPrec (Atom x) = prettyPrec x @@ -324,7 +327,7 @@ withExplParens (Inferred _ Unify) x = braces $ x withExplParens (Inferred _ (Synth _)) x = brackets x instance IRRep r => Pretty (TabPiType r n) where - pretty (TabPiType dict (b:>ty) body) = let + pretty (TabPiType dict (BD (b:>ty)) body) = let prettyBody = case body of Pi subpi -> pretty subpi _ -> pLowest body diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 50a976816..92a74dfad 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -65,9 +65,11 @@ blockTy b = blockEffTy b <&> \(EffTy _ t) -> t piTypeWithoutDest :: PiType SimpIR n -> PiType SimpIR n piTypeWithoutDest (PiType bsRefB _) = case popNest bsRefB of - Just (PairB bs (_:>RawRefTy ansTy)) -> do - PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here - _ -> error "expected trailing dest binder" + Just (PairB bs refB) -> do + case binderType refB of + RawRefTy ansTy -> PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here + _ -> error "expected ref type" + _ -> error "expected trailing binder" blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff @@ -150,13 +152,13 @@ typeOfHof = \case _ -> error "expected a unary pi type" While _ -> return UnitTy Linearize f _ -> getLamExprType f >>= \case - PiType (UnaryNest (binder:>a)) (EffTy Pure b) -> do + PiType (UnaryNest binder) (EffTy Pure b) -> do let b' = ignoreHoistFailure $ hoist binder b - let fLinTy = Pi $ nonDepPiType [a] Pure b' + let fLinTy = Pi $ nonDepPiType [binderType binder] Pure b' return $ PairTy b' fLinTy _ -> error "expected a unary pi type" Transpose f _ -> getLamExprType f >>= \case - PiType (UnaryNest (_:>a)) _ -> return a + PiType (UnaryNest b) _ -> return $ binderType b _ -> error "expected a unary pi type" RunReader _ f -> do (resultTy, _) <- getTypeRWSAction f @@ -217,10 +219,10 @@ getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case dictTy <- DictTy <$> dictType (sink className) params withFreshBinder noHint dictTy \dictB -> do scDicts <- getSuperclassDicts (Var $ binderVar dictB) - CorePiType appExpl methodExpls methodBs effTy <- instantiate (sink absPiTy) scDicts + CorePiType appExpl methodExpls methodBs effTy <- instantiate absPiTy scDicts let paramExpls = paramNames <&> \name -> Inferred name Unify let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls - return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy + return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest (BD dictB) >>> methodBs) effTy getMethodType :: EnvReader m => Dict n -> Int -> m n (CorePiType n) getMethodType dict i = liftEnvReaderM $ withSubstReaderT do @@ -266,7 +268,7 @@ getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do buildDataConType :: (EnvReader m, EnvExtender m) => TyConDef n - -> (forall l. DExt n l => [Explicitness] -> Nest CBinder n l -> [CAtomName l] -> TyConParams l -> m l a) + -> (forall l. DExt n l => [Explicitness] -> CBinders n l -> [CAtomName l] -> TyConParams l -> m l a) -> m n a buildDataConType (TyConDef _ roleExpls bs _) cont = do let expls = snd <$> roleExpls @@ -274,9 +276,8 @@ buildDataConType (TyConDef _ roleExpls bs _) cont = do Explicit -> return $ Inferred Nothing Unify expl -> return $ expl refreshAbs (Abs bs UnitE) \bs' UnitE -> do - let vs = nestToNames bs' - vs' <- mapM toAtomVar vs - cont expls' bs' vs $ TyConParams expls (Var <$> vs') + let vs = bindersVars bs' + cont expls' bs' (atomVarName <$> vs) $ TyConParams expls (Var <$> vs) makeTyConParams :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n) makeTyConParams tc params = do @@ -322,12 +323,13 @@ functionEffs f = getLamExprType f >>= \case PiType b (EffTy effs _) -> return $ ignoreHoistFailure $ hoist b effs rwsFunEffects :: (IRRep r, EnvReader m) => RWS -> LamExpr r n -> m n (EffectRow r n) -rwsFunEffects rws f = getLamExprType f >>= \case +rwsFunEffects rws f = liftEnvReaderM $ getLamExprType f >>= \case PiType (BinaryNest h ref) et -> do let effs' = ignoreHoistFailure $ hoist ref (etEff et) - let hVal = Var $ AtomVar (binderName h) (TC HeapType) - let effs'' = deleteEff (RWSEffect rws hVal) effs' - return $ ignoreHoistFailure $ hoist h effs'' + refreshAbs (Abs h effs') \h' effs'' -> do + let hVal = Var $ binderVar h' + let effs''' = deleteEff (RWSEffect rws hVal) effs'' + return $ ignoreHoistFailure $ hoist h' effs''' _ -> error "Expected a binary function type" getLamExprType :: (IRRep r, EnvReader m) => LamExpr r n -> m n (PiType r n) @@ -361,6 +363,13 @@ getSuperclassTys (DictType _ className params) = do forM [0 .. nestLength superclasses - 1] \i -> do instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params +getSuperclassType :: RNest (BinderAndDecls CoreIR) n l -> CBinders l l' -> Int -> CType n +getSuperclassType _ Empty = error "bad index" +getSuperclassType bsAbove (Nest b bs) = \case + 0 -> ignoreHoistFailure $ hoist bsAbove (binderType b) + i -> getSuperclassType (RNest bsAbove b) bs (i-1) + + getTypeTopFun :: EnvReader m => TopFunName n -> m n (PiType SimpIR n) getTypeTopFun f = lookupTopFun f >>= \case DexTopFun _ (TopLam _ piTy _) _ -> return piTy @@ -383,7 +392,7 @@ liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where _ -> error $ "Not a valid FFI return type: " ++ pprint resultTys t:ts -> withFreshBinder noHint (BaseTy t) \b -> do PiType bs effTy <- go ts - return $ PiType (Nest b bs) effTy + return $ PiType (Nest (PlainBD b) bs) effTy -- === Data constraints === @@ -400,8 +409,8 @@ checkDataLike ty = case ty of TabPi (TabPiType _ b eltTy) -> do renameBinders b \_ -> checkDataLike eltTy - DepPairTy (DepPairType _ b@(_:>l) r) -> do - recur l + DepPairTy (DepPairType _ b r) -> do + recur $ binderType b renameBinders b \_ -> checkDataLike r NewtypeTyCon nt -> do (_, ty') <- unwrapNewtypeType =<< renameM nt diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 9be267241..b1fee94ac 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -117,12 +117,6 @@ instance IRRep r => HasType r (Con r) where SumCon tys _ _ -> SumTy tys HeapVal -> TC HeapType -getSuperclassType :: RNest CBinder n l -> Nest CBinder l l' -> Int -> CType n -getSuperclassType _ Empty = error "bad index" -getSuperclassType bsAbove (Nest b@(_:>t) bs) = \case - 0 -> ignoreHoistFailure $ hoist bsAbove t - i -> getSuperclassType (RNest bsAbove b) bs (i-1) - instance IRRep r => HasType r (Expr r) where getType expr = case expr of App (EffTy _ ty) _ _ -> ty @@ -207,19 +201,21 @@ rawStrType :: IRRep r => Type r n rawStrType = case newName "n" of Abs b v -> do let tabTy = rawFinTabType (Var $ AtomVar v IdxRepTy) CharRepTy - DepPairTy $ DepPairType ExplicitDepPair (b:>IdxRepTy) tabTy + DepPairTy $ DepPairType ExplicitDepPair (PlainBD (b:>IdxRepTy)) tabTy -- `n` argument is IdxRepVal, not Nat rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy -tabIxType :: TabPiType r n -> IxType r n -tabIxType (TabPiType d (_:>t) _) = IxType t d +tabIxType :: IRRep r => TabPiType r n -> IxType r n +tabIxType (TabPiType d b _) = IxType (binderType b) d typesAsBinderNest :: (SinkableE e, HoistableE e, IRRep r) - => [Type r n] -> e n -> Abs (Nest (Binder r)) e n -typesAsBinderNest types body = toConstBinderNest types body + => [Type r n] -> e n -> Abs (Binders r) e n +typesAsBinderNest types body = + case toConstBinderNest types body of + Abs bs body' -> Abs (fmapNest PlainBD bs) body' nonDepPiType :: [CType n] -> EffectRow CoreIR n -> CType n -> CorePiType n nonDepPiType argTys eff resultTy = case typesAsBinderNest argTys (PairE eff resultTy) of @@ -230,7 +226,7 @@ nonDepPiType argTys eff resultTy = case typesAsBinderNest argTys (PairE eff resu nonDepTabPiType :: IRRep r => IxType r n -> Type r n -> TabPiType r n nonDepTabPiType (IxType t d) resultTy = case toConstAbsPure resultTy of - Abs b resultTy' -> TabPiType d (b:>t) resultTy' + Abs b resultTy' -> TabPiType d (PlainBD (b:>t)) resultTy' corePiTypeToPiType :: CorePiType n -> PiType CoreIR n corePiTypeToPiType (CorePiType _ _ bs effTy) = PiType bs effTy diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 4a4c2c6a5..3c627e705 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -169,10 +169,10 @@ withBuffer cont = do body <- buildBlock do cont $ sink $ Var $ binderVar b return UnitVal - let binders = BinaryNest h b + let binders = BinaryNest (PlainBD h) (PlainBD b) let expls = [Inferred Nothing Unify, Explicit] let piTy = CorePiType ExplicitApp expls binders $ EffTy eff UnitTy - let lam = LamExpr (BinaryNest h b) body + let lam = LamExpr (BinaryNest (PlainBD h) (PlainBD b)) body return $ Lam $ CoreLamExpr piTy lam applyPreludeFunction "with_stack_internal" [lam] diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 028e35981..b3be6487c 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -109,18 +109,18 @@ forceTabLam (PairE ixTy (Abs b ab)) = result <- applyRename (b@>(atomVarName v)) ab >>= emitDecls toDataAtomIgnoreRecon result -type NaryTabLamExpr = Abs (Nest SBinder) (Abs (Nest SDecl) CAtom) +type NaryTabLamExpr = Abs SBinders (Abs (Nest SDecl) CAtom) fromNaryTabLam :: Int -> CAtom n -> Maybe (Int, NaryTabLamExpr n) fromNaryTabLam maxDepth | maxDepth <= 0 = error "expected positive number of args" fromNaryTabLam maxDepth = \case SimpInCore (TabLam _ (PairE _ (Abs b body))) -> - extend <|> (Just $ (1, Abs (Nest b Empty) body)) + extend <|> (Just $ (1, Abs (Nest (PlainBD b) Empty) body)) where extend = case body of Abs Empty lam | maxDepth > 1 -> do (d, Abs (Nest b2 bs2) body2) <- fromNaryTabLam (maxDepth - 1) lam - return $ (d + 1, Abs (Nest b (Nest b2 bs2)) body2) + return $ (d + 1, Abs (Nest (PlainBD b) (Nest b2 bs2)) body2) _ -> Nothing _ -> Nothing @@ -147,19 +147,19 @@ getRepType ty = go ty where RefType h a -> RefType <$> toDataAtomIgnoreReconAssumeNoDecls h <*> go a TypeKind -> error $ notDataType HeapType -> return $ HeapType - DepPairTy (DepPairType expl b@(_:>l) r) -> do - l' <- go l + DepPairTy depPairTy@(DepPairType expl b _) -> do + l' <- go $ binderType b withFreshBinder (getNameHint b) l' \b' -> do - x <- liftSimpAtom (sink l) (Var $ binderVar b') - r' <- go =<< applySubst (b@>SubstVal x) r - return $ DepPairTy $ DepPairType expl b' r' + x <- liftSimpAtom (sink $ binderType b) (Var $ binderVar b') + r' <- go =<< instantiate depPairTy [x] + return $ DepPairTy $ DepPairType expl (PlainBD b') r' TabPi tabTy -> do let ixTy = tabIxType tabTy IxType t' d' <- simplifyIxType ixTy withFreshBinder (getNameHint tabTy) t' \b' -> do x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b') - bodyTy' <- go =<< instantiate (sink tabTy) [x] - return $ TabPi $ TabPiType d' b' bodyTy' + bodyTy' <- go =<< instantiate tabTy [x] + return $ TabPi $ TabPiType d' (PlainBD b') bodyTy' NewtypeTyCon con -> do (_, ty') <- unwrapNewtypeType con go ty' @@ -188,18 +188,18 @@ toDataAtomIgnoreReconAssumeNoDecls x = do _ -> error "unexpected decls" withSimplifiedBinders - :: Nest (Binder CoreIR) o any - -> (forall o'. DExt o o' => Nest (Binder SimpIR) o o' -> [CAtom o'] -> SimplifyM i o' a) + :: Binders CoreIR o any + -> (forall o'. DExt o o' => Binders SimpIR o o' -> [CAtom o'] -> SimplifyM i o' a) -> SimplifyM i o a withSimplifiedBinders Empty cont = getDistinct >>= \Distinct -> cont Empty [] -withSimplifiedBinders (Nest (bCore:>ty) bsCore) cont = do +withSimplifiedBinders (Nest (BD (bCore:>ty)) bsCore) cont = do simpTy <- getRepType ty withFreshBinder (getNameHint bCore) simpTy \bSimp -> do x <- liftSimpAtom (sink ty) (Var $ binderVar bSimp) -- TODO: carry a substitution instead of doing N^2 work like this Abs bsCore' UnitE <- applySubst (bCore@>SubstVal x) (EmptyAbs bsCore) withSimplifiedBinders bsCore' \bsSimp xs -> - cont (Nest bSimp bsSimp) (sink x:xs) + cont (Nest (BD bSimp) bsSimp) (sink x:xs) -- === Reconstructions === @@ -532,13 +532,13 @@ emitSpecialization s = do return name specializedFunCoreDefinition :: (Mut n, TopBuilder m) => SpecializationSpec n -> m n (TopLam CoreIR n) -specializedFunCoreDefinition (AppSpecialization f (Abs bs staticArgs)) = do - (asTopLam =<<) $ liftBuilder $ buildLamExpr (EmptyAbs bs) \runtimeArgs -> do +specializedFunCoreDefinition (AppSpecialization f ab) = do + (asTopLam =<<) $ liftBuilder $ buildLamExpr ab \runtimeArgs -> do -- This avoids an infinite loop. Otherwise, in simplifyTopFunction, -- where we eta-expand and try to simplify `App f args`, we would see `f` as a -- "noinline" function and defer its simplification. NoinlineFun _ f' <- lookupAtomName (atomVarName (sink f)) - ListE staticArgs' <- applyRename (bs@@>(atomVarName <$> runtimeArgs)) staticArgs + ListE staticArgs' <- instantiate ab (Var <$> runtimeArgs) naryApp f' staticArgs' simplifyTabApp :: forall i o. Emits o @@ -605,11 +605,11 @@ requireIxDictCache dictAbs = do {-# INLINE requireIxDictCache #-} simplifyDictMethod :: Mut n => AbsDict n -> IxMethod -> TopBuilderM n (TopLam SimpIR n) -simplifyDictMethod absDict@(Abs bs dict) method = do +simplifyDictMethod absDict method = do ty <- liftEnvReaderM $ ixMethodType method absDict lamExpr <- liftBuilder $ buildTopLamFromPi ty \allArgs -> do - let (extraArgs, methodArgs) = splitAt (nestLength bs) allArgs - dict' <- applyRename (bs @@> (atomVarName <$> extraArgs)) dict + let (extraArgs, methodArgs) = splitAt (arity absDict) allArgs + dict' <- instantiate absDict (Var <$> extraArgs) emitExpr =<< mkApplyMethod dict' (fromEnum method) (Var <$> methodArgs) simplifyTopFunction lamExpr @@ -656,14 +656,14 @@ simplifyLam :: LamExpr CoreIR i -> SimplifyM i o (LamExpr SimpIR o, Abs (Nest (AtomNameBinder SimpIR)) ReconstructAtom o) simplifyLam (LamExpr bsTop body) = case bsTop of - Nest (b:>ty) bs -> do - ty' <- substM ty + Nest b bs -> do + ty' <- substM $ binderType b tySimp <- getRepType ty' withFreshBinder (getNameHint b) tySimp \b''@(b':>_) -> do x <- liftSimpAtom (sink ty') (Var $ binderVar b'') - extendSubst (b@>SubstVal x) do + extendSubstBD b [SubstVal x] do (LamExpr bs' body', Abs bsRecon recon) <- simplifyLam $ LamExpr bs body - return (LamExpr (Nest (b':>tySimp) bs') body', Abs (Nest b' bsRecon) recon) + return (LamExpr (Nest (PlainBD (b':>tySimp)) bs') body', Abs (Nest b' bsRecon) recon) Empty -> do SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body return (LamExpr Empty body', Abs Empty recon) @@ -829,7 +829,8 @@ simplifyHof _hint resultTy = \case let recon' = ignoreHoistFailure $ hoist b recon applyRecon recon' ans RunWriter Nothing (BaseMonoid e combine) lam -> do - LamExpr (BinaryNest h (_:>RefTy _ wTy)) _ <- return lam + LamExpr (BinaryNest h refB) _ <- return lam + RefTy _ wTy <- return $ binderType refB wTy' <- substM $ ignoreHoistFailure $ hoist h wTy e' <- simplifyDataAtom e (combine', CoerceReconAbs) <- simplifyLam combine @@ -961,7 +962,7 @@ tryGetCustomRule f' = do _ -> return Nothing _ -> return Nothing -type Linearized = Abs (Nest SBinder) -- primal args +type Linearized = Abs SBinders -- primal args (Abs (Nest SDecl) -- primal decls (PairE SAtom -- primal result SLam)) -- tangent function @@ -973,7 +974,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do CustomLinearize nImplicit nExplicit zeros fCustom <- return rule linearized <- withSimplifiedBinders runtimeBs \runtimeBs' runtimeArgs -> do Abs runtimeBs' <$> buildScoped do - ListE staticArgs' <- instantiate (sink $ Abs runtimeBs staticArgs) (sink <$> runtimeArgs) + ListE staticArgs' <- instantiate (Abs runtimeBs staticArgs) (sink <$> runtimeArgs) fCustom' <- sinkM fCustom resultTy <- typeOfApp (getType fCustom') staticArgs' pairResult <- dropSubst $ simplifyApp noHint resultTy fCustom' staticArgs' @@ -1023,7 +1024,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do return $ activeArg':rest buildTangentArgs _ _ _ = error "zip error" - fromNonDepNest :: Nest CBinder n l -> [CType n] + fromNonDepNest :: Nest CBinderAndDecls n l -> [CType n] fromNonDepNest Empty = [] fromNonDepNest (Nest b bs) = case ignoreHoistFailure $ hoist b (Abs bs UnitE) of @@ -1038,13 +1039,13 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do LamExpr tBs _ <- return fLin residualsTangentsBs <- withFreshBinder "residual" rTy \rB -> do Abs tBs' UnitE <- sinkM $ Abs tBs UnitE - return $ Abs (Nest rB tBs') UnitE + return $ Abs (Nest (BD rB) tBs') UnitE residualsTangentsBs' <- return $ ignoreHoistFailure $ hoist decls residualsTangentsBs return (Abs decls (PairVal primalResult residuals), reconAbs, residualsTangentsBs') let primalFun = LamExpr bs declsAndResult LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do - LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (Var residuals) - applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emitBlock + lam <- applyReconAbs (sink reconAbs) (Var residuals) + instantiate lam (Var <$> tangents) >>= emitBlock let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody return $ PairE primalFun tangentFun @@ -1095,7 +1096,7 @@ exceptToMaybeExpr expr = case expr of s' <- substM s BinaryLamExpr h ref body <- return lam result <- emitRunState noHint s' \h' ref' -> - extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do + extendSubstBD (BinaryNest h ref) [Rename (atomVarName h'), Rename (atomVarName ref')] do blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type exceptToMaybeBlock blockResultTy body (maybeAns, newState) <- fromPair result @@ -1107,7 +1108,7 @@ exceptToMaybeExpr expr = case expr of monoid' <- substM monoid PairTy _ accumTy <- substM resultTy result <- emitRunWriter noHint accumTy monoid' \h' ref' -> - extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do + extendSubstBD (BinaryNest h ref) [Rename (atomVarName h'), Rename (atomVarName ref')] do blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type exceptToMaybeBlock blockResultTy body (maybeAns, accumResult) <- fromPair result diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 904e608d1..e75e5cd4c 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -44,26 +44,26 @@ transposeTopFun :: (MonadFail1 m, EnvReader m) => STopLam n -> m n (STopLam n) transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do - (Abs bsNonlin (Abs bLin body), Abs bsNonlin'' outTy) <- unpackLinearLamExpr lam + (Abs bsNonlin (Abs bLin body), outTyAbs) <- unpackLinearLamExpr lam refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do - outTy' <- applyRename (bsNonlin''@@> nestToNames bsNonlin') outTy - withFreshBinder "ct" outTy' \bCT -> do + outTy <- instantiate outTyAbs (Var <$> bindersVars bsNonlin') + withFreshBinder "ct" outTy \bCT -> do let ct = Var $ binderVar bCT body' <- buildBlock do inTy <- substNonlin $ binderType bLin withAccumulator inTy \refSubstVal -> - extendSubst (bLin @> refSubstVal) $ + extendSubstBD bLin [refSubstVal] $ transposeBlock body (sink ct) EffTy _ bodyTy <- blockEffTy body' - let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (EffTy Pure bodyTy) - let lamT = LamExpr (bsNonlin' >>> UnaryNest bCT) body' + let piTy = PiType (bsNonlin' >>> UnaryNest (PlainBD bCT)) (EffTy Pure bodyTy) + let lamT = LamExpr (bsNonlin' >>> UnaryNest (PlainBD bCT)) body' return $ TopLam False piTy lamT transposeTopFun (TopLam True _ _) = error "shouldn't be transposing in destination passing style" unpackLinearLamExpr :: (MonadFail1 m, EnvReader m) => LamExpr SimpIR n - -> m n ( Abs (Nest SBinder) (Abs SBinder SBlock) n - , Abs (Nest SBinder) SType n) + -> m n ( Abs SBinders (Abs SBinderAndDecls SBlock) n + , Abs SBinders SType n) unpackLinearLamExpr lam@(LamExpr bs body) = do let numNonlin = nestLength bs - 1 PairB bsNonlin (UnaryNest bLin) <- return $ splitNestAt numNonlin bs @@ -338,7 +338,7 @@ transposeHof hof ct = case hof of RunState Nothing s (BinaryLamExpr hB refB body) -> do (ctBody, ctState) <- fromPair ct (_, cts) <- (fromPair =<<) $ emitRunState noHint ctState \h ref -> do - extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ + extendSubstBD hB [RenameNonlin (atomVarName h)] $ extendSubstBD refB [RenameNonlin (atomVarName ref)] $ extendLinRegions h $ transposeBlock body (sink ctBody) return UnitVal @@ -347,7 +347,7 @@ transposeHof hof ct = case hof of accumTy <- substNonlin $ getType r baseMonoid <- tangentBaseMonoidFor accumTy (_, ct') <- (fromPair =<<) $ emitRunWriter noHint accumTy baseMonoid \h ref -> do - extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ + extendSubstBD hB [RenameNonlin (atomVarName h)] $ extendSubstBD refB [RenameNonlin (atomVarName ref)] $ extendLinRegions h $ transposeBlock body (sink ct) return UnitVal @@ -356,7 +356,7 @@ transposeHof hof ct = case hof of -- TODO: check we have the 0/+ monoid (ctBody, ctEff) <- fromPair ct void $ emitRunReader noHint ctEff \h ref -> do - extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ + extendSubstBD hB [RenameNonlin (atomVarName h)] $ extendSubstBD refB [RenameNonlin (atomVarName ref)] $ extendLinRegions h $ transposeBlock body (sink ctBody) return UnitVal diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 067af737e..9bfeec19b 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -143,6 +143,11 @@ type AtomBinderP (r::IR) = BinderP (AtomNameC r) type Binder r = AtomBinderP r (Type r) :: B type Alt r = Abs (Binder r) (Block r) :: E +-- This doesn't actually include the decls yet. I'm starting by making the type +-- distinct from the underlying binder without changing anything else. +data BinderAndDecls (r::IR) (n::S) (l::S) = BD (Binder r n l) +type Binders r = Nest (BinderAndDecls r) + newtype DotMethods n = DotMethods (M.Map SourceName (CAtomName n)) deriving (Show, Generic, Monoid, Semigroup) @@ -152,7 +157,7 @@ data TyConDef n where TyConDef :: SourceName -> [RoleExpl] - -> Nest CBinder n l + -> CBinders n l -> DataConDefs l -> TyConDef n @@ -164,7 +169,7 @@ data DataConDefs n = data DataConDef n = -- Name for pretty printing, constructor elements, representation type, -- list of projection indices that recovers elements from the representation. - DataConDef SourceName (EmptyAbs (Nest CBinder) n) (CType n) [[Projection]] + DataConDef SourceName (EmptyAbs CBinders n) (CType n) [[Projection]] deriving (Show, Generic) data ParamRole = TypeParam | DictParam | DataParam deriving (Show, Generic, Eq) @@ -183,7 +188,7 @@ data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) deriving (Show, Generic) data LamExpr (r::IR) (n::S) where - LamExpr :: Nest (Binder r) n l -> Block r l -> LamExpr r n + LamExpr :: Binders r n l -> Block r l -> LamExpr r n data CoreLamExpr (n::S) = CoreLamExpr (CorePiType n) (LamExpr CoreIR n) @@ -209,16 +214,16 @@ data IxType (r::IR) (n::S) = deriving (Show, Generic) data TabPiType (r::IR) (n::S) where - TabPiType :: IxDict r n -> Binder r n l -> Type r l -> TabPiType r n + TabPiType :: IxDict r n -> BinderAndDecls r n l -> Type r l -> TabPiType r n data PiType (r::IR) (n::S) where - PiType :: Nest (Binder r) n l -> EffTy r l -> PiType r n + PiType :: Binders r n l -> EffTy r l -> PiType r n data CorePiType (n::S) where - CorePiType :: AppExplicitness -> [Explicitness] -> Nest CBinder n l -> EffTy CoreIR l -> CorePiType n + CorePiType :: AppExplicitness -> [Explicitness] -> CBinders n l -> EffTy CoreIR l -> CorePiType n data DepPairType (r::IR) (n::S) where - DepPairType :: DepPairExplicitness -> Binder r n l -> Type r l -> DepPairType r n + DepPairType :: DepPairExplicitness -> BinderAndDecls r n l -> Type r l -> DepPairType r n type Val = Atom type Kind = Type @@ -233,8 +238,17 @@ data NonDepNest r ann n l = NonDepNest (Nest (AtomNameBinder r) n l) [ann n] -- === ToAtomAbs class === +class ToBinders (b::B) (r::IR) | b -> r where + toBinders :: b n l -> Binders r n l + +instance ToBinders (BinderAndDecls r) r where + toBinders b = UnaryNest b + +instance ToBinders (Binders r) r where + toBinders b = b + class ToBindersAbs (e::E) (body::E) (r::IR) | e -> body, e -> r where - toAbs :: e n -> Abs (Nest (Binder r)) body n + toAbs :: e n -> Abs (Binders r) body n instance ToBindersAbs CorePiType (EffTy CoreIR) CoreIR where toAbs (CorePiType _ _ bs effTy) = Abs bs effTy @@ -242,8 +256,8 @@ instance ToBindersAbs CorePiType (EffTy CoreIR) CoreIR where instance ToBindersAbs CoreLamExpr (Block CoreIR) CoreIR where toAbs (CoreLamExpr _ lam) = toAbs lam -instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where - toAbs = id +instance ToBinders b r => ToBindersAbs (Abs b body) body r where + toAbs (Abs b e) = Abs (toBinders b) e instance ToBindersAbs (PiType r) (EffTy r) r where toAbs (PiType bs effTy) = Abs bs effTy @@ -263,7 +277,7 @@ instance ToBindersAbs InstanceDef (ListE CAtom `PairE` InstanceBody) CoreIR wher instance ToBindersAbs TyConDef DataConDefs CoreIR where toAbs (TyConDef _ _ bs body) = Abs bs body -instance ToBindersAbs ClassDef (Abs (Nest CBinder) (ListE CorePiType)) CoreIR where +instance ToBindersAbs ClassDef (Abs CBinders (ListE CorePiType)) CoreIR where toAbs (ClassDef _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys)) instance ToBindersAbs (TopLam r) (Block r) r where @@ -432,6 +446,8 @@ data RefOp r n = type CAtom = Atom CoreIR type CType = Type CoreIR type CBinder = Binder CoreIR +type CBinderAndDecls = BinderAndDecls CoreIR +type CBinders = Binders CoreIR type CExpr = Expr CoreIR type CBlock = Block CoreIR type CDecl = Decl CoreIR @@ -450,6 +466,8 @@ type SDecls = Decls SimpIR type SAtomName = AtomName SimpIR type SAtomVar = AtomVar SimpIR type SBinder = Binder SimpIR +type SBinderAndDecls = BinderAndDecls SimpIR +type SBinders = Binders SimpIR type SRepVal = RepVal SimpIR type SLam = LamExpr SimpIR type STopLam = TopLam SimpIR @@ -488,8 +506,8 @@ data ClassDef (n::S) where -> [SourceName] -- method source names -> [Maybe SourceName] -- parameter source names -> [RoleExpl] -- parameter info - -> Nest CBinder n1 n2 -- parameters - -> Nest CBinder n2 n3 -- superclasses + -> CBinders n1 n2 -- parameters + -> CBinders n2 n3 -- superclasses -> [CorePiType n3] -- method types -> ClassDef n1 @@ -497,7 +515,7 @@ data InstanceDef (n::S) where InstanceDef :: ClassName n1 -> [RoleExpl] -- parameter info - -> Nest CBinder n1 n2 -- parameters (types and dictionaries) + -> CBinders n1 n2 -- parameters (types and dictionaries) -> [CAtom n2] -- class parameters -> InstanceBody n2 -> InstanceDef n1 @@ -913,7 +931,7 @@ instance IRRep r => Store (Effect r n) -- === Specialization and generalization === type Generalized (r::IR) (e::E) (n::S) = (Abstracted r e n, [Atom r n]) -type Abstracted (r::IR) (e::E) = Abs (Nest (Binder r)) e +type Abstracted (r::IR) (e::E) = Abs (Binders r) e type AbsDict = Abstracted CoreIR Dict data SpecializedDictDef n = @@ -937,14 +955,22 @@ data LinearizationSpec (n::S) = -- === Binder utils === -binderType :: Binder r n l -> Type r n -binderType (_:>ty) = ty +class BindsNames b => ToBinderVar (b::B) (r::IR) | b -> r where + binderType :: b n l -> Type r n + binderVar :: (IRRep r, DExt n l) => b n l -> AtomVar r l -binderVar :: (IRRep r, DExt n l) => Binder r n l -> AtomVar r l -binderVar (b:>ty) = AtomVar (binderName b) (sink ty) +instance IRRep r => ToBinderVar (BinderAndDecls r) r where + binderType (BD (_:>ty)) = ty + binderVar (BD (b:>ty)) = + AtomVar (sink $ binderName b) (sink ty) -bindersVars :: (Distinct l, Ext n l, IRRep r) - => Nest (Binder r) n l -> [AtomVar r l] +instance IRRep r => ToBinderVar (Binder r) r where + binderType (_:>ty) = ty + binderVar (b:>ty) = AtomVar (sink $ binderName b) (sink ty) + +bindersVars + :: (Distinct l, Ext n l, IRRep r, ToBinderVar b r) + => Nest b n l -> [AtomVar r l] bindersVars = \case Empty -> [] Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ @@ -1064,7 +1090,7 @@ pattern RefTy r a = TC (RefType r a) pattern RawRefTy :: Type r n -> Type r n pattern RawRefTy a = TC (RefType (Con HeapVal) a) -pattern TabTy :: IxDict r n -> Binder r n l -> Type r l -> Type r n +pattern TabTy :: IxDict r n -> BinderAndDecls r n l -> Type r l -> Type r n pattern TabTy d b body = TabPi (TabPiType d b body) pattern FinTy :: Atom CoreIR n -> Type CoreIR n @@ -1085,13 +1111,24 @@ pattern EffKind = NewtypeTyCon EffectRowKind pattern FinConst :: Word32 -> Type CoreIR n pattern FinConst n = NewtypeTyCon (Fin (NatVal n)) +pattern PlainBD :: Binder r n l -> BinderAndDecls r n l +pattern PlainBD b = BD b -- this will become `BD b Empty` + pattern NullaryLamExpr :: Block r n -> LamExpr r n pattern NullaryLamExpr body = LamExpr Empty body +asUnaryLamExpr :: LamExpr r n -> Maybe (Abs (Binder r) (Block r) n) +asUnaryLamExpr (LamExpr (UnaryNest (BD b)) (Abs decls result)) = + Just $ Abs b $ Abs decls result +-- asUnaryLamExpr (LamExpr (UnaryNest (BD b decls)) (Abs decls' result)) = +-- Just $ Abs b $ Abs (decls >>> decls') result +asUnaryLamExpr _ = Nothing + pattern UnaryLamExpr :: Binder r n l -> Block r l -> LamExpr r n -pattern UnaryLamExpr b body = LamExpr (UnaryNest b) body +pattern UnaryLamExpr b body <- (asUnaryLamExpr -> Just (Abs b body)) + where UnaryLamExpr b body = LamExpr (UnaryNest (BD b)) body -pattern BinaryLamExpr :: Binder r n l1 -> Binder r l1 l2 -> Block r l2 -> LamExpr r n +pattern BinaryLamExpr :: BinderAndDecls r n l1 -> BinderAndDecls r l1 l2 -> Block r l2 -> LamExpr r n pattern BinaryLamExpr b1 b2 body = LamExpr (BinaryNest b1 b2) body pattern WithoutDecls :: e n -> WithDecls r e n @@ -1196,7 +1233,7 @@ instance AlphaEqE DataConDefs instance AlphaHashableE DataConDefs instance GenericE TyConDef where - type RepE TyConDef = PairE (LiftE (SourceName, [RoleExpl])) (Abs (Nest CBinder) DataConDefs) + type RepE TyConDef = PairE (LiftE (SourceName, [RoleExpl])) (Abs CBinders DataConDefs) fromE (TyConDef sourceName expls bs cons) = PairE (LiftE (sourceName, expls)) (Abs bs cons) {-# INLINE fromE #-} toE (PairE (LiftE (sourceName, expls)) (Abs bs cons)) = TyConDef sourceName expls bs cons @@ -1218,7 +1255,7 @@ instance HasSourceName (TyConDef n) where instance GenericE DataConDef where type RepE DataConDef = (LiftE (SourceName, [[Projection]])) - `PairE` EmptyAbs (Nest CBinder) `PairE` Type CoreIR + `PairE` EmptyAbs CBinders `PairE` Type CoreIR fromE (DataConDef name bs repTy idxs) = (LiftE (name, idxs)) `PairE` bs `PairE` repTy {-# INLINE fromE #-} toE ((LiftE (name, idxs)) `PairE` bs `PairE` repTy) = DataConDef name bs repTy idxs @@ -1853,7 +1890,7 @@ deriving instance (Show (ann n)) => IRRep r => Show (NonDepNest r ann n l) instance GenericE ClassDef where type RepE ClassDef = LiftE (SourceName, [SourceName], [Maybe SourceName], [RoleExpl]) - `PairE` Abs (Nest CBinder) (Abs (Nest CBinder) (ListE CorePiType)) + `PairE` Abs CBinders (Abs CBinders (ListE CorePiType)) fromE (ClassDef name names paramNames roleExpls b scs tys) = LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys)) {-# INLINE fromE #-} @@ -1873,7 +1910,7 @@ instance HasSourceName (ClassDef n) where instance GenericE InstanceDef where type RepE InstanceDef = - ClassName `PairE` LiftE [RoleExpl] `PairE` Abs (Nest CBinder) (ListE CAtom `PairE` InstanceBody) + ClassName `PairE` LiftE [RoleExpl] `PairE` Abs CBinders (ListE CAtom `PairE` InstanceBody) fromE (InstanceDef name expls bs params body) = name `PairE` LiftE expls `PairE` Abs bs (ListE params `PairE` body) toE (name `PairE` LiftE expls `PairE` Abs bs (ListE params `PairE` body)) = @@ -1979,7 +2016,7 @@ instance Semigroup (Cache n) where Cache (y1<>x1) (y2<>x2) (y3<>x3) (y4<>x4) (x5<>y5) (x6<>y6) instance GenericE (LamExpr r) where - type RepE (LamExpr r) = Abs (Nest (Binder r)) (Block r) + type RepE (LamExpr r) = Abs (Binders r) (Block r) fromE (LamExpr b block) = Abs b block {-# INLINE fromE #-} toE (Abs b block) = LamExpr b block @@ -2009,7 +2046,7 @@ deriving instance Show (CoreLamExpr n) deriving via WrapE CoreLamExpr n instance Generic (CoreLamExpr n) instance GenericE CorePiType where - type RepE CorePiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) (EffTy CoreIR) + type RepE CorePiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs CBinders (EffTy CoreIR) fromE (CorePiType ex exs b effTy) = LiftE (ex, exs) `PairE` Abs b effTy {-# INLINE fromE #-} toE (LiftE (ex, exs) `PairE` Abs b effTy) = CorePiType ex exs b effTy @@ -2065,7 +2102,7 @@ instance IRRep r => AlphaHashableE (IxType r) where hashWithSaltE env salt (IxType t _) = hashWithSaltE env salt t instance IRRep r => GenericE (TabPiType r) where - type RepE (TabPiType r) = PairE (IxDict r) (Abs (Binder r) (Type r)) + type RepE (TabPiType r) = PairE (IxDict r) (Abs (BinderAndDecls r) (Type r)) fromE (TabPiType d b resultTy) = PairE d (Abs b resultTy) {-# INLINE fromE #-} toE (PairE d (Abs b resultTy)) = TabPiType d b resultTy @@ -2088,7 +2125,7 @@ deriving instance IRRep r => Show (TabPiType r n) deriving via WrapE (TabPiType r) n instance IRRep r => Generic (TabPiType r n) instance GenericE (PiType r) where - type RepE (PiType r) = Abs (Nest (Binder r)) (EffTy r) + type RepE (PiType r) = Abs (Binders r) (EffTy r) fromE (PiType bs effTy) = Abs bs effTy {-# INLINE fromE #-} toE (Abs bs effTy) = PiType bs effTy @@ -2104,7 +2141,7 @@ deriving via WrapE (PiType r) n instance IRRep r => Generic (PiType r n) instance IRRep r => Store (PiType r n) instance GenericE (DepPairType r) where - type RepE (DepPairType r) = PairE (LiftE DepPairExplicitness) (Abs (Binder r) (Type r)) + type RepE (DepPairType r) = PairE (LiftE DepPairExplicitness) (Abs (BinderAndDecls r) (Type r)) fromE (DepPairType expl b resultTy) = LiftE expl `PairE` Abs b resultTy {-# INLINE fromE #-} toE (LiftE expl `PairE` Abs b resultTy) = DepPairType expl b resultTy @@ -2230,7 +2267,7 @@ instance AlphaHashableE TopFun instance GenericE SpecializationSpec where type RepE SpecializationSpec = - PairE (AtomVar CoreIR) (Abs (Nest (Binder CoreIR)) (ListE CAtom)) + PairE (AtomVar CoreIR) (Abs (Binders CoreIR) (ListE CAtom)) fromE (AppSpecialization fname (Abs bs args)) = PairE fname (Abs bs args) {-# INLINE fromE #-} toE (PairE fname (Abs bs args)) = AppSpecialization fname (Abs bs args) @@ -2378,6 +2415,27 @@ instance IRRep r => AlphaHashableB (Decl r) instance IRRep r => ProvesExt (Decl r) instance IRRep r => BindsNames (Decl r) +instance GenericB (BinderAndDecls r) where + type RepB (BinderAndDecls r) = Binder r + fromB (BD b) = b + {-# INLINE fromB #-} + toB b = BD b + {-# INLINE toB #-} + +instance HasNameHint (BinderAndDecls r n l) where + getNameHint (BD b) = getNameHint b + +deriving instance IRRep r => Show (BinderAndDecls r n l) +deriving via WrapB (BinderAndDecls r) n l instance IRRep r => Generic (BinderAndDecls r n l) + +instance IRRep r => SinkableB (BinderAndDecls r) +instance IRRep r => HoistableB (BinderAndDecls r) +instance IRRep r => RenameB (BinderAndDecls r) +instance IRRep r => AlphaEqB (BinderAndDecls r) +instance IRRep r => AlphaHashableB (BinderAndDecls r) +instance IRRep r => ProvesExt (BinderAndDecls r) +instance IRRep r => BindsNames (BinderAndDecls r) + instance IRRep r => GenericE (Effect r) where type RepE (Effect r) = EitherE3 (PairE (LiftE RWS) (Atom r)) @@ -2744,6 +2802,7 @@ instance Store (SpecializationSpec n) instance Store (LinearizationSpec n) instance IRRep r => Store (DeclBinding r n) instance IRRep r => Store (Decl r n l) +instance IRRep r => Store (BinderAndDecls r n l) instance Store (TyConParams n) instance Store (DataConDefs n) instance Store (TyConDef n) diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index d9a62728a..e1cc78587 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -26,7 +26,6 @@ import Subst import PPrint import QueryType import Types.Core -import Types.OpNames qualified as P import Types.Primitives import Util (allM, zipWithZ) @@ -94,10 +93,10 @@ newtype TopVectorizeM (i::S) (o::S) (a:: *) = TopVectorizeM vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, Errs) vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do case popNest bsDestB of - Just (PairB bs b) -> + Just (PairB bs (BD b)) -> refreshAbs (Abs bs (Abs b body)) \bs' body' -> do (Abs b'' body'', errs) <- liftTopVectorizeM width $ vectorizeLoopsDestBlock body' - return $ (TopLam d ty (LamExpr (bs' >>> UnaryNest b'') body''), errs) + return $ (TopLam d ty (LamExpr (bs' >>> UnaryNest (PlainBD b'')) body''), errs) Nothing -> error "expected a trailing dest binder" liftTopVectorizeM :: (EnvReader m) @@ -159,12 +158,12 @@ vectorizeLoopsDecls nest cont = vectorizeLoopsLamExpr :: LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o) vectorizeLoopsLamExpr (LamExpr bs body) = case bs of Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsBlock body) - Nest (b:>ty) rest -> do - ty' <- renameM ty - withFreshBinder (getNameHint b) ty' \b' -> do - extendRenamer (b @> binderName b') do + Nest b rest -> do + ty <- renameM $ binderType b + withFreshBinder (getNameHint b) ty \b' -> do + extendSubstBD b [binderName b'] do LamExpr bs' body' <- vectorizeLoopsLamExpr $ LamExpr rest body - return $ LamExpr (Nest b' bs') body' + return $ LamExpr (Nest (BD b') bs') body' vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o) vectorizeLoopsExpr expr = do @@ -197,8 +196,8 @@ vectorizeLoopsExpr expr = do item' <- renameM item itemTy <- return $ getType item' lam <- buildEffLam noHint itemTy \hb refb -> - extendRenamer (hb' @> atomVarName hb) do - extendRenamer (refb' @> atomVarName refb) do + extendSubstBD hb' [atomVarName hb] do + extendSubstBD refb' [atomVarName refb] do vectorizeLoopsBlock body PrimOp . Hof <$> mkTypedHof (RunReader item' lam) PrimOp (Hof (TypedHof (EffTy _ ty) @@ -208,8 +207,8 @@ vectorizeLoopsExpr expr = do commutativity <- monoidCommutativity monoid' PairTy _ accTy <- renameM ty lam <- buildEffLam noHint accTy \hb refb -> - extendRenamer (hb' @> atomVarName hb) do - extendRenamer (refb' @> atomVarName refb) do + extendSubstBD hb'[atomVarName hb] do + extendSubstBD refb' [atomVarName refb] do extendCommuteMap (atomVarName hb) commutativity do vectorizeLoopsBlock body PrimOp . Hof <$> mkTypedHof (RunWriter (Just dest') monoid' lam) @@ -257,22 +256,12 @@ monoidCommutativity monoid = case isAdditionMonoid monoid of Nothing -> return DoesNotCommute {-# INLINE monoidCommutativity #-} +-- XXX: this is wrong. We need to add a mechanism for specifying commutativity isAdditionMonoid :: BaseMonoid SimpIR n -> Maybe () isAdditionMonoid monoid = do BaseMonoid { baseEmpty = (Con (Lit l)) - , baseCombine = BinaryLamExpr (b1:>_) (b2:>_) body } <- Just monoid + , baseCombine = BinaryLamExpr _ _ _ } <- Just monoid unless (_isZeroLit l) Nothing - PrimOp (BinOp op (Var b1') (Var b2')) <- exprBlock body - unless (op `elem` [P.IAdd, P.FAdd]) Nothing - case (binderName b1, atomVarName b1', binderName b2, atomVarName b2') of - -- Checking the raw names here because (i) I don't know how to convince the - -- name system to let me check the well-typed names (which is because b2 - -- might shadow b1), and (ii) there are so few patterns that I can just - -- enumerate them. - (UnsafeMakeName n1, UnsafeMakeName n1', UnsafeMakeName n2, UnsafeMakeName n2') -> do - when (n1 == n2) Nothing - unless ((n1 == n1' && n2 == n2') || (n1 == n2' && n2 == n1')) Nothing - Just () _isZeroLit :: LitVal -> Bool _isZeroLit = \case @@ -369,14 +358,14 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of vectorizeBlock body >>= \case (VVal _ ans) -> return ans (VRename v) -> Var <$> toAtomVar v) - (Nest (b:>ty) rest, (stab:stabs)) -> do - ty' <- vectorizeType ty + (Nest b rest, (stab:stabs)) -> do + ty' <- vectorizeType $ binderType b ty'' <- promoteTypeByStability ty' stab withFreshBinder (getNameHint b) ty'' \b' -> do var <- toAtomVar $ binderName b' - extendSubst (b @> VVal stab (Var var)) do + extendSubstBD b [VVal stab (Var var)] do LamExpr rest' body' <- vectorizeLamExpr (LamExpr rest body) stabs - return $ LamExpr (Nest b' rest') body' + return $ LamExpr (Nest (PlainBD b') rest') body' _ -> error "Zip error" vectorizeBlock :: Emits o => SBlock i -> VectorizeM i o (VAtom o) diff --git a/tests/unit/ConstantCastingSpec.hs b/tests/unit/ConstantCastingSpec.hs index fe9abab12..e7ff079f1 100644 --- a/tests/unit/ConstantCastingSpec.hs +++ b/tests/unit/ConstantCastingSpec.hs @@ -31,9 +31,9 @@ castOp ty x = MiscOp $ CastOp (BaseTy (Scalar ty)) x castLam :: EnvExtender m => ScalarBaseType -> ScalarBaseType -> m n (SLam n) castLam fromTy toTy = do - withFreshBinder noHint (BaseTy (Scalar fromTy)) \x -> do - body <- exprToBlock $ PrimOp $ castOp toTy $ Var $ binderVar x - return $ LamExpr (Nest x Empty) body + withFreshBinder noHint (BaseTy (Scalar fromTy)) \b -> do + body <- exprToBlock $ PrimOp $ castOp toTy $ Var $ binderVar b + return $ LamExpr (Nest (PlainBD b) Empty) body exprToBlock :: EnvReader m => SExpr n -> m n (SBlock n) exprToBlock expr = do From 84d12b8c39f415df7da0972fc686723158190680 Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 21 Jul 2023 12:34:25 -0400 Subject: [PATCH 2/2] WIP --- dex.cabal | 3 +- src/lib/Builder.hs | 728 ++++++++++++++++++++++++++--- src/lib/CheapReduction.hs | 958 -------------------------------------- src/lib/CheckType.hs | 150 +++--- src/lib/Core.hs | 10 +- src/lib/Export.hs | 1 - src/lib/Generalize.hs | 17 +- src/lib/Imp.hs | 172 +++---- src/lib/Inference.hs | 794 ++++++++++++++++--------------- src/lib/Inline.hs | 16 +- src/lib/Linearize.hs | 86 ++-- src/lib/Lower.hs | 40 +- src/lib/OccAnalysis.hs | 16 +- src/lib/Optimize.hs | 56 +-- src/lib/PPrint.hs | 17 +- src/lib/QueryType.hs | 736 ++++++++++++----------------- src/lib/QueryTypePure.hs | 312 ------------- src/lib/RuntimePrint.hs | 26 +- src/lib/Simplify.hs | 404 ++++++++-------- src/lib/Subst.hs | 148 +++++- src/lib/TopLevel.hs | 20 +- src/lib/Transpose.hs | 60 ++- src/lib/Types/Core.hs | 92 ++-- src/lib/Vectorize.hs | 66 +-- src/lib/Visitor.hs | 269 +++++++++++ 25 files changed, 2422 insertions(+), 2775 deletions(-) delete mode 100644 src/lib/CheapReduction.hs delete mode 100644 src/lib/QueryTypePure.hs create mode 100644 src/lib/Visitor.hs diff --git a/dex.cabal b/dex.cabal index 1b6b7b771..b44c6a2ae 100644 --- a/dex.cabal +++ b/dex.cabal @@ -49,7 +49,7 @@ library , Builder , CUDA , Cat - , CheapReduction + , Visitor , CheckType , ConcreteSyntax , Core @@ -96,7 +96,6 @@ library , Types.OpNames , Types.Source , QueryType - , QueryTypePure , Util , Vectorize if flag(live) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index b00baa6d1..cf3940b10 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -10,19 +10,26 @@ module Builder where import Control.Applicative +import Control.Category ((>>>)) import Control.Monad import Control.Monad.Reader import Control.Monad.Writer.Strict hiding (Alt) import Control.Monad.State.Strict (MonadState (..), StateT (..), runStateT) import qualified Data.Map.Strict as M import Data.Foldable (fold) +import Data.Functor ((<&>)) +import Data.Graph (graphFromEdges, topSort) +import Data.List (elemIndex) +import Data.Text.Prettyprint.Doc (Pretty (..)) +import Foreign.Ptr + +import qualified Unsafe.Coerce as TrulyUnsafe import Data.Graph (graphFromEdges, topSort) import Data.Text.Prettyprint.Doc (Pretty (..)) import Foreign.Ptr import qualified Unsafe.Coerce as TrulyUnsafe -import CheapReduction import Core import Err import IRVariants @@ -34,7 +41,8 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source -import Util (enumerate, transitiveClosureM, bindM2, toSnocList, (...)) +import Util (enumerate, transitiveClosureM, bindM2, toSnocList, (...), Tree (..)) +import Visitor -- === Ordinary (local) builder class === @@ -49,6 +57,12 @@ class Builder r m => ScopableBuilder (r::IR) (m::MonadKind1) | m -> r where -> (forall l. DExt n l => Nest (Decl r) n l -> e l -> m l a) -> m n a +buildScoped2 + :: (ScopableBuilder r m, SinkableE e) + => (forall l. (Emits l, DExt n l) => m l (Nest (Decl r) n l -> a)) + -> m n a +buildScoped2 cont = undefined -- buildScopedAndThen cont \decls body -> return $ Abs decls body + buildScoped :: (ScopableBuilder r m, SinkableE e) => (forall l. (Emits l, DExt n l) => m l (e l)) @@ -61,6 +75,10 @@ type CBuilder = Builder CoreIR type Builder2 (r::IR) (m :: MonadKind2) = forall i. Builder r (m i) type ScopableBuilder2 (r::IR) (m :: MonadKind2) = forall i. ScopableBuilder r (m i) +type BuilderEmits r m n = (Builder r m, Emits n, IRRep r) +type CBuilderEmits m n = BuilderEmits CoreIR m n +type SBuilderEmits m n = BuilderEmits SimpIR m n + emitDecl :: (Builder r m, Emits n) => NameHint -> LetAnn -> Expr r n -> m n (AtomVar r n) emitDecl _ _ (Atom (Var n)) = return n emitDecl hint ann expr = rawEmitDecl hint ann expr @@ -78,23 +96,23 @@ emitOp :: (Builder r m, IsPrimOp e, Emits n) => e r n -> m n (Atom r n) emitOp op = Var <$> emit (PrimOp $ toPrimOp op) {-# INLINE emitOp #-} -emitExpr :: (Builder r m, Emits n) => Expr r n -> m n (Atom r n) +emitExpr :: BuilderEmits r m n => Expr r n -> m n (Atom r n) emitExpr expr = Var <$> emit expr {-# INLINE emitExpr #-} -emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) +emitHof :: BuilderEmits r m n => Hof r n -> m n (Atom r n) emitHof hof = mkTypedHof hof >>= emitOp -mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n) +mkTypedHof :: BuilderEmits r m n => Hof r n -> m n (TypedHof r n) mkTypedHof hof = do effTy <- effTyOfHof hof return $ TypedHof effTy hof -emitUnOp :: (Builder r m, Emits n) => UnOp -> Atom r n -> m n (Atom r n) +emitUnOp :: BuilderEmits r m n => UnOp -> Atom r n -> m n (Atom r n) emitUnOp op x = emitOp $ UnOp op x {-# INLINE emitUnOp #-} -emitBlock :: (Builder r m, Emits n) => Block r n -> m n (Atom r n) +emitBlock :: BuilderEmits r m n => Block r n -> m n (Atom r n) emitBlock = emitDecls emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e) @@ -109,7 +127,7 @@ emitDecls' (Nest (Let b (DeclBinding ann expr)) rest) e = do AtomVar v _ <- emitDecl (getNameHint b) ann expr' extendSubst (b @> v) $ emitDecls' rest e -emitExprToAtom :: (Builder r m, Emits n) => Expr r n -> m n (Atom r n) +emitExprToAtom :: BuilderEmits r m n => Expr r n -> m n (Atom r n) emitExprToAtom (Atom atom) = return atom emitExprToAtom expr = Var <$> emit expr {-# INLINE emitExprToAtom #-} @@ -680,13 +698,14 @@ buildLamExpr => (Abs (Binders r) any n) -> (forall l. (Emits l, Distinct l, DExt n l) => [AtomVar r l] -> m l (Atom r l)) -> m n (LamExpr r n) -buildLamExpr (Abs bs _) cont = case bs of - Empty -> LamExpr Empty <$> buildBlock (cont []) - Nest b rest -> do - Abs b' (LamExpr bs' body') <- buildAbs (getNameHint b) (binderType b) \v -> do - rest' <- instantiate (Abs (UnaryNest b) (EmptyAbs rest)) [Var v] - buildLamExpr rest' \vs -> cont $ sink v : vs - return $ LamExpr (Nest (PlainBD b') bs') body' +buildLamExpr (Abs bs _) cont = undefined +-- buildLamExpr (Abs bs _) cont = case bs of + -- Empty -> LamExpr Empty <$> buildBlock (cont []) + -- Nest b rest -> do + -- Abs b' (LamExpr bs' body') <- buildAbs (getNameHint b) (binderType b) \v -> do + -- rest' <- instantiate (Abs (UnaryNest b) (EmptyAbs rest)) [Var v] + -- buildLamExpr rest' \vs -> cont $ sink v : vs + -- return $ LamExpr (Nest (PlainBD b') bs') body' buildTopLamFromPi :: ScopableBuilder r m @@ -1030,44 +1049,39 @@ getProjRef i r = emitOp =<< mkProjRef r i -- ProjectElt atoms are always fully reduced (to avoid type errors between two -- equivalent types spelled differently). getUnpacked :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n [Atom r n] -getUnpacked atom = do - atom' <- cheapNormalize atom - ty <- return $ getType atom' - positions <- case ty of - ProdTy tys -> return $ void tys - DepPairTy _ -> return [(), ()] - _ -> error $ "not a product type: " ++ pprint ty - forM (enumerate positions) \(i, _) -> - normalizeProj (ProjectProduct i) atom' -{-# SCC getUnpacked #-} +getUnpacked atom = undefined +-- getUnpacked atom = do +-- atom' <- cheapNormalize atom +-- ty <- return $ getType atom' +-- positions <- case ty of +-- ProdTy tys -> return $ void tys +-- DepPairTy _ -> return [(), ()] +-- _ -> error $ "not a product type: " ++ pprint ty +-- forM (enumerate positions) \(i, _) -> +-- normalizeProj (ProjectProduct i) atom' +-- {-# SCC getUnpacked #-} getProj :: (Builder r m, Emits n) => Int -> Atom r n -> m n (Atom r n) -getProj i atom = do - atom' <- cheapNormalize atom - normalizeProj (ProjectProduct i) atom' +getProj i atom = undefined +-- getProj i atom = do +-- atom' <- cheapNormalize atom +-- normalizeProj (ProjectProduct i) atom' emitUnpacked :: (Builder r m, Emits n) => Atom r n -> m n [AtomVar r n] emitUnpacked tup = do xs <- getUnpacked tup forM xs \x -> emit $ Atom x -unwrapNewtype :: EnvReader m => CAtom n -> m n (CAtom n) -unwrapNewtype (NewtypeCon _ x) = return x -unwrapNewtype x = case getType x of - NewtypeTyCon con -> do - (_, ty) <- unwrapNewtypeType con - return $ ProjectElt ty UnwrapNewtype x - _ -> error "not a newtype" -{-# INLINE unwrapNewtype #-} - projectTuple :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n) -projectTuple i x = normalizeProj (ProjectProduct i) x +projectTuple i x = undefined +-- projectTuple i x = normalizeProj (ProjectProduct i) x projectStruct :: EnvReader m => Int -> CAtom n -> m n (CAtom n) -projectStruct i x = do - projs <- getStructProjections i (getType x) - normalizeNaryProj projs x -{-# INLINE projectStruct #-} +projectStruct i x = undefined +-- projectStruct i x = do +-- projs <- getStructProjections i (getType x) +-- normalizeNaryProj projs x +-- {-# INLINE projectStruct #-} projectStructRef :: (Builder CoreIR m, Emits n) => Int -> CAtom n -> m n (CAtom n) projectStructRef i x = do @@ -1085,27 +1099,27 @@ getStructProjections i (NewtypeTyCon (UserADTType _ tyConName _)) = do _ -> [ProjectProduct i, UnwrapNewtype] getStructProjections _ _ = error "not a struct" -mkApp :: EnvReader m => CAtom n -> [CAtom n] -> m n (CExpr n) +mkApp :: BuilderEmits CoreIR m n => CAtom n -> [CAtom n] -> m n (CExpr n) mkApp f xs = do et <- appEffTy (getType f) xs return $ App et f xs -mkTabApp :: (EnvReader m, IRRep r) => Atom r n -> [Atom r n] -> m n (Expr r n) +mkTabApp :: BuilderEmits r m n => Atom r n -> [Atom r n] -> m n (Expr r n) mkTabApp xs ixs = do ty <- typeOfTabApp (getType xs) ixs return $ TabApp ty xs ixs -mkTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (SExpr n) +mkTopApp :: BuilderEmits SimpIR m n => TopFunName n -> [SAtom n] -> m n (SExpr n) mkTopApp f xs = do resultTy <- typeOfTopApp f xs return $ TopApp resultTy f xs -mkApplyMethod :: EnvReader m => CAtom n -> Int -> [CAtom n] -> m n (CExpr n) +mkApplyMethod :: BuilderEmits CoreIR m n => CAtom n -> Int -> [CAtom n] -> m n (CExpr n) mkApplyMethod d i xs = do resultTy <- typeOfApplyMethod d i xs return $ ApplyMethod resultTy d i xs -mkDictAtom :: EnvReader m => DictExpr n -> m n (CAtom n) +mkDictAtom :: BuilderEmits CoreIR m n => DictExpr n -> m n (CAtom n) mkDictAtom d = do ty <- typeOfDictExpr d return $ DictCon ty d @@ -1146,41 +1160,41 @@ naryAppHinted :: (CBuilder m, Emits n) => NameHint -> CAtom n -> [CAtom n] -> m n (CAtom n) naryAppHinted hint f xs = Var <$> (mkApp f xs >>= emitHinted hint) -tabApp :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +tabApp :: BuilderEmits r m n => Atom r n -> Atom r n -> m n (Atom r n) tabApp x i = mkTabApp x [i] >>= emitExpr -naryTabApp :: (Builder r m, Emits n) => Atom r n -> [Atom r n] -> m n (Atom r n) +naryTabApp :: BuilderEmits r m n => Atom r n -> [Atom r n] -> m n (Atom r n) naryTabApp = naryTabAppHinted noHint {-# INLINE naryTabApp #-} -naryTabAppHinted :: (Builder r m, Emits n) +naryTabAppHinted :: BuilderEmits r m n => NameHint -> Atom r n -> [Atom r n] -> m n (Atom r n) naryTabAppHinted hint f xs = do expr <- mkTabApp f xs Var <$> emitHinted hint expr -indexRef :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +indexRef :: BuilderEmits r m n => Atom r n -> Atom r n -> m n (Atom r n) indexRef ref i = emitOp =<< mkIndexRef ref i -naryIndexRef :: (Builder r m, Emits n) => Atom r n -> [Atom r n] -> m n (Atom r n) +naryIndexRef :: BuilderEmits r m n => Atom r n -> [Atom r n] -> m n (Atom r n) naryIndexRef ref is = foldM indexRef ref is -ptrOffset :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +ptrOffset :: BuilderEmits r m n => Atom r n -> Atom r n -> m n (Atom r n) ptrOffset x (IdxRepVal 0) = return x ptrOffset x i = emitOp $ MemOp $ PtrOffset x i {-# INLINE ptrOffset #-} -unsafePtrLoad :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) +unsafePtrLoad :: BuilderEmits r m n => Atom r n -> m n (Atom r n) unsafePtrLoad x = do body <- liftEmitBuilder $ buildBlock $ emitOp . MemOp . PtrLoad =<< sinkM x emitHof $ RunIO body -mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (PrimOp r n) +mkIndexRef :: BuilderEmits r m n => Atom r n -> Atom r n -> m n (PrimOp r n) mkIndexRef ref i = do resultTy <- typeOfIndexRef (getType ref) i return $ RefOp ref $ IndexRef resultTy i -mkProjRef :: (EnvReader m, IRRep r) => Atom r n -> Projection -> m n (PrimOp r n) +mkProjRef :: BuilderEmits r m n => Atom r n -> Projection -> m n (PrimOp r n) mkProjRef ref i = do resultTy <- typeOfProjRef (getType ref) i return $ RefOp ref $ ProjRef resultTy i @@ -1380,8 +1394,8 @@ instance IRRep r => SubstE AtomSubstVal (TelescopeType r) instance IRRep r => HoistableE (TelescopeType r) telescopicCapture - :: (EnvReader m, HoistableE e, HoistableB b, IRRep r) - => b n l -> e l -> m l (Atom r l, ReconAbs r e n) + :: (ScopableBuilder r m, HoistableE e, HoistableB b) + => b n l -> e l -> m l (Block r l, ReconAbs r e n) telescopicCapture bs e = do vs <- localVarsAndTypeVars bs e vTys <- forM vs \v -> getType <$> toAtomVar v @@ -1389,7 +1403,7 @@ telescopicCapture bs e = do let vsSorted = map fst vsTysSorted ty <- liftEnvReaderM $ buildTelescopeTy vsTysSorted valsSorted <- forM vsSorted \v -> Var <$> toAtomVar v - result <- buildTelescopeVal valsSorted ty + result <- buildScoped $ buildTelescopeVal (sink <$> valsSorted) (sink ty) reconAbs <- return $ ignoreHoistFailure $ hoist bs do case abstractFreeVarsNoAnn vsSorted e of Abs bs' e' -> Abs (ReconBinders ty bs') e' @@ -1411,20 +1425,16 @@ buildTelescopeTy ((v,ty):xs) = do Abs b rhs' <- return $ abstractFreeVar v rhs case hoist b rhs' of HoistSuccess rhs'' -> return $ prependTelescopeTy ty rhs'' - HoistFailure _ -> return $ DepTelescope (ProdTelescope []) (Abs (BD (b:>ty)) rhs') + HoistFailure _ -> return $ DepTelescope (ProdTelescope []) (Abs (PlainBD (b:>ty)) rhs') prependTelescopeTy :: Type r n -> TelescopeType r n -> TelescopeType r n prependTelescopeTy x = \case DepTelescope lhs rhs -> DepTelescope (prependTelescopeTy x lhs) rhs ProdTelescope xs -> ProdTelescope (x:xs) -buildTelescopeVal - :: (EnvReader m, IRRep r) => [Atom r n] - -> TelescopeType r n -> m n (Atom r n) +buildTelescopeVal :: BuilderEmits r m n => [Atom r n] -> TelescopeType r n -> m n (Atom r n) buildTelescopeVal xsTop tyTop = fst <$> go tyTop xsTop where - go :: (EnvReader m, IRRep r) - => TelescopeType r n -> [Atom r n] - -> m n (Atom r n, [Atom r n]) + go :: BuilderEmits r m n => TelescopeType r n -> [Atom r n] -> m n (Atom r n, [Atom r n]) go ty rest = case ty of ProdTelescope tys -> do (xs, rest') <- return $ splitAt (length tys) rest @@ -1592,3 +1602,589 @@ applyFloatBinOp f x y = case (x, y) of _applyFloatUnOp :: (forall a. (Num a, Fractional a) => a -> a) -> Atom r n -> Atom r n _applyFloatUnOp f x = applyFloatBinOp (\_ -> f) (error "shouldn't be needed") x + + +-- === Stuff that used to be in Normalize/CheapReduce === + +-- We should figure out how to organize these things into modules but for now +-- I'm putting them in one file to avoid circular dependency issues. + +repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n) +repValAtom (RepVal ty tree) = case ty of + ProdTy ts -> case tree of + Branch trees -> ProdVal <$> mapM repValAtom (zipWith RepVal ts trees) + _ -> malformed + BaseTy _ -> case tree of + Leaf x -> case x of + ILit l -> return $ Con $ Lit l + _ -> fallback + _ -> malformed + _ -> fallback + where fallback = return $ RepValAtom $ RepVal ty tree + malformed = error "malformed repval" +{-# INLINE repValAtom #-} + +liftSimpType :: EnvReader m => SType n -> m n (CType n) +liftSimpType = \case + BaseTy t -> return $ BaseTy t + ProdTy ts -> ProdTy <$> mapM rec ts + SumTy ts -> SumTy <$> mapM rec ts + t -> error $ "not implemented: " ++ pprint t + where rec = liftSimpType +{-# INLINE liftSimpType #-} + +liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) +liftSimpFun (Pi piTy) f = return $ SimpInCore $ LiftSimpFun piTy f +liftSimpFun _ _ = error "not a pi type" + +unwrapNewtypeType :: BuilderEmits CoreIR m n => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n) +unwrapNewtypeType = \case + Nat -> return (NatCon, IdxRepTy) + Fin n -> return (FinCon n, NatTy) + UserADTType sn defName params -> do + def <- lookupTyCon defName + ty' <- dataDefRep <$> instantiateTyConDef def params + return (UserADTData sn defName params, ty') + ty -> error $ "Shouldn't be projecting: " ++ pprint ty +{-# INLINE unwrapNewtypeType #-} + +unwrapLeadingNewtypesType :: BuilderEmits CoreIR m n => CType n -> m n ([NewtypeCon n], CType n) +unwrapLeadingNewtypesType = \case + NewtypeTyCon tyCon -> do + (dataCon, ty) <- unwrapNewtypeType tyCon + (dataCons, ty') <- unwrapLeadingNewtypesType ty + return (dataCon:dataCons, ty') + ty -> return ([], ty) + +wrapNewtypesData :: [NewtypeCon n] -> CAtom n-> CAtom n +wrapNewtypesData [] x = x +wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x + +instantiateTyConDef :: BuilderEmits CoreIR m n => TyConDef n -> TyConParams n -> m n (DataConDefs n) +instantiateTyConDef tyConDef (TyConParams _ xs) = instantiate tyConDef xs +{-# INLINE instantiateTyConDef #-} + +assumeConst + :: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> body n +assumeConst e = case toAbs e of Abs bs body -> ignoreHoistFailure $ hoist bs body + +arity :: (IRRep r, ToBindersAbs e body r) => e n -> Int +arity e = case toAbs e of Abs bs _ -> nestLength bs + +tryAsConst + :: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> Maybe (body n) +tryAsConst e = + case toAbs e of + Abs bs body -> case hoist bs body of + HoistFailure _ -> Nothing + HoistSuccess e' -> Just e' + +instantiate + :: (Builder r m, Emits n, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, SinkableE e, + ToBindersAbs e body r, Ext h n) + => e h -> [Atom r n] -> m n (body n) +instantiate e xs = undefined +-- instantiate e xs = do +-- Abs bs body <- sinkM $ toAbs e +-- let bs' = fmapNest (\(BD b) -> b) bs +-- applySubst (bs' @@> (SubstVal <$> xs)) body +-- {-# INLINE instantiate #-} + +-- "lazy" subst-extending version of `instantiate` +withInstantiated + :: (SubstReader AtomSubstVal m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r) + => e i -> [Atom r o] + -> (forall i'. body i' -> m i' o a) + -> m i o a +withInstantiated e xs cont = undefined +-- withInstantiated e xs cont = do +-- Abs bs body <- return $ toAbs e +-- let bs' = fmapNest (\(BD b) -> b) bs +-- extendSubst (bs' @@> (SubstVal <$> xs)) $ cont body + +instantiateNames + :: (EnvReader m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r, Ext h n) + => e h -> [AtomName r n] -> m n (body n) +instantiateNames e vs = undefined +-- instantiateNames e vs = do +-- Abs bs body <- sinkM $ toAbs e +-- let bs' = fmapNest (\(BD b) -> b) bs +-- applyRename (bs' @@> vs) body + +-- "lazy" subst-extending version of `instantiateNames` +withInstantiatedNames + :: (SubstReader Name m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r) + => e i -> [AtomName r o] + -> (forall i'. body i' -> m i' o a) + -> m i o a +withInstantiatedNames e vs cont = undefined +-- withInstantiatedNames e vs cont = do +-- Abs bs body <- return $ toAbs e +-- let bs' = fmapNest (\(BD b) -> b) bs +-- extendRenamer (bs' @@> vs) $ cont body + +extendSubstBD + :: forall v m b r i i' o a + . (SubstReader v m, ToBinders b r, IRRep r) + => b i i' -> [v (AtomNameC r) o] -> m i' o a -> m i o a +extendSubstBD bsTop xsTop contTop = undefined -- go (toBinders bsTop) xsTop contTop +-- extendSubstBD bsTop xsTop contTop = go (toBinders bsTop) xsTop contTop +-- where +-- go :: Binders r ii ii' -> [v (AtomNameC r) o] -> m ii' o a -> m ii o a +-- go Empty [] cont = cont +-- go (Nest (BD b) bs) (x:xs) cont = extendSubst (b@>x) $ go bs xs cont +-- go _ _ _ = error "zip error" +-- {-# INLINE extendSubstBD #-} + +-- Returns a representation type (type of an TypeCon-typed Newtype payload) +-- given a list of instantiated DataConDefs. +dataDefRep :: DataConDefs n -> CType n +dataDefRep (ADTCons cons) = case cons of + [] -> error "unreachable" -- There's no representation for a void type + [DataConDef _ _ ty _] -> ty + tys -> SumTy $ tys <&> \(DataConDef _ _ ty _) -> ty +dataDefRep (StructFields fields) = case map snd fields of + [ty] -> ty + tys -> ProdTy tys + +makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n) +makeStructRepVal tyConName args = do + TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName + case fields of + [_] -> case args of + [arg] -> return arg + _ -> error "wrong number of args" + _ -> return $ ProdVal args + +-- === Exposed helpers for querying types and effects === + +caseAltsBinderTys :: BuilderEmits r m n => Type r n -> m n [Type r n] +caseAltsBinderTys ty = case ty of + SumTy types -> return types + NewtypeTyCon t -> case t of + UserADTType _ defName params -> do + def <- lookupTyCon defName + ~(ADTCons cons) <- instantiateTyConDef def params + return [repTy | DataConDef _ _ repTy _ <- cons] + _ -> error msg + _ -> error msg + where msg = "Case analysis only supported on ADTs, not on " ++ pprint ty + +extendEffect :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n +extendEffect eff (EffectRow effs t) = EffectRow (effs <> eSetSingleton eff) t + +blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) +blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do + effs <- declsEffects decls mempty + return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result + where + declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) + declsEffects Empty !acc = return acc + declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do + expr' <- sinkM expr + declsEffects rest $ acc <> getEffects expr' + +blockTy :: (EnvReader m, IRRep r) => Block r n -> m n (Type r n) +blockTy b = blockEffTy b <&> \(EffTy _ t) -> t + +piTypeWithoutDest :: PiType SimpIR n -> PiType SimpIR n +piTypeWithoutDest (PiType bsRefB _) = + case popNest bsRefB of + Just (PairB bs refB) -> do + case binderType refB of + RawRefTy ansTy -> PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here + _ -> error "expected ref type" + _ -> error "expected trailing binder" + +blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) +blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff + +typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) +typeOfApp (Pi piTy) xs = withSubstReaderT $ + withInstantiated piTy xs \(EffTy _ ty) -> substM ty +typeOfApp _ _ = error "expected a pi type" + +typeOfTabApp :: BuilderEmits r m n => Type r n -> [Atom r n] -> m n (Type r n) +typeOfTabApp t [] = return t +typeOfTabApp (TabPi tabTy) (i:rest) = do + resultTy <- instantiate tabTy [i] + typeOfTabApp resultTy rest +typeOfTabApp ty _ = error $ "expected a table type. Got: " ++ pprint ty + +typeOfApplyMethod :: BuilderEmits CoreIR m n => CAtom n -> Int -> [CAtom n] -> m n (EffTy CoreIR n) +typeOfApplyMethod d i args = do + ty <- Pi <$> getMethodType d i + appEffTy ty args + +typeOfDictExpr :: BuilderEmits CoreIR m n => DictExpr n -> m n (CType n) +typeOfDictExpr e = liftEmitBuilder case e of + InstanceDict instanceName args -> do + instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName + sourceName <- getSourceName <$> lookupClassDef className + PairE (ListE params) _ <- instantiate instanceDef args + return $ DictTy $ DictType sourceName className params + InstantiatedGiven given args -> typeOfApp (getType given) args + SuperclassProj d i -> do + DictTy (DictType _ className params) <- return $ getType d + classDef <- lookupClassDef className + withSubstReaderT $ withInstantiated classDef params \(Abs superclasses _) -> do + substM $ getSuperclassType REmpty superclasses i + IxFin n -> liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n + DataData ty -> DictTy <$> dataDictType ty + +typeOfTopApp :: BuilderEmits SimpIR m n => TopFunName n -> [SAtom n] -> m n (EffTy SimpIR n) +typeOfTopApp f xs = do + piTy <- getTypeTopFun f + instantiate piTy xs + +typeOfIndexRef :: BuilderEmits r m n => Type r n -> Atom r n -> m n (Type r n) +typeOfIndexRef (TC (RefType h s)) i = do + TabPi tabPi <- return s + eltTy <- instantiate tabPi [i] + return $ TC $ RefType h eltTy +typeOfIndexRef _ _ = error "expected a ref type" + +typeOfProjRef :: BuilderEmits r m n => Type r n -> Projection -> m n (Type r n) +typeOfProjRef (TC (RefType h s)) p = do + TC . RefType h <$> case p of + ProjectProduct i -> do + ~(ProdTy tys) <- return s + return $ tys !! i + UnwrapNewtype -> do + case s of + NewtypeTyCon tc -> snd <$> unwrapNewtypeType tc + _ -> error "expected a newtype" +typeOfProjRef _ _ = error "expected a reference" + +appEffTy :: BuilderEmits r m n => Type r n -> [Atom r n] -> m n (EffTy r n) +appEffTy (Pi piTy) xs = instantiate piTy xs +appEffTy t _ = error $ "expected a pi type, got: " ++ pprint t + +partialAppType :: BuilderEmits r m n => Type r n -> [Atom r n] -> m n (Type r n) +partialAppType (Pi (CorePiType appExpl expls bs effTy)) xs = do + (_, expls2) <- return $ splitAt (length xs) expls + PairB bs1 bs2 <- return $ splitNestAt (length xs) bs + instantiate (Abs bs1 (Pi $ CorePiType appExpl expls2 bs2 effTy)) xs +partialAppType _ _ = error "expected a pi type" + +effTyOfHof :: BuilderEmits r m n => Hof r n -> m n (EffTy r n) +effTyOfHof hof = EffTy <$> hofEffects hof <*> typeOfHof hof + +typeOfHof :: BuilderEmits r m n => Hof r n -> m n (Type r n) +typeOfHof = \case + For _ ixTy f -> getLamExprType f >>= \case + PiType (UnaryNest b) (EffTy _ eltTy) -> return $ TabTy (ixTypeDict ixTy) b eltTy + _ -> error "expected a unary pi type" + While _ -> return UnitTy + Linearize f _ -> getLamExprType f >>= \case + PiType (UnaryNest binder) (EffTy Pure b) -> do + let b' = ignoreHoistFailure $ hoist binder b + let fLinTy = Pi $ nonDepPiType [binderType binder] Pure b' + return $ PairTy b' fLinTy + _ -> error "expected a unary pi type" + Transpose f _ -> getLamExprType f >>= \case + PiType (UnaryNest b) _ -> return $ binderType b + _ -> error "expected a unary pi type" + RunReader _ f -> do + (resultTy, _) <- getTypeRWSAction f + return resultTy + RunWriter _ _ f -> uncurry PairTy <$> getTypeRWSAction f + RunState _ _ f -> do + (resultTy, stateTy) <- getTypeRWSAction f + return $ PairTy resultTy stateTy + RunIO f -> blockTy f + RunInit f -> blockTy f + CatchException ty _ -> return ty + +hofEffects :: (EnvReader m, IRRep r) => Hof r n -> m n (EffectRow r n) +hofEffects = \case + For _ _ f -> functionEffs f + While body -> blockEff body + Linearize _ _ -> return Pure -- Body has to be a pure function + Transpose _ _ -> return Pure -- Body has to be a pure function + RunReader _ f -> rwsFunEffects Reader f + RunWriter d _ f -> maybeInit d <$> rwsFunEffects Writer f + RunState d _ f -> maybeInit d <$> rwsFunEffects State f + RunIO f -> deleteEff IOEffect <$> blockEff f + RunInit f -> deleteEff InitEffect <$> blockEff f + CatchException _ f -> deleteEff ExceptionEffect <$> blockEff f + where maybeInit :: IRRep r => Maybe (Atom r i) -> (EffectRow r o -> EffectRow r o) + maybeInit d = case d of Just _ -> (<>OneEffect InitEffect); Nothing -> id + +deleteEff :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n +deleteEff eff (EffectRow effs t) = EffectRow (effs `eSetDifference` eSetSingleton eff) t + +getMethodIndex :: EnvReader m => ClassName n -> SourceName -> m n Int +getMethodIndex className methodSourceName = do + ClassDef _ methodNames _ _ _ _ _ <- lookupClassDef className + case elemIndex methodSourceName methodNames of + Nothing -> error $ methodSourceName ++ " is not a method of " ++ pprint className + Just i -> return i +{-# INLINE getMethodIndex #-} + +getUVarType :: BuilderEmits CoreIR m n => UVar n -> m n (CType n) +getUVarType = \case + UAtomVar v -> getType <$> toAtomVar v + UTyConVar v -> getTyConNameType v + UDataConVar v -> getDataConNameType v + UPunVar v -> getStructDataConType v + UClassVar v -> do + ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef v + return $ Pi $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind + UMethodVar v -> getMethodNameType v + UEffectVar _ -> error "not implemented" + UEffectOpVar _ -> error "not implemented" + +getMethodNameType :: BuilderEmits CoreIR m n => MethodName n -> m n (CType n) +getMethodNameType v = liftEmitBuilder $ lookupEnv v >>= \case + MethodBinding className i -> do + ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className + refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' absPiTy -> do + let params = Var <$> bindersVars paramBs' + dictTy <- DictTy <$> dictType (sink className) params + withFreshBinderAndDecls noHint dictTy \dict -> do + piTy <- instantiate absPiTy =<< getSuperclassDicts (Var dict) + CorePiType appExpl methodExpls methodBs effTy <- return piTy + let paramExpls = paramNames <&> \name -> Inferred name Unify + let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls + return \dictB -> + Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy + +withFreshBinderAndDecls + :: (IRRep r, EnvExtender m, ScopableBuilder r m) + => NameHint -> Type r n + -> (forall l. (DExt n l, Emits l) => AtomVar r l -> m l (BinderAndDecls r n l -> a)) + -> m n a +withFreshBinderAndDecls _ _ _ = undefined + +-- dictTy <- DictTy <$> dictType (sink className) params + -- withFreshBinder noHint dictTy \dictB -> do + -- scDicts <- getSuperclassDicts (Var $ binderVar dictB) + -- CorePiType appExpl methodExpls methodBs effTy <- instantiate absPiTy scDicts + -- let paramExpls = paramNames <&> \name -> Inferred name Unify + -- let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls + -- return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest (PlainBD dictB) >>> methodBs) effTy + + +getMethodType :: BuilderEmits CoreIR m n => Dict n -> Int -> m n (CorePiType n) +getMethodType dict i = liftEmitBuilder $ withSubstReaderT do + ~(DictTy (DictType _ className params)) <- return $ getType dict + superclassDicts <- getSuperclassDicts dict + classDef <- lookupClassDef className + withInstantiated classDef params \ab -> do + withInstantiated ab superclassDicts \(ListE methodTys) -> + substM $ methodTys !! i + +getTyConNameType :: EnvReader m => TyConName n -> m n (Type CoreIR n) +getTyConNameType v = do + TyConDef _ expls bs _ <- lookupTyCon v + case bs of + Empty -> return TyKind + _ -> return $ Pi $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind + +getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n) +getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do + (tyCon, i) <- lookupDataCon dataCon + tyConDef <- lookupTyCon tyCon + buildDataConType tyConDef \expls paramBs' paramVs params -> do + withInstantiatedNames tyConDef paramVs \(ADTCons dataCons) -> do + DataConDef _ ab _ _ <- renameM (dataCons !! i) + refreshAbs ab \dataBs UnitE -> do + let appExpl = case dataBs of Empty -> ImplicitApp + _ -> ExplicitApp + let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) (sink params) + let dataExpls = nestToList (const $ Explicit) dataBs + return $ Pi $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) + +getStructDataConType :: EnvReader m => TyConName n -> m n (CType n) +getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do + tyConDef <- lookupTyCon tyCon + buildDataConType tyConDef \expls paramBs' paramVs params -> do + withInstantiatedNames tyConDef paramVs \(StructFields fields) -> do + fieldTys <- forM fields \(_, t) -> renameM t + let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) params + Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy + let dataExpls = nestToList (const Explicit) dataBs + return $ Pi $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') + +buildDataConType + :: (EnvReader m, EnvExtender m) + => TyConDef n + -> (forall l. DExt n l => [Explicitness] -> CBinders n l -> [CAtomName l] -> TyConParams l -> m l a) + -> m n a +buildDataConType (TyConDef _ roleExpls bs _) cont = do + let expls = snd <$> roleExpls + expls' <- forM expls \case + Explicit -> return $ Inferred Nothing Unify + expl -> return $ expl + refreshAbs (Abs bs UnitE) \bs' UnitE -> do + let vs = bindersVars bs' + cont expls' bs' (atomVarName <$> vs) $ TyConParams expls (Var <$> vs) + +makeTyConParams :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n) +makeTyConParams tc params = do + TyConDef _ expls _ _ <- lookupTyCon tc + return $ TyConParams (map snd expls) params + +getDataClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) +getDataClassName = lookupSourceMap "Data" >>= \case + Nothing -> throw CompilerErr $ "Data interface needed but not defined!" + Just (UClassVar v) -> return v + Just _ -> error "not a class var" + +dataDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) +dataDictType ty = do + dataClassName <- getDataClassName + dictType dataClassName [Type ty] + +getIxClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) +getIxClassName = lookupSourceMap "Ix" >>= \case + Nothing -> throw CompilerErr $ "Ix interface needed but not defined!" + Just (UClassVar v) -> return v + Just _ -> error "not a class var" + +dictType :: EnvReader m => ClassName n -> [CAtom n] -> m n (DictType n) +dictType className params = do + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className + return $ DictType sourceName className params + +ixDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) +ixDictType ty = do + ixClassName <- getIxClassName + dictType ixClassName [Type ty] + +makePreludeMaybeTy :: EnvReader m => CType n -> m n (CType n) +makePreludeMaybeTy ty = do + ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" + return $ TypeCon "Maybe" tyConName $ TyConParams [Explicit] [Type ty] + +-- === computing effects === + +functionEffs :: (IRRep r, EnvReader m) => LamExpr r n -> m n (EffectRow r n) +functionEffs f = getLamExprType f >>= \case + PiType b (EffTy effs _) -> return $ ignoreHoistFailure $ hoist b effs + +rwsFunEffects :: (IRRep r, EnvReader m) => RWS -> LamExpr r n -> m n (EffectRow r n) +rwsFunEffects rws f = liftEnvReaderM $ getLamExprType f >>= \case + PiType (BinaryNest h ref) et -> do + let effs' = ignoreHoistFailure $ hoist ref (etEff et) + refreshAbs (Abs h effs') \h' effs'' -> do + let hVal = Var $ binderVar h' + let effs''' = deleteEff (RWSEffect rws hVal) effs'' + return $ ignoreHoistFailure $ hoist h' effs''' + _ -> error "Expected a binary function type" + +getLamExprType :: (IRRep r, EnvReader m) => LamExpr r n -> m n (PiType r n) +getLamExprType (LamExpr bs body) = liftEnvReaderM $ + refreshAbs (Abs bs body) \bs' body' -> do + effTy <- blockEffTy body' + return $ PiType bs' effTy + +getTypeRWSAction :: (IRRep r, EnvReader m) => LamExpr r n -> m n (Type r n, Type r n) +getTypeRWSAction f = getLamExprType f >>= \case + PiType (BinaryNest regionBinder refBinder) (EffTy _ resultTy) -> do + case binderType refBinder of + RefTy _ referentTy -> do + let referentTy' = ignoreHoistFailure $ hoist regionBinder referentTy + let resultTy' = ignoreHoistFailure $ hoist (PairB regionBinder refBinder) resultTy + return (resultTy', referentTy') + _ -> error "expected a ref" + _ -> error "expected a pi type" + +getSuperclassDicts :: BuilderEmits CoreIR m n => CAtom n -> m n ([CAtom n]) +getSuperclassDicts dict = do + case getType dict of + DictTy dTy -> do + ts <- getSuperclassTys dTy + forM (enumerate ts) \(i, t) -> return $ DictCon t $ SuperclassProj dict i + _ -> error "expected a dict type" + +getSuperclassTys :: BuilderEmits CoreIR m n => DictType n -> m n [CType n] +getSuperclassTys (DictType _ className params) = do + ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className + forM [0 .. nestLength superclasses - 1] \i -> do + instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params + +getSuperclassType :: RNest (BinderAndDecls CoreIR) n l -> CBinders l l' -> Int -> CType n +getSuperclassType _ Empty = error "bad index" +getSuperclassType bsAbove (Nest b bs) = \case + 0 -> ignoreHoistFailure $ hoist bsAbove (binderType b) + i -> getSuperclassType (RNest bsAbove b) bs (i-1) + + +getTypeTopFun :: EnvReader m => TopFunName n -> m n (PiType SimpIR n) +getTypeTopFun f = lookupTopFun f >>= \case + DexTopFun _ (TopLam _ piTy _) _ -> return piTy + FFITopFun _ iTy -> liftIFunType iTy + +asTopLam :: (EnvReader m, IRRep r) => LamExpr r n -> m n (TopLam r n) +asTopLam lam = do + piTy <- getLamExprType lam + return $ TopLam False piTy lam + +liftIFunType :: (IRRep r, EnvReader m) => IFunType -> m n (PiType r n) +liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where + go :: IRRep r => [BaseType] -> EnvReaderM n (PiType r n) + go = \case + [] -> return $ PiType Empty (EffTy (OneEffect IOEffect) resultTy) + where resultTy = case resultTys of + [] -> UnitTy + [t] -> BaseTy t + [t1, t2] -> PairTy (BaseTy t1) (BaseTy t2) + _ -> error $ "Not a valid FFI return type: " ++ pprint resultTys + t:ts -> withFreshBinder noHint (BaseTy t) \b -> do + PiType bs effTy <- go ts + return $ PiType (Nest (PlainBD b) bs) effTy + +unwrapNewtype :: EnvReader m => CAtom n -> m n (CAtom n) +unwrapNewtype (NewtypeCon _ x) = return x +-- unwrapNewtype x = case getType x of +-- NewtypeTyCon con -> do +-- (_, ty) <- unwrapNewtypeType con +-- return $ ProjectElt ty UnwrapNewtype x +-- _ -> error "not a newtype" +-- {-# INLINE unwrapNewtype #-} + +-- === Data constraints === + +isData :: EnvReader m => Type CoreIR n -> m n Bool +isData ty = do + result <- liftEnvReaderT $ withSubstReaderT $ checkDataLike ty + case runFallibleM result of + Success () -> return True + Failure _ -> return False + +checkDataLike :: Type CoreIR i -> SubstReaderT Name FallibleEnvReaderM i o () +checkDataLike ty = undefined +-- checkDataLike ty = case ty of +-- TyVar _ -> notData +-- TabPi (TabPiType _ b eltTy) -> do +-- renameBinders b \_ -> +-- checkDataLike eltTy +-- DepPairTy (DepPairType _ b r) -> do +-- recur $ binderType b +-- renameBinders b \_ -> checkDataLike r +-- NewtypeTyCon nt -> do +-- (_, ty') <- unwrapNewtypeType =<< renameM nt +-- dropSubst $ recur ty' +-- TC con -> case con of +-- BaseType _ -> return () +-- ProdType as -> mapM_ recur as +-- SumType cs -> mapM_ recur cs +-- RefType _ _ -> return () +-- HeapType -> return () +-- _ -> notData +-- _ -> notData +-- where +-- recur = checkDataLike +-- notData = throw TypeErr $ pprint ty + +checkExtends :: (Fallible m, IRRep r) => EffectRow r n -> EffectRow r n -> m () +checkExtends allowed (EffectRow effs effTail) = do + let (EffectRow allowedEffs allowedEffTail) = allowed + case effTail of + EffectRowTail _ -> assertEq allowedEffTail effTail "" + NoTail -> return () + forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $ + throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ + "\nAllowed: " ++ pprint allowed diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs deleted file mode 100644 index ff55abdf7..000000000 --- a/src/lib/CheapReduction.hs +++ /dev/null @@ -1,958 +0,0 @@ --- Copyright 2021 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -Wno-orphans #-} - -module CheapReduction - ( CheaplyReducibleE (..), cheapReduce, cheapReduceWithDecls, cheapNormalize - , normalizeProj, asNaryProj, normalizeNaryProj - , depPairLeftTy, instantiateTyConDef - , dataDefRep, unwrapNewtypeType, repValAtom - , unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType - , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) - , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 - , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated - , instantiateNames, withInstantiatedNames, assumeConst, tryAsConst - , extendSubstBD, arity) - where - -import Control.Applicative -import Control.Monad.Trans -import Control.Monad.Writer.Strict hiding (Alt) -import Control.Monad.State.Strict -import Control.Monad.Reader -import Data.Foldable (toList) -import Data.Functor.Identity -import Data.Functor ((<&>)) -import qualified Data.List.NonEmpty as NE -import qualified Data.Map.Strict as M - -import Subst -import Core -import Err -import IRVariants -import MTL1 -import Name -import PPrint () -import QueryTypePure -import Types.Core -import Types.Imp -import Types.Primitives -import Util -import {-# SOURCE #-} Inference (trySynthTerm) - --- Carry out the reductions we are willing to carry out during type --- inference. The goal is to support type aliases like `Int = Int32` --- and type-level functions like `def Fin (n:Int) : Type = Range 0 n`. --- The reductions in question are mostly inlining and beta-reducing --- functions. There's also a bunch of stuff to do with synthesizing --- class dictionaries, because types often contain polymorphic --- literals (e.g., `Fin 10`). - --- === api === - -type NiceE r e = (HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e, IRRep r) - -cheapReduce :: forall r e' e m n - . (EnvReader m, CheaplyReducibleE r e e', NiceE r e, NiceE r e') - => e n -> m n (Maybe (e' n)) -cheapReduce e = liftCheapReducerM idSubst $ cheapReduceE e -{-# INLINE cheapReduce #-} -{-# SCC cheapReduce #-} - -cheapReduceWithDecls - :: forall r e' e m n l - . ( CheaplyReducibleE r e e', NiceE r e', NiceE r e, EnvReader m ) - => Nest (Decl r) n l -> e l -> m n (Maybe (e' n)) -cheapReduceWithDecls decls result = do - Abs decls' result' <- sinkM $ Abs decls result - liftCheapReducerM idSubst $ - cheapReduceWithDeclsB decls' $ - cheapReduceE result' -{-# INLINE cheapReduceWithDecls #-} -{-# SCC cheapReduceWithDecls #-} - -cheapNormalize :: (EnvReader m, CheaplyReducibleE r e e, NiceE r e) => e n -> m n (e n) -cheapNormalize a = cheapReduce a >>= \case - Just ans -> return ans - _ -> error "couldn't normalize expression" -{-# INLINE cheapNormalize #-} - --- === internal === - -newtype CheapReducerM (r::IR) (i :: S) (o :: S) (a :: *) = - CheapReducerM - (SubstReaderT AtomSubstVal - (MaybeT1 - (ScopedT1 (MapE (AtomName r) (MaybeE (Atom r))) - (EnvReaderT Identity))) i o a) - deriving (Functor, Applicative, Monad, Alternative) - -deriving instance IRRep r => ScopeReader (CheapReducerM r i) -deriving instance IRRep r => EnvReader (CheapReducerM r i) -deriving instance IRRep r => EnvExtender (CheapReducerM r i) -deriving instance IRRep r => SubstReader AtomSubstVal (CheapReducerM r) - -class ( Alternative2 m, SubstReader AtomSubstVal m - , EnvReader2 m, EnvExtender2 m) => CheapReducer m r | m -> r where - updateCache :: AtomName r o -> Maybe (Atom r o) -> m i o () - lookupCache :: AtomName r o -> m i o (Maybe (Maybe (Atom r o))) - -instance IRRep r => CheapReducer (CheapReducerM r) r where - updateCache v u = CheapReducerM $ SubstReaderT $ lift $ lift11 $ - modify (MapE . M.insert v (toMaybeE u) . fromMapE) - lookupCache v = CheapReducerM $ SubstReaderT $ lift $ lift11 $ - fmap fromMaybeE <$> gets (M.lookup v . fromMapE) - -liftCheapReducerM - :: (EnvReader m, IRRep r) - => Subst AtomSubstVal i o -> CheapReducerM r i o a - -> m o (Maybe a) -liftCheapReducerM subst (CheapReducerM m) = do - liftM runIdentity $ liftEnvReaderT $ runScopedT1 - (runMaybeT1 $ runSubstReaderT subst m) mempty -{-# INLINE liftCheapReducerM #-} - -cheapReduceWithDeclsB - :: NiceE r e - => Nest (Decl r) i i' - -> (forall o'. Ext o o' => CheapReducerM r i' o' (e o')) - -> CheapReducerM r i o (e o) -cheapReduceWithDeclsB decls cont = do - Abs irreducibleDecls result <- cheapReduceWithDeclsRec decls cont - case hoist irreducibleDecls result of - HoistSuccess result' -> return result' - HoistFailure _ -> empty - -cheapReduceWithDeclsRec - :: NiceE r e - => Nest (Decl r) i i' - -> (forall o'. Ext o o' => CheapReducerM r i' o' (e o')) - -> CheapReducerM r i o (Abs (Nest (Decl r)) e o) -cheapReduceWithDeclsRec decls cont = case decls of - Empty -> Abs Empty <$> cont - Nest (Let b binding@(DeclBinding _ expr)) rest -> do - optional (cheapReduceE expr) >>= \case - Nothing -> do - binding' <- substM binding - withFreshBinder (getNameHint b) binding' \(b':>_) -> do - updateCache (binderName b') Nothing - extendSubst (b@>Rename (binderName b')) do - Abs decls' result <- cheapReduceWithDeclsRec rest cont - return $ Abs (Nest (Let b' binding') decls') result - Just x -> - extendSubst (b@>SubstVal x) $ - cheapReduceWithDeclsRec rest cont - -cheapReduceName :: forall c r i o . (IRRep r, Color c) => Name c o -> CheapReducerM r i o (AtomSubstVal c o) -cheapReduceName v = - case eqColorRep @c @(AtomNameC r) of - Just ColorsEqual -> - lookupEnv v >>= \case - AtomNameBinding binding -> cheapReduceAtomBinding v binding - Nothing -> stuck - where stuck = return $ Rename v - -cheapReduceAtomBinding - :: forall r i o. IRRep r - => AtomName r o -> AtomBinding r o -> CheapReducerM r i o (AtomSubstVal (AtomNameC r) o) -cheapReduceAtomBinding v = \case - LetBound (DeclBinding _ e) -> do - cachedVal <- lookupCache v >>= \case - Nothing -> do - result <- optional (dropSubst $ cheapReduceE e) - updateCache v result - return result - Just result -> return result - case cachedVal of - Nothing -> stuck - Just ans -> return $ SubstVal ans - _ -> stuck - where stuck = return $ Rename v - -class CheaplyReducibleE (r::IR) (e::E) (e'::E) | e -> e', e -> r where - cheapReduceE :: e i -> CheapReducerM r i o (e' o) - -instance IRRep r => CheaplyReducibleE r (Atom r) (Atom r) where - cheapReduceE :: forall i o. Atom r i -> CheapReducerM r i o (Atom r o) - cheapReduceE a = confuseGHC >>= \_ -> case a of - -- Don't try to eagerly reduce lambda bodies. We might get stuck long before - -- we have a chance to apply tham. Also, recursive traversal of those bodies - -- means that we will follow the full call chain, so it's really expensive! - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - Lam _ -> substM a - DictHole ctx ty' access -> do - ty <- cheapReduceE ty' - runFallibleT1 (trySynthTerm ty access) >>= \case - Success d -> return d - Failure _ -> return $ DictHole ctx ty access - -- We traverse the Atom constructors that might contain lambda expressions - -- explicitly, to make sure that we can skip normalizing free vars inside those. - Con con -> Con <$> traverseOp con cheapReduceE cheapReduceE (error "unexpected lambda") - DictCon t d -> do - t' <- cheapReduceE t - cheapReduceDictExpr t' d - SimpInCore (LiftSimp t x) -> do - t' <- cheapReduceE t - x' <- substM x - liftSimpAtom t' x' - -- These two are a special-case hack. TODO(dougalm): write a traversal over - -- the NewtypeTyCon (or types generally) - NewtypeCon NatCon n -> NewtypeCon NatCon <$> cheapReduceE n - -- Do recursive reduction via substitution - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - _ -> do - a' <- substM a - dropSubst $ traverseNames cheapReduceName a' - -instance IRRep r => CheaplyReducibleE r (Type r) (Type r) where - cheapReduceE :: forall i o. Type r i -> CheapReducerM r i o (Type r o) - cheapReduceE a = case a of - -- Don't try to eagerly reduce lambda bodies. We might get stuck long before - -- we have a chance to apply tham. Also, recursive traversal of those bodies - -- means that we will follow the full call chain, so it's really expensive! - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - TabPi (TabPiType d b resultTy) -> do - d' <- cheapReduceE d - cheapReduceBinder b \b' -> TabPi <$> TabPiType d' b' <$> cheapReduceE resultTy - -- We traverse the Atom constructors that might contain lambda expressions - -- explicitly, to make sure that we can skip normalizing free vars inside those. - NewtypeTyCon (Fin n) -> NewtypeTyCon . Fin <$> cheapReduceE n - -- Do recursive reduction via substitution - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - _ -> do - a' <- substM a - dropSubst $ traverseNames cheapReduceName a' - -cheapReduceBinder - :: IRRep r - => BinderAndDecls r i i' - -> (forall o'. DExt o o' => BinderAndDecls r o o' -> CheapReducerM r i' o' a) - -> CheapReducerM r i o a -cheapReduceBinder (BD (b:>ty)) cont = do - ty' <- cheapReduceE ty - withFreshBinder (getNameHint b) ty' \b' -> do - extendSubst (b@>Rename (binderName b')) $ cont (BD b') - -cheapReduceDictExpr :: CType o -> DictExpr i -> CheapReducerM CoreIR i o (CAtom o) -cheapReduceDictExpr resultTy d = case d of - SuperclassProj child superclassIx -> do - cheapReduceE child >>= \case - DictCon _ (InstanceDict instanceName args) -> dropSubst do - args' <- mapM cheapReduceE args - InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName - let InstanceBody superclasses _ = body - instantiate (Abs bs (superclasses !! superclassIx)) args' - child' -> return $ DictCon resultTy $ SuperclassProj child' superclassIx - InstantiatedGiven f xs -> - reduceApp <|> justSubst - where reduceApp = do - f' <- cheapReduceE f - xs' <- mapM cheapReduceE (toList xs) - cheapReduceApp f' xs' - InstanceDict _ _ -> justSubst - IxFin _ -> justSubst - DataData ty -> DictCon resultTy . DataData <$> cheapReduceE ty - where justSubst = DictCon resultTy <$> substM d - -instance CheaplyReducibleE CoreIR TyConParams TyConParams where - cheapReduceE (TyConParams infs ps) = - TyConParams infs <$> mapM cheapReduceE ps - -instance (CheaplyReducibleE r e e', NiceE r e') => CheaplyReducibleE r (Abs (Nest (Decl r)) e) e' where - cheapReduceE (Abs decls result) = cheapReduceWithDeclsB decls $ cheapReduceE result - -instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where - cheapReduceE expr = confuseGHC >>= \_ -> case expr of - Atom atom -> cheapReduceE atom - App _ f' xs' -> do - xs <- mapM cheapReduceE xs' - f <- cheapReduceE f' - cheapReduceApp f xs - -- TODO: Make sure that this wraps correctly - -- TODO: Other casts? - PrimOp (MiscOp (CastOp ty' val')) -> do - ty <- cheapReduceE ty' - case ty of - BaseTy (Scalar Word32Type) -> do - val <- cheapReduceE val' - case val of - Con (Lit (Word64Lit v)) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v - _ -> empty - _ -> empty - ApplyMethod _ dict i explicitArgs -> do - explicitArgs' <- mapM cheapReduceE explicitArgs - cheapReduceE dict >>= \case - DictCon _ (InstanceDict instanceName args) -> dropSubst do - args' <- mapM cheapReduceE args - def <- lookupInstanceDef instanceName - withInstantiated def args' \(PairE _ (InstanceBody _ methods)) -> do - method' <- cheapReduceE $ methods !! i - cheapReduceApp method' explicitArgs' - _ -> empty - _ -> empty - -cheapReduceApp :: CAtom o -> [CAtom o] -> CheapReducerM CoreIR i o (CAtom o) -cheapReduceApp f xs = case f of - Lam lam -> dropSubst $ withInstantiated lam xs \body -> cheapReduceE body - _ -> empty - -instance IRRep r => CheaplyReducibleE r (IxType r) (IxType r) where - cheapReduceE (IxType t d) = IxType <$> cheapReduceE t <*> cheapReduceE d - -instance IRRep r => CheaplyReducibleE r (IxDict r) (IxDict r) where - cheapReduceE = \case - IxDictAtom x -> IxDictAtom <$> cheapReduceE x - IxDictRawFin n -> IxDictRawFin <$> cheapReduceE n - IxDictSpecialized t d xs -> - IxDictSpecialized <$> cheapReduceE t <*> substM d <*> mapM cheapReduceE xs - -instance (CheaplyReducibleE r e1 e1', CheaplyReducibleE r e2 e2') - => CheaplyReducibleE r (PairE e1 e2) (PairE e1' e2') where - cheapReduceE (PairE e1 e2) = PairE <$> cheapReduceE e1 <*> cheapReduceE e2 - -instance (CheaplyReducibleE r e1 e1', CheaplyReducibleE r e2 e2') - => CheaplyReducibleE r (EitherE e1 e2) (EitherE e1' e2') where - cheapReduceE (LeftE e) = LeftE <$> cheapReduceE e - cheapReduceE (RightE e) = RightE <$> cheapReduceE e - --- XXX: TODO: figure out exactly what our normalization invariants are. We --- shouldn't have to choose `normalizeProj` or `asNaryProj` on a --- case-by-case basis. This is here for now because it makes it easier to switch --- to the new version of `ProjectElt`. -asNaryProj :: IRRep r => Projection -> Atom r n -> (NE.NonEmpty Projection, AtomVar r n) -asNaryProj p (Var v) = (p NE.:| [], v) -asNaryProj p1 (ProjectElt _ p2 x) = do - let (p2' NE.:| ps, v) = asNaryProj p2 x - (p1 NE.:| (p2':ps), v) -asNaryProj p x = error $ "Can't normalize projection: " ++ pprint p ++ " " ++ pprint x - --- assumes the atom is already normalized -normalizeNaryProj :: IRRep r => EnvReader m => [Projection] -> Atom r n -> m n (Atom r n) -normalizeNaryProj [] x = return x -normalizeNaryProj (i:is) x = normalizeProj i =<< normalizeNaryProj is x - --- assumes the atom itself is already normalized -normalizeProj :: IRRep r => EnvReader m => Projection -> Atom r n -> m n (Atom r n) -normalizeProj UnwrapNewtype atom = case atom of - NewtypeCon _ x -> return x - SimpInCore (LiftSimp (NewtypeTyCon t) x) -> do - t' <- snd <$> unwrapNewtypeType t - return $ SimpInCore $ LiftSimp t' x - x -> case getType x of - NewtypeTyCon t -> do - t' <- snd <$> unwrapNewtypeType t - return $ ProjectElt t' UnwrapNewtype x - _ -> error "expected a newtype" -normalizeProj (ProjectProduct i) atom = case atom of - Con (ProdCon xs) -> return $ xs !! i - DepPair l _ _ | i == 0 -> return l - DepPair _ r _ | i == 1 -> return r - SimpInCore (LiftSimp _ x) -> do - x' <- normalizeProj (ProjectProduct i) x - resultTy <- getResultTy - return $ SimpInCore $ LiftSimp resultTy x' - RepValAtom (RepVal _ tree) -> case tree of - Branch trees -> do - resultTy <- getResultTy - repValAtom $ RepVal resultTy (trees!!i) - Leaf _ -> error "unexpected leaf" - _ -> do - resultTy <- getResultTy - return $ ProjectElt resultTy (ProjectProduct i) atom - where - getResultTy = projType i (getType atom) atom -{-# INLINE normalizeProj #-} - --- === lifting imp to simp and simp to core === - -repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n) -repValAtom (RepVal ty tree) = case ty of - ProdTy ts -> case tree of - Branch trees -> ProdVal <$> mapM repValAtom (zipWith RepVal ts trees) - _ -> malformed - BaseTy _ -> case tree of - Leaf x -> case x of - ILit l -> return $ Con $ Lit l - _ -> fallback - _ -> malformed - _ -> fallback - where fallback = return $ RepValAtom $ RepVal ty tree - malformed = error "malformed repval" -{-# INLINE repValAtom #-} - -liftSimpType :: EnvReader m => SType n -> m n (CType n) -liftSimpType = \case - BaseTy t -> return $ BaseTy t - ProdTy ts -> ProdTy <$> mapM rec ts - SumTy ts -> SumTy <$> mapM rec ts - t -> error $ "not implemented: " ++ pprint t - where rec = liftSimpType -{-# INLINE liftSimpType #-} - -liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) -liftSimpAtom ty simpAtom = case simpAtom of - Var _ -> justLift - ProjectElt _ _ _ -> justLift - RepValAtom _ -> justLift -- TODO(dougalm): should we make more effort to pull out products etc? - _ -> do - (cons , ty') <- unwrapLeadingNewtypesType ty - atom <- case (ty', simpAtom) of - (BaseTy _ , Con (Lit v)) -> return $ Con $ Lit v - (ProdTy tys, Con (ProdCon xs)) -> Con . ProdCon <$> zipWithM rec tys xs - (SumTy tys, Con (SumCon _ i x)) -> Con . SumCon tys i <$> rec (tys!!i) x - (DepPairTy dpt, DepPair x1 x2 _) -> do - x1' <- rec (depPairLeftTy dpt) x1 - t2' <- instantiate dpt [x1'] - x2' <- rec t2' x2 - return $ DepPair x1' x2' dpt - _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty' - return $ wrapNewtypesData cons atom - where - rec = liftSimpAtom - justLift = return $ SimpInCore $ LiftSimp ty simpAtom -{-# INLINE liftSimpAtom #-} - -liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) -liftSimpFun (Pi piTy) f = return $ SimpInCore $ LiftSimpFun piTy f -liftSimpFun _ _ = error "not a pi type" - --- See Note [Confuse GHC] from Simplify.hs -confuseGHC :: IRRep r => CheapReducerM r i n (DistinctEvidence n) -confuseGHC = getDistinct -{-# INLINE confuseGHC #-} - --- TODO: These used to live in QueryType. Think about a better way to organize --- them. Maybe a common set of low-level type-querying utils that both --- CheapReduction and QueryType import? - -depPairLeftTy :: IRRep r => DepPairType r n -> Type r n -depPairLeftTy (DepPairType _ b _) = binderType b -{-# INLINE depPairLeftTy #-} - -unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n) -unwrapNewtypeType = \case - Nat -> return (NatCon, IdxRepTy) - Fin n -> return (FinCon n, NatTy) - UserADTType sn defName params -> do - def <- lookupTyCon defName - ty' <- dataDefRep <$> instantiateTyConDef def params - return (UserADTData sn defName params, ty') - ty -> error $ "Shouldn't be projecting: " ++ pprint ty -{-# INLINE unwrapNewtypeType #-} - -projType :: (IRRep r, EnvReader m) => Int -> Type r n -> Atom r n -> m n (Type r n) -projType i ty x = case ty of - ProdTy xs -> return $ xs !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x - instantiate t [xFst] - _ -> error $ "Can't project type: " ++ pprint ty - -unwrapLeadingNewtypesType :: EnvReader m => CType n -> m n ([NewtypeCon n], CType n) -unwrapLeadingNewtypesType = \case - NewtypeTyCon tyCon -> do - (dataCon, ty) <- unwrapNewtypeType tyCon - (dataCons, ty') <- unwrapLeadingNewtypesType ty - return (dataCon:dataCons, ty') - ty -> return ([], ty) - -wrapNewtypesData :: [NewtypeCon n] -> CAtom n-> CAtom n -wrapNewtypesData [] x = x -wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x - -instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n) -instantiateTyConDef tyConDef (TyConParams _ xs) = instantiate tyConDef xs -{-# INLINE instantiateTyConDef #-} - -assumeConst - :: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> body n -assumeConst e = case toAbs e of Abs bs body -> ignoreHoistFailure $ hoist bs body - -arity :: (IRRep r, ToBindersAbs e body r) => e n -> Int -arity e = case toAbs e of Abs bs _ -> nestLength bs - -tryAsConst - :: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> Maybe (body n) -tryAsConst e = - case toAbs e of - Abs bs body -> case hoist bs body of - HoistFailure _ -> Nothing - HoistSuccess e' -> Just e' - -instantiate - :: (EnvReader m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, SinkableE e, - ToBindersAbs e body r, Ext h n) - => e h -> [Atom r n] -> m n (body n) -instantiate e xs = do - Abs bs body <- sinkM $ toAbs e - let bs' = fmapNest (\(BD b) -> b) bs - applySubst (bs' @@> (SubstVal <$> xs)) body -{-# INLINE instantiate #-} - --- "lazy" subst-extending version of `instantiate` -withInstantiated - :: (SubstReader AtomSubstVal m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r) - => e i -> [Atom r o] - -> (forall i'. body i' -> m i' o a) - -> m i o a -withInstantiated e xs cont = do - Abs bs body <- return $ toAbs e - let bs' = fmapNest (\(BD b) -> b) bs - extendSubst (bs' @@> (SubstVal <$> xs)) $ cont body - -instantiateNames - :: (EnvReader m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r, Ext h n) - => e h -> [AtomName r n] -> m n (body n) -instantiateNames e vs = do - Abs bs body <- sinkM $ toAbs e - let bs' = fmapNest (\(BD b) -> b) bs - applyRename (bs' @@> vs) body - --- "lazy" subst-extending version of `instantiateNames` -withInstantiatedNames - :: (SubstReader Name m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r) - => e i -> [AtomName r o] - -> (forall i'. body i' -> m i' o a) - -> m i o a -withInstantiatedNames e vs cont = do - Abs bs body <- return $ toAbs e - let bs' = fmapNest (\(BD b) -> b) bs - extendRenamer (bs' @@> vs) $ cont body - -extendSubstBD - :: forall v m b r i i' o a - . (SubstReader v m, ToBinders b r, IRRep r) - => b i i' -> [v (AtomNameC r) o] -> m i' o a -> m i o a -extendSubstBD bsTop xsTop contTop = go (toBinders bsTop) xsTop contTop - where - go :: Binders r ii ii' -> [v (AtomNameC r) o] -> m ii' o a -> m ii o a - go Empty [] cont = cont - go (Nest (BD b) bs) (x:xs) cont = extendSubst (b@>x) $ go bs xs cont - go _ _ _ = error "zip error" -{-# INLINE extendSubstBD #-} - --- Returns a representation type (type of an TypeCon-typed Newtype payload) --- given a list of instantiated DataConDefs. -dataDefRep :: DataConDefs n -> CType n -dataDefRep (ADTCons cons) = case cons of - [] -> error "unreachable" -- There's no representation for a void type - [DataConDef _ _ ty _] -> ty - tys -> SumTy $ tys <&> \(DataConDef _ _ ty _) -> ty -dataDefRep (StructFields fields) = case map snd fields of - [ty] -> ty - tys -> ProdTy tys - -makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n) -makeStructRepVal tyConName args = do - TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName - case fields of - [_] -> case args of - [arg] -> return arg - _ -> error "wrong number of args" - _ -> return $ ProdVal args - --- === traversable terms === - -class Monad m => NonAtomRenamer m i o | m -> i, m -> o where - renameN :: (IsAtomName c ~ False, Color c) => Name c i -> m (Name c o) - -class NonAtomRenamer m i o => Visitor m r i o | m -> i, m -> o where - visitType :: Type r i -> m (Type r o) - visitAtom :: Atom r i -> m (Atom r o) - visitLam :: LamExpr r i -> m (LamExpr r o) - visitPi :: PiType r i -> m (PiType r o) - -class VisitGeneric (e:: E) (r::IR) | e -> r where - visitGeneric :: Visitor m r i o => e i -> m (e o) - -type Visitor2 (m::MonadKind2) r = forall i o . Visitor (m i o) r i o - -instance VisitGeneric (Atom r) r where visitGeneric = visitAtom -instance VisitGeneric (Type r) r where visitGeneric = visitType -instance VisitGeneric (LamExpr r) r where visitGeneric = visitLam -instance VisitGeneric (PiType r) r where visitGeneric = visitPi - -visitBlock :: Visitor m r i o => Block r i -> m (Block r o) -visitBlock b = visitGeneric (LamExpr Empty b) >>= \case - LamExpr Empty b' -> return b' - _ -> error "not a block" - -visitAlt :: Visitor m r i o => Alt r i -> m (Alt r o) -visitAlt (Abs b body) = do - visitGeneric (UnaryLamExpr b body) >>= \case - UnaryLamExpr b' body' -> return $ Abs b' body' - _ -> error "not an alt" - -traverseOpTerm - :: (GenericOp e, Visitor m r i o, OpConst e r ~ OpConst e r) - => e r i -> m (e r o) -traverseOpTerm e = traverseOp e visitGeneric visitGeneric visitGeneric - -visitAtomDefault - :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) - => Atom r i -> m i o (Atom r o) -visitAtomDefault atom = case atom of - Var _ -> atomSubstM atom - SimpInCore _ -> atomSubstM atom - ProjectElt t i x -> ProjectElt <$> visitType t <*> pure i <*> visitGeneric x - _ -> visitAtomPartial atom - -visitTypeDefault - :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) - => Type r i -> m i o (Type r o) -visitTypeDefault = \case - TyVar v -> atomSubstM $ TyVar v - ProjectEltTy t i x -> ProjectEltTy <$> visitType t <*> pure i <*> visitGeneric x - x -> visitTypePartial x - -visitPiDefault - :: (Visitor2 m r, IRRep r, FromName v, AtomSubstReader v m, EnvExtender2 m) - => PiType r i -> m i o (PiType r o) -visitPiDefault (PiType bs effty) = do - visitBinders bs \bs' -> do - effty' <- visitGeneric effty - return $ PiType bs' effty' - -visitBinders - :: (Visitor2 m r, IRRep r, FromName v, AtomSubstReader v m, EnvExtender2 m) - => Binders r i i' - -> (forall o'. DExt o o' => Binders r o o' -> m i' o' a) - -> m i o a -visitBinders Empty cont = getDistinct >>= \Distinct -> cont Empty -visitBinders (Nest (BD (b:>ty)) bs) cont = do - ty' <- visitType ty - withFreshBinder (getNameHint b) ty' \b' -> do - extendRenamer (b@>binderName b') do - visitBinders bs \bs' -> - cont $ Nest (BD b') bs' - --- XXX: This doesn't handle the `Var`, `ProjectElt`, `SimpInCore` cases. These --- should be handled explicitly beforehand. TODO: split out these cases under a --- separate constructor, perhaps even a `hole` paremeter to `Atom` or part of --- `IR`. -visitAtomPartial :: (IRRep r, Visitor m r i o) => Atom r i -> m (Atom r o) -visitAtomPartial = \case - Var _ -> error "Not handled generically" - SimpInCore _ -> error "Not handled generically" - ProjectElt _ _ _ -> error "Not handled generically" - Con con -> Con <$> visitGeneric con - PtrVar t v -> PtrVar t <$> renameN v - DepPair x y t -> do - x' <- visitGeneric x - y' <- visitGeneric y - ~(DepPairTy t') <- visitGeneric $ DepPairTy t - return $ DepPair x' y' t' - Lam lam -> Lam <$> visitGeneric lam - Eff eff -> Eff <$> visitGeneric eff - DictCon t d -> DictCon <$> visitType t <*> visitGeneric d - NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x - DictHole ctx ty access -> DictHole ctx <$> visitGeneric ty <*> pure access - TypeAsAtom t -> TypeAsAtom <$> visitGeneric t - RepValAtom repVal -> RepValAtom <$> visitGeneric repVal - --- XXX: This doesn't handle the `TyVar` or `ProjectEltTy` cases. These should be --- handled explicitly beforehand. -visitTypePartial :: (IRRep r, Visitor m r i o) => Type r i -> m (Type r o) -visitTypePartial = \case - TyVar _ -> error "Not handled generically" - ProjectEltTy _ _ _ -> error "Not handled generically" - NewtypeTyCon t -> NewtypeTyCon <$> visitGeneric t - Pi t -> Pi <$> visitGeneric t - TabPi t -> TabPi <$> visitGeneric t - TC t -> TC <$> visitGeneric t - DepPairTy t -> DepPairTy <$> visitGeneric t - DictTy t -> DictTy <$> visitGeneric t - -instance IRRep r => VisitGeneric (Expr r) r where - visitGeneric = \case - TopApp et v xs -> TopApp <$> visitGeneric et <*> renameN v <*> mapM visitGeneric xs - TabApp t tab xs -> TabApp <$> visitType t <*> visitGeneric tab <*> mapM visitGeneric xs - -- TODO: should we reuse the original effects? Whether it's valid depends on - -- the type-preservation requirements for a visitor. We should clarify what - -- those are. - Case x alts effTy -> do - x' <- visitGeneric x - alts' <- mapM visitAlt alts - effTy' <- visitGeneric effTy - return $ Case x' alts' effTy' - Atom x -> Atom <$> visitGeneric x - TabCon Nothing t xs -> TabCon Nothing <$> visitGeneric t <*> mapM visitGeneric xs - TabCon (Just (WhenIRE d)) t xs -> TabCon <$> (Just . WhenIRE <$> visitGeneric d) <*> visitGeneric t <*> mapM visitGeneric xs - PrimOp op -> PrimOp <$> visitGeneric op - App et fAtom xs -> App <$> visitGeneric et <*> visitGeneric fAtom <*> mapM visitGeneric xs - ApplyMethod et m i xs -> ApplyMethod <$> visitGeneric et <*> visitGeneric m <*> pure i <*> mapM visitGeneric xs - -instance IRRep r => VisitGeneric (PrimOp r) r where - visitGeneric = \case - UnOp op x -> UnOp op <$> visitGeneric x - BinOp op x y -> BinOp op <$> visitGeneric x <*> visitGeneric y - MemOp op -> MemOp <$> visitGeneric op - VectorOp op -> VectorOp <$> visitGeneric op - MiscOp op -> MiscOp <$> visitGeneric op - Hof op -> Hof <$> visitGeneric op - DAMOp op -> DAMOp <$> visitGeneric op - RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric visitGeneric - -instance IRRep r => VisitGeneric (TypedHof r) r where - visitGeneric (TypedHof eff hof) = TypedHof <$> visitGeneric eff <*> visitGeneric hof - -instance IRRep r => VisitGeneric (Hof r) r where - visitGeneric = \case - For ann d lam -> For ann <$> visitGeneric d <*> visitGeneric lam - RunReader x body -> RunReader <$> visitGeneric x <*> visitGeneric body - RunWriter dest bm body -> RunWriter <$> mapM visitGeneric dest <*> visitGeneric bm <*> visitGeneric body - RunState dest s body -> RunState <$> mapM visitGeneric dest <*> visitGeneric s <*> visitGeneric body - While b -> While <$> visitBlock b - RunIO b -> RunIO <$> visitBlock b - RunInit b -> RunInit <$> visitBlock b - CatchException t b -> CatchException <$> visitType t <*> visitBlock b - Linearize lam x -> Linearize <$> visitGeneric lam <*> visitGeneric x - Transpose lam x -> Transpose <$> visitGeneric lam <*> visitGeneric x - -instance IRRep r => VisitGeneric (BaseMonoid r) r where - visitGeneric (BaseMonoid x lam) = BaseMonoid <$> visitGeneric x <*> visitGeneric lam - -instance IRRep r => VisitGeneric (DAMOp r) r where - visitGeneric = \case - Seq eff dir d x lam -> Seq <$> visitGeneric eff <*> pure dir <*> visitGeneric d <*> visitGeneric x <*> visitGeneric lam - RememberDest eff x lam -> RememberDest <$> visitGeneric eff <*> visitGeneric x <*> visitGeneric lam - AllocDest t -> AllocDest <$> visitGeneric t - Place x y -> Place <$> visitGeneric x <*> visitGeneric y - Freeze x -> Freeze <$> visitGeneric x - -instance IRRep r => VisitGeneric (Effect r) r where - visitGeneric = \case - RWSEffect rws h -> RWSEffect rws <$> visitGeneric h - ExceptionEffect -> pure ExceptionEffect - IOEffect -> pure IOEffect - InitEffect -> pure InitEffect - -instance IRRep r => VisitGeneric (EffectRow r) r where - visitGeneric (EffectRow effs tailVar) = do - effs' <- eSetFromList <$> mapM visitGeneric (eSetToList effs) - tailEffRow <- case tailVar of - NoTail -> return $ EffectRow mempty NoTail - EffectRowTail v -> visitGeneric (Var v) <&> \case - Var v' -> EffectRow mempty (EffectRowTail v') - Eff r -> r - _ -> error "Not a valid effect substitution" - return $ extendEffRow effs' tailEffRow - -instance VisitGeneric DictExpr CoreIR where - visitGeneric = \case - InstantiatedGiven x xs -> InstantiatedGiven <$> visitGeneric x <*> mapM visitGeneric xs - SuperclassProj x i -> SuperclassProj <$> visitGeneric x <*> pure i - InstanceDict v xs -> InstanceDict <$> renameN v <*> mapM visitGeneric xs - IxFin x -> IxFin <$> visitGeneric x - DataData t -> DataData <$> visitGeneric t - -instance VisitGeneric NewtypeCon CoreIR where - visitGeneric = \case - UserADTData sn t params -> UserADTData sn <$> renameN t <*> visitGeneric params - NatCon -> return NatCon - FinCon x -> FinCon <$> visitGeneric x - -instance VisitGeneric NewtypeTyCon CoreIR where - visitGeneric = \case - Nat -> return Nat - Fin x -> Fin <$> visitGeneric x - EffectRowKind -> return EffectRowKind - UserADTType n v params -> UserADTType n <$> renameN v <*> visitGeneric params - -instance VisitGeneric TyConParams CoreIR where - visitGeneric (TyConParams expls xs) = TyConParams expls <$> mapM visitGeneric xs - -instance IRRep r => VisitGeneric (IxDict r) r where - visitGeneric = \case - IxDictAtom x -> IxDictAtom <$> visitGeneric x - IxDictRawFin x -> IxDictRawFin <$> visitGeneric x - IxDictSpecialized t v xs -> IxDictSpecialized <$> visitGeneric t <*> renameN v <*> mapM visitGeneric xs - -instance IRRep r => VisitGeneric (IxType r) r where - visitGeneric (IxType t d) = IxType <$> visitType t <*> visitGeneric d - -instance VisitGeneric DictType CoreIR where - visitGeneric (DictType n v xs) = DictType n <$> renameN v <*> mapM visitGeneric xs - -instance VisitGeneric CoreLamExpr CoreIR where - visitGeneric (CoreLamExpr t lam) = CoreLamExpr <$> visitGeneric t <*> visitGeneric lam - -instance VisitGeneric CorePiType CoreIR where - visitGeneric (CorePiType app expl bs effty) = do - PiType bs' effty' <- visitGeneric $ PiType bs effty - return $ CorePiType app expl bs' effty' - -instance IRRep r => VisitGeneric (TabPiType r) r where - visitGeneric (TabPiType d b eltTy) = do - d' <- visitGeneric d - visitGeneric (PiType (UnaryNest b) (EffTy Pure eltTy)) <&> \case - PiType (UnaryNest b') (EffTy Pure eltTy') -> TabPiType d' b' eltTy' - _ -> error "not a table pi type" - -instance IRRep r => VisitGeneric (DepPairType r) r where - visitGeneric (DepPairType expl b ty) = do - visitGeneric (PiType (UnaryNest b) (EffTy Pure ty)) <&> \case - PiType (UnaryNest b') (EffTy Pure ty') -> DepPairType expl b' ty' - _ -> error "not a dependent pair type" - -instance VisitGeneric (RepVal SimpIR) SimpIR where - visitGeneric (RepVal ty tree) = RepVal <$> visitGeneric ty <*> mapM renameIExpr tree - where renameIExpr = \case - ILit l -> return $ ILit l - IVar v t -> IVar <$> renameN v <*> pure t - IPtrVar v t -> IPtrVar <$> renameN v <*> pure t - -instance IRRep r => VisitGeneric (DeclBinding r) r where - visitGeneric (DeclBinding ann expr) = DeclBinding ann <$> visitGeneric expr - -instance IRRep r => VisitGeneric (EffTy r) r where - visitGeneric (EffTy eff ty) = - EffTy <$> visitGeneric eff <*> visitGeneric ty - -instance VisitGeneric DataConDefs CoreIR where - visitGeneric = \case - ADTCons cons -> ADTCons <$> mapM visitGeneric cons - StructFields defs -> do - let (names, tys) = unzip defs - tys' <- mapM visitGeneric tys - return $ StructFields $ zip names tys' - -instance VisitGeneric DataConDef CoreIR where - visitGeneric (DataConDef sn (Abs bs UnitE) repTy ps) = do - PiType bs' _ <- visitGeneric $ PiType bs $ EffTy Pure UnitTy - repTy' <- visitGeneric repTy - return $ DataConDef sn (Abs bs' UnitE) repTy' ps - -instance VisitGeneric (Con r) r where visitGeneric = traverseOpTerm -instance VisitGeneric (TC r) r where visitGeneric = traverseOpTerm -instance VisitGeneric (MiscOp r) r where visitGeneric = traverseOpTerm -instance VisitGeneric (VectorOp r) r where visitGeneric = traverseOpTerm -instance VisitGeneric (MemOp r) r where visitGeneric = traverseOpTerm - --- === SubstE/SubstB instances === --- These live here, as orphan instances, because we normalize as we substitute. - -toAtomVar :: (EnvReader m, IRRep r) => AtomName r n -> m n (AtomVar r n) -toAtomVar v = do - ty <- getType <$> lookupAtomName v - return $ AtomVar v ty - -newtype SubstVisitor i o a = SubstVisitor { runSubstVisitor :: Reader (Env o, Subst AtomSubstVal i o) a } - deriving (Functor, Applicative, Monad, MonadReader (Env o, Subst AtomSubstVal i o)) - -substV :: (Distinct o, SubstE AtomSubstVal e) => e i -> SubstVisitor i o (e o) -substV x = ask <&> \env -> substE env x - -instance Distinct o => NonAtomRenamer (SubstVisitor i o) i o where - renameN = substV - -instance (Distinct o, IRRep r) => Visitor (SubstVisitor i o) r i o where - visitType = substV - visitAtom = substV - visitLam = substV - visitPi = substV - -instance Color c => SubstE AtomSubstVal (AtomSubstVal c) where - substE (_, env) (Rename name) = env ! name - substE env (SubstVal val) = SubstVal $ substE env val - -instance SubstV (SubstVal Atom) (SubstVal Atom) where - -instance IRRep r => SubstE AtomSubstVal (Atom r) where - substE es@(env, subst) = \case - Var (AtomVar v ty) -> case subst!v of - Rename v' -> Var $ AtomVar v' (substE es ty) - SubstVal x -> x - SimpInCore x -> SimpInCore (substE es x) - ProjectElt _ i x -> do - let x' = substE es x - runEnvReaderM env $ normalizeProj i x' - atom -> runReader (runSubstVisitor $ visitAtomPartial atom) es - -instance IRRep r => SubstE AtomSubstVal (Type r) where - substE es@(env, subst) = \case - TyVar (AtomVar v ty) -> case subst ! v of - Rename v' -> TyVar $ AtomVar v' (substE es ty) - SubstVal (Type t) -> t - SubstVal atom -> error $ "bad substitution: " ++ pprint v ++ " -> " ++ pprint atom - ProjectEltTy _ i x -> do - let x' = substE es x - case runEnvReaderM env $ normalizeProj i x' of - Type t -> t - _ -> error "bad substitution" - ty -> runReader (runSubstVisitor $ visitTypePartial ty) es - -instance SubstE AtomSubstVal SimpInCore - -instance IRRep r => SubstE AtomSubstVal (EffectRow r) where - substE env (EffectRow effs tailVar) = do - let effs' = eSetFromList $ map (substE env) (eSetToList effs) - let tailEffRow = case tailVar of - NoTail -> EffectRow mempty NoTail - EffectRowTail (AtomVar v _) -> case snd env ! v of - Rename v' -> do - let v'' = runEnvReaderM (fst env) $ toAtomVar v' - EffectRow mempty (EffectRowTail v'') - SubstVal (Var v') -> EffectRow mempty (EffectRowTail v') - SubstVal (Eff r) -> r - _ -> error "Not a valid effect substitution" - extendEffRow effs' tailEffRow - -instance IRRep r => SubstE AtomSubstVal (Effect r) - -instance SubstE AtomSubstVal SpecializationSpec where - substE env (AppSpecialization (AtomVar f _) ab) = do - let f' = case snd env ! f of - Rename v -> runEnvReaderM (fst env) $ toAtomVar v - SubstVal (Var v) -> v - _ -> error "bad substitution" - AppSpecialization f' (substE env ab) - -instance SubstE AtomSubstVal EffectDef -instance SubstE AtomSubstVal EffectOpType -instance SubstE AtomSubstVal IExpr -instance IRRep r => SubstE AtomSubstVal (RepVal r) -instance SubstE AtomSubstVal TyConParams -instance SubstE AtomSubstVal DataConDef -instance IRRep r => SubstE AtomSubstVal (BaseMonoid r) -instance IRRep r => SubstE AtomSubstVal (DAMOp r) -instance IRRep r => SubstE AtomSubstVal (TypedHof r) -instance IRRep r => SubstE AtomSubstVal (Hof r) -instance IRRep r => SubstE AtomSubstVal (TC r) -instance IRRep r => SubstE AtomSubstVal (Con r) -instance IRRep r => SubstE AtomSubstVal (MiscOp r) -instance IRRep r => SubstE AtomSubstVal (VectorOp r) -instance IRRep r => SubstE AtomSubstVal (MemOp r) -instance IRRep r => SubstE AtomSubstVal (PrimOp r) -instance IRRep r => SubstE AtomSubstVal (RefOp r) -instance IRRep r => SubstE AtomSubstVal (EffTy r) -instance IRRep r => SubstE AtomSubstVal (Expr r) -instance IRRep r => SubstE AtomSubstVal (GenericOpRep const r) -instance SubstE AtomSubstVal InstanceBody -instance SubstE AtomSubstVal DictType -instance SubstE AtomSubstVal DictExpr -instance IRRep r => SubstE AtomSubstVal (LamExpr r) -instance SubstE AtomSubstVal CorePiType -instance SubstE AtomSubstVal CoreLamExpr -instance IRRep r => SubstE AtomSubstVal (TabPiType r) -instance IRRep r => SubstE AtomSubstVal (PiType r) -instance IRRep r => SubstE AtomSubstVal (DepPairType r) -instance SubstE AtomSubstVal SolverBinding -instance IRRep r => SubstE AtomSubstVal (DeclBinding r) -instance IRRep r => SubstB AtomSubstVal (Decl r) -instance IRRep r => SubstB AtomSubstVal (BinderAndDecls r) -instance SubstE AtomSubstVal NewtypeTyCon -instance SubstE AtomSubstVal NewtypeCon -instance IRRep r => SubstE AtomSubstVal (IxDict r) -instance IRRep r => SubstE AtomSubstVal (IxType r) -instance SubstE AtomSubstVal DataConDefs diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 2f6ac6e66..3390126b2 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -15,12 +15,12 @@ import Control.Monad.Reader import Control.Monad.State.Class import Data.Functor -import CheapReduction import Core import Err import IRVariants import MTL1 import Name +import Builder import Subst import PPrint () import QueryType @@ -87,15 +87,16 @@ parallelAffines actions = TyperM $ do -- === typeable things === checkTypesEq :: IRRep r => Type r o -> Type r o -> TyperM r i o () -checkTypesEq reqTy ty = alphaEq reqTy ty >>= \case - True -> return () - False -> {-# SCC typeNormalization #-} do - reqTy' <- cheapNormalize reqTy - ty' <- cheapNormalize ty - alphaEq reqTy' ty' >>= \case - True -> return () - False -> throw TypeErr $ pprint reqTy' ++ " != " ++ pprint ty' -{-# INLINE checkTypesEq #-} +checkTypesEq reqTy ty = undefined +-- checkTypesEq reqTy ty = alphaEq reqTy ty >>= \case +-- True -> return () +-- False -> {-# SCC typeNormalization #-} do +-- reqTy' <- cheapNormalize reqTy +-- ty' <- cheapNormalize ty +-- alphaEq reqTy' ty' >>= \case +-- True -> return () +-- False -> throw TypeErr $ pprint reqTy' ++ " != " ++ pprint ty' +-- {-# INLINE checkTypesEq #-} class SinkableE e => CheckableE (r::IR) (e::E) | e -> r where checkE :: e i -> TyperM r i o (e o) @@ -179,27 +180,6 @@ instance IRRep r => CheckableE r (Atom r) where con' <- typeCheckNewtypeCon con xTy return $ NewtypeCon con' x' SimpInCore x -> SimpInCore <$> checkE x - DictHole ctx ty access -> do - ty' <- ty |: TyKind - return $ DictHole ctx ty' access - ProjectElt resultTy UnwrapNewtype x -> do - resultTy' <- resultTy |: TyKind - (x', NewtypeTyCon con) <- checkAndGetType x - resultTy'' <- snd <$> unwrapNewtypeType con - checkTypesEq resultTy' resultTy'' - return $ ProjectElt resultTy' UnwrapNewtype x' - ProjectElt resultTy (ProjectProduct i) x -> do - resultTy' <- resultTy |: TyKind - (x', xTy) <- checkAndGetType x - resultTy'' <- case xTy of - ProdTy tys -> return $ tys !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x' - checkInstantiation t [xFst] - _ -> throw TypeErr $ "Not a product type:" ++ pprint xTy - checkTypesEq resultTy' resultTy'' - return $ ProjectElt resultTy' (ProjectProduct i) x' TypeAsAtom ty -> TypeAsAtom <$> checkE ty instance IRRep r => CheckableE r (AtomVar r) where @@ -224,25 +204,6 @@ instance IRRep r => CheckableE r (Type r) where void $ checkInstantiation (Abs paramBs UnitE) params' return $ DictTy (DictType sn className' params') TyVar v -> TyVar <$> checkE v - ProjectEltTy resultTy UnwrapNewtype x -> do - resultTy' <- resultTy |: TyKind - x' <- checkE x - NewtypeTyCon con <- return $ getType x' - ty <- snd <$> unwrapNewtypeType con - checkTypesEq resultTy' ty - return $ ProjectEltTy resultTy' UnwrapNewtype x' - ProjectEltTy resultTy (ProjectProduct i) x -> do - resultTy' <- resultTy |: TyKind - (x', ty) <- checkAndGetType x - resultTy'' <- case ty of - ProdTy tys -> return $ tys !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x' - instantiate t [xFst] - _ -> throw TypeErr $ "Not a product type:" ++ pprint ty - checkTypesEq resultTy' resultTy'' - return $ ProjectEltTy resultTy' (ProjectProduct i) x' instance CheckableE CoreIR SimpInCore where checkE x = renameM x -- TODO: check @@ -255,7 +216,10 @@ instance (ToBinding ann c, Color c, CheckableE r ann) => CheckableB r (BinderP c cont b' instance IRRep r => CheckableB r (BinderAndDecls r) where - checkB (BD b) cont = checkB b \b' -> cont $ BD b' + checkB (BD b ds) cont = checkB b \b' -> checkB ds \ds' -> cont $ BD b' ds' + +instance IRRep r => CheckableB r (Decl r) where + checkB _ _ = undefined checkBinderType :: (IRRep r) => Type r o -> Binder r i i' @@ -270,7 +234,7 @@ checkBinderAndDecls :: (IRRep r) => Type r o -> BinderAndDecls r i i' -> (forall o'. DExt o o' => BinderAndDecls r o o' -> TyperM r i' o' a) -> TyperM r i o a -checkBinderAndDecls ty (BD b) cont = checkBinderType ty b \b' -> cont (BD b') +checkBinderAndDecls ty _ _ = undefined instance IRRep r => CheckableWithEffects r (Expr r) where checkWithEffects allowedEffs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of @@ -309,14 +273,14 @@ instance IRRep r => CheckableWithEffects r (Expr r) where checkTypesEq (sink reqBinderTy) (sink $ binderType b') Abs b' <$> checkBlock (sink effTy') body return $ Case scrut' alts' effTy' - ApplyMethod effTy dict i args -> do - effTy' <- checkEffTy allowedEffs effTy - dict' <- checkE dict - args' <- mapM checkE args - methodTy <- getMethodType dict' i - effTy'' <- checkInstantiation methodTy args' - checkAlphaEq effTy' effTy'' - return $ ApplyMethod effTy' dict' i args' + -- ApplyMethod effTy dict i args -> do + -- effTy' <- checkEffTy allowedEffs effTy + -- dict' <- checkE dict + -- args' <- mapM checkE args + -- methodTy <- getMethodType dict' i + -- effTy'' <- checkInstantiation methodTy args' + -- checkAlphaEq effTy' effTy'' + -- return $ ApplyMethod effTy' dict' i args' TabCon maybeD ty xs -> do ty'@(TabPi (TabPiType _ b restTy)) <- ty |: TyKind maybeD' <- mapM renameM maybeD -- TODO: check @@ -327,6 +291,27 @@ instance IRRep r => CheckableWithEffects r (Expr r) where -- each index from the ix dict. HoistFailure _ -> forM xs checkE return $ TabCon maybeD' ty' xs' + DictHole ctx ty access -> do + ty' <- ty |: TyKind + return $ DictHole ctx ty' access + ProjectElt resultTy UnwrapNewtype x -> undefined + -- resultTy' <- resultTy |: TyKind + -- (x', NewtypeTyCon con) <- checkAndGetType x + -- resultTy'' <- snd <$> unwrapNewtypeType con + -- checkTypesEq resultTy' resultTy'' + -- return $ ProjectElt resultTy' UnwrapNewtype x' + -- ProjectElt resultTy (ProjectProduct i) x -> do + -- resultTy' <- resultTy |: TyKind + -- (x', xTy) <- checkAndGetType x + -- resultTy'' <- case xTy of + -- ProdTy tys -> return $ tys !! i + -- DepPairTy t | i == 0 -> return $ depPairLeftTy t + -- DepPairTy t | i == 1 -> do + -- xFst <- normalizeProj (ProjectProduct 0) x' + -- checkInstantiation t [xFst] + -- _ -> throw TypeErr $ "Not a product type:" ++ pprint xTy + -- checkTypesEq resultTy' resultTy'' + -- return $ ProjectElt resultTy' (ProjectProduct i) x' instance CheckableE CoreIR TyConParams where checkE (TyConParams expls params) = TyConParams expls <$> mapM checkE params @@ -489,17 +474,17 @@ instance IRRep r => CheckableWithEffects r (PrimOp r) where eltTy' <- checkInstantiation tabTy [i'] checkTypesEq givenTy' (TC $ RefType h eltTy') return $ IndexRef givenTy' i' - ProjRef givenTy p -> do - givenTy' <- givenTy |: TyKind - resultEltTy <- case p of - ProjectProduct i -> do - ProdTy tys <- return s - return $ tys !! i - UnwrapNewtype -> do - NewtypeTyCon tc <- return s - snd <$> unwrapNewtypeType tc - checkTypesEq givenTy' (TC $ RefType h resultEltTy) - return $ ProjRef givenTy' p + -- ProjRef givenTy p -> do + -- givenTy' <- givenTy |: TyKind + -- resultEltTy <- case p of + -- ProjectProduct i -> do + -- ProdTy tys <- return s + -- return $ tys !! i + -- UnwrapNewtype -> do + -- NewtypeTyCon tc <- return s + -- snd <$> unwrapNewtypeType tc + -- checkTypesEq givenTy' (TC $ RefType h resultEltTy) + -- return $ ProjRef givenTy' p return $ RefOp ref' m' instance IRRep r => CheckableE r (EffTy r) where @@ -582,12 +567,13 @@ instance IRRep r => CheckableWithEffects r (MiscOp r) where ty|:TyKind checkSomeSumType :: IRRep r => Type r o -> TyperM r i o [Type r o] -checkSomeSumType = \case - SumTy cases -> return cases - NewtypeTyCon con -> do - (_, SumTy cases) <- unwrapNewtypeType con - return cases - t -> error $ "not some sum type: " ++ pprint t +checkSomeSumType = undefined -- need to produce TypeBlock instead? Or just use Emits? +-- checkSomeSumType = \case +-- SumTy cases -> return cases +-- NewtypeTyCon con -> do +-- (_, SumTy cases) <- unwrapNewtypeType con +-- return cases +-- t -> error $ "not some sum type: " ++ pprint t instance IRRep r => CheckableE r (VectorOp r) where checkE = \case @@ -786,11 +772,11 @@ checkInstantiation abTop xsTop = do where go :: Abs (Binders r) body o' -> [Atom r o'] -> TyperM r i o' (body o') go (Abs Empty body) [] = return body - go (Abs (Nest (BD b) bs) body) (x:xs) = do - checkTypesEq (getType x) (binderType b) - rest <- applySubst (b@>SubstVal x) (Abs bs body) - go rest xs - go _ _ = throw ZipErr "Wrong number of args" + -- go (Abs (Nest (BD b) bs) body) (x:xs) = do + -- checkTypesEq (getType x) (binderType b) + -- rest <- applySubst (b@>SubstVal x) (Abs bs body) + -- go rest xs + -- go _ _ = throw ZipErr "Wrong number of args" checkIntBaseType :: Fallible m => BaseType -> m () checkIntBaseType t = case t of diff --git a/src/lib/Core.hs b/src/lib/Core.hs index f6d53574e..e7c221f27 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -40,6 +40,14 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import QueryType + +-- === Various helpers === + +toAtomVar :: (EnvReader m, IRRep r) => AtomName r n -> m n (AtomVar r n) +toAtomVar v = do + ty <- getType <$> lookupAtomName v + return $ AtomVar v ty -- === Typeclasses for monads === @@ -216,7 +224,7 @@ instance IRRep r => BindsEnv (Decl r) where {-# INLINE toEnvFrag #-} instance IRRep r => BindsEnv (BinderAndDecls r) where - toEnvFrag (BD b) = toEnvFrag b + toEnvFrag (BD b d) = toEnvFrag (PairB b d) {-# INLINE toEnvFrag #-} instance BindsEnv EnvFrag where diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 090b090cf..4b47bd141 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -19,7 +19,6 @@ import Foreign.C.String import Foreign.Ptr import Builder -import CheapReduction import Core import Err import IRVariants diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index cebb3a690..5fc9595a1 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -12,9 +12,9 @@ import Core import Err import Types.Core import Inference +import Builder import IRVariants import QueryType -import CheapReduction import Name import Subst import MTL1 @@ -131,13 +131,14 @@ traverseTyParams ty f = getDistinct >>= \Distinct -> case ty of Abs paramRoles UnitE <- getClassRoleBinders name params' <- traverseRoleBinders f paramRoles params return $ DictTy $ DictType sn name params' - TabPi tabTy@(TabPiType (IxDictAtom d) b _) -> do - iTy <- f' TypeParam TyKind $ binderType b - dictTy <- liftM ignoreExcept $ runFallibleT1 $ DictTy <$> ixDictType iTy - d' <- f DictParam dictTy d - withFreshBinder (getNameHint b) iTy \b' -> do - resultTy' <- instantiate tabTy [Var $ binderVar b'] >>= (f' TypeParam TyKind) - return $ TabTy (IxDictAtom d') (PlainBD b') resultTy' + TabPi tabTy@(TabPiType (IxDictAtom d) b _) -> undefined + -- TabPi tabTy@(TabPiType (IxDictAtom d) b _) -> do + -- iTy <- f' TypeParam TyKind $ binderType b + -- dictTy <- liftM ignoreExcept $ runFallibleT1 $ DictTy <$> ixDictType iTy + -- d' <- f DictParam dictTy d + -- withFreshBinder (getNameHint b) iTy \b' -> do + -- resultTy' <- instantiate tabTy [Var $ binderVar b'] >>= (f' TypeParam TyKind) + -- return $ TabTy (IxDictAtom d') (PlainBD b') resultTy' -- shouldn't need this once we can exclude IxDictFin and IxDictSpecialized from CoreI TabPi t -> return $ TabPi t TC tc -> TC <$> case tc of diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 86c6fb2db..ba641666e 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -32,7 +32,6 @@ import qualified Control.Monad.State.Strict as MTL import Algebra import Builder -import CheapReduction import CheckType (CheckableE (..)) import Core import Err @@ -126,12 +125,12 @@ getNaryLamImpArgTypes t = liftEnvReaderM $ go t where return (ts:argTys, resultTys) Empty -> ([],) <$> getDestBaseTypes (etTy effTy) -interpretImpArgsWithDest :: EnvReader m - => PiType SimpIR n -> [IExpr n] -> m n ([SAtom n], Dest n) +interpretImpArgsWithDest :: Emits n + => PiType SimpIR n -> [IExpr n] -> SubstImpM i n ([SAtom n], Dest n) interpretImpArgsWithDest t xs = do piTy@(PiType bs _) <- return t (args, xsLeft) <- _interpretImpArgs (EmptyAbs bs) xs - EffTy _ resultTy' <- instantiate piTy args + EffTy _ resultTy' <- instantiateImp piTy args (destTree, xsRest) <- listToTree resultTy' xsLeft case xsRest of [] -> return () @@ -316,21 +315,37 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of Just (con, arg) -> do Abs b body <- return $ alts !! con extendSubst (b @> SubstVal arg) $ translateBlock body - Nothing -> do - RepVal sumTy (Branch (tag:xss)) <- atomToRepVal e' - ts <- caseAltsBinderTys sumTy - tag' <- repValAtom $ RepVal TagRepTy tag - xss' <- zipWithM (\t x -> repValAtom $ RepVal t x) ts xss - go tag' xss' - where - go tag xss = do - tag' <- fromScalarAtom tag - emitSwitch tag' (zip xss alts) $ - \(xs, Abs b body) -> - extendSubst (b @> SubstVal (sink xs)) $ - void $ translateBlock body - return UnitVal + -- Nothing -> do + -- RepVal sumTy (Branch (tag:xss)) <- atomToRepVal e' + -- ts <- caseAltsBinderTys sumTy + -- tag' <- repValAtom $ RepVal TagRepTy tag + -- xss' <- zipWithM (\t x -> repValAtom $ RepVal t x) ts xss + -- go tag' xss' + -- where + -- go tag xss = do + -- tag' <- fromScalarAtom tag + -- emitSwitch tag' (zip xss alts) $ + -- \(xs, Abs b body) -> + -- extendSubst (b @> SubstVal (sink xs)) $ + -- void $ translateBlock body + -- return UnitVal TabCon _ _ _ -> error "Unexpected `TabCon` in Imp pass." + ProjectElt _ p val -> undefined + -- ProjectElt _ p val -> do + -- (ps, v) <- return $ asNaryProj p val + -- lookupAtomName (atomVarName v) >>= \case + -- TopDataBound (RepVal _ tree) -> applyProjection (toList ps) tree + -- _ -> error "should only be projecting a data atom" + -- where + -- applyProjection :: [Projection] -> Tree (IExpr n) -> SubstImpM i n (Tree (IExpr n)) + -- applyProjection [] t = return t + -- applyProjection (i:is) t = do + -- t' <- applyProjection is t + -- case i of + -- UnwrapNewtype -> error "impossible" + -- ProjectProduct idx -> case t' of + -- Branch ts -> return $ ts !! idx + -- _ -> error "should only be projecting a branch" toImpRefOp :: Emits o => SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o) @@ -363,7 +378,7 @@ toImpRefOp refDest' m = do alphaEq accTy baseTy >>= \case -- Immediately beta-reduce, beacuse Imp doesn't reduce non-table applications. True -> do - body <- instantiate bc [x, y] + body <- instantiateImp bc [x, y] ans <- liftBuilderImp $ emitBlock $ sink body storeAtom accDest ans False -> case accTy of @@ -374,7 +389,7 @@ toImpRefOp refDest' m = do idx <- unsafeFromOrdinalImp (sink ixTy) i xElt <- liftBuilderImp $ tabApp (sink x) (sink idx) yElt <- liftBuilderImp $ tabApp (sink y) (sink idx) - eltTy <- instantiate t [idx] + eltTy <- instantiateImp t [idx] ithDest <- indexDest (sink accDest) idx liftMonoidCombine ithDest eltTy (sink bc) xElt yElt _ -> error $ "Base monoid type mismatch: can't lift " ++ @@ -582,7 +597,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i x' <- sinkM x - eltTy <- instantiate t [idx] + eltTy <- instantiateImp t [idx] ithDest <- indexDest (sink accDest) idx liftMonoidEmpty ithDest eltTy x' _ -> error $ "Base monoid type mismatch: can't lift " ++ @@ -724,43 +739,44 @@ storeRepVal (Dest _ destTree) repVal@(RepVal _ valTree) = do -- Like `typeToTree`, but when we additionally have the value, we can populate -- the existentially-hidden fields. valueToTree :: EnvReader m => SRepVal n -> m n (Tree (LeafType n)) -valueToTree (RepVal tyTop valTop) = do - go REmpty tyTop valTop - where - go :: EnvReader m => RNest (TypeCtxLayer SimpIR) n l -> SType l -> Tree (IExpr n) - -> m n (Tree (LeafType n)) - go ctx ty val = case ty of - BaseTy b -> return $ Leaf $ LeafType (unRNest ctx) b - TabTy d b bodyTy -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val - RefTy _ t -> go (RNest ctx RefCtx) t val - DepPairTy dpTy@(DepPairType _ b t2) -> case val of - Branch [v1, v2] -> do - case allDepPairCtxs (unRNest ctx) of - Just UnitB -> do - let t1 = depPairLeftTy dpTy - tree1 <- rec t1 v1 - x <- repValAtom $ RepVal t1 v1 - t2' <- instantiate dpTy [x] - tree2 <- go (RNest ctx (DepPairCtx NothingB )) t2' v2 - return $ Branch [tree1, tree2] - Nothing -> do - let t1 = depPairLeftTy dpTy - tree1 <- rec t1 v1 - tree2 <- go (RNest ctx (DepPairCtx (JustB b))) t2 v2 - return $ Branch [tree1, tree2] - _ -> error "expected a branch" - ProdTy ts -> case val of - Branch vals -> Branch <$> zipWithM rec ts vals - _ -> error "expected a branch" - SumTy ts -> case val of - Branch (tagVal:vals) -> do - tag <- rec TagRepTy tagVal - results <- zipWithM rec ts vals - return $ Branch $ tag : results - _ -> error "expected a branch" - _ -> error $ "not implemented " ++ pprint ty - where rec = go ctx -{-# INLINE valueToTree #-} +valueToTree (RepVal tyTop valTop) = undefined +-- valueToTree (RepVal tyTop valTop) = do +-- go REmpty tyTop valTop +-- where +-- go :: EnvReader m => RNest (TypeCtxLayer SimpIR) n l -> SType l -> Tree (IExpr n) +-- -> m n (Tree (LeafType n)) +-- go ctx ty val = case ty of +-- BaseTy b -> return $ Leaf $ LeafType (unRNest ctx) b +-- TabTy d b bodyTy -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val +-- RefTy _ t -> go (RNest ctx RefCtx) t val +-- DepPairTy dpTy@(DepPairType _ b t2) -> case val of +-- Branch [v1, v2] -> do +-- case allDepPairCtxs (unRNest ctx) of +-- Just UnitB -> do +-- let t1 = depPairLeftTy dpTy +-- tree1 <- rec t1 v1 +-- x <- repValAtom $ RepVal t1 v1 +-- t2' <- instantiate dpTy [x] +-- tree2 <- go (RNest ctx (DepPairCtx NothingB )) t2' v2 +-- return $ Branch [tree1, tree2] +-- Nothing -> do +-- let t1 = depPairLeftTy dpTy +-- tree1 <- rec t1 v1 +-- tree2 <- go (RNest ctx (DepPairCtx (JustB b))) t2 v2 +-- return $ Branch [tree1, tree2] +-- _ -> error "expected a branch" +-- ProdTy ts -> case val of +-- Branch vals -> Branch <$> zipWithM rec ts vals +-- _ -> error "expected a branch" +-- SumTy ts -> case val of +-- Branch (tagVal:vals) -> do +-- tag <- rec TagRepTy tagVal +-- results <- zipWithM rec ts vals +-- return $ Branch $ tag : results +-- _ -> error "expected a branch" +-- _ -> error $ "not implemented " ++ pprint ty +-- where rec = go ctx +-- {-# INLINE valueToTree #-} allDepPairCtxs :: TypeCtx SimpIR n l -> Maybe (UnitB n l) allDepPairCtxs ctx = case splitLeadingDepPairs ctx of @@ -877,21 +893,6 @@ atomToRepVal x = RepVal (getType x) <$> go x where TopDataBound (RepVal _ tree) -> return tree _ -> error "should only have pointer and data atom names left" PtrVar ty p -> return $ Leaf $ IPtrVar p ty - ProjectElt _ p val -> do - (ps, v) <- return $ asNaryProj p val - lookupAtomName (atomVarName v) >>= \case - TopDataBound (RepVal _ tree) -> applyProjection (toList ps) tree - _ -> error "should only be projecting a data atom" - where - applyProjection :: [Projection] -> Tree (IExpr n) -> SubstImpM i n (Tree (IExpr n)) - applyProjection [] t = return t - applyProjection (i:is) t = do - t' <- applyProjection is t - case i of - UnwrapNewtype -> error "impossible" - ProjectProduct idx -> case t' of - Branch ts -> return $ ts !! idx - _ -> error "should only be projecting a branch" -- XXX: We used to have a function called `destToAtom` which loaded the value -- from the dest. This version is not that. It just lifts a dest into an atom of @@ -1004,7 +1005,7 @@ buildGarbageVal ty = indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n) indexDest (Dest (TabPi tabTy) tree) i = do - eltTy <- instantiate tabTy [i] + eltTy <- instantiateImp tabTy [i] ord <- ordinalImp (tabIxType tabTy) i leafTys <- typeToTree $ TabPi tabTy Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do @@ -1028,7 +1029,7 @@ indexRepValParam :: Emits n -> (IExpr n -> SubstImpM i n (IExpr n)) -> SubstImpM i n (SRepVal n) indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do - eltTy <- instantiate tabTy [i] + eltTy <- instantiateImp tabTy [i] ord <- ordinalImp (tabIxType tabTy) i leafTys <- typeToTree (TabPi tabTy) vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do @@ -1119,12 +1120,13 @@ computeSizeGivenOrdinal :: EnvReader m => IxBinder SimpIR n l -> IndexStructure SimpIR l -> m n (Abs (Binder SimpIR) (Block SimpIR) n) -computeSizeGivenOrdinal (PairB (LiftB d) (BD (b:>t))) idxStruct = liftBuilder do - withFreshBinder noHint IdxRepTy \bOrdinal -> - Abs bOrdinal <$> buildBlock do - i <- unsafeFromOrdinal (sink $ IxType t d) $ Var $ sink $ binderVar bOrdinal - idxStruct' <- applySubst (b@>SubstVal i) idxStruct - elemCountPoly $ sink idxStruct' +computeSizeGivenOrdinal _ _ = undefined +-- computeSizeGivenOrdinal (PairB (LiftB d) (BD (b:>t))) idxStruct = liftBuilder do +-- withFreshBinder noHint IdxRepTy \bOrdinal -> +-- Abs bOrdinal <$> buildBlock do +-- i <- unsafeFromOrdinal (sink $ IxType t d) $ Var $ sink $ binderVar bOrdinal +-- idxStruct' <- applySubst (b@>SubstVal i) idxStruct +-- elemCountPoly $ sink idxStruct' -- Split the index structure into a prefix of non-dependent index types -- and a trailing nest of indices that can contain inter-dependencies. @@ -1209,11 +1211,17 @@ withFreshIBinder hint ty cont = do cont $ IBinder b ty {-# INLINE withFreshIBinder #-} +instantiateImp + :: (Emits n, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, SinkableE e, + ToBindersAbs e body r, Ext h n) + => e h -> [Atom r n] -> SubstImpM i n (body n) +instantiateImp = undefined + emitCall :: Emits n => PiType SimpIR n -> ImpFunName n -> [SAtom n] -> SubstImpM i n (SAtom n) emitCall piTy f xs = do - EffTy _ resultTy' <- instantiate piTy xs + EffTy _ resultTy' <- instantiateImp piTy xs dest <- allocDest resultTy' argsImp <- forM xs \x -> repValToList <$> atomToRepVal x destImp <- repValToList <$> atomToRepVal (destToAtom dest) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index b04d8029a..4999bee13 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -34,7 +34,6 @@ import qualified Unsafe.Coerce as TrulyUnsafe import GHC.Generics (Generic (..)) import Builder -import CheapReduction import CheckType import Core import Err @@ -50,6 +49,7 @@ import Types.Primitives import Types.Source import Util hiding (group) import PPrint (prettyBlock) +import Visitor -- === Top-level interface === @@ -868,9 +868,10 @@ extendSynthCandidates _ _ env = env {-# INLINE extendSynthCandidates #-} extendSynthCandidatess :: Distinct n => [Explicitness] -> CBinders n' n -> Env n -> Env n -extendSynthCandidatess (expl:expls) (Nest (BD b) bs) env = - extendSynthCandidatess expls bs env' - where env' = extendSynthCandidates expl (withExtEvidence bs $ sink $ binderName b) env +extendSynthCandidatess (expl:expls) _ _ = undefined +-- extendSynthCandidatess (expl:expls) (Nest (BD b) bs) env = +-- extendSynthCandidatess expls bs env' +-- where env' = extendSynthCandidates expl (withExtEvidence bs $ sink $ binderName b) env extendSynthCandidatess [] Empty env = env extendSynthCandidatess _ _ _ = error "zip error" {-# INLINE extendSynthCandidatess #-} @@ -924,11 +925,12 @@ inferRho hint expr = checkOrInferRho hint expr Infer {-# INLINE inferRho #-} getImplicitArg :: EmitsInf o => InferenceArgDesc -> InferenceMechanism -> CType o -> InfererM i o (CAtom o) -getImplicitArg desc inf argTy = case inf of - Unify -> Var <$> freshInferenceName (ImplicitArgInfVar desc) argTy - Synth reqMethodAccess -> do - ctx <- srcPosCtx <$> getErrCtx - return $ DictHole (AlwaysEqual ctx) argTy reqMethodAccess +getImplicitArg desc inf argTy = undefined +-- getImplicitArg desc inf argTy = case inf of +-- Unify -> Var <$> freshInferenceName (ImplicitArgInfVar desc) argTy +-- Synth reqMethodAccess -> do +-- ctx <- srcPosCtx <$> getErrCtx +-- return $ DictHole (AlwaysEqual ctx) argTy reqMethodAccess withBlockDecls :: EmitsBoth o @@ -1084,15 +1086,16 @@ checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do {-# INLINE inferAndInstantiate #-} applyFromLiteralMethod :: EmitsBoth n => SourceName -> CAtom n -> DefaultType -> CAtom n -> InfererM i n (CAtom n) -applyFromLiteralMethod methodName defaultVal defaultTy litVal = do - lookupSourceMap methodName >>= \case - Nothing -> return defaultVal - Just ~(UMethodVar methodName') -> do - MethodBinding className _ <- lookupEnv methodName' - resultTyVar <- freshInferenceName MiscInfVar TyKind - dictTy <- DictTy <$> dictType className [Var resultTyVar] - addDefault (atomVarName resultTyVar) defaultTy - emitExpr =<< mkApplyMethod (DictHole (AlwaysEqual emptySrcPosCtx) dictTy Full) 0 [litVal] +applyFromLiteralMethod methodName defaultVal defaultTy litVal = undefined +-- applyFromLiteralMethod methodName defaultVal defaultTy litVal = do +-- lookupSourceMap methodName >>= \case +-- Nothing -> return defaultVal +-- Just ~(UMethodVar methodName') -> do +-- MethodBinding className _ <- lookupEnv methodName' +-- resultTyVar <- freshInferenceName MiscInfVar TyKind +-- dictTy <- DictTy <$> dictType className [Var resultTyVar] +-- addDefault (atomVarName resultTyVar) defaultTy +-- emitExpr =<< mkApplyMethod (DictHole (AlwaysEqual emptySrcPosCtx) dictTy Full) 0 [litVal] -- atom that requires instantiation to become a rho type data SigmaAtom n = @@ -1168,30 +1171,31 @@ data FieldDef (n::S) = deriving (Show, Generic) getFieldDefs :: CType n -> InfererM i n (M.Map FieldName' (FieldDef n)) -getFieldDefs ty = case ty of - NewtypeTyCon (UserADTType _ tyName params) -> do - TyConBinding ~(Just tyDef) (DotMethods dotMethods) <- lookupEnv tyName - instantiateTyConDef tyDef params >>= \case - StructFields fields -> do - let projFields = enumerate fields <&> \(i, (field, _)) -> - [(FieldName field, FieldProj i), (FieldNum i, FieldProj i)] - let methodFields = M.toList dotMethods <&> \(field, f) -> - (FieldName field, FieldDotMethod f params) - return $ M.fromList $ concat projFields ++ methodFields - ADTCons _ -> noFields "" - RefTy _ valTy -> case valTy of - RefTy _ _ -> noFields "" - _ -> do - valFields <- getFieldDefs valTy - return $ M.filter isProj valFields - where isProj = \case - FieldProj _ -> True - _ -> False - ProdTy ts -> return $ M.fromList $ enumerate ts <&> \(i, _) -> (FieldNum i, FieldProj i) - TabPi _ -> noFields "\nArray indexing uses [] now." - _ -> noFields "" - where - noFields s = throw TypeErr $ "Can't get fields for type " ++ pprint ty ++ s +getFieldDefs ty = undefined +-- getFieldDefs ty = case ty of +-- NewtypeTyCon (UserADTType _ tyName params) -> do +-- TyConBinding ~(Just tyDef) (DotMethods dotMethods) <- lookupEnv tyName +-- instantiateTyConDef tyDef params >>= \case +-- StructFields fields -> do +-- let projFields = enumerate fields <&> \(i, (field, _)) -> +-- [(FieldName field, FieldProj i), (FieldNum i, FieldProj i)] +-- let methodFields = M.toList dotMethods <&> \(field, f) -> +-- (FieldName field, FieldDotMethod f params) +-- return $ M.fromList $ concat projFields ++ methodFields +-- ADTCons _ -> noFields "" +-- RefTy _ valTy -> case valTy of +-- RefTy _ _ -> noFields "" +-- _ -> do +-- valFields <- getFieldDefs valTy +-- return $ M.filter isProj valFields +-- where isProj = \case +-- FieldProj _ -> True +-- _ -> False +-- ProdTy ts -> return $ M.fromList $ enumerate ts <&> \(i, _) -> (FieldNum i, FieldProj i) +-- TabPi _ -> noFields "\nArray indexing uses [] now." +-- _ -> noFields "" +-- where +-- noFields s = throw TypeErr $ "Can't get fields for type " ++ pprint ty ++ s instantiateSigma :: forall i o. EmitsBoth o => SigmaAtom o -> InfererM i o (CAtom o) instantiateSigma sigmaAtom = case getType sigmaAtom of @@ -1236,55 +1240,57 @@ etaExpandExplicits :: EmitsInf o => SourceName -> CorePiType o -> (forall o'. (EmitsBoth o', DExt o o') => [CAtom o'] -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -etaExpandExplicits fSourceName piTy@(CorePiType _ explsTop bsTop _) contTop = do - Abs bs body <- go explsTop bsTop \xs -> do - EffTy effs _ <- instantiate piTy xs - withAllowedEffects effs do - body <- buildBlockInf $ contTop $ sinkList xs - return $ PairE effs body - let (expls, bs') = unzipAttrs bs - coreLamExpr ExplicitApp expls $ Abs bs' body - where - go :: (EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e, HoistableE e ) - => [Explicitness] -> CBinders o any - -> (forall o'. (EmitsInf o', DExt o o') => [CAtom o'] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinderAndDecls)) e o) - go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (expl:expls) (Nest b rest) cont = case expl of - Explicit -> do - prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl (binderType b) \v -> do - Abs rest' UnitE <- instantiateNames (Abs b $ Abs rest UnitE) [atomVarName v] - go expls rest' \args -> cont (sink (Var v) : args) - Inferred argSourceName infMech -> do - arg <- getImplicitArg (fSourceName, fromMaybe "_" argSourceName) infMech (binderType b) - Abs rest' UnitE <- instantiate (Abs b $ Abs rest UnitE) [arg] - go expls rest' \args -> cont (sink arg : args) - go _ _ _ = error "zip error" +etaExpandExplicits fSourceName piTy@(CorePiType _ explsTop bsTop _) contTop = undefined +-- etaExpandExplicits fSourceName piTy@(CorePiType _ explsTop bsTop _) contTop = do +-- Abs bs body <- go explsTop bsTop \xs -> do +-- EffTy effs _ <- instantiate piTy xs +-- withAllowedEffects effs do +-- body <- buildBlockInf $ contTop $ sinkList xs +-- return $ PairE effs body +-- let (expls, bs') = unzipAttrs bs +-- coreLamExpr ExplicitApp expls $ Abs bs' body +-- where +-- go :: (EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e, HoistableE e ) +-- => [Explicitness] -> CBinders o any +-- -> (forall o'. (EmitsInf o', DExt o o') => [CAtom o'] -> InfererM i o' (e o')) +-- -> InfererM i o (Abs (Nest (WithExpl CBinderAndDecls)) e o) +-- go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] +-- go (expl:expls) (Nest b rest) cont = case expl of +-- Explicit -> do +-- prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl (binderType b) \v -> do +-- Abs rest' UnitE <- instantiateNames (Abs b $ Abs rest UnitE) [atomVarName v] +-- go expls rest' \args -> cont (sink (Var v) : args) +-- Inferred argSourceName infMech -> do +-- arg <- getImplicitArg (fSourceName, fromMaybe "_" argSourceName) infMech (binderType b) +-- Abs rest' UnitE <- instantiate (Abs b $ Abs rest UnitE) [arg] +-- go expls rest' \args -> cont (sink arg : args) +-- go _ _ _ = error "zip error" buildLamInf :: EmitsInf o => CorePiType o -> (forall o' . (EmitsBoth o', DExt o o') => [(Explicitness, CAtom o')] -> CType o' -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -buildLamInf piTy@(CorePiType appExpl explsTop bsTop _) contTop = do - ab <- go explsTop bsTop \xs -> do - let (expls, xs') = unzip xs - EffTy effs' resultTy' <- instantiate piTy xs' - withAllowedEffects effs' do - body <- buildBlockInf $ contTop (zip expls $ sinkList xs') (sink resultTy') - return $ PairE effs' body - coreLamExpr appExpl explsTop ab - where - go :: (EmitsInf o, HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e) - => [Explicitness] -> CBinders o any - -> (forall o'. (EmitsInf o', DExt o o') => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) - -> InfererM i o (Abs CBinders e o) - go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (expl:expls) (Nest b rest) cont = do - prependAbs <$> buildAbsInfWithDecls (getNameHint b) expl (binderType b) \v -> do - Abs rest' UnitE <- instantiateNames (Abs b $ Abs rest UnitE) [atomVarName v] - go expls rest' \args -> cont $ (expl, sink $ Var v) : args - go _ _ _ = error "zip error" +buildLamInf piTy@(CorePiType appExpl explsTop bsTop _) contTop = undefined +-- buildLamInf piTy@(CorePiType appExpl explsTop bsTop _) contTop = do +-- ab <- go explsTop bsTop \xs -> do +-- let (expls, xs') = unzip xs +-- EffTy effs' resultTy' <- instantiate piTy xs' +-- withAllowedEffects effs' do +-- body <- buildBlockInf $ contTop (zip expls $ sinkList xs') (sink resultTy') +-- return $ PairE effs' body +-- coreLamExpr appExpl explsTop ab +-- where +-- go :: (EmitsInf o, HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e) +-- => [Explicitness] -> CBinders o any +-- -> (forall o'. (EmitsInf o', DExt o o') => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) +-- -> InfererM i o (Abs CBinders e o) +-- go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] +-- go (expl:expls) (Nest b rest) cont = do +-- prependAbs <$> buildAbsInfWithDecls (getNameHint b) expl (binderType b) \v -> do +-- Abs rest' UnitE <- instantiateNames (Abs b $ Abs rest UnitE) [atomVarName v] +-- go expls rest' \args -> cont $ (expl, sink $ Var v) : args +-- go _ _ _ = error "zip error" class ExplicitArg (e::E) where checkExplicitArg :: EmitsBoth o => IsDependent -> e i -> CType o -> InfererM i o (CAtom o) @@ -1352,7 +1358,7 @@ checkOrInferApp f' posArgs namedArgs reqTy = do ty <- return $ getType x constrainTypesEq req ty -maybeInterpretPunsAsTyCons :: RequiredTy CType n -> SigmaAtom n -> InfererM i n (SigmaAtom n) +maybeInterpretPunsAsTyCons :: Emits n => RequiredTy CType n -> SigmaAtom n -> InfererM i n (SigmaAtom n) maybeInterpretPunsAsTyCons (Check TyKind) (SigmaUVar sn _ (UPunVar v)) = do let v' = UTyConVar v ty <- getUVarType v' @@ -1416,7 +1422,7 @@ applyDataCon tc conIx topArgs = do where conTys = conDefs <&> \(DataConDef _ _ rty _) -> rty return $ NewtypeCon (UserADTData sn tc params) repVal where - wrap :: EnvReader m => CType n -> [CAtom n] -> m n (CAtom n) + wrap :: Emits n => CType n -> [CAtom n] -> InfererM i n (CAtom n) wrap _ [arg] = return $ arg wrap rty args = case rty of ProdTy tys -> @@ -1468,7 +1474,7 @@ inferMixedArgs fSourceName explsTop bsAbs posArgs namedArgs = do Just _ -> False arg <- inferMixedArg isDependent (binderType b) expl arg' <- lift11 $ zonk arg - rest' <- instantiate (Abs b rest) [arg'] + rest' <- lift11 $ instantiate (Abs b rest) [arg'] (arg:) <$> go expls rest' go _ _ = error "zip error" @@ -1503,14 +1509,15 @@ checkNamedArgValidity expls offeredNames = do ++ "\nShould be one of: " ++ pprint acceptedNames inferPrimArg :: EmitsBoth o => UExpr i -> InfererM i o (CAtom o) -inferPrimArg x = do - xBlock <- buildBlockInf $ inferRho noHint x - EffTy _ ty <- blockEffTy xBlock - case ty of - TyKind -> cheapReduce xBlock >>= \case - Just reduced -> return reduced - _ -> throw CompilerErr "Type args to primops must be reducible" - _ -> emitBlock xBlock +inferPrimArg x = undefined +-- inferPrimArg x = do +-- xBlock <- buildBlockInf $ inferRho noHint x +-- EffTy _ ty <- blockEffTy xBlock +-- case ty of +-- TyKind -> cheapReduce xBlock >>= \case +-- Just reduced -> return reduced +-- _ -> throw CompilerErr "Type args to primops must be reducible" +-- _ -> emitBlock xBlock matchPrimApp :: Emits o => PrimName -> [CAtom o] -> InfererM i o (CAtom o) matchPrimApp = \case @@ -1611,15 +1618,16 @@ checkSigmaDependent hint e@(WithSrcE ctx _) ty = addSrcContext ctx $ withReducibleEmissions :: ( EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e - , HoistableE e, CheaplyReducibleE CoreIR e e) + , HoistableE e) => String -> (forall o' . (EmitsBoth o', DExt o o') => InfererM i o' (e o')) -> InfererM i o (e o) -withReducibleEmissions msg cont = do - Abs decls result <- buildDeclsInf cont - cheapReduceWithDecls decls result >>= \case - Just t -> return t - _ -> throw TypeErr msg +withReducibleEmissions msg cont = undefined +-- withReducibleEmissions msg cont = do +-- Abs decls result <- buildDeclsInf cont +-- cheapReduceWithDecls decls result >>= \case +-- Just t -> return t +-- _ -> throw TypeErr msg -- === sorting case alternatives === @@ -1676,12 +1684,13 @@ buildSortedCase scrut alts resultTy = do -- TODO: cache this with the instance def (requires a recursive binding) instanceFun :: EnvReader m => InstanceName n -> AppExplicitness -> m n (CAtom n) -instanceFun instanceName appExpl = do - InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName - ab <- liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do - result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> bindersVars bs') - return $ Abs bs' (PairE Pure (WithoutDecls result)) - Lam <$> coreLamExpr appExpl (snd<$>expls) ab +instanceFun instanceName appExpl = undefined +-- instanceFun instanceName appExpl = do +-- InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName +-- ab <- liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do +-- result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> bindersVars bs') +-- return $ Abs bs' (PairE Pure (WithoutDecls result)) +-- Lam <$> coreLamExpr appExpl (snd<$>expls) ab checkMaybeAnnExpr :: EmitsBoth o => NameHint -> Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) @@ -1860,41 +1869,43 @@ withRoleUBinders roleBs cont = case roleBs of _ -> error "zip error" inferULam :: EmitsInf o => ULamExpr i -> InfererM i o (CoreLamExpr o) -inferULam (ULamExpr bs appExpl effs resultTy body) = do - ab <- withUBinders bs \_ -> do - effs' <- fromMaybe Pure <$> mapM checkUEffRow effs - resultTy' <- mapM checkUType resultTy - body' <- buildBlockInf $ withAllowedEffects (sink effs') do - case resultTy' of - Nothing -> withBlockDecls body \result -> inferSigma noHint result - Just resultTy'' -> - withBlockDecls body \result -> - checkSigma noHint result (sink resultTy'') - return (PairE effs' body') - Abs bs' (PairE effs' body') <- return ab - let (expls, bs'') = unzipAttrs bs' - case appExpl of - ImplicitApp -> checkImplicitLamRestrictions bs'' effs' - ExplicitApp -> return () - coreLamExpr appExpl expls $ Abs bs'' $ PairE effs' body' +inferULam (ULamExpr bs appExpl effs resultTy body) = undefined +-- inferULam (ULamExpr bs appExpl effs resultTy body) = do +-- ab <- withUBinders bs \_ -> do +-- effs' <- fromMaybe Pure <$> mapM checkUEffRow effs +-- resultTy' <- mapM checkUType resultTy +-- body' <- buildBlockInf $ withAllowedEffects (sink effs') do +-- case resultTy' of +-- Nothing -> withBlockDecls body \result -> inferSigma noHint result +-- Just resultTy'' -> +-- withBlockDecls body \result -> +-- checkSigma noHint result (sink resultTy'') +-- return (PairE effs' body') +-- Abs bs' (PairE effs' body') <- return ab +-- let (expls, bs'') = unzipAttrs bs' +-- case appExpl of +-- ImplicitApp -> checkImplicitLamRestrictions bs'' effs' +-- ExplicitApp -> return () +-- coreLamExpr appExpl expls $ Abs bs'' $ PairE effs' body' checkImplicitLamRestrictions :: CBinders o o' -> EffectRow CoreIR o' -> InfererM i o () checkImplicitLamRestrictions _ _ = return () -- TODO checkUForExpr :: EmitsBoth o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) -checkUForExpr (UForExpr (UAnnBinder bFor ann cs) body) tabPi@(TabPiType _ bPi _) = do - unless (null cs) $ throw TypeErr "`for` binders shouldn't have constraints" - let iTy = binderType bPi - case ann of - UNoAnn -> return () - UAnn forAnn -> checkUType forAnn >>= constrainTypesEq iTy - Abs b body' <- buildAbsInfWithDecls (getNameHint bFor) Explicit iTy \i -> do - extendRenamer (bFor@>atomVarName i) do - resultTy <- instantiate tabPi [Var i] - buildBlockInf do - withBlockDecls body \result -> - checkSigma noHint result $ sink resultTy - return $ LamExpr (UnaryNest b) body' +checkUForExpr (UForExpr (UAnnBinder bFor ann cs) body) tabPi@(TabPiType _ bPi _) = undefined +-- checkUForExpr (UForExpr (UAnnBinder bFor ann cs) body) tabPi@(TabPiType _ bPi _) = do +-- unless (null cs) $ throw TypeErr "`for` binders shouldn't have constraints" +-- let iTy = binderType bPi +-- case ann of +-- UNoAnn -> return () +-- UAnn forAnn -> checkUType forAnn >>= constrainTypesEq iTy +-- Abs b body' <- buildAbsInfWithDecls (getNameHint bFor) Explicit iTy \i -> do +-- extendRenamer (bFor@>atomVarName i) do +-- resultTy <- instantiate tabPi [Var i] +-- buildBlockInf do +-- withBlockDecls body \result -> +-- checkSigma noHint result $ sink resultTy +-- return $ LamExpr (UnaryNest b) body' inferUForExpr :: EmitsBoth o => UForExpr i -> InfererM i o (LamExpr CoreIR o) inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do @@ -1958,34 +1969,36 @@ checkLamBinders (piExpl:piExpls) (Nest piB piBs) lamBs cont = do checkLamBinders _ _ _ _ = error "zip error" checkInstanceParams :: EmitsInf o => [Explicitness] -> CBinders o any -> [UExpr i] -> InfererM i o [CAtom o] -checkInstanceParams expls bsTop paramsTop = do - checkArity expls paramsTop - go bsTop paramsTop - where - go :: EmitsInf o => CBinders o any -> [UExpr i] -> InfererM i o [CAtom o] - go Empty [] = return [] - go (Nest b bs) (x:xs) = do - x' <- checkUParam (binderType b) x - Abs bs' UnitE <- instantiate (Abs b $ Abs bs UnitE) [x'] - (x':) <$> go bs' xs - go _ _ = error "zip error" +checkInstanceParams expls bsTop paramsTop = undefined +-- checkInstanceParams expls bsTop paramsTop = do +-- checkArity expls paramsTop +-- go bsTop paramsTop +-- where +-- go :: EmitsInf o => CBinders o any -> [UExpr i] -> InfererM i o [CAtom o] +-- go Empty [] = return [] +-- go (Nest b bs) (x:xs) = do +-- x' <- checkUParam (binderType b) x +-- Abs bs' UnitE <- instantiate (Abs b $ Abs bs UnitE) [x'] +-- (x':) <$> go bs' xs +-- go _ _ = error "zip error" checkInstanceBody :: EmitsInf o => ClassName o -> [CAtom o] -> [UMethodDef i] -> InfererM i o (InstanceBody o) -checkInstanceBody className params methods = do - classDef@(ClassDef _ methodNames _ _ _ _ _) <- lookupClassDef className - superclassAbs@(Abs scBs' _) <- instantiate classDef params - superclassTys <- superclassDictTys scBs' - superclassDicts <- mapM (flip trySynthTerm Full) superclassTys - ListE methodTys'' <- instantiate superclassAbs superclassDicts - methodsChecked <- mapM (checkMethodDef className methodTys'') methods - let (idxs, methods') = unzip $ sortOn fst $ methodsChecked - forM_ (repeated idxs) \i -> - throw TypeErr $ "Duplicate method: " ++ pprint (methodNames!!i) - forM_ ([0..(length methodTys'' - 1)] `listDiff` idxs) \i -> - throw TypeErr $ "Missing method: " ++ pprint (methodNames!!i) - return $ InstanceBody superclassDicts methods' +checkInstanceBody className params methods = undefined +-- checkInstanceBody className params methods = do +-- classDef@(ClassDef _ methodNames _ _ _ _ _) <- lookupClassDef className +-- superclassAbs@(Abs scBs' _) <- instantiate classDef params +-- superclassTys <- superclassDictTys scBs' +-- superclassDicts <- mapM (flip trySynthTerm Full) superclassTys +-- ListE methodTys'' <- instantiate superclassAbs superclassDicts +-- methodsChecked <- mapM (checkMethodDef className methodTys'') methods +-- let (idxs, methods') = unzip $ sortOn fst $ methodsChecked +-- forM_ (repeated idxs) \i -> +-- throw TypeErr $ "Duplicate method: " ++ pprint (methodNames!!i) +-- forM_ ([0..(length methodTys'' - 1)] `listDiff` idxs) \i -> +-- throw TypeErr $ "Missing method: " ++ pprint (methodNames!!i) +-- return $ InstanceBody superclassDicts methods' superclassDictTys :: CBinders o o' -> InfererM i o [CType o] superclassDictTys Empty = return [] @@ -2052,22 +2065,23 @@ checkCasePat :: EmitsBoth o -> CType o -> (forall o'. (EmitsBoth o', Ext o o') => InfererM i' o' (CAtom o')) -> InfererM i o (Alt CoreIR o) -checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat of - UPatCon ~(InternalName _ _ conName) ps -> do - (dataDefName, con) <- renameM conName >>= lookupDataCon - TyConDef sourceName roleExpls paramBs (ADTCons cons) <- lookupTyCon dataDefName - DataConDef _ _ repTy idxs <- return $ cons !! con - when (length idxs /= nestLength ps) $ throw TypeErr $ - "Unexpected number of pattern binders. Expected " ++ show (length idxs) - ++ " got " ++ show (nestLength ps) - (params, repTy') <- inferParams sourceName roleExpls (Abs paramBs repTy) - constrainTypesEq scrutineeTy $ TypeCon sourceName dataDefName params - buildAltInf repTy' \arg -> do - args <- forM idxs \projs -> do - ans <- normalizeNaryProj (init projs) (Var arg) - emit $ Atom ans - bindLetPats ps args $ cont - _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" +checkCasePat (WithSrcB pos pat) scrutineeTy cont = undefined +-- checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat of +-- UPatCon ~(InternalName _ _ conName) ps -> do +-- (dataDefName, con) <- renameM conName >>= lookupDataCon +-- TyConDef sourceName roleExpls paramBs (ADTCons cons) <- lookupTyCon dataDefName +-- DataConDef _ _ repTy idxs <- return $ cons !! con +-- when (length idxs /= nestLength ps) $ throw TypeErr $ +-- "Unexpected number of pattern binders. Expected " ++ show (length idxs) +-- ++ " got " ++ show (nestLength ps) +-- (params, repTy') <- inferParams sourceName roleExpls (Abs paramBs repTy) +-- constrainTypesEq scrutineeTy $ TypeCon sourceName dataDefName params +-- buildAltInf repTy' \arg -> do +-- args <- forM idxs \projs -> do +-- ans <- normalizeNaryProj (init projs) (Var arg) +-- emit $ Atom ans +-- bindLetPats ps args $ cont +-- _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" inferParams :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) => SourceName -> [RoleExpl] -> Abs CBinders e o -> InfererM i o (TyConParams o, e o) @@ -2097,14 +2111,14 @@ bindLetPats _ _ _ = error "mismatched number of args" bindLetPat :: EmitsBoth o => UPat i i' -> CAtomVar o -> InfererM i' o a -> InfererM i o a bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of UPatBinder b -> extendSubst (b @> atomVarName v) cont - UPatProd ps -> do - let n = nestLength ps - ty <- return $ getType v - _ <- fromProdType n ty - x <- zonk $ Var v - xs <- forM (iota n) \i -> do - normalizeProj (ProjectProduct i) x >>= emit . Atom - bindLetPats ps xs cont + -- UPatProd ps -> do + -- let n = nestLength ps + -- ty <- return $ getType v + -- _ <- fromProdType n ty + -- x <- zonk $ Var v + -- xs <- forM (iota n) \i -> do + -- normalizeProj (ProjectProduct i) x >>= emit . Atom + -- bindLetPats ps xs cont UPatDepPair (PairB p1 p2) -> do let x = Var v ty <- return $ getType x @@ -2115,20 +2129,20 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of x2 <- getSnd x' >>= zonk >>= emit . Atom bindLetPat p2 x2 do cont - UPatCon ~(InternalName _ _ conName) ps -> do - (dataDefName, _) <- lookupDataCon =<< renameM conName - TyConDef sourceName roleExpls paramBs cons <- lookupTyCon dataDefName - case cons of - ADTCons [DataConDef _ _ _ idxss] -> do - when (length idxss /= nestLength ps) $ throw TypeErr $ - "Unexpected number of pattern binders. Expected " ++ show (length idxss) - ++ " got " ++ show (nestLength ps) - (params, UnitE) <- inferParams sourceName roleExpls (Abs paramBs UnitE) - constrainVarTy v $ TypeCon sourceName dataDefName params - x <- cheapNormalize =<< zonk (Var v) - xs <- forM idxss \idxs -> normalizeNaryProj idxs x >>= emit . Atom - bindLetPats ps xs cont - _ -> throw TypeErr $ "sum type constructor in can't-fail pattern" + -- UPatCon ~(InternalName _ _ conName) ps -> do + -- (dataDefName, _) <- lookupDataCon =<< renameM conName + -- TyConDef sourceName roleExpls paramBs cons <- lookupTyCon dataDefName + -- case cons of + -- ADTCons [DataConDef _ _ _ idxss] -> do + -- when (length idxss /= nestLength ps) $ throw TypeErr $ + -- "Unexpected number of pattern binders. Expected " ++ show (length idxss) + -- ++ " got " ++ show (nestLength ps) + -- (params, UnitE) <- inferParams sourceName roleExpls (Abs paramBs UnitE) + -- constrainVarTy v $ TypeCon sourceName dataDefName params + -- x <- cheapNormalize =<< zonk (Var v) + -- xs <- forM idxss \idxs -> normalizeNaryProj idxs x >>= emit . Atom + -- bindLetPats ps xs cont + -- _ -> throw TypeErr $ "sum type constructor in can't-fail pattern" UPatTable ps -> do elemTy <- freshType let n = fromIntegral (nestLength ps) :: Word32 @@ -2160,36 +2174,37 @@ checkUParam k uty@(WithSrcE pos _) = addSrcContext pos $ inferTabCon :: forall i o. EmitsBoth o => NameHint -> [UExpr i] -> RequiredTy CType o -> InfererM i o (CAtom o) -inferTabCon hint xs reqTy = do - let n = fromIntegral (length xs) :: Word32 - let finTy = FinConst n - ctx <- srcPosCtx <$> getErrCtx - let dataDictHole dTy = Just $ WhenIRE $ DictHole (AlwaysEqual ctx) dTy Full - case reqTy of - Infer -> do - elemTy <- case xs of - [] -> freshType - (x:_) -> getType <$> inferRho noHint x - ixTy <- asIxType finTy - let tabTy = ixTy ==> elemTy - xs' <- forM xs \x -> checkRho noHint x elemTy - dTy <- DictTy <$> dataDictType elemTy - liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' - Check ty -> do - tabTy@(TabPiType _ b _) <- fromTabPiType True ty - constrainTypesEq (binderType b) finTy - xs' <- forM (enumerate xs) \(i, x) -> do - let i' = NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i) :: CAtom o - elemTy' <- instantiate tabTy [i'] - checkRho noHint x elemTy' - dTy <- case tryAsConst tabTy of - Nothing -> ignoreExcept <$> liftEnvReaderT do - withFreshBinder noHint finTy \b' -> do - elemTy' <- instantiate tabTy [Var $ binderVar b'] - dTy <- DictTy <$> dataDictType elemTy' - return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest (BD b')) (EffTy Pure dTy) - Just elemTy' -> DictTy <$> dataDictType elemTy' - liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) ty xs' +inferTabCon hint xs reqTy = undefined +-- inferTabCon hint xs reqTy = do +-- let n = fromIntegral (length xs) :: Word32 +-- let finTy = FinConst n +-- ctx <- srcPosCtx <$> getErrCtx +-- let dataDictHole dTy = Just $ WhenIRE $ DictHole (AlwaysEqual ctx) dTy Full +-- case reqTy of +-- Infer -> do +-- elemTy <- case xs of +-- [] -> freshType +-- (x:_) -> getType <$> inferRho noHint x +-- ixTy <- asIxType finTy +-- let tabTy = ixTy ==> elemTy +-- xs' <- forM xs \x -> checkRho noHint x elemTy +-- dTy <- DictTy <$> dataDictType elemTy +-- liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' +-- Check ty -> do +-- tabTy@(TabPiType _ b _) <- fromTabPiType True ty +-- constrainTypesEq (binderType b) finTy +-- xs' <- forM (enumerate xs) \(i, x) -> do +-- let i' = NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i) :: CAtom o +-- elemTy' <- instantiate tabTy [i'] +-- checkRho noHint x elemTy' +-- dTy <- case tryAsConst tabTy of +-- Nothing -> ignoreExcept <$> liftEnvReaderT do +-- withFreshBinder noHint finTy \b' -> do +-- elemTy' <- instantiate tabTy [Var $ binderVar b'] +-- dTy <- DictTy <$> dataDictType elemTy' +-- return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest (BD b')) (EffTy Pure dTy) +-- Just elemTy' -> DictTy <$> dataDictType elemTy' +-- liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) ty xs' -- Bool flag is just to tweak the reported error message fromTabPiType :: EmitsBoth o => Bool -> CType o -> InfererM i o (TabPiType CoreIR o) @@ -2236,11 +2251,12 @@ openEffectRow (EffectRow effs NoTail) = extendEffRow effs <$> freshEff openEffectRow effRow = return effRow asIxType :: CType o -> InfererM i o (IxType CoreIR o) -asIxType ty = do - dictTy <- DictTy <$> ixDictType ty - ctx <- srcPosCtx <$> getErrCtx - return $ IxType ty $ IxDictAtom $ DictHole (AlwaysEqual ctx) dictTy Full -{-# SCC asIxType #-} +asIxType ty = undefined +-- asIxType ty = do +-- dictTy <- DictTy <$> ixDictType ty +-- ctx <- srcPosCtx <$> getErrCtx +-- return $ IxType ty $ IxDictAtom $ DictHole (AlwaysEqual ctx) dictTy Full +-- {-# SCC asIxType #-} -- === Solver === @@ -2552,25 +2568,27 @@ instance Unifiable CorePiType where => Abs CBinders (EffTy CoreIR) n -> Abs CBinders (EffTy CoreIR) n -> SolverM n () - go (Abs Empty (EffTy e1 t1)) (Abs Empty (EffTy e2 t2)) = unify t1 t2 >> unify e1 e2 - go (Abs (Nest b1 bs1) rest1) - (Abs (Nest b2 bs2) rest2) = do - unify (binderType b1) (binderType b2) - v <- freshSkolemName (binderType b1) - ab1 <- zonk =<< instantiate (Abs b1 (Abs bs1 rest1)) [Var v] - ab2 <- zonk =<< instantiate (Abs b2 (Abs bs2 rest2)) [Var v] - go ab1 ab2 - go _ _ = empty + go = undefined + -- go (Abs Empty (EffTy e1 t1)) (Abs Empty (EffTy e2 t2)) = unify t1 t2 >> unify e1 e2 + -- go (Abs (Nest b1 bs1) rest1) + -- (Abs (Nest b2 bs2) rest2) = do + -- unify (binderType b1) (binderType b2) + -- v <- freshSkolemName (binderType b1) + -- ab1 <- zonk =<< instantiate (Abs b1 (Abs bs1 rest1)) [Var v] + -- ab2 <- zonk =<< instantiate (Abs b2 (Abs bs2 rest2)) [Var v] + -- go ab1 ab2 + -- go _ _ = empty unifyTabPiType :: EmitsInf n => TabPiType CoreIR n -> TabPiType CoreIR n -> SolverM n () -unifyTabPiType (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = do - let ann1 = binderType b1 - let ann2 = binderType b2 - unify ann1 ann2 - v <- freshSkolemName ann1 - ty1' <- instantiate (Abs b1 ty1) [Var v] - ty2' <- instantiate (Abs b2 ty2) [Var v] - unify ty1' ty2' +unifyTabPiType (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = undefined +-- unifyTabPiType (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = do +-- let ann1 = binderType b1 +-- let ann2 = binderType b2 +-- unify ann1 ann2 +-- v <- freshSkolemName ann1 +-- ty1' <- instantiate (Abs b1 ty1) [Var v] +-- ty2' <- instantiate (Abs b2 ty2) [Var v] +-- unify ty1' ty2' extendSolution :: CAtomName n -> CAtom n -> SolverM n () extendSolution v t = @@ -2661,38 +2679,40 @@ generalizeDictAndUnify ty dict = do zonk dict' generalizeDictRec :: EmitsInf n => Dict n -> SolverM n (Dict n) -generalizeDictRec dict = do - -- TODO: we should be able to avoid the normalization here . We only need it - -- because we sometimes end up with superclass projections. But they shouldn't - -- really be allowed to occur in the post-simplification IR. - DictCon _ dict' <- cheapNormalize dict - mkDictAtom =<< case dict' of - InstanceDict instanceName args -> do - InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName - args' <- generalizeInstanceArgs roleExpls bs args - return $ InstanceDict instanceName args' - IxFin _ -> IxFin <$> Var <$> freshInferenceName MiscInfVar NatTy - InstantiatedGiven _ _ -> notSimplifiedDict - SuperclassProj _ _ -> notSimplifiedDict - DataData ty -> DataData <$> TyVar <$> freshInferenceName MiscInfVar ty - where notSimplifiedDict = error $ "Not a simplified dict: " ++ pprint dict +generalizeDictRec dict = undefined +-- generalizeDictRec dict = do +-- -- TODO: we should be able to avoid the normalization here . We only need it +-- -- because we sometimes end up with superclass projections. But they shouldn't +-- -- really be allowed to occur in the post-simplification IR. +-- DictCon _ dict' <- cheapNormalize dict +-- mkDictAtom =<< case dict' of +-- InstanceDict instanceName args -> do +-- InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName +-- args' <- generalizeInstanceArgs roleExpls bs args +-- return $ InstanceDict instanceName args' +-- IxFin _ -> IxFin <$> Var <$> freshInferenceName MiscInfVar NatTy +-- InstantiatedGiven _ _ -> notSimplifiedDict +-- SuperclassProj _ _ -> notSimplifiedDict +-- DataData ty -> DataData <$> TyVar <$> freshInferenceName MiscInfVar ty +-- where notSimplifiedDict = error $ "Not a simplified dict: " ++ pprint dict generalizeInstanceArgs :: EmitsInf n => [RoleExpl] -> CBinders n l -> [CAtom n] -> SolverM n [CAtom n] -generalizeInstanceArgs [] Empty [] = return [] -generalizeInstanceArgs ((role,_):expls) (Nest b bs) (arg:args) = do - arg' <- case role of - -- XXX: for `TypeParam` we can just emit a fresh inference name rather than - -- traversing the whole type like we do in `Generalize.hs`. The reason is - -- that it's valid to implement `generalizeDict` by synthesizing an entirely - -- fresh dictionary, and if we were to do that, we would infer this type - -- parameter exactly as we do here, using inference. - TypeParam -> Var <$> freshInferenceName MiscInfVar TyKind - DictParam -> generalizeDictAndUnify (binderType b) arg - DataParam -> Var <$> freshInferenceName MiscInfVar (binderType b) - Abs bs' UnitE <- instantiate (Abs b (Abs bs UnitE)) [arg'] - args' <- generalizeInstanceArgs expls bs' args - return $ arg':args' -generalizeInstanceArgs _ _ _ = error "zip error" +generalizeInstanceArgs [] Empty [] = undefined +-- generalizeInstanceArgs [] Empty [] = return [] +-- generalizeInstanceArgs ((role,_):expls) (Nest b bs) (arg:args) = do +-- arg' <- case role of +-- -- XXX: for `TypeParam` we can just emit a fresh inference name rather than +-- -- traversing the whole type like we do in `Generalize.hs`. The reason is +-- -- that it's valid to implement `generalizeDict` by synthesizing an entirely +-- -- fresh dictionary, and if we were to do that, we would infer this type +-- -- parameter exactly as we do here, using inference. +-- TypeParam -> Var <$> freshInferenceName MiscInfVar TyKind +-- DictParam -> generalizeDictAndUnify (binderType b) arg +-- DataParam -> Var <$> freshInferenceName MiscInfVar (binderType b) +-- Abs bs' UnitE <- instantiate (Abs b (Abs bs UnitE)) [arg'] +-- args' <- generalizeInstanceArgs expls bs' args +-- return $ arg':args' +-- generalizeInstanceArgs _ _ _ = error "zip error" synthInstanceDefAndAddSynthCandidate :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceName n) @@ -2760,17 +2780,18 @@ synthInstanceDef (InstanceDef className expls bs params body) = do -- main entrypoint to dictionary synthesizer trySynthTerm :: (Fallible1 m, EnvReader m) => CType n -> RequiredMethodAccess -> m n (SynthAtom n) -trySynthTerm ty reqMethodAccess = do - hasInferenceVars ty >>= \case - True -> throw TypeErr "Can't synthesize a dictionary for a type with inference vars" - False -> do - synthTy <- liftExcept $ typeAsSynthType ty - solutions <- liftSyntherM $ synthTerm synthTy reqMethodAccess - case solutions of - [] -> throw TypeErr $ "Couldn't synthesize a class dictionary for: " ++ pprint ty - [d] -> cheapNormalize d -- normalize to reduce code size - _ -> throw TypeErr $ "Multiple candidate class dictionaries for: " ++ pprint ty -{-# SCC trySynthTerm #-} +trySynthTerm ty reqMethodAccess = undefined +-- trySynthTerm ty reqMethodAccess = do +-- hasInferenceVars ty >>= \case +-- True -> throw TypeErr "Can't synthesize a dictionary for a type with inference vars" +-- False -> do +-- synthTy <- liftExcept $ typeAsSynthType ty +-- solutions <- liftSyntherM $ synthTerm synthTy reqMethodAccess +-- case solutions of +-- [] -> throw TypeErr $ "Couldn't synthesize a class dictionary for: " ++ pprint ty +-- [d] -> cheapNormalize d -- normalize to reduce code size +-- _ -> throw TypeErr $ "Multiple candidate class dictionaries for: " ++ pprint ty +-- {-# SCC trySynthTerm #-} type SynthAtom = CAtom type SynthPiType n = ([Explicitness], Abs CBinders DictType n) @@ -2866,7 +2887,7 @@ getSuperclassClosurePure env givens newGivens = synthTy <- return $ getSynthType synthExpr superclasses <- case synthTy of SynthPiType _ -> return [] - SynthDictType dTy -> getSuperclassTys dTy + -- SynthDictType dTy -> getSuperclassTys dTy forM (enumerate superclasses) \(i, ty) -> do return $ DictCon ty $ SuperclassProj synthExpr i @@ -2905,27 +2926,28 @@ withGivenBinders :: (SinkableE e, RenameE e) => [Explicitness] -> Abs CBinders e n -> (forall l. DExt n l => CBinders n l -> e l -> SyntherM l a) -> SyntherM n a -withGivenBinders explsTop (Abs bsTop e) contTop = - runSubstReaderT idSubst $ go explsTop bsTop \bsTop' -> do - e' <- renameM e - liftSubstReaderT $ contTop bsTop' e' - where - go :: [Explicitness] -> CBinders i i' - -> (forall o'. DExt o o' => CBinders o o' -> SubstReaderT Name SyntherM i' o' a) - -> SubstReaderT Name SyntherM i o a - go expls bs cont = case (expls, bs) of - ([], Empty) -> getDistinct >>= \Distinct -> cont Empty - (expl:explsRest, Nest (BD b) rest) -> do - argTy <- renameM $ binderType b - withFreshBinder (getNameHint b) argTy \b' -> do - givens <- case expl of - Inferred _ (Synth _) -> return [Var $ binderVar b'] - _ -> return [] - s <- getSubst - liftSubstReaderT $ extendGivens givens $ - runSubstReaderT (s <>> b@>binderName b') $ - go explsRest rest \rest' -> cont (Nest (BD b') rest') - _ -> error "zip error" +withGivenBinders explsTop (Abs bsTop e) contTop = undefined +-- withGivenBinders explsTop (Abs bsTop e) contTop = +-- runSubstReaderT idSubst $ go explsTop bsTop \bsTop' -> do +-- e' <- renameM e +-- liftSubstReaderT $ contTop bsTop' e' +-- where +-- go :: [Explicitness] -> CBinders i i' +-- -> (forall o'. DExt o o' => CBinders o o' -> SubstReaderT Name SyntherM i' o' a) +-- -> SubstReaderT Name SyntherM i o a +-- go expls bs cont = case (expls, bs) of +-- ([], Empty) -> getDistinct >>= \Distinct -> cont Empty +-- (expl:explsRest, Nest (BD b) rest) -> do +-- argTy <- renameM $ binderType b +-- withFreshBinder (getNameHint b) argTy \b' -> do +-- givens <- case expl of +-- Inferred _ (Synth _) -> return [Var $ binderVar b'] +-- _ -> return [] +-- s <- getSubst +-- liftSubstReaderT $ extendGivens givens $ +-- runSubstReaderT (s <>> b@>binderName b') $ +-- go explsRest rest \rest' -> cont (Nest (BD b') rest') +-- _ -> error "zip error" isMethodAccessAllowedBy :: EnvReader m => RequiredMethodAccess -> InstanceName n -> m n Bool isMethodAccessAllowedBy access instanceName = do @@ -2958,30 +2980,32 @@ synthDictFromInstance targetTy@(DictType _ targetClass _) = do return $ DictCon (DictTy targetTy) $ InstanceDict candidate args instantiateSynthArgs :: DictType n -> SynthPiType n -> SyntherM n [CAtom n] -instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = do - ListE args <- (liftExceptAlt =<<) $ liftSolverM $ solveLocal do - args <- runSubstReaderT idSubst $ go (sink targetTop) explsTop (sink $ Abs bsTop resultTyTop) - zonk $ ListE args - forM args \case - DictHole _ argTy req -> liftExceptAlt (typeAsSynthType argTy) >>= flip synthTerm req - arg -> return arg - where - go :: EmitsInf o - => DictType o -> [Explicitness] -> Abs CBinders DictType i - -> SubstReaderT AtomSubstVal SolverM i o [CAtom o] - go target allExpls (Abs bs proposed) = case (allExpls, bs) of - ([], Empty) -> do - proposed' <- substM proposed - liftSubstReaderT $ unify target proposed' - return [] - (expl:expls, Nest b rest) -> do - argTy <- substM $ binderType b - arg <- liftSubstReaderT case expl of - Explicit -> error "instances shouldn't have explicit args" - Inferred _ Unify -> Var <$> freshInferenceName MiscInfVar argTy - Inferred _ (Synth req) -> return $ DictHole (AlwaysEqual emptySrcPosCtx) argTy req - liftM (arg:) $ extendSubstBD b [SubstVal arg] $ go target expls (Abs rest proposed) - _ -> error "zip error" +instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = undefined +-- instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = do +-- ListE args <- (liftExceptAlt =<<) $ liftSolverM $ solveLocal do +-- args <- runSubstReaderT idSubst $ go (sink targetTop) explsTop (sink $ Abs bsTop resultTyTop) +-- zonk $ ListE args +-- forM args \case +-- DictHole _ argTy req -> liftExceptAlt (typeAsSynthType argTy) >>= flip synthTerm req +-- arg -> return arg +-- where +-- go :: EmitsInf o +-- => DictType o -> [Explicitness] -> Abs CBinders DictType i +-- -> SubstReaderT AtomSubstVal SolverM i o [CAtom o] +-- go target allExpls (Abs bs proposed) = undefined + -- go target allExpls (Abs bs proposed) = case (allExpls, bs) of + -- ([], Empty) -> do + -- proposed' <- substM proposed + -- liftSubstReaderT $ unify target proposed' + -- return [] + -- (expl:expls, Nest b rest) -> do + -- argTy <- substM $ binderType b + -- arg <- liftSubstReaderT case expl of + -- Explicit -> error "instances shouldn't have explicit args" + -- Inferred _ Unify -> Var <$> freshInferenceName MiscInfVar argTy + -- Inferred _ (Synth req) -> return $ DictHole (AlwaysEqual emptySrcPosCtx) argTy req + -- liftM (arg:) $ extendSubstBD b [SubstVal arg] $ go target expls (Abs rest proposed) + -- _ -> error "zip error" synthDictForData :: forall n. DictType n -> SyntherM n (SynthAtom n) synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of @@ -2991,9 +3015,10 @@ synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of TabPi (TabPiType _ b eltTy) -> recurBinder (Abs b eltTy) >> success DepPairTy (DepPairType _ b r) -> do recur (binderType b) >> recurBinder (Abs b r) >> success - NewtypeTyCon nt -> do - (_, ty') <- unwrapNewtypeType nt - recur ty' >> success + NewtypeTyCon nt -> undefined + -- NewtypeTyCon nt -> do + -- (_, ty') <- unwrapNewtypeType nt + -- recur ty' >> success TC con -> case con of BaseType _ -> success ProdType as -> mapM_ recur as >> success @@ -3058,23 +3083,23 @@ instance DictSynthTraversable (TopLam CoreIR) where dsTraverse (TopLam d ty lam) = TopLam d <$> visitPiDefault ty <*> visitLamNoEmits lam instance DictSynthTraversable CAtom where - dsTraverse atom = case atom of - DictHole (AlwaysEqual ctx) ty access -> do - ty' <- cheapNormalize =<< dsTraverse ty - ans <- liftEnvReaderT $ addSrcContext ctx $ trySynthTerm ty' access - case ans of - Failure errs -> put (LiftE errs) >> renameM atom - Success d -> return d - Lam (CoreLamExpr piTy@(CorePiType _ expls _ _) (LamExpr bsLam (Abs decls result))) -> do - Pi piTy' <- dsTraverse $ Pi piTy - lam' <- dsTraverseExplBinders expls bsLam \bsLam' -> do - visitDeclsNoEmits decls \decls' -> do - LamExpr bsLam' <$> Abs decls' <$> dsTraverse result - return $ Lam $ CoreLamExpr piTy' lam' - Var _ -> renameM atom - SimpInCore _ -> renameM atom - ProjectElt _ _ _ -> renameM atom - _ -> visitAtomPartial atom + dsTraverse atom = undefined + -- dsTraverse atom = case atom of + -- DictHole (AlwaysEqual ctx) ty access -> do + -- ty' <- cheapNormalize =<< dsTraverse ty + -- ans <- liftEnvReaderT $ addSrcContext ctx $ trySynthTerm ty' access + -- case ans of + -- Failure errs -> put (LiftE errs) >> renameM atom + -- Success d -> return d + -- Lam (CoreLamExpr piTy@(CorePiType _ expls _ _) (LamExpr bsLam (Abs decls result))) -> do + -- Pi piTy' <- dsTraverse $ Pi piTy + -- lam' <- dsTraverseExplBinders expls bsLam \bsLam' -> do + -- visitDeclsNoEmits decls \decls' -> do + -- LamExpr bsLam' <$> Abs decls' <$> dsTraverse result + -- return $ Lam $ CoreLamExpr piTy' lam' + -- Var _ -> renameM atom + -- SimpInCore _ -> renameM atom + -- _ -> visitAtomPartial atom instance DictSynthTraversable CType where dsTraverse ty = case ty of @@ -3082,7 +3107,6 @@ instance DictSynthTraversable CType where dsTraverseExplBinders expls bs \bs' -> do CorePiType appExpl expls bs' <$> (EffTy <$> renameM effs <*> dsTraverse resultTy) TyVar _ -> renameM ty - ProjectEltTy _ _ _ -> renameM ty _ -> visitTypePartial ty instance DictSynthTraversable DataConDefs where dsTraverse = visitGeneric @@ -3092,13 +3116,13 @@ dsTraverseExplBinders -> (forall o'. DExt o o' => CBinders o o' -> DictSynthTraverserM i' o' a) -> DictSynthTraverserM i o a dsTraverseExplBinders [] Empty cont = getDistinct >>= \Distinct -> cont Empty -dsTraverseExplBinders (expl:expls) (Nest (BD b) bs) cont = do - ty <- dsTraverse $ binderType b - withFreshBinder (getNameHint b) ty \b' -> do - let v = binderName b' - extendSynthCandidatesDict expl v $ extendRenamer (b@>v) do - dsTraverseExplBinders expls bs \bs' -> cont $ Nest (BD b') bs' -dsTraverseExplBinders _ _ _ = error "zip error" +-- dsTraverseExplBinders (expl:expls) (Nest (BD b) bs) cont = do +-- ty <- dsTraverse $ binderType b +-- withFreshBinder (getNameHint b) ty \b' -> do +-- let v = binderName b' +-- extendSynthCandidatesDict expl v $ extendRenamer (b@>v) do +-- dsTraverseExplBinders expls bs \bs' -> cont $ Nest (BD b') bs' +-- dsTraverseExplBinders _ _ _ = error "zip error" extendSynthCandidatesDict :: Explicitness -> CAtomName n -> DictSynthTraverserM i n a -> DictSynthTraverserM i n a extendSynthCandidatesDict c v cont = DictSynthTraverserM do @@ -3124,28 +3148,30 @@ buildBlockInf :: EmitsInf n => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (CAtom l)) -> InfererM i n (CBlock n) -buildBlockInf cont = do - Abs decls (PairE result ty) <- buildDeclsInf do - ans <- cont - ty <- cheapNormalize $ getType ans - return $ PairE ans ty - let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line - <> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line - void $ liftHoistExcept' (docAsStr msg) $ hoist decls ty - return $ Abs decls result -{-# INLINE buildBlockInf #-} +buildBlockInf cont = undefined +-- buildBlockInf cont = do +-- Abs decls (PairE result ty) <- buildDeclsInf do +-- ans <- cont +-- ty <- cheapNormalize $ getType ans +-- return $ PairE ans ty +-- let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line +-- <> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line +-- void $ liftHoistExcept' (docAsStr msg) $ hoist decls ty +-- return $ Abs decls result +-- {-# INLINE buildBlockInf #-} buildBlockInfWithRecon :: (EmitsInf n, RenameE e, HoistableE e, SinkableE e) => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (e l)) -> InfererM i n (PairE CBlock (ReconAbs CoreIR e) n) -buildBlockInfWithRecon cont = do - ab <- buildDeclsInfUnzonked cont - (block, recon) <- refreshAbs ab \decls result -> do - (newResult, recon) <- telescopicCapture decls result - return (Abs decls newResult, recon) - return $ PairE block recon -{-# INLINE buildBlockInfWithRecon #-} +buildBlockInfWithRecon cont = undefined +-- buildBlockInfWithRecon cont = do +-- ab <- buildDeclsInfUnzonked cont +-- (block, recon) <- refreshAbs ab \decls result -> do +-- (newResult, recon) <- telescopicCapture decls result +-- return (Abs decls newResult, recon) +-- return $ PairE block recon +-- {-# INLINE buildBlockInfWithRecon #-} buildTabPiInf :: EmitsInf n diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 75039399b..c3db61eae 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -11,7 +11,6 @@ import Data.List.NonEmpty qualified as NE import Builder import Core import Err -import CheapReduction import IRVariants import Name import Subst @@ -19,6 +18,7 @@ import Occurrence hiding (Var) import Optimize import Types.Core import Types.Primitives +import Visitor -- === External API === @@ -280,10 +280,6 @@ inlineExpr ctx = \case inlineAtom :: Emits o => Context SExpr e o -> SAtom i -> InlineM i o (e o) inlineAtom ctx = \case Var name -> inlineName ctx name - ProjectElt _ i x -> do - let (idxs, v) = asNaryProj i x - ans <- normalizeNaryProj (NE.toList idxs) =<< inline Stop (Var v) - reconstruct ctx $ Atom ans atom -> (Atom <$> visitAtomPartial atom) >>= reconstruct ctx inlineName :: Emits o => Context SExpr e o -> SAtomVar i -> InlineM i o (e o) @@ -323,11 +319,11 @@ withBinders -> (forall o'. DExt o o' => SBinders o o' -> InlineM i' o' a) -> InlineM i o a withBinders Empty cont = getDistinct >>= \Distinct -> cont Empty -withBinders (Nest (BD (b:>ty)) bs) cont = do - ty' <- buildScopedAssumeNoDecls $ inline Stop ty - withFreshBinder (getNameHint b) ty' \b' -> - extendRenamer (b@>binderName b') do - withBinders bs \bs' -> cont $ Nest (BD b') bs' +-- withBinders (Nest (BD (b:>ty)) bs) cont = do +-- ty' <- buildScopedAssumeNoDecls $ inline Stop ty +-- withFreshBinder (getNameHint b) ty' \b' -> +-- extendRenamer (b@>binderName b') do +-- withBinders bs \bs' -> cont $ Nest (BD b') bs' instance Inlinable (PiType SimpIR) where inline ctx (PiType bs effTy) = diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 2ec562777..c17390536 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -17,7 +17,6 @@ import GHC.Stack import Builder import Core -import CheapReduction import Imp import IRVariants import MTL1 @@ -73,7 +72,7 @@ extendActiveSubst extendActiveSubst b v cont = extendSubst (b@>atomVarName v) $ extendActivePrimals v cont extendActiveSubstBD :: BinderAndDecls SimpIR i i' -> SAtomVar o -> PrimalM i' o a -> PrimalM i o a -extendActiveSubstBD (BD b) v cont = extendActiveSubst b v cont +extendActiveSubstBD (BD b decls) v cont = undefined -- extendActiveSubst b v cont extendActiveEffs :: Effect SimpIR o -> PrimalM i o a -> PrimalM i o a extendActiveEffs eff = local \primals -> @@ -235,11 +234,12 @@ data ObligateReconAbs (e::E) (n::S) = ObligateRecon (SType n) (ReconAbs SimpIR e n) instance ReconFunctor MaybeReconAbs where - capture locals original toCapture = do - (reconVal, recon) <- telescopicCapture locals toCapture - case recon of - Abs (ReconBinders _ Empty) toCapture' -> return (original, TrivialRecon toCapture') - _ -> return (PairVal original reconVal, ReconWithData recon) + capture locals original toCapture = undefined + -- capture locals original toCapture = do + -- (reconVal, recon) <- telescopicCapture locals toCapture + -- case recon of + -- Abs (ReconBinders _ Empty) toCapture' -> return (original, TrivialRecon toCapture') + -- _ -> return (PairVal original reconVal, ReconWithData recon) reconstruct primalAux recon = case recon of TrivialRecon linLam -> return (primalAux, linLam) @@ -249,11 +249,12 @@ instance ReconFunctor MaybeReconAbs where return (primal, linLam') instance ReconFunctor ObligateReconAbs where - capture locals original toCapture = do - (reconVal, recon) <- telescopicCapture locals toCapture - -- TODO: telescopicCapture should probably return the hoisted type - reconValTy <- return $ ignoreHoistFailure $ hoist locals $ getType reconVal - return (PairVal original reconVal, ObligateRecon reconValTy recon) + capture locals original toCapture = undefined + -- capture locals original toCapture = do + -- (reconVal, recon) <- telescopicCapture locals toCapture + -- -- TODO: telescopicCapture should probably return the hoisted type + -- reconValTy <- return $ ignoreHoistFailure $ hoist locals $ getType reconVal + -- return (PairVal original reconVal, ObligateRecon reconValTy recon) reconstruct primalAux recon = case recon of ObligateRecon _ reconAbs -> do @@ -292,29 +293,30 @@ linearize f x = runPrimalMInit $ linearizeLambdaApp f x {-# SCC linearize #-} linearizeTopLam :: STopLam n -> [Active] -> DoubleBuilder SimpIR n (STopLam n, STopLam n) -linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do - (primalFun, tangentFun) <- runPrimalMInit $ refreshBinders bs \bs' frag -> extendSubst frag do - let allPrimals = bindersVars bs' - activeVs <- catMaybes <$> forM (zip actives allPrimals) \(active, v) -> case active of - True -> return $ Just v - False -> return $ Nothing - (body', linLamAbs) <- extendActivePrimalss activeVs do - linearizeBlockDefuncGeneral emptyOutFrag body - let primalFun = LamExpr bs' body' - ObligateRecon ty (Abs bsRecon (LamExpr bsTangent tangentBody)) <- return linLamAbs - tangentFun <- withFreshBinder "residuals" ty \bResidual -> do - xs <- unpackTelescope bsRecon $ Var $ binderVar bResidual - Abs bsTangent' UnitE <- applySubst (bsRecon @@> map SubstVal xs) (Abs bsTangent UnitE) - tangentTy <- ProdTy <$> typesFromNonDepBinderNest bsTangent' - withFreshBinder "t" tangentTy \bTangent -> do - tangentBody' <- buildBlock do - ts <- getUnpacked $ Var $ sink $ binderVar bTangent - let substFrag = bsRecon @@> map (SubstVal . sink) xs - <.> (fmapNest (\(BD b) -> b) bsTangent) @@> map (SubstVal . sink) ts - emitBlock =<< applySubst substFrag tangentBody - return $ LamExpr (bs' >>> BinaryNest (PlainBD bResidual) (PlainBD bTangent)) tangentBody' - return (primalFun, tangentFun) - (,) <$> asTopLam primalFun <*> asTopLam tangentFun +linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = undefined +-- linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do +-- (primalFun, tangentFun) <- runPrimalMInit $ refreshBinders bs \bs' frag -> extendSubst frag do +-- let allPrimals = bindersVars bs' +-- activeVs <- catMaybes <$> forM (zip actives allPrimals) \(active, v) -> case active of +-- True -> return $ Just v +-- False -> return $ Nothing +-- (body', linLamAbs) <- extendActivePrimalss activeVs do +-- linearizeBlockDefuncGeneral emptyOutFrag body +-- let primalFun = LamExpr bs' body' +-- ObligateRecon ty (Abs bsRecon (LamExpr bsTangent tangentBody)) <- return linLamAbs +-- tangentFun <- withFreshBinder "residuals" ty \bResidual -> do +-- xs <- unpackTelescope bsRecon $ Var $ binderVar bResidual +-- Abs bsTangent' UnitE <- applySubst (bsRecon @@> map SubstVal xs) (Abs bsTangent UnitE) +-- tangentTy <- ProdTy <$> typesFromNonDepBinderNest bsTangent' +-- withFreshBinder "t" tangentTy \bTangent -> do +-- tangentBody' <- buildBlock do +-- ts <- getUnpacked $ Var $ sink $ binderVar bTangent +-- let substFrag = bsRecon @@> map (SubstVal . sink) xs +-- <.> (fmapNest (\(BD b) -> b) bsTangent) @@> map (SubstVal . sink) ts +-- emitBlock =<< applySubst substFrag tangentBody +-- return $ LamExpr (bs' >>> BinaryNest (PlainBD bResidual) (PlainBD bTangent)) tangentBody' +-- return (primalFun, tangentFun) +-- (,) <$> asTopLam primalFun <*> asTopLam tangentFun linearizeTopLam (TopLam True _ _) _ = error "expected a non-destination-passing function" -- reify the tangent builder as a lambda @@ -337,12 +339,6 @@ linearizeAtom atom = case atom of Con con -> linearizePrimCon con DepPair _ _ _ -> notImplemented PtrVar _ _ -> emitZeroT - ProjectElt _ i x -> do - WithTangent x' tx <- linearizeAtom x - xi <- normalizeProj i x' - return $ WithTangent xi do - t <- tx - normalizeProj i t RepValAtom _ -> emitZeroT where emitZeroT = withZeroT $ renameM atom @@ -431,6 +427,12 @@ linearizeExpr expr = case expr of ty' <- renameM ty seqLin (map linearizeAtom xs) `bindLin` \(ComposeE xs') -> emitExpr $ TabCon Nothing (sink ty') xs' + ProjectElt _ i x -> undefined + -- WithTangent x' tx <- linearizeAtom x + -- xi <- normalizeProj i x' + -- return $ WithTangent xi do + -- t <- tx + -- normalizeProj i t linearizeOp :: Emits o => PrimOp SimpIR i -> LinM i o SAtom SAtom linearizeOp op = case op of @@ -674,7 +676,7 @@ linearizeEffectFun rws (BinaryLamExpr hB refB body) = do -- ensures that such references can never be *used* once the effect runner -- returns, but technically it's legal to return them. let linLam' = ignoreHoistFailure $ hoist (PairB h b) linLam - return (BinaryLamExpr (BD h) (BD b) body', linLam') + return (BinaryLamExpr (PlainBD h) (PlainBD b) body', linLam') linearizeEffectFun _ _ = error "expect effect function to be a binary lambda" withT :: PrimalM i o (e1 o) diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 3362b539a..6dec28625 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -21,7 +21,6 @@ import Unsafe.Coerce import Builder import Core import Imp -import CheapReduction import IRVariants import Name import Subst @@ -29,6 +28,7 @@ import QueryType import Types.Core import Types.Primitives import Util (enumerate) +import Visitor -- === For loop resolution === @@ -64,12 +64,12 @@ type DestBlock = Abs (SBinder) SBlock lowerFullySequential :: EnvReader m => Bool -> STopLam n -> m n (STopLam n) lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftEnvReaderM $ do lam <- case wantDestStyle of - True -> do - refreshAbs (Abs bs body) \bs' body' -> do - let xs = Var <$> bindersVars bs' - EffTy _ resultTy <- instantiate piTy xs - Abs b body'' <- lowerFullySequentialBlock resultTy body' - return $ LamExpr (bs' >>> UnaryNest (PlainBD b)) body'' + -- True -> do + -- refreshAbs (Abs bs body) \bs' body' -> do + -- let xs = Var <$> bindersVars bs' + -- EffTy _ resultTy <- instantiate piTy xs + -- Abs b body'' <- lowerFullySequentialBlock resultTy body' + -- return $ LamExpr (bs' >>> UnaryNest (PlainBD b)) body'' False -> do refreshAbs (Abs bs body) \bs' body' -> do body'' <- lowerFullySequentialBlockNoDest body' @@ -143,7 +143,7 @@ lowerFor ansTy maybeDest dir ixTy (UnaryLamExpr (ib:>ty) body) = do let destTy = getType initDest body' <- buildUnaryLamExpr noHint (PairTy ty' destTy) \b' -> do (i, destProd) <- fromPair $ Var b' - dest <- normalizeProj (ProjectProduct 0) destProd + dest <- getProj 0 destProd idest <- emitOp =<< mkIndexRef dest i extendSubst (ib @> SubstVal i) $ lowerBlockWithDest idest body $> UnitVal ans <- emitSeq dir ixTy' initDest body' >>= getProj 0 @@ -237,12 +237,13 @@ lookupDest = flip lookupNameMap -- -- XXX: When adding more cases, be careful about potentially repeated vars in the output! decomposeDest :: Emits o => Dest SimpIR o -> SAtom i' -> LowerM i o (Maybe (DestAssignment i' o)) -decomposeDest dest = \case - Var v -> return $ Just $ singletonNameMap (atomVarName v) $ FullDest dest - ProjectElt _ p x -> do - (ps, v) <- return $ asNaryProj p x - return $ Just $ singletonNameMap (atomVarName v) $ ProjDest ps dest - _ -> return Nothing +decomposeDest dest = undefined +-- decomposeDest dest = \case +-- Var v -> return $ Just $ singletonNameMap (atomVarName v) $ FullDest dest +-- ProjectElt _ p x -> do +-- (ps, v) <- return $ asNaryProj p x +-- return $ Just $ singletonNameMap (atomVarName v) $ ProjDest ps dest +-- _ -> return Nothing lowerBlockWithDest :: Emits o => Dest SimpIR o -> SBlock i -> LowerM i o (SAtom o) lowerBlockWithDest dest (Abs decls ans) = do @@ -349,11 +350,12 @@ lowerExprWithDest dest expr = case expr of ProjDest _ _ -> return Nothing place :: Emits o => ProjDest o -> SAtom o -> LowerM i o () -place pd x = case pd of - FullDest d -> void $ emitOp $ DAMOp $ Place d x - ProjDest p d -> do - x' <- normalizeNaryProj (NE.toList p) x - void $ emitOp $ DAMOp $ Place d x' +place pd x = undefined +-- place pd x = case pd of +-- FullDest d -> void $ emitOp $ DAMOp $ Place d x +-- ProjDest p d -> do +-- x' <- normalizeNaryProj (NE.toList p) x +-- void $ emitOp $ DAMOp $ Place d x' -- === Extensions to the name system === diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index fddcb5ab6..eb61f6768 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -12,7 +12,6 @@ import Data.Maybe (fromMaybe) import Control.Monad.Reader.Class import Core -import CheapReduction import IRVariants import Name import MTL1 @@ -21,6 +20,7 @@ import Occurrence qualified as Occ import Types.Core import Types.Primitives import QueryType +import Visitor -- === External API === @@ -239,7 +239,6 @@ instance HasOCC SAtom where modify (<> FV (singletonNameMapE n $ AccessInfo One a)) ty' <- occTy ty return $ Var (AtomVar n ty') - ProjectElt t i x -> ProjectElt <$> occ a t <*> pure i <*> occ a x atom -> runOCCMVisitor a $ visitAtomPartial atom instance HasOCC SType where @@ -451,12 +450,13 @@ instance HasOCC (Hof SimpIR) where oneShot :: Access n -> [IxExpr n] -> LamExpr SimpIR n -> OCCM n (LamExpr SimpIR n) oneShot acc [] (LamExpr Empty body) = LamExpr Empty <$> occNest acc body -oneShot acc (ix:ixs) (LamExpr (Nest (BD b) bs) body) = do - occWithBinder (Abs b (LamExpr bs body)) \b' restLam -> - extend b' (sink ix) do - LamExpr bs' body' <- oneShot (sink acc) (map sink ixs) restLam - return $ LamExpr (Nest (BD b') bs') body' -oneShot _ _ _ = error "zip error" +oneShot acc (ix:ixs) (LamExpr (Nest _ bs) body) = undefined +-- oneShot acc (ix:ixs) (LamExpr (Nest (BD b) bs) body) = do +-- occWithBinder (Abs b (LamExpr bs body)) \b' restLam -> +-- extend b' (sink ix) do +-- LamExpr bs' body' <- oneShot (sink acc) (map sink ixs) restLam +-- return $ LamExpr (Nest (BD b') bs') body' +-- oneShot _ _ _ = error "zip error" -- Going under a lambda binder. occWithBinder diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 2aceee2fc..5162c0054 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -26,11 +26,11 @@ import Name import Subst import IRVariants import Core -import CheapReduction import Builder import QueryType import Util (iota) import Err +import Visitor optimize :: EnvReader m => STopLam n -> m n (STopLam n) optimize = dceTop -- Clean up user code @@ -312,32 +312,33 @@ hoistLoopInvariant = liftLamExpr hoistLoopInvariantBlock licmExpr :: Emits o => SExpr i -> LICMM i o (SAtom o) licmExpr = \case - PrimOp (DAMOp (Seq _ dir ix (ProdVal dests) (LamExpr (UnaryNest b) body))) -> do - ix' <- substM ix - dests' <- mapM visitAtom dests - let numCarriesOriginal = length dests' - Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do - -- First, traverse the block, to allow any Hofs inside it to hoist their own decls. - Abs decls ans <- buildBlock $ visitBlockEmits body - -- Now, we process the decls and decide which ones to hoist. - liftEnvReaderM $ runSubstReaderT idSubst $ - seqLICM REmpty mempty b' REmpty decls ans - PairE (ListE extraDests) ab <- emitDecls $ Abs hdecls destsAndBody - extraDests' <- mapM toAtomVar extraDests - -- Append the destinations of hoisted Allocs as loop carried values. - let dests'' = ProdVal $ dests' ++ (Var <$> extraDests') - let carryTy = getType dests'' - let lbTy = case ix' of IxType ixTy _ -> PairTy ixTy carryTy - extraDestsTyped <- forM extraDests' \(AtomVar d t) -> return (d, t) - Abs extraDestBs (Abs lb bodyAbs) <- return $ abstractFreeVars extraDestsTyped ab - let extraDestBs' = fmapNest PlainBD extraDestBs - body' <- withFreshBinder noHint lbTy \lb' -> do - (oldIx, allCarries) <- fromPair $ Var $ binderVar lb' - (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpacked allCarries - let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) - block <- instantiate (Abs (extraDestBs' >>> UnaryNest lb) bodyAbs) (newCarries <> [oldLoopBinderVal]) - return $ UnaryLamExpr lb' block - emitSeq dir ix' dests'' body' + PrimOp (DAMOp (Seq _ dir ix (ProdVal dests) (LamExpr (UnaryNest b) body))) -> undefined + -- PrimOp (DAMOp (Seq _ dir ix (ProdVal dests) (LamExpr (UnaryNest b) body))) -> do + -- ix' <- substM ix + -- dests' <- mapM visitAtom dests + -- let numCarriesOriginal = length dests' + -- Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do + -- -- First, traverse the block, to allow any Hofs inside it to hoist their own decls. + -- Abs decls ans <- buildBlock $ visitBlockEmits body + -- -- Now, we process the decls and decide which ones to hoist. + -- liftEnvReaderM $ runSubstReaderT idSubst $ + -- seqLICM REmpty mempty b' REmpty decls ans + -- PairE (ListE extraDests) ab <- emitDecls $ Abs hdecls destsAndBody + -- extraDests' <- mapM toAtomVar extraDests + -- -- Append the destinations of hoisted Allocs as loop carried values. + -- let dests'' = ProdVal $ dests' ++ (Var <$> extraDests') + -- let carryTy = getType dests'' + -- let lbTy = case ix' of IxType ixTy _ -> PairTy ixTy carryTy + -- extraDestsTyped <- forM extraDests' \(AtomVar d t) -> return (d, t) + -- Abs extraDestBs (Abs lb bodyAbs) <- return $ abstractFreeVars extraDestsTyped ab + -- let extraDestBs' = fmapNest PlainBD extraDestBs + -- body' <- withFreshBinder noHint lbTy \lb' -> do + -- (oldIx, allCarries) <- fromPair $ Var $ binderVar lb' + -- (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpacked allCarries + -- let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) + -- block <- instantiate (Abs (extraDestBs' >>> UnaryNest lb) bodyAbs) (newCarries <> [oldLoopBinderVal]) + -- return $ UnaryLamExpr lb' block + -- emitSeq dir ix' dests'' body' PrimOp (Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body)))) -> do ix' <- substM ix Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do @@ -427,7 +428,6 @@ instance Color c => HasDCE (Name c) where instance HasDCE SAtom where dce = \case Var n -> modify (<> FV (freeVarsE n)) $> Var n - ProjectElt t i x -> ProjectElt <$> dce t <*> pure i <*> dce x atom -> visitAtomPartial atom instance HasDCE SType where dce = visitTypePartial diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index b3ab47ecc..235a14836 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -41,7 +41,7 @@ import Types.Imp import Types.Misc import Types.Primitives import Types.Source -import QueryTypePure +import QueryType import Util (Tree (..)) -- A DocPrec is a slightly context-aware Doc, specifically one that @@ -160,7 +160,8 @@ instance PrettyE ann => Pretty (BinderP c ann n l) where pretty (b:>ty) = p b <> ":" <> p ty instance IRRep r => Pretty (BinderAndDecls r n l) where - pretty (BD b) = pretty b + pretty (BD b Empty) = pretty b + pretty (BD b ds) = pretty b <> pretty ds instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (Expr r n) where @@ -172,6 +173,8 @@ instance IRRep r => PrettyPrec (Expr r n) where prettyPrec (TabCon _ _ es) = atPrec ArgPrec $ list $ pApp <$> es prettyPrec (PrimOp op) = prettyPrec op prettyPrec (ApplyMethod _ d i xs) = atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs + prettyPrec (ProjectElt _ idxs v) = atPrec LowestPrec $ "ProjectElt" <+> p idxs <+> p v + prettyPrec (DictHole _ e _) = atPrec LowestPrec $ "synthesize" <+> pApp e prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann prettyPrecCase name e alts effs = atPrec LowestPrec $ @@ -254,10 +257,8 @@ instance IRRep r => PrettyPrec (Atom r n) where PtrVar _ v -> atPrec ArgPrec $ p v DictCon _ d -> atPrec LowestPrec $ p d RepValAtom x -> atPrec LowestPrec $ pretty x - ProjectElt _ idxs v -> atPrec LowestPrec $ "ProjectElt" <+> p idxs <+> p v NewtypeCon con x -> prettyPrecNewtype con x SimpInCore x -> prettyPrec x - DictHole _ e _ -> atPrec LowestPrec $ "synthesize" <+> pApp e TypeAsAtom ty -> prettyPrec ty instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec @@ -270,8 +271,6 @@ instance IRRep r => PrettyPrec (Type r n) where DictTy t -> atPrec LowestPrec $ p t NewtypeTyCon con -> prettyPrec con TyVar v -> atPrec ArgPrec $ p v - ProjectEltTy _ idxs v -> - atPrec LowestPrec $ "ProjectElt" <+> p idxs <+> p v instance Pretty (SimpInCore n) where pretty = prettyFromPrettyPrec instance PrettyPrec (SimpInCore n) where @@ -327,16 +326,16 @@ withExplParens (Inferred _ Unify) x = braces $ x withExplParens (Inferred _ (Synth _)) x = brackets x instance IRRep r => Pretty (TabPiType r n) where - pretty (TabPiType dict (BD (b:>ty)) body) = let + pretty (TabPiType dict (BD (b:>ty) ds) body) = let prettyBody = case body of Pi subpi -> pretty subpi _ -> pLowest body prettyBinder = case dict of - IxDictRawFin n -> if binderName b `isFreeIn` body + IxDictRawFin n -> if binderName b `isFreeIn` (Abs ds body) then parens $ p b <> ":" <> prettyTy else prettyTy where prettyTy = "RawFin" <+> p n - _ -> prettyBinderHelper (b:>ty) body + _ -> prettyBinderHelper (b:>ty) (Abs ds body) in prettyBinder <> prettyIxDict dict <> (group $ line <> "=>" <+> prettyBody) -- A helper to let us turn dict printing on and off. We mostly want it off to diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 92a74dfad..eedce70fd 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -1,439 +1,315 @@ --- Copyright 2022 Google LLC +-- Copyright 2023 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module QueryType (module QueryType, module QueryTypePure, toAtomVar) where - -import Control.Category ((>>>)) -import Control.Monad -import Data.List (elemIndex) -import Data.Functor ((<&>)) +module QueryType where import Types.Primitives import Types.Core -import Types.Source -import Types.Imp import IRVariants -import Core -import Err -import Name hiding (withFreshM) -import Subst -import Util -import PPrint () -import QueryTypePure -import CheapReduction - -sourceNameType :: (EnvReader m, Fallible1 m) => SourceName -> m n (Type CoreIR n) -sourceNameType v = do - lookupSourceMap v >>= \case - Nothing -> throw UnboundVarErr $ pprint v - Just uvar -> getUVarType uvar - --- === Exposed helpers for querying types and effects === - -caseAltsBinderTys :: (EnvReader m, IRRep r) => Type r n -> m n [Type r n] -caseAltsBinderTys ty = case ty of - SumTy types -> return types - NewtypeTyCon t -> case t of - UserADTType _ defName params -> do - def <- lookupTyCon defName - ~(ADTCons cons) <- instantiateTyConDef def params - return [repTy | DataConDef _ _ repTy _ <- cons] - _ -> error msg - _ -> error msg - where msg = "Case analysis only supported on ADTs, not on " ++ pprint ty - -extendEffect :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n -extendEffect eff (EffectRow effs t) = EffectRow (effs <> eSetSingleton eff) t - -blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) -blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do - effs <- declsEffects decls mempty - return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result - where - declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) - declsEffects Empty !acc = return acc - declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do - expr' <- sinkM expr - declsEffects rest $ acc <> getEffects expr' - -blockTy :: (EnvReader m, IRRep r) => Block r n -> m n (Type r n) -blockTy b = blockEffTy b <&> \(EffTy _ t) -> t - -piTypeWithoutDest :: PiType SimpIR n -> PiType SimpIR n -piTypeWithoutDest (PiType bsRefB _) = - case popNest bsRefB of - Just (PairB bs refB) -> do - case binderType refB of - RawRefTy ansTy -> PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here - _ -> error "expected ref type" - _ -> error "expected trailing binder" - -blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) -blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff - -typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfApp (Pi piTy) xs = withSubstReaderT $ - withInstantiated piTy xs \(EffTy _ ty) -> substM ty -typeOfApp _ _ = error "expected a pi type" - -typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfTabApp t [] = return t -typeOfTabApp (TabPi tabTy) (i:rest) = do - resultTy <- instantiate tabTy [i] - typeOfTabApp resultTy rest -typeOfTabApp ty _ = error $ "expected a table type. Got: " ++ pprint ty - -typeOfApplyMethod :: EnvReader m => CAtom n -> Int -> [CAtom n] -> m n (EffTy CoreIR n) -typeOfApplyMethod d i args = do - ty <- Pi <$> getMethodType d i - appEffTy ty args - -typeOfDictExpr :: EnvReader m => DictExpr n -> m n (CType n) -typeOfDictExpr e = liftM ignoreExcept $ liftEnvReaderT $ case e of - InstanceDict instanceName args -> do - instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName - sourceName <- getSourceName <$> lookupClassDef className - PairE (ListE params) _ <- instantiate instanceDef args - return $ DictTy $ DictType sourceName className params - InstantiatedGiven given args -> typeOfApp (getType given) args - SuperclassProj d i -> do - DictTy (DictType _ className params) <- return $ getType d - classDef <- lookupClassDef className - withSubstReaderT $ withInstantiated classDef params \(Abs superclasses _) -> do - substM $ getSuperclassType REmpty superclasses i - IxFin n -> liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n - DataData ty -> DictTy <$> dataDictType ty - -typeOfTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (EffTy SimpIR n) -typeOfTopApp f xs = do - piTy <- getTypeTopFun f - instantiate piTy xs - -typeOfIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Type r n -> Atom r n -> m n (Type r n) -typeOfIndexRef (TC (RefType h s)) i = do - TabPi tabPi <- return s - eltTy <- instantiate tabPi [i] - return $ TC $ RefType h eltTy -typeOfIndexRef _ _ = error "expected a ref type" - -typeOfProjRef :: EnvReader m => Type r n -> Projection -> m n (Type r n) -typeOfProjRef (TC (RefType h s)) p = do - TC . RefType h <$> case p of - ProjectProduct i -> do - ~(ProdTy tys) <- return s - return $ tys !! i - UnwrapNewtype -> do - case s of - NewtypeTyCon tc -> snd <$> unwrapNewtypeType tc - _ -> error "expected a newtype" -typeOfProjRef _ _ = error "expected a reference" - -appEffTy :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (EffTy r n) -appEffTy (Pi piTy) xs = instantiate piTy xs -appEffTy t _ = error $ "expected a pi type, got: " ++ pprint t - -partialAppType :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -partialAppType (Pi (CorePiType appExpl expls bs effTy)) xs = do - (_, expls2) <- return $ splitAt (length xs) expls - PairB bs1 bs2 <- return $ splitNestAt (length xs) bs - instantiate (Abs bs1 (Pi $ CorePiType appExpl expls2 bs2 effTy)) xs -partialAppType _ _ = error "expected a pi type" - -effTyOfHof :: (EnvReader m, IRRep r) => Hof r n -> m n (EffTy r n) -effTyOfHof hof = EffTy <$> hofEffects hof <*> typeOfHof hof - -typeOfHof :: (EnvReader m, IRRep r) => Hof r n -> m n (Type r n) -typeOfHof = \case - For _ ixTy f -> getLamExprType f >>= \case - PiType (UnaryNest b) (EffTy _ eltTy) -> return $ TabTy (ixTypeDict ixTy) b eltTy - _ -> error "expected a unary pi type" - While _ -> return UnitTy - Linearize f _ -> getLamExprType f >>= \case - PiType (UnaryNest binder) (EffTy Pure b) -> do - let b' = ignoreHoistFailure $ hoist binder b - let fLinTy = Pi $ nonDepPiType [binderType binder] Pure b' - return $ PairTy b' fLinTy - _ -> error "expected a unary pi type" - Transpose f _ -> getLamExprType f >>= \case - PiType (UnaryNest b) _ -> return $ binderType b - _ -> error "expected a unary pi type" - RunReader _ f -> do - (resultTy, _) <- getTypeRWSAction f - return resultTy - RunWriter _ _ f -> uncurry PairTy <$> getTypeRWSAction f - RunState _ _ f -> do - (resultTy, stateTy) <- getTypeRWSAction f - return $ PairTy resultTy stateTy - RunIO f -> blockTy f - RunInit f -> blockTy f - CatchException ty _ -> return ty - -hofEffects :: (EnvReader m, IRRep r) => Hof r n -> m n (EffectRow r n) -hofEffects = \case - For _ _ f -> functionEffs f - While body -> blockEff body - Linearize _ _ -> return Pure -- Body has to be a pure function - Transpose _ _ -> return Pure -- Body has to be a pure function - RunReader _ f -> rwsFunEffects Reader f - RunWriter d _ f -> maybeInit d <$> rwsFunEffects Writer f - RunState d _ f -> maybeInit d <$> rwsFunEffects State f - RunIO f -> deleteEff IOEffect <$> blockEff f - RunInit f -> deleteEff InitEffect <$> blockEff f - CatchException _ f -> deleteEff ExceptionEffect <$> blockEff f - where maybeInit :: IRRep r => Maybe (Atom r i) -> (EffectRow r o -> EffectRow r o) - maybeInit d = case d of Just _ -> (<>OneEffect InitEffect); Nothing -> id - -deleteEff :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n -deleteEff eff (EffectRow effs t) = EffectRow (effs `eSetDifference` eSetSingleton eff) t - -getMethodIndex :: EnvReader m => ClassName n -> SourceName -> m n Int -getMethodIndex className methodSourceName = do - ClassDef _ methodNames _ _ _ _ _ <- lookupClassDef className - case elemIndex methodSourceName methodNames of - Nothing -> error $ methodSourceName ++ " is not a method of " ++ pprint className - Just i -> return i -{-# INLINE getMethodIndex #-} - -getUVarType :: EnvReader m => UVar n -> m n (CType n) -getUVarType = \case - UAtomVar v -> getType <$> toAtomVar v - UTyConVar v -> getTyConNameType v - UDataConVar v -> getDataConNameType v - UPunVar v -> getStructDataConType v - UClassVar v -> do - ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef v - return $ Pi $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind - UMethodVar v -> getMethodNameType v - UEffectVar _ -> error "not implemented" - UEffectOpVar _ -> error "not implemented" - -getMethodNameType :: EnvReader m => MethodName n -> m n (CType n) -getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case - MethodBinding className i -> do - ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className - refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' absPiTy -> do - let params = Var <$> bindersVars paramBs' - dictTy <- DictTy <$> dictType (sink className) params - withFreshBinder noHint dictTy \dictB -> do - scDicts <- getSuperclassDicts (Var $ binderVar dictB) - CorePiType appExpl methodExpls methodBs effTy <- instantiate absPiTy scDicts - let paramExpls = paramNames <&> \name -> Inferred name Unify - let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls - return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest (BD dictB) >>> methodBs) effTy - -getMethodType :: EnvReader m => Dict n -> Int -> m n (CorePiType n) -getMethodType dict i = liftEnvReaderM $ withSubstReaderT do - ~(DictTy (DictType _ className params)) <- return $ getType dict - superclassDicts <- getSuperclassDicts dict - classDef <- lookupClassDef className - withInstantiated classDef params \ab -> do - withInstantiated ab superclassDicts \(ListE methodTys) -> - substM $ methodTys !! i - -getTyConNameType :: EnvReader m => TyConName n -> m n (Type CoreIR n) -getTyConNameType v = do - TyConDef _ expls bs _ <- lookupTyCon v - case bs of - Empty -> return TyKind - _ -> return $ Pi $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind - -getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n) -getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do - (tyCon, i) <- lookupDataCon dataCon - tyConDef <- lookupTyCon tyCon - buildDataConType tyConDef \expls paramBs' paramVs params -> do - withInstantiatedNames tyConDef paramVs \(ADTCons dataCons) -> do - DataConDef _ ab _ _ <- renameM (dataCons !! i) - refreshAbs ab \dataBs UnitE -> do - let appExpl = case dataBs of Empty -> ImplicitApp - _ -> ExplicitApp - let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) (sink params) - let dataExpls = nestToList (const $ Explicit) dataBs - return $ Pi $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) - -getStructDataConType :: EnvReader m => TyConName n -> m n (CType n) -getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do - tyConDef <- lookupTyCon tyCon - buildDataConType tyConDef \expls paramBs' paramVs params -> do - withInstantiatedNames tyConDef paramVs \(StructFields fields) -> do - fieldTys <- forM fields \(_, t) -> renameM t - let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) params - Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy - let dataExpls = nestToList (const Explicit) dataBs - return $ Pi $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') - -buildDataConType - :: (EnvReader m, EnvExtender m) - => TyConDef n - -> (forall l. DExt n l => [Explicitness] -> CBinders n l -> [CAtomName l] -> TyConParams l -> m l a) - -> m n a -buildDataConType (TyConDef _ roleExpls bs _) cont = do - let expls = snd <$> roleExpls - expls' <- forM expls \case - Explicit -> return $ Inferred Nothing Unify - expl -> return $ expl - refreshAbs (Abs bs UnitE) \bs' UnitE -> do - let vs = bindersVars bs' - cont expls' bs' (atomVarName <$> vs) $ TyConParams expls (Var <$> vs) - -makeTyConParams :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n) -makeTyConParams tc params = do - TyConDef _ expls _ _ <- lookupTyCon tc - return $ TyConParams (map snd expls) params - -getDataClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) -getDataClassName = lookupSourceMap "Data" >>= \case - Nothing -> throw CompilerErr $ "Data interface needed but not defined!" - Just (UClassVar v) -> return v - Just _ -> error "not a class var" - -dataDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) -dataDictType ty = do - dataClassName <- getDataClassName - dictType dataClassName [Type ty] - -getIxClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) -getIxClassName = lookupSourceMap "Ix" >>= \case - Nothing -> throw CompilerErr $ "Ix interface needed but not defined!" - Just (UClassVar v) -> return v - Just _ -> error "not a class var" - -dictType :: EnvReader m => ClassName n -> [CAtom n] -> m n (DictType n) -dictType className params = do - ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className - return $ DictType sourceName className params - -ixDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) -ixDictType ty = do - ixClassName <- getIxClassName - dictType ixClassName [Type ty] - -makePreludeMaybeTy :: EnvReader m => CType n -> m n (CType n) -makePreludeMaybeTy ty = do - ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" - return $ TypeCon "Maybe" tyConName $ TyConParams [Explicit] [Type ty] - --- === computing effects === - -functionEffs :: (IRRep r, EnvReader m) => LamExpr r n -> m n (EffectRow r n) -functionEffs f = getLamExprType f >>= \case - PiType b (EffTy effs _) -> return $ ignoreHoistFailure $ hoist b effs - -rwsFunEffects :: (IRRep r, EnvReader m) => RWS -> LamExpr r n -> m n (EffectRow r n) -rwsFunEffects rws f = liftEnvReaderM $ getLamExprType f >>= \case - PiType (BinaryNest h ref) et -> do - let effs' = ignoreHoistFailure $ hoist ref (etEff et) - refreshAbs (Abs h effs') \h' effs'' -> do - let hVal = Var $ binderVar h' - let effs''' = deleteEff (RWSEffect rws hVal) effs'' - return $ ignoreHoistFailure $ hoist h' effs''' - _ -> error "Expected a binary function type" - -getLamExprType :: (IRRep r, EnvReader m) => LamExpr r n -> m n (PiType r n) -getLamExprType (LamExpr bs body) = liftEnvReaderM $ - refreshAbs (Abs bs body) \bs' body' -> do - effTy <- blockEffTy body' - return $ PiType bs' effTy - -getTypeRWSAction :: (IRRep r, EnvReader m) => LamExpr r n -> m n (Type r n, Type r n) -getTypeRWSAction f = getLamExprType f >>= \case - PiType (BinaryNest regionBinder refBinder) (EffTy _ resultTy) -> do - case binderType refBinder of - RefTy _ referentTy -> do - let referentTy' = ignoreHoistFailure $ hoist regionBinder referentTy - let resultTy' = ignoreHoistFailure $ hoist (PairB regionBinder refBinder) resultTy - return (resultTy', referentTy') - _ -> error "expected a ref" - _ -> error "expected a pi type" - -getSuperclassDicts :: EnvReader m => CAtom n -> m n ([CAtom n]) -getSuperclassDicts dict = do - case getType dict of - DictTy dTy -> do - ts <- getSuperclassTys dTy - forM (enumerate ts) \(i, t) -> return $ DictCon t $ SuperclassProj dict i - _ -> error "expected a dict type" - -getSuperclassTys :: EnvReader m => DictType n -> m n [CType n] -getSuperclassTys (DictType _ className params) = do - ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className - forM [0 .. nestLength superclasses - 1] \i -> do - instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params - -getSuperclassType :: RNest (BinderAndDecls CoreIR) n l -> CBinders l l' -> Int -> CType n -getSuperclassType _ Empty = error "bad index" -getSuperclassType bsAbove (Nest b bs) = \case - 0 -> ignoreHoistFailure $ hoist bsAbove (binderType b) - i -> getSuperclassType (RNest bsAbove b) bs (i-1) - - -getTypeTopFun :: EnvReader m => TopFunName n -> m n (PiType SimpIR n) -getTypeTopFun f = lookupTopFun f >>= \case - DexTopFun _ (TopLam _ piTy _) _ -> return piTy - FFITopFun _ iTy -> liftIFunType iTy - -asTopLam :: (EnvReader m, IRRep r) => LamExpr r n -> m n (TopLam r n) -asTopLam lam = do - piTy <- getLamExprType lam - return $ TopLam False piTy lam - -liftIFunType :: (IRRep r, EnvReader m) => IFunType -> m n (PiType r n) -liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where - go :: IRRep r => [BaseType] -> EnvReaderM n (PiType r n) - go = \case - [] -> return $ PiType Empty (EffTy (OneEffect IOEffect) resultTy) - where resultTy = case resultTys of - [] -> UnitTy - [t] -> BaseTy t - [t1, t2] -> PairTy (BaseTy t1) (BaseTy t2) - _ -> error $ "Not a valid FFI return type: " ++ pprint resultTys - t:ts -> withFreshBinder noHint (BaseTy t) \b -> do - PiType bs effTy <- go ts - return $ PiType (Nest (PlainBD b) bs) effTy - --- === Data constraints === - -isData :: EnvReader m => Type CoreIR n -> m n Bool -isData ty = do - result <- liftEnvReaderT $ withSubstReaderT $ checkDataLike ty - case runFallibleM result of - Success () -> return True - Failure _ -> return False - -checkDataLike :: Type CoreIR i -> SubstReaderT Name FallibleEnvReaderM i o () -checkDataLike ty = case ty of - TyVar _ -> notData - TabPi (TabPiType _ b eltTy) -> do - renameBinders b \_ -> - checkDataLike eltTy - DepPairTy (DepPairType _ b r) -> do - recur $ binderType b - renameBinders b \_ -> checkDataLike r - NewtypeTyCon nt -> do - (_, ty') <- unwrapNewtypeType =<< renameM nt - dropSubst $ recur ty' - TC con -> case con of - BaseType _ -> return () - ProdType as -> mapM_ recur as - SumType cs -> mapM_ recur cs - RefType _ _ -> return () - HeapType -> return () - _ -> notData - _ -> notData - where - recur = checkDataLike - notData = throw TypeErr $ pprint ty - -checkExtends :: (Fallible m, IRRep r) => EffectRow r n -> EffectRow r n -> m () -checkExtends allowed (EffectRow effs effTail) = do - let (EffectRow allowedEffs allowedEffTail) = allowed - case effTail of - EffectRowTail _ -> assertEq allowedEffTail effTail "" - NoTail -> return () - forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $ - throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ - "\nAllowed: " ++ pprint allowed - +import Name + +class HasType (r::IR) (e::E) | e -> r where + getType :: e n -> Type r n + +class HasEffects (e::E) (r::IR) | e -> r where + getEffects :: e n -> EffectRow r n + +isPure :: (IRRep r, HasEffects e r) => e n -> Bool +isPure e = case getEffects e of + Pure -> True + _ -> False + +-- === querying types implementation === + +instance IRRep r => HasType r (AtomBinding r) where + getType = \case + LetBound (DeclBinding _ e) -> getType e + MiscBound ty -> ty + SolverBound (InfVarBound ty _) -> ty + SolverBound (SkolemBound ty) -> ty + NoinlineFun ty _ -> ty + TopDataBound (RepVal ty _) -> ty + FFIFunBound piTy _ -> Pi piTy + +litType :: LitVal -> BaseType +litType v = case v of + Int64Lit _ -> Scalar Int64Type + Int32Lit _ -> Scalar Int32Type + Word8Lit _ -> Scalar Word8Type + Word32Lit _ -> Scalar Word32Type + Word64Lit _ -> Scalar Word64Type + Float64Lit _ -> Scalar Float64Type + Float32Lit _ -> Scalar Float32Type + PtrLit ty _ -> PtrType ty + +typeBinOp :: BinOp -> BaseType -> BaseType +typeBinOp binop xTy = case binop of + IAdd -> xTy; ISub -> xTy + IMul -> xTy; IDiv -> xTy + IRem -> xTy; + ICmp _ -> Scalar Word8Type + FAdd -> xTy; FSub -> xTy + FMul -> xTy; FDiv -> xTy; + FPow -> xTy + FCmp _ -> Scalar Word8Type + BAnd -> xTy; BOr -> xTy + BXor -> xTy + BShL -> xTy; BShR -> xTy + +typeUnOp :: UnOp -> BaseType -> BaseType +typeUnOp = const id -- All unary ops preserve the type of the input + + +instance IRRep r => HasType r (AtomVar r) where + getType (AtomVar _ ty) = ty + {-# INLINE getType #-} + +instance IRRep r => HasType r (Atom r) where + getType atom = case atom of + Var name -> getType name + Lam (CoreLamExpr piTy _) -> Pi piTy + DepPair _ _ ty -> DepPairTy ty + Con con -> getType con + Eff _ -> EffKind + PtrVar t _ -> PtrTy t + DictCon ty _ -> ty + NewtypeCon con _ -> getNewtypeType con + RepValAtom (RepVal ty _) -> ty + SimpInCore x -> getType x + TypeAsAtom ty -> getType ty + +instance IRRep r => HasType r (Type r) where + getType = \case + NewtypeTyCon con -> getType con + Pi _ -> TyKind + TabPi _ -> TyKind + DepPairTy _ -> TyKind + TC _ -> TyKind + DictTy _ -> TyKind + TyVar v -> getType v + +instance HasType CoreIR SimpInCore where + getType = \case + LiftSimp t _ -> t + LiftSimpFun piTy _ -> Pi $ piTy + TabLam t _ -> TabPi $ t + ACase _ _ t -> t + +instance HasType CoreIR NewtypeTyCon where + getType _ = TyKind + +getNewtypeType :: NewtypeCon n -> CType n +getNewtypeType con = case con of + NatCon -> NewtypeTyCon Nat + FinCon n -> NewtypeTyCon $ Fin n + UserADTData sn d params -> NewtypeTyCon $ UserADTType sn d params + +instance IRRep r => HasType r (Con r) where + getType = \case + Lit l -> BaseTy $ litType l + ProdCon xs -> ProdTy $ map getType xs + SumCon tys _ _ -> SumTy tys + HeapVal -> TC HeapType + +instance IRRep r => HasType r (Expr r) where + getType expr = case expr of + App (EffTy _ ty) _ _ -> ty + TopApp (EffTy _ ty) _ _ -> ty + TabApp t _ _ -> t + Atom x -> getType x + TabCon _ ty _ -> ty + PrimOp op -> getType op + Case _ _ (EffTy _ resultTy) -> resultTy + ApplyMethod (EffTy _ t) _ _ _ -> t + ProjectElt t _ _ -> t + DictHole _ ty _ -> ty + +instance IRRep r => HasType r (DAMOp r) where + getType = \case + AllocDest ty -> RawRefTy ty + Place _ _ -> UnitTy + Freeze ref -> case getType ref of + RawRefTy ty -> ty + ty -> error $ "Not a reference type: " ++ show ty + Seq _ _ _ cinit _ -> getType cinit + RememberDest _ d _ -> getType d + +instance IRRep r => HasType r (PrimOp r) where + getType primOp = case primOp of + BinOp op x _ -> TC $ BaseType $ typeBinOp op $ getTypeBaseType x + UnOp op x -> TC $ BaseType $ typeUnOp op $ getTypeBaseType x + Hof (TypedHof (EffTy _ ty) _) -> ty + MemOp op -> getType op + MiscOp op -> getType op + VectorOp op -> getType op + DAMOp op -> getType op + RefOp ref m -> case getType ref of + TC (RefType _ s) -> case m of + MGet -> s + MPut _ -> UnitTy + MAsk -> s + MExtend _ _ -> UnitTy + IndexRef t _ -> t + ProjRef t _ -> t + _ -> error "not a reference type" + +getTypeBaseType :: (IRRep r, HasType r e) => e n -> BaseType +getTypeBaseType e = case getType e of + TC (BaseType b) -> b + ty -> error $ "Expected a base type. Got: " ++ show ty + +instance IRRep r => HasType r (MemOp r) where + getType = \case + IOAlloc _ -> PtrTy (CPU, Scalar Word8Type) + IOFree _ -> UnitTy + PtrOffset arr _ -> getType arr + PtrLoad ptr -> do + let PtrTy (_, t) = getType ptr + BaseTy t + PtrStore _ _ -> UnitTy + +instance IRRep r => HasType r (VectorOp r) where + getType = \case + VectorBroadcast _ vty -> vty + VectorIota vty -> vty + VectorIdx _ _ vty -> vty + VectorSubref ref _ vty -> case getType ref of + TC (RefType h _) -> TC $ RefType h vty + ty -> error $ "Not a reference type: " ++ show ty + +instance IRRep r => HasType r (MiscOp r) where + getType = \case + Select _ x _ -> getType x + ThrowError t -> t + ThrowException t -> t + CastOp t _ -> t + BitcastOp t _ -> t + UnsafeCoerce t _ -> t + GarbageVal t -> t + SumTag _ -> TagRepTy + ToEnum t _ -> t + OutputStream -> BaseTy $ hostPtrTy $ Scalar Word8Type + where hostPtrTy ty = PtrType (CPU, ty) + ShowAny _ -> rawStrType -- TODO: constrain `ShowAny` to have `HasCore r` + ShowScalar _ -> PairTy IdxRepTy $ rawFinTabType (IdxRepVal showStringBufferSize) CharRepTy + +rawStrType :: IRRep r => Type r n +rawStrType = case newName "n" of + Abs b v -> do + let tabTy = rawFinTabType (Var $ AtomVar v IdxRepTy) CharRepTy + DepPairTy $ DepPairType ExplicitDepPair (PlainBD (b:>IdxRepTy)) tabTy + +-- `n` argument is IdxRepVal, not Nat +rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n +rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy + +depPairLeftTy :: IRRep r => DepPairType r n -> Type r n +depPairLeftTy (DepPairType _ b _) = binderType b +{-# INLINE depPairLeftTy #-} + +tabIxType :: IRRep r => TabPiType r n -> IxType r n +tabIxType (TabPiType d b _) = IxType (binderType b) d + +typesAsBinderNest + :: (SinkableE e, HoistableE e, IRRep r) + => [Type r n] -> e n -> Abs (Binders r) e n +typesAsBinderNest types body = + case toConstBinderNest types body of + Abs bs body' -> Abs (fmapNest PlainBD bs) body' + +nonDepPiType :: [CType n] -> EffectRow CoreIR n -> CType n -> CorePiType n +nonDepPiType argTys eff resultTy = case typesAsBinderNest argTys (PairE eff resultTy) of + Abs bs (PairE eff' resultTy') -> do + let expls = nestToList (const Explicit) bs + CorePiType ExplicitApp expls bs $ EffTy eff' resultTy' + +nonDepTabPiType :: IRRep r => IxType r n -> Type r n -> TabPiType r n +nonDepTabPiType (IxType t d) resultTy = + case toConstAbsPure resultTy of + Abs b resultTy' -> TabPiType d (PlainBD (b:>t)) resultTy' + +corePiTypeToPiType :: CorePiType n -> PiType CoreIR n +corePiTypeToPiType (CorePiType _ _ bs effTy) = PiType bs effTy + +coreLamToTopLam :: CoreLamExpr n -> TopLam CoreIR n +coreLamToTopLam (CoreLamExpr ty f) = TopLam False (corePiTypeToPiType ty) f + +(==>) :: IRRep r => IxType r n -> Type r n -> Type r n +a ==> b = TabPi $ nonDepTabPiType a b + +litFinIxTy :: Int -> IxType r n +litFinIxTy n = finIxTy $ IdxRepVal $ fromIntegral n + +finIxTy :: Atom r n -> IxType r n +finIxTy n = IxType IdxRepTy (IxDictRawFin n) + +ixTyFromDict :: IRRep r => IxDict r n -> IxType r n +ixTyFromDict ixDict = flip IxType ixDict $ case ixDict of + IxDictAtom dict -> case getType dict of + DictTy (DictType "Ix" _ [Type iTy]) -> iTy + _ -> error $ "Not an Ix dict: " ++ show dict + IxDictRawFin _ -> IdxRepTy + IxDictSpecialized n _ _ -> n + +-- === querying effects implementation === + +instance IRRep r => HasEffects (Expr r) r where + getEffects = \case + Atom _ -> Pure + App (EffTy eff _) _ _ -> eff + TopApp (EffTy eff _) _ _ -> eff + TabApp _ _ _ -> Pure + Case _ _ (EffTy effs _) -> effs + TabCon _ _ _ -> Pure + ApplyMethod (EffTy eff _) _ _ _ -> eff + PrimOp primOp -> getEffects primOp + +instance IRRep r => HasEffects (DeclBinding r) r where + getEffects (DeclBinding _ expr) = getEffects expr + {-# INLINE getEffects #-} + +instance IRRep r => HasEffects (PrimOp r) r where + getEffects = \case + UnOp _ _ -> Pure + BinOp _ _ _ -> Pure + VectorOp _ -> Pure + MemOp op -> case op of + IOAlloc _ -> OneEffect IOEffect + IOFree _ -> OneEffect IOEffect + PtrLoad _ -> OneEffect IOEffect + PtrStore _ _ -> OneEffect IOEffect + PtrOffset _ _ -> Pure + MiscOp op -> case op of + ThrowException _ -> OneEffect ExceptionEffect + Select _ _ _ -> Pure + ThrowError _ -> Pure + CastOp _ _ -> Pure + UnsafeCoerce _ _ -> Pure + GarbageVal _ -> Pure + BitcastOp _ _ -> Pure + SumTag _ -> Pure + ToEnum _ _ -> Pure + OutputStream -> Pure + ShowAny _ -> Pure + ShowScalar _ -> Pure + RefOp ref m -> case getType ref of + TC (RefType h _) -> case m of + MGet -> OneEffect (RWSEffect State h) + MPut _ -> OneEffect (RWSEffect State h) + MAsk -> OneEffect (RWSEffect Reader h) + -- XXX: We don't verify the base monoid. See note about RunWriter. + MExtend _ _ -> OneEffect (RWSEffect Writer h) + IndexRef _ _ -> Pure + ProjRef _ _ -> Pure + _ -> error "not a ref" + DAMOp op -> case op of + Place _ _ -> OneEffect InitEffect + Seq eff _ _ _ _ -> eff + RememberDest eff _ _ -> eff + AllocDest _ -> Pure -- is this correct? + Freeze _ -> Pure -- is this correct? + Hof (TypedHof (EffTy eff _) _) -> eff + {-# INLINE getEffects #-} diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs deleted file mode 100644 index b1fee94ac..000000000 --- a/src/lib/QueryTypePure.hs +++ /dev/null @@ -1,312 +0,0 @@ --- Copyright 2023 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module QueryTypePure where - -import Types.Primitives -import Types.Core -import IRVariants -import Name - -class HasType (r::IR) (e::E) | e -> r where - getType :: e n -> Type r n - -class HasEffects (e::E) (r::IR) | e -> r where - getEffects :: e n -> EffectRow r n - -isPure :: (IRRep r, HasEffects e r) => e n -> Bool -isPure e = case getEffects e of - Pure -> True - _ -> False - --- === querying types implementation === - -instance IRRep r => HasType r (AtomBinding r) where - getType = \case - LetBound (DeclBinding _ e) -> getType e - MiscBound ty -> ty - SolverBound (InfVarBound ty _) -> ty - SolverBound (SkolemBound ty) -> ty - NoinlineFun ty _ -> ty - TopDataBound (RepVal ty _) -> ty - FFIFunBound piTy _ -> Pi piTy - -litType :: LitVal -> BaseType -litType v = case v of - Int64Lit _ -> Scalar Int64Type - Int32Lit _ -> Scalar Int32Type - Word8Lit _ -> Scalar Word8Type - Word32Lit _ -> Scalar Word32Type - Word64Lit _ -> Scalar Word64Type - Float64Lit _ -> Scalar Float64Type - Float32Lit _ -> Scalar Float32Type - PtrLit ty _ -> PtrType ty - -typeBinOp :: BinOp -> BaseType -> BaseType -typeBinOp binop xTy = case binop of - IAdd -> xTy; ISub -> xTy - IMul -> xTy; IDiv -> xTy - IRem -> xTy; - ICmp _ -> Scalar Word8Type - FAdd -> xTy; FSub -> xTy - FMul -> xTy; FDiv -> xTy; - FPow -> xTy - FCmp _ -> Scalar Word8Type - BAnd -> xTy; BOr -> xTy - BXor -> xTy - BShL -> xTy; BShR -> xTy - -typeUnOp :: UnOp -> BaseType -> BaseType -typeUnOp = const id -- All unary ops preserve the type of the input - - -instance IRRep r => HasType r (AtomVar r) where - getType (AtomVar _ ty) = ty - {-# INLINE getType #-} - -instance IRRep r => HasType r (Atom r) where - getType atom = case atom of - Var name -> getType name - Lam (CoreLamExpr piTy _) -> Pi piTy - DepPair _ _ ty -> DepPairTy ty - Con con -> getType con - Eff _ -> EffKind - PtrVar t _ -> PtrTy t - DictCon ty _ -> ty - NewtypeCon con _ -> getNewtypeType con - RepValAtom (RepVal ty _) -> ty - ProjectElt t _ _ -> t - SimpInCore x -> getType x - DictHole _ ty _ -> ty - TypeAsAtom ty -> getType ty - -instance IRRep r => HasType r (Type r) where - getType = \case - NewtypeTyCon con -> getType con - Pi _ -> TyKind - TabPi _ -> TyKind - DepPairTy _ -> TyKind - TC _ -> TyKind - DictTy _ -> TyKind - TyVar v -> getType v - ProjectEltTy t _ _ -> t - -instance HasType CoreIR SimpInCore where - getType = \case - LiftSimp t _ -> t - LiftSimpFun piTy _ -> Pi $ piTy - TabLam t _ -> TabPi $ t - ACase _ _ t -> t - -instance HasType CoreIR NewtypeTyCon where - getType _ = TyKind - -getNewtypeType :: NewtypeCon n -> CType n -getNewtypeType con = case con of - NatCon -> NewtypeTyCon Nat - FinCon n -> NewtypeTyCon $ Fin n - UserADTData sn d params -> NewtypeTyCon $ UserADTType sn d params - -instance IRRep r => HasType r (Con r) where - getType = \case - Lit l -> BaseTy $ litType l - ProdCon xs -> ProdTy $ map getType xs - SumCon tys _ _ -> SumTy tys - HeapVal -> TC HeapType - -instance IRRep r => HasType r (Expr r) where - getType expr = case expr of - App (EffTy _ ty) _ _ -> ty - TopApp (EffTy _ ty) _ _ -> ty - TabApp t _ _ -> t - Atom x -> getType x - TabCon _ ty _ -> ty - PrimOp op -> getType op - Case _ _ (EffTy _ resultTy) -> resultTy - ApplyMethod (EffTy _ t) _ _ _ -> t - -instance IRRep r => HasType r (DAMOp r) where - getType = \case - AllocDest ty -> RawRefTy ty - Place _ _ -> UnitTy - Freeze ref -> case getType ref of - RawRefTy ty -> ty - ty -> error $ "Not a reference type: " ++ show ty - Seq _ _ _ cinit _ -> getType cinit - RememberDest _ d _ -> getType d - -instance IRRep r => HasType r (PrimOp r) where - getType primOp = case primOp of - BinOp op x _ -> TC $ BaseType $ typeBinOp op $ getTypeBaseType x - UnOp op x -> TC $ BaseType $ typeUnOp op $ getTypeBaseType x - Hof (TypedHof (EffTy _ ty) _) -> ty - MemOp op -> getType op - MiscOp op -> getType op - VectorOp op -> getType op - DAMOp op -> getType op - RefOp ref m -> case getType ref of - TC (RefType _ s) -> case m of - MGet -> s - MPut _ -> UnitTy - MAsk -> s - MExtend _ _ -> UnitTy - IndexRef t _ -> t - ProjRef t _ -> t - _ -> error "not a reference type" - -getTypeBaseType :: (IRRep r, HasType r e) => e n -> BaseType -getTypeBaseType e = case getType e of - TC (BaseType b) -> b - ty -> error $ "Expected a base type. Got: " ++ show ty - -instance IRRep r => HasType r (MemOp r) where - getType = \case - IOAlloc _ -> PtrTy (CPU, Scalar Word8Type) - IOFree _ -> UnitTy - PtrOffset arr _ -> getType arr - PtrLoad ptr -> do - let PtrTy (_, t) = getType ptr - BaseTy t - PtrStore _ _ -> UnitTy - -instance IRRep r => HasType r (VectorOp r) where - getType = \case - VectorBroadcast _ vty -> vty - VectorIota vty -> vty - VectorIdx _ _ vty -> vty - VectorSubref ref _ vty -> case getType ref of - TC (RefType h _) -> TC $ RefType h vty - ty -> error $ "Not a reference type: " ++ show ty - -instance IRRep r => HasType r (MiscOp r) where - getType = \case - Select _ x _ -> getType x - ThrowError t -> t - ThrowException t -> t - CastOp t _ -> t - BitcastOp t _ -> t - UnsafeCoerce t _ -> t - GarbageVal t -> t - SumTag _ -> TagRepTy - ToEnum t _ -> t - OutputStream -> BaseTy $ hostPtrTy $ Scalar Word8Type - where hostPtrTy ty = PtrType (CPU, ty) - ShowAny _ -> rawStrType -- TODO: constrain `ShowAny` to have `HasCore r` - ShowScalar _ -> PairTy IdxRepTy $ rawFinTabType (IdxRepVal showStringBufferSize) CharRepTy - -rawStrType :: IRRep r => Type r n -rawStrType = case newName "n" of - Abs b v -> do - let tabTy = rawFinTabType (Var $ AtomVar v IdxRepTy) CharRepTy - DepPairTy $ DepPairType ExplicitDepPair (PlainBD (b:>IdxRepTy)) tabTy - --- `n` argument is IdxRepVal, not Nat -rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n -rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy - -tabIxType :: IRRep r => TabPiType r n -> IxType r n -tabIxType (TabPiType d b _) = IxType (binderType b) d - -typesAsBinderNest - :: (SinkableE e, HoistableE e, IRRep r) - => [Type r n] -> e n -> Abs (Binders r) e n -typesAsBinderNest types body = - case toConstBinderNest types body of - Abs bs body' -> Abs (fmapNest PlainBD bs) body' - -nonDepPiType :: [CType n] -> EffectRow CoreIR n -> CType n -> CorePiType n -nonDepPiType argTys eff resultTy = case typesAsBinderNest argTys (PairE eff resultTy) of - Abs bs (PairE eff' resultTy') -> do - let expls = nestToList (const Explicit) bs - CorePiType ExplicitApp expls bs $ EffTy eff' resultTy' - -nonDepTabPiType :: IRRep r => IxType r n -> Type r n -> TabPiType r n -nonDepTabPiType (IxType t d) resultTy = - case toConstAbsPure resultTy of - Abs b resultTy' -> TabPiType d (PlainBD (b:>t)) resultTy' - -corePiTypeToPiType :: CorePiType n -> PiType CoreIR n -corePiTypeToPiType (CorePiType _ _ bs effTy) = PiType bs effTy - -coreLamToTopLam :: CoreLamExpr n -> TopLam CoreIR n -coreLamToTopLam (CoreLamExpr ty f) = TopLam False (corePiTypeToPiType ty) f - -(==>) :: IRRep r => IxType r n -> Type r n -> Type r n -a ==> b = TabPi $ nonDepTabPiType a b - -litFinIxTy :: Int -> IxType r n -litFinIxTy n = finIxTy $ IdxRepVal $ fromIntegral n - -finIxTy :: Atom r n -> IxType r n -finIxTy n = IxType IdxRepTy (IxDictRawFin n) - -ixTyFromDict :: IRRep r => IxDict r n -> IxType r n -ixTyFromDict ixDict = flip IxType ixDict $ case ixDict of - IxDictAtom dict -> case getType dict of - DictTy (DictType "Ix" _ [Type iTy]) -> iTy - _ -> error $ "Not an Ix dict: " ++ show dict - IxDictRawFin _ -> IdxRepTy - IxDictSpecialized n _ _ -> n - --- === querying effects implementation === - -instance IRRep r => HasEffects (Expr r) r where - getEffects = \case - Atom _ -> Pure - App (EffTy eff _) _ _ -> eff - TopApp (EffTy eff _) _ _ -> eff - TabApp _ _ _ -> Pure - Case _ _ (EffTy effs _) -> effs - TabCon _ _ _ -> Pure - ApplyMethod (EffTy eff _) _ _ _ -> eff - PrimOp primOp -> getEffects primOp - -instance IRRep r => HasEffects (DeclBinding r) r where - getEffects (DeclBinding _ expr) = getEffects expr - {-# INLINE getEffects #-} - -instance IRRep r => HasEffects (PrimOp r) r where - getEffects = \case - UnOp _ _ -> Pure - BinOp _ _ _ -> Pure - VectorOp _ -> Pure - MemOp op -> case op of - IOAlloc _ -> OneEffect IOEffect - IOFree _ -> OneEffect IOEffect - PtrLoad _ -> OneEffect IOEffect - PtrStore _ _ -> OneEffect IOEffect - PtrOffset _ _ -> Pure - MiscOp op -> case op of - ThrowException _ -> OneEffect ExceptionEffect - Select _ _ _ -> Pure - ThrowError _ -> Pure - CastOp _ _ -> Pure - UnsafeCoerce _ _ -> Pure - GarbageVal _ -> Pure - BitcastOp _ _ -> Pure - SumTag _ -> Pure - ToEnum _ _ -> Pure - OutputStream -> Pure - ShowAny _ -> Pure - ShowScalar _ -> Pure - RefOp ref m -> case getType ref of - TC (RefType h _) -> case m of - MGet -> OneEffect (RWSEffect State h) - MPut _ -> OneEffect (RWSEffect State h) - MAsk -> OneEffect (RWSEffect Reader h) - -- XXX: We don't verify the base monoid. See note about RunWriter. - MExtend _ _ -> OneEffect (RWSEffect Writer h) - IndexRef _ _ -> Pure - ProjRef _ _ -> Pure - _ -> error "not a ref" - DAMOp op -> case op of - Place _ _ -> OneEffect InitEffect - Seq eff _ _ _ _ -> eff - RememberDest eff _ _ -> eff - AllocDest _ -> Pure -- is this correct? - Freeze _ -> Pure -- is this correct? - Hof (TypedHof (EffTy eff _) _) -> eff - {-# INLINE getEffects #-} diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 3c627e705..41764e4a7 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -15,12 +15,12 @@ import Err import IRVariants import MTL1 import Name -import CheapReduction import Types.Core import Types.Source import Types.Primitives import QueryType import Util (enumerate) +import Visitor newtype Printer (n::S) (a :: *) = Printer { runPrinter' :: ReaderT1 (Atom CoreIR) (BuilderM CoreIR) n a } deriving ( Functor, Applicative, Monad, EnvReader, MonadReader (Atom CoreIR n) @@ -78,7 +78,6 @@ showAnyRec atom = case getType atom of parens $ sepBy ", " $ map rec xs -- TODO: traverse the type and print out data components TypeKind -> printAsConstant - ProjectEltTy _ _ _ -> error "not implemented" Pi _ -> printTypeOnly "function" TabPi _ -> brackets $ forEachTabElt atom \iOrd x -> do isFirst <- ieq iOrd (NatVal 0) @@ -94,7 +93,7 @@ showAnyRec atom = case getType atom of EffectRowKind -> printAsConstant -- hack to print strings nicely. TODO: make `Char` a newtype UserADTType "List" _ (TyConParams [Explicit] [Type Word8Ty]) -> do - charTab <- normalizeNaryProj [ProjectProduct 1, UnwrapNewtype] atom + charTab <- getProj 1 =<< unwrapNewtype atom emitCharLit '"' emitCharTab charTab emitCharLit '"' @@ -113,15 +112,16 @@ showAnyRec atom = case getType atom of rec =<< projectStruct i atom where showDataCon :: Emits n' => DataConDef n' -> CAtom n' -> Print n' - showDataCon (DataConDef sn _ _ projss) arg = do - case projss of - [] -> emitLit sn - _ -> parens do - emitLit (sn ++ " ") - sepBy " " $ projss <&> \projs -> - -- we use `init` to strip off the `UnwrapCompoundNewtype` since - -- we're already under the case alternative - rec =<< normalizeNaryProj (init projs) arg + showDataCon (DataConDef sn _ _ projss) arg = undefined + -- showDataCon (DataConDef sn _ _ projss) arg = do + -- case projss of + -- [] -> emitLit sn + -- _ -> parens do + -- emitLit (sn ++ " ") + -- sepBy " " $ projss <&> \projs -> + -- -- we use `init` to strip off the `UnwrapCompoundNewtype` since + -- -- we're already under the case alternative + -- rec =<< normalizeNaryProj (init projs) arg DepPairTy _ -> parens do (x, y) <- fromPair atom rec x >> emitLit " ,> " >> rec y @@ -200,7 +200,7 @@ stringLitAsCharTab s = do t <- finTabTyCore (NatVal $ fromIntegral $ length s) CharRepTy emitExpr $ TabCon Nothing t (map charRepVal s) -finTabTyCore :: (Fallible1 m, EnvReader m) => CAtom n -> CType n -> m n (CType n) +finTabTyCore :: CBuilderEmits m n => CAtom n -> CType n -> m n (CType n) finTabTyCore n eltTy = do d <- mkDictAtom $ IxFin n return $ IxType (FinTy n) (IxDictAtom d) ==> eltTy diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index b3be6487c..705aa6769 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -18,7 +18,6 @@ import Data.Maybe import Data.Text.Prettyprint.Doc (Pretty (..), hardline) import Builder -import CheapReduction import CheckType import Core import Err @@ -85,10 +84,6 @@ tryAsDataAtom atom = do DepPair x y ty -> do DepPairTy ty' <- getRepType $ DepPairTy ty DepPair <$> go x <*> go y <*> pure ty' - ProjectElt _ UnwrapNewtype x -> go x - -- TODO: do we need to think about a case like `fst (1, \x.x)`, where - -- the projection is data but the argument isn't? - ProjectElt _ (ProjectProduct i) x -> normalizeProj (ProjectProduct i) =<< go x NewtypeCon _ x -> go x SimpInCore x -> case x of LiftSimp _ x' -> return x' @@ -98,7 +93,6 @@ tryAsDataAtom atom = do Lam _ -> notData DictCon _ _ -> notData Eff _ -> notData - DictHole _ _ _ -> notData TypeAsAtom _ -> notData where notData = error $ "Not runtime-representable data: " ++ pprint atom @@ -124,21 +118,20 @@ fromNaryTabLam maxDepth = \case _ -> Nothing _ -> Nothing -forceACase :: Emits n => SAtom n -> [Abs SBinder CAtom n] -> CType n -> SimplifyM i n (SAtom n) +forceACase :: Emits n => SAtom n -> [Abs SBinderAndDecls CAtom n] -> CType n -> SimplifyM i n (SAtom n) forceACase scrut alts resultTy = do resultTy' <- getRepType resultTy buildCase scrut resultTy' \i arg -> do - Abs b result <- return $ alts !! i - applySubst (b@>SubstVal arg) result >>= toDataAtomIgnoreRecon + instantiate (alts !! i) [arg] >>= toDataAtomIgnoreRecon -tryGetRepType :: Type CoreIR n -> SimplifyM i n (Maybe (SType n)) +tryGetRepType :: Emits n => Type CoreIR n -> SimplifyM i n (Maybe (SType n)) tryGetRepType t = isData t >>= \case False -> return Nothing True -> Just <$> getRepType t -getRepType :: Type CoreIR n -> SimplifyM i n (SType n) +getRepType :: Emits n => Type CoreIR n -> SimplifyM i n (SType n) getRepType ty = go ty where - go :: Type CoreIR n -> SimplifyM i n (SType n) + go :: Emits n => Type CoreIR n -> SimplifyM i n (SType n) go = \case TC con -> TC <$> case con of BaseType b -> return $ BaseType b @@ -148,27 +141,49 @@ getRepType ty = go ty where TypeKind -> error $ notDataType HeapType -> return $ HeapType DepPairTy depPairTy@(DepPairType expl b _) -> do - l' <- go $ binderType b - withFreshBinder (getNameHint b) l' \b' -> do - x <- liftSimpAtom (sink $ binderType b) (Var $ binderVar b') - r' <- go =<< instantiate depPairTy [x] - return $ DepPairTy $ DepPairType expl (PlainBD b') r' + let bTy = binderType b + l' <- go bTy + withFreshBinderAndDecls (getNameHint b) l' \v -> do + x <- liftSimpAtom (sink bTy) (Var v) + r' <- go =<< liftCoreBuilder (instantiate depPairTy [x]) + return \b' -> DepPairTy $ DepPairType expl b' r' TabPi tabTy -> do let ixTy = tabIxType tabTy IxType t' d' <- simplifyIxType ixTy - withFreshBinder (getNameHint tabTy) t' \b' -> do - x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b') - bodyTy' <- go =<< instantiate tabTy [x] - return $ TabPi $ TabPiType d' (PlainBD b') bodyTy' + withFreshBinderAndDecls (getNameHint tabTy) t' \v -> do + x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var v) + bodyTy' <- go =<< liftCoreBuilder (instantiate tabTy [x]) + return \b' -> TabPi $ TabPiType d' b' bodyTy' NewtypeTyCon con -> do - (_, ty') <- unwrapNewtypeType con + (_, ty') <- liftCoreBuilder $ unwrapNewtypeType con go ty' Pi _ -> error notDataType DictTy _ -> error notDataType TyVar _ -> error "Shouldn't have type variables in CoreIR IR with SimpIR builder names" - ProjectEltTy _ _ _ -> error "Shouldn't have this left" where notDataType = "Not a type of runtime-representable data: " ++ pprint ty +liftSimpAtom :: Emits n => Type CoreIR n -> SAtom n -> SimplifyM i n (CAtom n) +liftSimpAtom ty simpAtom = case simpAtom of + Var _ -> justLift + RepValAtom _ -> justLift -- TODO(dougalm): should we make more effort to pull out products etc? + _ -> do + (cons , ty') <- liftCoreBuilder $ unwrapLeadingNewtypesType ty + atom <- case (ty', simpAtom) of + (BaseTy _ , Con (Lit v)) -> return $ Con $ Lit v + (ProdTy tys, Con (ProdCon xs)) -> Con . ProdCon <$> zipWithM rec tys xs + (SumTy tys, Con (SumCon _ i x)) -> Con . SumCon tys i <$> rec (tys!!i) x + (DepPairTy dpt, DepPair x1 x2 _) -> do + x1' <- rec (depPairLeftTy dpt) x1 + t2' <- liftCoreBuilder $ instantiate dpt [x1'] + x2' <- rec t2' x2 + return $ DepPair x1' x2' dpt + _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty' + return $ wrapNewtypesData cons atom + where + rec = liftSimpAtom + justLift = return $ SimpInCore $ LiftSimp ty simpAtom +{-# INLINE liftSimpAtom #-} + toDataAtom :: Emits n => CAtom n -> SimplifyM i n (SAtom n, Type CoreIR n) toDataAtom x = tryAsDataAtom x >>= \case Just x' -> return x' @@ -187,19 +202,24 @@ toDataAtomIgnoreReconAssumeNoDecls x = do Empty -> return result _ -> error "unexpected decls" -withSimplifiedBinders - :: Binders CoreIR o any - -> (forall o'. DExt o o' => Binders SimpIR o o' -> [CAtom o'] -> SimplifyM i o' a) +simplifyBinders + :: Emits o => Binders CoreIR i i' + -> (forall o'. DExt o o' => [CAtom o'] -> SimplifyM i' o' (Binders SimpIR o o' -> a)) -> SimplifyM i o a -withSimplifiedBinders Empty cont = getDistinct >>= \Distinct -> cont Empty [] -withSimplifiedBinders (Nest (BD (bCore:>ty)) bsCore) cont = do - simpTy <- getRepType ty - withFreshBinder (getNameHint bCore) simpTy \bSimp -> do - x <- liftSimpAtom (sink ty) (Var $ binderVar bSimp) - -- TODO: carry a substitution instead of doing N^2 work like this - Abs bsCore' UnitE <- applySubst (bCore@>SubstVal x) (EmptyAbs bsCore) - withSimplifiedBinders bsCore' \bsSimp xs -> - cont (Nest (BD bSimp) bsSimp) (sink x:xs) +simplifyBinders bs cont = case bs of + Empty -> do + Distinct <- getDistinct + f <- cont [] + return $ f Empty + Nest b bs -> do + cTy <- substM $ binderType b + sTy <- getRepType cTy + withFreshBinderAndDecls (getNameHint b) sTy \sVar -> do + x <- liftSimpAtom (sink cTy) (Var sVar) + extendSubstBD b [SubstVal x] do + simplifyBinders bs \xs -> do + f <- cont (sink x:xs) + return \sBinders -> \sBinder -> f (Nest sBinder sBinders) -- === Reconstructions === @@ -207,7 +227,7 @@ data ReconstructAtom (n::S) = CoerceRecon (Type CoreIR n) | LamRecon (ReconAbs SimpIR CAtom n) -applyRecon :: (EnvReader m, Fallible1 m) => ReconstructAtom n -> SAtom n -> m n (CAtom n) +applyRecon :: Emits n => ReconstructAtom n -> SAtom n -> SimplifyM i n (CAtom n) applyRecon (CoerceRecon ty) x = liftSimpAtom ty x applyRecon (LamRecon ab) x = applyReconAbs ab x @@ -235,6 +255,9 @@ liftSimplifyM cont = do liftDoubleBuilderToSimplifyM :: DoubleBuilder SimpIR o a -> SimplifyM i o a liftDoubleBuilderToSimplifyM cont = SimplifyM $ liftSubstReaderT cont +liftCoreBuilder :: Emits o => BuilderM CoreIR o a -> SimplifyM i o a +liftCoreBuilder = undefined + instance Simplifier SimplifyM deriving instance ScopableBuilder SimpIR (SimplifyM i) @@ -253,15 +276,16 @@ simplifyTopBlock _ = error "not a block (nullary lambda)" {-# SCC simplifyTopBlock #-} simplifyTopFunction :: (TopBuilder m, Mut n) => CTopLam n -> m n (STopLam n) -simplifyTopFunction (TopLam False _ f) = do - asTopLam =<< liftSimplifyM do - (lam, CoerceReconAbs) <- simplifyLam f - return lam -simplifyTopFunction _ = error "shouldn't be in destination-passing style already" -{-# SCC simplifyTopFunction #-} +simplifyTopFunction (TopLam False _ f) = undefined +-- simplifyTopFunction (TopLam False _ f) = do +-- asTopLam =<< liftSimplifyM do +-- (lam, CoerceReconAbs) <- simplifyLam f +-- return lam +-- simplifyTopFunction _ = error "shouldn't be in destination-passing style already" +-- {-# SCC simplifyTopFunction #-} applyReconTop :: (EnvReader m, Fallible1 m) => ReconstructAtom n -> SAtom n -> m n (CAtom n) -applyReconTop = applyRecon +applyReconTop = undefined -- applyRecon instance GenericE SimplifiedBlock where type RepE SimplifiedBlock = PairE SBlock ReconstructAtom @@ -369,20 +393,19 @@ defuncCaseCore :: Emits o -> SimplifyM i o (CAtom o) defuncCaseCore scrut resultTy cont = do tryAsDataAtom scrut >>= \case - Just (scrutSimp, _) -> do - altBinderTys <- caseAltsBinderTys $ getType scrut - defuncCase scrutSimp resultTy \i x -> do - let xCoreTy = altBinderTys !! i - x' <- liftSimpAtom (sink xCoreTy) x - cont i x' + -- Just (scrutSimp, _) -> do + -- altBinderTys <- caseAltsBinderTys $ getType scrut + -- defuncCase scrutSimp resultTy \i x -> do + -- let xCoreTy = altBinderTys !! i + -- x' <- liftSimpAtom (sink xCoreTy) x + -- cont i x' Nothing -> case trySelectBranch scrut of Just (i, arg) -> getDistinct >>= \Distinct -> cont i arg Nothing -> go scrut where go = \case SimpInCore (ACase scrutSimp alts _) -> do defuncCase scrutSimp resultTy \i x -> do - Abs altb altAtom <- return $ alts !! i - altAtom' <- applySubst (altb @> SubstVal x) altAtom + altAtom' <- instantiate (alts !! i) [x] cont i altAtom' NewtypeCon con scrut' | isSumCon con -> go scrut' _ -> nope @@ -416,7 +439,9 @@ defuncCase scrut resultTy cont = do caseResult <- emitExpr $ caseExpr (dataVal, sumVal) <- fromPair caseResult reconAlts <- forM (zip closureTys recons) \(ty, recon) -> - buildAbs noHint ty \v -> applyRecon (sink recon) (Var v) + withFreshBinderAndDecls noHint ty \v -> do + result <- applyRecon (sink recon) (Var v) + return \b -> Abs b result let nonDataVal = SimpInCore $ ACase sumVal reconAlts newNonDataTy Distinct <- getDistinct fromSplit split dataVal nonDataVal @@ -426,21 +451,22 @@ simplifyAlt -> SType o -> (forall o'. (Emits o', DExt o o') => SAtom o' -> SimplifyM i o' (CAtom o')) -> SimplifyM i o (Alt SimpIR o, SType o, ReconstructAtom o) -simplifyAlt split ty cont = do - withFreshBinder noHint ty \b -> do - ab <- buildScoped $ cont $ sink $ Var $ binderVar b - (body, recon) <- refreshAbs ab \decls result -> do - let locals = toScopeFrag b >>> toScopeFrag decls - -- TODO: this might be too cautious. The type only needs to - -- be hoistable above the decls. In principle it can still - -- mention vars from the lambda binders. - Distinct <- getDistinct - (resultData, resultNonData) <- toSplit split result - (newResult, reconAbs) <- telescopicCapture locals resultNonData - return (Abs decls (PairVal resultData newResult), LamRecon reconAbs) - EffTy _ (PairTy _ nonDataType) <- blockEffTy body - let nonDataType' = ignoreHoistFailure $ hoist b nonDataType - return (Abs b body, nonDataType', recon) +simplifyAlt split ty cont = undefined +-- simplifyAlt split ty cont = do +-- withFreshBinder noHint ty \b -> do +-- ab <- buildScoped $ cont $ sink $ binderAtom b +-- (body, recon) <- refreshAbs ab \decls result -> do +-- let locals = toScopeFrag b >>> toScopeFrag decls +-- -- TODO: this might be too cautious. The type only needs to +-- -- be hoistable above the decls. In principle it can still +-- -- mention vars from the lambda binders. +-- Distinct <- getDistinct +-- (resultData, resultNonData) <- toSplit split result +-- (newResult, reconAbs) <- telescopicCapture locals resultNonData +-- return (Abs decls (PairVal resultData newResult), LamRecon reconAbs) +-- EffTy _ (PairTy _ nonDataType) <- blockEffTy body +-- let nonDataType' = ignoreHoistFailure $ hoist b nonDataType +-- return (Abs b body, nonDataType', recon) simplifyApp :: forall i o. Emits o => NameHint -> CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) @@ -457,7 +483,7 @@ simplifyApp hint resultTy f xs = case f of SimpInCore (ACase e alts _) -> dropSubst do defuncCase e resultTy \i x -> do Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do + extendSubstBD b [SubstVal x] do xs' <- mapM sinkM xs simplifyApp hint (sink resultTy) body xs' SimpInCore (LiftSimpFun _ lam) -> do @@ -485,12 +511,6 @@ simplifyAtomAndInline atom = confuseGHC >>= \_ -> case atom of Rename v' -> doInline =<< toAtomVar v' SubstVal (Var v') -> doInline v' SubstVal x -> return x - -- This is a hack because we weren't normalize the unwrapping of - -- `unit_type_scale` in `plot.dx`. We need a better system for deciding how to - -- normalize and inline. - ProjectElt _ i x -> do - x' <- simplifyAtom x >>= normalizeProj i - dropSubst $ simplifyAtomAndInline x' _ -> simplifyAtom atom >>= \case Var v -> doInline v ans -> return ans @@ -554,23 +574,23 @@ simplifyTabApp f@(SimpInCore sic) xs = case sic of atom <- emitDecls block' simplifyTabApp atom xsRest Nothing -> error "should never happen" - ACase e alts ty -> dropSubst do - resultTy <- typeOfTabApp ty xs - defuncCase e resultTy \i x -> do - Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do - xs' <- mapM sinkM xs - body' <- substM body - simplifyTabApp body' xs' - LiftSimp _ f' -> do - fTy <- return $ getType f - resultTy <- typeOfTabApp fTy xs - xs' <- mapM toDataAtomIgnoreRecon xs - liftSimpAtom resultTy =<< naryTabApp f' xs' + -- ACase e alts ty -> dropSubst do + -- resultTy <- typeOfTabApp ty xs + -- defuncCase e resultTy \i x -> do + -- Abs b body <- return $ alts !! i + -- extendSubst (b@>SubstVal x) do + -- xs' <- mapM sinkM xs + -- body' <- substM body + -- simplifyTabApp body' xs' + -- LiftSimp _ f' -> do + -- fTy <- return $ getType f + -- resultTy <- typeOfTabApp fTy xs + -- xs' <- mapM toDataAtomIgnoreRecon xs + -- liftSimpAtom resultTy =<< naryTabApp f' xs' LiftSimpFun _ _ -> error "not implemented" simplifyTabApp f _ = error $ "Unexpected table: " ++ pprint f -simplifyIxType :: IxType CoreIR o -> SimplifyM i o (IxType SimpIR o) +simplifyIxType :: Emits o => IxType CoreIR o -> SimplifyM i o (IxType SimpIR o) simplifyIxType (IxType t ixDict) = do t' <- getRepType t IxType t' <$> case ixDict of @@ -578,7 +598,7 @@ simplifyIxType (IxType t ixDict) = do n' <- toDataAtomIgnoreReconAssumeNoDecls n return $ IxDictRawFin n' IxDictAtom d -> do - (dictAbs, params) <- generalizeIxDict =<< cheapNormalize d + (dictAbs, params) <- generalizeIxDict d params' <- mapM toDataAtomIgnoreReconAssumeNoDecls params sdName <- requireIxDictCache dictAbs return $ IxDictSpecialized t' sdName params' @@ -629,10 +649,8 @@ simplifyAtom atom = confuseGHC >>= \_ -> case atom of Con con -> Con <$> traverseOp con substM simplifyAtom (error "unexpected lambda") Eff eff -> Eff <$> substM eff PtrVar t v -> PtrVar t <$> substM v - DictCon t d -> (DictCon <$> substM t <*> substM d) >>= cheapNormalize - DictHole _ _ _ -> error "shouldn't have dict holes past inference" + DictCon t d -> DictCon <$> substM t <*> substM d NewtypeCon _ _ -> substM atom - ProjectElt _ i x -> normalizeProj i =<< simplifyAtom x SimpInCore _ -> substM atom TypeAsAtom _ -> substM atom @@ -653,29 +671,21 @@ simplifyVar v = do -- Assumes first order (args/results are "data", allowing newtypes), monormophic simplifyLam - :: LamExpr CoreIR i - -> SimplifyM i o (LamExpr SimpIR o, Abs (Nest (AtomNameBinder SimpIR)) ReconstructAtom o) -simplifyLam (LamExpr bsTop body) = case bsTop of - Nest b bs -> do - ty' <- substM $ binderType b - tySimp <- getRepType ty' - withFreshBinder (getNameHint b) tySimp \b''@(b':>_) -> do - x <- liftSimpAtom (sink ty') (Var $ binderVar b'') - extendSubstBD b [SubstVal x] do - (LamExpr bs' body', Abs bsRecon recon) <- simplifyLam $ LamExpr bs body - return (LamExpr (Nest (PlainBD (b':>tySimp)) bs') body', Abs (Nest b' bsRecon) recon) - Empty -> do + :: Emits o => LamExpr CoreIR i + -> SimplifyM i o (LamExpr SimpIR o, Abs SBinders ReconstructAtom o) +simplifyLam (LamExpr bs body) = + simplifyBinders bs \_ -> do SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body - return (LamExpr Empty body', Abs Empty recon) + return \bs' -> (LamExpr bs' body', Abs bs' recon) data SplitDataNonData n = SplitDataNonData { dataTy :: Type SimpIR n , nonDataTy :: Type CoreIR n , toSplit :: forall i l . CAtom l -> SimplifyM i l (SAtom l, CAtom l) - , fromSplit :: forall i l . DExt n l => SAtom l -> CAtom l -> SimplifyM i l (CAtom l) } + , fromSplit :: forall i l . (Emits l, DExt n l) => SAtom l -> CAtom l -> SimplifyM i l (CAtom l) } -- bijection between that type and a (data, non-data) pair type. -splitDataComponents :: Type CoreIR n -> SimplifyM i n (SplitDataNonData n) +splitDataComponents :: Emits n => Type CoreIR n -> SimplifyM i n (SplitDataNonData n) splitDataComponents = \case ProdTy tys -> do splits <- mapM splitDataComponents tys @@ -717,8 +727,8 @@ buildSimplifiedBlock cont = do case eitherResult of LeftE ans -> do (block, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do - (newResult, reconAbs) <- telescopicCapture (toScopeFrag decls') ans' - return (Abs decls' newResult, LamRecon reconAbs) + (Abs decls'' newResult, reconAbs) <- telescopicCapture (toScopeFrag decls') ans' + return (Abs (decls' >>> decls'') newResult, LamRecon reconAbs) return $ SimplifiedBlock block recon RightE (ans `PairE` ty) -> do let ty' = ignoreHoistFailure $ hoist (toScopeFrag decls) ty @@ -783,7 +793,7 @@ pattern CoerceReconAbs <- Abs _ (CoerceRecon _) applyDictMethod :: Emits o => CType o -> CAtom o -> Int -> [CAtom o] -> SimplifyM i o (CAtom o) applyDictMethod resultTy d i methodArgs = do - cheapNormalize d >>= \case + case d of DictCon _ (InstanceDict instanceName instanceArgs) -> dropSubst do instanceArgs' <- mapM simplifyAtom instanceArgs instanceDef <- lookupInstanceDef instanceName @@ -811,13 +821,15 @@ simplifyHof _hint resultTy = \case ans <- emitHof $ For d ixTypeSimp lam' case recon of CoerceRecon _ -> liftSimpAtom resultTy ans - LamRecon (Abs bsClosure reconResult) -> do + LamRecon (Abs bsClosure reconResult) -> dropSubst do TabPi resultTabTy <- return resultTy liftM (SimpInCore . TabLam resultTabTy) $ PairE ixTypeSimp <$> buildAbs noHint (ixTypeType ixTypeSimp) \i -> buildScoped do i' <- sinkM i - xs <- unpackTelescope bsClosure =<< tabApp (sink ans) (Var i') - applySubst (bIx@>Rename (atomVarName i') <.> bsClosure @@> map SubstVal xs) reconResult + extendSubstBD bIx [Rename (atomVarName i')] do + xs <- unpackTelescope bsClosure =<< tabApp (sink ans) (Var i') + extendSubst (bsClosure @@> map SubstVal xs) $ + substM reconResult While body -> do SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyBlock body result <- emitHof $ While body' @@ -895,11 +907,13 @@ fmapMaybe fmapMaybe scrut f = do ~(MaybeTy justTy) <- return $ getType scrut (justAlt, resultJustTy) <- withFreshBinder noHint justTy \b -> do - result <- f (Var $ binderVar b) + result <- f (binderAtom b) resultTy <- return $ ignoreHoistFailure $ hoist b (getType result) result' <- preludeJustVal result - return (Abs b result', resultTy) - nothingAlt <- buildAbs noHint UnitTy \_ -> preludeNothingVal $ sink resultJustTy + return (Abs (PlainBD b) result', resultTy) + nothingAlt <- withFreshBinder noHint UnitTy \b -> do + result <- preludeNothingVal $ sink resultJustTy + return $ Abs (PlainBD b) result resultMaybeTy <- makePreludeMaybeTy resultJustTy return $ SimpInCore $ ACase scrut [nothingAlt, justAlt] resultMaybeTy @@ -970,84 +984,86 @@ type Linearized = Abs SBinders -- primal args simplifyCustomLinearization :: Abstracted CoreIR (ListE CAtom) n -> [Active] -> AtomRules n -> SimplifyM i n (PairE STopLam STopLam n) -simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do - CustomLinearize nImplicit nExplicit zeros fCustom <- return rule - linearized <- withSimplifiedBinders runtimeBs \runtimeBs' runtimeArgs -> do - Abs runtimeBs' <$> buildScoped do - ListE staticArgs' <- instantiate (Abs runtimeBs staticArgs) (sink <$> runtimeArgs) - fCustom' <- sinkM fCustom - resultTy <- typeOfApp (getType fCustom') staticArgs' - pairResult <- dropSubst $ simplifyApp noHint resultTy fCustom' staticArgs' - (primalResult, fLin) <- fromPair pairResult - primalResult' <- toDataAtomIgnoreRecon primalResult - let explicitPrimalArgs = drop nImplicit staticArgs' - allTangentTys <- forM explicitPrimalArgs \primalArg -> do - tangentType =<< getRepType (getType primalArg) - let actives' = drop (length actives - nExplicit) actives - activeTangentTys <- catMaybes <$> forM (zip allTangentTys actives') - \(t, active) -> return case active of True -> Just t; False -> Nothing - fLin' <- buildUnaryLamExpr "t" (ProdTy activeTangentTys) \activeTangentArg -> do - activeTangentArgs <- getUnpacked $ Var activeTangentArg - ListE allTangentTys' <- sinkM $ ListE allTangentTys - tangentArgs <- buildTangentArgs zeros (zip allTangentTys' actives') activeTangentArgs - -- TODO: we're throwing away core type information here. Once we - -- support core-level tangent types we should make an effort to - -- correctly restore the core types before applying `fLin`. Right now, - -- a custom linearization defined for a function on ADTs will - -- not work. - fLin' <- sinkM fLin - Pi (CorePiType _ _ bs _) <- return $ getType fLin' - let tangentCoreTys = fromNonDepNest bs - tangentArgs' <- zipWithM liftSimpAtom tangentCoreTys tangentArgs - resultTyTangent <- typeOfApp (getType fLin') tangentArgs' - tangentResult <- dropSubst $ simplifyApp noHint resultTyTangent fLin' tangentArgs' - toDataAtomIgnoreRecon tangentResult - return $ PairE primalResult' fLin' - PairE primalFun tangentFun <- defuncLinearized linearized - primalFun' <- asTopLam primalFun - tangentFun' <- asTopLam tangentFun - return $ PairE primalFun' tangentFun' - where - buildTangentArgs :: Emits n => SymbolicZeros -> [(SType n, Active)] -> [SAtom n] -> SimplifyM i n [SAtom n] - buildTangentArgs _ [] [] = return [] - buildTangentArgs zeros ((t, False):tys) activeArgs = do - inactiveArg <- case zeros of - SymbolicZeros -> symbolicTangentZero t - InstantiateZeros -> zeroAt t - rest <- buildTangentArgs zeros tys activeArgs - return $ inactiveArg:rest - buildTangentArgs zeros ((_, True):tys) (activeArg:activeArgs) = do - activeArg' <- case zeros of - SymbolicZeros -> symbolicTangentNonZero activeArg - InstantiateZeros -> return activeArg - rest <- buildTangentArgs zeros tys activeArgs - return $ activeArg':rest - buildTangentArgs _ _ _ = error "zip error" - - fromNonDepNest :: Nest CBinderAndDecls n l -> [CType n] - fromNonDepNest Empty = [] - fromNonDepNest (Nest b bs) = - case ignoreHoistFailure $ hoist b (Abs bs UnitE) of - Abs bs' UnitE -> binderType b : fromNonDepNest bs' - -defuncLinearized :: EnvReader m => Linearized n -> m n (PairE SLam SLam n) -defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do - (declsAndResult, reconAbs, residualsTangentsBs) <- - refreshAbs ab' \decls (PairE primalResult fLin) -> do - (residuals, reconAbs) <- telescopicCapture (toScopeFrag decls) fLin - let rTy = getType residuals - LamExpr tBs _ <- return fLin - residualsTangentsBs <- withFreshBinder "residual" rTy \rB -> do - Abs tBs' UnitE <- sinkM $ Abs tBs UnitE - return $ Abs (Nest (BD rB) tBs') UnitE - residualsTangentsBs' <- return $ ignoreHoistFailure $ hoist decls residualsTangentsBs - return (Abs decls (PairVal primalResult residuals), reconAbs, residualsTangentsBs') - let primalFun = LamExpr bs declsAndResult - LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do - lam <- applyReconAbs (sink reconAbs) (Var residuals) - instantiate lam (Var <$> tangents) >>= emitBlock - let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody - return $ PairE primalFun tangentFun +simplifyCustomLinearization _ = undefined -- (Abs runtimeBs staticArgs) actives rule = undefined +-- simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do +-- CustomLinearize nImplicit nExplicit zeros fCustom <- return rule +-- linearized <- simplifyBinders runtimeBs \runtimeBs' runtimeArgs -> do +-- Abs runtimeBs' <$> buildScoped do +-- ListE staticArgs' <- instantiate (Abs runtimeBs staticArgs) (sink <$> runtimeArgs) +-- fCustom' <- sinkM fCustom +-- resultTy <- typeOfApp (getType fCustom') staticArgs' +-- pairResult <- dropSubst $ simplifyApp noHint resultTy fCustom' staticArgs' +-- (primalResult, fLin) <- fromPair pairResult +-- primalResult' <- toDataAtomIgnoreRecon primalResult +-- let explicitPrimalArgs = drop nImplicit staticArgs' +-- allTangentTys <- forM explicitPrimalArgs \primalArg -> do +-- tangentType =<< getRepType (getType primalArg) +-- let actives' = drop (length actives - nExplicit) actives +-- activeTangentTys <- catMaybes <$> forM (zip allTangentTys actives') +-- \(t, active) -> return case active of True -> Just t; False -> Nothing +-- fLin' <- buildUnaryLamExpr "t" (ProdTy activeTangentTys) \activeTangentArg -> do +-- activeTangentArgs <- getUnpacked $ Var activeTangentArg +-- ListE allTangentTys' <- sinkM $ ListE allTangentTys +-- tangentArgs <- buildTangentArgs zeros (zip allTangentTys' actives') activeTangentArgs +-- -- TODO: we're throwing away core type information here. Once we +-- -- support core-level tangent types we should make an effort to +-- -- correctly restore the core types before applying `fLin`. Right now, +-- -- a custom linearization defined for a function on ADTs will +-- -- not work. +-- fLin' <- sinkM fLin +-- Pi (CorePiType _ _ bs _) <- return $ getType fLin' +-- let tangentCoreTys = fromNonDepNest bs +-- tangentArgs' <- zipWithM liftSimpAtom tangentCoreTys tangentArgs +-- resultTyTangent <- typeOfApp (getType fLin') tangentArgs' +-- tangentResult <- dropSubst $ simplifyApp noHint resultTyTangent fLin' tangentArgs' +-- toDataAtomIgnoreRecon tangentResult +-- return $ PairE primalResult' fLin' +-- PairE primalFun tangentFun <- defuncLinearized linearized +-- primalFun' <- asTopLam primalFun +-- tangentFun' <- asTopLam tangentFun +-- return $ PairE primalFun' tangentFun' +-- where +-- buildTangentArgs :: Emits n => SymbolicZeros -> [(SType n, Active)] -> [SAtom n] -> SimplifyM i n [SAtom n] +-- buildTangentArgs _ [] [] = return [] +-- buildTangentArgs zeros ((t, False):tys) activeArgs = do +-- inactiveArg <- case zeros of +-- SymbolicZeros -> symbolicTangentZero t +-- InstantiateZeros -> zeroAt t +-- rest <- buildTangentArgs zeros tys activeArgs +-- return $ inactiveArg:rest +-- buildTangentArgs zeros ((_, True):tys) (activeArg:activeArgs) = do +-- activeArg' <- case zeros of +-- SymbolicZeros -> symbolicTangentNonZero activeArg +-- InstantiateZeros -> return activeArg +-- rest <- buildTangentArgs zeros tys activeArgs +-- return $ activeArg':rest +-- buildTangentArgs _ _ _ = error "zip error" + +-- fromNonDepNest :: Nest CBinderAndDecls n l -> [CType n] +-- fromNonDepNest Empty = [] +-- fromNonDepNest (Nest b bs) = +-- case ignoreHoistFailure $ hoist b (Abs bs UnitE) of +-- Abs bs' UnitE -> binderType b : fromNonDepNest bs' + +_defuncLinearized :: EnvReader m => Linearized n -> m n (PairE SLam SLam n) +_defuncLinearized _ = undefined +-- defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do +-- (declsAndResult, reconAbs, residualsTangentsBs) <- +-- refreshAbs ab' \decls (PairE primalResult fLin) -> do +-- (residuals, reconAbs) <- telescopicCapture (toScopeFrag decls) fLin +-- let rTy = getType residuals +-- LamExpr tBs _ <- return fLin +-- residualsTangentsBs <- withFreshBinder "residual" rTy \rB -> do +-- Abs tBs' UnitE <- sinkM $ Abs tBs UnitE +-- return $ Abs (Nest (BD rB) tBs') UnitE +-- residualsTangentsBs' <- return $ ignoreHoistFailure $ hoist decls residualsTangentsBs +-- return (Abs decls (PairVal primalResult residuals), reconAbs, residualsTangentsBs') +-- let primalFun = LamExpr bs declsAndResult +-- LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do +-- lam <- applyReconAbs (sink reconAbs) (Var residuals) +-- instantiate lam (Var <$> tangents) >>= emitBlock +-- let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody +-- return $ PairE primalFun tangentFun -- === exception-handling pass === diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 5b13ef624..5de201e8d 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -14,6 +14,7 @@ import Control.Applicative import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State.Strict +import Data.Functor ((<&>)) import Name import IRVariants @@ -21,6 +22,8 @@ import Types.Core import Core import qualified RawName as R import Err +import Visitor +import Types.Imp -- === SubstReader class === @@ -245,10 +248,6 @@ data SubstVal (atom::IR->E) (c::C) (n::S) where Rename :: Name c n -> SubstVal atom c n type AtomSubstVal = SubstVal Atom -type family IsAtomName (c::C) where - IsAtomName (AtomNameC r) = True - IsAtomName _ = False - instance (Color c, IsAtomName c ~ False) => SubstE (SubstVal atom) (Name c) where substE (_, env) v = case env ! v of Rename v' -> v' @@ -498,3 +497,144 @@ instance (SubstE v e0, SubstE v e1, SubstE v e2, Case6 e -> Case6 $ substE env e Case7 e -> Case7 $ substE env e {-# INLINE substE #-} + +newtype SubstVisitor i o a = SubstVisitor { runSubstVisitor :: Reader (Env o, Subst AtomSubstVal i o) a } + deriving (Functor, Applicative, Monad, MonadReader (Env o, Subst AtomSubstVal i o)) + +substV :: (Distinct o, SubstE AtomSubstVal e) => e i -> SubstVisitor i o (e o) +substV x = ask <&> \env -> substE env x + +instance Distinct o => NonAtomRenamer (SubstVisitor i o) i o where + renameN = substV + +instance (Distinct o, IRRep r) => Visitor (SubstVisitor i o) r i o where + visitType = substV + visitAtom = substV + visitLam = substV + visitPi = substV + +instance Color c => SubstE AtomSubstVal (AtomSubstVal c) where + substE (_, env) (Rename name) = env ! name + substE env (SubstVal val) = SubstVal $ substE env val + +instance SubstV (SubstVal Atom) (SubstVal Atom) where + +instance IRRep r => SubstE AtomSubstVal (Atom r) where + substE es@(_, subst) = \case + Var (AtomVar v ty) -> case subst!v of + Rename v' -> Var $ AtomVar v' (substE es ty) + SubstVal x -> x + SimpInCore x -> SimpInCore (substE es x) + atom -> runReader (runSubstVisitor $ visitAtomPartial atom) es + +instance IRRep r => SubstE AtomSubstVal (Type r) where + substE es@(_, subst) = \case + TyVar (AtomVar v ty) -> case subst ! v of + Rename v' -> TyVar $ AtomVar v' (substE es ty) + SubstVal (Type t) -> t + SubstVal atom -> error $ "bad substitution: " ++ pprint v ++ " -> " ++ pprint atom + ty -> runReader (runSubstVisitor $ visitTypePartial ty) es + +instance SubstE AtomSubstVal SimpInCore + +instance IRRep r => SubstE AtomSubstVal (EffectRow r) where + substE env (EffectRow effs tailVar) = do + let effs' = eSetFromList $ map (substE env) (eSetToList effs) + let tailEffRow = case tailVar of + NoTail -> EffectRow mempty NoTail + EffectRowTail (AtomVar v _) -> case snd env ! v of + Rename v' -> do + let v'' = runEnvReaderM (fst env) $ toAtomVar v' + EffectRow mempty (EffectRowTail v'') + SubstVal (Var v') -> EffectRow mempty (EffectRowTail v') + SubstVal (Eff r) -> r + _ -> error "Not a valid effect substitution" + extendEffRow effs' tailEffRow + +instance IRRep r => SubstE AtomSubstVal (Effect r) + +instance SubstE AtomSubstVal SpecializationSpec where + substE env (AppSpecialization (AtomVar f _) ab) = do + let f' = case snd env ! f of + Rename v -> runEnvReaderM (fst env) $ toAtomVar v + SubstVal (Var v) -> v + _ -> error "bad substitution" + AppSpecialization f' (substE env ab) + +instance SubstE AtomSubstVal EffectDef +instance SubstE AtomSubstVal EffectOpType +instance SubstE AtomSubstVal IExpr +instance IRRep r => SubstE AtomSubstVal (RepVal r) +instance SubstE AtomSubstVal TyConParams +instance SubstE AtomSubstVal DataConDef +instance IRRep r => SubstE AtomSubstVal (BaseMonoid r) +instance IRRep r => SubstE AtomSubstVal (DAMOp r) +instance IRRep r => SubstE AtomSubstVal (TypedHof r) +instance IRRep r => SubstE AtomSubstVal (Hof r) +instance IRRep r => SubstE AtomSubstVal (TC r) +instance IRRep r => SubstE AtomSubstVal (Con r) +instance IRRep r => SubstE AtomSubstVal (MiscOp r) +instance IRRep r => SubstE AtomSubstVal (VectorOp r) +instance IRRep r => SubstE AtomSubstVal (MemOp r) +instance IRRep r => SubstE AtomSubstVal (PrimOp r) +instance IRRep r => SubstE AtomSubstVal (RefOp r) +instance IRRep r => SubstE AtomSubstVal (EffTy r) +instance IRRep r => SubstE AtomSubstVal (Expr r) +instance IRRep r => SubstE AtomSubstVal (GenericOpRep const r) +instance SubstE AtomSubstVal InstanceBody +instance SubstE AtomSubstVal DictType +instance SubstE AtomSubstVal DictExpr +instance IRRep r => SubstE AtomSubstVal (LamExpr r) +instance SubstE AtomSubstVal CorePiType +instance SubstE AtomSubstVal CoreLamExpr +instance IRRep r => SubstE AtomSubstVal (TabPiType r) +instance IRRep r => SubstE AtomSubstVal (PiType r) +instance IRRep r => SubstE AtomSubstVal (DepPairType r) +instance SubstE AtomSubstVal SolverBinding +instance IRRep r => SubstE AtomSubstVal (DeclBinding r) +instance IRRep r => SubstB AtomSubstVal (Decl r) +instance IRRep r => SubstB AtomSubstVal (BinderAndDecls r) +instance SubstE AtomSubstVal NewtypeTyCon +instance SubstE AtomSubstVal NewtypeCon +instance IRRep r => SubstE AtomSubstVal (IxDict r) +instance IRRep r => SubstE AtomSubstVal (IxType r) +instance SubstE AtomSubstVal DataConDefs + +-- === defaults visitors based on substitution === + +visitAtomDefault + :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) + => Atom r i -> m i o (Atom r o) +visitAtomDefault atom = case atom of + Var _ -> atomSubstM atom + SimpInCore _ -> atomSubstM atom + _ -> visitAtomPartial atom + +visitTypeDefault + :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) + => Type r i -> m i o (Type r o) +visitTypeDefault = \case + TyVar v -> atomSubstM $ TyVar v + x -> visitTypePartial x + +visitPiDefault + :: (Visitor2 m r, IRRep r, FromName v, AtomSubstReader v m, EnvExtender2 m) + => PiType r i -> m i o (PiType r o) +visitPiDefault (PiType bs effty) = do + visitBinders bs \bs' -> do + effty' <- visitGeneric effty + return $ PiType bs' effty' + +visitBinders + :: (Visitor2 m r, IRRep r, FromName v, AtomSubstReader v m, EnvExtender2 m) + => Binders r i i' + -> (forall o'. DExt o o' => Binders r o o' -> m i' o' a) + -> m i o a +visitBinders Empty cont = getDistinct >>= \Distinct -> cont Empty +visitBinders _ _ = undefined +-- visitBinders (Nest (BD (b:>ty)) bs) cont = do +-- ty' <- visitType ty +-- withFreshBinder (getNameHint b) ty' \b' -> do +-- extendRenamer (b@>binderName b') do +-- visitBinders bs \bs' -> +-- cont $ Nest (BD b') bs' diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index caf3d591c..e6eb208cb 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -49,7 +49,6 @@ import CheckType (checkTypes) #endif import Core import ConcreteSyntax -import CheapReduction import Err import IRVariants import Imp @@ -277,10 +276,10 @@ evalSourceBlock' mname block = case sbContents block of -- "Can't export functions with captured pointers (not implemented)." -- _ -> return $ Con $ Lit val -- logTop $ ExportedFun name f - GetType -> do -- TODO: don't actually evaluate it - val <- evalUExpr expr - ty <- cheapNormalize $ getType val - logTop $ TextOut $ pprintCanonicalized ty + -- GetType -> do -- TODO: don't actually evaluate it + -- val <- evalUExpr expr + -- ty <- cheapNormalize $ getType val + -- logTop $ TextOut $ pprintCanonicalized ty DeclareForeign fname dexName cTy -> do let b = fromString dexName :: UBinder (AtomNameC CoreIR) VoidS VoidS ty <- evalUType =<< parseExpr cTy @@ -328,9 +327,14 @@ evalSourceBlock' mname block = case sbContents block of $ "Custom linearization can only be defined for functions" UnParseable _ s -> throw ParseErr s Misc m -> case m of - GetNameType v -> do - ty <- cheapNormalize =<< sourceNameType v - logTop $ TextOut $ pprintCanonicalized ty + -- GetNameType v -> do + -- ty <- cheapNormalize =<< sourceNameType v + -- logTop $ TextOut $ pprintCanonicalized ty + -- sourceNameType :: (EnvReader m, Fallible1 m) => SourceName -> m n (Type CoreIR n) + -- sourceNameType v = do + -- lookupSourceMap v >>= \case + -- Nothing -> throw UnboundVarErr $ pprint v + -- Just uvar -> getUVarType uvar ImportModule moduleName -> importModule moduleName QueryEnv query -> void $ runEnvQuery query $> UnitE ProseBlock _ -> return () diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index e75e5cd4c..474e82a07 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -15,7 +15,6 @@ import GHC.Stack import Builder import Core -import CheapReduction import Err import Imp import IRVariants @@ -43,22 +42,23 @@ runTransposeM cont = runReaderT1 (ListE []) $ runSubstReaderT idSubst $ cont transposeTopFun :: (MonadFail1 m, EnvReader m) => STopLam n -> m n (STopLam n) -transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do - (Abs bsNonlin (Abs bLin body), outTyAbs) <- unpackLinearLamExpr lam - refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do - outTy <- instantiate outTyAbs (Var <$> bindersVars bsNonlin') - withFreshBinder "ct" outTy \bCT -> do - let ct = Var $ binderVar bCT - body' <- buildBlock do - inTy <- substNonlin $ binderType bLin - withAccumulator inTy \refSubstVal -> - extendSubstBD bLin [refSubstVal] $ - transposeBlock body (sink ct) - EffTy _ bodyTy <- blockEffTy body' - let piTy = PiType (bsNonlin' >>> UnaryNest (PlainBD bCT)) (EffTy Pure bodyTy) - let lamT = LamExpr (bsNonlin' >>> UnaryNest (PlainBD bCT)) body' - return $ TopLam False piTy lamT -transposeTopFun (TopLam True _ _) = error "shouldn't be transposing in destination passing style" +transposeTopFun (TopLam False _ lam) = undefined +-- transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do +-- (Abs bsNonlin (Abs bLin body), outTyAbs) <- unpackLinearLamExpr lam +-- refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do +-- outTy <- instantiate outTyAbs (Var <$> bindersVars bsNonlin') +-- withFreshBinder "ct" outTy \bCT -> do +-- let ct = Var $ binderVar bCT +-- body' <- buildBlock do +-- inTy <- substNonlin $ binderType bLin +-- withAccumulator inTy \refSubstVal -> +-- extendSubstBD bLin [refSubstVal] $ +-- transposeBlock body (sink ct) +-- EffTy _ bodyTy <- blockEffTy body' +-- let piTy = PiType (bsNonlin' >>> UnaryNest (PlainBD bCT)) (EffTy Pure bodyTy) +-- let lamT = LamExpr (bsNonlin' >>> UnaryNest (PlainBD bCT)) body' +-- return $ TopLam False piTy lamT +-- transposeTopFun (TopLam True _ _) = error "shouldn't be transposing in destination passing style" unpackLinearLamExpr :: (MonadFail1 m, EnvReader m) => LamExpr SimpIR n @@ -216,15 +216,6 @@ transposeExpr expr ct = case expr of refProj <- naryIndexRef ref (toList is') emitCTToRef refProj ct LinTrivial -> return () - ProjectElt _ i' x' -> do - let (idxs, v) = asNaryProj i' x' - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "an error, probably" - LinRef ref -> do - ref' <- getNaryProjRef (toList idxs) ref - refProj <- naryIndexRef ref' (toList is') - emitCTToRef refProj ct - LinTrivial -> return () _ -> error $ "shouldn't occur: " ++ pprint x PrimOp op -> transposeOp op ct Case e alts _ -> do @@ -245,6 +236,15 @@ transposeExpr expr ct = case expr of forM_ (enumerate es) \(ordinalIdx, e) -> do i <- unsafeFromOrdinal idxTy (IdxRepVal $ fromIntegral ordinalIdx) tabApp ct i >>= transposeAtom e + ProjectElt _ i' x' -> undefined + -- ProjectElt _ i' x' -> do + -- let (idxs, v) = asNaryProj i' x' + -- lookupSubstM (atomVarName v) >>= \case + -- RenameNonlin _ -> error "an error, probably" + -- LinRef ref -> do + -- ref' <- getNaryProjRef (toList idxs) ref + -- emitCTToRef ref' ct + -- LinTrivial -> return () transposeOp :: Emits o => PrimOp SimpIR i -> SAtom o -> TransposeM i o () transposeOp op ct = case op of @@ -315,14 +315,6 @@ transposeAtom atom ct = case atom of Con con -> transposeCon con ct DepPair _ _ _ -> notImplemented PtrVar _ _ -> notTangent - ProjectElt _ i' x' -> do - let (idxs, v) = asNaryProj i' x' - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "an error, probably" - LinRef ref -> do - ref' <- getNaryProjRef (toList idxs) ref - emitCTToRef ref' ct - LinTrivial -> return () RepValAtom _ -> error "not implemented" where notTangent = error $ "Not a tangent atom: " ++ pprint atom diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 9bfeec19b..a60ad149f 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -23,6 +23,7 @@ module Types.Core (module Types.Core, SymbolicZeros (..)) where +import Control.Category ((>>>)) import Data.Word import Data.Maybe (fromJust) import Data.Functor @@ -51,15 +52,12 @@ data Atom (r::IR) (n::S) where Var :: AtomVar r n -> Atom r n Con :: Con r n -> Atom r n PtrVar :: PtrType -> PtrName n -> Atom r n - ProjectElt :: Type r n -> Projection -> Atom r n -> Atom r n DepPair :: Atom r n -> Atom r n -> DepPairType r n -> Atom r n -- === CoreIR only === Lam :: CoreLamExpr n -> Atom CoreIR n Eff :: EffectRow CoreIR n -> Atom CoreIR n DictCon :: Type CoreIR n -> DictExpr n -> Atom CoreIR n NewtypeCon :: NewtypeCon n -> Atom CoreIR n -> Atom CoreIR n - DictHole :: AlwaysEqual SrcPosCtx -> Type CoreIR n -> RequiredMethodAccess - -> Atom CoreIR n TypeAsAtom :: Type CoreIR n -> Atom CoreIR n -- === Shims between IRs === SimpInCore :: SimpInCore n -> Atom CoreIR n @@ -73,10 +71,6 @@ data Type (r::IR) (n::S) where DictTy :: DictType n -> Type CoreIR n Pi :: CorePiType n -> Type CoreIR n NewtypeTyCon :: NewtypeTyCon n -> Type CoreIR n - -- It was bad enough having this in `Atom`, but it's even worse now that it's - -- replicated in `Type` too. We should be able to remove both once - -- we represent types as normalized blocks. - ProjectEltTy :: CType n -> Projection -> CAtom n -> Type CoreIR n data AtomVar (r::IR) (n::S) = AtomVar { atomVarName :: AtomName r n @@ -88,7 +82,7 @@ data SimpInCore (n::S) = LiftSimp (CType n) (SAtom n) | LiftSimpFun (CorePiType n) (LamExpr SimpIR n) | TabLam (TabPiType CoreIR n) (TabLamExpr n) - | ACase (SAtom n) [Abs SBinder CAtom n] (CType n) + | ACase (SAtom n) [Abs SBinderAndDecls CAtom n] (CType n) deriving (Show, Generic) deriving instance IRRep r => Show (Atom r n) @@ -105,6 +99,8 @@ data Expr r n where PrimOp :: PrimOp r n -> Expr r n App :: EffTy CoreIR n -> CAtom n -> [CAtom n] -> Expr CoreIR n ApplyMethod :: EffTy CoreIR n -> CAtom n -> Int -> [CAtom n] -> Expr CoreIR n + ProjectElt :: Type r n -> Projection -> Atom r n -> Expr r n + DictHole :: AlwaysEqual SrcPosCtx -> Type CoreIR n -> RequiredMethodAccess -> Expr CoreIR n deriving instance IRRep r => Show (Expr r n) deriving via WrapE (Expr r) n instance IRRep r => Generic (Expr r n) @@ -143,9 +139,9 @@ type AtomBinderP (r::IR) = BinderP (AtomNameC r) type Binder r = AtomBinderP r (Type r) :: B type Alt r = Abs (Binder r) (Block r) :: E --- This doesn't actually include the decls yet. I'm starting by making the type --- distinct from the underlying binder without changing anything else. -data BinderAndDecls (r::IR) (n::S) (l::S) = BD (Binder r n l) +data BinderAndDecls (r::IR) (n::S) (l::S) where + BD :: Binder r n l -> Nest (Decl r) l l' -> BinderAndDecls r n l' + type Binders r = Nest (BinderAndDecls r) newtype DotMethods n = DotMethods (M.Map SourceName (CAtomName n)) @@ -529,6 +525,7 @@ data InstanceBody (n::S) = data DictType (n::S) = DictType SourceName (ClassName n) [CAtom n] deriving (Show, Generic) +-- TODO: remove this and just do these operations in decls data DictExpr (n::S) = InstantiatedGiven (CAtom n) [CAtom n] | SuperclassProj (CAtom n) Int -- (could instantiate here too, but we don't need it for now) @@ -960,9 +957,9 @@ class BindsNames b => ToBinderVar (b::B) (r::IR) | b -> r where binderVar :: (IRRep r, DExt n l) => b n l -> AtomVar r l instance IRRep r => ToBinderVar (BinderAndDecls r) r where - binderType (BD (_:>ty)) = ty - binderVar (BD (b:>ty)) = - AtomVar (sink $ binderName b) (sink ty) + binderType (BD (_:>ty) _) = ty + binderVar (BD (b:>ty) ds) = + AtomVar (withExtEvidence ds $ sink $ binderName b) (sink ty) instance IRRep r => ToBinderVar (Binder r) r where binderType (_:>ty) = ty @@ -976,6 +973,9 @@ bindersVars = \case Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ sink (binderVar b) : bindersVars bs +binderAtom :: (IRRep r, DExt n l, ToBinderVar b r) => b n l -> Atom r l +binderAtom b = Var $ binderVar b + -- === ToBinding === atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n @@ -1015,11 +1015,9 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where -- a Var, it doesn't check whether it's a type. pattern Type :: CType n -> CAtom n pattern Type t <- ((\case Var v -> Just (TyVar v) - ProjectElt t i x -> Just $ ProjectEltTy t i x TypeAsAtom t -> Just t _ -> Nothing) -> Just t) where Type (TyVar v) = Var v - Type (ProjectEltTy t i x) = ProjectElt t i x Type t = TypeAsAtom t pattern IdxRepScalarBaseTy :: ScalarBaseType @@ -1112,21 +1110,19 @@ pattern FinConst :: Word32 -> Type CoreIR n pattern FinConst n = NewtypeTyCon (Fin (NatVal n)) pattern PlainBD :: Binder r n l -> BinderAndDecls r n l -pattern PlainBD b = BD b -- this will become `BD b Empty` +pattern PlainBD b = BD b Empty pattern NullaryLamExpr :: Block r n -> LamExpr r n pattern NullaryLamExpr body = LamExpr Empty body asUnaryLamExpr :: LamExpr r n -> Maybe (Abs (Binder r) (Block r) n) -asUnaryLamExpr (LamExpr (UnaryNest (BD b)) (Abs decls result)) = - Just $ Abs b $ Abs decls result --- asUnaryLamExpr (LamExpr (UnaryNest (BD b decls)) (Abs decls' result)) = --- Just $ Abs b $ Abs (decls >>> decls') result +asUnaryLamExpr (LamExpr (UnaryNest (BD b decls)) (Abs decls' result)) = + Just $ Abs b $ Abs (decls >>> decls') result asUnaryLamExpr _ = Nothing pattern UnaryLamExpr :: Binder r n l -> Block r l -> LamExpr r n pattern UnaryLamExpr b body <- (asUnaryLamExpr -> Just (Abs b body)) - where UnaryLamExpr b body = LamExpr (UnaryNest (BD b)) body + where UnaryLamExpr b body = LamExpr (UnaryNest (PlainBD b)) body pattern BinaryLamExpr :: BinderAndDecls r n l1 -> BinderAndDecls r l1 l2 -> Block r l2 -> LamExpr r n pattern BinaryLamExpr b1 b2 body = LamExpr (BinaryNest b1 b2) body @@ -1462,7 +1458,7 @@ instance GenericE SimpInCore where {- LiftSimp -} (CType `PairE` SAtom) {- LiftSimpFun -} (CorePiType `PairE` LamExpr SimpIR) {- TabLam -} (TabPiType CoreIR `PairE` TabLamExpr) - {- ACase -} (SAtom `PairE` ListE (Abs SBinder CAtom) `PairE` CType) + {- ACase -} (SAtom `PairE` ListE (Abs SBinderAndDecls CAtom) `PairE` CType) fromE = \case LiftSimp ty x -> Case0 $ ty `PairE` x LiftSimpFun ty x -> Case1 $ ty `PairE` x @@ -1490,17 +1486,13 @@ instance IRRep r => GenericE (Atom r) where -- toE/fromE entirely. If you wish to modify the order, please consult the -- GHC Core dump to make sure you haven't regressed this optimization. type RepE (Atom r) = EitherE3 - (EitherE4 + (EitherE3 {- Var -} (AtomVar r) - {- ProjectElt -} (Type r `PairE` LiftE Projection `PairE` Atom r) {- Lam -} (WhenCore r CoreLamExpr) {- DepPair -} (Atom r `PairE` Atom r `PairE` DepPairType r) - ) (EitherE4 + ) (EitherE3 {- DictCon -} (WhenCore r (CType `PairE` DictExpr)) {- NewtypeCon -} (WhenCore r (NewtypeCon `PairE` Atom r)) - {- DictHole -} (WhenCore r (LiftE (AlwaysEqual SrcPosCtx) `PairE` - (Type CoreIR) `PairE` - (LiftE RequiredMethodAccess))) {- Con -} (Con r) ) (EitherE5 {- Eff -} ( WhenCore r (EffectRow r)) @@ -1512,13 +1504,11 @@ instance IRRep r => GenericE (Atom r) where fromE atom = case atom of Var v -> Case0 (Case0 v) - ProjectElt t idxs x -> Case0 (Case1 (t `PairE` LiftE idxs `PairE` x)) - Lam lamExpr -> Case0 (Case2 (WhenIRE lamExpr)) - DepPair l r ty -> Case0 (Case3 $ l `PairE` r `PairE` ty) + Lam lamExpr -> Case0 (Case1 (WhenIRE lamExpr)) + DepPair l r ty -> Case0 (Case2 $ l `PairE` r `PairE` ty) DictCon t d -> Case1 $ Case0 $ WhenIRE $ t `PairE` d NewtypeCon c x -> Case1 $ Case1 $ WhenIRE (c `PairE` x) - DictHole s t access -> Case1 $ Case2 $ WhenIRE (LiftE s `PairE` t `PairE` LiftE access) - Con con -> Case1 $ Case3 con + Con con -> Case1 $ Case2 con Eff effs -> Case2 $ Case0 $ WhenIRE effs PtrVar t v -> Case2 $ Case1 $ LiftE t `PairE` v RepValAtom rv -> Case2 $ Case2 $ WhenIRE $ rv @@ -1529,15 +1519,13 @@ instance IRRep r => GenericE (Atom r) where toE atom = case atom of Case0 val -> case val of Case0 v -> Var v - Case1 (t `PairE` LiftE idxs `PairE` x) -> ProjectElt t idxs x - Case2 (WhenIRE (lamExpr)) -> Lam lamExpr - Case3 (l `PairE` r `PairE` ty) -> DepPair l r ty + Case1 (WhenIRE (lamExpr)) -> Lam lamExpr + Case2 (l `PairE` r `PairE` ty) -> DepPair l r ty _ -> error "impossible" Case1 val -> case val of Case0 (WhenIRE (t `PairE` d)) -> DictCon t d Case1 (WhenIRE (c `PairE` x)) -> NewtypeCon c x - Case2 (WhenIRE (LiftE s `PairE` t `PairE` LiftE access)) -> DictHole s t access - Case3 con -> Con con + Case2 con -> Con con _ -> error "impossible" Case2 val -> case val of Case0 (WhenIRE effs) -> Eff effs @@ -1583,7 +1571,7 @@ instance IRRep r => AlphaHashableE (AtomVar r) where instance IRRep r => RenameE (AtomVar r) instance IRRep r => GenericE (Type r) where - type RepE (Type r) = EitherE8 + type RepE (Type r) = EitherE7 {- TyVar -} (WhenCore r CAtomVar) {- Pi -} (WhenCore r CorePiType) {- TabPi -} (TabPiType r) @@ -1591,7 +1579,6 @@ instance IRRep r => GenericE (Type r) where {- DictTy -} (WhenCore r DictType) {- NewtypeTyCon -} (WhenCore r NewtypeTyCon) {- TC -} (TC r) - {- ProjectEltTy -} (WhenCore r (Type r `PairE` LiftE Projection `PairE` Atom r)) fromE = \case TyVar v -> Case0 $ WhenIRE v @@ -1601,7 +1588,6 @@ instance IRRep r => GenericE (Type r) where DictTy d -> Case4 $ WhenIRE d NewtypeTyCon t -> Case5 $ WhenIRE t TC con -> Case6 $ con - ProjectEltTy t idxs x -> Case7 (WhenIRE (t `PairE` LiftE idxs `PairE` x)) {-# INLINE fromE #-} toE = \case @@ -1612,7 +1598,7 @@ instance IRRep r => GenericE (Type r) where Case4 (WhenIRE d) -> DictTy d Case5 (WhenIRE t) -> NewtypeTyCon t Case6 con -> TC con - Case7 (WhenIRE (t `PairE` LiftE idxs `PairE` x)) -> ProjectEltTy t idxs x + _ -> error "impossible" {-# INLINE toE #-} instance IRRep r => SinkableE (Type r) @@ -1630,10 +1616,14 @@ instance IRRep r => GenericE (Expr r) where {- Atom -} (Atom r) {- TopApp -} (WhenSimp r (EffTy r `PairE` TopFunName `PairE` ListE (Atom r))) ) - ( EitherE3 + ( EitherE5 {- TabCon -} (MaybeE (WhenCore r Dict) `PairE` Type r `PairE` ListE (Atom r)) {- PrimOp -} (PrimOp r) - {- ApplyMethod -} (WhenCore r (EffTy r `PairE` Atom r `PairE` LiftE Int `PairE` ListE (Atom r)))) + {- ApplyMethod -} (WhenCore r (EffTy r `PairE` Atom r `PairE` LiftE Int `PairE` ListE (Atom r))) + {- ProjectElt -} (Type r `PairE` LiftE Projection `PairE` Atom r) + {- DictHole -} (WhenCore r (LiftE (AlwaysEqual SrcPosCtx) `PairE` + (Type CoreIR) `PairE` + (LiftE RequiredMethodAccess)))) fromE = \case App et f xs -> Case0 $ Case0 (WhenIRE (et `PairE` f `PairE` ListE xs)) @@ -1644,6 +1634,8 @@ instance IRRep r => GenericE (Expr r) where TabCon d ty xs -> Case1 $ Case0 (toMaybeE d `PairE` ty `PairE` ListE xs) PrimOp op -> Case1 $ Case1 op ApplyMethod et d i xs -> Case1 $ Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) + ProjectElt t idxs x -> Case1 $ Case3 (t `PairE` LiftE idxs `PairE` x) + DictHole s t access -> Case1 $ Case4 $ WhenIRE (LiftE s `PairE` t `PairE` LiftE access) {-# INLINE fromE #-} toE = \case Case0 case0 -> case case0 of @@ -1657,6 +1649,8 @@ instance IRRep r => GenericE (Expr r) where Case0 (d `PairE` ty `PairE` ListE xs) -> TabCon (fromMaybeE d) ty xs Case1 op -> PrimOp op Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) -> ApplyMethod et d i xs + Case3 (t `PairE` LiftE idxs `PairE` x) -> ProjectElt t idxs x + Case4 (WhenIRE (LiftE s `PairE` t `PairE` LiftE access)) -> DictHole s t access _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} @@ -2416,14 +2410,14 @@ instance IRRep r => ProvesExt (Decl r) instance IRRep r => BindsNames (Decl r) instance GenericB (BinderAndDecls r) where - type RepB (BinderAndDecls r) = Binder r - fromB (BD b) = b + type RepB (BinderAndDecls r) = PairB (Binder r) (Nest (Decl r)) + fromB (BD b ds) = PairB b ds {-# INLINE fromB #-} - toB b = BD b + toB (PairB b ds) = BD b ds {-# INLINE toB #-} instance HasNameHint (BinderAndDecls r n l) where - getNameHint (BD b) = getNameHint b + getNameHint (BD b _) = getNameHint b deriving instance IRRep r => Show (BinderAndDecls r n l) deriving via WrapB (BinderAndDecls r) n l instance IRRep r => Generic (BinderAndDecls r n l) diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index e1cc78587..c1315c87b 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -17,7 +17,6 @@ import Control.Monad.State.Strict import Builder import Core import Err -import CheapReduction import IRVariants import Lower (DestBlock) import MTL1 @@ -28,6 +27,7 @@ import QueryType import Types.Core import Types.Primitives import Util (allM, zipWithZ) +import Visitor -- === Vectorization === @@ -91,13 +91,14 @@ newtype TopVectorizeM (i::S) (o::S) (a:: *) = TopVectorizeM , SubstReader Name) vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, Errs) -vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do - case popNest bsDestB of - Just (PairB bs (BD b)) -> - refreshAbs (Abs bs (Abs b body)) \bs' body' -> do - (Abs b'' body'', errs) <- liftTopVectorizeM width $ vectorizeLoopsDestBlock body' - return $ (TopLam d ty (LamExpr (bs' >>> UnaryNest (PlainBD b'')) body''), errs) - Nothing -> error "expected a trailing dest binder" +vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = undefined +-- vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do +-- case popNest bsDestB of +-- Just (PairB bs (BD b)) -> +-- refreshAbs (Abs bs (Abs b body)) \bs' body' -> do +-- (Abs b'' body'', errs) <- liftTopVectorizeM width $ vectorizeLoopsDestBlock body' +-- return $ (TopLam d ty (LamExpr (bs' >>> UnaryNest (PlainBD b'')) body''), errs) +-- Nothing -> error "expected a trailing dest binder" liftTopVectorizeM :: (EnvReader m) => Word32 -> TopVectorizeM i i a -> m i (a, Errs) @@ -156,14 +157,15 @@ vectorizeLoopsDecls nest cont = vectorizeLoopsDecls rest cont vectorizeLoopsLamExpr :: LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o) -vectorizeLoopsLamExpr (LamExpr bs body) = case bs of - Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsBlock body) - Nest b rest -> do - ty <- renameM $ binderType b - withFreshBinder (getNameHint b) ty \b' -> do - extendSubstBD b [binderName b'] do - LamExpr bs' body' <- vectorizeLoopsLamExpr $ LamExpr rest body - return $ LamExpr (Nest (BD b') bs') body' +vectorizeLoopsLamExpr (LamExpr bs body) = undefined +-- vectorizeLoopsLamExpr (LamExpr bs body) = case bs of +-- Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsBlock body) +-- Nest b rest -> do +-- ty <- renameM $ binderType b +-- withFreshBinder (getNameHint b) ty \b' -> do +-- extendSubstBD b [binderName b'] do +-- LamExpr bs' body' <- vectorizeLoopsLamExpr $ LamExpr rest body +-- return $ LamExpr (Nest (BD b') bs') body' vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o) vectorizeLoopsExpr expr = do @@ -225,12 +227,13 @@ vectorizeLoopsExpr expr = do simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m) => IxType SimpIR n -> m n (Maybe Word32) -simplifyIxSize ixty = do - sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size [] - cheapReduce sizeMethod >>= \case - Just (IdxRepVal n) -> return $ Just n - _ -> return Nothing -{-# INLINE simplifyIxSize #-} +simplifyIxSize ixty = undefined +-- simplifyIxSize ixty = do +-- sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size [] +-- cheapReduce sizeMethod >>= \case +-- Just (IdxRepVal n) -> return $ Just n +-- _ -> return Nothing +-- {-# INLINE simplifyIxSize #-} -- Really we should check this by seeing whether there is an instance for a -- `Commutative` class, or something like that, but for now just pattern-match @@ -396,6 +399,16 @@ vectorizeExpr expr = addVectErrCtx "vectorizeExpr" ("Expr:\n" ++ pprint expr) do throwVectErr $ "bad type: " ++ pprint tblTy ++ "\ntbl' : " ++ pprint tbl' Atom atom -> vectorizeAtom atom PrimOp op -> vectorizePrimOp op + -- Vectors of base newtypes are already newtype-stripped. + ProjectElt _ (ProjectProduct i) x -> undefined + -- ProjectElt _ (ProjectProduct i) x -> do + -- VVal vv x' <- vectorizeAtom x + -- ov <- case vv of + -- ProdStability sbs -> return $ sbs !! i + -- _ -> throwVectErr "Invalid projection" + -- x'' <- normalizeProj (ProjectProduct i) x' + -- return $ VVal ov x'' + ProjectElt _ UnwrapNewtype _ -> error "Shouldn't have newtypes left" -- TODO: check statically _ -> throwVectErr $ "Cannot vectorize expr: " ++ pprint expr vectorizeDAMOp :: Emits o => DAMOp SimpIR i -> VectorizeM i o (VAtom o) @@ -527,15 +540,6 @@ vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do Var v -> lookupSubstM (atomVarName v) >>= \case VRename v' -> VVal Uniform . Var <$> toAtomVar v' v' -> return v' - -- Vectors of base newtypes are already newtype-stripped. - ProjectElt _ (ProjectProduct i) x -> do - VVal vv x' <- vectorizeAtom x - ov <- case vv of - ProdStability sbs -> return $ sbs !! i - _ -> throwVectErr "Invalid projection" - x'' <- normalizeProj (ProjectProduct i) x' - return $ VVal ov x'' - ProjectElt _ UnwrapNewtype _ -> error "Shouldn't have newtypes left" -- TODO: check statically Con (Lit l) -> return $ VVal Uniform $ Con $ Lit l _ -> do subst <- getSubst diff --git a/src/lib/Visitor.hs b/src/lib/Visitor.hs new file mode 100644 index 000000000..1fdd85a09 --- /dev/null +++ b/src/lib/Visitor.hs @@ -0,0 +1,269 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module Visitor where + +import Control.Applicative +import Control.Monad.Trans +import Control.Monad.Writer.Strict hiding (Alt) +import Control.Monad.State.Strict +import Control.Monad.Reader +import Data.Foldable (toList) +import Data.Functor.Identity +import Data.Functor ((<&>)) +import qualified Data.List.NonEmpty as NE +import qualified Data.Map.Strict as M + +import Core +import Err +import IRVariants +import MTL1 +import Name +import PPrint () +import QueryType +import Types.Core +import Types.Imp +import Types.Primitives + +type family IsAtomName (c::C) where + IsAtomName (AtomNameC r) = True + IsAtomName _ = False + +class Monad m => NonAtomRenamer m i o | m -> i, m -> o where + renameN :: (IsAtomName c ~ False, Color c) => Name c i -> m (Name c o) + +class NonAtomRenamer m i o => Visitor m r i o | m -> i, m -> o where + visitType :: Type r i -> m (Type r o) + visitAtom :: Atom r i -> m (Atom r o) + visitLam :: LamExpr r i -> m (LamExpr r o) + visitPi :: PiType r i -> m (PiType r o) + +class VisitGeneric (e:: E) (r::IR) | e -> r where + visitGeneric :: Visitor m r i o => e i -> m (e o) + +type Visitor2 (m::MonadKind2) r = forall i o . Visitor (m i o) r i o + +instance VisitGeneric (Atom r) r where visitGeneric = visitAtom +instance VisitGeneric (Type r) r where visitGeneric = visitType +instance VisitGeneric (LamExpr r) r where visitGeneric = visitLam +instance VisitGeneric (PiType r) r where visitGeneric = visitPi + +visitBlock :: Visitor m r i o => Block r i -> m (Block r o) +visitBlock b = visitGeneric (LamExpr Empty b) >>= \case + LamExpr Empty b' -> return b' + _ -> error "not a block" + +visitAlt :: Visitor m r i o => Alt r i -> m (Alt r o) +visitAlt (Abs b body) = do + visitGeneric (UnaryLamExpr b body) >>= \case + UnaryLamExpr b' body' -> return $ Abs b' body' + _ -> error "not an alt" + +traverseOpTerm + :: (GenericOp e, Visitor m r i o, OpConst e r ~ OpConst e r) + => e r i -> m (e r o) +traverseOpTerm e = traverseOp e visitGeneric visitGeneric visitGeneric + +-- XXX: This doesn't handle the `Var`, `ProjectElt`, `SimpInCore` cases. These +-- should be handled explicitly beforehand. TODO: split out these cases under a +-- separate constructor, perhaps even a `hole` paremeter to `Atom` or part of +-- `IR`. +visitAtomPartial :: (IRRep r, Visitor m r i o) => Atom r i -> m (Atom r o) +visitAtomPartial = \case + Var _ -> error "Not handled generically" + SimpInCore _ -> error "Not handled generically" + Con con -> Con <$> visitGeneric con + PtrVar t v -> PtrVar t <$> renameN v + DepPair x y t -> do + x' <- visitGeneric x + y' <- visitGeneric y + ~(DepPairTy t') <- visitGeneric $ DepPairTy t + return $ DepPair x' y' t' + Lam lam -> Lam <$> visitGeneric lam + Eff eff -> Eff <$> visitGeneric eff + DictCon t d -> DictCon <$> visitType t <*> visitGeneric d + NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x + TypeAsAtom t -> TypeAsAtom <$> visitGeneric t + RepValAtom repVal -> RepValAtom <$> visitGeneric repVal + +-- XXX: This doesn't handle the `TyVar` or `ProjectEltTy` cases. These should be +-- handled explicitly beforehand. +visitTypePartial :: (IRRep r, Visitor m r i o) => Type r i -> m (Type r o) +visitTypePartial = \case + TyVar _ -> error "Not handled generically" + NewtypeTyCon t -> NewtypeTyCon <$> visitGeneric t + Pi t -> Pi <$> visitGeneric t + TabPi t -> TabPi <$> visitGeneric t + TC t -> TC <$> visitGeneric t + DepPairTy t -> DepPairTy <$> visitGeneric t + DictTy t -> DictTy <$> visitGeneric t + +instance IRRep r => VisitGeneric (Expr r) r where + visitGeneric = \case + TopApp et v xs -> TopApp <$> visitGeneric et <*> renameN v <*> mapM visitGeneric xs + TabApp t tab xs -> TabApp <$> visitType t <*> visitGeneric tab <*> mapM visitGeneric xs + -- TODO: should we reuse the original effects? Whether it's valid depends on + -- the type-preservation requirements for a visitor. We should clarify what + -- those are. + Case x alts effTy -> do + x' <- visitGeneric x + alts' <- mapM visitAlt alts + effTy' <- visitGeneric effTy + return $ Case x' alts' effTy' + Atom x -> Atom <$> visitGeneric x + TabCon Nothing t xs -> TabCon Nothing <$> visitGeneric t <*> mapM visitGeneric xs + TabCon (Just (WhenIRE d)) t xs -> TabCon <$> (Just . WhenIRE <$> visitGeneric d) <*> visitGeneric t <*> mapM visitGeneric xs + PrimOp op -> PrimOp <$> visitGeneric op + App et fAtom xs -> App <$> visitGeneric et <*> visitGeneric fAtom <*> mapM visitGeneric xs + ApplyMethod et m i xs -> ApplyMethod <$> visitGeneric et <*> visitGeneric m <*> pure i <*> mapM visitGeneric xs + DictHole ctx ty access -> DictHole ctx <$> visitGeneric ty <*> pure access + +instance IRRep r => VisitGeneric (PrimOp r) r where + visitGeneric = \case + UnOp op x -> UnOp op <$> visitGeneric x + BinOp op x y -> BinOp op <$> visitGeneric x <*> visitGeneric y + MemOp op -> MemOp <$> visitGeneric op + VectorOp op -> VectorOp <$> visitGeneric op + MiscOp op -> MiscOp <$> visitGeneric op + Hof op -> Hof <$> visitGeneric op + DAMOp op -> DAMOp <$> visitGeneric op + RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric visitGeneric + +instance IRRep r => VisitGeneric (TypedHof r) r where + visitGeneric (TypedHof eff hof) = TypedHof <$> visitGeneric eff <*> visitGeneric hof + +instance IRRep r => VisitGeneric (Hof r) r where + visitGeneric = \case + For ann d lam -> For ann <$> visitGeneric d <*> visitGeneric lam + RunReader x body -> RunReader <$> visitGeneric x <*> visitGeneric body + RunWriter dest bm body -> RunWriter <$> mapM visitGeneric dest <*> visitGeneric bm <*> visitGeneric body + RunState dest s body -> RunState <$> mapM visitGeneric dest <*> visitGeneric s <*> visitGeneric body + While b -> While <$> visitBlock b + RunIO b -> RunIO <$> visitBlock b + RunInit b -> RunInit <$> visitBlock b + CatchException t b -> CatchException <$> visitType t <*> visitBlock b + Linearize lam x -> Linearize <$> visitGeneric lam <*> visitGeneric x + Transpose lam x -> Transpose <$> visitGeneric lam <*> visitGeneric x + +instance IRRep r => VisitGeneric (BaseMonoid r) r where + visitGeneric (BaseMonoid x lam) = BaseMonoid <$> visitGeneric x <*> visitGeneric lam + +instance IRRep r => VisitGeneric (DAMOp r) r where + visitGeneric = \case + Seq eff dir d x lam -> Seq <$> visitGeneric eff <*> pure dir <*> visitGeneric d <*> visitGeneric x <*> visitGeneric lam + RememberDest eff x lam -> RememberDest <$> visitGeneric eff <*> visitGeneric x <*> visitGeneric lam + AllocDest t -> AllocDest <$> visitGeneric t + Place x y -> Place <$> visitGeneric x <*> visitGeneric y + Freeze x -> Freeze <$> visitGeneric x + +instance IRRep r => VisitGeneric (Effect r) r where + visitGeneric = \case + RWSEffect rws h -> RWSEffect rws <$> visitGeneric h + ExceptionEffect -> pure ExceptionEffect + IOEffect -> pure IOEffect + InitEffect -> pure InitEffect + +instance IRRep r => VisitGeneric (EffectRow r) r where + visitGeneric (EffectRow effs tailVar) = do + effs' <- eSetFromList <$> mapM visitGeneric (eSetToList effs) + tailEffRow <- case tailVar of + NoTail -> return $ EffectRow mempty NoTail + EffectRowTail v -> visitGeneric (Var v) <&> \case + Var v' -> EffectRow mempty (EffectRowTail v') + Eff r -> r + _ -> error "Not a valid effect substitution" + return $ extendEffRow effs' tailEffRow + +instance VisitGeneric DictExpr CoreIR where + visitGeneric = \case + InstantiatedGiven x xs -> InstantiatedGiven <$> visitGeneric x <*> mapM visitGeneric xs + SuperclassProj x i -> SuperclassProj <$> visitGeneric x <*> pure i + InstanceDict v xs -> InstanceDict <$> renameN v <*> mapM visitGeneric xs + IxFin x -> IxFin <$> visitGeneric x + DataData t -> DataData <$> visitGeneric t + +instance VisitGeneric NewtypeCon CoreIR where + visitGeneric = \case + UserADTData sn t params -> UserADTData sn <$> renameN t <*> visitGeneric params + NatCon -> return NatCon + FinCon x -> FinCon <$> visitGeneric x + +instance VisitGeneric NewtypeTyCon CoreIR where + visitGeneric = \case + Nat -> return Nat + Fin x -> Fin <$> visitGeneric x + EffectRowKind -> return EffectRowKind + UserADTType n v params -> UserADTType n <$> renameN v <*> visitGeneric params + +instance VisitGeneric TyConParams CoreIR where + visitGeneric (TyConParams expls xs) = TyConParams expls <$> mapM visitGeneric xs + +instance IRRep r => VisitGeneric (IxDict r) r where + visitGeneric = \case + IxDictAtom x -> IxDictAtom <$> visitGeneric x + IxDictRawFin x -> IxDictRawFin <$> visitGeneric x + IxDictSpecialized t v xs -> IxDictSpecialized <$> visitGeneric t <*> renameN v <*> mapM visitGeneric xs + +instance IRRep r => VisitGeneric (IxType r) r where + visitGeneric (IxType t d) = IxType <$> visitType t <*> visitGeneric d + +instance VisitGeneric DictType CoreIR where + visitGeneric (DictType n v xs) = DictType n <$> renameN v <*> mapM visitGeneric xs + +instance VisitGeneric CoreLamExpr CoreIR where + visitGeneric (CoreLamExpr t lam) = CoreLamExpr <$> visitGeneric t <*> visitGeneric lam + +instance VisitGeneric CorePiType CoreIR where + visitGeneric (CorePiType app expl bs effty) = do + PiType bs' effty' <- visitGeneric $ PiType bs effty + return $ CorePiType app expl bs' effty' + +instance IRRep r => VisitGeneric (TabPiType r) r where + visitGeneric (TabPiType d b eltTy) = do + d' <- visitGeneric d + visitGeneric (PiType (UnaryNest b) (EffTy Pure eltTy)) <&> \case + PiType (UnaryNest b') (EffTy Pure eltTy') -> TabPiType d' b' eltTy' + _ -> error "not a table pi type" + +instance IRRep r => VisitGeneric (DepPairType r) r where + visitGeneric (DepPairType expl b ty) = do + visitGeneric (PiType (UnaryNest b) (EffTy Pure ty)) <&> \case + PiType (UnaryNest b') (EffTy Pure ty') -> DepPairType expl b' ty' + _ -> error "not a dependent pair type" + +instance VisitGeneric (RepVal SimpIR) SimpIR where + visitGeneric (RepVal ty tree) = RepVal <$> visitGeneric ty <*> mapM renameIExpr tree + where renameIExpr = \case + ILit l -> return $ ILit l + IVar v t -> IVar <$> renameN v <*> pure t + IPtrVar v t -> IPtrVar <$> renameN v <*> pure t + +instance IRRep r => VisitGeneric (DeclBinding r) r where + visitGeneric (DeclBinding ann expr) = DeclBinding ann <$> visitGeneric expr + +instance IRRep r => VisitGeneric (EffTy r) r where + visitGeneric (EffTy eff ty) = + EffTy <$> visitGeneric eff <*> visitGeneric ty + +instance VisitGeneric DataConDefs CoreIR where + visitGeneric = \case + ADTCons cons -> ADTCons <$> mapM visitGeneric cons + StructFields defs -> do + let (names, tys) = unzip defs + tys' <- mapM visitGeneric tys + return $ StructFields $ zip names tys' + +instance VisitGeneric DataConDef CoreIR where + visitGeneric (DataConDef sn (Abs bs UnitE) repTy ps) = do + PiType bs' _ <- visitGeneric $ PiType bs $ EffTy Pure UnitTy + repTy' <- visitGeneric repTy + return $ DataConDef sn (Abs bs' UnitE) repTy' ps + +instance VisitGeneric (Con r) r where visitGeneric = traverseOpTerm +instance VisitGeneric (TC r) r where visitGeneric = traverseOpTerm +instance VisitGeneric (MiscOp r) r where visitGeneric = traverseOpTerm +instance VisitGeneric (VectorOp r) r where visitGeneric = traverseOpTerm +instance VisitGeneric (MemOp r) r where visitGeneric = traverseOpTerm