Skip to content

Commit

Permalink
[OP] Enable register via match tag (#57)
Browse files Browse the repository at this point in the history
* [OP] Enable register via match tag

* more docs on usage
  • Loading branch information
tqchen committed May 29, 2018
1 parent 8928f4e commit 7c4ed36
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 37 deletions.
38 changes: 22 additions & 16 deletions nnvm/example/src/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape)
NNVM_REGISTER_OP(cast)
.describe("cast source type to target")
.set_num_inputs(1)
.include("ElementwiseOpAttr")
.set_attr_parser(
[](NodeAttrs* attrs) {
// parse attr parser to get target attribute
Expand All @@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast)
CHECK(is >> dtype);
attrs->parsed = std::move(dtype);
})
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs,
std::vector<int> *itype,
Expand All @@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast)
return true;
});

NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});

NNVM_REGISTER_OP(identity)
.describe("identity function")
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", SameShape)
.include("ElementwiseOpAttr")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
Expand All @@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.add_alias("__add_symbol__")
.set_attr<FInferShape>("FInferShape", SameShape)
.include("ElementwiseOpAttr")
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
Expand All @@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(mul)
.describe("multiply two data together")
.set_num_inputs(2)
.include("ElementwiseOpAttr")
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>(
Expand Down Expand Up @@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign)
return std::vector<uint32_t>{0};
});

NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
.set_attr<FInferShape>("FInferShape", SameShape);


NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.include("ElementwiseOpAttr")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});


} // namespace myproject
144 changes: 125 additions & 19 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Node;
struct NodeAttrs;
template<typename ValueType>
class OpMap;
class OpGroup;
class OpRegistryEntry;
using dmlc::ParamFieldInfo;

Expand All @@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
* .set_attr<OpKernel>("OpKernel<gpu>", AddKernel)
* .include("ElementwiseOpAttr");
*
* // can register attribute by group
* // all the ops that include the group get the attribute.
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another")
Expand All @@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* // Can call regster multiple times in different files
* // to register different part of information
* NNVM_REGISTER_OP(sub)
* .set_attr<OpKernel>("gpu_kernel", SubKernel);
* .set_attr<OpKernel>("OpKernel<gpu>", SubKernel);
* .include("ElementwiseOpAttr");
*
* // get operators from registry.
* void my_function() {
Expand All @@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
*
* // get additional registered information,
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("gpu_kernel");
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("OpKernel<gpu>");
* // we can get the kernel functions by using operator as key.
* auto add_kernel = kernel[add];
* auto sub_kernel = kernel[sub];
Expand Down Expand Up @@ -199,6 +207,23 @@ class Op {
* \return reference to self.
*/
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 10);
/*!
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
Expand All @@ -207,14 +232,13 @@ class Op {
*/
Op& add_alias(const std::string& alias); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \tparam ValueType The type of the value to be set.
* \brief Include all the attributes from an registered op group.
* \param group_name The name of the group.
* \return reference to self.
*
* \sa NNVM_REGISTER_OP_GROUP
*/
template<typename ValueType>
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value);
Op& include(const std::string& group_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
Expand All @@ -235,6 +259,7 @@ class Op {
private:
template<typename ValueType>
friend class OpMap;
friend class OpGroup;
friend class dmlc::Registry<Op>;
// Program internal unique index of operator.
// Used to help index the program.
Expand All @@ -246,6 +271,13 @@ class Op {
// update the attribute OpMap
static void UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater);
// add a trigger based on tag matching on certain tag attribute
// This will apply trigger on all the op such that
// include the corresponding group.
// The trigger will also be applied to all future registrations
// that calls include
static void AddGroupTrigger(const std::string& group_name,
std::function<void(Op*)> trigger);
};

/*!
Expand Down Expand Up @@ -285,14 +317,44 @@ class OpMap {
OpMap() = default;
};

/*!
* \brief auxiliary data structure used to
* set attributes to a group of operators
*/
class OpGroup {
public:
/*! \brief the tag key to be matched */
std::string group_name;
/*!
* \brief Register additional attributes to operator group.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 1);
};

// internal macros to make
#define NNVM_REGISTER_VAR_DEF(OpName) \
#define NNVM_REGISTER_VAR_DEF(OpName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName

#define NNVM_REGISTER_GVAR_DEF(TagName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName

/*!
* \def NNVM_REGISTER_OP
* \brief Register
* This macro must be used under namespace dmlc, and only used once in cc file.
* \brief Register a new operator, or set attribute of the corresponding op.
*
* \param OpName The name of registry
*
* \code
Expand All @@ -308,6 +370,31 @@ class OpMap {
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)

/*!
* \def NNVM_REGISTER_OP_GROUP
* \brief Register attribute to a group of operators.
* These attributes will be registered to Op that include the group.
*
* \param GroupName The name of the group.
*
* \code
*
* NNVM_REGISTER_OP(add)
* .include("ElementwiseOpAttr");
*
* // register same attributes to all the ops that include the group
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(mul)
* .include("ElementwiseOpAttr");
*
* \endcode
*/
#define NNVM_REGISTER_OP_GROUP(GroupName) \
DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
::nnvm::OpGroup {#GroupName}

// implementations of template functions after this.
// member function of Op
template<typename ValueType>
Expand All @@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {

template<typename ValueType>
inline Op& Op::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) {
const std::string& attr_name,
const ValueType& value,
int plevel) {
CHECK_GT(plevel, 0)
<< "plevel in set_attr must be greater than 0";
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
UpdateAttrMap(attr_name,
[this, attr_name, value, plevel](any* pmap) {
// the callback is in lockscope so is threadsafe.
if (pmap->empty()) {
OpMap<ValueType> pm;
Expand All @@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*)
std::make_pair(ValueType(), 0));
}
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0)
CHECK(p.second != plevel)
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is already registered.";
vec[index_] = std::make_pair(value, 1);
<< " is already registered with same plevel=" << plevel;
if (p.second < plevel) {
vec[index_] = std::make_pair(value, plevel);
}
});
return *this;
}


inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
this->description = descr;
return *this;
Expand Down Expand Up @@ -409,7 +504,7 @@ template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0;
const uint32_t idx = op->index_;
return idx < data_.size() ? data_[idx].second : 0;
return idx < data_.size() ? (data_[idx].second != 0) : 0;
}

template<typename ValueType>
Expand All @@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def
}
}

template<typename ValueType>
inline OpGroup& OpGroup::set_attr(const std::string& attr_name,
const ValueType& value,
int plevel) {
auto trigger = [attr_name, value, plevel](Op* op) {
op->set_attr<ValueType>(attr_name, value, plevel);
};
Op::AddGroupTrigger(group_name, trigger);
return *this;
}

} // namespace nnvm

#endif // NNVM_OP_H_
Loading

0 comments on commit 7c4ed36

Please sign in to comment.