Skip to content

Commit

Permalink
Unbundle binders from their role/expl attributes.
Browse files Browse the repository at this point in the history
Fancy B-kinded things are a pain and they're about to get worse when we add
decls to binders. An earlier attempt at adding decls without doing this forced
me to create lots of complicated type classes to handle all the `WithExpl` and
`RolePiBinder` variants.
  • Loading branch information
dougalm committed Jun 27, 2023
1 parent c7fef43 commit 75eacbf
Show file tree
Hide file tree
Showing 21 changed files with 543 additions and 575 deletions.
49 changes: 29 additions & 20 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,23 @@ topDecl = dropSrc topDecl' where
topDecl' (CSDecl ann d) = ULocalDecl <$> decl ann (WithSrc emptySrcPosCtx d)
topDecl' (CData name tyConParams givens constructors) = do
tyConParams' <- aExplicitParams tyConParams
givens' <- toNest <$> fromMaybeM givens [] aGivens
givens' <- aOptGivens givens
constructors' <- forM constructors \(v, ps) -> do
ps' <- toNest <$> mapM tyOptBinder ps
return (v, ps')
return $ UDataDefDecl
(UDataDef name (givens' >>> tyConParams') $
(UDataDef name (catUOptAnnExplBinders givens' tyConParams') $
map (\(name', cons) -> (name', UDataDefTrail cons)) constructors')
(fromString name)
(toNest $ map (fromString . fst) constructors')
topDecl' (CStruct name params givens fields defs) = do
params' <- aExplicitParams params
givens' <- toNest <$> fromMaybeM givens [] aGivens
givens' <- aOptGivens givens
fields' <- forM fields \(v, ty) -> (v,) <$> expr ty
methods <- forM defs \(ann, d) -> do
(methodName, lam) <- aDef d
return (ann, methodName, Abs (UBindSource emptySrcPosCtx "self") lam)
return $ UStructDecl (fromString name) (UStructDef name (givens' >>> params') fields' methods)
return $ UStructDecl (fromString name) (UStructDef name (catUOptAnnExplBinders givens' params') fields' methods)
topDecl' (CInterface name params methods) = do
params' <- aExplicitParams params
(methodNames, methodTys) <- unzip <$> forM methods \(methodName, ty) -> do
Expand All @@ -153,7 +153,7 @@ aInstanceDef :: CInstanceDef -> SyntaxM (UTopDecl VoidS VoidS)
aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do
let clName' = fromString clName
args' <- mapM expr args
givens' <- toNest <$> fromMaybeM givens [] aGivens
givens' <- aOptGivens givens
methods' <- catMaybes <$> mapM aMethod methods
case instNameAndParams of
Nothing -> return $ UInstance clName' givens' args' methods' NothingB ImplicitApp
Expand All @@ -162,7 +162,7 @@ aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do
case optParams of
Just params -> do
params' <- aExplicitParams params
return $ UInstance clName' (givens' >>> params') args' methods' instName' ExplicitApp
return $ UInstance clName' (catUOptAnnExplBinders givens' params') args' methods' instName' ExplicitApp
Nothing -> return $ UInstance clName' givens' args' methods' instName' ImplicitApp

aDef :: CDef -> SyntaxM (SourceName, ULamExpr VoidS)
Expand All @@ -173,19 +173,27 @@ aDef (CDef name params optRhs optGivens body) = do
effs <- fromMaybeM optEffs UPure aEffects
resultTy' <- expr resultTy
return (expl, Just effs, Just resultTy')
implicitParams <- toNest <$> fromMaybeM optGivens [] aGivens
let allParams = implicitParams >>> explicitParams
implicitParams <- aOptGivens optGivens
let allParams = catUOptAnnExplBinders implicitParams explicitParams
body' <- block body
return (name, ULamExpr allParams expl effs resultTy body')

catUOptAnnExplBinders :: UOptAnnExplBinders n l -> UOptAnnExplBinders l l' -> UOptAnnExplBinders n l'
catUOptAnnExplBinders (expls, bs) (expls', bs') = (expls <> expls', bs >>> bs')

stripParens :: Group -> Group
stripParens (WithSrc _ (CParens [g])) = stripParens g
stripParens g = g

aExplicitParams :: ExplicitParams -> SyntaxM (Nest (WithExpl UOptAnnBinder) VoidS VoidS)
aExplicitParams :: ExplicitParams -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS)
aExplicitParams gs = generalBinders DataParam Explicit gs

aGivens :: GivenClause -> SyntaxM [WithExpl UOptAnnBinder VoidS VoidS]
aOptGivens :: Maybe GivenClause -> SyntaxM (UOptAnnExplBinders VoidS VoidS)
aOptGivens optGivens = do
(expls, implicitParams) <- unzip <$> fromMaybeM optGivens [] aGivens
return (expls, toNest implicitParams)

aGivens :: GivenClause -> SyntaxM [(Explicitness, UOptAnnBinder VoidS VoidS)]
aGivens (implicits, optConstraints) = do
implicits' <- mapM (generalBinder DataParam (Inferred Nothing Unify)) implicits
constraints <- fromMaybeM optConstraints [] \gs -> do
Expand All @@ -194,23 +202,24 @@ aGivens (implicits, optConstraints) = do

generalBinders
:: ParamStyle -> Explicitness -> [Group]
-> SyntaxM (Nest (WithExpl UOptAnnBinder) VoidS VoidS)
generalBinders paramStyle expl params = toNest . concat <$>
forM params \case
-> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS)
generalBinders paramStyle expl params = do
(expls, bs) <- unzip . concat <$> forM params \case
WithSrc _ (CGivens gs) -> aGivens gs
p -> (:[]) <$> generalBinder paramStyle expl p
return (expls, toNest bs)

generalBinder :: ParamStyle -> Explicitness -> Group
-> SyntaxM (WithExpl UOptAnnBinder VoidS VoidS)
-> SyntaxM (Explicitness, UOptAnnBinder VoidS VoidS)
generalBinder paramStyle expl g = case expl of
Inferred _ (Synth _) -> WithExpl expl <$> tyOptBinder g
Inferred _ (Synth _) -> (expl,) <$> tyOptBinder g
Inferred _ Unify -> do
b <- binderOptTy g
expl' <- return case b of
UAnnBinder (UBindSource _ s) _ _ -> Inferred (Just s) Unify
_ -> expl
return $ WithExpl expl' b
Explicit -> WithExpl expl <$> case paramStyle of
return (expl', b)
Explicit -> (expl,) <$> case paramStyle of
TypeParam -> tyOptBinder g
DataParam -> binderOptTy g

Expand Down Expand Up @@ -347,7 +356,7 @@ aMethod (WithSrc src d) = Just . WithSrcE src <$> addSrcContext src case d of
(name, lam) <- aDef def
return $ UMethodDef (fromString name) lam
CLet (WithSrc _ (CIdentifier name)) rhs -> do
rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs
rhs' <- ULamExpr ([], Empty) ImplicitApp Nothing Nothing <$> block rhs
return $ UMethodDef (fromString name) rhs'
_ -> throw SyntaxErr "Unexpected method definition. Expected `def` or `x = ...`."

Expand All @@ -368,10 +377,10 @@ blockDecls [WithSrc src d] = addSrcContext src case d of
CExpr g -> (Empty,) <$> expr g
_ -> throw SyntaxErr "Block must end in expression"
blockDecls (WithSrc pos (CBind b rhs):ds) = do
WithExpl _ b' <- generalBinder DataParam Explicit b
(_, b') <- generalBinder DataParam Explicit b
rhs' <- asExpr <$> block rhs
body <- block $ IndentedBlock ds
let lam = ULam $ ULamExpr (UnaryNest (WithExpl Explicit b')) ExplicitApp Nothing Nothing body
let lam = ULam $ ULamExpr ([Explicit], UnaryNest b') ExplicitApp Nothing Nothing body
return (Empty, WithSrcE pos $ extendAppRight rhs' (ns lam))
blockDecls (d:ds) = do
d' <- decl PlainLet d
Expand Down
15 changes: 3 additions & 12 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -602,22 +602,13 @@ buildBlock
-> m n (Block r n)
buildBlock = buildScoped

coreLamExpr :: EnvReader m => AppExplicitness
-> Abs (Nest (WithExpl CBinder)) (PairE (EffectRow CoreIR) CBlock) n
-> m n (CoreLamExpr n)
coreLamExpr appExpl ab = liftEnvReaderM do
refreshAbs ab \bs' (PairE effs' body') -> do
EffTy _ resultTy <- blockEffTy body'
let bs'' = fmapNest withoutExpl bs'
return $ CoreLamExpr (CorePiType appExpl bs' (EffTy effs' resultTy)) (LamExpr bs'' body')

buildCoreLam
:: ScopableBuilder CoreIR m
=> CorePiType n
-> (forall l. (Emits l, DExt n l) => [CAtomVar l] -> m l (CAtom l))
-> m n (CoreLamExpr n)
buildCoreLam piTy@(CorePiType _ bs _) cont = do
lam <- buildLamExpr (EmptyAbs $ fmapNest withoutExpl bs) cont
buildCoreLam piTy@(CorePiType _ _ bs _) cont = do
lam <- buildLamExpr (EmptyAbs bs) cont
return $ CoreLamExpr piTy lam

buildAbs
Expand Down Expand Up @@ -1083,7 +1074,7 @@ projectStructRef i x = do

getStructProjections :: EnvReader m => Int -> CType n -> m n [Projection]
getStructProjections i (NewtypeTyCon (UserADTType _ tyConName _)) = do
TyConDef _ _ ~(StructFields fields) <- lookupTyCon tyConName
TyConDef _ _ _ ~(StructFields fields) <- lookupTyCon tyConName
return case fields of
[_] | i == 0 -> [UnwrapNewtype]
| otherwise -> error "bad index"
Expand Down
14 changes: 6 additions & 8 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ cheapReduceDictExpr resultTy d = case d of
cheapReduceE child >>= \case
DictCon _ (InstanceDict instanceName args) -> dropSubst do
args' <- mapM cheapReduceE args
InstanceDef _ bs _ body <- lookupInstanceDef instanceName
InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName
let InstanceBody superclasses _ = body
applySubst (bs@@>(SubstVal <$> args')) (superclasses !! superclassIx)
child' -> return $ DictCon resultTy $ SuperclassProj child' superclassIx
Expand Down Expand Up @@ -285,7 +285,7 @@ instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where
cheapReduceE dict >>= \case
DictCon _ (InstanceDict instanceName args) -> dropSubst do
args' <- mapM cheapReduceE args
InstanceDef _ bs _ (InstanceBody _ methods) <- lookupInstanceDef instanceName
InstanceDef _ _ bs _ (InstanceBody _ methods) <- lookupInstanceDef instanceName
let method = methods !! i
extendSubst (bs@@>(SubstVal <$> args')) do
method' <- cheapReduceE method
Expand Down Expand Up @@ -466,7 +466,7 @@ wrapNewtypesData [] x = x
wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x

instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n)
instantiateTyConDef (TyConDef _ bs conDefs) (TyConParams _ xs) = do
instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do
applySubst (bs @@> (SubstVal <$> xs)) conDefs
{-# INLINE instantiateTyConDef #-}

Expand All @@ -487,7 +487,7 @@ dataDefRep (StructFields fields) = case map snd fields of

makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n)
makeStructRepVal tyConName args = do
TyConDef _ _ (StructFields fields) <- lookupTyCon tyConName
TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName
case fields of
[_] -> case args of
[arg] -> return arg
Expand Down Expand Up @@ -725,11 +725,9 @@ instance VisitGeneric CoreLamExpr CoreIR where
visitGeneric (CoreLamExpr t lam) = CoreLamExpr <$> visitGeneric t <*> visitGeneric lam

instance VisitGeneric CorePiType CoreIR where
visitGeneric (CorePiType app bsExpl effty) = do
let (expls, bs) = unzipExpls bsExpl
visitGeneric (CorePiType app expl bs effty) = do
PiType bs' effty' <- visitGeneric $ PiType bs effty
let bsExpl' = zipExpls expls bs'
return $ CorePiType app bsExpl' effty'
return $ CorePiType app expl bs' effty'

instance IRRep r => VisitGeneric (TabPiType r) r where
visitGeneric (TabPiType d b eltTy) = do
Expand Down
31 changes: 14 additions & 17 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ instance IRRep r => HasType r (Type r) where
TC tyCon -> typeCheckPrimTC tyCon
DepPairTy ty -> getTypeE ty
DictTy (DictType _ className params) -> do
ClassDef _ _ _ paramBs _ _ <- renameM className >>= lookupClassDef
ClassDef _ _ _ _ paramBs _ _ <- renameM className >>= lookupClassDef
params' <- mapM renameM params
checkArgTys paramBs params'
return TyKind
Expand Down Expand Up @@ -293,9 +293,6 @@ instance (ToBinding ann c, Color c, CheckableE r ann) => CheckableB r (BinderP c
extendRenamer (b@>binderName b') $
cont b'

instance (BindsNames b, CheckableB r b) => CheckableB r (WithExpl b) where
checkB (WithExpl expl b) cont = checkB b \b' -> cont (WithExpl expl b')

typeCheckExpr :: (Typer m r, IRRep r) => EffectRow r o -> Expr r i -> m i o (Type r o)
typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of
App (EffTy _ reqTy) f xs -> do
Expand All @@ -318,7 +315,7 @@ typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case ex
return resultTy'
ApplyMethod (EffTy _ reqTy) dict i args -> do
DictTy (DictType _ className params) <- getTypeE dict
ClassDef _ _ _ paramBs classBs methodTys <- lookupClassDef className
ClassDef _ _ _ _ paramBs classBs methodTys <- lookupClassDef className
let methodTy = methodTys !! i
superclassDicts <- getSuperclassDicts =<< renameM dict
let subst = ( paramBs @@> map SubstVal params
Expand All @@ -342,8 +339,8 @@ dictExprType :: Typer m CoreIR => DictExpr i -> m i o (CType o)
dictExprType e = case e of
InstanceDict instanceName args -> do
instanceName' <- renameM instanceName
InstanceDef className bs params _ <- lookupInstanceDef instanceName'
ClassDef sourceName _ _ _ _ _ <- lookupClassDef className
InstanceDef className _ bs params _ <- lookupInstanceDef instanceName'
ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className
args' <- mapM renameM args
checkArgTys bs args'
ListE params' <- applySubst (bs@@>(SubstVal<$>args')) (ListE params)
Expand All @@ -353,7 +350,7 @@ dictExprType e = case e of
checkApp Pure givenTy (toList args)
SuperclassProj d i -> do
DictTy (DictType _ className params) <- getTypeE d
ClassDef _ _ _ bs superclasses _ <- lookupClassDef className
ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className
let scType = getSuperclassType REmpty superclasses i
checkedApplyNaryAbs (Abs bs scType) params
IxFin n -> do
Expand All @@ -370,7 +367,7 @@ instance IRRep r => HasType r (DepPairType r) where
return TyKind

instance HasType CoreIR CorePiType where
getTypeE (CorePiType _ bs (EffTy eff resultTy)) = do
getTypeE (CorePiType _ _ bs (EffTy eff resultTy)) = do
checkB bs \_ -> do
void $ checkE eff
resultTy|:TyKind
Expand Down Expand Up @@ -407,14 +404,14 @@ checkAgainstGiven givenTy computedTy = do
return givenTy'

checkCoreLam :: Typer m CoreIR => CorePiType o -> LamExpr CoreIR i -> m i o ()
checkCoreLam (CorePiType _ Empty (EffTy effs resultTy)) (LamExpr Empty body) = do
checkCoreLam (CorePiType _ _ Empty (EffTy effs resultTy)) (LamExpr Empty body) = do
resultTy' <- checkBlockWithEffs effs body
checkTypesEq resultTy resultTy'
checkCoreLam (CorePiType expl (Nest piB piBs) effTy) (LamExpr (Nest lamB lamBs) body) = do
checkCoreLam (CorePiType expl (_:expls) (Nest piB piBs) effTy) (LamExpr (Nest lamB lamBs) body) = do
argTy <- renameM $ binderType lamB
checkTypesEq (binderType piB) argTy
withFreshBinder (getNameHint lamB) argTy \b -> do
piTy <- applyRename (piB@>binderName b) (CorePiType expl piBs effTy)
piTy <- applyRename (piB@>binderName b) (CorePiType expl expls piBs effTy)
extendRenamer (lamB@>binderName b) do
checkCoreLam piTy (LamExpr lamBs body)
checkCoreLam _ _ = throw TypeErr "zip error"
Expand Down Expand Up @@ -446,7 +443,7 @@ typeCheckNewtypeCon con x = case con of
FinCon n -> n|:NatTy >> x|:NatTy >> renameM (Fin n)
UserADTData _ d params -> do
d' <- renameM d
def@(TyConDef sn _ _) <- lookupTyCon d'
def@(TyConDef sn _ _ _) <- lookupTyCon d'
params' <- renameM params
void $ checkedInstantiateTyConDef def params'
return $ UserADTType sn d' params'
Expand Down Expand Up @@ -773,7 +770,7 @@ checkAlt resultTyReq bTyReq effs (Abs b body) = do

checkApp :: (Typer m r, IRRep r) => EffectRow r o -> Type r o -> [Atom r i] -> m i o (Type r o)
checkApp allowedEffs fTy xs = case fTy of
Pi (CorePiType _ bs effTy) -> do
Pi (CorePiType _ _ bs effTy) -> do
xs' <- mapM renameM xs
checkArgTys bs xs'
let subst = bs @@> fmap SubstVal xs'
Expand Down Expand Up @@ -929,7 +926,7 @@ checkUnOp op x = do
checkedInstantiateTyConDef
:: (EnvReader m, Fallible1 m)
=> TyConDef n -> TyConParams n -> m n (DataConDefs n)
checkedInstantiateTyConDef (TyConDef _ bs cons) (TyConParams _ xs) = do
checkedInstantiateTyConDef (TyConDef _ _ bs cons) (TyConParams _ xs) = do
checkedApplyNaryAbs (Abs bs cons) xs

checkedApplyNaryAbs
Expand Down Expand Up @@ -995,7 +992,7 @@ asFFIFunType ty = return do
return (impTy, piTy)

checkFFIFunTypeM :: Fallible m => CorePiType n -> m IFunType
checkFFIFunTypeM (CorePiType appExpl (Nest b bs) effTy) = do
checkFFIFunTypeM (CorePiType appExpl (_:expls) (Nest b bs) effTy) = do
argTy <- checkScalar $ binderType b
case bs of
Empty -> do
Expand All @@ -1006,7 +1003,7 @@ checkFFIFunTypeM (CorePiType appExpl (Nest b bs) effTy) = do
_ -> FFIMultiResultCC
return $ IFunType cc [argTy] resultTys
Nest b' rest -> do
let naryPiRest = CorePiType appExpl (Nest b' rest) effTy
let naryPiRest = CorePiType appExpl expls (Nest b' rest) effTy
IFunType cc argTys resultTys <- checkFFIFunTypeM naryPiRest
return $ IFunType cc (argTy:argTys) resultTys
checkFFIFunTypeM _ = error "expected at least one argument"
Expand Down
12 changes: 4 additions & 8 deletions src/lib/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,13 @@ instance BindsEnv EnvFrag where
toEnvFrag frag = frag
{-# INLINE toEnvFrag #-}

instance BindsEnv b => BindsEnv (WithExpl b) where
toEnvFrag (WithExpl _ b) = toEnvFrag b
{-# INLINE toEnvFrag #-}

instance BindsEnv RolePiBinder where
toEnvFrag (RolePiBinder _ b) = toEnvFrag b
{-# INLINE toEnvFrag #-}

instance BindsEnv (RecSubstFrag Binding) where
toEnvFrag frag = EnvFrag frag

instance BindsEnv b => BindsEnv (WithAttrB a b) where
toEnvFrag (WithAttrB _ b) = toEnvFrag b
{-# INLINE toEnvFrag #-}

instance (BindsEnv b1, BindsEnv b2)
=> (BindsEnv (PairB b1 b2)) where
toEnvFrag (PairB b1 b2) = do
Expand Down
10 changes: 5 additions & 5 deletions src/lib/Export.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,26 @@ liftExportSigM cont = do

corePiToExportSig :: CallingConvention
-> CorePiType i -> ExportSigM CoreIR i o (ExportedSignature o)
corePiToExportSig cc (CorePiType _ tbs (EffTy effs resultTy)) = do
corePiToExportSig cc (CorePiType _ expls tbs (EffTy effs resultTy)) = do
case effs of
Pure -> return ()
_ -> throw TypeErr "Only pure functions can be exported"
goArgs cc Empty [] tbs resultTy
goArgs cc Empty [] (zipAttrs expls tbs) resultTy

simpPiToExportSig :: CallingConvention
-> PiType SimpIR i -> ExportSigM SimpIR i o (ExportedSignature o)
simpPiToExportSig cc (PiType bs (EffTy effs resultTy)) = do
case effs of
Pure -> return ()
_ -> throw TypeErr "Only pure functions can be exported"
bs' <- return $ fmapNest (\b -> WithExpl Explicit b) bs
bs' <- return $ fmapNest (\b -> WithAttrB Explicit b) bs
goArgs cc Empty [] bs' resultTy

goArgs :: (IRRep r)
=> CallingConvention
-> Nest ExportArg o o'
-> [CAtomName o']
-> Nest (WithExpl (Binder r)) i i'
-> Nest (WithAttrB Explicitness (Binder r)) i i'
-> Type r i'
-> ExportSigM r i o' (ExportedSignature o)
goArgs cc argSig argVs piBs piRes = case piBs of
Expand All @@ -128,7 +128,7 @@ goArgs cc argSig argVs piBs piRes = case piBs of
StandardCC -> (fromListE $ sink $ ListE argVs) ++ nestToList (sink . binderName) resSig
XLACC -> []
_ -> error $ "calling convention not supported: " ++ show cc
Nest (WithExpl expl (b:>ty)) bs -> do
Nest (WithAttrB expl (b:>ty)) bs -> do
ety <- toExportType ty
withFreshBinder (getNameHint b) ety \(v:>_) ->
extendSubst (b @> Rename (binderName v)) $ do
Expand Down
Loading

0 comments on commit 75eacbf

Please sign in to comment.