Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

NDArray saved in Python cannot be loaded in Scala #10338

Closed
vascokk opened this issue Mar 30, 2018 · 2 comments · Fixed by #13678
Closed

NDArray saved in Python cannot be loaded in Scala #10338

vascokk opened this issue Mar 30, 2018 · 2 comments · Fixed by #13678

Comments

@vascokk
Copy link

vascokk commented Mar 30, 2018

Description

I am trying to convert a numpy array into an ndarray and save it. I intend to use the ndaray file in Scala, but getting src/ndarray/ndarray.cc:672: Check failed: fi->Read(data) Invalid NDArray file format.

Environment info (Required)

Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 4
On-line CPU(s) list: 0-3
Thread(s) per core: 1
Core(s) per socket: 4
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 60
Model name: Intel(R) Core(TM) i5-4570 CPU @ 3.20GHz
Stepping: 3
CPU MHz: 3192.584
CPU max MHz: 3600.0000
CPU min MHz: 800.0000
BogoMIPS: 6385.16
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 6144K
NUMA node0 CPU(s): 0-3
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm abm cpuid_fault epb invpcid_single pti retpoline tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm xsaveopt dtherm ida arat pln pts
----------Python Info----------
('Version :', '2.7.14')
('Compiler :', 'GCC 7.2.0')
('Build :', ('default', 'Sep 23 2017 22:06:14'))
('Arch :', ('64bit', ''))
------------Pip Info-----------
('Version :', '9.0.3')
('Directory :', '/usr/local/lib/python2.7/dist-packages/pip')
----------MXNet Info-----------
('Version :', '1.1.0')
('Directory :', '/usr/local/lib/python2.7/dist-packages/mxnet')
('Commit Hash :', '07a83a0325a3d782513a04f47d711710972cb144')
----------System Info----------
('Platform :', 'Linux-4.13.0-37-generic-x86_64-with-Ubuntu-17.10-artful')
('system :', 'Linux')
('node :', 'desktop-102016')
('release :', '4.13.0-37-generic')
('version :', '#42-Ubuntu SMP Wed Mar 7 14:13:23 UTC 2018')
----------Hardware Info----------
('machine :', 'x86_64')
('processor :', 'x86_64')
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0008 sec, LOAD: 0.6838 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0010 sec, LOAD: 0.1299 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0004 sec, LOAD: 0.8468 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0006 sec, LOAD: 0.0610 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0014 sec, LOAD: 0.0648 sec.
Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.0016 sec, LOAD: 0.7765 sec.

Package used (Python/R/Scala/Julia):
Python 2.7

For Scala user, please provide:

  1. Java version: (java -version):

java version "1.8.0_144"
Java(TM) SE Runtime Environment (build 1.8.0_144-b01)
Java HotSpot(TM) 64-Bit Server VM (build 25.144-b01, mixed mode)

  1. Maven version: (mvn -version)
    Apache Maven 3.5.0
    Maven home: /usr/share/maven
    Java version: 1.8.0_112, vendor: Oracle Corporation
    Java home: /usr/local/java/jdk1.8.0_112/jre
    Default locale: en_IE, platform encoding: UTF-8
    OS name: "linux", version: "4.13.0-37-generic", arch: "amd64", family: "unix"

  2. Scala runtime if applicable: (scala -version)
    Scala code runner version 2.11.8 -- Copyright 2002-2016, LAMP/EPFL

For R user, please provide R sessionInfo():

Build info (Required if built from source)

Compiler (gcc/clang/mingw/visual studio):

MXNet commit hash:
(Paste the output of git rev-parse HEAD here.)

Build config:
(Paste the content of config.mk, or the build command.)

Error Message:

[10:10:45] /home/ubuntu/mxnet/dmlc-core/include/dmlc/logging.h:300: [10:10:45] src/ndarray/ndarray.cc:672: Check failed: fi->Read(data) Invalid NDArray file format

Stack trace returned 5 entries:
[bt] (0) /tmp/mxnet7044930640479576680/mxnet-scala(_ZN4dmlc15LogMessageFatalD1Ev+0x39) [0x7f529b12e6b9]
[bt] (1) /tmp/mxnet7044930640479576680/mxnet-scala(_ZN5mxnet7NDArray4LoadEPN4dmlc6StreamEPSt6vectorIS0_SaIS0_EEPS4_INSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEESaISD_EE+0x207) [0x7f529ba4e747]
[bt] (2) /tmp/mxnet7044930640479576680/mxnet-scala(MXNDArrayLoad+0x4cd) [0x7f529bd7c26d]
[bt] (3) /tmp/mxnet7044930640479576680/mxnet-scala(Java_ml_dmlc_mxnet_LibInfo_mxNDArrayLoad+0x64) [0x7f529b122254]
[bt] (4) [0x7f52f9017a34]

Exception in thread "main" ml.dmlc.mxnet.MXNetError: [10:10:45] src/ndarray/ndarray.cc:672: Check failed: fi->Read(data) Invalid NDArray file format

Minimum reproducible example

(If you are using your own code, please provide a short script that reproduces the error. Otherwise, please provide link to the existing example.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

$ python
Python 2.7.14 (default, Sep 23 2017, 22:06:14)
[GCC 7.2.0] on linux2
Type "help", "copyright", "credits" or "license" for more information.

import mxnet as mx
a = mx.nd.ones((100, 50))
mx.nd.save("a", [a])

then in Scala:

import ml.dmlc.mxnet._
val a = NDArray.load("a")

What have you tried to solve it?

@gigasquid
Copy link
Member

gigasquid commented May 7, 2018

I did some investigating in this due to running into a problem on the clojure side. I'd add it here in case it helps:

  • Saving as a dict seems to work fine mx.nd.save("a2" ,{'a': a})

Also, related - the dtypes of the python ndarray are not all supported:

The python ndarray dytpes are

# pylint: disable= no-member
_DTYPE_NP_TO_MX = {
    None: -1,
    np.float32: 0,
    np.float64: 1,
    np.float16: 2,
    np.uint8: 3,
    np.int32: 4,
    np.int8: 5,
    np.int64: 6,
}

While the Scala ndtypes are

object DType extends Enumeration {
  type DType = Value
  val Float32 = Value(0, "float32")
  val Float64 = Value(1, "float64")
  val Float16 = Value(2, "float16")
  val UInt8 = Value(3, "uint8")
  val Int32 = Value(4, "int32")
  val Unknown = Value(-1, "unknown")
  private[mxnet] def numOfBytes(dtype: DType): Int = {
    dtype match {
      case DType.UInt8 => 1
      case DType.Int32 => 4
      case DType.Float16 => 2
      case DType.Float32 => 4
      case DType.Float64 => 8
      case DType.Unknown => 0
    }
  }

For example, if a user tries to save a python ndarray as int64, they will not be able to use it. Maybe an error message on load would be good?

@Hodapp87
Copy link

Hodapp87 commented May 7, 2018

I did a quick fix below that seems to resolve it, but I have tested only with @gigasquid's Clojure bindings, and I still need to move this to upstream master and check the resultant values more carefully. An actual error message as she suggests would also be very helpful.

diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/DType.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/DType.scala
index bfe757d5..0fc53918 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/DType.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/DType.scala
@@ -24,6 +24,8 @@ object DType extends Enumeration {
   val Float16 = Value(2, "float16")
   val UInt8 = Value(3, "uint8")
   val Int32 = Value(4, "int32")
+  val Int8 = Value(5, "int8")
+  val Int64 = Value(6, "int64")
   private[mxnet] def numOfBytes(dtype: DType): Int = {
     dtype match {
       case DType.UInt8 => 1
@@ -31,6 +33,8 @@ object DType extends Enumeration {
       case DType.Float16 => 2
       case DType.Float32 => 4
       case DType.Float64 => 8
+      case DType.Int8 => 1
+      case DType.Int64 => 8
     }
   }
 }
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
index 7cfc059c..423040fd 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
@@ -1167,8 +1167,8 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
     dtype match {
       case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble)
       case DType.Float64 => units.map(wrapBytes(_).getDouble)
-      case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble)
-      case DType.UInt8 => internal.map(_.toDouble)
+      case DType.Int32 | DType.Int64 => units.map(wrapBytes(_).getInt.toDouble)
+      case DType.UInt8 | DType.Int8 => internal.map(_.toDouble)
     }
   }
   def toFloatArray: Array[Float] = {
@@ -1176,8 +1176,8 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
     dtype match {
       case DType.Float32 => units.map(wrapBytes(_).getFloat)
       case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat)
-      case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat)
-      case DType.UInt8 => internal.map(_.toFloat)
+      case DType.Int32 | DType.Int64 => units.map(wrapBytes(_).getInt.toFloat)
+      case DType.UInt8 | DType.Int8 => internal.map(_.toFloat)
     }
   }
   def toIntArray: Array[Int] = {
@@ -1185,8 +1185,8 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
     dtype match {
       case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt)
       case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt)
-      case DType.Int32 => units.map(wrapBytes(_).getInt)
-      case DType.UInt8 => internal.map(_.toInt)
+      case DType.Int32 | DType.Int64 => units.map(wrapBytes(_).getInt)
+      case DType.UInt8 | DType.Int8 => internal.map(_.toInt)
     }
   }
   def toByteArray: Array[Byte] = {
@@ -1194,8 +1194,8 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
     dtype match {
       case DType.Float16 | DType.Float32 => units.map(wrapBytes(_).getFloat.toByte)
       case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte)
-      case DType.Int32 => units.map(wrapBytes(_).getInt.toByte)
-      case DType.UInt8 => internal.clone()
+      case DType.Int32 | DType.Int64 => units.map(wrapBytes(_).getInt.toByte)
+      case DType.UInt8 | DType.Int8 => internal.clone()
     }
   }

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants