Skip to content

Commit

Permalink
Working on GH-8: Make the first training example in the Sequential gu…
Browse files Browse the repository at this point in the history
…ide pass.

Adding the RMSProp algorithm.
  • Loading branch information
cesarsouza committed Aug 27, 2017
1 parent 2b16adc commit 6200ec4
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 18 deletions.
9 changes: 6 additions & 3 deletions Sources/Backends/Base/IBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ public interface IBackend : IDisposable

Tensor clip(Tensor norms, int v, int maxValue);

Tensor zeros(int[] shape, TFDataType dtype = Utils.DEFAULT_DTYPE, string name = null);

Tensor zeros(int?[] shape, TFDataType dtype = Utils.DEFAULT_DTYPE, string name = null);

float epsilon();

TFDataType floatx();
Expand Down Expand Up @@ -113,6 +117,7 @@ public interface IBackend : IDisposable

Tensor subtract<T>(Tensor a, T b);


Tensor subtract<T>(T a, Tensor b);


Expand Down Expand Up @@ -213,9 +218,7 @@ public interface IBackend : IDisposable
Tensor update_add(Tensor iterations, int v);


object get_variable_shape(Tensor p);

Tensor get_variable_shape(object s);
int?[] get_variable_shape(Tensor x);

Tensor sum(double v, Tensor tensor);

Expand Down
27 changes: 19 additions & 8 deletions Sources/Backends/TensorFlowBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,14 @@ public Function function(object inputs, List<Tensor> list, List<Tensor> updates,




public object get_variable_shape(Tensor p)
{
throw new NotImplementedException();
}

public Tensor get_variable_shape(object s)
/// <summary>
/// Returns the shape of a variable.
/// </summary>
///
public int?[] get_variable_shape(Tensor x)
{
throw new NotImplementedException();
// https://github.com/fchollet/keras/blob/f65a56fb65062c8d14d215c9f4b1015b97cc5bf3/keras/backend/tensorflow_backend.py#L2192
return int_shape(x);
}

public List<Tensor> gradients(ILoss loss, object param)
Expand Down Expand Up @@ -919,6 +918,18 @@ public Tensor tensor(TFOutput output)
}


/// <summary>
/// Instantiates an all-zeros variable and returns it.
/// </summary>
/// <param name="shape">Tuple of integers, shape of returned Keras variable.</param>
/// <param name="dtype">Data type of returned Keras variable.</param>
/// <param name="name">String, name of returned Keras variable.</param>
/// <returns>A variable(including Keras metadata), filled with <c>0.0</c>.</returns>
public Tensor zeros(int?[] shape, TFDataType dtype = Utils.DEFAULT_DTYPE, string name = null)
{
return zeros(shape.Select(i => i.Value).ToArray(), dtype, name);
}

/// <summary>
/// Instantiates an all-zeros variable and returns it.
/// </summary>
Expand Down
5 changes: 5 additions & 0 deletions Sources/Engine/Topology/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ public static implicit operator TFOutput(Tensor t)
return b.K.add(a, b);
}

public static Tensor operator +(Tensor a, double b)
{
return a.K.add(a, b);
}

public static Tensor operator -(double a, Tensor b)
{
return b.K.subtract(a, b);
Expand Down
4 changes: 2 additions & 2 deletions Sources/Optimizers/Base/Optimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ namespace KerasSharp.Optimizers
[DataContract]
public abstract class OptimizerBase
{
private List<Tensor> updates;
private List<Tensor> weights;
protected List<Tensor> updates;
protected List<Tensor> weights;
public double clipnorm;
public double clipvalue;

Expand Down
74 changes: 72 additions & 2 deletions Sources/Optimizers/RMSProp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,83 @@ namespace KerasSharp.Optimizers

using static KerasSharp.Backends.Current;
using KerasSharp.Engine.Topology;
using System.Linq;

/// <summary>
/// RMSProp optimizer.
/// </summary>
///
/// <remarks>
/// It is recommended to leave the parameters of this optimizer at their default values (except the learning rate, which can be freely tuned).
/// </remarks>
///
[DataContract]
public class RMSProp : OptimizerBase, IOptimizer
{
public List<Tensor> get_updates(List<Tensor> collected_trainable_weights, Dictionary<Tensor, IWeightConstraint> constraints, Tensor total_loss)
private Tensor decay;
private double initial_decay;
private Tensor iterations;
private Tensor lr;
private Tensor rho;
private double epsilon;

// https://github.com/fchollet/keras/blob/f65a56fb65062c8d14d215c9f4b1015b97cc5bf3/keras/optimizers.py#L190

public RMSProp()
: this(lr: 0.001, rho: 0.9, epsilon: 1e-8, decay: 0.0)
{

}

public RMSProp(double lr, double rho = 0.9, double epsilon = 1e-8, double decay = 0.0)
{
this.lr = K.variable(lr, name: "lr");
this.rho = K.variable(rho, name: "rho");
this.epsilon = epsilon;
this.decay = K.variable(decay, name: "decay");
this.initial_decay = decay;
this.iterations = K.variable(0.0, name: "iterations");
}

public List<Tensor> get_updates(List<Tensor> parameters, Dictionary<Tensor, IWeightConstraint> constraints, Tensor loss)
{
throw new NotImplementedException();
// https://github.com/fchollet/keras/blob/f65a56fb65062c8d14d215c9f4b1015b97cc5bf3/keras/optimizers.py#L221

List<Tensor> grads = this.get_gradients(loss, parameters);
List<int?[]> shapes = parameters.Select(p => K.get_variable_shape(p)).ToList();
List<Tensor> accumulators = shapes.Select(shape => K.zeros(shape)).ToList();
this.weights = accumulators;
this.updates = new List<Tensor>();

Tensor lr = this.lr;
if (this.initial_decay > 0)
{
lr = lr * (1.0 / (1.0 + this.decay * this.iterations));
this.updates.Add(K.update_add(this.iterations, 1));
}

for (int i = 0; i < parameters.Count; i++)
{
Tensor p = parameters[i];
Tensor g = grads[i];
Tensor a = accumulators[i];

// update accumulator
Tensor new_a = this.rho * a + (1.0 - this.rho) * K.square(g);
this.updates.Add(K.update(a, new_a));
Tensor new_p = p - lr * g / (K.sqrt(new_a) + this.epsilon);

// apply constraints
if (constraints.ContainsKey(p))
{
IWeightConstraint c = constraints[p];
new_p = c.Call(new_p);
}

this.updates.Add(K.update(p, new_p));
}

return this.updates;
}
}
}
4 changes: 2 additions & 2 deletions Sources/Optimizers/SGD.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ public List<Tensor> get_updates(List<Tensor> param, Dictionary<Tensor, IWeightCo
}

// momentum
var shapes = param.Select(p => K.get_variable_shape(p)).ToList();
List<Tensor> moments = shapes.Select(s => K.get_variable_shape(s)).ToList();
List<int?[]> shapes = param.Select(p => K.get_variable_shape(p)).ToList();
List<Tensor> moments = shapes.Select(s => K.zeros(s)).ToList();

this.weights = new[] { this.iterations }.Concat(moments).ToList();

Expand Down
4 changes: 3 additions & 1 deletion Tests/SequentialTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ import numpy as np
Assert.AreEqual(model._constraints, model.model._constraints);
Assert.AreEqual(model._feed_inputs, model.model._feed_inputs);
Assert.AreEqual(model._feed_input_names, model.model._feed_input_names);
Assert.AreEqual(model._feed_sample_weight_modes, model.model._feed_sample_weight_modes);
Assert.AreEqual(null, model._feed_sample_weight_modes);
Assert.AreEqual(1, model.model._feed_sample_weight_modes.Count);
Assert.AreEqual(null, model.model._feed_sample_weight_modes[0]);
Assert.AreEqual(model._initial_weights, model.model._initial_weights);
Assert.AreEqual(model._losses, model.model._losses);
Assert.AreEqual(model._non_trainable_weights, model.model._non_trainable_weights);
Expand Down

0 comments on commit 6200ec4

Please sign in to comment.