Skip to content

Commit

Permalink
Remove some more uses of @@>
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Jun 27, 2023
1 parent 43e42ec commit 336bb04
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 81 deletions.
6 changes: 2 additions & 4 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,8 +1134,7 @@ naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] ->
naryTopAppInlined f xs = do
TopFunBinding f' <- lookupEnv f
case f' of
DexTopFun _ (TopLam _ _ (LamExpr bs body)) _ ->
applySubst (bs@@>(SubstVal<$>xs)) body >>= emitBlock
DexTopFun _ lam _ -> instantiate lam xs >>= emitBlock
_ -> naryTopApp f xs
{-# INLINE naryTopAppInlined #-}

Expand Down Expand Up @@ -1194,8 +1193,7 @@ applyIxMethod dict method args = case dict of
IxDictSpecialized _ d params -> do
SpecializedDict _ maybeFs <- lookupSpecDict d
Just fs <- return maybeFs
TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method
emitBlock =<< applySubst (bs @@> fmap SubstVal (params ++ args)) body
instantiate (fs !! fromEnum method) (params ++ args) >>= emitBlock

unsafeFromOrdinal :: (SBuilder m, Emits n) => IxType SimpIR n -> Atom SimpIR n -> m n (Atom SimpIR n)
unsafeFromOrdinal (IxType _ dict) i = applyIxMethod dict UnsafeFromOrdinal [i]
Expand Down
41 changes: 31 additions & 10 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ module CheapReduction
, unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType
, liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..)
, visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate
, bindersToVars, bindersToAtoms)
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated
, bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames)
where

import Control.Applicative
Expand Down Expand Up @@ -242,7 +242,7 @@ cheapReduceDictExpr resultTy d = case d of
args' <- mapM cheapReduceE args
InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName
let InstanceBody superclasses _ = body
applySubst (bs@@>(SubstVal <$> args')) (superclasses !! superclassIx)
instantiate (Abs bs (superclasses !! superclassIx)) args'
child' -> return $ DictCon resultTy $ SuperclassProj child' superclassIx
InstantiatedGiven f xs ->
reduceApp <|> justSubst
Expand Down Expand Up @@ -285,19 +285,16 @@ instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where
cheapReduceE dict >>= \case
DictCon _ (InstanceDict instanceName args) -> dropSubst do
args' <- mapM cheapReduceE args
InstanceDef _ _ bs _ (InstanceBody _ methods) <- lookupInstanceDef instanceName
let method = methods !! i
extendSubst (bs@@>(SubstVal <$> args')) do
method' <- cheapReduceE method
def <- lookupInstanceDef instanceName
withInstantiated def args' \(PairE _ (InstanceBody _ methods)) -> do
method' <- cheapReduceE $ methods !! i
cheapReduceApp method' explicitArgs'
_ -> empty
_ -> empty

cheapReduceApp :: CAtom o -> [CAtom o] -> CheapReducerM CoreIR i o (CAtom o)
cheapReduceApp f xs = case f of
Lam (CoreLamExpr _ (LamExpr bs body)) -> do
let subst = bs @@> fmap SubstVal xs
dropSubst $ extendSubst subst $ cheapReduceE body
Lam lam -> dropSubst $ withInstantiated lam xs \body -> cheapReduceE body
_ -> empty

instance IRRep r => CheaplyReducibleE r (IxType r) (IxType r) where
Expand Down Expand Up @@ -476,6 +473,30 @@ instantiate
instantiate e xs = case toAbs e of
Abs bs body -> applySubst (bs @@> (SubstVal <$> xs)) body

-- "lazy" subst-extending version of `instantiate`
withInstantiated
:: (SubstReader AtomSubstVal m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r)
=> e i -> [Atom r o]
-> (forall i'. body i' -> m i' o a)
-> m i o a
withInstantiated e xs cont = case toAbs e of
Abs bs body -> extendSubst (bs @@> (SubstVal <$> xs)) $ cont body

instantiateNames
:: (EnvReader m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r)
=> e n -> [AtomName r n] -> m n (body n)
instantiateNames e vs = case toAbs e of
Abs bs body -> applyRename (bs @@> vs) body

-- "lazy" subst-extending version of `instantiateNames`
withInstantiatedNames
:: (SubstReader Name m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r)
=> e i -> [AtomName r o]
-> (forall i'. body i' -> m i' o a)
-> m i o a
withInstantiatedNames e vs cont = case toAbs e of
Abs bs body -> extendRenamer (bs @@> vs) $ cont body

-- Returns a representation type (type of an TypeCon-typed Newtype payload)
-- given a list of instantiated DataConDefs.
dataDefRep :: DataConDefs n -> CType n
Expand Down
99 changes: 43 additions & 56 deletions src/lib/QueryType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,15 @@ blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n)
blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff

typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n)
typeOfApp (Pi (CorePiType _ _ bs (EffTy _ resultTy))) xs = do
let subst = bs @@> fmap SubstVal xs
applySubst subst resultTy
typeOfApp (Pi piTy) xs = withSubstReaderT $
withInstantiated piTy xs \(EffTy _ ty) -> substM ty
typeOfApp _ _ = error "expected a pi type"

typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n)
typeOfTabApp t [] = return t
typeOfTabApp (TabTy _ (b:>_) resultTy) (i:rest) = do
resultTy' <- applySubst (b@>SubstVal i) resultTy
typeOfTabApp resultTy' rest
typeOfTabApp (TabPi tabTy) (i:rest) = do
resultTy <- instantiate tabTy [i]
typeOfTabApp resultTy rest
typeOfTabApp ty _ = error $ "expected a table type. Got: " ++ pprint ty

typeOfApplyMethod :: EnvReader m => CAtom n -> Int -> [CAtom n] -> m n (EffTy CoreIR n)
Expand All @@ -93,23 +92,23 @@ typeOfApplyMethod d i args = do
typeOfDictExpr :: EnvReader m => DictExpr n -> m n (CType n)
typeOfDictExpr e = liftM ignoreExcept $ liftEnvReaderT $ case e of
InstanceDict instanceName args -> do
InstanceDef className _ bs params _ <- lookupInstanceDef instanceName
ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className
ListE params' <- applySubst (bs @@> map SubstVal args) $ ListE params
return $ DictTy $ DictType sourceName className params'
instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName
sourceName <- getSourceName <$> lookupClassDef className
PairE (ListE params) _ <- instantiate instanceDef args
return $ DictTy $ DictType sourceName className params
InstantiatedGiven given args -> typeOfApp (getType given) args
SuperclassProj d i -> do
DictTy (DictType _ className params) <- return $ getType d
ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className
applySubst (bs @@> map SubstVal params) $
getSuperclassType REmpty superclasses i
classDef <- lookupClassDef className
withSubstReaderT $ withInstantiated classDef params \(Abs superclasses _) -> do
substM $ getSuperclassType REmpty superclasses i
IxFin n -> liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n
DataData ty -> DictTy <$> dataDictType ty

typeOfTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (EffTy SimpIR n)
typeOfTopApp f xs = do
PiType bs effTy <- getTypeTopFun f
applySubst (bs @@> map SubstVal xs) effTy
piTy <- getTypeTopFun f
instantiate piTy xs

typeOfIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Type r n -> Atom r n -> m n (Type r n)
typeOfIndexRef (TC (RefType h s)) i = do
Expand All @@ -131,25 +130,16 @@ typeOfProjRef (TC (RefType h s)) p = do
typeOfProjRef _ _ = error "expected a reference"

appEffTy :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (EffTy r n)
appEffTy (Pi (CorePiType _ _ bs effTy)) xs = do
let subst = bs @@> fmap SubstVal xs
applySubst subst effTy
appEffTy (Pi piTy) xs = instantiate piTy xs
appEffTy t _ = error $ "expected a pi type, got: " ++ pprint t

partialAppType :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n)
partialAppType (Pi (CorePiType appExpl expls bs effTy)) xs = do
(_, expls2) <- return $ splitAt (length xs) expls
PairB bs1 bs2 <- return $ splitNestAt (length xs) bs
let subst = bs1 @@> fmap SubstVal xs
applySubst subst $ Pi $ CorePiType appExpl expls2 bs2 effTy
instantiate (Abs bs1 (Pi $ CorePiType appExpl expls2 bs2 effTy)) xs
partialAppType _ _ = error "expected a pi type"

appEffects :: (EnvReader m, IRRep r) => Type r n -> [Atom r n] -> m n (EffectRow r n)
appEffects (Pi (CorePiType _ _ bs (EffTy effs _))) xs = do
let subst = bs @@> fmap SubstVal xs
applySubst subst effs
appEffects _ _ = error "expected a pi type"

effTyOfHof :: (EnvReader m, IRRep r) => Hof r n -> m n (EffTy r n)
effTyOfHof hof = EffTy <$> hofEffects hof <*> typeOfHof hof

Expand Down Expand Up @@ -222,27 +212,24 @@ getMethodNameType :: EnvReader m => MethodName n -> m n (CType n)
getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case
MethodBinding className i -> do
ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className
refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' (Abs scBinders' piTy) -> do
refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' absPiTy -> do
let params = Var <$> nestToAtomVars paramBs'
dictTy <- DictTy <$> dictType (sink className) params
withFreshBinder noHint dictTy \dictB -> do
scDicts <- getSuperclassDicts (Var $ binderVar dictB)
piTy' <- applySubst (scBinders'@@>(SubstVal<$>scDicts)) piTy
CorePiType appExpl methodExpls methodBs effTy <- return piTy'
CorePiType appExpl methodExpls methodBs effTy <- instantiate (sink absPiTy) scDicts
let paramExpls = paramNames <&> \name -> Inferred name Unify
let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls
return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy

getMethodType :: EnvReader m => Dict n -> Int -> m n (CorePiType n)
getMethodType dict i = do
getMethodType dict i = liftEnvReaderM $ withSubstReaderT do
~(DictTy (DictType _ className params)) <- return $ getType dict
ClassDef _ _ _ _ paramBs classBs methodTys <- lookupClassDef className
let methodTy = methodTys !! i
superclassDicts <- getSuperclassDicts dict
let subst = ( paramBs @@> map SubstVal params
<.> classBs @@> map SubstVal superclassDicts)
applySubst subst methodTy
{-# INLINE getMethodType #-}
classDef <- lookupClassDef className
withInstantiated classDef params \ab -> do
withInstantiated ab superclassDicts \(ListE methodTys) ->
substM $ methodTys !! i

getTyConNameType :: EnvReader m => TyConName n -> m n (Type CoreIR n)
getTyConNameType v = do
Expand All @@ -252,28 +239,29 @@ getTyConNameType v = do
_ -> return $ Pi $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind

getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n)
getDataConNameType dataCon = liftEnvReaderM do
getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do
(tyCon, i) <- lookupDataCon dataCon
lookupTyCon tyCon >>= \case
tyConDef@(TyConDef tcSn _ paramBs ~(ADTCons dataCons)) ->
buildDataConType tyConDef \expls paramBs' paramVs params -> do
DataConDef _ ab _ _ <- applyRename (paramBs @@> paramVs) (dataCons !! i)
refreshAbs ab \dataBs UnitE -> do
let appExpl = case dataBs of Empty -> ImplicitApp
_ -> ExplicitApp
let resultTy = NewtypeTyCon $ UserADTType tcSn (sink tyCon) (sink params)
let dataExpls = nestToList (const $ Explicit) dataBs
return $ Pi $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy)
tyConDef <- lookupTyCon tyCon
buildDataConType tyConDef \expls paramBs' paramVs params -> do
withInstantiatedNames tyConDef paramVs \(ADTCons dataCons) -> do
DataConDef _ ab _ _ <- renameM (dataCons !! i)
refreshAbs ab \dataBs UnitE -> do
let appExpl = case dataBs of Empty -> ImplicitApp
_ -> ExplicitApp
let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) (sink params)
let dataExpls = nestToList (const $ Explicit) dataBs
return $ Pi $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy)

getStructDataConType :: EnvReader m => TyConName n -> m n (CType n)
getStructDataConType tyCon = liftEnvReaderM do
tyConDef@(TyConDef tcSn _ paramBs ~(StructFields fields)) <- lookupTyCon tyCon
getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do
tyConDef <- lookupTyCon tyCon
buildDataConType tyConDef \expls paramBs' paramVs params -> do
fieldTys <- forM fields \(_, t) -> applyRename (paramBs @@> paramVs) t
let resultTy = NewtypeTyCon $ UserADTType tcSn (sink tyCon) params
Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy
let dataExpls = nestToList (const Explicit) dataBs
return $ Pi $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy')
withInstantiatedNames tyConDef paramVs \(StructFields fields) -> do
fieldTys <- forM fields \(_, t) -> renameM t
let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) params
Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy
let dataExpls = nestToList (const Explicit) dataBs
return $ Pi $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy')

buildDataConType
:: (EnvReader m, EnvExtender m)
Expand Down Expand Up @@ -371,8 +359,7 @@ getSuperclassTys :: EnvReader m => DictType n -> m n [CType n]
getSuperclassTys (DictType _ className params) = do
ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className
forM [0 .. nestLength superclasses - 1] \i -> do
applySubst (bs @@> map SubstVal params) $
getSuperclassType REmpty superclasses i
instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params

getTypeTopFun :: EnvReader m => TopFunName n -> m n (PiType SimpIR n)
getTypeTopFun f = lookupTopFun f >>= \case
Expand Down
22 changes: 11 additions & 11 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ simplifyApp hint resultTy f xs = case f of
_ -> slow =<< simplifyAtomAndInline f
where
fast :: LamExpr CoreIR i' -> SimplifyM i' o (CAtom o)
fast (LamExpr bs body) = extendSubst (bs@@>(SubstVal<$>xs)) $ simplifyBlock body
fast lam = withInstantiated lam xs \body -> simplifyBlock body

slow :: CAtom o -> SimplifyM i o (CAtom o)
slow = \case
Expand All @@ -462,10 +462,10 @@ simplifyApp hint resultTy f xs = case f of
extendSubst (b@>SubstVal x) do
xs' <- mapM sinkM xs
simplifyApp hint (sink resultTy) body xs'
SimpInCore (LiftSimpFun _ (LamExpr bs body)) -> do
SimpInCore (LiftSimpFun _ lam) -> do
xs' <- mapM toDataAtomIgnoreRecon xs
body' <- applySubst (bs@@>map SubstVal xs') body
liftSimpAtom resultTy =<< emitBlock body'
result <- instantiate lam xs' >>= emitBlock
liftSimpAtom resultTy result
Var v -> do
lookupAtomName (atomVarName v) >>= \case
NoinlineFun _ _ -> simplifyTopFunApp v xs
Expand Down Expand Up @@ -549,10 +549,10 @@ simplifyTabApp f [] = return f
simplifyTabApp f@(SimpInCore sic) xs = case sic of
TabLam _ _ -> do
case fromNaryTabLam (length xs) f of
Just (bsCount, Abs bs block) -> do
Just (bsCount, ab) -> do
let (xsPref, xsRest) = splitAt bsCount xs
xsPref' <- mapM toDataAtomIgnoreRecon xsPref
block' <- applySubst (bs@@>(SubstVal <$> xsPref')) block
block' <- instantiate ab xsPref'
atom <- emitDecls block'
simplifyTabApp atom xsRest
Nothing -> error "should never happen"
Expand Down Expand Up @@ -788,10 +788,10 @@ applyDictMethod resultTy d i methodArgs = do
cheapNormalize d >>= \case
DictCon _ (InstanceDict instanceName instanceArgs) -> dropSubst do
instanceArgs' <- mapM simplifyAtom instanceArgs
InstanceDef _ _ bsInstance _ body <- lookupInstanceDef instanceName
let InstanceBody _ methods = body
let method = methods !! i
extendSubst (bsInstance @@> (SubstVal <$> instanceArgs')) do
instanceDef <- lookupInstanceDef instanceName
withInstantiated instanceDef instanceArgs' \(PairE _ body) -> do
let InstanceBody _ methods = body
let method = methods !! i
simplifyApp noHint resultTy method methodArgs
DictCon _ (IxFin n) -> applyIxFinMethod (toEnum i) n methodArgs
d' -> error $ "Not a simplified dict: " ++ pprint d'
Expand Down Expand Up @@ -975,7 +975,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do
CustomLinearize nImplicit nExplicit zeros fCustom <- return rule
linearized <- withSimplifiedBinders runtimeBs \runtimeBs' runtimeArgs -> do
Abs runtimeBs' <$> buildScoped do
ListE staticArgs' <- applySubst (runtimeBs @@> (SubstVal . sink <$> runtimeArgs)) staticArgs
ListE staticArgs' <- instantiate (sink $ Abs runtimeBs staticArgs) (sink <$> runtimeArgs)
fCustom' <- sinkM fCustom
resultTy <- typeOfApp (getType fCustom') staticArgs'
pairResult <- dropSubst $ simplifyApp noHint resultTy fCustom' staticArgs'
Expand Down
20 changes: 20 additions & 0 deletions src/lib/Types/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ class ToBindersAbs (e::E) (body::E) (r::IR) | e -> body, e -> r where
instance ToBindersAbs CorePiType (EffTy CoreIR) CoreIR where
toAbs (CorePiType _ _ bs effTy) = Abs bs effTy

instance ToBindersAbs CoreLamExpr (Block CoreIR) CoreIR where
toAbs (CoreLamExpr _ lam) = toAbs lam

instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where
toAbs = id

Expand All @@ -254,6 +257,18 @@ instance ToBindersAbs (TabPiType r) (Type r) r where
instance ToBindersAbs (DepPairType r) (Type r) r where
toAbs (DepPairType _ b rhsTy) = Abs (UnaryNest b) rhsTy

instance ToBindersAbs InstanceDef (ListE CAtom `PairE` InstanceBody) CoreIR where
toAbs (InstanceDef _ _ bs params body) = Abs bs (ListE params `PairE` body)

instance ToBindersAbs TyConDef DataConDefs CoreIR where
toAbs (TyConDef _ _ bs body) = Abs bs body

instance ToBindersAbs ClassDef (Abs (Nest CBinder) (ListE CorePiType)) CoreIR where
toAbs (ClassDef _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys))

instance ToBindersAbs (TopLam r) (Block r) r where
toAbs (TopLam _ _ lam) = toAbs lam

-- === GenericOp class ===

class IsPrimOp (e::IR->E) where
Expand Down Expand Up @@ -1198,6 +1213,9 @@ instance AlphaHashableE TyConDef
instance HasNameHint (TyConDef n) where
getNameHint (TyConDef v _ _ _) = getNameHint v

instance HasSourceName (TyConDef n) where
getSourceName (TyConDef v _ _ _) = v

instance GenericE DataConDef where
type RepE DataConDef = (LiftE (SourceName, [[Projection]]))
`PairE` EmptyAbs (Nest CBinder) `PairE` Type CoreIR
Expand Down Expand Up @@ -1850,6 +1868,8 @@ instance AlphaHashableE ClassDef
instance RenameE ClassDef
deriving instance Show (ClassDef n)
deriving via WrapE ClassDef n instance Generic (ClassDef n)
instance HasSourceName (ClassDef n) where
getSourceName = \case ClassDef name _ _ _ _ _ _ -> name

instance GenericE InstanceDef where
type RepE InstanceDef =
Expand Down

0 comments on commit 336bb04

Please sign in to comment.