Skip to content

Commit

Permalink
Restructure; add furthur operations
Browse files Browse the repository at this point in the history
  • Loading branch information
rounakdatta committed Apr 9, 2024
1 parent 6f584f9 commit fad34ca
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
49 changes: 35 additions & 14 deletions lib/neuron.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,48 @@ module Neuron = struct
mutable data : float;
mutable grad : float;
mutable backward : unit -> unit;

(* capturing the operator would be useful later when we add some viz *)
op : string;
prev : t list;
}

let create data operator = { data; grad = 0.;
backward = (fun () -> ()); op = operator; prev = []
let create data operator = {
data;
grad = 0.;
backward = (fun () -> ());
op = operator; prev = [];
}

let add a b =
let out = create (a.data +. b.data) "+" in
out.backward <- (fun () ->
a.grad <- a.grad +. out.grad;
b.grad <- b.grad +. out.grad;
let add base partner =
let resultant = create (base.data +. partner.data) "+" in

resultant.backward <- (fun () ->
base.grad <- base.grad +. resultant.grad;
partner.grad <- partner.grad +. resultant.grad;
);
out
resultant

let mul a b =
let out = create (a.data *. b.data) "*" in
out.backward <- (fun () ->
a.grad <- a.grad +. b.data *. out.grad;
b.grad <- b.grad +. a.data *. out.grad;
let mul base partner =
let resultant = create (base.data *. partner.data) "*" in

resultant.backward <- (fun () ->
base.grad <- base.grad +. partner.data *. resultant.grad;
partner.grad <- partner.grad +. base.data *. resultant.grad;
);
out
resultant

let exp base exponent =
let resultant = create (base.data ** exponent) "**" in

resultant.backward <- (fun () ->
base.grad <- base.grad +. exponent *. (base.data ** (exponent -. 1.)) *. resultant.grad;
)

let relu base =
let resultant = create (max 0. base.data) "relu" in

resultant.backward <- (fun () ->
base.grad <- base.grad +. (if base.data > 0. then resultant.grad else 0.);
)
end
14 changes: 12 additions & 2 deletions lib/neuron.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@ module Neuron : sig
(* Constructor; constructs a unit neuron of a value and an operator. *)
val create : float -> string -> t

(* Adds two values, resulting in a new value. *)
(* Handles the gradient flows in addition operation. *)
val add : t -> t -> t

(* Multiplies two values, resulting in a new value. *)
(* Handles the gradient flows in multiplication operation. *)
val mul : t -> t -> t

(* Handles the gradient flows in exponent / power operation. *)
(* second argument is the exponent. *)
val exp : t -> int -> t

(* Handles the gradient flows in ReLU operation. *)
val relu : t -> t

(* Handles backpropagation of the gradients. *)
val backpropagate : t -> unit
end

0 comments on commit fad34ca

Please sign in to comment.