-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add custom operators; add simple tests
- Loading branch information
1 parent
2b9c099
commit 5300009
Showing
6 changed files
with
55 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
(library | ||
(name nn) | ||
(modules Neuron Network)) | ||
(name smolgrad) | ||
(public_name smolgrad) | ||
(modules neuron)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,28 @@ | ||
module Neuron : sig | ||
type t | ||
|
||
(* getter for the data *) | ||
val data : t -> float | ||
|
||
(* Constructor; constructs a unit neuron of a value and an operator. *) | ||
val create : float -> string -> t list -> t | ||
val create : ?op:string -> ?deps:t list -> float -> t | ||
|
||
(* Handles the gradient flows in addition operation. *) | ||
val add : t -> t -> t | ||
val ( + ) : t -> t -> t | ||
|
||
(* Handles the gradient flows in multiplication operation. *) | ||
val mul : t -> t -> t | ||
val ( * ) : t -> t -> t | ||
|
||
(* Handles the gradient flows in exponent / power operation. *) | ||
(* second argument is the exponent. *) | ||
val exp : t -> float -> t | ||
val ( ** ) : t -> float -> t | ||
|
||
(* Handles the gradient flows in ReLU operation. *) | ||
val relu : t -> t | ||
|
||
(* Handles backpropagation of the gradients for all the nodes connected to the specified base node. *) | ||
(* Handles backpropagation of the gradients for all the nodes connected to this as the base node. *) | ||
val backpropagate : t -> unit | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
(test | ||
(name smolgrad_tests) | ||
(libraries ounit2 smolgrad)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
open Smolgrad.Neuron | ||
|
||
let test_simple_operation () = | ||
let a = Neuron.create 4.0 in | ||
let b = Neuron.create 2.0 in | ||
|
||
let abba = Neuron.(a + b) in | ||
assert (Neuron.data abba = 6.0); | ||
;; | ||
|
||
(* here we open the Neuron module wide open locally, thereby allowing the clean custom `+` operator usage *) | ||
let test_custom_operator () = | ||
let open Neuron in | ||
let a = create 4.0 in | ||
let b = create 2.0 in | ||
|
||
let abba = a + b in | ||
assert (Neuron.data abba = 6.0); | ||
;; | ||
|
||
let () = | ||
test_simple_operation (); | ||
test_custom_operator (); | ||
;; |