Skip to content

Commit

Permalink
Start using alcotest
Browse files Browse the repository at this point in the history
  • Loading branch information
rounakdatta committed Apr 10, 2024
1 parent 5300009 commit 943cb53
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 9 deletions.
8 changes: 4 additions & 4 deletions lib/neuron.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ module Neuron = struct
dependencies : t list;
}

let data base =
base.data
let data base = base.data
let grad base = base.grad
let dependencies base = base.dependencies

let create ?(op="n/a") ?(deps=[]) dt = {
data = dt;
Expand Down Expand Up @@ -61,7 +62,7 @@ module Neuron = struct
let backpropagate base =
(* we topologically sort all the connected nodes from the base node *)
let rec sort_topologically visited resultant candidate =
if not (List.mem candidate visited) then
if not (List.memq candidate visited) then
let visited = candidate :: visited in
let visited, resultant =
List.fold_left (fun (visited, resultant) dependent ->
Expand All @@ -80,4 +81,3 @@ module Neuron = struct
(* and propagate the gradient changes *)
List.iter (fun v -> v.backward ()) (List.rev resultant)
end

6 changes: 6 additions & 0 deletions lib/neuron.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ module Neuron : sig
(* getter for the data *)
val data : t -> float

(* getter for the gradient or weight *)
val grad : t -> float

(* getter for the dependencies of the node *)
val dependencies : t -> t list

(* Constructor; constructs a unit neuron of a value and an operator. *)
val create : ?op:string -> ?deps:t list -> float -> t

Expand Down
2 changes: 1 addition & 1 deletion smolgrad.opam
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ synopsis: "Algorithmic differentiation in OCaml"
depends: [
"dune"
"ppx_jane"
"ounit"
"alcotest"
]
2 changes: 1 addition & 1 deletion test/dune
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
(test
(name smolgrad_tests)
(libraries ounit2 smolgrad))
(libraries alcotest smolgrad))
35 changes: 32 additions & 3 deletions test/smolgrad_tests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,49 @@ let test_simple_operation () =
let b = Neuron.create 2.0 in

let abba = Neuron.(a + b) in
assert (Neuron.data abba = 6.0);
Alcotest.(check (float 0.0))
"Nodes add up correctly"
6.0 (Neuron.data abba);
;;

(* here we open the Neuron module wide open locally, thereby allowing the clean custom `+` operator usage *)
(* we'll avoid this pattern elsewhere in the tests *)
let test_custom_operator () =
let open Neuron in
let a = create 4.0 in
let b = create 2.0 in

let abba = a + b in
assert (Neuron.data abba = 6.0);
Alcotest.(check (float 0.0))
"Nodes add up correctly with custom operator"
6.0 (Neuron.data abba);
;;


let test_graph_construction () =
let a = Neuron.create 4.0 in
let b = Neuron.create 2.0 in
let c = Neuron.(a * b + b**3.0) in
let d = Neuron.(c + a) in

Alcotest.(check (list (float 0.0)))
"Dependency graph is constructed correctly"
(List.map (fun x -> Neuron.data x) [a; b]) (List.map (fun x -> Neuron.data x) (Neuron.dependencies d));
;;

let test_backpropagation () =
let a = Neuron.create 4.0 in
let b = Neuron.create 2.0 in
let c = Neuron.(a * b + b**3.0) in

Neuron.backpropagate c;
Alcotest.(check (float 0.0))
"Backpropagation yields correct gradient for a complex graph"
6.0 (Neuron.grad c);
;;

let () =
test_simple_operation ();
test_custom_operator ();
test_graph_construction ();
(* test_backpropagation (); *)
;;

0 comments on commit 943cb53

Please sign in to comment.