Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Loading numpy-incompatible NDArray in numpy-compatible mode #16597

Merged
merged 2 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
* \param curr returns the current status
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIsNumpyShape(bool* curr);
MXNET_DLL int MXIsNumpyShape(int* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_shape 1 when numpy shape semantics is thread local on,
Expand Down
10 changes: 6 additions & 4 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! \brief whether numpy compatibility is on. */
bool is_np_shape() const {
/*! \brief return current numpy compatibility status,
* GlobalOn(2), ThreadLocalOn(1), Off(0).
* */
int is_np_shape() const {
if (is_np_shape_global_) {
return true;
return 2;
}
return is_np_shape_thread_local_;
return is_np_shape_thread_local_ ? 1 : 0;
}
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2778,7 +2778,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape
(JNIEnv *env, jobject obj, jobject compatibleRef) {
bool isNumpyShape;
haojin2 marked this conversation as resolved.
Show resolved Hide resolved
int ret = MXIsNumpyShape(&isNumpyShape);
int ret = MXIsNumpyShape(static_cast<int*>(&isNumpyShape));
SetIntField(env, compatibleRef, static_cast<int>(isNumpyShape));
return ret;
}
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) {
API_END();
}

int MXIsNumpyShape(bool* curr) {
int MXIsNumpyShape(int* curr) {
API_BEGIN();
*curr = Imperative::Get()->is_np_shape();
API_END();
Expand Down
3 changes: 2 additions & 1 deletion src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1732,7 +1732,8 @@ bool NDArray::Load(dmlc::Stream *strm) {
" Please turn on np shape semantics in Python using `with np_shape(True)`"
" or decorator `use_np_shape` to scope the code of loading the ndarray.";
} else {
CHECK(!Imperative::Get()->is_np_shape())
// when the flag is global on, skip the check since it would be always global on.
CHECK(Imperative::Get()->is_np_shape() == GlobalOn || !Imperative::Get()->is_np_shape())
<< "ndarray was not saved in np shape semantics, but being loaded in np shape semantics."
" Please turn off np shape semantics in Python using `with np_shape(False)`"
" to scope the code of loading the ndarray.";
Expand Down