Skip to content

Commit

Permalink
fix comments (#8)
Browse files Browse the repository at this point in the history
* add base code for mkldnn 1.0

* fix comments

* Update mkldnn.mk
  • Loading branch information
rongzha1 authored and TaoLv committed Sep 9, 2019
1 parent 8cc4201 commit 2c1cece
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ NDArray NDArray::MKLDNNDataReshape(const mxnet::TShape &shape) const {
// We shouldn't submit the reorder primitive here because submit will
// be called in operators.
mkldnn_format_tag_t format = ptr_->mkl_mem_->GetDefaultFormat();
// CHECK_NE(format, ptr_->mkl_mem_->GetFormat());
CHECK(ptr_->IsMKLDNN());
mkldnn::memory::desc def_desc = ptr_->mkl_mem_->GetDesc(format);
mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_desc);
MKLDNNStream *stream = MKLDNNStream::Get();
Expand Down Expand Up @@ -1615,7 +1615,6 @@ void NDArray::Save(dmlc::Stream *strm) const {
nd_cpu = *this;
#if MXNET_USE_MKLDNN == 1
if (nd_cpu.IsMKLDNNData()) {
LOG(FATAL) << "TODO: MKL-DNN 1.0";
nd_cpu = nd_cpu.Reorder2Default();
}
#endif
Expand Down
8 changes: 4 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ class MKLDNNStream {
public:
static MKLDNNStream *Get();

MKLDNNStream(): s(CpuEngine::Get()->get_engine()) {}
MKLDNNStream(): s(CpuEngine::Get()->get_engine()) {};

void RegisterPrimArgs(const mkldnn::primitive &prim,
const mkldnn_args_map_t &args) {
net_prim_args.push_back(std::make_pair(prim, args));
net_prim_args.emplace_back(prim, args);
}

void RegisterMem(std::shared_ptr<const mkldnn::memory> mem) {
Expand Down Expand Up @@ -399,7 +399,7 @@ typedef std::pair<OutDataOp, mkldnn::memory *> mkldnn_output_t;
void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem);

/*
* Here we want to get MKLDNN memory whose primitive desc is exactly the same as
* Here we want to get MKLDNN memory whose desc is exactly the same as
* the given one. operator== can't guarantee that. == can return true even if
* the formats are different. I need to double check its format.
*/
Expand Down Expand Up @@ -496,7 +496,7 @@ inline bool same_shape(const mxnet::TShape &shape, int dtype,
}

/*
* There is a large overhead of getting mkldnn::memory::primitive_desc from
* There is a large overhead of getting mkldnn::memory::desc from
* mkldnn::memory. This class is created to cache the metadata of mkldnn memory
* to provide a much more lightweight method to access them.
*/
Expand Down

0 comments on commit 2c1cece

Please sign in to comment.