From adb9fa4ae2757104bb4e05bf7e628490ae677c2f Mon Sep 17 00:00:00 2001 From: Rounak Datta Date: Thu, 11 Apr 2024 12:50:07 +0530 Subject: [PATCH] Add additional tests for neuron module --- lib/neuron.ml | 14 +++++++++++++- lib/neuron.mli | 10 ++++++++++ test/neuron_operations.ml | 25 +++++++++++++++++++++++++ test/smolgrad_tests.ml | 3 +++ 4 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 test/neuron_operations.ml diff --git a/lib/neuron.ml b/lib/neuron.ml index a9c3e65..e4f21aa 100644 --- a/lib/neuron.ml +++ b/lib/neuron.ml @@ -8,13 +8,25 @@ module Neuron = struct is_non_linear : bool; } + (* this is just a getter data structure for the parameters *) + type out_t = { + weights : Variable.Variable.t list; + bias : Variable.Variable.t; + } + + (* note how we explicitly define the types to avoid confusions and fights with the compiler *) + let parameters (neuron: t) : out_t = { + weights = neuron.weights; + bias = neuron.bias; + } + let create number_of_inputs is_non_linear = { weights = List.init number_of_inputs (fun _ -> Variable.Variable.create (random_weight_initializer)); bias = Variable.Variable.create 0.0; is_non_linear = is_non_linear; } - let weigh_inputs neuron input_vector = + let weigh_inputs (neuron: t) input_vector = (* one-to-one multiplication of inputs to their corresponding weights *) let weighted_sum = List.fold_left2 (fun accumulator weight_i input_i -> Variable.Variable.(accumulator + weight_i * input_i)) (Variable.Variable.create 0.0) neuron.weights input_vector in diff --git a/lib/neuron.mli b/lib/neuron.mli index 7df98f2..223be7c 100644 --- a/lib/neuron.mli +++ b/lib/neuron.mli @@ -2,6 +2,16 @@ module Neuron : sig type t + (* we've defined this public type both here in the interface file as well as the implementation file *) + (* this is required for the internal attributes of the type to be accessible publicly *) + type out_t = { + weights : Variable.Variable.t list; + bias : Variable.Variable.t; + } + + (* getter for the parameters (aka weights and biases) of the neuron *) + val parameters : t -> out_t + (* Constructor; constructs a unit neuron *) val create : int -> bool -> t diff --git a/test/neuron_operations.ml b/test/neuron_operations.ml new file mode 100644 index 0000000..bd8ce1c --- /dev/null +++ b/test/neuron_operations.ml @@ -0,0 +1,25 @@ +open Smolgrad.Neuron + +let test_neuron_initialization () = + let n = Neuron.create 5 true in + let n_weights = (Neuron.parameters n).weights in + assert (Smolgrad.Variable.Variable.data (Neuron.parameters n).bias = 0.0); + + Alcotest.(check int) + "Neuron: Right number of parameters are set" + 5 + (List.length (Neuron.parameters n).weights); + + Alcotest.(check (float 0.0)) + "Neuron: Bias is initially set to 0.0" + 0.0 + (Smolgrad.Variable.Variable.data (Neuron.parameters n).bias); + + let are_weights_in_range weights = + List.for_all (fun x -> (Smolgrad.Variable.Variable.data x) >= -1.0 && (Smolgrad.Variable.Variable.data x) <= 1.0) weights in + + Alcotest.(check bool) + "Neuron: All weights are in range [-1, 1]" + true + (are_weights_in_range n_weights); +;; diff --git a/test/smolgrad_tests.ml b/test/smolgrad_tests.ml index fbded76..38cbabc 100644 --- a/test/smolgrad_tests.ml +++ b/test/smolgrad_tests.ml @@ -4,4 +4,7 @@ let () = Variable_operations.test_custom_operator (); Variable_operations.test_graph_construction (); Variable_operations.test_backpropagation (); + + (* unit tests for Neuron *) + Neuron_operations.test_neuron_initialization (); ;;