Skip to content

Commit

Permalink
Add tests for Network
Browse files Browse the repository at this point in the history
  • Loading branch information
rounakdatta committed Apr 13, 2024
1 parent 1ecf443 commit 89d120b
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 2 deletions.
3 changes: 3 additions & 0 deletions lib/layer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ module Layer = struct
neurons : Neuron.Neuron.t list
}

let parameters (layer: t) =
List.map Neuron.Neuron.parameters layer.neurons

let create number_of_input_dimensions number_of_neurons is_non_linear =
let neurons = List.init number_of_neurons (fun _ -> Neuron.Neuron.create number_of_input_dimensions is_non_linear) in
{ neurons = neurons }
Expand Down
3 changes: 3 additions & 0 deletions lib/layer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
module Layer : sig
type t

(* getter for the parameters (aka weights and biases) of the layer *)
val parameters : t -> Neuron.Neuron.out_t list

(* Constructor; constructs a layer of neurons *)
val create : int -> int -> bool -> t

Expand Down
6 changes: 5 additions & 1 deletion lib/network.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ module Network = struct
layers : Layer.Layer.t list;
}

let parameters (network : t) =
List.map (fun x -> Layer.Layer.parameters x) network.layers

(* the number of output parameters of a layer is the number of neurons in that layer *)
(* important to understand that the input layer isn't truly a layer with weights and biases; just an abstraction *)
let create number_of_input_dimensions number_of_neurons_per_layer =
Expand All @@ -11,7 +14,8 @@ module Network = struct
let rec build_layers stacked_layers sizes =
match sizes with
| input :: output :: rest ->
let layer = Layer.Layer.create input output true in
let has_non_linear_activation_for_layer = (rest <> []) in
let layer = Layer.Layer.create input output has_non_linear_activation_for_layer in

(* note how we are stacking the layers in reverse order *)
build_layers (layer :: stacked_layers) (output :: rest)
Expand Down
3 changes: 3 additions & 0 deletions lib/network.mli
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
module Network : sig
type t

(* getter for the parameters (aka weights and biases) of the entire network *)
val parameters : t -> Neuron.Neuron.out_t list list

val create : int -> int list -> t

val propagate_input : t -> Variable.Variable.t list -> Variable.Variable.t list
Expand Down
23 changes: 23 additions & 0 deletions test/network_operations.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
open Smolgrad.Variable
open Smolgrad.Network

let test_neural_network_initialization () =
let nn = Network.create 3 [5; 2; 3] in

Alcotest.(check int)
"Network: Number of layers are expected"
3
(List.length (Network.parameters nn));
;;

let test_neural_network_propagation_of_input () =
let nn = Network.create 3 [5; 2; 6] in
let input_vector = [Variable.create 3.0; Variable.create 2.0; Variable.create 1.0] in
let output_vector = Network.propagate_input nn input_vector in

(* the dimension of the output vector should be equal to the number of neurons in the last layer *)
Alcotest.(check int)
"Network: Output vector has the expected number of dimensions"
6
(List.length output_vector);
;;
2 changes: 1 addition & 1 deletion test/neuron_operations.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
open Smolgrad.Neuron
open Smolgrad.Variable
open Smolgrad.Neuron

let test_neuron_initialization () =
let n = Neuron.create 5 true in
Expand Down
4 changes: 4 additions & 0 deletions test/smolgrad_tests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ let () =
(* unit tests for Neuron *)
Neuron_operations.test_neuron_initialization ();
Neuron_operations.test_neuron_weights_reacting_to_input ();

(* unit tests for Network *)
Network_operations.test_neural_network_initialization ();
Network_operations.test_neural_network_propagation_of_input ();
;;

0 comments on commit 89d120b

Please sign in to comment.