Skip to content

Commit

Permalink
Avoid some uses of :> and @>.
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Jun 27, 2023
1 parent 75eacbf commit b274115
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 76 deletions.
12 changes: 6 additions & 6 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,8 @@ buildMap :: (Emits n, ScopableBuilder r m)
-> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l))
-> m n (Atom r n)
buildMap xs f = do
TabTy d (_:>t) _ <- return $ getType xs
buildFor noHint Fwd (IxType t d) \i ->
TabPi t <- return $ getType xs
buildFor noHint Fwd (tabIxType t) \i ->
tabApp (sink xs) (Var i) >>= f

unzipTab :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n, Atom r n)
Expand Down Expand Up @@ -857,8 +857,8 @@ zeroAt ty = liftEmitBuilder $ go ty where
go = \case
BaseTy bt -> return $ Con $ Lit $ zeroLit bt
ProdTy tys -> ProdVal <$> mapM go tys
TabTy d (b:>t) bodyTy -> buildFor (getNameHint b) Fwd (IxType t d) \i ->
go =<< applySubst (b @> SubstVal (Var i)) bodyTy
TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i ->
go =<< instantiateTabPiTy (sink tabPi) (Var i)
_ -> unreachable
zeroLit bt = case bt of
Scalar Float64Type -> Float64Lit 0.0
Expand Down Expand Up @@ -902,8 +902,8 @@ tangentBaseMonoidFor ty = do
addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n)
addTangent x y = do
case getType x of
TabTy d (b:>t) _ ->
liftEmitBuilder $ buildFor (getNameHint b) Fwd (IxType t d) \i -> do
TabPi t ->
liftEmitBuilder $ buildFor (getNameHint t) Fwd (tabIxType t) \i -> do
bindM2 addTangent (tabApp (sink x) (Var i)) (tabApp (sink y) (Var i))
TC con -> case con of
BaseType (Scalar _) -> emitOp $ BinOp FAdd x y
Expand Down
6 changes: 5 additions & 1 deletion src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module CheapReduction
, unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType
, liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..)
, visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy, instantiateTabPiTy
, bindersToVars, bindersToAtoms)
where

Expand Down Expand Up @@ -474,6 +474,10 @@ instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (Ef
instantiatePiTy (PiType bs effTy) xs = do
applySubst (bs @@> (SubstVal <$> xs)) effTy

instantiateTabPiTy :: (EnvReader m, IRRep r) => TabPiType r n -> Atom r n -> m n (Type r n)
instantiateTabPiTy (TabPiType _ b resultTy) x = do
applySubst (b @> SubstVal x) resultTy

-- Returns a representation type (type of an TypeCon-typed Newtype payload)
-- given a list of instantiated DataConDefs.
dataDefRep :: DataConDefs n -> CType n
Expand Down
15 changes: 5 additions & 10 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -788,11 +788,7 @@ checkTabApp ty (i:rest) = do
resultTy' <- applySubst (b@>SubstVal i') resultTy
checkTabApp resultTy' rest

checkArgTys
:: (Typer m r, SubstB AtomSubstVal b, BindsNames b, BindsOneAtomName r b, IRRep r)
=> Nest b o o'
-> [Atom r o]
-> m i o ()
checkArgTys :: (Typer m r, IRRep r) => Nest (Binder r) o o' -> [Atom r o] -> m i o ()
checkArgTys Empty [] = return ()
checkArgTys (Nest b bs) (x:xs) = do
dropSubst $ x |: binderType b
Expand Down Expand Up @@ -930,15 +926,14 @@ checkedInstantiateTyConDef (TyConDef _ _ bs cons) (TyConParams _ xs) = do
checkedApplyNaryAbs (Abs bs cons) xs

checkedApplyNaryAbs
:: forall b r e o m
. ( BindsOneAtomName r b, EnvReader m, Fallible1 m, SinkableE e
, SubstE AtomSubstVal e, IRRep r, SubstB AtomSubstVal b)
=> Abs (Nest b) e o -> [Atom r o] -> m o (e o)
:: forall r e o m
. ( EnvReader m, Fallible1 m, SinkableE e , SubstE AtomSubstVal e, IRRep r)
=> Abs (Nest (Binder r)) e o -> [Atom r o] -> m o (e o)
checkedApplyNaryAbs (Abs bsTop e) xsTop = do
go (EmptyAbs bsTop) xsTop
applySubst (bsTop@@>(SubstVal<$>xsTop)) e
where
go :: EmptyAbs (Nest b) o -> [Atom r o] -> m o ()
go :: EmptyAbs (Nest (Binder r)) o -> [Atom r o] -> m o ()
go (Abs Empty UnitE) [] = return ()
go (Abs (Nest b bs) UnitE) (x:xs) = do
checkAlphaEq (binderType b) (getType x)
Expand Down
4 changes: 2 additions & 2 deletions src/lib/Export.hs
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ parseTabTy = go []
NewtypeTyCon Nat -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape
TabTy d (b:>ixty) a -> do
maybeN <- case IxType ixty d of
(IxType (NewtypeTyCon (Fin n)) _) -> return $ Just n
(IxType _ (IxDictRawFin n)) -> return $ Just n
IxType (NewtypeTyCon (Fin n)) _ -> return $ Just n
IxType _ (IxDictRawFin n) -> return $ Just n
_ -> return Nothing
maybeDim <- case maybeN of
Just (Var v) -> do
Expand Down
36 changes: 18 additions & 18 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,16 @@ toImpRefOp refDest' m = do
ans <- liftBuilderImp $ emitBlock (sink body')
storeAtom accDest ans
False -> case accTy of
TabTy d (b:>t) eltTy -> do
let ixTy = IxType t d
TabPi t -> do
let ixTy = tabIxType t
n <- indexSetSizeImp ixTy
emitLoop noHint Fwd n \i -> do
idx <- unsafeFromOrdinalImp (sink ixTy) i
xElt <- liftBuilderImp $ tabApp (sink x) (sink idx)
yElt <- liftBuilderImp $ tabApp (sink y) (sink idx)
eltTy' <- applySubst (b@>SubstVal idx) eltTy
eltTy <- instantiateTabPiTy (sink t) idx
ithDest <- indexDest (sink accDest) idx
liftMonoidCombine ithDest eltTy' (sink bc) xElt yElt
liftMonoidCombine ithDest eltTy (sink bc) xElt yElt
_ -> error $ "Base monoid type mismatch: can't lift " ++
pprint baseTy ++ " to " ++ pprint accTy

Expand Down Expand Up @@ -578,15 +578,15 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do
alphaEq xTy accTy >>= \case
True -> storeAtom accDest x
False -> case accTy of
TabTy d (b:>t) eltTy -> do
let ixTy = IxType t d
TabPi t -> do
let ixTy = tabIxType t
n <- indexSetSizeImp ixTy
emitLoop noHint Fwd n \i -> do
idx <- unsafeFromOrdinalImp (sink ixTy) i
x' <- sinkM x
eltTy' <- applySubst (b@>SubstVal idx) eltTy
eltTy <- instantiateTabPiTy (sink t) idx
ithDest <- indexDest (sink accDest) idx
liftMonoidEmpty ithDest eltTy' x'
liftMonoidEmpty ithDest eltTy x'
_ -> error $ "Base monoid type mismatch: can't lift " ++
pprint xTy ++ " to " ++ pprint accTy

Expand Down Expand Up @@ -1002,11 +1002,11 @@ buildGarbageVal ty =
-- === Operations on dests ===

indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n)
indexDest (Dest destValTy@(TabTy d (b:>t) eltTy) tree) i = do
eltTy' <- applySubst (b@>SubstVal i) eltTy
ord <- ordinalImp (IxType t d) i
leafTys <- typeToTree destValTy
Dest eltTy' <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do
indexDest (Dest (TabPi tabTy) tree) i = do
eltTy <- instantiateTabPiTy tabTy i
ord <- ordinalImp (tabIxType tabTy) i
leafTys <- typeToTree $ TabPi tabTy
Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do
BufferType ixStruct _ <- return $ getRefBufferType leafTy
offset <- computeOffsetImp ixStruct ord
impOffset ptr offset
Expand All @@ -1026,10 +1026,10 @@ indexRepValParam :: Emits n
=> SRepVal n -> SAtom n -> (SType n -> SType n)
-> (IExpr n -> SubstImpM i n (IExpr n))
-> SubstImpM i n (SRepVal n)
indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc func = do
eltTy' <- applySubst (b@>SubstVal i) eltTy
ord <- ordinalImp (IxType t d) i
leafTys <- typeToTree tabTy
indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do
eltTy <- instantiateTabPiTy tabTy i
ord <- ordinalImp (tabIxType tabTy) i
leafTys <- typeToTree (TabPi tabTy)
vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do
BufferPtr (BufferType ixStruct _) <- return $ getIExprInterpretation leafTy
offset <- computeOffsetImp ixStruct ord
Expand All @@ -1041,7 +1041,7 @@ indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc
_ -> func ptr'
-- `func` may have changed the types of the `vals'`. The caller must also
-- supply `tyFunc` to reflect that change in the SType.
return $ RepVal (tyFunc eltTy') vals'
return $ RepVal (tyFunc eltTy) vals'
indexRepValParam _ _ _ _ = error "expected table type"
{-# INLINE indexRepValParam #-}

Expand Down
6 changes: 3 additions & 3 deletions src/lib/Lower.hs
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ lowerFor _ _ _ _ _ = error "expected a unary lambda expression"
lowerTabCon :: forall i o. Emits o
=> Maybe (Dest SimpIR o) -> SType i -> [SAtom i] -> LowerM i o (SExpr o)
lowerTabCon maybeDest tabTy elems = do
tabTy'@(TabPi (TabPiType dict (_:>t) _)) <- substM tabTy
TabPi tabTy' <- substM tabTy
dest <- case maybeDest of
Just d -> return d
Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest tabTy'
Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest $ TabPi tabTy'
Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do
buildBlock $ unsafeFromOrdinal (sink $ IxType t dict) $ Var $ sink ord
buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ Var $ sink ord
-- This is emitting a chain of RememberDest ops to force `dest` to be used
-- linearly, and to force reads of the `Freeze dest'` result not to be
-- reordered in front of the writes.
Expand Down
7 changes: 5 additions & 2 deletions src/lib/QueryTypePure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ instance IRRep r => HasType r (Con r) where

getSuperclassType :: RNest CBinder n l -> Nest CBinder l l' -> Int -> CType n
getSuperclassType _ Empty = error "bad index"
getSuperclassType bsAbove (Nest b bs) = \case
0 -> ignoreHoistFailure $ hoist bsAbove $ binderType b
getSuperclassType bsAbove (Nest b@(_:>t) bs) = \case
0 -> ignoreHoistFailure $ hoist bsAbove t
i -> getSuperclassType (RNest bsAbove b) bs (i-1)

instance IRRep r => HasType r (Expr r) where
Expand Down Expand Up @@ -213,6 +213,9 @@ rawStrType = case newName "n" of
rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n
rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy

tabIxType :: TabPiType r n -> IxType r n
tabIxType (TabPiType d (_:>t) _) = IxType t d

typesAsBinderNest
:: (SinkableE e, HoistableE e, IRRep r)
=> [Type r n] -> e n -> Abs (Nest (Binder r)) e n
Expand Down
8 changes: 4 additions & 4 deletions src/lib/RuntimePrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ bufferTy h = do
extendBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n ()
extendBuffer buf tab = do
RefTy h _ <- return $ getType buf
TabTy d (_:>t) _ <- return $ getType tab
n <- applyIxMethodCore Size (IxType t d) []
TabPi t <- return $ getType tab
n <- applyIxMethodCore Size (tabIxType t) []
void $ applyPreludeFunction "stack_extend_internal" [n, h, buf, tab]

-- argument has type `Word8`
Expand Down Expand Up @@ -237,8 +237,8 @@ forEachTabElt
-> (forall l. (Emits l, DExt n l) => CAtom l -> CAtom l -> m l ())
-> m n ()
forEachTabElt tab cont = do
TabTy d (_:>t) _ <- return $ getType tab
let ixTy = IxType t d
TabPi t <- return $ getType tab
let ixTy = tabIxType t
void $ buildFor "i" Fwd ixTy \i -> do
x <- tabApp (sink tab) (Var i)
i' <- applyIxMethodCore Ordinal (sink ixTy) [Var i]
Expand Down
10 changes: 5 additions & 5 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ getRepType ty = go ty where
x <- liftSimpAtom (sink l) (Var $ binderVar b')
r' <- go =<< applySubst (b@>SubstVal x) r
return $ DepPairTy $ DepPairType expl b' r'
TabPi (TabPiType d (b:>t) bodyTy) -> do
let ixTy = IxType t d
TabPi tabTy -> do
let ixTy = tabIxType tabTy
IxType t' d' <- simplifyIxType ixTy
withFreshBinder (getNameHint b) t' \b' -> do
withFreshBinder (getNameHint tabTy) t' \b' -> do
x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b')
bodyTy' <- go =<< applySubst (b@>SubstVal x) bodyTy
bodyTy' <- go =<< instantiateTabPiTy (sink tabTy) x
return $ TabPi $ TabPiType d' b' bodyTy'
NewtypeTyCon con -> do
(_, ty') <- unwrapNewtypeType con
Expand Down Expand Up @@ -1025,7 +1025,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do
return $ activeArg':rest
buildTangentArgs _ _ _ = error "zip error"

fromNonDepNest :: (HoistableB b, BindsOneAtomName CoreIR b) => Nest b n l -> [CType n]
fromNonDepNest :: Nest CBinder n l -> [CType n]
fromNonDepNest Empty = []
fromNonDepNest (Nest b bs) =
case ignoreHoistFailure $ hoist b (Abs bs UnitE) of
Expand Down
33 changes: 8 additions & 25 deletions src/lib/Types/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -897,17 +897,13 @@ data LinearizationSpec (n::S) =
LinearizationSpec (TopFunName n) [Active]
deriving (Show, Generic)

-- === BindsOneAtomName ===
-- === Binder utils ===

class BindsOneName b (AtomNameC r) => BindsOneAtomName (r::IR) (b::B) | b -> r where
binderType :: b n l -> Type r n
binderVar :: DExt n l => b n l -> AtomVar r l
binderType :: Binder r n l -> Type r n
binderType (_:>ty) = ty

bindersTypes :: (IRRep r, Distinct l, ProvesExt b, BindsNames b, BindsOneAtomName r b)
=> Nest b n l -> [Type r l]
bindersTypes Empty = []
bindersTypes n@(Nest b bs) = ty : bindersTypes bs
where ty = withExtEvidence n $ sink (binderType b)
binderVar :: (IRRep r, DExt n l) => Binder r n l -> AtomVar r l
binderVar (b:>ty) = AtomVar (binderName b) (sink ty)

nestToAtomVars :: (Distinct l, Ext n l, IRRep r)
=> Nest (Binder r) n l -> [AtomVar r l]
Expand All @@ -916,14 +912,6 @@ nestToAtomVars = \case
Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $
sink (binderVar b) : nestToAtomVars bs

instance IRRep r => BindsOneAtomName r (BinderP (AtomNameC r) (Type r)) where
binderType (_ :> ty) = ty
binderVar (b:>t) = AtomVar (binderName b) (sink t)

toBinderNest :: BindsOneAtomName r b => Nest b n l -> Nest (Binder r) n l
toBinderNest Empty = Empty
toBinderNest (Nest b bs) = Nest (asNameBinder b :> binderType b) (toBinderNest bs)

-- === ToBinding ===

atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n
Expand Down Expand Up @@ -957,14 +945,6 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where
toBinding (LeftE e) = toBinding e
toBinding (RightE e) = toBinding e

-- === HasArgType ===

class HasArgType (e::E) (r::IR) | e -> r where
argType :: e n -> Type r n

instance HasArgType (TabPiType r) r where
argType (TabPiType _ (_:>ty) _) = ty

-- === Pattern synonyms ===

-- XXX: only use this pattern when you're actually expecting a type. If it's
Expand Down Expand Up @@ -2055,6 +2035,9 @@ instance IRRep r => AlphaEqE (TabPiType r) where
instance IRRep r => AlphaHashableE (TabPiType r) where
hashWithSaltE env salt (TabPiType _ b t) = hashWithSaltE env salt $ Abs b t

instance HasNameHint (TabPiType r n) where
getNameHint (TabPiType _ b _) = getNameHint b

instance IRRep r => SinkableE (TabPiType r)
instance IRRep r => HoistableE (TabPiType r)
instance IRRep r => RenameE (TabPiType r)
Expand Down

0 comments on commit b274115

Please sign in to comment.