Skip to content

Commit

Permalink
fix ambiguity in Graphs where non-Tensor values were ignored (#275)
Browse files Browse the repository at this point in the history
* add dedicated `Graph` constructor for each op
* delete `Show` for `Graph`
  • Loading branch information
joelberkeley authored May 25, 2022
1 parent 3699466 commit 8276a38
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 203 deletions.
97 changes: 48 additions & 49 deletions src/Compiler/Graph.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,62 +15,61 @@ limitations under the License.
--}
module Compiler.Graph

import Primitive
import Data.Hashable
import Data.Stream
import Types
import Util

||| A `Graph` represents the graph computation of a tensor value. It is equivalent to the
||| computation graph used at runtime, but might not be an exact representation. Specifically, given
||| two `Graph`s gx and gy that compute tensors x and y respectively, if gx is equal to gy, then x
||| is equal to y, but the computations used to compute x and y may be different.
||| A `Graph` represents a computational graph used to compute a tensor value. It is defined by
||| the following proprty: For any two `Graph`s gx and gy that compute tensors x and y respectively,
||| if gx is identical to gy, then the values of x and y are equal.
|||
||| It is primarily used for memoization in constructing the computation graph.
public export
data Graph : Type where
||| Represents a function application.
|||
||| @name The name of the operation.
||| @arguments The arguments the operation is called on.
||| @shape The shape of the resulting tensor.
||| @type A string representation of the data type of the resulting tensor.
Operation :
(name : String) ->
(arguments : List Graph) ->
(shape : Shape) ->
(type : String) ->
Graph

||| Represents a tensor value. This tensor can have a concrete value, or correspond to a
||| function argument.
|||
||| @name The name of the method of instantiating the tensor.
||| @id_ An identifier to differentiate this tensor from other tensors.
||| @shape The shape of this tensor.
||| @type A string representation of the data type of the tensor.
Leaf : (name : String) -> (id_ : Bits64) -> (shape : Shape) -> (type : String) -> Graph

export covering
Eq Graph where
Operation lname largs lshape ltype == Operation rname rargs rshape rtype =
lname == rname && largs == rargs && lshape == rshape && ltype == rtype
Leaf lname lhash lshape ltype == Leaf rname rhash rshape rtype =
lname == rname && lhash == rhash && lshape == rshape && ltype == rtype
_ == _ = False

export covering
Show Graph where
show xs = impl 0 xs where
impl : Nat -> Graph -> String
impl depth =
let indent = pack $ take (2 * depth) (repeat ' ')
in \case
Operation name args shape type =>
let init = indent ++ "\{type}\{show shape} \{name}"
in foldl (\acc, g => acc ++ "\n" ++ impl (S depth) g) init args
Leaf name hash shape type => indent ++ "\{type}\{show shape} \{name}"
FromLiteral : Primitive dtype => Shape -> (hash : Bits64) -> Graph
Parameter : Primitive dtype => Shape -> Nat -> Graph
Reshape : Shape -> Graph -> Graph
Slice : Nat -> Nat -> Nat -> Graph -> Graph
Concat : Nat -> Graph -> Graph -> Graph
Diag : Graph -> Graph
Triangle : (lower : Bool) -> Graph -> Graph
Transpose : Graph -> Graph
Identity : Primitive dtype => Nat -> Graph
Broadcast : Shape -> Graph -> Graph
Map : Graph -> List Graph -> Graph
Reduce : Graph -> Nat -> Graph -> Graph
ElementwiseBinary : (name : String) -> Graph -> Graph -> Graph
ElementwiseUnary : (name : String) -> Graph -> Graph
Select : Graph -> Graph -> Graph -> Graph
Cond : Graph -> Graph -> Graph -> Graph -> Graph -> Graph
Dot : Graph -> Graph -> Graph
Cholesky : Graph -> Graph
TriangularSolve : (lower : Bool) -> Graph -> Graph -> Graph

export covering
Hashable Graph where
hashWithSalt salt (Operation name arguments shape type) =
salt `hashWithSalt` name `hashWithSalt` arguments `hashWithSalt` shape `hashWithSalt` type
hashWithSalt salt (Leaf name id_ shape type) =
salt `hashWithSalt` name `hashWithSalt` id_ `hashWithSalt` shape `hashWithSalt` type
hashWithSalt salt (FromLiteral {dtype} hash shape) =
salt `hashWithSalt` ("FromLiteral", typeString {dtype}, shape, hash)
hashWithSalt salt (Parameter {dtype} shape position) =
salt `hashWithSalt` ("Parameter", typeString {dtype}, shape, position)
hashWithSalt salt (Reshape to x) = salt `hashWithSalt` ("Reshape", to, x)
hashWithSalt salt (Slice axis from to x) = salt `hashWithSalt` ("Slice", axis, from, to)
hashWithSalt salt (Concat axis x y) = salt `hashWithSalt` ("Concat", axis, x, y)
hashWithSalt salt (Diag x) = salt `hashWithSalt` ("Diag", x)
hashWithSalt salt (Triangle lower x) = salt `hashWithSalt` ("Triangle", lower, x)
hashWithSalt salt (Transpose x) = salt `hashWithSalt` ("Transpose", x)
hashWithSalt salt (Identity {dtype} n) = salt `hashWithSalt` ("Identity", typeString {dtype}, n)
hashWithSalt salt (Broadcast to x) = salt `hashWithSalt` ("Broadcast", to, x)
hashWithSalt salt (Map f xs) = salt `hashWithSalt` ("Map", f, xs)
hashWithSalt salt (Reduce monoid axis x) = salt `hashWithSalt` ("Reduce", monoid, axis, x)
hashWithSalt salt (ElementwiseBinary name x y) = hashWithSalt salt (name, x, y)
hashWithSalt salt (ElementwiseUnary name x) = hashWithSalt salt (name, x)
hashWithSalt salt (Select pred f t) = salt `hashWithSalt` ("Select", pred, f, t)
hashWithSalt salt (Cond pred fTrue true fFalse false) =
salt `hashWithSalt` "Cond" `hashWithSalt` (pred, fTrue, true, fFalse, false)
hashWithSalt salt (Dot x y) = salt `hashWithSalt` ("Dot", x, y)
hashWithSalt salt (Cholesky x) = salt `hashWithSalt` ("Cholesky", x)
hashWithSalt salt (TriangularSolve lower x y) =
salt `hashWithSalt` ("TriangularSolve", lower, x, y)
Loading

0 comments on commit 8276a38

Please sign in to comment.