Skip to content

Commit

Permalink
Add vectorization under a loop, and another example even closer to th…
Browse files Browse the repository at this point in the history
…e matmul kernel.

In the process, discovered and fixed two typing bugs that happened
to cancel out on previous test cases:
- Vectorizing a LamExpr may change the types of the arguments (if they
  are now vectors)
- Vector-indexing returns an object of different type from the element
  type that was being indexed (namely, the vector of those), and
  vectorIndexRepVal in Imp needs to accommodate that.
  • Loading branch information
axch committed Jun 22, 2023
1 parent 5fbe15a commit e446273
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 15 deletions.
27 changes: 17 additions & 10 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ toImpVectorOp = \case
-- VectorIdx requires that tbl' have a scalar element type, which is
-- ultimately enforced by `Lower.getVectorType` barfing on non-scalars.
tbl <- atomToRepVal tbl'
repValAtom =<< vectorIndexRepVal tbl i (toIVectorType vty)
repValAtom =<< vectorIndexRepVal tbl i vty
VectorSubref ref i vty -> do
refDest <- atomToDest ref
refi <- destToAtom <$> indexDest refDest i
Expand Down Expand Up @@ -1024,9 +1024,10 @@ naryIndexRepVal x (ix:ixs) = do

-- TODO: de-dup with indexDest?
indexRepValParam :: Emits n
=> RepVal SimpIR n -> SAtom n -> (IExpr n -> SubstImpM i n (IExpr n))
-> SubstImpM i n (RepVal SimpIR n)
indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i func = do
=> SRepVal n -> SAtom n -> (SType n -> SType n)
-> (IExpr n -> SubstImpM i n (IExpr n))
-> SubstImpM i n (SRepVal n)
indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc func = do
eltTy' <- applySubst (b@>SubstVal i) eltTy
ord <- ordinalImp (IxType t d) i
leafTys <- typeToTree tabTy
Expand All @@ -1039,20 +1040,26 @@ indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i func =
case ixStruct of
EmptyAbs (Nest _ Empty) -> func ptr' >>= load
_ -> func ptr'
return $ RepVal eltTy' vals'
indexRepValParam _ _ _ = error "expected table type"
-- `func` may have changed the types of the `vals'`. The caller must also
-- supply `tyFunc` to reflect that change in the SType.
return $ RepVal (tyFunc eltTy') vals'
indexRepValParam _ _ _ _ = error "expected table type"
{-# INLINE indexRepValParam #-}

indexRepVal :: Emits n
=> RepVal SimpIR n -> SAtom n -> SubstImpM i n (RepVal SimpIR n)
indexRepVal rep i = indexRepValParam rep i return
indexRepVal rep i = indexRepValParam rep i id return
{-# INLINE indexRepVal #-}

vectorIndexRepVal :: Emits n
=> RepVal SimpIR n -> SAtom n -> IVectorType
=> RepVal SimpIR n -> SAtom n -> SType n
-> SubstImpM i n (RepVal SimpIR n)
vectorIndexRepVal rep i vty = indexRepValParam rep i action where
action ptr = castPtrToVectorType ptr vty
vectorIndexRepVal rep i vty =
-- Passing `const vty` here depends on knowing that `vectorIndexRepVal` is
-- only called on references of scalar base type, so that the give `vty` is,
-- actually, the type of the result of the indexing operation.
indexRepValParam rep i (const vty) action where
action ptr = castPtrToVectorType ptr (toIVectorType vty)
{-# INLINE vectorIndexRepVal #-}

projectDest :: Int -> Dest n -> Dest n
Expand Down
42 changes: 37 additions & 5 deletions src/lib/Vectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import QueryType
import Types.Core
import Types.OpNames qualified as P
import Types.Primitives
import Util (allM)
import Util (allM, zipWithZ)

-- === Vectorization ===

Expand Down Expand Up @@ -152,6 +152,16 @@ vectorizeLoopsDecls nest cont =
extendSubst (b @> atomVarName v) $
vectorizeLoopsDecls rest cont

vectorizeLoopsLamExpr :: LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o)
vectorizeLoopsLamExpr (LamExpr bs body) = case bs of
Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsBlock body)
Nest (b:>ty) rest -> do
ty' <- renameM ty
withFreshBinder (getNameHint b) ty' \b' -> do
extendRenamer (b @> binderName b') do
LamExpr bs' body' <- vectorizeLoopsLamExpr $ LamExpr rest body
return $ LamExpr (Nest b' bs') body'

vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o)
vectorizeLoopsExpr expr = do
vectorByteWidth <- askVectorByteWidth
Expand All @@ -175,7 +185,8 @@ vectorizeLoopsExpr expr = do
ctx = mempty { messageCtx = [msg] }
errs' = prependCtxToErrs ctx errs
modify (<> LiftE errs')
renameM expr
recurSeq expr
PrimOp (DAMOp (Seq _ _ _ _ _)) -> recurSeq expr
PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
item' <- renameM item
itemTy <- return $ getType item'
Expand All @@ -197,6 +208,15 @@ vectorizeLoopsExpr expr = do
vectorizeLoopsBlock body
PrimOp . Hof <$> mkTypedHof (RunWriter (Just dest') monoid' lam)
_ -> renameM expr
where
recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o)
recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do
effs' <- renameM effs
ixty' <- renameM ixty
dest' <- renameM dest
body' <- vectorizeLoopsLamExpr body
return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
recurSeq _ = error "Impossible"

-- Really we should check this by seeing whether there is an instance for a
-- `Commutative` class, or something like that, but for now just pattern-match
Expand Down Expand Up @@ -331,7 +351,8 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of
(VRename v) -> Var <$> toAtomVar v)
(Nest (b:>ty) rest, (stab:stabs)) -> do
ty' <- vectorizeType ty
withFreshBinder (getNameHint b) ty' \b' -> do
ty'' <- promoteTypeByStability ty' stab
withFreshBinder (getNameHint b) ty'' \b' -> do
var <- toAtomVar $ binderName b'
extendSubst (b @> VVal stab (Var var)) do
LamExpr rest' body' <- vectorizeLamExpr (LamExpr rest body) stabs
Expand Down Expand Up @@ -396,14 +417,16 @@ vectorizeRefOp ref' op =
VVal xStab x <- vectorizeAtom x'
basemonoid <- case refStab of
Uniform -> case xStab of
Uniform -> vectorizeBaseMonoid basemonoid' Uniform Uniform
Uniform -> do
vectorizeBaseMonoid basemonoid' Uniform Uniform
-- This case represents accumulating something loop-varying into a
-- loop-invariant accumulator, as e.g. sum. We can implement that for
-- commutative monoids, but we would want to have started with private
-- accumulators (one per lane), and then reduce them with an
-- appropriate sequence of vector reduction intrinsics at the end.
_ -> throwVectErr $ "Vectorizing non-sliced accumulation not implemented"
Contiguous -> vectorizeBaseMonoid basemonoid' Varying xStab
Contiguous -> do
vectorizeBaseMonoid basemonoid' Varying xStab
s -> throwVectErr $ "Cannot vectorize reference with loop-varying stability " ++ show s
VVal Uniform <$> emitOp (RefOp ref $ MExtend basemonoid x)
IndexRef _ i' -> do
Expand Down Expand Up @@ -543,6 +566,15 @@ ensureVarying (VRename v) = do
x <- Var <$> toAtomVar v
ensureVarying (VVal Uniform x)

promoteTypeByStability :: SType o -> Stability -> VectorizeM i o (SType o)
promoteTypeByStability ty = \case
Uniform -> return ty
Contiguous -> return ty
Varying -> getVectorType ty
ProdStability stabs -> case ty of
ProdTy elts -> ProdTy <$> zipWithZ promoteTypeByStability elts stabs
_ -> throw ZipErr "Type and stability"

-- === computing byte widths ===

newtype CalcWidthM i o a = CalcWidthM {
Expand Down
21 changes: 21 additions & 0 deletions tests/opt-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,24 @@ _ = yield_accum (AddMonoid Int32) \ref.
-- CHECK: [[xi:v#[0-9]+]]:<16xInt32> =
-- CHECK-NEXT: vslice
-- CHECK: extend [[refi]] [[xi]]

"vectorizing under an outer loop, like matmul"
-- CHECK-LABEL: vectorizing under an outer loop, like matmul

mat1 = for i:(Fin 32). for j:(Fin 32).
(n_to_i32 (ordinal i)) * (n_to_i32 (ordinal j)) + 1

mat2 = for i:(Fin 32). for j:(Fin 32).
(n_to_i32 (ordinal i)) * (n_to_i32 (ordinal j)) + 7

%passes vect
_ = yield_accum (AddMonoid Int32) \result.
for k:(Fin 32).
for j:(Fin 32).
result!(3@(Fin 32))!j += mat1[3@_][k] * mat2[k][j]
-- CHECK: seq (RawFin 0x2)
-- CHECK: [[refj:v#[0-9]+]]:(Ref {{v#[0-9]+}} <16xInt32>) = vrefslice
-- CHECK: [[mat2j:v#[0-9]+]]:<16xInt32> = vslice
-- CHECK: [[mat1:v#[0-9]+]]:<16xInt32> = vbroadcast
-- CHECK: [[prodj:v#[0-9]+]]:<16xInt32> = %imul [[mat1]] [[mat2j]]
-- CHECK: extend [[refj]] [[prodj]]

0 comments on commit e446273

Please sign in to comment.