diff --git a/src/metric.jl b/src/metric.jl index 3f35e7455ac5..a9e4db2dea4e 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -113,4 +113,45 @@ function reset!(metric :: Accuracy) metric.n_sample = 0 end +#=doc +.. class:: MSE + + Mean Squared Error. + + Calculates the mean squared error regression loss in one dimension. +=# + +type MSE <: AbstractEvalMetric + mse_sum :: Float64 + n_sample :: Int + MSE() = new(0.0, 0) +end + +function _update_single_output(metric :: MSE, label :: NDArray, pred :: NDArray) + label = copy(label) + pred = copy(pred) + + n_sample = size(pred)[end] + metric.n_sample += n_sample + + for i = 1:n_sample + metric.mse_sum += (label[i] - pred[i])^2 + end +end + +function update!(metric :: MSE, labels :: Vector{NDArray}, preds :: Vector{NDArray}) + @assert length(labels) == length(preds) + for i = 1:length(labels) + _update_single_output(metric, labels[i], preds[i]) + end +end + +function get(metric :: MSE) + return [(:MSE, metric.mse_sum / metric.n_sample)] +end + +function reset!(metric :: MSE) + metric.mse_sum = 0.0 + metric.n_sample = 0 +end