Skip to content

Commit

Permalink
Merge pull request apache#25 from Andy-P/pull-request/f5f27779
Browse files Browse the repository at this point in the history
Adds Mean Squared Error evaluation metric
  • Loading branch information
pluskid committed Nov 12, 2015
2 parents 3226713 + f5f2777 commit 94e17a6
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,45 @@ function reset!(metric :: Accuracy)
metric.n_sample = 0
end

#=doc
.. class:: MSE
Mean Squared Error.
Calculates the mean squared error regression loss in one dimension.
=#

type MSE <: AbstractEvalMetric
mse_sum :: Float64
n_sample :: Int

MSE() = new(0.0, 0)
end

function _update_single_output(metric :: MSE, label :: NDArray, pred :: NDArray)
label = copy(label)
pred = copy(pred)

n_sample = size(pred)[end]
metric.n_sample += n_sample

for i = 1:n_sample
metric.mse_sum += (label[i] - pred[i])^2
end
end

function update!(metric :: MSE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

function get(metric :: MSE)
return [(:MSE, metric.mse_sum / metric.n_sample)]
end

function reset!(metric :: MSE)
metric.mse_sum = 0.0
metric.n_sample = 0
end

0 comments on commit 94e17a6

Please sign in to comment.