Skip to content

Commit

Permalink
refactor backend for type-safety and clarity (#274)
Browse files Browse the repository at this point in the history
* add type-safe layer round FFI
* separate XLA Idris API from high-level usage of this API
* simplify functionality for reading and writing `Literal`s
  • Loading branch information
joelberkeley authored May 26, 2022
1 parent 8276a38 commit c604608
Show file tree
Hide file tree
Showing 37 changed files with 1,288 additions and 585 deletions.
46 changes: 32 additions & 14 deletions spidr.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,39 @@ modules =
BayesianOptimization.Acquisition,
BayesianOptimization.Morphisms,

Compiler.FFI,
Compiler.Computation,
Compiler.Graph,
Compiler.TensorFlow.Compiler.XLA.Client.Lib.Math,
Compiler.TensorFlow.Compiler.XLA.Client.Lib.Matrix,
Compiler.TensorFlow.Compiler.XLA.Client.ClientLibrary,
Compiler.TensorFlow.Compiler.XLA.Client.LocalClient,
Compiler.TensorFlow.Compiler.XLA.Client.XlaBuilder,
Compiler.TensorFlow.Compiler.XLA.Client.XlaComputation,
Compiler.TensorFlow.Compiler.XLA.Literal,
Compiler.TensorFlow.Compiler.XLA.Service.PlatformUtil,
Compiler.TensorFlow.Compiler.XLA.Shape,
Compiler.TensorFlow.Compiler.XLA.ShapeUtil,
Compiler.TensorFlow.Compiler.XLA.XlaData,
Compiler.TensorFlow.Core.CommonRuntime.GPU.GPUInit,
Compiler.XLA,
Compiler.LiteralRW,

Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Math,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Matrix,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.ClientLibrary,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.LocalClient,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaBuilder,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaComputation,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Literal,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Service.PlatformUtil,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Shape,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.ShapeUtil,
Compiler.Xla.Prim.TensorFlow.Core.CommonRuntime.GPU.GPUInit,
Compiler.Xla.Prim.TensorFlow.Core.Platform.Status,
Compiler.Xla.Prim.Util,

Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Math,
Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Matrix,
Compiler.Xla.TensorFlow.Compiler.Xla.Client.ClientLibrary,
Compiler.Xla.TensorFlow.Compiler.Xla.Client.LocalClient,
Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder,
Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation,
Compiler.Xla.TensorFlow.Compiler.Xla.Literal,
Compiler.Xla.TensorFlow.Compiler.Xla.Service.PlatformUtil,
Compiler.Xla.TensorFlow.Compiler.Xla.Shape,
Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil,
Compiler.Xla.TensorFlow.Compiler.Xla.XlaData,
Compiler.Xla.TensorFlow.Core.CommonRuntime.GPU.GPUInit,
Compiler.Xla.TensorFlow.Core.Platform.Status,
Compiler.Xla.TensorFlow.StreamExecutor.Platform,
Compiler.Xla.Util,

Data,
Distribution,
Expand Down
95 changes: 95 additions & 0 deletions src/Compiler/Computation.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{--
Copyright 2022 Joel Berkeley
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module Compiler.Computation

import Control.Monad.State
import Data.SortedMap

import Data.Hashable

import Compiler.Graph
import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder
import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation
import Compiler.Xla.TensorFlow.Compiler.Xla.Shape
import Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Types

public export
data CachingBuilder : Type where
MkCachingBuilder : XlaBuilder -> SortedMap Bits64 XlaOp -> CachingBuilder

public export
Computation : Type -> Type
Computation = StateT CachingBuilder IO

export
cached : Graph -> Computation XlaOp -> Computation XlaOp
cached graph xs = assert_total $ let graphHash = hash graph in do
builder <- get
case cacheLookup builder graphHash of
Just op => pure op
Nothing => do
op <- xs
builder <- get
put (cacheInsert builder graphHash op)
pure op

where
cacheInsert : CachingBuilder -> Bits64 -> XlaOp -> CachingBuilder
cacheInsert (MkCachingBuilder builder cache) key xlaOp =
MkCachingBuilder builder (insert key xlaOp cache)

cacheLookup : CachingBuilder -> Bits64 -> Maybe XlaOp
cacheLookup (MkCachingBuilder _ cache) key = lookup key cache

export
build : HasIO io => String -> Computation XlaOp -> io XlaComputation
build computationName x = do
builder <- mkXlaBuilder computationName
MkCachingBuilder builder _ <- liftIO $ execStateT (MkCachingBuilder builder empty) x
build builder

export
buildWithSubBuilder :
String -> List (Computation XlaOp) -> Computation XlaOp -> Computation XlaComputation
buildWithSubBuilder computationName computationArguments computationResult = do
MkCachingBuilder builder _ <- get
subBuilder <- createSubBuilder builder computationName
let cachingSubBuilder = MkCachingBuilder subBuilder empty
allOps = sequence_ (computationArguments ++ [computationResult])
MkCachingBuilder subBuilder _ <- liftIO $ execStateT cachingSubBuilder allOps
build subBuilder

export
opToString : Computation XlaOp -> String
opToString x = unsafePerformIO $ do
builder <- mkXlaBuilder "toString"
(MkCachingBuilder builder _, xlaOp) <- runStateT (MkCachingBuilder builder empty) x
pure $ opToString builder xlaOp

export
parameter : Primitive dtype => Nat -> Types.Shape -> String -> (Graph, Computation XlaOp)
parameter position shape name =
let graph = Parameter {dtype} shape position

param : Computation XlaOp
param = do
MkCachingBuilder builder _ <- get
xlaShape <- mkShape {dtype} shape
cached graph $ parameter builder position xlaShape name

in (graph, param)
6 changes: 3 additions & 3 deletions src/Compiler/Graph.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ limitations under the License.
--}
module Compiler.Graph

import Primitive
import Data.Hashable
import Data.Stream

import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Types
import Util
import Util.Hashable

||| A `Graph` represents a computational graph used to compute a tensor value. It is defined by
||| the following proprty: For any two `Graph`s gx and gy that compute tensors x and y respectively,
Expand Down
76 changes: 76 additions & 0 deletions src/Compiler/LiteralRW.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
{--
Copyright 2022 Joel Berkeley
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module Compiler.LiteralRW

import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Compiler.Xla.TensorFlow.Compiler.Xla.Literal
import Literal
import Util

range : (n : Nat) -> Literal [n] Nat
range n = impl n []
where
impl : (p : Nat) -> Literal [q] Nat -> Literal [q + p] Nat
impl Z xs = rewrite plusZeroRightNeutral q in xs
impl (S p) xs = rewrite sym $ plusSuccRightSucc q p in impl p (Scalar p :: xs)

indexed : {shape : _} -> Literal shape (List Nat)
indexed = go shape []
where
concat : Literal [d] (Literal ds a) -> Literal (d :: ds) a
concat [] = []
concat (Scalar x :: xs) = x :: concat xs

go : (shape : Types.Shape) -> List Nat -> Literal shape (List Nat)
go [] idxs = Scalar idxs
go (0 :: _) _ = []
go (S d :: ds) idxs = concat $ map (\i => go ds (snoc idxs i)) (range (S d))

export
interface Primitive dtype => LiteralRW dtype ty where
set : Literal -> List Nat -> ty -> IO ()
get : Literal -> List Nat -> ty

export
write : (HasIO io, LiteralRW dtype a) => {shape : _} -> Literal shape a -> io Literal
write xs = liftIO $ do
literal <- allocLiteral {dtype} shape
sequence_ [| (\idxs => set {dtype} literal idxs) indexed xs |]
pure literal

export
read : LiteralRW dtype a => Literal -> {shape : _} -> Literal shape a
read lit = map (get {dtype} lit) indexed

export
LiteralRW PRED Bool where
set = set
get = get

export
LiteralRW F64 Double where
set = set
get = get

export
LiteralRW S32 Int where
set = set
get = get

export
LiteralRW U32 Nat where
set lit idx x = Int.set lit idx (cast x)
get = cast .: Int.get
44 changes: 0 additions & 44 deletions src/Compiler/TensorFlow/Compiler/XLA/Client/LocalClient.idr

This file was deleted.

Loading

0 comments on commit c604608

Please sign in to comment.