diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py index 37a18cbe4051..da24b9cd41eb 100644 --- a/python/tvm/_ffi/_ctypes/ndarray.py +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -24,6 +24,8 @@ def _from_dlpack(dltensor): dltensor = ctypes.py_object(dltensor) if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor): ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor) + # enforce type to make sure it works for all ctypes + ptr = ctypes.cast(ptr, ctypes.c_void_p) handle = TVMArrayHandle() check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) @@ -36,6 +38,8 @@ def _dlpack_deleter(pycapsule): pycapsule = ctypes.cast(pycapsule, ctypes.py_object) if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor): ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor) + # enforce type to make sure it works for all ctypes + ptr = ctypes.cast(ctypes.c_void_p, ptr) _LIB.TVMDLManagedTensorCallDeleter(ptr) ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))