From cddb5c1b39877161db7af00eeda76955cc989cc8 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 11 Sep 2016 10:12:10 -0700 Subject: [PATCH] Strict gradient boundary check (#44) --- nnvm/src/pass/gradient.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 2a6bd00e0e8ef..0f3f57fd7cf4d 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -115,6 +115,8 @@ Graph Gradient(Graph src) { } std::vector input_grads = grad_fun_map[ptr->op()] (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); + CHECK_EQ((*rit)->inputs.size(), input_grads.size()) + << "Gradient function not returning enough gradient"; auto git = input_grads.begin(); for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));