From c6046089a6e48d84e1b6683bdc0d4710c1b84bc9 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Thu, 26 May 2022 15:06:29 +0100 Subject: [PATCH] refactor backend for type-safety and clarity (#274) * add type-safe layer round FFI * separate XLA Idris API from high-level usage of this API * simplify functionality for reading and writing `Literal`s --- spidr.ipkg | 46 ++- src/Compiler/Computation.idr | 95 +++++ src/Compiler/Graph.idr | 6 +- src/Compiler/LiteralRW.idr | 76 ++++ .../Compiler/XLA/Client/LocalClient.idr | 44 --- .../TensorFlow/Compiler/XLA/Literal.idr | 115 ------ src/Compiler/XLA.idr | 103 ----- .../Compiler/Xla}/Client/ClientLibrary.idr | 7 +- .../Compiler/Xla}/Client/Lib/Math.idr | 4 +- .../Compiler/Xla}/Client/Lib/Matrix.idr | 4 +- .../Compiler/Xla/Client/LocalClient.idr | 27 ++ .../Compiler/Xla}/Client/XlaBuilder.idr | 41 +- .../Compiler/Xla}/Client/XlaComputation.idr | 9 +- .../Prim/TensorFlow/Compiler/Xla/Literal.idr | 52 +++ .../Compiler/Xla}/Service/PlatformUtil.idr | 4 +- .../Prim/TensorFlow/Compiler/Xla}/Shape.idr | 9 +- .../TensorFlow/Compiler/Xla/ShapeUtil.idr | 24 ++ .../Core/CommonRuntime/GPU/GPUInit.idr | 4 +- .../Prim}/TensorFlow/Core/Platform/Status.idr | 14 +- src/Compiler/Xla/Prim/Util.idr | 30 ++ .../Compiler/Xla/Client/ClientLibrary.idr | 28 ++ .../Compiler/Xla/Client/Lib/Math.idr | 67 ++++ .../Compiler/Xla/Client/Lib/Matrix.idr | 42 ++ .../Compiler/Xla/Client/LocalClient.idr | 31 ++ .../Compiler/Xla/Client/XlaBuilder.idr | 359 ++++++++++++++++++ .../Compiler/Xla/Client/XlaComputation.idr | 26 ++ .../Xla/TensorFlow/Compiler/Xla/Literal.idr | 79 ++++ .../Compiler/Xla/Service/PlatformUtil.idr | 25 ++ .../Xla/TensorFlow/Compiler/Xla/Shape.idr | 27 ++ .../TensorFlow/Compiler/Xla}/ShapeUtil.idr | 22 +- .../TensorFlow/Compiler/Xla}/XlaData.idr | 6 +- .../Core/CommonRuntime/GPU/GPUInit.idr | 33 ++ .../Xla/TensorFlow/Core/Platform/Status.idr | 31 ++ .../TensorFlow/StreamExecutor/Platform.idr | 20 + src/Compiler/{FFI.idr => Xla/Util.idr} | 29 +- src/Primitive.idr | 6 +- src/Tensor.idr | 328 +++++++--------- 37 files changed, 1288 insertions(+), 585 deletions(-) create mode 100644 src/Compiler/Computation.idr create mode 100644 src/Compiler/LiteralRW.idr delete mode 100644 src/Compiler/TensorFlow/Compiler/XLA/Client/LocalClient.idr delete mode 100644 src/Compiler/TensorFlow/Compiler/XLA/Literal.idr delete mode 100644 src/Compiler/XLA.idr rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/Prim/TensorFlow/Compiler/Xla}/Client/ClientLibrary.idr (82%) rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/Prim/TensorFlow/Compiler/Xla}/Client/Lib/Math.idr (93%) rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/Prim/TensorFlow/Compiler/Xla}/Client/Lib/Matrix.idr (90%) create mode 100644 src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/LocalClient.idr rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/Prim/TensorFlow/Compiler/Xla}/Client/XlaBuilder.idr (88%) rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/Prim/TensorFlow/Compiler/Xla}/Client/XlaComputation.idr (82%) create mode 100644 src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/Prim/TensorFlow/Compiler/Xla}/Service/PlatformUtil.idr (87%) rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/Prim/TensorFlow/Compiler/Xla}/Shape.idr (84%) create mode 100644 src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr rename src/Compiler/{ => Xla/Prim}/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr (88%) rename src/Compiler/{ => Xla/Prim}/TensorFlow/Core/Platform/Status.idr (75%) create mode 100644 src/Compiler/Xla/Prim/Util.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/ClientLibrary.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/Math.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/Matrix.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/LocalClient.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaComputation.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Literal.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Service/PlatformUtil.idr create mode 100644 src/Compiler/Xla/TensorFlow/Compiler/Xla/Shape.idr rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/TensorFlow/Compiler/Xla}/ShapeUtil.idr (55%) rename src/Compiler/{TensorFlow/Compiler/XLA => Xla/TensorFlow/Compiler/Xla}/XlaData.idr (77%) create mode 100644 src/Compiler/Xla/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr create mode 100644 src/Compiler/Xla/TensorFlow/Core/Platform/Status.idr create mode 100644 src/Compiler/Xla/TensorFlow/StreamExecutor/Platform.idr rename src/Compiler/{FFI.idr => Xla/Util.idr} (71%) diff --git a/spidr.ipkg b/spidr.ipkg index 6bb02a73e..003875471 100644 --- a/spidr.ipkg +++ b/spidr.ipkg @@ -9,21 +9,39 @@ modules = BayesianOptimization.Acquisition, BayesianOptimization.Morphisms, - Compiler.FFI, + Compiler.Computation, Compiler.Graph, - Compiler.TensorFlow.Compiler.XLA.Client.Lib.Math, - Compiler.TensorFlow.Compiler.XLA.Client.Lib.Matrix, - Compiler.TensorFlow.Compiler.XLA.Client.ClientLibrary, - Compiler.TensorFlow.Compiler.XLA.Client.LocalClient, - Compiler.TensorFlow.Compiler.XLA.Client.XlaBuilder, - Compiler.TensorFlow.Compiler.XLA.Client.XlaComputation, - Compiler.TensorFlow.Compiler.XLA.Literal, - Compiler.TensorFlow.Compiler.XLA.Service.PlatformUtil, - Compiler.TensorFlow.Compiler.XLA.Shape, - Compiler.TensorFlow.Compiler.XLA.ShapeUtil, - Compiler.TensorFlow.Compiler.XLA.XlaData, - Compiler.TensorFlow.Core.CommonRuntime.GPU.GPUInit, - Compiler.XLA, + Compiler.LiteralRW, + + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Math, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Matrix, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.ClientLibrary, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.LocalClient, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaBuilder, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaComputation, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Literal, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Service.PlatformUtil, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Shape, + Compiler.Xla.Prim.TensorFlow.Compiler.Xla.ShapeUtil, + Compiler.Xla.Prim.TensorFlow.Core.CommonRuntime.GPU.GPUInit, + Compiler.Xla.Prim.TensorFlow.Core.Platform.Status, + Compiler.Xla.Prim.Util, + + Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Math, + Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Matrix, + Compiler.Xla.TensorFlow.Compiler.Xla.Client.ClientLibrary, + Compiler.Xla.TensorFlow.Compiler.Xla.Client.LocalClient, + Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder, + Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation, + Compiler.Xla.TensorFlow.Compiler.Xla.Literal, + Compiler.Xla.TensorFlow.Compiler.Xla.Service.PlatformUtil, + Compiler.Xla.TensorFlow.Compiler.Xla.Shape, + Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil, + Compiler.Xla.TensorFlow.Compiler.Xla.XlaData, + Compiler.Xla.TensorFlow.Core.CommonRuntime.GPU.GPUInit, + Compiler.Xla.TensorFlow.Core.Platform.Status, + Compiler.Xla.TensorFlow.StreamExecutor.Platform, + Compiler.Xla.Util, Data, Distribution, diff --git a/src/Compiler/Computation.idr b/src/Compiler/Computation.idr new file mode 100644 index 000000000..b2b3c98fd --- /dev/null +++ b/src/Compiler/Computation.idr @@ -0,0 +1,95 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Computation + +import Control.Monad.State +import Data.SortedMap + +import Data.Hashable + +import Compiler.Graph +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation +import Compiler.Xla.TensorFlow.Compiler.Xla.Shape +import Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil +import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData +import Types + +public export +data CachingBuilder : Type where + MkCachingBuilder : XlaBuilder -> SortedMap Bits64 XlaOp -> CachingBuilder + +public export +Computation : Type -> Type +Computation = StateT CachingBuilder IO + +export +cached : Graph -> Computation XlaOp -> Computation XlaOp +cached graph xs = assert_total $ let graphHash = hash graph in do + builder <- get + case cacheLookup builder graphHash of + Just op => pure op + Nothing => do + op <- xs + builder <- get + put (cacheInsert builder graphHash op) + pure op + + where + cacheInsert : CachingBuilder -> Bits64 -> XlaOp -> CachingBuilder + cacheInsert (MkCachingBuilder builder cache) key xlaOp = + MkCachingBuilder builder (insert key xlaOp cache) + + cacheLookup : CachingBuilder -> Bits64 -> Maybe XlaOp + cacheLookup (MkCachingBuilder _ cache) key = lookup key cache + +export +build : HasIO io => String -> Computation XlaOp -> io XlaComputation +build computationName x = do + builder <- mkXlaBuilder computationName + MkCachingBuilder builder _ <- liftIO $ execStateT (MkCachingBuilder builder empty) x + build builder + +export +buildWithSubBuilder : + String -> List (Computation XlaOp) -> Computation XlaOp -> Computation XlaComputation +buildWithSubBuilder computationName computationArguments computationResult = do + MkCachingBuilder builder _ <- get + subBuilder <- createSubBuilder builder computationName + let cachingSubBuilder = MkCachingBuilder subBuilder empty + allOps = sequence_ (computationArguments ++ [computationResult]) + MkCachingBuilder subBuilder _ <- liftIO $ execStateT cachingSubBuilder allOps + build subBuilder + +export +opToString : Computation XlaOp -> String +opToString x = unsafePerformIO $ do + builder <- mkXlaBuilder "toString" + (MkCachingBuilder builder _, xlaOp) <- runStateT (MkCachingBuilder builder empty) x + pure $ opToString builder xlaOp + +export +parameter : Primitive dtype => Nat -> Types.Shape -> String -> (Graph, Computation XlaOp) +parameter position shape name = + let graph = Parameter {dtype} shape position + + param : Computation XlaOp + param = do + MkCachingBuilder builder _ <- get + xlaShape <- mkShape {dtype} shape + cached graph $ parameter builder position xlaShape name + + in (graph, param) diff --git a/src/Compiler/Graph.idr b/src/Compiler/Graph.idr index d2bc275a3..e272e5399 100644 --- a/src/Compiler/Graph.idr +++ b/src/Compiler/Graph.idr @@ -15,11 +15,11 @@ limitations under the License. --} module Compiler.Graph -import Primitive import Data.Hashable -import Data.Stream + +import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData import Types -import Util +import Util.Hashable ||| A `Graph` represents a computational graph used to compute a tensor value. It is defined by ||| the following proprty: For any two `Graph`s gx and gy that compute tensors x and y respectively, diff --git a/src/Compiler/LiteralRW.idr b/src/Compiler/LiteralRW.idr new file mode 100644 index 000000000..b68bdf9c1 --- /dev/null +++ b/src/Compiler/LiteralRW.idr @@ -0,0 +1,76 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.LiteralRW + +import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData +import Compiler.Xla.TensorFlow.Compiler.Xla.Literal +import Literal +import Util + +range : (n : Nat) -> Literal [n] Nat +range n = impl n [] + where + impl : (p : Nat) -> Literal [q] Nat -> Literal [q + p] Nat + impl Z xs = rewrite plusZeroRightNeutral q in xs + impl (S p) xs = rewrite sym $ plusSuccRightSucc q p in impl p (Scalar p :: xs) + +indexed : {shape : _} -> Literal shape (List Nat) +indexed = go shape [] + where + concat : Literal [d] (Literal ds a) -> Literal (d :: ds) a + concat [] = [] + concat (Scalar x :: xs) = x :: concat xs + + go : (shape : Types.Shape) -> List Nat -> Literal shape (List Nat) + go [] idxs = Scalar idxs + go (0 :: _) _ = [] + go (S d :: ds) idxs = concat $ map (\i => go ds (snoc idxs i)) (range (S d)) + +export +interface Primitive dtype => LiteralRW dtype ty where + set : Literal -> List Nat -> ty -> IO () + get : Literal -> List Nat -> ty + +export +write : (HasIO io, LiteralRW dtype a) => {shape : _} -> Literal shape a -> io Literal +write xs = liftIO $ do + literal <- allocLiteral {dtype} shape + sequence_ [| (\idxs => set {dtype} literal idxs) indexed xs |] + pure literal + +export +read : LiteralRW dtype a => Literal -> {shape : _} -> Literal shape a +read lit = map (get {dtype} lit) indexed + +export +LiteralRW PRED Bool where + set = set + get = get + +export +LiteralRW F64 Double where + set = set + get = get + +export +LiteralRW S32 Int where + set = set + get = get + +export +LiteralRW U32 Nat where + set lit idx x = Int.set lit idx (cast x) + get = cast .: Int.get diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Client/LocalClient.idr b/src/Compiler/TensorFlow/Compiler/XLA/Client/LocalClient.idr deleted file mode 100644 index d60acfa22..000000000 --- a/src/Compiler/TensorFlow/Compiler/XLA/Client/LocalClient.idr +++ /dev/null @@ -1,44 +0,0 @@ -{-- -Copyright 2022 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. ---} -module Compiler.TensorFlow.Compiler.XLA.Client.LocalClient - -import System.FFI - -import Compiler.FFI -import Compiler.TensorFlow.Compiler.XLA.Client.XlaComputation -import Compiler.TensorFlow.Compiler.XLA.Literal - -public export -LocalClient : Type -LocalClient = Struct "LocalClient" [] - -%foreign (libxla "LocalClient_TransferToServer") -prim__transferToServerImpl : LocalClient -> GCAnyPtr -> PrimIO AnyPtr - -export -prim__transferToServer : LocalClient -> GCAnyPtr -> IO GCAnyPtr -prim__transferToServer client literal = do - globalData <- primIO (prim__transferToServerImpl client literal) - onCollectAny globalData free - -%foreign (libxla "LocalClient_ExecuteAndTransfer") -prim__executeAndTransferImpl : LocalClient -> GCAnyPtr -> AnyPtr -> Int -> PrimIO AnyPtr - -export -prim__executeAndTransfer : LocalClient -> GCAnyPtr -> AnyPtr -> Int -> IO GCAnyPtr -prim__executeAndTransfer client computation arguments argumentsLen = do - literal <- primIO (prim__executeAndTransferImpl client computation arguments argumentsLen) - onCollectAny literal Literal.delete diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Literal.idr b/src/Compiler/TensorFlow/Compiler/XLA/Literal.idr deleted file mode 100644 index b11bbae4f..000000000 --- a/src/Compiler/TensorFlow/Compiler/XLA/Literal.idr +++ /dev/null @@ -1,115 +0,0 @@ -{-- -Copyright 2021 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. ---} -module Compiler.TensorFlow.Compiler.XLA.Literal - -import System.FFI - -import Compiler.FFI -import Compiler.TensorFlow.Compiler.XLA.Shape -import Compiler.TensorFlow.Compiler.XLA.ShapeUtil -import Compiler.TensorFlow.Compiler.XLA.XlaData -import Literal -import Types -import Util - -export -interface Primitive dtype => LiteralPrimitiveRW dtype ty where - set : GCAnyPtr -> GCPtr Int -> ty -> PrimIO () - get : GCAnyPtr -> GCPtr Int -> ty - -export -%foreign (libxla "Literal_new") -prim__allocLiteral : GCAnyPtr -> PrimIO AnyPtr - -%foreign (libxla "Literal_delete") -prim__delete : AnyPtr -> PrimIO () - -export -delete : AnyPtr -> IO () -delete = primIO . prim__delete - -%foreign (libxla "Literal_Set_bool") -prim__literalSetBool : GCAnyPtr -> GCPtr Int -> Int -> PrimIO () - -%foreign (libxla "Literal_Get_bool") -literalGetBool : GCAnyPtr -> GCPtr Int -> Int - -export -LiteralPrimitiveRW PRED Bool where - set lit idxs x = prim__literalSetBool lit idxs (if x then 1 else 0) - get lit idxs = cIntToBool (literalGetBool lit idxs) - -%foreign (libxla "Literal_Set_double") -prim__literalSetDouble : GCAnyPtr -> GCPtr Int -> Double -> PrimIO () - -%foreign (libxla "Literal_Get_double") -literalGetDouble : GCAnyPtr -> GCPtr Int -> Double - -export -LiteralPrimitiveRW F64 Double where - set = prim__literalSetDouble - get = literalGetDouble - -%foreign (libxla "Literal_Set_int") -prim__literalSetInt : GCAnyPtr -> GCPtr Int -> Int -> PrimIO () - -%foreign (libxla "Literal_Get_int") -literalGetInt : GCAnyPtr -> GCPtr Int -> Int - -export -LiteralPrimitiveRW S32 Int where - set = prim__literalSetInt - get = literalGetInt - -export -LiteralPrimitiveRW U32 Nat where - set lit idx x = prim__literalSetInt lit idx (cast x) - get = cast .: literalGetInt - -enumerate : {d : _} -> {ds : _} -> Literal (d :: ds) dtype -> Vect d (Nat, Literal ds dtype) -enumerate xs = Vect.enumerate (toVect xs) where - toVect : {0 d : _} -> Literal (d :: ds) dtype -> Vect d (Literal ds dtype) - toVect {d=0} [] = [] - toVect (x :: xs) = x :: toVect xs - -populateLiteral : {shape : _} -> LiteralPrimitiveRW dtype a => Literal shape a -> GCAnyPtr -> IO () -populateLiteral {shape} lit ptr = impl shape [] lit where - impl : (shape', idxs : Shape) -> Literal shape' a -> IO () - impl [] idxs (Scalar x) = primIO (set {dtype} ptr !(mkIntArray idxs) x) - impl (0 :: _) _ _ = pure () - impl (S _ :: ds) idxs (x :: xs) = - traverse_ (\(idx, ys) => impl ds (idxs ++ [idx]) ys) (enumerate (x :: xs)) - -export -mkLiteral : HasIO io => LiteralPrimitiveRW dtype a => {shape : _} -> Literal shape a -> io GCAnyPtr -mkLiteral xs = do - xlaShape <- mkShape {dtype} shape - literal <- primIO $ prim__allocLiteral xlaShape - literal <- onCollectAny literal Literal.delete - liftIO $ populateLiteral {dtype} xs literal - pure literal - -concat : Vect d (Literal ds a) -> Literal (d :: ds) a -concat [] = [] -concat (x :: xs) = x :: concat xs - -export -toLiteral : {shape : _} -> GCAnyPtr -> LiteralPrimitiveRW dtype a => Literal shape a -toLiteral lit = impl shape [] where - impl : (shape', idxs : Shape) -> Literal shape' a - impl [] idxs = Scalar (unsafePerformIO $ (map (get {dtype} lit) (mkIntArray idxs))) - impl (0 :: ds) idxs = [] - impl (S d :: ds) idxs = concat $ map (\i => impl ds (snoc idxs i)) (range (S d)) diff --git a/src/Compiler/XLA.idr b/src/Compiler/XLA.idr deleted file mode 100644 index 562c50531..000000000 --- a/src/Compiler/XLA.idr +++ /dev/null @@ -1,103 +0,0 @@ -{-- -Copyright 2022 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. ---} -module Compiler.XLA - -import Data.Hashable -import Control.Monad.State -import Data.SortedMap - -import Compiler.Graph -import Compiler.TensorFlow.Compiler.XLA.Client.XlaBuilder -import Compiler.TensorFlow.Compiler.XLA.Client.XlaComputation -import Compiler.TensorFlow.Compiler.XLA.ShapeUtil -import Compiler.TensorFlow.Compiler.XLA.XlaData -import Types - -public export -data XlaBuilder : Type where - MkXlaBuilder : GCAnyPtr -> SortedMap Bits64 GCAnyPtr -> XlaBuilder - --- note, the type of thing pointed to by the GCAnyPtr can be anything, and must be inferred from the --- context. -public export -ComputationComponent : Type -ComputationComponent = StateT XlaBuilder IO GCAnyPtr - -cacheInsert : XlaBuilder -> Bits64 -> GCAnyPtr -> XlaBuilder -cacheInsert (MkXlaBuilder ptr cache) k v = MkXlaBuilder ptr (insert k v cache) - -cacheLookup : XlaBuilder -> Bits64 -> Maybe GCAnyPtr -cacheLookup (MkXlaBuilder _ cache) k = lookup k cache - -export -cached : Graph -> ComputationComponent -> ComputationComponent -cached graph xs = let graphHash = assert_total $ hash graph in do - builder <- get - case cacheLookup builder graphHash of - Just opPtr => pure opPtr - Nothing => do - opPtr <- xs - builder <- get - put (cacheInsert builder graphHash opPtr) - pure opPtr - -mkXlaBuilder : String -> IO XlaBuilder -mkXlaBuilder computationName = do - ptr <- primIO (prim__mkXlaBuilder computationName) - ptr <- onCollectAny ptr XlaBuilder.delete - pure (MkXlaBuilder ptr empty) - -export -build : String -> ComputationComponent -> IO GCAnyPtr -build computationName x = do - builder <- mkXlaBuilder computationName - (MkXlaBuilder ptr _) <- execStateT builder x - onCollectAny (prim__build ptr) XlaComputation.delete - -export -buildWithSubBuilder : - String -> List ComputationComponent -> ComputationComponent -> ComputationComponent -buildWithSubBuilder computationName args res = do - MkXlaBuilder ptr _ <- get - subPtr <- primIO (prim__createSubBuilder ptr computationName) - subPtr <- onCollectAny subPtr XlaBuilder.delete - let subBuilder = MkXlaBuilder subPtr empty - allOps = sequence_ (args ++ [res]) - MkXlaBuilder subPtr _ <- liftIO $ execStateT subBuilder allOps - let computation = prim__build subPtr - onCollectAny computation XlaComputation.delete - -export -prim__opToString : ComputationComponent -> IO String -prim__opToString xs = do - builder <- mkXlaBuilder "toString" - (MkXlaBuilder ptr _, op) <- runStateT builder xs - pure (XlaBuilder.prim__opToString ptr op) - -export -prim__constantLiteral : GCAnyPtr -> Graph -> ComputationComponent -prim__constantLiteral literal graph = do - MkXlaBuilder ptr _ <- get - op <- primIO $ prim__constantLiteral ptr literal - onCollectAny op XlaOp.delete - -export -prim__parameter : Primitive dtype => Int -> Shape -> String -> ComputationComponent -prim__parameter position shape name = do - (MkXlaBuilder ptr _) <- get - xlaShape <- mkShape {dtype} shape - op <- primIO $ prim__parameter ptr position xlaShape name - onCollectAny op XlaOp.delete diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Client/ClientLibrary.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/ClientLibrary.idr similarity index 82% rename from src/Compiler/TensorFlow/Compiler/XLA/Client/ClientLibrary.idr rename to src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/ClientLibrary.idr index bba04649e..d5bd3bf0a 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/Client/ClientLibrary.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/ClientLibrary.idr @@ -13,13 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Compiler.XLA.Client.ClientLibrary +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.ClientLibrary import System.FFI -import Compiler.FFI -import Compiler.TensorFlow.Compiler.XLA.Client.LocalClient +import Compiler.Xla.Prim.Util export %foreign (libxla "ClientLibrary_GetOrCreateLocalClient") -prim__getOrCreateLocalClient : AnyPtr -> AnyPtr -> Int -> PrimIO LocalClient +prim__getOrCreateLocalClient : AnyPtr -> AnyPtr -> Int -> PrimIO AnyPtr diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Client/Lib/Math.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/Math.idr similarity index 93% rename from src/Compiler/TensorFlow/Compiler/XLA/Client/Lib/Math.idr rename to src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/Math.idr index c6eb2b335..911523d8b 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/Client/Lib/Math.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/Math.idr @@ -13,11 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Compiler.XLA.Client.Lib.Math +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Math import System.FFI -import Compiler.FFI +import Compiler.Xla.Prim.Util export %foreign (libxla "Square") diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Client/Lib/Matrix.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/Matrix.idr similarity index 90% rename from src/Compiler/TensorFlow/Compiler/XLA/Client/Lib/Matrix.idr rename to src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/Matrix.idr index 1eb8bb727..d5c36e0e3 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/Client/Lib/Matrix.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/Matrix.idr @@ -13,11 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Compiler.XLA.Client.Lib.Matrix +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Matrix import System.FFI -import Compiler.FFI +import Compiler.Xla.Prim.Util export %foreign (libxla "IdentityMatrix") diff --git a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/LocalClient.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/LocalClient.idr new file mode 100644 index 000000000..15ab2fa9f --- /dev/null +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/LocalClient.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.LocalClient + +import System.FFI + +import Compiler.Xla.Prim.Util + +%foreign (libxla "LocalClient_TransferToServer") +prim__transferToServer : AnyPtr -> GCAnyPtr -> PrimIO AnyPtr + +export +%foreign (libxla "LocalClient_ExecuteAndTransfer") +prim__executeAndTransfer : AnyPtr -> GCAnyPtr -> AnyPtr -> Int -> PrimIO AnyPtr diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Client/XlaBuilder.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr similarity index 88% rename from src/Compiler/TensorFlow/Compiler/XLA/Client/XlaBuilder.idr rename to src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr index b321b7d83..11b7dc9ee 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/Client/XlaBuilder.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr @@ -13,31 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Compiler.XLA.Client.XlaBuilder +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaBuilder import System.FFI -import Compiler.FFI -import Compiler.Graph -import Compiler.TensorFlow.Compiler.XLA.Client.XlaComputation -import Compiler.TensorFlow.Compiler.XLA.Literal -import Types -import Util - -{- - - - - XlaBuilder - - - -} +import Compiler.Xla.Prim.Util namespace XlaBuilder + export %foreign (libxla "XlaBuilder_delete") prim__delete : AnyPtr -> PrimIO () - export - delete : AnyPtr -> IO () - delete = primIO . prim__delete - export %foreign (libxla "XlaBuilder_new") prim__mkXlaBuilder : String -> PrimIO AnyPtr @@ -54,34 +40,19 @@ export %foreign (libxla "XlaBuilder_OpToString") prim__opToString : GCAnyPtr -> GCAnyPtr -> String -{- - - - - XlaOp - - - -} - namespace XlaOp + export %foreign (libxla "XlaOp_delete") prim__delete : AnyPtr -> PrimIO () - export - delete : AnyPtr -> IO () - delete = primIO . prim__delete - +export %foreign (libxla "sizeof_XlaOp") sizeOfXlaOp : Int +export %foreign (libxla "set_array_XlaOp") prim__setArrayXlaOp : AnyPtr -> Int -> GCAnyPtr -> PrimIO () -export -mkXlaOpArray : HasIO io => List GCAnyPtr -> io GCAnyPtr -mkXlaOpArray ops = do - arr <- malloc (cast (length ops) * sizeOfXlaOp) - traverse_ (\(idx, op) => - primIO $ prim__setArrayXlaOp arr (cast idx) op) (enumerate (fromList ops)) - onCollectAny arr free - export %foreign (libxla "Parameter") prim__parameter : GCAnyPtr -> Int -> GCAnyPtr -> String -> PrimIO AnyPtr diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Client/XlaComputation.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaComputation.idr similarity index 82% rename from src/Compiler/TensorFlow/Compiler/XLA/Client/XlaComputation.idr rename to src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaComputation.idr index ea71ae866..26519f137 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/Client/XlaComputation.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaComputation.idr @@ -13,15 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Compiler.XLA.Client.XlaComputation +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaComputation import System.FFI -import Compiler.FFI +import Compiler.Xla.Prim.Util +export %foreign (libxla "XlaComputation_delete") prim__delete : AnyPtr -> PrimIO () - -export -delete : AnyPtr -> IO () -delete = primIO . prim__delete diff --git a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr new file mode 100644 index 000000000..043d9245c --- /dev/null +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr @@ -0,0 +1,52 @@ +{-- +Copyright 2021 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Literal + +import System.FFI + +import Compiler.Xla.Prim.Util + +export +%foreign (libxla "Literal_new") +prim__allocLiteral : GCAnyPtr -> PrimIO AnyPtr + +export +%foreign (libxla "Literal_delete") +prim__delete : AnyPtr -> PrimIO () + +export +%foreign (libxla "Literal_Set_bool") +prim__literalSetBool : GCAnyPtr -> GCPtr Int -> Int -> PrimIO () + +export +%foreign (libxla "Literal_Get_bool") +literalGetBool : GCAnyPtr -> GCPtr Int -> Int + +export +%foreign (libxla "Literal_Set_double") +prim__literalSetDouble : GCAnyPtr -> GCPtr Int -> Double -> PrimIO () + +export +%foreign (libxla "Literal_Get_double") +literalGetDouble : GCAnyPtr -> GCPtr Int -> Double + +export +%foreign (libxla "Literal_Set_int") +prim__literalSetInt : GCAnyPtr -> GCPtr Int -> Int -> PrimIO () + +export +%foreign (libxla "Literal_Get_int") +literalGetInt : GCAnyPtr -> GCPtr Int -> Int diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Service/PlatformUtil.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Service/PlatformUtil.idr similarity index 87% rename from src/Compiler/TensorFlow/Compiler/XLA/Service/PlatformUtil.idr rename to src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Service/PlatformUtil.idr index 605b14328..83a0e849c 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/Service/PlatformUtil.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Service/PlatformUtil.idr @@ -13,11 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Compiler.XLA.Service.PlatformUtil +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Service.PlatformUtil import System.FFI -import Compiler.FFI +import Compiler.Xla.Prim.Util export %foreign (libxla "PlatformUtil_GetPlatform") diff --git a/src/Compiler/TensorFlow/Compiler/XLA/Shape.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Shape.idr similarity index 84% rename from src/Compiler/TensorFlow/Compiler/XLA/Shape.idr rename to src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Shape.idr index 94b87eaf3..64827ca22 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/Shape.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Shape.idr @@ -13,15 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Compiler.XLA.Shape +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Shape import System.FFI -import Compiler.FFI +import Compiler.Xla.Prim.Util +export %foreign (libxla "Shape_delete") prim__delete : AnyPtr -> PrimIO () - -export -delete : AnyPtr -> IO () -delete = primIO . prim__delete diff --git a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr new file mode 100644 index 000000000..75fecd19f --- /dev/null +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr @@ -0,0 +1,24 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.ShapeUtil + +import System.FFI + +import Compiler.Xla.Prim.Util + +export +%foreign (libxla "MakeShape") +prim__mkShape : Int -> GCPtr Int -> Int -> PrimIO AnyPtr diff --git a/src/Compiler/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr b/src/Compiler/Xla/Prim/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr similarity index 88% rename from src/Compiler/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr rename to src/Compiler/Xla/Prim/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr index b1a8a78c6..cf08c6184 100644 --- a/src/Compiler/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr @@ -13,11 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Core.CommonRuntime.GPU.GPUInit +module Compiler.Xla.Prim.TensorFlow.Core.CommonRuntime.GPU.GPUInit import System.FFI -import Compiler.FFI +import Compiler.Xla.Prim.Util export %foreign (libxla "ValidateGPUMachineManager") diff --git a/src/Compiler/TensorFlow/Core/Platform/Status.idr b/src/Compiler/Xla/Prim/TensorFlow/Core/Platform/Status.idr similarity index 75% rename from src/Compiler/TensorFlow/Core/Platform/Status.idr rename to src/Compiler/Xla/Prim/TensorFlow/Core/Platform/Status.idr index 9f8cabe0c..e0609cfe7 100644 --- a/src/Compiler/TensorFlow/Core/Platform/Status.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Core/Platform/Status.idr @@ -13,22 +13,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Core.Platform.Status +module Compiler.Xla.Prim.TensorFlow.Core.Platform.Status import System.FFI -import Compiler.FFI +import Compiler.Xla.Prim.Util +export %foreign (libxla "Status_delete") prim__delete : AnyPtr -> PrimIO () export -delete : AnyPtr -> IO () -delete = primIO . prim__delete - %foreign (libxla "Status_ok") -prim__okImpl : GCAnyPtr -> Int - -export -prim__ok : GCAnyPtr -> Bool -prim__ok = cIntToBool . prim__okImpl +prim__ok : GCAnyPtr -> Int diff --git a/src/Compiler/Xla/Prim/Util.idr b/src/Compiler/Xla/Prim/Util.idr new file mode 100644 index 000000000..69df2c87c --- /dev/null +++ b/src/Compiler/Xla/Prim/Util.idr @@ -0,0 +1,30 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.Prim.Util + +import System.FFI + +public export +libxla : String -> String +libxla fname = "C:" ++ fname ++ ",libc_xla_extension" + +export +%foreign (libxla "sizeof_int") +sizeofInt : Int + +export +%foreign (libxla "set_array_int") +prim__setArrayInt : Ptr Int -> Int -> Int -> PrimIO () diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/ClientLibrary.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/ClientLibrary.idr new file mode 100644 index 000000000..f5cf1a552 --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/ClientLibrary.idr @@ -0,0 +1,28 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Client.ClientLibrary + +import System.FFI + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.ClientLibrary +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.LocalClient +import Compiler.Xla.TensorFlow.StreamExecutor.Platform + +export +getOrCreateLocalClient : Platform -> IO LocalClient +getOrCreateLocalClient (MkPlatform platform) = do + client <- primIO $ prim__getOrCreateLocalClient platform prim__getNullAnyPtr 0 + pure (MkLocalClient client) diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/Math.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/Math.idr new file mode 100644 index 000000000..9be943824 --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/Math.idr @@ -0,0 +1,67 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Math + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Math +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder + +export +square : HasIO io => XlaOp -> io XlaOp +square = unaryOp prim__square + +export +reciprocal : HasIO io => XlaOp -> io XlaOp +reciprocal = unaryOp prim__reciprocal + +export +acos : HasIO io => XlaOp -> io XlaOp +acos = unaryOp prim__acos + +export +asin : HasIO io => XlaOp -> io XlaOp +asin = unaryOp prim__asin + +export +atan : HasIO io => XlaOp -> io XlaOp +atan = unaryOp prim__atan + +export +tan : HasIO io => XlaOp -> io XlaOp +tan = unaryOp prim__tan + +export +acosh : HasIO io => XlaOp -> io XlaOp +acosh = unaryOp prim__acosh + +export +asinh : HasIO io => XlaOp -> io XlaOp +asinh = unaryOp prim__asinh + +export +atanh : HasIO io => XlaOp -> io XlaOp +atanh = unaryOp prim__atanh + +export +cosh : HasIO io => XlaOp -> io XlaOp +cosh = unaryOp prim__cosh + +export +sinh : HasIO io => XlaOp -> io XlaOp +sinh = unaryOp prim__sinh + +export +erf : HasIO io => XlaOp -> io XlaOp +erf = unaryOp prim__erf diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/Matrix.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/Matrix.idr new file mode 100644 index 000000000..14f3e033a --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/Matrix.idr @@ -0,0 +1,42 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Matrix + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Matrix +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData +import Compiler.Xla.Util + +export +identityMatrix : HasIO io => Primitive dtype => XlaBuilder -> Nat -> Nat -> io XlaOp +identityMatrix (MkXlaBuilder builder) m n = do + opPtr <- primIO $ prim__identityMatrix builder (xlaIdentifier {dtype}) (cast m) (cast n) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +getMatrixDiagonal : HasIO io => XlaOp -> io XlaOp +getMatrixDiagonal (MkXlaOp x) = do + opPtr <- primIO $ prim__getMatrixDiagonal x + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +triangle : HasIO io => XlaOp -> Bool -> io XlaOp +triangle (MkXlaOp x) lower = do + opPtr <- primIO $ prim__triangle x (boolToCInt lower) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/LocalClient.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/LocalClient.idr new file mode 100644 index 000000000..c1a12946e --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/LocalClient.idr @@ -0,0 +1,31 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Client.LocalClient + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.LocalClient +import Compiler.Xla.TensorFlow.Compiler.Xla.Literal +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation + +public export +data LocalClient : Type where + MkLocalClient : AnyPtr -> LocalClient + +export +executeAndTransfer : LocalClient -> XlaComputation -> IO Literal +executeAndTransfer (MkLocalClient client) (MkXlaComputation computation) = do + literal <- primIO $ prim__executeAndTransfer client computation prim__getNullAnyPtr 0 + literal <- onCollectAny literal Literal.delete + pure (MkLiteral literal) diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr new file mode 100644 index 000000000..82fff0343 --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr @@ -0,0 +1,359 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder + +import System.FFI + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation +import Compiler.Xla.TensorFlow.Compiler.Xla.Literal +import Compiler.Xla.TensorFlow.Compiler.Xla.Shape +import Compiler.Xla.Util +import Types +import Util + +public export +data XlaBuilder : Type where + MkXlaBuilder : GCAnyPtr -> XlaBuilder + +public export +data XlaOp : Type where + MkXlaOp : GCAnyPtr -> XlaOp + +namespace XlaBuilder + export + delete : HasIO io => AnyPtr -> io () + delete = primIO . XlaBuilder.prim__delete + +namespace XlaOp + export + delete : HasIO io => AnyPtr -> io () + delete = primIO . XlaOp.prim__delete + +export +mkXlaBuilder : HasIO io => String -> io XlaBuilder +mkXlaBuilder computationName = do + ptr <- primIO (prim__mkXlaBuilder computationName) + ptr <- onCollectAny ptr XlaBuilder.delete + pure (MkXlaBuilder ptr) + +export +createSubBuilder : HasIO io => XlaBuilder -> String -> io XlaBuilder +createSubBuilder (MkXlaBuilder builderPtr) computationName = do + subBuilderPtr <- primIO (prim__createSubBuilder builderPtr computationName) + subBuilderPtr <- onCollectAny subBuilderPtr XlaBuilder.delete + pure (MkXlaBuilder subBuilderPtr) + +export +build : HasIO io => XlaBuilder -> io XlaComputation +build (MkXlaBuilder ptr) = do + let computationPtr = prim__build ptr + computationPtr <- onCollectAny computationPtr XlaComputation.delete + pure (MkXlaComputation computationPtr) + +export +opToString : XlaBuilder -> XlaOp -> String +opToString (MkXlaBuilder builderPtr) (MkXlaOp opPtr) = prim__opToString builderPtr opPtr + +data XlaOpArray : Type where + MkXlaOpArray : GCAnyPtr -> XlaOpArray + +export +mkXlaOpArray : HasIO io => List XlaOp -> io XlaOpArray +mkXlaOpArray ops = do + arr <- malloc (cast (length ops) * sizeOfXlaOp) + traverse_ (\(idx, (MkXlaOp opPtr)) => + primIO $ prim__setArrayXlaOp arr (cast idx) opPtr) (enumerate (fromList ops)) + arr <- onCollectAny arr free + pure (MkXlaOpArray arr) + +export +parameter : HasIO io => XlaBuilder -> Nat -> Xla.Shape -> String -> io XlaOp +parameter (MkXlaBuilder builderPtr) position (MkShape shapePtr) name = do + opPtr <- primIO $ prim__parameter builderPtr (cast position) shapePtr name + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +constantLiteral : HasIO io => XlaBuilder -> Literal -> io XlaOp +constantLiteral (MkXlaBuilder builderPtr) (MkLiteral literalPtr) = do + opPtr <- primIO (prim__constantLiteral builderPtr literalPtr) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +broadcast : HasIO io => XlaOp -> List Nat -> io XlaOp +broadcast (MkXlaOp opPtr) broadcastSizes = do + MkIntArray broadcastSizesArrayPtr <- mkIntArray broadcastSizes + opPtr <- primIO $ prim__broadcast opPtr broadcastSizesArrayPtr (cast $ length broadcastSizes) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +broadcastInDim : HasIO io => XlaOp -> List Nat -> List Nat -> io XlaOp +broadcastInDim (MkXlaOp opPtr) outDimSize broadcastDimensions = do + MkIntArray outDimSizeArrayPtr <- mkIntArray outDimSize + MkIntArray broadcastDimensionsArrayPtr <- mkIntArray broadcastDimensions + opPtr <- primIO $ prim__broadcastInDim + opPtr + outDimSizeArrayPtr (cast $ length outDimSize) + broadcastDimensionsArrayPtr (cast $ length broadcastDimensions) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +reshape : HasIO io => XlaOp -> List Nat -> List Nat -> io XlaOp +reshape (MkXlaOp opPtr) dimensions newSizes = do + MkIntArray dimensionsArrayPtr <- mkIntArray dimensions + MkIntArray newSizesArrayPtr <- mkIntArray newSizes + opPtr <- primIO $ prim__reshape + opPtr + dimensionsArrayPtr (cast $ length dimensions) + newSizesArrayPtr (cast $ length newSizes) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +slice : HasIO io => XlaOp -> List Nat -> List Nat -> List Nat -> io XlaOp +slice (MkXlaOp opPtr) startIndices limitIndices strides = do + MkIntArray startIndicesArrayPtr <- mkIntArray startIndices + MkIntArray limitIndicesArrayPtr <- mkIntArray limitIndices + MkIntArray stridesArrayPtr <- mkIntArray strides + let rank = cast (length startIndices) + opPtr <- primIO $ prim__slice + opPtr + startIndicesArrayPtr rank + limitIndicesArrayPtr rank + stridesArrayPtr rank + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +concatInDim : + HasIO io => + XlaBuilder -> + (operands : List XlaOp) -> + {auto 0 _ : NonEmpty operands} -> + Nat -> + io XlaOp +concatInDim (MkXlaBuilder builder) operands dimension = do + MkXlaOpArray xlaOpArrayPtr <- mkXlaOpArray operands + opPtr <- primIO $ prim__concatInDim + builder xlaOpArrayPtr (cast $ length operands) (cast dimension) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +select : HasIO io => XlaOp -> XlaOp -> XlaOp -> io XlaOp +select (MkXlaOp pred) (MkXlaOp onTrue) (MkXlaOp onFalse) = do + opPtr <- primIO $ prim__select pred onTrue onFalse + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +binaryOp : HasIO io => (GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr) -> XlaOp -> XlaOp -> io XlaOp +binaryOp prim__f (MkXlaOp x) (MkXlaOp y) = do + opPtr <- primIO $ prim__f x y + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +eq : HasIO io => XlaOp -> XlaOp -> io XlaOp +eq = binaryOp prim__eq + +export +ne : HasIO io => XlaOp -> XlaOp -> io XlaOp +ne = binaryOp prim__ne + +export +ge : HasIO io => XlaOp -> XlaOp -> io XlaOp +ge = binaryOp prim__ge + +export +gt : HasIO io => XlaOp -> XlaOp -> io XlaOp +gt = binaryOp prim__gt + +export +lt : HasIO io => XlaOp -> XlaOp -> io XlaOp +lt = binaryOp prim__lt + +export +le : HasIO io => XlaOp -> XlaOp -> io XlaOp +le = binaryOp prim__le + +export +dot : HasIO io => XlaOp -> XlaOp -> io XlaOp +dot = binaryOp prim__dot + +public export +data Transpose = NoTranspose | Transpose_ | Adjoint + +export +triangularSolve : HasIO io => XlaOp -> XlaOp -> Bool -> Bool -> Bool -> Transpose -> io XlaOp +triangularSolve (MkXlaOp a) (MkXlaOp b) leftSide lower unitDiagonal transposeA = do + let transposeA : Int = case transposeA of + NoTranspose => 1 + Transpose_ => 2 + Adjoint => 3 + opPtr <- primIO $ prim__triangularSolve + a b (boolToCInt leftSide) (boolToCInt lower) (boolToCInt unitDiagonal) transposeA + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +cholesky : HasIO io => XlaOp -> Bool -> io XlaOp +cholesky (MkXlaOp a) lower = do + opPtr <- primIO $ prim__cholesky a (boolToCInt lower) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +add : HasIO io => XlaOp -> XlaOp -> io XlaOp +add = binaryOp prim__add + +export +sub : HasIO io => XlaOp -> XlaOp -> io XlaOp +sub = binaryOp prim__sub + +export +mul : HasIO io => XlaOp -> XlaOp -> io XlaOp +mul = binaryOp prim__mul + +export +div : HasIO io => XlaOp -> XlaOp -> io XlaOp +div = binaryOp prim__div + +export +max : HasIO io => XlaOp -> XlaOp -> io XlaOp +max = binaryOp prim__max + +export +min : HasIO io => XlaOp -> XlaOp -> io XlaOp +min = binaryOp prim__min + +export +and : HasIO io => XlaOp -> XlaOp -> io XlaOp +and = binaryOp prim__and + +export +or : HasIO io => XlaOp -> XlaOp -> io XlaOp +or = binaryOp prim__or + +export +unaryOp : HasIO io => (GCAnyPtr -> PrimIO AnyPtr) -> XlaOp -> io XlaOp +unaryOp prim__f (MkXlaOp x) = do + opPtr <- primIO $ prim__f x + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +not : HasIO io => XlaOp -> io XlaOp +not = unaryOp prim__not + +export +reduce : HasIO io => XlaOp -> XlaOp -> XlaComputation -> List Nat -> io XlaOp +reduce (MkXlaOp operand) (MkXlaOp initValue) (MkXlaComputation computation) dimensions = do + MkIntArray dimensionsIntArrayPtr <- mkIntArray dimensions + opPtr <- primIO $ prim__reduce + operand initValue computation dimensionsIntArrayPtr (cast $ length dimensions) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +abs : HasIO io => XlaOp -> io XlaOp +abs = unaryOp prim__abs + +export +exp : HasIO io => XlaOp -> io XlaOp +exp = unaryOp prim__exp + +export +floor : HasIO io => XlaOp -> io XlaOp +floor = unaryOp prim__floor + +export +ceil : HasIO io => XlaOp -> io XlaOp +ceil = unaryOp prim__ceil + +export +log : HasIO io => XlaOp -> io XlaOp +log = unaryOp prim__log + +export +logistic : HasIO io => XlaOp -> io XlaOp +logistic = unaryOp prim__logistic + +export +cos : HasIO io => XlaOp -> io XlaOp +cos = unaryOp prim__cos + +export +sin : HasIO io => XlaOp -> io XlaOp +sin = unaryOp prim__sin + +export +tanh : HasIO io => XlaOp -> io XlaOp +tanh = unaryOp prim__tanh + +export +sqrt : HasIO io => XlaOp -> io XlaOp +sqrt = unaryOp prim__sqrt + +export +pow : HasIO io => XlaOp -> XlaOp -> io XlaOp +pow = binaryOp prim__pow + +export +neg : HasIO io => XlaOp -> io XlaOp +neg = unaryOp prim__neg + +export +transpose : HasIO io => XlaOp -> List Nat -> io XlaOp +transpose (MkXlaOp operand) permutation = do + MkIntArray permutationIntArrayPtr <- mkIntArray permutation + opPtr <- primIO $ prim__transpose operand permutationIntArrayPtr (cast $ length permutation) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +map : HasIO io => XlaBuilder -> List XlaOp -> XlaComputation -> List Nat -> io XlaOp +map (MkXlaBuilder builder) operands (MkXlaComputation computation) dimensions = do + MkXlaOpArray operandsXlaOpArrayPtr <- mkXlaOpArray operands + MkIntArray dimensionsIntArrayPtr <- mkIntArray dimensions + opPtr <- primIO $ prim__map + builder + operandsXlaOpArrayPtr (cast $ length operands) + computation + dimensionsIntArrayPtr (cast $ length dimensions) + prim__getNullAnyPtr 0 + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + +export +conditional : HasIO io => XlaOp -> XlaOp -> XlaComputation -> XlaOp -> XlaComputation -> io XlaOp +conditional + (MkXlaOp pred) + (MkXlaOp trueOperand) + (MkXlaComputation trueComputation) + (MkXlaOp falseOperand) + (MkXlaComputation falseComputation) = do + opPtr <- primIO $ prim__conditional + pred + trueOperand + trueComputation + falseOperand + falseComputation + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaComputation.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaComputation.idr new file mode 100644 index 000000000..b0d54a23f --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaComputation.idr @@ -0,0 +1,26 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaComputation + +public export +data XlaComputation : Type where + MkXlaComputation : GCAnyPtr -> XlaComputation + +export +delete : AnyPtr -> IO () +delete = primIO . prim__delete diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Literal.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Literal.idr new file mode 100644 index 000000000..2bd44b3e3 --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Literal.idr @@ -0,0 +1,79 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Literal + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Literal +import Compiler.Xla.TensorFlow.Compiler.Xla.Shape +import Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil +import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData +import Compiler.Xla.Util +import Types + +namespace Xla + public export + data Literal : Type where + MkLiteral : GCAnyPtr -> Literal + +export +delete : AnyPtr -> IO () +delete = primIO . prim__delete + +export +allocLiteral : HasIO io => Primitive dtype => Types.Shape -> io Literal +allocLiteral shape = do + MkShape shapePtr <- mkShape {dtype} shape + litPtr <- primIO $ prim__allocLiteral shapePtr + litPtr <- onCollectAny litPtr Literal.delete + pure (MkLiteral litPtr) + +namespace Bool + export + set : Literal -> List Nat -> Bool -> IO () + set (MkLiteral lit) idxs value = do + MkIntArray idxsArrayPtr <- mkIntArray idxs + primIO $ prim__literalSetBool lit idxsArrayPtr (if value then 1 else 0) + + export + get : Literal -> List Nat -> Bool + get (MkLiteral lit) idxs = unsafePerformIO $ do + MkIntArray idxsArrayPtr <- mkIntArray idxs + pure $ cIntToBool $ literalGetBool lit idxsArrayPtr + +namespace Double + export + set : Literal -> List Nat -> Double -> IO () + set (MkLiteral lit) idxs value = do + MkIntArray idxsArrayPtr <- mkIntArray idxs + primIO $ prim__literalSetDouble lit idxsArrayPtr value + + export + get : Literal -> List Nat -> Double + get (MkLiteral lit) idxs = unsafePerformIO $ do + MkIntArray idxsArrayPtr <- mkIntArray idxs + pure $ literalGetDouble lit idxsArrayPtr + +namespace Int + export + set : Literal -> List Nat -> Int -> IO () + set (MkLiteral lit) idxs value = do + MkIntArray idxsArrayPtr <- mkIntArray idxs + primIO $ prim__literalSetInt lit idxsArrayPtr value + + export + get : Literal -> List Nat -> Int + get (MkLiteral lit) idxs = unsafePerformIO $ do + MkIntArray idxsArrayPtr <- mkIntArray idxs + pure $ literalGetInt lit idxsArrayPtr diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Service/PlatformUtil.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Service/PlatformUtil.idr new file mode 100644 index 000000000..626f12149 --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Service/PlatformUtil.idr @@ -0,0 +1,25 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Service.PlatformUtil + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Service.PlatformUtil +import Compiler.Xla.TensorFlow.StreamExecutor.Platform + +export +getPlatform : String -> IO Platform +getPlatform platformName = do + platform <- primIO $ prim__getPlatform platformName + pure (MkPlatform platform) diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Shape.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Shape.idr new file mode 100644 index 000000000..0a558e988 --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Shape.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Compiler.Xla.Shape + +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Shape + +namespace Xla + public export + data Shape : Type where + MkShape : GCAnyPtr -> Shape + +export +delete : AnyPtr -> IO () +delete = primIO . prim__delete diff --git a/src/Compiler/TensorFlow/Compiler/XLA/ShapeUtil.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/ShapeUtil.idr similarity index 55% rename from src/Compiler/TensorFlow/Compiler/XLA/ShapeUtil.idr rename to src/Compiler/Xla/TensorFlow/Compiler/Xla/ShapeUtil.idr index 9792814fb..3ab047ea7 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/ShapeUtil.idr +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/ShapeUtil.idr @@ -13,21 +13,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.TensorFlow.Compiler.XLA.ShapeUtil +module Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil -import System.FFI - -import Compiler.FFI -import Compiler.TensorFlow.Compiler.XLA.Shape -import Compiler.TensorFlow.Compiler.XLA.XlaData +import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.ShapeUtil +import Compiler.Xla.TensorFlow.Compiler.Xla.Shape +import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData +import Compiler.Xla.Util import Types -%foreign (libxla "MakeShape") -prim__mkShape : Int -> GCPtr Int -> Int -> PrimIO AnyPtr - export -mkShape : HasIO io => Primitive dtype => Shape -> io GCAnyPtr +mkShape : (HasIO io, Primitive dtype) => Types.Shape -> io Xla.Shape mkShape shape = do let dtypeEnum = xlaIdentifier {dtype} - shapePtr <- primIO $ prim__mkShape dtypeEnum !(mkIntArray shape) (cast (length shape)) - onCollectAny shapePtr Shape.delete + MkIntArray shapeArrayPtr <- mkIntArray shape + shapePtr <- primIO $ prim__mkShape dtypeEnum shapeArrayPtr (cast $ length shape) + shapePtr <- onCollectAny shapePtr Shape.delete + pure (MkShape shapePtr) diff --git a/src/Compiler/TensorFlow/Compiler/XLA/XlaData.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr similarity index 77% rename from src/Compiler/TensorFlow/Compiler/XLA/XlaData.idr rename to src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr index 16e3e97a1..27ce038da 100644 --- a/src/Compiler/TensorFlow/Compiler/XLA/XlaData.idr +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr @@ -13,12 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -||| This module contains Idris types that represent primitive types supported by the XLA backend. -||| These Idris types have no values (at least not Idris values). Instead they carry metadata that -||| allows us to manage memory layouts in the backend. -module Compiler.TensorFlow.Compiler.XLA.XlaData +module Compiler.Xla.TensorFlow.Compiler.Xla.XlaData -||| A `Primitive` is an Idris representation of a primitive type supported by the XLA backend. export interface Primitive dtype where xlaIdentifier : Int diff --git a/src/Compiler/Xla/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr b/src/Compiler/Xla/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr new file mode 100644 index 000000000..4e4769fad --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Core/CommonRuntime/GPU/GPUInit.idr @@ -0,0 +1,33 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Core.CommonRuntime.GPU.GPUInit + +import Compiler.Xla.Prim.TensorFlow.Core.CommonRuntime.GPU.GPUInit +import Compiler.Xla.TensorFlow.Core.Platform.Status +import Compiler.Xla.TensorFlow.StreamExecutor.Platform + +export +validateGPUMachineManager : IO Status +validateGPUMachineManager = do + status <- primIO prim__validateGPUMachineManager + status <- onCollectAny status Status.delete + pure (MkStatus status) + +export +gpuMachineManager : IO Platform +gpuMachineManager = do + platform <- primIO prim__gpuMachineManager + pure (MkPlatform platform) diff --git a/src/Compiler/Xla/TensorFlow/Core/Platform/Status.idr b/src/Compiler/Xla/TensorFlow/Core/Platform/Status.idr new file mode 100644 index 000000000..f727e73dc --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/Core/Platform/Status.idr @@ -0,0 +1,31 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.Core.Platform.Status + +import Compiler.Xla.Prim.TensorFlow.Core.Platform.Status +import Compiler.Xla.Util + +public export +data Status : Type where + MkStatus : GCAnyPtr -> Status + +export +delete : AnyPtr -> IO () +delete = primIO . prim__delete + +export +ok : Status -> Bool +ok (MkStatus ptr) = cIntToBool (prim__ok ptr) diff --git a/src/Compiler/Xla/TensorFlow/StreamExecutor/Platform.idr b/src/Compiler/Xla/TensorFlow/StreamExecutor/Platform.idr new file mode 100644 index 000000000..b9697a761 --- /dev/null +++ b/src/Compiler/Xla/TensorFlow/StreamExecutor/Platform.idr @@ -0,0 +1,20 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Xla.TensorFlow.StreamExecutor.Platform + +public export +data Platform : Type where + MkPlatform : AnyPtr -> Platform diff --git a/src/Compiler/FFI.idr b/src/Compiler/Xla/Util.idr similarity index 71% rename from src/Compiler/FFI.idr rename to src/Compiler/Xla/Util.idr index ec6e5db20..456e6ddea 100644 --- a/src/Compiler/FFI.idr +++ b/src/Compiler/Xla/Util.idr @@ -13,22 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --} -module Compiler.FFI +module Compiler.Xla.Util -import Data.Vect import System.FFI -import Types +import Compiler.Xla.Prim.Util import Util -export -free : Ptr t -> IO () -free = System.FFI.free . prim__forgetPtr - -public export -libxla : String -> String -libxla fname = "C:" ++ fname ++ ",libc_xla_extension" - export cIntToBool : Int -> Bool cIntToBool 0 = False @@ -37,16 +28,20 @@ cIntToBool x = let msg = "Internal error: expected 0 or 1 from XLA C API for boolean conversion, got " ++ show x in (assert_total idris_crash) msg -%foreign (libxla "sizeof_int") -sizeofInt : Int +export +boolToCInt : Bool -> Int +boolToCInt True = 1 +boolToCInt False = 0 -%foreign (libxla "set_array_int") -prim__setArrayInt : Ptr Int -> Int -> Int -> PrimIO () +public export +data IntArray : Type where + MkIntArray : GCPtr Int -> IntArray export -mkIntArray : HasIO io => Cast ty Int => List ty -> io (GCPtr Int) +mkIntArray : (HasIO io, Cast a Int) => List a -> io IntArray mkIntArray xs = do ptr <- malloc (cast (length xs) * sizeofInt) let ptr = prim__castPtr ptr traverse_ (\(idx, x) => primIO $ prim__setArrayInt ptr (cast idx) (cast x)) (enumerate xs) - onCollect ptr free + ptr <- onCollect ptr (free . prim__forgetPtr) + pure (MkIntArray ptr) diff --git a/src/Primitive.idr b/src/Primitive.idr index 120bc9eb6..e90f8c877 100644 --- a/src/Primitive.idr +++ b/src/Primitive.idr @@ -23,8 +23,8 @@ module Primitive import Data.Hashable -import Compiler.TensorFlow.Compiler.XLA.Literal -import public Compiler.TensorFlow.Compiler.XLA.XlaData +import Compiler.LiteralRW +import public Compiler.Xla.TensorFlow.Compiler.Xla.XlaData import public Util.Hashable %hide Prelude.Num @@ -81,7 +81,7 @@ export Ord F64 where ||| A `PrimitiveRW dtype idr` means that values of type `idr` can be used to construct backend ||| data with data type `dtype`. export -interface Hashable idr => LiteralPrimitiveRW dtype idr => PrimitiveRW dtype idr | dtype where +interface Hashable idr => LiteralRW dtype idr => PrimitiveRW dtype idr | dtype where export PrimitiveRW PRED Bool where export PrimitiveRW S32 Int where diff --git a/src/Tensor.idr b/src/Tensor.idr index 3415070e8..d420a8019 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -17,31 +17,33 @@ limitations under the License. ||| number of functions operating on numeric `Tensor`s. module Tensor -import Data.SortedMap import Control.Monad.State -import Data.Hashable import public Data.List import public Data.List.Elem import Decidable.Equality import System.FFI +import Data.Hashable + +import Compiler.Computation +import Compiler.Graph +import Compiler.LiteralRW +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Math +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Matrix +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.ClientLibrary +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.LocalClient +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation +import Compiler.Xla.TensorFlow.Compiler.Xla.Literal +import Compiler.Xla.TensorFlow.Compiler.Xla.Service.PlatformUtil +import Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil +import Compiler.Xla.TensorFlow.Core.CommonRuntime.GPU.GPUInit +import Compiler.Xla.TensorFlow.Core.Platform.Status +import Compiler.Xla.TensorFlow.StreamExecutor.Platform import Literal import public Primitive import public Types import public Util -import Compiler.XLA -import Compiler.FFI -import Compiler.Graph -import Compiler.TensorFlow.Compiler.XLA.Client.Lib.Math -import Compiler.TensorFlow.Compiler.XLA.Client.Lib.Matrix -import Compiler.TensorFlow.Compiler.XLA.Client.ClientLibrary -import Compiler.TensorFlow.Compiler.XLA.Client.LocalClient -import Compiler.TensorFlow.Compiler.XLA.Client.XlaBuilder -import Compiler.TensorFlow.Compiler.XLA.Service.PlatformUtil -import Compiler.TensorFlow.Core.CommonRuntime.GPU.GPUInit -import Compiler.TensorFlow.Core.Platform.Status -import public Compiler.TensorFlow.Compiler.XLA.Literal -import Compiler.TensorFlow.Compiler.XLA.ShapeUtil ----------------------------- core definitions ---------------------------- @@ -53,7 +55,7 @@ import Compiler.TensorFlow.Compiler.XLA.ShapeUtil ||| @dtype The element type. export data Tensor : (0 shape : Shape) -> (0 dtype : Type) -> Type where - MkTensor : {shape : _} -> Graph -> ComputationComponent -> Tensor shape dtype + MkTensor : {shape : _} -> Graph -> Computation XlaOp -> Tensor shape dtype ||| Construct a `Tensor` from `Literal` data. export @@ -61,8 +63,9 @@ fromLiteral : PrimitiveRW dtype a => {shape : _} -> Literal shape a -> Tensor sh fromLiteral xs = let graph = FromLiteral {dtype} shape (hashWithSalt defaultSalt xs) in MkTensor graph $ cached graph $ do - lit <- mkLiteral {dtype} xs - prim__constantLiteral lit graph + MkCachingBuilder builder _ <- get + literal <- write {dtype} xs + constantLiteral builder literal namespace F64 export @@ -89,32 +92,24 @@ namespace S32 export toLiteral : PrimitiveRW dtype ty => Tensor shape dtype -> Literal shape ty toLiteral (MkTensor {shape} _ xs) = unsafePerformIO $ do - gpuStatus <- primIO prim__validateGPUMachineManager - gpuStatus <- onCollectAny gpuStatus Status.delete - platform <- - if prim__ok gpuStatus - then primIO prim__gpuMachineManager - else primIO (prim__getPlatform "Host") - + gpuStatus <- validateGPUMachineManager + platform <- if ok gpuStatus then gpuMachineManager else getPlatform "Host" computation <- build "" xs - client <- primIO $ prim__getOrCreateLocalClient platform prim__getNullAnyPtr 0 - lit <- prim__executeAndTransfer client computation prim__getNullAnyPtr 0 - pure (toLiteral {dtype} lit) + client <- getOrCreateLocalClient platform + lit <- executeAndTransfer client computation + pure (read {dtype} lit) -||| A string representation of an unevaluated `Tensor`, detailing all enqueued XLA operations. +||| A string representation of an unevaluated `Tensor`, detailing all enqueued Xla operations. ||| Useful for debugging. export Show (Tensor shape dtype) where - show (MkTensor _ xs) = unsafePerformIO (prim__opToString xs) + show (MkTensor _ xs) = opToString xs ----------------------------- structural operations ---------------------------- -reshapeImpl : (from, to : Shape) -> ComputationComponent -> ComputationComponent -reshapeImpl from to xs = do - dimOrder <- mkIntArray (range (length from)) - cto <- mkIntArray to - reshaped <- primIO $ prim__reshape !xs dimOrder (cast (length from)) cto (cast (length to)) - onCollectAny reshaped XlaOp.delete +reshapeWithDefaultOrdering : + (from, to : Shape) -> Computation XlaOp -> Computation XlaOp +reshapeWithDefaultOrdering from to xs = reshape !xs (range $ length from) to ||| Reshape a `Tensor`. For example, `reshape {to=[2, 1]} (fromLiteral [3, 4])` is ||| `fromLiteral [[3], [4]]`. The output can have a different rank to the input. @@ -123,7 +118,7 @@ reshape : Primitive dtype => {to : _} -> product from = product to => Tensor from dtype -> Tensor to dtype reshape (MkTensor {shape=from} graph xs) = let graph = Reshape to graph - in MkTensor graph $ cached graph $ reshapeImpl from to xs + in MkTensor graph $ cached graph $ reshapeWithDefaultOrdering from to xs ||| Add a dimension of length one at the specified `axis`. The new dimension will be at the ||| specified `axis` in the new `Tensor` (as opposed to the original `Tensor`). For example, @@ -135,7 +130,7 @@ expand : Primitive dtype => (axis : Nat) -> axis `LTE` length shape => Tensor sh expand axis (MkTensor {shape} graph xs) = let to = insertAt axis 1 shape graph = Reshape to graph - in MkTensor graph $ cached graph $ reshapeImpl shape to xs + in MkTensor graph $ cached graph $ reshapeWithDefaultOrdering shape to xs namespace Squeezable ||| A `Squeezable from to` constitutes proof that the shape `from` can be squeezed to the @@ -177,9 +172,9 @@ namespace Squeezable ||| ``` export squeeze : Primitive dtype => {to : _} -> Squeezable from to => Tensor from dtype -> Tensor to dtype -squeeze (MkTensor graph xs) = +squeeze (MkTensor {shape=from} graph xs) = let graph = Reshape to graph - in MkTensor graph $ cached graph $ reshapeImpl from to xs + in MkTensor graph $ cached graph $ reshapeWithDefaultOrdering from to xs ||| Take a slice from a single `Tensor` axis. For example, for ||| ``` @@ -223,13 +218,11 @@ slice : (axis, from, to : Nat) -> from `LTE` to => InBounds axis shape => Tensor shape dtype -> Tensor (replaceAt axis (to `minus` from) shape) dtype slice axis from to (MkTensor graph xs) = let graph = Slice axis from to graph - in MkTensor graph $ cached graph $ do - let rank = length shape - start <- mkIntArray (replicate axis 0 ++ [from] ++ replicate (rank `minus` axis) 0) - stop <- mkIntArray (replaceAt axis to shape) - strides <- mkIntArray (the (List Int) $ replicate rank 1) - sliced <- primIO $ prim__slice !xs start (cast rank) stop (cast rank) strides (cast rank) - onCollectAny sliced XlaOp.delete + rank = length shape + start = replicate axis 0 ++ [from] ++ replicate (rank `minus` (S axis)) 0 + stop = replaceAt axis to shape + strides = replicate rank 1 + in MkTensor graph $ cached graph $ do slice !xs start stop strides ||| Get the `idx`-th element from the specified `axis` of a tensor. For example, ||| `index 0 1 $ fromLiteral [[1, 2], [3, 4], [5, 6]]` is `fromLiteral [3, 4]`, and @@ -246,7 +239,7 @@ index axis idx xs with (xs) slice @{lteSuccRight (reflexive {ty=Nat})} axis idx (S idx) xs to = deleteAt axis shape graph = Reshape to graph - in MkTensor graph $ cached graph $ reshapeImpl shape to sliced + in MkTensor graph $ cached graph $ reshapeWithDefaultOrdering shape to sliced ||| Split a `Tensor` along a given axis at the specified index. For example, ||| `split 0 2 fromLiteral [[1, 2], [3, 4], [5, 6]]` is @@ -288,10 +281,8 @@ concat : Primitive dtype => (axis : Nat) -> Tensor s dtype -> Tensor s' dtype concat axis (MkTensor graphL l) (MkTensor graphR r) = let graph = Concat axis graphL graphR in MkTensor graph $ cached graph $ do - operands <- mkXlaOpArray [!l, !r] - MkXlaBuilder ptr _ <- get - res <- primIO $ prim__concatInDim ptr operands 2 (cast axis) - onCollectAny res XlaOp.delete + MkCachingBuilder builder _ <- get + concatInDim builder [!l, !r] (cast axis) ||| The diagonal of a matrix as a vector. For example, for ||| ``` @@ -305,9 +296,7 @@ export diag : Primitive dtype => Tensor [n, n] dtype -> Tensor [n] dtype diag (MkTensor graph xs) = let graph = Diag graph - in MkTensor graph $ cached graph $ do - xs <- primIO (prim__getMatrixDiagonal !xs) - onCollectAny xs XlaOp.delete + in MkTensor graph $ cached graph $ do getMatrixDiagonal !xs ||| Represents the upper- or lower-trinagular component of a matrix. public export @@ -331,9 +320,7 @@ export triangle : Primitive dtype => Triangle -> Tensor [n, n] dtype -> Tensor [n, n] dtype triangle tri (MkTensor graph xs) = let graph = Triangle (case tri of Upper => False; Lower => True) graph - in MkTensor graph $ cached graph $ do - op <- primIO $ prim__triangle !xs (case tri of Upper => 0; Lower => 1) - onCollectAny op XlaOp.delete + in MkTensor graph $ cached graph $ do triangle !xs (case tri of Upper => False; Lower => True) ||| Tranpose a matrix. For example, `(fromLiteral [[1, 2], [3, 4]]).T` is ||| `fromLiteral [[1, 3], [2, 4]]`. @@ -341,10 +328,7 @@ export (.T) : Primitive dtype => Tensor [m, n] dtype -> Tensor [n, m] dtype (MkTensor graph xs).T = let graph = Transpose graph - in MkTensor graph $ cached graph $ do - permutations <- mkIntArray $ the (List Int) $ [1, 0] - op <- primIO $ prim__transpose !xs permutations 2 - onCollectAny op XlaOp.delete + in MkTensor graph $ cached graph $ do transpose !xs [1, 0] ||| The identity tensor, with inferred shape and element type. For example, ||| ``` @@ -363,9 +347,8 @@ identity = let graph = Identity {dtype} n n = cast n in MkTensor graph $ cached graph $ do - MkXlaBuilder ptr _ <- get - op <- primIO $ prim__identityMatrix ptr (xlaIdentifier {dtype}) n n - onCollectAny op XlaOp.delete + MkCachingBuilder builder _ <- get + identityMatrix {dtype} builder n n ||| A `DimBroadcastable from to` proves that a dimension of size `from` can be broadcast to a ||| dimension of size `to`. @@ -435,37 +418,24 @@ broadcast xs with (xs) let graph = Broadcast to graph in case (isElem 0 to, from == to) of (Yes _, False) => MkTensor graph $ cached graph $ do - xlaShape <- mkShape {dtype} to - literal <- primIO $ prim__allocLiteral xlaShape - literal <- onCollectAny literal Literal.delete - prim__constantLiteral literal graph + MkCachingBuilder builder _ <- get + literal <- allocLiteral {dtype} to + constantLiteral builder literal _ => impl [] to xs where - broadcast : List Nat -> ComputationComponent -> ComputationComponent - broadcast broadcastSizes xs = do - broadcastSizesPtr <- mkIntArray broadcastSizes - op <- primIO (prim__broadcast !xs broadcastSizesPtr (cast $ length broadcastSizes)) - onCollectAny op XlaOp.delete - - broadcastInDim : Shape -> Shape -> ComputationComponent -> ComputationComponent - broadcastInDim ods bcd xs = do - odsPtr <- mkIntArray ods - bcdPtr <- mkIntArray bcd - let len = cast (length ods) - op <- primIO (prim__broadcastInDim !xs odsPtr len bcdPtr len) - onCollectAny op XlaOp.delete impl : {from, to : _} -> (toLeading, toTrailing : List Nat) -> {auto prf : Broadcastable from toTrailing} -> Tensor from dtype -> Tensor to dtype impl toLeading _ {prf=Same} (MkTensor _ mkOp) = let graph = Broadcast to graph - in MkTensor graph $ cached graph $ - if (length toLeading == 0) then mkOp else broadcast toLeading mkOp + in MkTensor graph $ cached graph $ + if (length toLeading == 0) then mkOp else do broadcast !mkOp toLeading impl toLeading (th' :: tt') {prf=Match _} (MkTensor _ mkOp) = let graph = Broadcast to graph - in MkTensor graph $ cached graph $ - broadcast toLeading (broadcastInDim (th' :: tt') (range (length from)) mkOp) + in MkTensor graph $ cached graph $ do + x <- broadcastInDim !mkOp (th' :: tt') (range (length from)) + broadcast x toLeading impl toLeading (th' :: tt') {prf=Nest _} xs = impl (toLeading ++ [th']) tt' xs %hint @@ -502,26 +472,13 @@ fill = broadcast {prf=scalarToAnyOk shape} . fromLiteral . Scalar export map : (Primitive a, Primitive b) => (Tensor [] a -> Tensor [] b) -> Tensor shape a -> Tensor shape b map f (MkTensor graph xs) = - let graph0 = Parameter {dtype=a} [] 0 - p0 = cached graph0 $ prim__parameter 0 [] "" {dtype=a} + let (graph0, p0) = parameter 0 [] "" {dtype=a} MkTensor graphf res = f (MkTensor graph0 p0) graph = Map graphf [graph] in MkTensor graph $ cached graph $ do computation <- buildWithSubBuilder "computation" [p0] res - - operands <- mkXlaOpArray [!xs] - let rank = length shape - dimensions <- mkIntArray (range rank) - MkXlaBuilder ptr _ <- get - - res <- primIO (prim__map - ptr - operands 1 - computation - dimensions (cast rank) - prim__getNullAnyPtr 0 - ) - onCollectAny res XlaOp.delete + MkCachingBuilder builder _ <- get + map builder [!xs] computation (range $ length shape) ||| Lift a binary function on scalars to an element-wise function on `Tensor`s of arbitrary shape. ||| For example, @@ -536,28 +493,14 @@ export map2 : (Primitive a, Primitive b, Primitive c) => (Tensor [] a -> Tensor [] b -> Tensor [] c) -> Tensor shape a -> Tensor shape b -> Tensor shape c map2 f (MkTensor graphL l) (MkTensor graphR r) = - let graph0 = Parameter {dtype=a} [] 0 - graph1 = Parameter {dtype=b} [] 1 - p0 = cached graph0 $ prim__parameter 0 [] "" {dtype=a} - p1 = cached graph1 $ prim__parameter 1 [] "" {dtype=b} + let (graph0, p0) = parameter 0 [] "" {dtype=a} + (graph1, p1) = parameter 1 [] "" {dtype=b} MkTensor graphf res = f (MkTensor graph0 p0) (MkTensor graph1 p1) graph = Map graphf [graphL, graphR] in MkTensor graph $ cached graph $ do computation <- buildWithSubBuilder "computation" [p0, p1] res - - operands <- mkXlaOpArray [!l, !r] - let rank = length shape - dimensions <- mkIntArray (range rank) - MkXlaBuilder ptr _ <- get - - res <- primIO (prim__map - ptr - operands 2 - computation - dimensions (cast rank) - prim__getNullAnyPtr 0 - ) - onCollectAny res XlaOp.delete + MkCachingBuilder builder _ <- get + map builder [!l, !r] computation (range $ length shape) ||| Reduce elements along one `axis` of a `Tensor` according to a specified `reducer` `Monoid`. ||| For example, if `x = fromLiteral [[0, 1, 2], [3, 4, 5]]`, then reduce @{Sum} 0 x` is @@ -572,77 +515,75 @@ reduce axis (MkTensor graph xs) = let semigroup : Monoid a -> Semigroup a semigroup _ = %search - in let graph0 = Parameter {dtype} [] 0 - graph1 = Parameter {dtype} [] 1 - p0 = cached graph0 $ prim__parameter 0 [] "" {dtype} - p1 = cached graph1 $ prim__parameter 1 [] "" {dtype} - MkTensor graphf resf = (<+>) @{semigroup reducer} (MkTensor graph0 p0) (MkTensor graph1 p1) + in let (graph0, p0) = parameter 0 [] "" {dtype} + (graph1, p1) = parameter 1 [] "" {dtype} + MkTensor graphf resf = + (<+>) @{semigroup reducer} (MkTensor graph0 p0) (MkTensor graph1 p1) graph = Reduce graphf axis graph in MkTensor graph $ cached graph $ do computation <- buildWithSubBuilder "computation" [p0, p1] resf let MkTensor _ init = neutral @{reducer} - op <- primIO $ prim__reduce !xs !init computation !(mkIntArray [axis]) 1 - onCollectAny op XlaOp.delete + reduce !xs !init computation [axis] ----------------------------- numeric operations ---------------------------- -unaryOp : String -> (GCAnyPtr -> PrimIO AnyPtr) -> Tensor shape a -> Tensor shape b -unaryOp fnName primOperator (MkTensor graph xs) = +unaryOp : + Primitive b => String -> (XlaOp -> Computation XlaOp) -> Tensor shape a -> Tensor shape b +unaryOp fnName xlaOperation (MkTensor graph xs) = let graph = ElementwiseUnary fnName graph - in MkTensor graph $ cached graph $ do - op <- primIO (primOperator !xs) - onCollectAny op XlaOp.delete - -binaryOp : String -> (GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr) - -> Tensor shape a -> Tensor shape b -> Tensor shape c -binaryOp fnName primOperator (MkTensor graphL l) (MkTensor graphR r) = + in MkTensor graph $ cached graph $ do xlaOperation !xs + +binaryOp : + Primitive c => + String -> + (XlaOp -> XlaOp -> Computation XlaOp) -> + Tensor shape a -> Tensor shape b -> Tensor shape c +binaryOp fnName xlaOperation (MkTensor graphL l) (MkTensor graphR r) = let graph = ElementwiseBinary fnName graphL graphR - in MkTensor graph $ cached graph $ do - op <- primIO (primOperator !l !r) - onCollectAny op XlaOp.delete + in MkTensor graph $ cached graph $ do xlaOperation !l !r ||| Element-wise equality. For example, `fromLiteral [1, 2] == fromLiteral [1, 3]` is ||| `fromLiteral [True, False]`. export (==) : Primitive.Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED -(==) = binaryOp "(==)" prim__eq +(==) = binaryOp "(==)" eq ||| Element-wise inequality. For example, `fromLiteral [1, 2] /= fromLiteral [1, 3]` is ||| `fromLiteral [False, True]`. export (/=) : Primitive.Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED -(/=) = binaryOp "(/=)" prim__ne +(/=) = binaryOp "(/=)" ne ||| Element-wise less than. For example, `fromLiteral [1, 2, 3] < fromLiteral [2, 2, 2]` is ||| `fromLiteral [True, False, False]`. export (<) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED -(<) = binaryOp "(<)" prim__lt +(<) = binaryOp "(<)" lt ||| Element-wise greater than. For example, `fromLiteral [1, 2, 3] > fromLiteral [2, 2, 2]` is ||| `fromLiteral [False, False, True]`. export (>) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED -(>) = binaryOp "(>)" prim__gt +(>) = binaryOp "(>)" gt ||| Element-wise less than or equal. For example, `fromLiteral [1, 2, 3] <= fromLiteral [2, 2, 2]` ||| is `fromLiteral [True, True, False]`. export (<=) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED -(<=) = binaryOp "(<=)" prim__le +(<=) = binaryOp "(<=)" le ||| Element-wise greater than or equal. For example, ||| `fromLiteral [1, 2, 3] >= fromLiteral [2, 2, 2]` is `fromLiteral [False, True, True]`. export (>=) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED -(>=) = binaryOp "(>=)" prim__ge +(>=) = binaryOp "(>=)" ge ||| Element-wise boolean and. For example, ||| `fromLiteral [True, True, False, False] && fromLiteral [True, False, True, False]` is ||| `fromLiteral [True, False, False, False]`. export (&&) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED -(&&) = binaryOp "(&&)" prim__and +(&&) = binaryOp "(&&)" and namespace Semigroup export @@ -659,7 +600,7 @@ namespace Monoid ||| `fromLiteral [True, True, True, False]`. export (||) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED -(||) = binaryOp "(||)" prim__or +(||) = binaryOp "(||)" or namespace Semigroup export @@ -675,7 +616,7 @@ namespace Monoid ||| `fromLiteral [False, True]`. export not : Tensor shape PRED -> Tensor shape PRED -not = unaryOp "not" prim__not +not = unaryOp "not" not ||| Choose elements from two `Tensor`s based on a `Tensor` of predicates. For each element in the ||| predicates, the output will use the corresponding element from `onTrue` if the element is @@ -699,9 +640,7 @@ select : Primitive dtype => Tensor shape PRED -> (onTrue : Tensor shape dtype) -> (onFalse : Tensor shape dtype) -> Tensor shape dtype select (MkTensor gPred pred) (MkTensor gTrue true) (MkTensor gFalse false) = let graph = Select gPred gTrue gFalse - in MkTensor graph $ cached graph $ do - op <- primIO $ prim__select !pred !true !false - onCollectAny op XlaOp.delete + in MkTensor graph $ cached graph $ do select !pred !true !false ||| Use a scalar predicate to choose which of two functions to evaluate. If the predicte is truthy, ||| evaluate `onTrue` on the corresponding specified argument, otherwise evaluate `onFalse` on the @@ -732,18 +671,15 @@ cond (MkTensor graphPred pred) onTrue (MkTensor graphTrue true) onFalse (MkTensor graphFalse false) = - let grapht = Parameter {dtype=tt} ts 0 - graphf = Parameter {dtype=ft} fs 0 - pt = cached grapht $ prim__parameter 0 ts "" {dtype} - pf = cached graphf $ prim__parameter 0 fs "" {dtype} + let (grapht, pt) = parameter 0 ts "" {dtype=tt} + (graphf, pf) = parameter 0 fs "" {dtype=ft} MkTensor graphOnTrue trueRes = onTrue (MkTensor grapht pt) MkTensor graphOnFalse falseRes = onFalse (MkTensor graphf pf) graph = Cond graphPred graphOnTrue graphTrue graphOnFalse graphFalse in MkTensor graph $ cached graph $ do trueComp <- buildWithSubBuilder "truthy computation" [pt] trueRes falseComp <- buildWithSubBuilder "falsy computation" [pf] falseRes - op <- primIO $ prim__conditional !pred !true trueComp !false falseComp - onCollectAny op XlaOp.delete + conditional !pred !true trueComp !false falseComp -- see https://www.python.org/dev/peps/pep-0465/#precedence-and-associativity infixl 9 @@ @@ -758,9 +694,7 @@ namespace Vector (@@) : Primitive.Num dtype => Tensor [S m] dtype -> Tensor [S m] dtype -> Tensor [] dtype (MkTensor graphL l) @@ (MkTensor graphR r) = let graph = Dot graphL graphR - in MkTensor graph $ cached graph $ do - op <- primIO $ prim__dot !l !r - onCollectAny op XlaOp.delete + in MkTensor graph $ cached graph $ do dot !l !r namespace Matrix ||| Matrix multiplication with a matrix or vector. Contraction is along the last axis of the first @@ -791,15 +725,13 @@ namespace Matrix -> length tl `LTE` 1 => Tensor (n :: tl) dtype (MkTensor graphL l) @@ (MkTensor graphR r) = let graph = Dot graphL graphR - in MkTensor graph $ cached graph $ do - op <- primIO $ prim__dot !l !r - onCollectAny op XlaOp.delete + in MkTensor graph $ cached graph $ do dot !l !r ||| Element-wise addition. For example, `fromLiteral [1, 2] + fromLiteral [3, 4]` is ||| `fromLiteral [4, 6]`. export (+) : Primitive.Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype -(+) = binaryOp "(+)" prim__add +(+) = binaryOp "(+)" add namespace Semigroup export @@ -815,19 +747,19 @@ namespace Monoid ||| Element-wise negation. For example, `- fromLiteral [1, -2]` is `fromLiteral [-1, 2]`. export negate : Primitive.Neg dtype => Tensor shape dtype -> Tensor shape dtype -negate = unaryOp "negate" prim__neg +negate = unaryOp "negate" neg ||| Element-wise subtraction. For example, `fromLiteral [3, 4] - fromLiteral [4, 2]` is ||| `fromLiteral [-1, 2]`. export (-) : Primitive.Neg dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype -(-) = binaryOp "(-)" prim__sub +(-) = binaryOp "(-)" sub ||| Element-wise multiplication. For example, `fromLiteral [2, 3] * fromLiteral [4, 5]` is ||| `fromLiteral [8, 15]`. export (*) : Primitive.Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype -(*) = binaryOp "(*)" prim__mul +(*) = binaryOp "(*)" mul namespace Scalarwise ||| Multiplication by a scalar. For example, `fromLiteral 2 * fromLiteral [3, 5]` is @@ -854,7 +786,7 @@ namespace Monoid ||| `fromLiteral [0.5, 0.6]`. export (/) : Primitive.Fractional dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype -(/) = binaryOp "(/)" prim__div +(/) = binaryOp "(/)" div namespace Scalarwise ||| Floating point division by a scalar. For example, `fromLiteral [3.4, -5.6] / fromLiteral 2` is @@ -871,7 +803,7 @@ namespace Scalarwise ||| is `fromLiteral [-0.5, nan, 5]`. export recip : Tensor shape F64 -> Tensor shape F64 -recip = unaryOp "recip" prim__reciprocal +recip = unaryOp "recip" reciprocal infixr 9 ^ @@ -884,121 +816,121 @@ infixr 9 ^ ||| Note: The first root is used. export (^) : Tensor shape F64 -> Tensor shape F64 -> Tensor shape F64 -(^) = binaryOp "(^)" prim__pow +(^) = binaryOp "(^)" pow ||| Element-wise absolute value. For example, `abs (fromLiteral [-2, 3])` is ||| `fromLiteral [2, 3]`. export abs : Primitive.Abs dtype => Tensor shape dtype -> Tensor shape dtype -abs = unaryOp "abs" prim__abs +abs = unaryOp "abs" abs ||| The element-wise natural exponential. For example, `exp (fromLiteral [-1, 0, 2])` is ||| `fromLiteral [1 / euler, 1, pow euler 2]`. export exp : Tensor shape F64 -> Tensor shape F64 -exp = unaryOp "exp" prim__exp +exp = unaryOp "exp" exp ||| The element-wise floor function. For example, ||| `floor (fromLiteral [-1.6, -1.5, -1.4, -1.0, 1.0, 1.4, 1.5, 1.6])` is ||| `fromLiteral [-2.0, -2.0, -2.0, -1.0, 1.0, 1.0, 1.0, 1.0]`. export floor : Tensor shape F64 -> Tensor shape F64 -floor = unaryOp "floor" prim__floor +floor = unaryOp "floor" floor ||| The element-wise ceiling function. For example, ||| `ceil (fromLiteral [-1.6, -1.5, -1.4, -1.0, 1.0, 1.4, 1.5, 1.6])` is ||| `fromLiteral [-1.0, -1.0, -1.0, -1.0, 1.0, 2.0, 2.0, 2.0]`. export ceil : Tensor shape F64 -> Tensor shape F64 -ceil = unaryOp "ceil" prim__ceil +ceil = unaryOp "ceil" ceil ||| The element-wise natural logarithm. Negative inputs yield NaN output. For example, ||| `log (fromLiteral [1 / euler, 1, euler * euler])` is `fromLiteral [-1, 0, 2]`. export log : Tensor shape F64 -> Tensor shape F64 -log = unaryOp "log" prim__log +log = unaryOp "log" log ||| The element-wise logistic function equivalent to `1 / 1 + exp (-x)`. export logistic : Tensor shape F64 -> Tensor shape F64 -logistic = unaryOp "logistic" prim__logistic +logistic = unaryOp "logistic" logistic ||| The element-wise sine. export sin : Tensor shape F64 -> Tensor shape F64 -sin = unaryOp "sin" prim__sin +sin = unaryOp "sin" sin ||| The element-wise cosine. export cos : Tensor shape F64 -> Tensor shape F64 -cos = unaryOp "cos" prim__cos +cos = unaryOp "cos" cos ||| The element-wise tangent. export tan : Tensor shape F64 -> Tensor shape F64 -tan = unaryOp "tan" prim__tan +tan = unaryOp "tan" tan ||| The element-wise inverse sine. export asin : Tensor shape F64 -> Tensor shape F64 -asin = unaryOp "asin" prim__asin +asin = unaryOp "asin" asin ||| The element-wise inverse cosine. export acos : Tensor shape F64 -> Tensor shape F64 -acos = unaryOp "acos" prim__acos +acos = unaryOp "acos" acos ||| The element-wise inverse tangent. export atan : Tensor shape F64 -> Tensor shape F64 -atan = unaryOp "atan" prim__atan +atan = unaryOp "atan" atan ||| The element-wise hyperbolic sine. export sinh : Tensor shape F64 -> Tensor shape F64 -sinh = unaryOp "sinh" prim__sinh +sinh = unaryOp "sinh" sinh ||| The element-wise hyperbolic cosine. export cosh : Tensor shape F64 -> Tensor shape F64 -cosh = unaryOp "cosh" prim__cosh +cosh = unaryOp "cosh" cosh ||| The element-wise hyperbolic tangent. export tanh : Tensor shape F64 -> Tensor shape F64 -tanh = unaryOp "tanh" prim__tanh +tanh = unaryOp "tanh" tanh ||| The element-wise inverse hyperbolic sine. export asinh : Tensor shape F64 -> Tensor shape F64 -asinh = unaryOp "asinh" prim__asinh +asinh = unaryOp "asinh" asinh ||| The element-wise inverse hyperbolic cosine. export acosh : Tensor shape F64 -> Tensor shape F64 -acosh = unaryOp "acosh" prim__acosh +acosh = unaryOp "acosh" acosh ||| The element-wise inverse hyperbolic tangent. export atanh : Tensor shape F64 -> Tensor shape F64 -atanh = unaryOp "atanh" prim__atanh +atanh = unaryOp "atanh" atanh ||| An approximation to the element-wise error function. export erf : Tensor shape F64 -> Tensor shape F64 -erf = unaryOp "erf" prim__erf +erf = unaryOp "erf" erf ||| The element-wise square. For example, `square (fromLiteral [-2, 0, 3])` ||| is `fromLiteral [4, 0, 9]`. export square : Tensor shape F64 -> Tensor shape F64 -square = unaryOp "square" prim__square +square = unaryOp "square" square ||| The element-wise square root. The first root is used. Negative inputs yield NaN output. ||| For example, `sqrt (fromLiteral [0, 9])` is `fromLiteral [0, 3]`. export sqrt : Tensor shape F64 -> Tensor shape F64 -sqrt = unaryOp "sqrt" prim__sqrt +sqrt = unaryOp "sqrt" sqrt ||| The element-wise minimum of the first argument compared to the second. For example, ||| `min (fromLiteral [-3, -1, 3]) (fromLiteral [-1, 0, 1])` is `fromLiteral [-3, -1, 1]`. @@ -1006,7 +938,7 @@ sqrt = unaryOp "sqrt" prim__sqrt ||| **Note:** There is a known issue where sometimes the wrong value is chosen if one value is NaN. export min : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype -min = binaryOp "min" prim__min +min = binaryOp "min" min ||| **Note:** There is a known issue where sometimes the wrong value is chosen if one value is NaN. namespace Semigroup @@ -1027,7 +959,7 @@ namespace Monoid ||| **Note:** There is a known issue where sometimes the wrong value is chosen if one value is NaN. export max : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype -max = binaryOp "max" prim__max +max = binaryOp "max" max ||| **Note:** There is a known issue where sometimes the wrong value is chosen if one value is NaN. namespace Semigroup @@ -1052,9 +984,7 @@ export cholesky : Tensor [S n, S n] F64 -> Tensor [S n, S n] F64 cholesky (MkTensor graph xs) = let graph = Cholesky graph - in triangle Lower $ MkTensor graph $ cached graph $ do - res <- primIO $ prim__cholesky !xs 1 - onCollectAny res XlaOp.delete + in triangle Lower $ MkTensor graph $ cached graph $ do cholesky !xs True infix 9 |\, \| @@ -1071,8 +1001,7 @@ namespace Matrix (MkTensor graphA a) |\ (MkTensor graphB b) = let graph = TriangularSolve True graphA graphB in MkTensor graph $ cached graph $ do - op <- primIO $ prim__triangularSolve !a !b 1 1 0 1 - onCollectAny op XlaOp.delete + triangularSolve !a !b True True False NoTranspose ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is an upper-triangular ||| matrix. `a` is given by the upper-triangular elements of the first argument. Values in the @@ -1086,8 +1015,7 @@ namespace Matrix (MkTensor graphA a) \| (MkTensor graphB b) = let graph = TriangularSolve False graphA graphB in MkTensor graph $ cached graph $ do - op <- primIO $ prim__triangularSolve !a !b 1 0 0 1 - onCollectAny op XlaOp.delete + triangularSolve !a !b True False False NoTranspose namespace Vector ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is a lower-triangular matrix.