Skip to content

Commit

Permalink
Add additional tests for neuron module
Browse files Browse the repository at this point in the history
  • Loading branch information
rounakdatta committed Apr 11, 2024
1 parent bd27de6 commit adb9fa4
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
14 changes: 13 additions & 1 deletion lib/neuron.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,25 @@ module Neuron = struct
is_non_linear : bool;
}

(* this is just a getter data structure for the parameters *)
type out_t = {
weights : Variable.Variable.t list;
bias : Variable.Variable.t;
}

(* note how we explicitly define the types to avoid confusions and fights with the compiler *)
let parameters (neuron: t) : out_t = {
weights = neuron.weights;
bias = neuron.bias;
}

let create number_of_inputs is_non_linear = {
weights = List.init number_of_inputs (fun _ -> Variable.Variable.create (random_weight_initializer));
bias = Variable.Variable.create 0.0;
is_non_linear = is_non_linear;
}

let weigh_inputs neuron input_vector =
let weigh_inputs (neuron: t) input_vector =
(* one-to-one multiplication of inputs to their corresponding weights *)
let weighted_sum = List.fold_left2 (fun accumulator weight_i input_i -> Variable.Variable.(accumulator + weight_i * input_i))
(Variable.Variable.create 0.0) neuron.weights input_vector in
Expand Down
10 changes: 10 additions & 0 deletions lib/neuron.mli
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
module Neuron : sig
type t

(* we've defined this public type both here in the interface file as well as the implementation file *)
(* this is required for the internal attributes of the type to be accessible publicly *)
type out_t = {
weights : Variable.Variable.t list;
bias : Variable.Variable.t;
}

(* getter for the parameters (aka weights and biases) of the neuron *)
val parameters : t -> out_t

(* Constructor; constructs a unit neuron *)
val create : int -> bool -> t

Expand Down
25 changes: 25 additions & 0 deletions test/neuron_operations.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
open Smolgrad.Neuron

let test_neuron_initialization () =
let n = Neuron.create 5 true in
let n_weights = (Neuron.parameters n).weights in
assert (Smolgrad.Variable.Variable.data (Neuron.parameters n).bias = 0.0);

Alcotest.(check int)
"Neuron: Right number of parameters are set"
5
(List.length (Neuron.parameters n).weights);

Alcotest.(check (float 0.0))
"Neuron: Bias is initially set to 0.0"
0.0
(Smolgrad.Variable.Variable.data (Neuron.parameters n).bias);

let are_weights_in_range weights =
List.for_all (fun x -> (Smolgrad.Variable.Variable.data x) >= -1.0 && (Smolgrad.Variable.Variable.data x) <= 1.0) weights in

Alcotest.(check bool)
"Neuron: All weights are in range [-1, 1]"
true
(are_weights_in_range n_weights);
;;
3 changes: 3 additions & 0 deletions test/smolgrad_tests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ let () =
Variable_operations.test_custom_operator ();
Variable_operations.test_graph_construction ();
Variable_operations.test_backpropagation ();

(* unit tests for Neuron *)
Neuron_operations.test_neuron_initialization ();
;;

0 comments on commit adb9fa4

Please sign in to comment.