From 274c4014a561ef2e6707a012c4b386919b877a16 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 9 Mar 2019 15:58:14 -0500 Subject: [PATCH] [DLPACK] fix flaky ctypes support (#2759) --- python/tvm/_ffi/_ctypes/ndarray.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py index 37a18cbe4051..da24b9cd41eb 100644 --- a/python/tvm/_ffi/_ctypes/ndarray.py +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -24,6 +24,8 @@ def _from_dlpack(dltensor): dltensor = ctypes.py_object(dltensor) if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor): ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor) + # enforce type to make sure it works for all ctypes + ptr = ctypes.cast(ptr, ctypes.c_void_p) handle = TVMArrayHandle() check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) @@ -36,6 +38,8 @@ def _dlpack_deleter(pycapsule): pycapsule = ctypes.cast(pycapsule, ctypes.py_object) if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor): ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor) + # enforce type to make sure it works for all ctypes + ptr = ctypes.cast(ctypes.c_void_p, ptr) _LIB.TVMDLManagedTensorCallDeleter(ptr) ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))