Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Loading numpy-incompatible NDArray in numpy-compatible mode #16597

Merged
merged 2 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
* \param curr returns the current status
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIsNumpyShape(bool* curr);
MXNET_DLL int MXIsNumpyShape(int* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_shape 1 when numpy shape semantics is thread local on,
Expand Down
10 changes: 6 additions & 4 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! \brief whether numpy compatibility is on. */
bool is_np_shape() const {
/*! \brief return current numpy compatibility status,
* GlobalOn(2), ThreadLocalOn(1), Off(0).
* */
int is_np_shape() const {
if (is_np_shape_global_) {
return true;
return 2;
}
return is_np_shape_thread_local_;
return is_np_shape_thread_local_ ? 1 : 0;
}
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2777,9 +2777,9 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile
// Numpy
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape
(JNIEnv *env, jobject obj, jobject compatibleRef) {
bool isNumpyShape;
int isNumpyShape;
int ret = MXIsNumpyShape(&isNumpyShape);
SetIntField(env, compatibleRef, static_cast<int>(isNumpyShape));
SetIntField(env, compatibleRef, isNumpyShape);
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) {
API_END();
}

int MXIsNumpyShape(bool* curr) {
int MXIsNumpyShape(int* curr) {
API_BEGIN();
*curr = Imperative::Get()->is_np_shape();
API_END();
Expand Down
3 changes: 2 additions & 1 deletion src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1732,7 +1732,8 @@ bool NDArray::Load(dmlc::Stream *strm) {
" Please turn on np shape semantics in Python using `with np_shape(True)`"
" or decorator `use_np_shape` to scope the code of loading the ndarray.";
} else {
CHECK(!Imperative::Get()->is_np_shape())
// when the flag is global on, skip the check since it would be always global on.
CHECK(Imperative::Get()->is_np_shape() == GlobalOn || !Imperative::Get()->is_np_shape())
<< "ndarray was not saved in np shape semantics, but being loaded in np shape semantics."
" Please turn off np shape semantics in Python using `with np_shape(False)`"
" to scope the code of loading the ndarray.";
Expand Down