diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 8e7563511..5f1c372de 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -786,8 +786,8 @@ buildMap :: (Emits n, ScopableBuilder r m) -> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l)) -> m n (Atom r n) buildMap xs f = do - TabTy d (_:>t) _ <- return $ getType xs - buildFor noHint Fwd (IxType t d) \i -> + TabPi t <- return $ getType xs + buildFor noHint Fwd (tabIxType t) \i -> tabApp (sink xs) (Var i) >>= f unzipTab :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n, Atom r n) @@ -857,8 +857,8 @@ zeroAt ty = liftEmitBuilder $ go ty where go = \case BaseTy bt -> return $ Con $ Lit $ zeroLit bt ProdTy tys -> ProdVal <$> mapM go tys - TabTy d (b:>t) bodyTy -> buildFor (getNameHint b) Fwd (IxType t d) \i -> - go =<< applySubst (b @> SubstVal (Var i)) bodyTy + TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> + go =<< instantiateTabPiTy (sink tabPi) (Var i) _ -> unreachable zeroLit bt = case bt of Scalar Float64Type -> Float64Lit 0.0 @@ -902,8 +902,8 @@ tangentBaseMonoidFor ty = do addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n) addTangent x y = do case getType x of - TabTy d (b:>t) _ -> - liftEmitBuilder $ buildFor (getNameHint b) Fwd (IxType t d) \i -> do + TabPi t -> + liftEmitBuilder $ buildFor (getNameHint t) Fwd (tabIxType t) \i -> do bindM2 addTangent (tabApp (sink x) (Var i)) (tabApp (sink y) (Var i)) TC con -> case con of BaseType (Scalar _) -> emitOp $ BinOp FAdd x y diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 9a35c2ed5..4c42bbda2 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -15,7 +15,7 @@ module CheapReduction , unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 - , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy + , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy, instantiateTabPiTy , bindersToVars, bindersToAtoms) where @@ -474,6 +474,10 @@ instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (Ef instantiatePiTy (PiType bs effTy) xs = do applySubst (bs @@> (SubstVal <$> xs)) effTy +instantiateTabPiTy :: (EnvReader m, IRRep r) => TabPiType r n -> Atom r n -> m n (Type r n) +instantiateTabPiTy (TabPiType _ b resultTy) x = do + applySubst (b @> SubstVal x) resultTy + -- Returns a representation type (type of an TypeCon-typed Newtype payload) -- given a list of instantiated DataConDefs. dataDefRep :: DataConDefs n -> CType n diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 671067d04..47cf2df19 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -788,11 +788,7 @@ checkTabApp ty (i:rest) = do resultTy' <- applySubst (b@>SubstVal i') resultTy checkTabApp resultTy' rest -checkArgTys - :: (Typer m r, SubstB AtomSubstVal b, BindsNames b, BindsOneAtomName r b, IRRep r) - => Nest b o o' - -> [Atom r o] - -> m i o () +checkArgTys :: (Typer m r, IRRep r) => Nest (Binder r) o o' -> [Atom r o] -> m i o () checkArgTys Empty [] = return () checkArgTys (Nest b bs) (x:xs) = do dropSubst $ x |: binderType b @@ -930,15 +926,14 @@ checkedInstantiateTyConDef (TyConDef _ _ bs cons) (TyConParams _ xs) = do checkedApplyNaryAbs (Abs bs cons) xs checkedApplyNaryAbs - :: forall b r e o m - . ( BindsOneAtomName r b, EnvReader m, Fallible1 m, SinkableE e - , SubstE AtomSubstVal e, IRRep r, SubstB AtomSubstVal b) - => Abs (Nest b) e o -> [Atom r o] -> m o (e o) + :: forall r e o m + . ( EnvReader m, Fallible1 m, SinkableE e , SubstE AtomSubstVal e, IRRep r) + => Abs (Nest (Binder r)) e o -> [Atom r o] -> m o (e o) checkedApplyNaryAbs (Abs bsTop e) xsTop = do go (EmptyAbs bsTop) xsTop applySubst (bsTop@@>(SubstVal<$>xsTop)) e where - go :: EmptyAbs (Nest b) o -> [Atom r o] -> m o () + go :: EmptyAbs (Nest (Binder r)) o -> [Atom r o] -> m o () go (Abs Empty UnitE) [] = return () go (Abs (Nest b bs) UnitE) (x:xs) = do checkAlphaEq (binderType b) (getType x) diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 42dcc7ba1..f7ab3184d 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -175,8 +175,8 @@ parseTabTy = go [] NewtypeTyCon Nat -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape TabTy d (b:>ixty) a -> do maybeN <- case IxType ixty d of - (IxType (NewtypeTyCon (Fin n)) _) -> return $ Just n - (IxType _ (IxDictRawFin n)) -> return $ Just n + IxType (NewtypeTyCon (Fin n)) _ -> return $ Just n + IxType _ (IxDictRawFin n) -> return $ Just n _ -> return Nothing maybeDim <- case maybeN of Just (Var v) -> do diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 1a2ab54dd..bfb73537c 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -366,16 +366,16 @@ toImpRefOp refDest' m = do ans <- liftBuilderImp $ emitBlock (sink body') storeAtom accDest ans False -> case accTy of - TabTy d (b:>t) eltTy -> do - let ixTy = IxType t d + TabPi t -> do + let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i xElt <- liftBuilderImp $ tabApp (sink x) (sink idx) yElt <- liftBuilderImp $ tabApp (sink y) (sink idx) - eltTy' <- applySubst (b@>SubstVal idx) eltTy + eltTy <- instantiateTabPiTy (sink t) idx ithDest <- indexDest (sink accDest) idx - liftMonoidCombine ithDest eltTy' (sink bc) xElt yElt + liftMonoidCombine ithDest eltTy (sink bc) xElt yElt _ -> error $ "Base monoid type mismatch: can't lift " ++ pprint baseTy ++ " to " ++ pprint accTy @@ -578,15 +578,15 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do alphaEq xTy accTy >>= \case True -> storeAtom accDest x False -> case accTy of - TabTy d (b:>t) eltTy -> do - let ixTy = IxType t d + TabPi t -> do + let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i x' <- sinkM x - eltTy' <- applySubst (b@>SubstVal idx) eltTy + eltTy <- instantiateTabPiTy (sink t) idx ithDest <- indexDest (sink accDest) idx - liftMonoidEmpty ithDest eltTy' x' + liftMonoidEmpty ithDest eltTy x' _ -> error $ "Base monoid type mismatch: can't lift " ++ pprint xTy ++ " to " ++ pprint accTy @@ -1002,11 +1002,11 @@ buildGarbageVal ty = -- === Operations on dests === indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n) -indexDest (Dest destValTy@(TabTy d (b:>t) eltTy) tree) i = do - eltTy' <- applySubst (b@>SubstVal i) eltTy - ord <- ordinalImp (IxType t d) i - leafTys <- typeToTree destValTy - Dest eltTy' <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do +indexDest (Dest (TabPi tabTy) tree) i = do + eltTy <- instantiateTabPiTy tabTy i + ord <- ordinalImp (tabIxType tabTy) i + leafTys <- typeToTree $ TabPi tabTy + Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do BufferType ixStruct _ <- return $ getRefBufferType leafTy offset <- computeOffsetImp ixStruct ord impOffset ptr offset @@ -1026,10 +1026,10 @@ indexRepValParam :: Emits n => SRepVal n -> SAtom n -> (SType n -> SType n) -> (IExpr n -> SubstImpM i n (IExpr n)) -> SubstImpM i n (SRepVal n) -indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc func = do - eltTy' <- applySubst (b@>SubstVal i) eltTy - ord <- ordinalImp (IxType t d) i - leafTys <- typeToTree tabTy +indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do + eltTy <- instantiateTabPiTy tabTy i + ord <- ordinalImp (tabIxType tabTy) i + leafTys <- typeToTree (TabPi tabTy) vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do BufferPtr (BufferType ixStruct _) <- return $ getIExprInterpretation leafTy offset <- computeOffsetImp ixStruct ord @@ -1041,7 +1041,7 @@ indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc _ -> func ptr' -- `func` may have changed the types of the `vals'`. The caller must also -- supply `tyFunc` to reflect that change in the SType. - return $ RepVal (tyFunc eltTy') vals' + return $ RepVal (tyFunc eltTy) vals' indexRepValParam _ _ _ _ = error "expected table type" {-# INLINE indexRepValParam #-} diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 8eaf8cbeb..bce5b8050 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -153,12 +153,12 @@ lowerFor _ _ _ _ _ = error "expected a unary lambda expression" lowerTabCon :: forall i o. Emits o => Maybe (Dest SimpIR o) -> SType i -> [SAtom i] -> LowerM i o (SExpr o) lowerTabCon maybeDest tabTy elems = do - tabTy'@(TabPi (TabPiType dict (_:>t) _)) <- substM tabTy + TabPi tabTy' <- substM tabTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest tabTy' + Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest $ TabPi tabTy' Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do - buildBlock $ unsafeFromOrdinal (sink $ IxType t dict) $ Var $ sink ord + buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ Var $ sink ord -- This is emitting a chain of RememberDest ops to force `dest` to be used -- linearly, and to force reads of the `Freeze dest'` result not to be -- reordered in front of the writes. diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 2501cbf8f..9be267241 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -119,8 +119,8 @@ instance IRRep r => HasType r (Con r) where getSuperclassType :: RNest CBinder n l -> Nest CBinder l l' -> Int -> CType n getSuperclassType _ Empty = error "bad index" -getSuperclassType bsAbove (Nest b bs) = \case - 0 -> ignoreHoistFailure $ hoist bsAbove $ binderType b +getSuperclassType bsAbove (Nest b@(_:>t) bs) = \case + 0 -> ignoreHoistFailure $ hoist bsAbove t i -> getSuperclassType (RNest bsAbove b) bs (i-1) instance IRRep r => HasType r (Expr r) where @@ -213,6 +213,9 @@ rawStrType = case newName "n" of rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy +tabIxType :: TabPiType r n -> IxType r n +tabIxType (TabPiType d (_:>t) _) = IxType t d + typesAsBinderNest :: (SinkableE e, HoistableE e, IRRep r) => [Type r n] -> e n -> Abs (Nest (Binder r)) e n diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 3255773ad..4a4c2c6a5 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -185,8 +185,8 @@ bufferTy h = do extendBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n () extendBuffer buf tab = do RefTy h _ <- return $ getType buf - TabTy d (_:>t) _ <- return $ getType tab - n <- applyIxMethodCore Size (IxType t d) [] + TabPi t <- return $ getType tab + n <- applyIxMethodCore Size (tabIxType t) [] void $ applyPreludeFunction "stack_extend_internal" [n, h, buf, tab] -- argument has type `Word8` @@ -237,8 +237,8 @@ forEachTabElt -> (forall l. (Emits l, DExt n l) => CAtom l -> CAtom l -> m l ()) -> m n () forEachTabElt tab cont = do - TabTy d (_:>t) _ <- return $ getType tab - let ixTy = IxType t d + TabPi t <- return $ getType tab + let ixTy = tabIxType t void $ buildFor "i" Fwd ixTy \i -> do x <- tabApp (sink tab) (Var i) i' <- applyIxMethodCore Ordinal (sink ixTy) [Var i] diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index c151937b6..a40b01fa5 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -153,12 +153,12 @@ getRepType ty = go ty where x <- liftSimpAtom (sink l) (Var $ binderVar b') r' <- go =<< applySubst (b@>SubstVal x) r return $ DepPairTy $ DepPairType expl b' r' - TabPi (TabPiType d (b:>t) bodyTy) -> do - let ixTy = IxType t d + TabPi tabTy -> do + let ixTy = tabIxType tabTy IxType t' d' <- simplifyIxType ixTy - withFreshBinder (getNameHint b) t' \b' -> do + withFreshBinder (getNameHint tabTy) t' \b' -> do x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b') - bodyTy' <- go =<< applySubst (b@>SubstVal x) bodyTy + bodyTy' <- go =<< instantiateTabPiTy (sink tabTy) x return $ TabPi $ TabPiType d' b' bodyTy' NewtypeTyCon con -> do (_, ty') <- unwrapNewtypeType con @@ -1025,7 +1025,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do return $ activeArg':rest buildTangentArgs _ _ _ = error "zip error" - fromNonDepNest :: (HoistableB b, BindsOneAtomName CoreIR b) => Nest b n l -> [CType n] + fromNonDepNest :: Nest CBinder n l -> [CType n] fromNonDepNest Empty = [] fromNonDepNest (Nest b bs) = case ignoreHoistFailure $ hoist b (Abs bs UnitE) of diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 9e9062d3f..d09f7c69f 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -897,17 +897,13 @@ data LinearizationSpec (n::S) = LinearizationSpec (TopFunName n) [Active] deriving (Show, Generic) --- === BindsOneAtomName === +-- === Binder utils === -class BindsOneName b (AtomNameC r) => BindsOneAtomName (r::IR) (b::B) | b -> r where - binderType :: b n l -> Type r n - binderVar :: DExt n l => b n l -> AtomVar r l +binderType :: Binder r n l -> Type r n +binderType (_:>ty) = ty -bindersTypes :: (IRRep r, Distinct l, ProvesExt b, BindsNames b, BindsOneAtomName r b) - => Nest b n l -> [Type r l] -bindersTypes Empty = [] -bindersTypes n@(Nest b bs) = ty : bindersTypes bs - where ty = withExtEvidence n $ sink (binderType b) +binderVar :: (IRRep r, DExt n l) => Binder r n l -> AtomVar r l +binderVar (b:>ty) = AtomVar (binderName b) (sink ty) nestToAtomVars :: (Distinct l, Ext n l, IRRep r) => Nest (Binder r) n l -> [AtomVar r l] @@ -916,14 +912,6 @@ nestToAtomVars = \case Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ sink (binderVar b) : nestToAtomVars bs -instance IRRep r => BindsOneAtomName r (BinderP (AtomNameC r) (Type r)) where - binderType (_ :> ty) = ty - binderVar (b:>t) = AtomVar (binderName b) (sink t) - -toBinderNest :: BindsOneAtomName r b => Nest b n l -> Nest (Binder r) n l -toBinderNest Empty = Empty -toBinderNest (Nest b bs) = Nest (asNameBinder b :> binderType b) (toBinderNest bs) - -- === ToBinding === atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n @@ -957,14 +945,6 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where toBinding (LeftE e) = toBinding e toBinding (RightE e) = toBinding e --- === HasArgType === - -class HasArgType (e::E) (r::IR) | e -> r where - argType :: e n -> Type r n - -instance HasArgType (TabPiType r) r where - argType (TabPiType _ (_:>ty) _) = ty - -- === Pattern synonyms === -- XXX: only use this pattern when you're actually expecting a type. If it's @@ -2055,6 +2035,9 @@ instance IRRep r => AlphaEqE (TabPiType r) where instance IRRep r => AlphaHashableE (TabPiType r) where hashWithSaltE env salt (TabPiType _ b t) = hashWithSaltE env salt $ Abs b t +instance HasNameHint (TabPiType r n) where + getNameHint (TabPiType _ b _) = getNameHint b + instance IRRep r => SinkableE (TabPiType r) instance IRRep r => HoistableE (TabPiType r) instance IRRep r => RenameE (TabPiType r)