Skip to content

Commit

Permalink
Add custom operators; add simple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rounakdatta committed Apr 10, 2024
1 parent 2b9c099 commit 5300009
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 12 deletions.
5 changes: 3 additions & 2 deletions lib/dune
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
(library
(name nn)
(modules Neuron Network))
(name smolgrad)
(public_name smolgrad)
(modules neuron))
Empty file removed lib/network.ml
Empty file.
25 changes: 17 additions & 8 deletions lib/neuron.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ module Neuron = struct
mutable backward : unit -> unit;

operator : string;
dependents : t list;
dependencies : t list;
}

let create dt op deps = {
let data base =
base.data

let create ?(op="n/a") ?(deps=[]) dt = {
data = dt;
grad = 0.;
backward = (fun () -> ());
operator = op;
dependents = deps;
dependencies = deps;
}

let add base partner =
let resultant = create (base.data +. partner.data) "+" [base; partner] in
let resultant = create ~op:"+" ~deps:[base; partner] (base.data +. partner.data) in

resultant.backward <- (fun () ->
base.grad <- base.grad +. resultant.grad;
Expand All @@ -26,7 +29,7 @@ module Neuron = struct
resultant

let mul base partner =
let resultant = create (base.data *. partner.data) "*" [base; partner] in
let resultant = create ~op:"*" ~deps:[base; partner] (base.data *. partner.data) in

resultant.backward <- (fun () ->
base.grad <- base.grad +. partner.data *. resultant.grad;
Expand All @@ -35,15 +38,20 @@ module Neuron = struct
resultant

let exp base exponent =
let resultant = create (base.data ** exponent) "**" [base] in
let resultant = create ~op:"**" ~deps:[base] (base.data ** exponent) in

resultant.backward <- (fun () ->
base.grad <- base.grad +. exponent *. (base.data ** (exponent -. 1.)) *. resultant.grad;
);
resultant

(* these are all the operator overloadings we need to associate with each of the binary operators *)
let ( + ) = add
let ( * ) = mul
let ( ** ) = exp

let relu base =
let resultant = create (max 0. base.data) "relu" [base] in
let resultant = create ~op:"relu" ~deps:[base] (max 0. base.data) in

resultant.backward <- (fun () ->
base.grad <- base.grad +. (if base.data > 0. then resultant.grad else 0.);
Expand All @@ -58,7 +66,7 @@ module Neuron = struct
let visited, resultant =
List.fold_left (fun (visited, resultant) dependent ->
sort_topologically visited resultant dependent
) (visited, resultant) candidate.dependents
) (visited, resultant) candidate.dependencies
in
(visited, candidate :: resultant)
else
Expand All @@ -72,3 +80,4 @@ module Neuron = struct
(* and propagate the gradient changes *)
List.iter (fun v -> v.backward ()) (List.rev resultant)
end

10 changes: 8 additions & 2 deletions lib/neuron.mli
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
module Neuron : sig
type t

(* getter for the data *)
val data : t -> float

(* Constructor; constructs a unit neuron of a value and an operator. *)
val create : float -> string -> t list -> t
val create : ?op:string -> ?deps:t list -> float -> t

(* Handles the gradient flows in addition operation. *)
val add : t -> t -> t
val ( + ) : t -> t -> t

(* Handles the gradient flows in multiplication operation. *)
val mul : t -> t -> t
val ( * ) : t -> t -> t

(* Handles the gradient flows in exponent / power operation. *)
(* second argument is the exponent. *)
val exp : t -> float -> t
val ( ** ) : t -> float -> t

(* Handles the gradient flows in ReLU operation. *)
val relu : t -> t

(* Handles backpropagation of the gradients for all the nodes connected to the specified base node. *)
(* Handles backpropagation of the gradients for all the nodes connected to this as the base node. *)
val backpropagate : t -> unit
end
3 changes: 3 additions & 0 deletions test/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(test
(name smolgrad_tests)
(libraries ounit2 smolgrad))
24 changes: 24 additions & 0 deletions test/smolgrad_tests.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
open Smolgrad.Neuron

let test_simple_operation () =
let a = Neuron.create 4.0 in
let b = Neuron.create 2.0 in

let abba = Neuron.(a + b) in
assert (Neuron.data abba = 6.0);
;;

(* here we open the Neuron module wide open locally, thereby allowing the clean custom `+` operator usage *)
let test_custom_operator () =
let open Neuron in
let a = create 4.0 in
let b = create 2.0 in

let abba = a + b in
assert (Neuron.data abba = 6.0);
;;

let () =
test_simple_operation ();
test_custom_operator ();
;;

0 comments on commit 5300009

Please sign in to comment.