diff --git a/lib/neuron.ml b/lib/neuron.ml index b4ddae6..b3755a6 100644 --- a/lib/neuron.ml +++ b/lib/neuron.ml @@ -8,8 +8,9 @@ module Neuron = struct dependencies : t list; } - let data base = - base.data + let data base = base.data + let grad base = base.grad + let dependencies base = base.dependencies let create ?(op="n/a") ?(deps=[]) dt = { data = dt; @@ -61,7 +62,7 @@ module Neuron = struct let backpropagate base = (* we topologically sort all the connected nodes from the base node *) let rec sort_topologically visited resultant candidate = - if not (List.mem candidate visited) then + if not (List.memq candidate visited) then let visited = candidate :: visited in let visited, resultant = List.fold_left (fun (visited, resultant) dependent -> @@ -80,4 +81,3 @@ module Neuron = struct (* and propagate the gradient changes *) List.iter (fun v -> v.backward ()) (List.rev resultant) end - diff --git a/lib/neuron.mli b/lib/neuron.mli index f06ae7d..5cce869 100644 --- a/lib/neuron.mli +++ b/lib/neuron.mli @@ -4,6 +4,12 @@ module Neuron : sig (* getter for the data *) val data : t -> float + (* getter for the gradient or weight *) + val grad : t -> float + + (* getter for the dependencies of the node *) + val dependencies : t -> t list + (* Constructor; constructs a unit neuron of a value and an operator. *) val create : ?op:string -> ?deps:t list -> float -> t diff --git a/smolgrad.opam b/smolgrad.opam index b0db3a1..0953c2d 100644 --- a/smolgrad.opam +++ b/smolgrad.opam @@ -4,5 +4,5 @@ synopsis: "Algorithmic differentiation in OCaml" depends: [ "dune" "ppx_jane" - "ounit" + "alcotest" ] diff --git a/test/dune b/test/dune index 5b572af..fee8fda 100644 --- a/test/dune +++ b/test/dune @@ -1,3 +1,3 @@ (test (name smolgrad_tests) - (libraries ounit2 smolgrad)) + (libraries alcotest smolgrad)) diff --git a/test/smolgrad_tests.ml b/test/smolgrad_tests.ml index 7ee5522..a40c7a9 100644 --- a/test/smolgrad_tests.ml +++ b/test/smolgrad_tests.ml @@ -5,20 +5,49 @@ let test_simple_operation () = let b = Neuron.create 2.0 in let abba = Neuron.(a + b) in - assert (Neuron.data abba = 6.0); + Alcotest.(check (float 0.0)) + "Nodes add up correctly" + 6.0 (Neuron.data abba); ;; (* here we open the Neuron module wide open locally, thereby allowing the clean custom `+` operator usage *) +(* we'll avoid this pattern elsewhere in the tests *) let test_custom_operator () = let open Neuron in let a = create 4.0 in let b = create 2.0 in let abba = a + b in - assert (Neuron.data abba = 6.0); + Alcotest.(check (float 0.0)) + "Nodes add up correctly with custom operator" + 6.0 (Neuron.data abba); ;; - + +let test_graph_construction () = + let a = Neuron.create 4.0 in + let b = Neuron.create 2.0 in + let c = Neuron.(a * b + b**3.0) in + let d = Neuron.(c + a) in + + Alcotest.(check (list (float 0.0))) + "Dependency graph is constructed correctly" + (List.map (fun x -> Neuron.data x) [a; b]) (List.map (fun x -> Neuron.data x) (Neuron.dependencies d)); +;; + +let test_backpropagation () = + let a = Neuron.create 4.0 in + let b = Neuron.create 2.0 in + let c = Neuron.(a * b + b**3.0) in + + Neuron.backpropagate c; + Alcotest.(check (float 0.0)) + "Backpropagation yields correct gradient for a complex graph" + 6.0 (Neuron.grad c); +;; + let () = test_simple_operation (); test_custom_operator (); + test_graph_construction (); + (* test_backpropagation (); *) ;;