diff --git a/Makefile b/Makefile index 9d5740ff0c7e..2c9eb787889b 100644 --- a/Makefile +++ b/Makefile @@ -81,16 +81,16 @@ engine.o: src/dag_engine/simple_engine.cc narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h -static_operator.o: src/static_operator/static_operator.cc -static_operator_cpu.o: src/static_operator/static_operator_cpu.cc -static_operator_gpu.o: src/static_operator/static_operator_gpu.cu +static_operator.o: src/operator/static_operator/static_operator.cc +static_operator_cpu.o: src/operator/static_operator/static_operator_cpu.cc +static_operator_gpu.o: src/operator/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc static_graph.o : src/symbol/static_graph.cc registry.o: src/registry.cc c_api.o: src/c_api.cc operator.o: src/operator/static_operator_wrapper.cc -fully_connect_op_cpu.o: src/static_operator/fully_connect_op.cc -fully_connect_op_gpu.o: src/static_operator/fully_connect_op.cu +fully_connect_op_cpu.o: src/operator/static_operator/fully_connect_op.cc +fully_connect_op_gpu.o: src/operator/static_operator/fully_connect_op.cu lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h deleted file mode 100644 index cfc8a2eb6c20..000000000000 --- a/include/mxnet/atomic_symbol.h +++ /dev/null @@ -1,94 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file atomic_symbol.h - * \brief atomic symbol interface of mxnet - */ -#ifndef MXNET_ATOMIC_SYMBOL_H_ -#define MXNET_ATOMIC_SYMBOL_H_ - -#include -#include -#include -#include -#include "./base.h" -#include "./tensor_blob.h" - -namespace mxnet { -// forward declare StaticOperator -class StaticOperator; -/*! - * \brief AtomicSymbol is the base class of all atomic symbols. - * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance - * of AtomicSymbol can be shared in the graphs of different Symbols - */ -class AtomicSymbol { - public: - /*! - * \brief virtual destructor - */ - virtual ~AtomicSymbol() {} - /*! \brief get the descriptions of inputs for this symbol */ - virtual std::vector ListArguments() const { - // default implementation returns "data" - return std::vector(1, std::string("data")); - } - /*! \brief get the descriptions of outputs for this symbol */ - virtual std::vector ListReturns() const { - // default implementation returns "output" - return std::vector(1, std::string("output")); - } - /*! \brief number of outputs of the symbol */ - virtual int NumReturns() const { - return 1; - } - /*! - * \brief set param for the symbol from string - * \param name parameter name - * \param val string for the configuration - */ - virtual void SetParam(const char *name, const char *val) {} - /*! - * \brief infer the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - * \return if the shape inference is successful, return true, else return false. - */ - virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const = 0; - /*! - * \brief Copy this AtomicSymbol and returns a pointer to the copied object. - * this is a virtual function because different subclass of AtomicSymbol would copy differently. - * \return a pointer to the copied atomic symbol - */ - virtual AtomicSymbol* Copy() const = 0; - /*! - * \brief Bind this AtomicSymbol to a context and get back a static operator - * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. - * Calling bind from the Symbol wrapper would generate a NArrayOperator. - */ - template - StaticOperator* Bind(Context ctx) const; - /*! - * \brief return the type string of the atomic symbol - * subclasses override this function. - */ - virtual std::string TypeString() const = 0; - friend class Symbol; - - /*! - * \brief create atomic symbol by type name - * \param type_name the type string of the AtomicSymbol - * \return a new constructed AtomicSymbol - */ - static AtomicSymbol *Create(const char* type_name); -}; - -} // namespace mxnet -#endif // MXNET_ATOMIC_SYMBOL_H_ diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 67c3a1b24b74..6256947faf5c 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -74,5 +74,47 @@ enum Property { kForwardRequireRnd = 2, }; +/*! \brief context information about the execution enviroment */ +struct Context { + /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ + int dev_mask; + /*! \brief device id we are going to run it on */ + int dev_id; + /*! \brief constructor */ + Context() : dev_mask(cpu::kDevMask), dev_id(0) {} + /*! + * \brief constructor of context + * \param dev_mask the device mask + * \param dev_id the device id + */ + Context(int dev_mask, int dev_id) + : dev_mask(dev_mask), dev_id(dev_id) {} + /*! + * \brief check if current context equals another one + * \param b another context to compare + * \return whether dev mask and id are same + */ + inline bool operator==(const Context &b) const { + return dev_mask == b.dev_mask && dev_id == b.dev_id; + } +}; + + +/*! + * \brief execution context provides the information needed + * in runtime to actually execute the operation + */ +struct RunContext { + /*! + * \brief the stream of the device, can be NULL or Stream* in GPU mode + */ + void *stream; +}; + +/*! \brief dynamic shape type */ +typedef mshadow::TShape TShape; +/*! \brief storage container type */ +typedef mshadow::TBlob TBlob; + } // namespace mxnet #endif // MXNET_BASE_H_ diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 9ad75b4e5954..29c9691e8ff5 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -245,6 +245,16 @@ MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out); +/*! + * \brief Create a Symbol by grouping list of symbols together + * \param num_symbols number of symbols to be grouped + * \param symbols array of symbol handles + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCreateGroup(mx_uint num_symbols, + SymbolHandle *symbols, + SymbolHandle *out); /*! * \brief Create symbol from config. * \param cfg configuration string diff --git a/include/mxnet/dag_engine.h b/include/mxnet/dag_engine.h index cf4008f9eb95..9e65d6108f60 100644 --- a/include/mxnet/dag_engine.h +++ b/include/mxnet/dag_engine.h @@ -15,7 +15,6 @@ #include #include #include "./base.h" -#include "./tensor_blob.h" namespace mxnet { /*! diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 4e7b4448e667..1c4098ce7ac8 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -10,7 +10,6 @@ #include #include "./base.h" #include "./storage.h" -#include "./tensor_blob.h" #include "./dag_engine.h" // check c++11 #if DMLC_USE_CXX11 == 0 diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 97f8ca035ebd..c1a53df61fa9 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -10,12 +10,70 @@ #include #include #include "./base.h" -#include "./tensor_blob.h" -#include "./static_operator.h" +#if DMLC_USE_CXX11 == 1 #include "./narray.h" #include "./dag_engine.h" - +#endif namespace mxnet { +/*! + * \brief StaticOperator interface + * StaticOperator is a stateful object that can be used to call forward and backprop + * + * This interface relies on pre-allocated memory in TBlob, the caller need to set + * the memory region in TBlob correctly before calling Forward and Backward + * + * \sa TBlob, TShape + */ +class StaticOperator { + public: + /*! \brief destructor */ + virtual ~StaticOperator() {} + /*! + * \brief describe property of op + * \return a bit map in int + */ + virtual int DescribeProperty() const { + // default most of layer only conatin internal state + return kContainInteralState; + } + /*! + * \brief perform a forward operation of StaticOperator, save the output to TBlob + * \param opt option on Forward such as whether this is training phase + * \param ctx runtime context + * \param in_data array of input data, it is const + * \param out_data array of output data, + * the space of TBlob in out_data must be pre-allocated with InferShape + */ + virtual void Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data) = 0; + /*! + * \brief perform a backward operation of the StaticOperator to get the gradient + * \param ctx runtime context + * \param grad_next the gradient value we get from output of the StaticOperator + * \param in_data the array of input data + * \param out_grad array of output gradient, there could be three possible TBlob + * in the each element in the array + * \param req request types of the gradient saving operation + * only inplace will change input data + * \sa GradReqType + */ + virtual void Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req) = 0; + /*! + * \brief factory function, create a new StaticOperator + * \param type the type of StaticOperator + * \param ctx the context device type of StaticOperator + * \return a pointer of StaticOperator object + */ + static StaticOperator *Create(const char *type, Context ctx); +}; + +#if DMLC_USE_CXX11 == 1 /*! * \brief operator interface * operator is an object can be scheduled by DAG engine directly. @@ -74,5 +132,6 @@ class Operator { */ static Operator *CreateWrapper(StaticOperator *op, Context ctx); }; // class operator +#endif } // namespace mxnet #endif // MXNET_OPERATOR_H_ diff --git a/include/mxnet/registry.h b/include/mxnet/registry.h index dcc87b6ee232..04a3eb1abb51 100644 --- a/include/mxnet/registry.h +++ b/include/mxnet/registry.h @@ -12,7 +12,7 @@ #include #include "./base.h" #include "./narray.h" -#include "./symbol.h" +#include "./symbolic.h" namespace mxnet { diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h deleted file mode 100644 index 1e3b8352de83..000000000000 --- a/include/mxnet/static_graph.h +++ /dev/null @@ -1,84 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_graph.h - * \brief The static graph of symbols - */ -#ifndef MXNET_STATIC_GRAPH_H_ -#define MXNET_STATIC_GRAPH_H_ - -#include -#include -#include -#include "./base.h" -#include "./atomic_symbol.h" - -namespace mxnet { -/*! - * \brief StaticGraph is the configuration of computation graphs. - * This is the "configuration file" of mxnet. - * It can be converted to/from Symbol, and can be used to bind to operators. - */ -class StaticGraph { - public: - /*! \brief represents a data in the graph */ - struct DataEntry { - /*! \brief the source node id in the computation graph */ - uint32_t source_id; - /*! - * \brief index of output from the source. - * If index == -1, it represents all the outputs. - */ - int32_t index; - }; - /*! \brief Operation Node in static graph */ - struct Node { - /*! \brief wrapped atomic symbol */ - std::unique_ptr sym; - /*! \brief name of the node */ - std::string name; - /*! \brief inputs (node_id, index) for of the nodes*/ - std::vector inputs; - }; - /*! \brief all nodes in the graph */ - std::vector nodes; - /*! \brief index is nodes that correspods to arguments */ - std::vector arg_nodes; - /*! \brief outputs(heads) of the graph */ - std::vector outputs; - // funtions to help inference in static graph - /*! - * \brief Perform a topological sort on the graph - * \return a topological order of node indices. - */ - std::vector TopoSort() const; - /*! - * \brief infer the node shapes in the computation graph. - * - * When calling this function, user can setup the shape information known into right position. - * Unknown shape are indicated by shape.ndim() == 0. - * - * \param topo_order The topological order of node index, as created by TopoSort. - * \param node_out_shapes The shapes of the each outputs of nodes in the graph. - * \return if the shape inference is successful, return true, else return false. - */ - bool InferNodeShapes(const std::vector &topo_order, - std::vector > *node_out_shapes) const; - /*! - * \brief infer the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by ListArguments - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - * \return if the shape inference is successful, return true, else return false. - */ - bool InferShape(std::vector *in_shape, - std::vector *out_shape) const; -}; -} // namespace mxnet -#endif // MXNET_STATIC_GRAPH_H_ diff --git a/include/mxnet/static_operator.h b/include/mxnet/static_operator.h deleted file mode 100644 index e3d4d68d9d85..000000000000 --- a/include/mxnet/static_operator.h +++ /dev/null @@ -1,73 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator.h - * \brief static operator interface of mxnet - */ -#ifndef MXNET_STATIC_OPERATOR_H_ -#define MXNET_STATIC_OPERATOR_H_ -// this file will be seen by cuda, no c++11 for now -#include -#include -#include "./base.h" -#include "./tensor_blob.h" - -namespace mxnet { -/*! - * \brief StaticOperator interface - * StaticOperator is a stateful object that can be used to call forward and backprop - * - * This interface relies on pre-allocated memory in TBlob, the caller need to set - * the memory region in TBlob correctly before calling Forward and Backward - * - * \sa TBlob, TShape - */ -class StaticOperator { - public: - /*! \brief destructor */ - virtual ~StaticOperator() {} - /*! - * \brief describe property of op - * \return a bit map in int - */ - virtual int DescribeProperty() const { - // default most of layer only conatin internal state - return kContainInteralState; - } - /*! - * \brief perform a forward operation of StaticOperator, save the output to TBlob - * \param opt option on Forward such as whether this is training phase - * \param ctx runtime context - * \param in_data array of input data, it is const - * \param out_data array of output data, - * the space of TBlob in out_data must be pre-allocated with InferShape - */ - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) = 0; - /*! - * \brief perform a backward operation of the StaticOperator to get the gradient - * \param ctx runtime context - * \param grad_next the gradient value we get from output of the StaticOperator - * \param in_data the array of input data - * \param out_grad array of output gradient, there could be three possible TBlob - * in the each element in the array - * \param req request types of the gradient saving operation - * only inplace will change input data - * \sa GradReqType - */ - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req) = 0; - /*! - * \brief factory function, create a new StaticOperator - * \param type the type of StaticOperator - * \param ctx the context device type of StaticOperator - * \return a pointer of StaticOperator object - */ - static StaticOperator *Create(const char *type, Context ctx); -}; -} // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 3bb123b44816..6afc2885a746 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -6,7 +6,6 @@ #ifndef MXNET_STORAGE_H_ #define MXNET_STORAGE_H_ #include "./base.h" -#include "./tensor_blob.h" namespace mxnet { /*! \brief memory allocator of storage */ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbolic.h similarity index 51% rename from include/mxnet/symbol.h rename to include/mxnet/symbolic.h index 18b4466706d4..0b3aead32e54 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbolic.h @@ -1,27 +1,169 @@ /*! - * Copyright (c) 2015 by Contributors - * \file symbol.h - * \brief symbolic interface of mxnet - */ -#ifndef MXNET_SYMBOL_H_ -#define MXNET_SYMBOL_H_ + * Copyright (c) 2015 by Contributors + * \file symbolic.h + * \brief + * \author Bing Xu +*/ + +#ifndef MXNET_SYMBOLIC_H_ +#define MXNET_SYMBOLIC_H_ -#include -#include #include #include -#include #include -#include -#include +#include +#if DMLC_USE_CXX11 == 1 #include #include +#endif #include "./base.h" -#include "./tensor_blob.h" -#include "./operator.h" -#include "./static_graph.h" namespace mxnet { +// forward declare StaticOperator +class StaticOperator; +/*! + * \brief AtomicSymbol is the base class of all atomic symbols. + * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance + * of AtomicSymbol can be shared in the graphs of different Symbols + */ +class AtomicSymbol { + public: + /*! + * \brief virtual destructor + */ + virtual ~AtomicSymbol() {} + /*! \brief get the descriptions of inputs for this symbol */ + virtual std::vector ListArguments() const { + // default implementation returns "data" + return std::vector(1, std::string("data")); + } + /*! \brief get the descriptions of outputs for this symbol */ + virtual std::vector ListReturns() const { + // default implementation returns "output" + return std::vector(1, std::string("output")); + } + /*! \brief number of outputs of the symbol */ + virtual int NumReturns() const { + return 1; + } + /*! + * \brief set param for the symbol from string + * \param name parameter name + * \param val string for the configuration + */ + virtual void SetParam(const char *name, const char *val) {} + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const = 0; + /*! + * \brief Copy this AtomicSymbol and returns a pointer to the copied object. + * this is a virtual function because different subclass of AtomicSymbol would copy differently. + * \return a pointer to the copied atomic symbol + */ + virtual AtomicSymbol* Copy() const = 0; + /*! + * \brief Bind this AtomicSymbol to a context and get back a static operator + * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. + * Calling bind from the Symbol wrapper would generate a NArrayOperator. + */ + template + StaticOperator* Bind(Context ctx) const; + /*! + * \brief return the type string of the atomic symbol + * subclasses override this function. + */ + virtual std::string TypeString() const = 0; + friend class Symbol; + + /*! + * \brief create atomic symbol by type name + * \param type_name the type string of the AtomicSymbol + * \return a new constructed AtomicSymbol + */ + static AtomicSymbol *Create(const char* type_name); +}; +#if DMLC_USE_CXX11 == 1 +/*! + * \brief StaticGraph is the configuration of computation graphs. + * This is the "configuration file" of mxnet. + * It can be converted to/from Symbol, and can be used to bind to operators. + */ +class StaticGraph { + public: + /*! \brief represents a data in the graph */ + struct DataEntry { + /*! \brief the source node id in the computation graph */ + uint32_t source_id; + /*! + * \brief index of output from the source. + * If index == -1, it represents all the outputs. + */ + int32_t index; + }; + /*! \brief Operation Node in static graph */ + struct Node { + /*! \brief wrapped atomic symbol */ + std::unique_ptr sym; + /*! \brief name of the node */ + std::string name; + /*! \brief inputs (node_id, index) for of the nodes*/ + std::vector inputs; + }; + /*! \brief all nodes in the graph */ + std::vector nodes; + /*! \brief index is nodes that correspods to arguments */ + std::vector arg_nodes; + /*! \brief outputs(heads) of the graph */ + std::vector outputs; + // funtions to help inference in static graph + /*! + * \brief Perform a topological sort on the graph + * \return a topological order of node indices. + */ + std::vector TopoSort() const; + /*! + * \brief infer the node shapes in the computation graph. + * + * When calling this function, user can setup the shape information known into right position. + * Unknown shape are indicated by shape.ndim() == 0. + * + * \param topo_order The topological order of node index, as created by TopoSort. + * \param node_out_shapes The shapes of the each outputs of nodes in the graph. + * \return if the shape inference is successful, return true, else return false. + */ + bool InferNodeShapes(const std::vector &topo_order, + std::vector > *node_out_shapes) const; + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by ListArguments + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + bool InferShape(std::vector *in_shape, + std::vector *out_shape) const; +}; +#endif +#if DMLC_USE_CXX11 == 1 /*! * \brief Symbol is used to represent dynamically generated symbolic computation graph. * @@ -60,7 +202,7 @@ class Symbol { * \param index index of multi output * \return the symbol corresponds to the indexed element. */ - Symbol operator[] (int index) const; + Symbol operator[] (size_t index) const; /*! * \brief Compose the symbol with arguments, this changes current symbol. * @@ -82,6 +224,14 @@ class Symbol { */ void Compose(const std::unordered_map& kwargs, const std::string& name); + /*! + * \brief Convert a list of symbols into static graph + * + * The user can go further to call bind function on static graph + * + * \param out_graph the pointer holder of the output graph + */ + void ToStaticGraph(StaticGraph *out_graph) const; /*! * \brief Apply the symbol as a function, compose with arguments * \param args positional arguments for the symbol @@ -101,7 +251,7 @@ class Symbol { * \return a new symbol which is the composition of current symbol with its arguments */ inline Symbol operator () (const std::unordered_map& kwargs, - const std::string& name) { + const std::string& name) const { Symbol s = this->Copy(); s.Compose(kwargs, name); return s; @@ -121,11 +271,18 @@ class Symbol { * \return if the shape inference is successful, return true, else return false. */ inline bool InferShape(std::vector *in_shape, - std::vector *out_shape) { + std::vector *out_shape) const { StaticGraph g; - Symbol::Convert({*this}, &g); + this->ToStaticGraph(&g); return g.InferShape(in_shape, out_shape); } + /*! + * \brief get number of outputs of this symbol + * \return number of outputs + */ + inline size_t NumReturns() const { + return heads_.size(); + } /*! * \brief create Symbol by wrapping AtomicSymbol * This function takes the ownership of atomic_symbol. @@ -136,20 +293,24 @@ class Symbol { */ static Symbol Create(AtomicSymbol *atomic_symbol); /*! - * \brief create equivalence of symbols from static graphs + * \brief create equivalence of symbol from static graphs * \param graph the static graph - * \return list of Symbols representing outputs of the graph + * \return the created symbol */ - static std::vector Create(const StaticGraph &graph); + static Symbol Create(const StaticGraph &graph); + /*! - * \brief Convert a list of symbols into static graph - * - * The user can go further to call bind function on static graph - * - * \param heads the heads of the graph - * \param out_graph the pointer holder of the output graph + * \brief create equivalence of symbol by grouping the symbols together + * \param symbols list of symbols + * \return the grouped symbol */ - static void Convert(const std::vector &heads, StaticGraph *out_graph); + static Symbol CreateGroup(const std::vector &symbols) { + Symbol ret; + for (const auto &s : symbols) { + ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end()); + } + return std::move(ret); + } /*! * \brief create variable symbol node * \param name name of the variable @@ -157,7 +318,7 @@ class Symbol { */ inline static Symbol CreateVariable(const std::string &name) { Symbol s; - s.head_ = DataEntry(std::make_shared(nullptr, name), 0); + s.heads_.push_back(DataEntry(std::make_shared(nullptr, name), 0)); return std::move(s); } @@ -170,13 +331,12 @@ class Symbol { std::shared_ptr source; /*! * \brief index of output from the source. - * If index == -1, it represents all the outputs. */ - int index; + uint32_t index; /*! \brief enabled default copy constructor */ DataEntry() {} /*! \brief constructor from index */ - DataEntry(std::shared_ptr source, int index) + DataEntry(std::shared_ptr source, uint32_t index) : source(source), index(index) {} }; /*! @@ -212,18 +372,14 @@ class Symbol { return sym == nullptr; } }; - /*! \brief the head node of the Symbol */ - DataEntry head_; - - private: - /*! \brief DFS Visit for symbol with single head - * This function call is specail case for DFSVisit_ - * \param fvisit function applied for each visit. - * \tparam FVisit visiting function type + /*! + * \brief the head nodes of Symbols + * This head is only effective when */ - template - inline void DFSVisit(FVisit fvisit) const { - DFSVisit({*this}, fvisit); + std::vector heads_; + /*! \return whwther the symbol is AtomicSymbol */ + inline bool is_atomic() const { + return heads_.size() == 1 && heads_[0].source->is_atomic(); } /*! * \brief Visit all the nodes in left-to-right depth first order. @@ -235,13 +391,12 @@ class Symbol { * \tparam FVisit visiting function type */ template - static inline void DFSVisit(const std::vector &heads, - FVisit fvisit) { + inline void DFSVisit(FVisit fvisit) const { std::vector stack; std::unordered_set visited; // put the head into the graph - for (auto &head : heads) { - Node *ptr = head.head_.source.get(); + for (auto &head : heads_) { + Node *ptr = head.source.get(); if (visited.count(ptr) == 0) { stack.push_back(ptr); visited.insert(ptr); @@ -267,5 +422,6 @@ class Symbol { */ int FindDuplicateArgs(std::unordered_map *out) const; }; +#endif } // namespace mxnet -#endif // MXNET_SYMBOL_H_ +#endif // MXNET_SYMBOLIC_H_ diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h deleted file mode 100644 index b39939cb1425..000000000000 --- a/include/mxnet/tensor_blob.h +++ /dev/null @@ -1,53 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file tensor_blob.h - * \brief tensor blob used to hold static memory used by - */ -#ifndef MXNET_TENSOR_BLOB_H_ -#define MXNET_TENSOR_BLOB_H_ -#include - -namespace mxnet { -/*! \brief context information about the execution enviroment */ -struct Context { - /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ - int dev_mask; - /*! \brief device id we are going to run it on */ - int dev_id; - /*! \brief constructor */ - Context() : dev_mask(cpu::kDevMask), dev_id(0) {} - /*! - * \brief constructor of context - * \param dev_mask the device mask - * \param dev_id the device id - */ - Context(int dev_mask, int dev_id) - : dev_mask(dev_mask), dev_id(dev_id) {} - /*! - * \brief check if current context equals another one - * \param b another context to compare - * \return whether dev mask and id are same - */ - inline bool operator==(const Context &b) const { - return dev_mask == b.dev_mask && dev_id == b.dev_id; - } -}; - - -/*! - * \brief execution context provides the information needed - * in runtime to actually execute the operation - */ -struct RunContext { - /*! - * \brief the stream of the device, can be NULL or Stream* in GPU mode - */ - void *stream; -}; - -/*! \brief dynamic shape type */ -typedef mshadow::TShape TShape; -/*! \brief storage container type */ -typedef mshadow::TBlob TBlob; -} // namespace mxnet -#endif // MXNET_TENSOR_BLOB_H_ diff --git a/python/mxnet/symbol_creator.py b/python/mxnet/symbol_creator.py index a0617ce395f8..d4b87e401e3b 100644 --- a/python/mxnet/symbol_creator.py +++ b/python/mxnet/symbol_creator.py @@ -106,3 +106,26 @@ def Variable(self, name): handle = SymbolHandle() check_call(_LIB.MXSymbolCreateVariable(name, ctypes.byref(handle))) return Symbol(handle) + + def Group(self, symbols): + """Create a symbolic variable that groups several symbols together. + + Parameters + ---------- + symbols : list + List of symbols to be grouped. + + Returns + ------- + sym : Symbol + The created group symbol. + """ + ihandles = [] + for sym in symbols: + if not isinstance(sym, Symbol): + raise TypeError('Expect Symbols in the list input') + ihandles.append(sym.handle) + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateGroup( + len(ihandles), c_array(SymbolHandle, ihandles), ctypes.byref(handle))) + return Symbol(handle) diff --git a/python/test_symbol.py b/python/test_symbol.py index f4823a087e3c..6d876fd46fb8 100644 --- a/python/test_symbol.py +++ b/python/test_symbol.py @@ -20,4 +20,8 @@ composed_fc4 = fc4(fc3_data=fc2, name='composed') print composed_fc4.debug_str() +multi_out = mx.sym.Group([composed_fc4, fc2]) +print multi_out.debug_str() +print multi_out.list_arguments() +print multi_out.list_returns() diff --git a/src/c_api.cc b/src/c_api.cc index df8eb349752c..9620840cf3b0 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -7,8 +7,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -99,11 +98,11 @@ using namespace mxnet; /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { -/*! \brief every function starts with API_BEGIN(); +/*! \brief every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR */ #define API_END() } catch(dmlc::Error &e) { return MXHandleException(e); } return 0; /*! - * \brief every function starts with API_BEGIN(); + * \brief every function starts with API_BEGIN(); * and finishes with API_END() or API_END_HANDLE_ERROR * The finally clause contains procedure to cleanup states when an error happens. */ @@ -317,6 +316,21 @@ int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { API_END_HANDLE_ERROR(delete s); } +int MXSymbolCreateGroup(mx_uint num_symbols, + SymbolHandle *symbols, + SymbolHandle *out) { + MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + MXAPISymbolWrapper **sym_arr = (MXAPISymbolWrapper**)symbols; // NOLINT(*) + API_BEGIN(); + std::vector syms; + for (mx_uint i = 0; i < num_symbols; ++i) { + syms.push_back(sym_arr[i]->sym); + } + s->sym = Symbol::CreateGroup(syms); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + int MXSymbolFree(SymbolHandle symbol) { API_BEGIN(); delete static_cast(symbol); @@ -348,11 +362,10 @@ int MXSymbolListArguments(SymbolHandle symbol, const char ***out_str_array) { MXAPISymbolWrapper *s = static_cast(symbol); API_BEGIN(); - if (s->ret_vec_charp.size() == 0) { - s->ret_vec_str = std::move((s->sym).ListArguments()); - for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { - s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); - } + s->ret_vec_str = std::move((s->sym).ListArguments()); + s->ret_vec_charp.clear(); + for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { + s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); } *out_size = static_cast(s->ret_vec_charp.size()); *out_str_array = dmlc::BeginPtr(s->ret_vec_charp); @@ -364,12 +377,10 @@ int MXSymbolListReturns(SymbolHandle symbol, const char ***out_str_array) { MXAPISymbolWrapper *s = static_cast(symbol); API_BEGIN(); + s->ret_vec_str = std::move((s->sym).ListReturns()); s->ret_vec_charp.clear(); - if (s->ret_vec_charp.size() == 0) { - s->ret_vec_str = std::move((s->sym).ListReturns()); - for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { - s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); - } + for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { + s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); } *out_size = static_cast(s->ret_vec_charp.size()); *out_str_array = dmlc::BeginPtr(s->ret_vec_charp); diff --git a/src/narray/narray_op.h b/src/narray/narray_op.h index 1ce546ed295d..2c39363fba32 100644 --- a/src/narray/narray_op.h +++ b/src/narray/narray_op.h @@ -8,7 +8,6 @@ #include #include #include -#include namespace mxnet { /*! \brief namespace to support all possible NArray operator */ diff --git a/src/operator/composite_operator.h b/src/operator/composite_operator.h index ddaa0a50f561..12297dc41c43 100644 --- a/src/operator/composite_operator.h +++ b/src/operator/composite_operator.h @@ -6,13 +6,12 @@ */ #ifndef MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ #define MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ +#include +#include +#include #include #include #include -#include "./atomic_symbol.h" -#include "./base.h" -#include "./static_graph.h" -#include "./static_operator.h" namespace mxnet { /*! @@ -34,11 +33,13 @@ class CompositeOperator : public Operator { /*! \brief Make operator by using graph * \param ctx ctx context of the created operator * \param in input narray - * \param graph input static graph + * \param grad gradient narray + * \param req gradient request */ void Bind(Context ctx, const std::vector &in, - std::shared_ptr graph); + const std::vector &grad + const std::vector &req); /*! * \brief perform a forward operation of operator, save the output to NArray * This method only pushes an execution request to the DAG engine, and @@ -54,6 +55,11 @@ class CompositeOperator : public Operator { RunContext ctx, const std::vector &in_data, const std::vector &out_data); + /*! + * \brief perform a forward operation of operator (no change to binded NArray) + * \param opt option on Forward such as whether this is training phase + */ + virtual void Forward(Option opt); /*! * \brief perform a backward operation of the operator to get the gradient * This method only pushes an execution request to the DAG engine, and @@ -72,31 +78,54 @@ class CompositeOperator : public Operator { const std::vector &out_grad, const std::vector &req); /*! - * \brief perform an extraction operation to get feature map + * \brief perform a backward operation of the operator to get the gradient + * No change to Binded NArray + */ + virtual void Backward(); + /*! + * \brief perform an extraction operation to get feature map * \param name of symbol need to be extracted * \return empty narray for invalid name or narray of the feature map */ virtual NArray Extract(const std::string &symbol_name); private: - /*! \brief - struct Connection { - + /*! + * \brief Update connections data in/after bind + * \param in input narray + * \param grad gradient narray + * \param req gradient request + */ + void UpdateConnection(const std::vector &in, + const std::vector &grad, + const std::vector &req); + /*! + * \brief Allocate each op node + */ + void AllocateNodes(RunContext ctx); + /*! + * \brief Structure for OpNode + */ + struct OpNode { + /*! \brief Static Operator */ + std::unique_ptr op; + /*! \brief inputs (init after setting output correctly) */ + std::vector inputs; + /*! \brief outputs */ + std::vector outputs; + /*! \brief gradient for output */ + std::vector outputs_grad; + /*! \brief gradient req for grad */ + std::vector req; + /*! \brief is variable */ + bool is_variable; }; - /*! \brief static operators for each node */ - std::vector > static_ops_; - /*! \brief feature map for each op */ - std::vector > feature_maps_; - /*! \brief input NArray link */ - std::vector > in_data_; - /*! \brief input NArray gradient */ - std::vector > in_grad_; - /*! \brief output NArray link */ - std::vector > out_data_; + /*! \brief connections */ + std::vector nodes_; + /*! \brief topo order of connections */ + std::vector topo_order_; /*! \brief static graph */ - std::shared_ptr graph_; + StaticGraph graph_; }; // class CompositeOperator } // namespace mxnet #endif // MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ - - diff --git a/src/static_operator/activation_op-inl.h b/src/operator/static_operator/activation_op-inl.h similarity index 91% rename from src/static_operator/activation_op-inl.h rename to src/operator/static_operator/activation_op-inl.h index b1ad0d090706..c888b35c9c61 100644 --- a/src/static_operator/activation_op-inl.h +++ b/src/operator/static_operator/activation_op-inl.h @@ -4,11 +4,11 @@ * \brief activation operator of mxnet */ -#ifndef MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ #include -#include +#include #include #include "./static_operator_common.h" @@ -57,4 +57,4 @@ class ActivationOp : public StaticOperator { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ diff --git a/src/static_operator/convolution_op-inl.h b/src/operator/static_operator/convolution_op-inl.h similarity index 98% rename from src/static_operator/convolution_op-inl.h rename to src/operator/static_operator/convolution_op-inl.h index 0f7c5ccbb631..2271839b697a 100644 --- a/src/static_operator/convolution_op-inl.h +++ b/src/operator/static_operator/convolution_op-inl.h @@ -4,10 +4,10 @@ * \brief convolution op * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ -#include +#include #include #include #include "./static_operator_common.h" @@ -266,4 +266,4 @@ class ConvolutionOp : public StaticOperator { }; // class ConvolutionOp } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ diff --git a/src/static_operator/dropout_op-inl.h b/src/operator/static_operator/dropout_op-inl.h similarity index 93% rename from src/static_operator/dropout_op-inl.h rename to src/operator/static_operator/dropout_op-inl.h index aba19ad3c88b..b79a79fbea65 100644 --- a/src/static_operator/dropout_op-inl.h +++ b/src/operator/static_operator/dropout_op-inl.h @@ -4,10 +4,10 @@ * \brief dropout operator * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_DROPOUT_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_DROPOUT_OP_INL_H_ -#include +#include #include #include "./mshadow_op.h" @@ -90,4 +90,4 @@ class DropoutOp : public StaticOperator { }; // class DropoutOp } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_DROPOUT_OP_INL_H_ diff --git a/src/static_operator/fully_connect_op-inl.h b/src/operator/static_operator/fully_connect_op-inl.h similarity index 95% rename from src/static_operator/fully_connect_op-inl.h rename to src/operator/static_operator/fully_connect_op-inl.h index 15f8e857d3cf..d39335deeeff 100644 --- a/src/static_operator/fully_connect_op-inl.h +++ b/src/operator/static_operator/fully_connect_op-inl.h @@ -3,12 +3,12 @@ * \file fully_connect_op-inl.h * \brief fully connect operator and symbol */ -#ifndef MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ #include -#include -#include +#include +#include #include #include #include "./static_operator_common.h" @@ -160,4 +160,4 @@ class FullyConnectSymbol : public AtomicSymbol { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ diff --git a/src/static_operator/fully_connect_op.cc b/src/operator/static_operator/fully_connect_op.cc similarity index 100% rename from src/static_operator/fully_connect_op.cc rename to src/operator/static_operator/fully_connect_op.cc diff --git a/src/static_operator/fully_connect_op.cu b/src/operator/static_operator/fully_connect_op.cu similarity index 100% rename from src/static_operator/fully_connect_op.cu rename to src/operator/static_operator/fully_connect_op.cu diff --git a/src/static_operator/mshadow_op.h b/src/operator/static_operator/mshadow_op.h similarity index 92% rename from src/static_operator/mshadow_op.h rename to src/operator/static_operator/mshadow_op.h index 2954b1f81a48..bb33471f168a 100644 --- a/src/static_operator/mshadow_op.h +++ b/src/operator/static_operator/mshadow_op.h @@ -4,8 +4,8 @@ * \brief extra mshadow operation for mxnet * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ -#define MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ #include #include @@ -102,5 +102,5 @@ struct square_root { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ diff --git a/src/static_operator/param.h b/src/operator/static_operator/param.h similarity index 93% rename from src/static_operator/param.h rename to src/operator/static_operator/param.h index c2829aced8ae..f6e91293eca3 100644 --- a/src/static_operator/param.h +++ b/src/operator/static_operator/param.h @@ -4,8 +4,8 @@ * \brief operator params * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_PARAM_H_ -#define MXNET_STATIC_OPERATOR_PARAM_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_PARAM_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_PARAM_H_ namespace mxnet { namespace op { @@ -68,6 +68,6 @@ struct Param { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_PARAM_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_PARAM_H_ diff --git a/src/static_operator/pooling_op-inl.h b/src/operator/static_operator/pooling_op-inl.h similarity index 96% rename from src/static_operator/pooling_op-inl.h rename to src/operator/static_operator/pooling_op-inl.h index e4bf344f7e5a..db5e40ffb4a7 100644 --- a/src/static_operator/pooling_op-inl.h +++ b/src/operator/static_operator/pooling_op-inl.h @@ -4,10 +4,10 @@ * \brief pooling operator * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ -#include +#include #include #include #include "./param.h" @@ -149,4 +149,4 @@ class PoolingOp : public StaticOperator { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ diff --git a/src/static_operator/reshape_op-inl.h b/src/operator/static_operator/reshape_op-inl.h similarity index 92% rename from src/static_operator/reshape_op-inl.h rename to src/operator/static_operator/reshape_op-inl.h index eb05a460573d..44d8f8fcef24 100644 --- a/src/static_operator/reshape_op-inl.h +++ b/src/operator/static_operator/reshape_op-inl.h @@ -4,10 +4,10 @@ * \brief * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_RESHAPE_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_RESHAPE_OP_INL_H_ -#include +#include #include namespace mxnet { @@ -72,4 +72,4 @@ class ReshapeOp : public StaticOperator { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_RESHAPE_OP_INL_H_ diff --git a/src/static_operator/static_operator-inl.h b/src/operator/static_operator/static_operator-inl.h similarity index 85% rename from src/static_operator/static_operator-inl.h rename to src/operator/static_operator/static_operator-inl.h index 99776b3db621..f03a6a51532e 100644 --- a/src/static_operator/static_operator-inl.h +++ b/src/operator/static_operator/static_operator-inl.h @@ -4,11 +4,11 @@ * \brief static device invarient code to create operators * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ -#define MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ #include #include -#include +#include #include "./mshadow_op.h" #include "./activation_op-inl.h" #include "./convolution_op-inl.h" @@ -46,4 +46,4 @@ inline StaticOperator *CreateOperator_(OpType type, mshadow::Random *prnd) } } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ diff --git a/src/static_operator/static_operator.cc b/src/operator/static_operator/static_operator.cc similarity index 96% rename from src/static_operator/static_operator.cc rename to src/operator/static_operator/static_operator.cc index 4a2a121532dd..671ef76f2f9c 100644 --- a/src/static_operator/static_operator.cc +++ b/src/operator/static_operator/static_operator.cc @@ -6,7 +6,7 @@ */ #include #include -#include +#include #include #include "./static_operator_common.h" diff --git a/src/static_operator/static_operator_common.h b/src/operator/static_operator/static_operator_common.h similarity index 88% rename from src/static_operator/static_operator_common.h rename to src/operator/static_operator/static_operator_common.h index 0d1553703200..06eb307b8ca0 100644 --- a/src/static_operator/static_operator_common.h +++ b/src/operator/static_operator/static_operator_common.h @@ -6,11 +6,11 @@ * common type definitions * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ -#define MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ #include -#include +#include #include namespace mxnet { namespace op { @@ -70,4 +70,4 @@ template StaticOperator *CreateOperator(OpType type); } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ diff --git a/src/static_operator/static_operator_cpu.cc b/src/operator/static_operator/static_operator_cpu.cc similarity index 100% rename from src/static_operator/static_operator_cpu.cc rename to src/operator/static_operator/static_operator_cpu.cc diff --git a/src/static_operator/static_operator_gpu.cu b/src/operator/static_operator/static_operator_gpu.cu similarity index 93% rename from src/static_operator/static_operator_gpu.cu rename to src/operator/static_operator/static_operator_gpu.cu index 580fe65d630d..a66167431dd1 100644 --- a/src/static_operator/static_operator_gpu.cu +++ b/src/operator/static_operator/static_operator_gpu.cu @@ -5,7 +5,6 @@ * \author Bing Xu */ #include -#include #include "static_operator-inl.h" namespace mxnet { diff --git a/src/operator/static_operator_wrapper.cc b/src/operator/static_operator_wrapper.cc index 97ed3b307291..afd4bae6241c 100644 --- a/src/operator/static_operator_wrapper.cc +++ b/src/operator/static_operator_wrapper.cc @@ -6,8 +6,6 @@ */ #include #include -#include -#include #include #include #include diff --git a/src/registry.cc b/src/registry.cc index f3ce0bd28ff0..04f391cb617c 100644 --- a/src/registry.cc +++ b/src/registry.cc @@ -6,7 +6,7 @@ #include #include #include -#include +#include namespace mxnet { diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index 9175fa9d55b0..ce54ad818bfe 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -4,7 +4,7 @@ * \brief static graph of mxnet */ #include -#include +#include #include #include diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 2a4ed45df1a8..e3700dd127f4 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -4,9 +4,8 @@ * \brief symbol of mxnet */ #include -#include +#include #include -#include #include #include #include @@ -33,13 +32,15 @@ Symbol Symbol::Copy() const { } // set the head Symbol s; - s.head_ = DataEntry(old_new[head_.source.get()], head_.index); + for (auto &head : heads_) { + s.heads_.push_back(DataEntry(old_new[head.source.get()], head.index)); + } return s; } void Symbol::Print(std::ostream &os) const { - if (head_.source->is_atomic()) { - os << "AtomicSymbol "<< " Type:" << head_.source->sym->TypeString() << '\n' + if (this->is_atomic()) { + os << "AtomicSymbol "<< " Type:" << heads_[0].source->sym->TypeString() << '\n' << "Inputs:"; std::vector args = this->ListArguments(); for (size_t i = 0; i < args.size(); ++i) { @@ -47,6 +48,11 @@ void Symbol::Print(std::ostream &os) const { } } else { // use DFSVisit to copy all the nodes + os << "Outputs:\n"; + for (size_t i = 0; i < heads_.size(); ++i) { + os << "\toutput[" << i << "]=" << heads_[i].source->name + << '(' << heads_[i].index << ")\n"; + } this->DFSVisit([&os](Node *node) { if (node->is_variable()) { os << "Variable:" << node->name << '\n'; @@ -81,23 +87,24 @@ int Symbol::FindDuplicateArgs(std::unordered_map *out) const { void Symbol::Compose(const std::vector& args, const std::string& name) { - CHECK(!head_.source->is_variable()) << "PlaceHolder cannot be composed"; - head_.source->name = name; + CHECK_EQ(NumReturns(), 1) << "Only composition of value function is supported currently"; + CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; + heads_[0].source->name = name; for (size_t i = 0; i < args.size(); ++i) { - CHECK_NE(args[i].head_.index, -1) + CHECK_NE(args[i].NumReturns(), 1) << "Argument " << i << " is a tuple, scalar is required"; } // positional arguments requires all arguments for now. // TODO(bing) consider partial assignments - if (head_.source->is_atomic()) { + if (this->is_atomic()) { // atomic symbol do not have place holder for all the arguments - std::vector req_args = head_.source->sym->ListArguments(); + std::vector req_args = heads_[0].source->sym->ListArguments(); CHECK_EQ(args.size(), req_args.size()) << "Incorrect number of arguments, requires " << req_args.size() << ", provided " << args.size(); - head_.source->inputs.resize(args.size()); + heads_[0].source->inputs.resize(args.size()); for (size_t i = 0; i < args.size(); ++i) { - head_.source->inputs[i] = args[i].head_; + heads_[0].source->inputs[i] = args[i].heads_[0]; } } else { // find all the place holders @@ -114,7 +121,7 @@ void Symbol::Compose(const std::vector& args, auto iter = replace_map.find(e->source.get()); if (iter == replace_map.end()) { if (arg_counter < args.size()) { - target = &(args[arg_counter].head_); + target = &(args[arg_counter].heads_[0]); replace_map[e->source.get()] = target; } ++arg_counter; @@ -137,38 +144,38 @@ void Symbol::Compose(const std::vector& args, void Symbol::Compose(const std::unordered_map& kwargs, const std::string& name) { - CHECK(!head_.source->is_variable()) << "PlaceHolder cannot be composed"; - head_.source->name = name; + CHECK_EQ(NumReturns(), 1) << "Only composition of value function is supported currently"; + CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; + heads_[0].source->name = name; for (const auto& kv : kwargs) { - CHECK_NE(kv.second.head_.index, -1) + CHECK_EQ(kv.second.NumReturns(), 1) << "Keyword Argument " << kv.first << " is a tuple, scalar is required"; } size_t nmatched = 0; - if (head_.source->is_atomic()) { + if (this->is_atomic()) { // atomic symbol do not have place holder for all the arguments - std::vector req_args = head_.source->sym->ListArguments(); - head_.source->inputs.resize(req_args.size()); + std::vector req_args = heads_[0].source->sym->ListArguments(); + heads_[0].source->inputs.resize(req_args.size()); for (size_t i = 0; i < req_args.size(); ++i) { auto iter = kwargs.find(req_args[i]); if (iter != kwargs.end()) { - head_.source->inputs[i] = iter->second.head_; - + heads_[0].source->inputs[i] = iter->second.heads_[0]; ++nmatched; } else { // create a variable node // TODO(bing): think of naming convention if (name.length() == 0) { - head_.source->inputs[i] = DataEntry( + heads_[0].source->inputs[i] = DataEntry( std::make_shared(nullptr, req_args[i]), 0); } else { - head_.source->inputs[i] = DataEntry( + heads_[0].source->inputs[i] = DataEntry( std::make_shared(nullptr, name + '_' + req_args[i]), 0); } } } // if things goes wrong recover the old state if (nmatched != kwargs.size()) { - head_.source->inputs.clear(); + heads_[0].source->inputs.clear(); } } else { // find all the arguments positions @@ -194,7 +201,7 @@ void Symbol::Compose(const std::unordered_map& kwargs, const DataEntry *target = nullptr; auto iter = kwargs.find(e->source->name); if (iter != kwargs.end()) { - target = &(iter->second.head_); + target = &(iter->second.heads_[0]); // count how many arguments have been matched. if (visited.count(e->source.get()) == 0) { visited.insert(e->source.get()); @@ -228,18 +235,22 @@ void Symbol::Compose(const std::unordered_map& kwargs, } } -Symbol Symbol::operator[] (int index) const { - CHECK_EQ(head_.index, -1) << "Current symbol can't be indexed because it returns a scalar."; - CHECK_GE(index, 0) << "Symbol only accept nonnegative index"; - Symbol s = *this; - s.head_.index = index; - return s; +Symbol Symbol::operator[] (size_t index) const { + size_t nreturn = NumReturns(); + CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; + if (nreturn == 1) { + return *this; + } else { + Symbol s; + s.heads_.push_back(heads_[index]); + return s; + } } std::vector Symbol::ListArguments() const { std::vector ret; - if (head_.source->is_atomic()) { - return head_.source->sym->ListArguments(); + if (this->is_atomic()) { + return heads_[0].source->sym->ListArguments(); } else { this->DFSVisit([&ret](Node *node) { if (node->is_variable()) { @@ -251,25 +262,43 @@ std::vector Symbol::ListArguments() const { } std::vector Symbol::ListReturns() const { - return head_.source->sym->ListReturns(); + std::vector ret; + for (auto &head : heads_) { + if (head.source->is_variable()) { + ret.push_back(head.source->name); + } else { + // TODO(bing) rethink about output naming + auto &hname = head.source->name; + std::string rname = head.source->sym->ListReturns()[head.index]; + if (hname.length() == 0) { + ret.push_back(std::move(rname)); + } else { + ret.push_back(hname + '_' + rname); + } + } + } + return std::move(ret); } Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { // use special representation for atomic symbol + auto node = std::make_shared(atomic_symbol, ""); + size_t nret = atomic_symbol->NumReturns(); Symbol s; - s.head_ = DataEntry(std::make_shared(atomic_symbol, ""), - atomic_symbol->NumReturns() > 1 ? -1 : 0); + for (uint32_t i = 0; i < nret; ++i) { + s.heads_.push_back(DataEntry(node, i)); + } return s; } -void Symbol::Convert(const std::vector &heads, StaticGraph *out_graph) { +void Symbol::ToStaticGraph(StaticGraph *out_graph) const { // TODO(bing): Check unique name std::vector node_order; std::unordered_map node_index; auto &arg_nodes = out_graph->arg_nodes; arg_nodes.clear(); - DFSVisit(heads, [&node_order, &node_index, &arg_nodes](Node *n) { + this->DFSVisit([&node_order, &node_index, &arg_nodes](Node *n) { uint32_t nid = static_cast(node_index.size()); node_index[n] = nid; if (n->is_variable()) { @@ -297,19 +326,11 @@ void Symbol::Convert(const std::vector &heads, StaticGraph *out_graph) { } // setup heads out_graph->outputs.clear(); - for (auto &head : heads) { + for (auto &head : heads_) { StaticGraph::DataEntry e; - e.source_id = node_index[head.head_.source.get()]; - if (head.head_.index == -1) { - int nout = head.head_.source->sym->NumReturns(); - for (int i = 0; i < nout; ++i) { - e.index = i; - out_graph->outputs.push_back(e); - } - } else { - e.index = head.head_.index; - out_graph->outputs.push_back(e); - } + e.source_id = node_index[head.source.get()]; + e.index = head.index; + out_graph->outputs.push_back(e); } } } // namespace mxnet