From 5300009fbd911a9fa09d3012d52982ba0fe491b3 Mon Sep 17 00:00:00 2001 From: Rounak Datta Date: Wed, 10 Apr 2024 17:54:46 +0530 Subject: [PATCH] Add custom operators; add simple tests --- lib/dune | 5 +++-- lib/network.ml | 0 lib/neuron.ml | 25 +++++++++++++++++-------- lib/neuron.mli | 10 ++++++++-- test/dune | 3 +++ test/smolgrad_tests.ml | 24 ++++++++++++++++++++++++ 6 files changed, 55 insertions(+), 12 deletions(-) delete mode 100644 lib/network.ml create mode 100644 test/dune create mode 100644 test/smolgrad_tests.ml diff --git a/lib/dune b/lib/dune index ea9b436..1c519e3 100644 --- a/lib/dune +++ b/lib/dune @@ -1,3 +1,4 @@ (library - (name nn) - (modules Neuron Network)) + (name smolgrad) + (public_name smolgrad) + (modules neuron)) diff --git a/lib/network.ml b/lib/network.ml deleted file mode 100644 index e69de29..0000000 diff --git a/lib/neuron.ml b/lib/neuron.ml index c53e201..b4ddae6 100644 --- a/lib/neuron.ml +++ b/lib/neuron.ml @@ -5,19 +5,22 @@ module Neuron = struct mutable backward : unit -> unit; operator : string; - dependents : t list; + dependencies : t list; } - let create dt op deps = { + let data base = + base.data + + let create ?(op="n/a") ?(deps=[]) dt = { data = dt; grad = 0.; backward = (fun () -> ()); operator = op; - dependents = deps; + dependencies = deps; } let add base partner = - let resultant = create (base.data +. partner.data) "+" [base; partner] in + let resultant = create ~op:"+" ~deps:[base; partner] (base.data +. partner.data) in resultant.backward <- (fun () -> base.grad <- base.grad +. resultant.grad; @@ -26,7 +29,7 @@ module Neuron = struct resultant let mul base partner = - let resultant = create (base.data *. partner.data) "*" [base; partner] in + let resultant = create ~op:"*" ~deps:[base; partner] (base.data *. partner.data) in resultant.backward <- (fun () -> base.grad <- base.grad +. partner.data *. resultant.grad; @@ -35,15 +38,20 @@ module Neuron = struct resultant let exp base exponent = - let resultant = create (base.data ** exponent) "**" [base] in + let resultant = create ~op:"**" ~deps:[base] (base.data ** exponent) in resultant.backward <- (fun () -> base.grad <- base.grad +. exponent *. (base.data ** (exponent -. 1.)) *. resultant.grad; ); resultant +(* these are all the operator overloadings we need to associate with each of the binary operators *) + let ( + ) = add + let ( * ) = mul + let ( ** ) = exp + let relu base = - let resultant = create (max 0. base.data) "relu" [base] in + let resultant = create ~op:"relu" ~deps:[base] (max 0. base.data) in resultant.backward <- (fun () -> base.grad <- base.grad +. (if base.data > 0. then resultant.grad else 0.); @@ -58,7 +66,7 @@ module Neuron = struct let visited, resultant = List.fold_left (fun (visited, resultant) dependent -> sort_topologically visited resultant dependent - ) (visited, resultant) candidate.dependents + ) (visited, resultant) candidate.dependencies in (visited, candidate :: resultant) else @@ -72,3 +80,4 @@ module Neuron = struct (* and propagate the gradient changes *) List.iter (fun v -> v.backward ()) (List.rev resultant) end + diff --git a/lib/neuron.mli b/lib/neuron.mli index 5e70884..f06ae7d 100644 --- a/lib/neuron.mli +++ b/lib/neuron.mli @@ -1,22 +1,28 @@ module Neuron : sig type t + (* getter for the data *) + val data : t -> float + (* Constructor; constructs a unit neuron of a value and an operator. *) - val create : float -> string -> t list -> t + val create : ?op:string -> ?deps:t list -> float -> t (* Handles the gradient flows in addition operation. *) val add : t -> t -> t + val ( + ) : t -> t -> t (* Handles the gradient flows in multiplication operation. *) val mul : t -> t -> t + val ( * ) : t -> t -> t (* Handles the gradient flows in exponent / power operation. *) (* second argument is the exponent. *) val exp : t -> float -> t + val ( ** ) : t -> float -> t (* Handles the gradient flows in ReLU operation. *) val relu : t -> t - (* Handles backpropagation of the gradients for all the nodes connected to the specified base node. *) + (* Handles backpropagation of the gradients for all the nodes connected to this as the base node. *) val backpropagate : t -> unit end diff --git a/test/dune b/test/dune new file mode 100644 index 0000000..5b572af --- /dev/null +++ b/test/dune @@ -0,0 +1,3 @@ +(test + (name smolgrad_tests) + (libraries ounit2 smolgrad)) diff --git a/test/smolgrad_tests.ml b/test/smolgrad_tests.ml new file mode 100644 index 0000000..7ee5522 --- /dev/null +++ b/test/smolgrad_tests.ml @@ -0,0 +1,24 @@ +open Smolgrad.Neuron + +let test_simple_operation () = + let a = Neuron.create 4.0 in + let b = Neuron.create 2.0 in + + let abba = Neuron.(a + b) in + assert (Neuron.data abba = 6.0); +;; + +(* here we open the Neuron module wide open locally, thereby allowing the clean custom `+` operator usage *) +let test_custom_operator () = + let open Neuron in + let a = create 4.0 in + let b = create 2.0 in + + let abba = a + b in + assert (Neuron.data abba = 6.0); +;; + +let () = + test_simple_operation (); + test_custom_operator (); +;;