Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add JSON serialization, based on commit from DP #91

Merged
merged 4 commits into from
Sep 18, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dmlc-core
48 changes: 34 additions & 14 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle,
* \param keys the name of the NDArray, optional, can be NULL
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayListSave(const char* fname,
mx_uint num_args,
NDArrayHandle* args,
const char** keys);
MXNET_DLL int MXNDArraySave(const char* fname,
mx_uint num_args,
NDArrayHandle* args,
const char** keys);
/*!
* \brief Load list of narray from the file.
* \param fname name of the file.
Expand All @@ -136,11 +136,11 @@ MXNET_DLL int MXNDArrayListSave(const char* fname,
* \param out_names the names of returning NDArrays, can be NULL
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayListLoad(const char* fname,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);
MXNET_DLL int MXNDArrayLoad(const char* fname,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);
/*!
* \brief Perform a synchronize copy from a continugous CPU memory region.
*
Expand Down Expand Up @@ -359,13 +359,33 @@ MXNET_DLL int MXSymbolCreateGroup(mx_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out);
/*!
* \brief Create symbol from config.
* \param cfg configuration string
* \param out created symbol handle
* \brief Load a symbol from a json file.
* \param fname the file name.
* \param out the output symbol.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out);
/*!
* \brief Load a symbol from a json string.
* \param json the json string.
* \param out the output symbol.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out);
/*!
* \brief Save a symbol into a json file.
* \param sym the input symbol.
* \param fname the file name.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname);
/*!
* \brief Save a symbol into a json string
* \param sym the input symbol.
* \param out_json output json string.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg,
SymbolHandle *out);
MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json);
/*!
* \brief Free the symbol handle.
* \param symbol the symbol
Expand Down
21 changes: 20 additions & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dmlc/registry.h>
#include <vector>
#include <string>
#include <memory>
#include "./base.h"
#include "./storage.h"
Expand Down Expand Up @@ -244,6 +246,24 @@ class NDArray {
inline void CheckAndAlloc() const {
ptr_->CheckAndAlloc();
}
/*!
* \brief Save list of narray into the file.
* \param fname name of the file.
* \param data the NDArrays to be saved.
* \param keys the name of the NDArray, optional, can be zero length.
*/
static void Save(const std::string& fname,
const std::vector<NDArray>& data,
const std::vector<std::string>& names);
/*!
* \brief Load list of narray into from the file.
* \param fname name of the file.
* \param data the NDArrays to be loaded
* \param keys the name of the NDArray, if saved in the file.
*/
static void Load(const std::string& fname,
std::vector<NDArray>* data,
std::vector<std::string>* keys);

private:
/*! \brief the real data chunk that backs NDArray */
Expand Down Expand Up @@ -397,7 +417,6 @@ void SampleUniform(real_t begin, real_t end, NDArray *out);
* \param out output NDArray.
*/
void SampleGaussian(real_t mu, real_t sigma, NDArray *out);

//--------------------------------------------------------------
// The following part are API Registration of NDArray functions.
//--------------------------------------------------------------
Expand Down
27 changes: 25 additions & 2 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
#define MXNET_OPERATOR_H_

#include <dmlc/base.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <vector>
#include <map>
#include <string>
#include <utility>
#include "./base.h"
Expand Down Expand Up @@ -149,6 +151,11 @@ class OperatorProperty {
* \param kwargs the keyword arguments parameters
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*!
* \brief Get a map representation of internal parameters.
* This can be used by Init to recover the state of OperatorProperty.
*/
virtual std::map<std::string, std::string> GetParams() const = 0;
/*!
* \brief Get input arguments of the Operator.
* \return vector of arguments.
Expand Down Expand Up @@ -221,6 +228,7 @@ class OperatorProperty {
/*!
* \brief return the type string of the Operator
* subclasses override this function.
* \return The type string.
*/
virtual std::string TypeString() const = 0;
//--------------------------------------------------------
Expand Down Expand Up @@ -415,6 +423,19 @@ struct OperatorPropertyReg
this->key_var_num_args = key;
return *this;
}
/*!
* \brief Check if TypeString of the type matches the registered name
*/
inline OperatorPropertyReg& check_name() {
OperatorProperty *p = this->body();
std::string type = p->TypeString();
delete p;
CHECK_EQ(this->name, type)
<< "Register Name and TypeString mismatch, name=\"" << this->name << "\","
<< " but TypeString=\"" << type <<"\"";
return *this;
}

/*! \brief The key num_args name. */
std::string key_var_num_args;
};
Expand All @@ -434,10 +455,12 @@ struct OperatorPropertyReg
*/
#define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \
static ::mxnet::OperatorProperty* __create__ ## OperatorProperty ## name ## __() { \
return new OperatorPropertyType; \
OperatorProperty* ret = new OperatorPropertyType(); \
return ret; \
} \
DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \
.set_body(__create__ ## OperatorProperty ## name ## __)
.set_body(__create__ ## OperatorProperty ## name ## __) \
.check_name()

#endif // DMLC_USE_CXX11
} // namespace mxnet
Expand Down
94 changes: 87 additions & 7 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
#define MXNET_SYMBOLIC_H_

#include <dmlc/base.h>
#include <dmlc/json.h>
#include <algorithm>
#include <vector>
#include <memory>
#include <map>
#include <string>
#include <utility>
#include <functional>
Expand Down Expand Up @@ -64,6 +67,25 @@ class StaticGraph {
if (source_id == other.source_id) return index < other.index;
return source_id < other.source_id;
}
/*!
* \brief interface for json serialization.
* \param writer the JSON writer to write json into.
*/
inline void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray(false);
writer->WriteArrayItem(source_id);
writer->WriteArrayItem(index);
writer->EndArray();
}
/*!
* \brief interface for json serialization.
* \param reader the JSON reader to read json from.
*/
inline void Load(dmlc::JSONReader *reader) {
std::pair<uint32_t, uint32_t> p;
reader->Read(&p);
*this = DataEntry(p.first, p.second);
}
};
/*!
* \brief Operation Node in static graphs.
Expand Down Expand Up @@ -95,6 +117,23 @@ class StaticGraph {
int32_t backward_source_id;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}

friend void swap(Node& lhs, Node& rhs) {
std::swap(lhs.op, rhs.op);
std::swap(lhs.name, rhs.name);
std::swap(lhs.inputs, rhs.inputs);
std::swap(lhs.backward_source_id, rhs.backward_source_id);
}
/*! \brief copy constructor in favor of serialization. */
Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr),
name(another.name),
inputs(another.inputs),
backward_source_id(another.backward_source_id) {}

inline Node& operator=(Node another) {
swap(*this, another);
return *this;
}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
Expand All @@ -107,13 +146,33 @@ class StaticGraph {
inline bool is_variable() const {
return op == nullptr && !is_backward();
}
/*!
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
void Save(dmlc::JSONWriter *writer) const;
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
void Load(dmlc::JSONReader *reader);
};
/*! \brief all nodes in the graph */
std::vector<Node> nodes;
/*! \brief index of nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief heads outputs of the graph */
std::vector<DataEntry> heads;
/*!
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
void Save(dmlc::JSONWriter *writer) const;
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
void Load(dmlc::JSONReader *reader);
// funtions to help inference in static graph
/*!
* \brief Perform a topological sort on the graph
Expand Down Expand Up @@ -246,6 +305,12 @@ class Symbol {
* \param out_graph the pointer holder of the output graph
*/
void ToStaticGraph(StaticGraph *out_graph) const;
/*!
* \brief create equivalence of symbol from static graphs.
* This operation will change the content of current symbol.
* \param graph the static graph
*/
void FromStaticGraph(const StaticGraph &graph);
/*!
* \brief Apply the symbol as a function, compose with arguments
* \param args positional arguments for the symbol
Expand All @@ -267,7 +332,6 @@ class Symbol {
* \return the new symbol with gradient graph
*/
Symbol Grad(const std::vector<std::string>& wrt) const;

/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param arg_shapes the shape of input arguments of the operator
Expand Down Expand Up @@ -299,6 +363,24 @@ class Symbol {
std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes,
std::vector<TShape> *aux_shapes) const;
/*!
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
inline void Save(dmlc::JSONWriter *writer) const {
StaticGraph g;
this->ToStaticGraph(&g);
g.Save(writer);
}
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
inline void Load(dmlc::JSONReader *reader) {
StaticGraph g;
g.Load(reader);
this->FromStaticGraph(g);
}
/*!
* \brief get number of outputs of this symbol
* \return number of outputs
Expand All @@ -315,12 +397,6 @@ class Symbol {
* \sa OperatorProperty::Create
*/
static Symbol Create(OperatorProperty *op);
/*!
* \brief create equivalence of symbol from static graphs
* \param graph the static graph
* \return the created symbol
*/
static Symbol Create(const StaticGraph &graph);
/*!
* \brief create equivalence of symbol by grouping the symbols together
* \param symbols list of symbols
Expand Down Expand Up @@ -430,4 +506,8 @@ class Executor {
const std::vector<NDArray> &aux_states);
}; // class operator
} // namespace mxnet

namespace dmlc {
DMLC_DECLARE_TRAITS(is_pod, ::mxnet::StaticGraph::DataEntry, true);
}
#endif // MXNET_SYMBOLIC_H_
Loading