Skip to content

Commit

Permalink
tanh improvements
Browse files Browse the repository at this point in the history
Fix issues with `tanh` when the value is negative and improve code
quality.
  • Loading branch information
Alejandro Isaza committed Mar 28, 2016
1 parent 09c2c4b commit e5acb79
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
4 changes: 2 additions & 2 deletions Source/Metal/LSTM.metal
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ kernel void lstm_forward(const device float* input [[ buffer(0) ]],
}

const auto previousActivation = state[unit + batchElement * 2 * params.unitCount];
auto activation = sigmoid(forgetGate + 1) * previousActivation + sigmoid(inputGate) * safe_tanh(newInput);
auto activation = bc::sigmoid(forgetGate + 1) * previousActivation + bc::sigmoid(inputGate) * bc::tanh(newInput);
if (params.clipTo > 0) {
activation = clamp(activation, -params.clipTo, params.clipTo);
}
const auto out = sigmoid(outputGate) * safe_tanh(activation);
const auto out = bc::sigmoid(outputGate) * bc::tanh(activation);

output[unit + batchElement * params.unitCount] = out;
state[unit + batchElement * 2 * params.unitCount] = activation;
Expand Down
36 changes: 22 additions & 14 deletions Source/Metal/Utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,41 @@

#include <metal_stdlib>

namespace bc {

inline float safe_tanh(const float x) {
const auto P0 = -8.237728127e-1f;
const auto P1 = -3.831010665e-3f;
const auto Q0 = 2.471319654f;
const auto xLarge = 8.66433975699931636772f;
const auto xMedium = 5.4930614433405484570e-1f;
const auto xSmall = 4.22863966691620432990e-4f;
/// The metal standard library `tanh` function causes NaN errors on certain GPUs. This is likely due to a naïve implementation that uses the series expansion of tanh(x). This `tanh` implementation is based on "Accurate Hyperbolic Tangent Computation" by Nelson H. F. Beebe, http://www.math.utah.edu/~beebe/software/ieee/tanh.pdf
inline float tanh(const float x) {
static constexpr auto xLarge = 8.66433975699931636772f;
static constexpr auto xMedium = 5.4930614433405484570e-1f;
static constexpr auto xSmall = 4.22863966691620432990e-4f;

const auto sign = x > 0 ? 1 : -1;
const auto sign = metal::sign(x);
const auto absX = sign * x;

if (absX >= xLarge) {
return sign;
} else if (xMedium <= x) {
}

if (absX >= xMedium) {
const auto temp = 0.5 - 1 / (1 + metal::exp(2 * absX));
return sign * (temp + temp);
} else if (xSmall <= x) {
const auto g = absX * absX;
const auto R = g * (P1 * g + P0) / (g + Q0);
return x + x * R;
} else {
}

if (absX < xSmall) {
return sign * absX;
}

const auto P0 = -8.237728127e-1f;
const auto P1 = -3.831010665e-3f;
const auto Q0 = 2.471319654f;

const auto g = absX * absX;
const auto R = g * (P1 * g + P0) / (g + Q0);
return x + x * R;
}

inline float sigmoid(const float x) {
return 1.0 / (1.0 + metal::exp(-x));
}

} // namespace

0 comments on commit e5acb79

Please sign in to comment.