Skip to content

Commit

Permalink
Enable aux data (apache#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 26, 2018
1 parent bee4698 commit 1edb62b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
8 changes: 7 additions & 1 deletion nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ class IndexedGraph {
inline const std::vector<uint32_t>& input_nodes() const {
return input_nodes_;
}
/*! \return list of mutable nodes */
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return mutable_input_nodes_;
}
/*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
Expand All @@ -161,8 +165,10 @@ class IndexedGraph {
explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure.
std::vector<Node> nodes_;
// index to input nodes
// index all to input nodes
std::vector<uint32_t> input_nodes_;
// index to mutable input nodes
std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
// mapping from node to index.
Expand Down
3 changes: 3 additions & 0 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,14 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
// member functions of OpMap
template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0;
const uint32_t idx = op->index_;
return idx < data_.size() ? data_[idx].second : 0;
}

template<typename ValueType>
inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
CHECK(op != nullptr);
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second)
<< "Attribute " << attr_name_
Expand All @@ -383,6 +385,7 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {

template<typename ValueType>
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
if (op == nullptr) return def_value;
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second) {
return data_[idx].first;
Expand Down
9 changes: 9 additions & 0 deletions nnvm/src/core/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* \brief Graph node data structure.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <limits>

namespace nnvm {
Expand Down Expand Up @@ -57,12 +58,20 @@ IndexedGraph::IndexedGraph(const Graph &g) {
node2index_.at(e.node.get()), e.index, e.version});
}

static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::unordered_set<uint32_t> mutable_inputs;
// setup array view
// input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>(
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op)) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
}
const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
Expand Down
2 changes: 1 addition & 1 deletion nnvm/src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ NNVM_REGISTER_PASS(InferShape)
.set_body([](Graph ret) {
return InferAttr<TShape>(
std::move(ret), TShape(),
"FInferShape", "shape_args", "shape_attr_key",
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; });
})
Expand Down

0 comments on commit 1edb62b

Please sign in to comment.