Skip to content

Commit

Permalink
Remove MExtend and MAsk
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Apr 23, 2024
1 parent d5997a8 commit e49f882
Show file tree
Hide file tree
Showing 24 changed files with 66 additions and 346 deletions.
1 change: 0 additions & 1 deletion dex.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ library
, ConcreteSyntax
, Core
, Err
, Export
, Generalize
, Imp
, ImpToLLVM
Expand Down
7 changes: 0 additions & 7 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ decl ann (WithSrcs sid _ d) = WithSrcB sid <$> case d of
CLet binder rhs -> do
(p, ty) <- patOptAnn binder
ULet ann p ty <$> asExpr <$> block rhs
CBind _ _ -> throw sid TopLevelArrowBinder
CDefDecl def -> do
(name, lam) <- aDef def
return $ ULet ann (fromSourceNameW name) Nothing (WithSrcE sid (ULam lam))
Expand Down Expand Up @@ -382,12 +381,6 @@ blockDecls [] = error "shouldn't have empty list of decls"
blockDecls [WithSrcs sid _ d] = case d of
CExpr g -> (Empty,) <$> expr g
_ -> throw sid BlockWithoutFinalExpr
blockDecls (WithSrcs sid _ (CBind b rhs):ds) = do
b' <- binderOptTy Explicit b
rhs' <- asExpr <$> block rhs
body <- block $ IndentedBlock sid ds -- Not really the right SrcId
let lam = ULam $ ULamExpr (UnaryNest b') ExplicitApp Nothing body
return (Empty, WithSrcE sid $ extendAppRight rhs' (WithSrcE sid lam))
blockDecls (d:ds) = do
d' <- decl PlainLet d
(ds', e) <- blockDecls ds
Expand Down
28 changes: 10 additions & 18 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -806,13 +806,6 @@ maybeTangentType' ty = case ty of
_ -> empty
where rec = maybeTangentType'

tangentBaseMonoidFor :: (Emits n, SBuilder m) => SType n -> m n (BaseMonoid SimpIR n)
tangentBaseMonoidFor ty = do
zero <- zeroAt ty
adder <- liftBuilder $ buildBinaryLamExpr (noHint, ty) (noHint, ty) \x y ->
addTangent (toAtom x) (toAtom y)
return $ BaseMonoid zero adder

addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n)
addTangent x y = do
case getTyCon x of
Expand Down Expand Up @@ -934,7 +927,7 @@ projectStruct i x = do

projectStructRef :: (Builder CoreIR m, Emits n) => Int -> CAtom n -> m n (CAtom n)
projectStructRef i x = do
RefTy _ valTy <- return $ getType x
RefTy valTy <- return $ getType x
projs <- getStructProjections i valTy
applyProjectionsRef projs x
{-# INLINE projectStructRef #-}
Expand Down Expand Up @@ -973,16 +966,15 @@ mkBlock (Abs decls body) = do
return $ Block effTy block

blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n)
blockEffTy _ = undefined
-- blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do
-- effs <- declsEffects decls mempty
-- return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result
-- where
-- declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l)
-- declsEffects Empty !acc = return acc
-- declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do
-- expr' <- sinkM expr
-- declsEffects rest $ acc <> getEffects expr'
blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do
effs <- declsEffects decls mempty
return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result
where
declsEffects :: IRRep r => Nest (Decl r) n l -> Effects r l -> EnvReaderM l (Effects r l)
declsEffects Empty !acc = return acc
declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do
expr' <- sinkM expr
declsEffects rest $ acc <> getEffects expr'

mkApp :: EnvReader m => CAtom n -> [CAtom n] -> m n (CExpr n)
mkApp f xs = do
Expand Down
6 changes: 1 addition & 5 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,6 @@ instance IRRep r => VisitGeneric (Hof r) r where
Linearize lam x -> Linearize <$> visitGeneric lam <*> visitGeneric x
Transpose lam x -> Transpose <$> visitGeneric lam <*> visitGeneric x

instance IRRep r => VisitGeneric (BaseMonoid r) r where
visitGeneric (BaseMonoid x lam) = BaseMonoid <$> visitGeneric x <*> visitGeneric lam

instance IRRep r => VisitGeneric (Effects r) r where
visitGeneric = \case
Pure -> return Pure
Expand Down Expand Up @@ -550,7 +547,7 @@ instance IRRep r => VisitGeneric (TyCon r) r where
BaseType bt -> return $ BaseType bt
ProdType tys -> ProdType <$> mapM visitGeneric tys
SumType tys -> SumType <$> mapM visitGeneric tys
RefType h t -> RefType h <$> visitGeneric t
RefType t -> RefType <$> visitGeneric t
TabPi t -> TabPi <$> visitGeneric t
DepPairTy t -> DepPairTy <$> visitGeneric t
TypeKind -> return TypeKind
Expand Down Expand Up @@ -686,7 +683,6 @@ instance SubstE AtomSubstVal IExpr
instance SubstE AtomSubstVal RepVal
instance SubstE AtomSubstVal TyConParams
instance SubstE AtomSubstVal DataConDef
instance IRRep r => SubstE AtomSubstVal (BaseMonoid r)
instance IRRep r => SubstE AtomSubstVal (TypedHof r)
instance IRRep r => SubstE AtomSubstVal (Hof r)
instance IRRep r => SubstE AtomSubstVal (TyCon r)
Expand Down
20 changes: 6 additions & 14 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ instance IRRep r => CheckableE r (Stuck r) where
Var name -> do
name' <- checkE name
case getType name' of
RawRefTy _ -> affineUsed $ atomVarName name'
RefTy _ -> affineUsed $ atomVarName name'
_ -> return ()
return $ Var name'
StuckUnwrap x -> do
Expand Down Expand Up @@ -372,7 +372,7 @@ instance IRRep r => CheckableE r (TyCon r) where
BaseType b -> return $ BaseType b
ProdType tys -> ProdType <$> mapM checkE tys
SumType cs -> SumType <$> mapM checkE cs
RefType r a -> RefType r <$> checkE a
RefType a -> RefType <$> checkE a
TypeKind -> return TypeKind
Pi t -> Pi <$> checkE t
TabPi t -> TabPi <$> checkE t
Expand Down Expand Up @@ -463,23 +463,18 @@ instance IRRep r => CheckableE r (PrimOp r) where
MiscOp op -> MiscOp <$> checkE op
MemOp op -> MemOp <$> checkE op
RefOp ref m -> do
(ref', TyCon (RefType h s)) <- checkAndGetType ref
(ref', TyCon (RefType s)) <- checkAndGetType ref
m' <- case m of
MGet -> return MGet
MPut x -> do
x' <- x|:s
return $ MPut x'
MAsk -> return MAsk
MExtend b x -> do
b' <- checkE b
x' <- x|:s
return $ MExtend b' x'
IndexRef givenTy i -> do
givenTy' <- checkE givenTy
TyCon (TabPi tabTy) <- return s
i' <- checkE i
eltTy' <- checkInstantiation tabTy [i']
checkTypesEq givenTy' (TyCon $ RefType h eltTy')
checkTypesEq givenTy' (TyCon $ RefType eltTy')
return $ IndexRef givenTy' i'
ProjRef givenTy p -> do
givenTy' <- checkE givenTy
Expand All @@ -490,16 +485,13 @@ instance IRRep r => CheckableE r (PrimOp r) where
UnwrapNewtype -> do
TyCon (NewtypeTyCon tc) <- return s
snd <$> unwrapNewtypeType tc
checkTypesEq givenTy' (TyCon $ RefType h resultEltTy)
checkTypesEq givenTy' (TyCon $ RefType resultEltTy)
return $ ProjRef givenTy' p
return $ RefOp ref' m'

instance IRRep r => CheckableE r (EffTy r) where
checkE (EffTy effs ty) = EffTy <$> checkE effs <*> checkE ty

instance IRRep r => CheckableE r (BaseMonoid r) where
checkE = renameM -- TODO: check

instance IRRep r => CheckableE r (MemOp r) where
checkE = \case
IOAlloc n -> do
Expand Down Expand Up @@ -592,7 +584,7 @@ instance IRRep r => CheckableE r (VectorOp r) where
return $ VectorIdx tbl' i' ty'
VectorSubref ref i ty -> do
ref' <- checkE ref
RefTy _ (TabTy _ b (BaseTy (Scalar sbt))) <- return $ getType ref'
RefTy (TabTy _ b (BaseTy (Scalar sbt))) <- return $ getType ref'
i' <- i |: binderType b
ty'@(BaseTy (Vector _ sbt')) <- checkE ty
unless (sbt == sbt') $ throwInternal "Scalar type mismatch"
Expand Down
1 change: 0 additions & 1 deletion src/lib/ConcreteSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ simpleLet = do
next <- nextChar
case next of
'=' -> sym "=" >> CLet lhs <$> cBlock
'<' -> sym "<-" >> CBind lhs <$> cBlock
_ -> return $ CExpr lhs

instanceDef :: Bool -> Parser CInstanceDef
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Generalize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ traverseTyParams (TyCon ty) f = liftM TyCon $ getDistinct >>= \Distinct -> case
return $ TabPi $ TabPiType d' (b':>iTy') resultTy'
BaseType b -> return $ BaseType b
ProdType tys -> ProdType <$> forM tys \t -> f' TypeParam TyKind t
RefType _ _ -> error "not implemented" -- how should we handle the ParamRole for the heap parameter?
RefType _ -> error "not implemented"
SumType tys -> SumType <$> forM tys \t -> f' TypeParam TyKind t
TypeKind -> return TypeKind
NewtypeTyCon con -> NewtypeTyCon <$> case con of
Expand Down
59 changes: 6 additions & 53 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ toImpFunction cc (TopLam True destTy lam) = do
argAtoms <- interpretImpArgs (sink $ EmptyAbs bs) vs
extendSubst (bs @@> (SubstVal <$> argAtoms)) do
dest <- case binderType destB of
RefTy _ ansTy -> allocDestUnmanaged =<< substM ansTy
RefTy ansTy -> allocDestUnmanaged =<< substM ansTy
_ -> error "Expected a reference type for body destination"
extendSubst (destB @> SubstVal (destToAtom dest)) do
void $ translateExpr body
Expand Down Expand Up @@ -327,12 +327,6 @@ toImpRefOp :: Emits o
toImpRefOp refDest' m = do
refDest <- atomToDest =<< substM refDest'
substM m >>= \case
MAsk -> loadAtom refDest
MExtend (BaseMonoid _ combine) x -> do
xTy <- return $ getType x
refVal <- loadAtom refDest
liftMonoidCombine refDest xTy combine refVal x
return UnitVal
MPut x -> storeAtom refDest x >> return UnitVal
MGet -> do
Dest resultTy _ <- return refDest
Expand All @@ -343,32 +337,6 @@ toImpRefOp refDest' m = do
loadAtom dest
IndexRef _ i -> destToAtom <$> indexDest refDest i
ProjRef _ ~(ProjectProduct i) -> return $ destToAtom $ projectDest i refDest
where
liftMonoidCombine :: Emits o
=> (Dest o) -> SType o -> LamExpr SimpIR o
-> SAtom o -> SAtom o -> SubstImpM n o ()
liftMonoidCombine accDest accTy bc x y = do
LamExpr (Nest (_:>baseTy) _) _ <- return bc
alphaEq accTy baseTy >>= \case
-- Immediately beta-reduce, beacuse Imp doesn't reduce non-table applications.
True -> do
BinaryLamExpr xb yb body <- return bc
body' <- applySubst (xb @> SubstVal x <.> yb @> SubstVal y) body
ans <- liftBuilderImp $ emit (sink body')
storeAtom accDest ans
False -> case accTy of
TyCon (TabPi t) -> do
let ixTy = tabIxType t
n <- indexSetSizeImp ixTy
emitLoop noHint Fwd n \i -> do
idx <- unsafeFromOrdinalImp (sink ixTy) i
xElt <- liftBuilderImp $ tabApp (sink x) (sink idx)
yElt <- liftBuilderImp $ tabApp (sink y) (sink idx)
eltTy <- instantiate (sink t) [idx]
ithDest <- indexDest (sink accDest) idx
liftMonoidCombine ithDest eltTy (sink bc) xElt yElt
_ -> error $ "Base monoid type mismatch: can't lift " ++
pprint baseTy ++ " to " ++ pprint accTy

toImpOp :: forall i o . Emits o => PrimOp SimpIR i -> SubstImpM i o (SAtom o)
toImpOp op = case op of
Expand Down Expand Up @@ -399,7 +367,7 @@ toImpVectorOp = \case
refi <- destToAtom <$> indexDest refDest i
refi' <- fromScalarAtom refi
resultVal <- castPtrToVectorType refi' (toIVectorType vty)
repValAtom $ RepVal (RefTy State vty) (Leaf resultVal)
repValAtom $ RepVal (RefTy vty) (Leaf resultVal)
where
returnIExprVal x = return $ toScalarAtom x

Expand Down Expand Up @@ -605,7 +573,7 @@ typeToTree tyTop = return $ go REmpty tyTop
go ctx (TyCon con) = case con of
BaseType b -> Leaf $ LeafType (unRNest ctx) b
TabPi (TabPiType d b bodyTy) -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy
RefType _ t -> go (RNest ctx RefCtx) t
RefType t -> go (RNest ctx RefCtx) t
DepPairTy (DepPairType _ (b:>t1) (t2)) -> do
let tree1 = rec t1
let tree2 = go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2
Expand Down Expand Up @@ -639,7 +607,7 @@ valueToTree (RepVal tyTop valTop) = do
go ctx (TyCon ty) val = case ty of
BaseType b -> return $ Leaf $ LeafType (unRNest ctx) b
TabPi (TabPiType d b bodyTy) -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val
RefType _ t -> go (RNest ctx RefCtx) t val
RefType t -> go (RNest ctx RefCtx) t val
DepPairTy (DepPairType _ (b:>t1) (t2)) -> case val of
Branch [v1, v2] -> do
case allDepPairCtxs (unRNest ctx) of
Expand Down Expand Up @@ -795,11 +763,11 @@ atomToRepVal x = RepVal (getType x) <$> go x where
-- from the dest. This version is not that. It just lifts a dest into an atom of
-- type `Ref _`.
destToAtom :: Dest n -> SAtom n
destToAtom (Dest valTy tree) = toAtom $ RepVal (RefTy State valTy) tree
destToAtom (Dest valTy tree) = toAtom $ RepVal (RefTy valTy) tree

atomToDest :: EnvReader m => SAtom n -> m n (Dest n)
atomToDest (Stuck _ (RepValAtom val)) = do
(RepVal ~(RefTy _ valTy) valTree) <- return val
(RepVal ~(RefTy valTy) valTree) <- return val
return $ Dest valTy valTree
atomToDest atom = error $ "Expected a non-var atom of type `RawRef _`, got: " ++ pprint atom
{-# INLINE atomToDest #-}
Expand Down Expand Up @@ -1276,21 +1244,6 @@ ordinalImp (IxType _ (DictCon dict)) i = fromScalarAtom =<< case dict of
IxSpecialized d params -> do
appSpecializedIxMethod d Ordinal (params ++ [i])

unsafeFromOrdinalImp :: Emits n => IxType SimpIR n -> IExpr n -> SubstImpM i n (SAtom n)
unsafeFromOrdinalImp (IxType _ (DictCon dict)) i = do
let i' = toScalarAtom i
case dict of
IxRawFin _ -> return i'
IxSpecialized d params ->
appSpecializedIxMethod d UnsafeFromOrdinal (params ++ [i'])

indexSetSizeImp :: Emits n => IxType SimpIR n -> SubstImpM i n (IExpr n)
indexSetSizeImp (IxType _ (DictCon dict)) = do
fromScalarAtom =<< case dict of
IxRawFin n -> return n
IxSpecialized d params ->
appSpecializedIxMethod d Size (params ++ [])

appSpecializedIxMethod :: Emits n => SpecDictName n -> IxMethod -> [SAtom n] -> SubstImpM i n (SAtom n)
appSpecializedIxMethod d method args = do
SpecializedDict _ (Just fs) <- lookupSpecDict d
Expand Down
12 changes: 5 additions & 7 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,8 @@ getFieldDefs sid ty = case ty of
(FieldName field, FieldDotMethod f params)
return $ M.fromList $ concat projFields ++ methodFields
ADTCons _ -> noFields
RefType _ valTy -> case valTy of
RefTy _ _ -> noFields
RefType valTy -> case valTy of
RefTy _ -> noFields
_ -> do
valFields <- getFieldDefs sid valTy
return $ M.filter isProj valFields
Expand All @@ -720,7 +720,7 @@ projectField i x = case getType x of
TyCon con -> case con of
ProdType _ -> proj i x
NewtypeTyCon _ -> projectStruct i x
RefType _ valTy -> case valTy of
RefType valTy -> case valTy of
TyCon (ProdType _) -> getProjRef (ProjectProduct i) x
TyCon (NewtypeTyCon _) -> projectStructRef i x
_ -> bad
Expand Down Expand Up @@ -1029,14 +1029,12 @@ matchPrimApp = \case
UMemOp op -> \x -> emit =<< MemOp <$> matchGenericOp op x
UBinOp op -> \case ~[x, y] -> emit $ BinOp op x y
UUnOp op -> \case ~[x] -> emit $ UnOp op x
UMAsk -> \case ~[r] -> emit $ RefOp r MAsk
UMGet -> \case ~[r] -> emit $ RefOp r MGet
UMPut -> \case ~[r, x] -> emit $ RefOp r $ MPut x
UIndexRef -> \case ~[r, i] -> indexRef r i
UApplyMethod i -> \case ~(d:args) -> emit =<< mkApplyMethod (fromJust $ toMaybeDict d) i args
ULinearize -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Linearize f' x
UTranspose -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Transpose f' x
UMExtend -> \case ~[r, z, f, x] -> do f' <- lam2 f; emit $ RefOp r $ MExtend (BaseMonoid z f') x
p -> \case xs -> throwInternal $ "Bad primitive application: " ++ show (p, xs)
where
lam2 :: Fallible m => CAtom n -> m (LamExpr CoreIR n)
Expand Down Expand Up @@ -1704,8 +1702,8 @@ instance Unifiable (TyCon CoreIR) where
{ SumType ts' <- matchit; unifyLists ts ts'}
( ProdType ts ) -> do
{ ProdType ts' <- matchit; unifyLists ts ts'}
( RefType h t ) -> do
{ RefType h' t' <- matchit; guard (h == h'); unify t t'}
( RefType t ) -> do
{ RefType t' <- matchit; unify t t'}
( DepPairTy t ) -> do
{ DepPairTy t' <- matchit; unify t t'}
where matchit = return t2
Expand Down
7 changes: 0 additions & 7 deletions src/lib/Lexing.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW
| InstanceKW | GivenKW | WithKW | SatisfyingKW
| IfKW | ThenKW | ElseKW | DoKW
| ImportKW | ForeignKW | NamedInstanceKW
| EffectKW | HandlerKW | JmpKW | CtlKW | ReturnKW | ResumeKW
| CustomLinearizationKW | CustomLinearizationSymbolicKW | PassKW
deriving (Enum)

Expand All @@ -118,12 +117,6 @@ keyWordToken = \case
DoKW -> "do"
ImportKW -> "import"
ForeignKW -> "foreign"
EffectKW -> "effect"
HandlerKW -> "handler"
JmpKW -> "jmp"
CtlKW -> "ctl"
ReturnKW -> "return"
ResumeKW -> "resume"
CustomLinearizationKW -> "custom-linearization"
CustomLinearizationSymbolicKW -> "custom-linearization-symbolic"
PassKW -> "pass"
Expand Down
Loading

0 comments on commit e49f882

Please sign in to comment.