Skip to content

Commit

Permalink
[FFI][BUGFIX] Fix memory leak when Pac callback argument is NDArray (#…
Browse files Browse the repository at this point in the history
…6744)

* [FFI][BUGFIX] Fix leak when Packed callback arg is ndarray.

Co-authored-by: Matthew Brookhart <[email protected]>

* Fix for rust ts and jvm

* Update rust/tvm-rt/src/to_function.rs

Co-authored-by: Junru Shao <[email protected]>

Co-authored-by: Matthew Brookhart <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2020
1 parent bb4179e commit fc69f68
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 2 deletions.
3 changes: 2 additions & 1 deletion jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs,
TVMValue arg = args[i];
int tcode = typeCodes[i];
if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle ||
tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle) {
tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle ||
tcode == kTVMNDArrayHandle) {
TVMCbArgToReturn(&arg, &tcode);
}
jobject jarg = tvmRetValueToJava(env, arg, tcode);
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def _get_global_func(name, allow_missing=False):
_return_module, ArgTypeCode.MODULE_HANDLE
)
C_TO_PY_ARG_SWITCH[ArgTypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
C_TO_PY_ARG_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
C_TO_PY_ARG_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = _wrap_arg_func(
lambda x: _make_array(x.v_handle, False, True), ArgTypeCode.NDARRAY_HANDLE
)

_CLASS_MODULE = None
_CLASS_PACKED_FUNC = None
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/_ffi/_cython/ndarray.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ cdef class NDArrayBase:
def __set__(self, value):
self._set_handle(value)

property is_view:
def __get__(self):
return self.c_is_view != 0

@property
def shape(self):
"""Shape of this array"""
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kTVMObjectHandle or
tcode == kTVMPackedFuncHandle or
tcode == kTVMModuleHandle or
tcode == kTVMNDArrayHandle or
tcode == kTVMObjectRefArg or
tcode > kTVMExtBegin):
CALL(TVMCbArgToReturn(&value, &tcode))
Expand Down
2 changes: 2 additions & 0 deletions rust/tvm-rt/src/to_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ pub trait ToFunction<I, O>: Sized {
value = args_list[i];
tcode = type_codes_list[i];
if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int
|| tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int
|| tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int
|| tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int
{
check_call!(ffi::TVMCbArgToReturn(
&mut value as *mut _,
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_runtime_packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,19 @@ def test_numpy_scalar():
assert tvm.testing.echo(np.int64(maxint)) == maxint


def test_ndarray_args():
def check(arr):
assert not arr.is_view
assert tvm.testing.object_use_count(arr) == 2

fcheck = tvm.runtime.convert(check)
x = tvm.nd.array([1, 2, 3])
fcheck(x)
assert tvm.testing.object_use_count(x) == 1


if __name__ == "__main__":
test_ndarray_args()
test_numpy_scalar()
test_rvalue_ref()
test_empty_array()
Expand Down
1 change: 1 addition & 0 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,7 @@ export class Instance implements Disposable {
tcode == ArgTypeCode.TVMObjectHandle ||
tcode == ArgTypeCode.TVMObjectRValueRefArg ||
tcode == ArgTypeCode.TVMPackedFuncHandle ||
tcode == ArgTypeCode.TVMNDArrayHandle ||
tcode == ArgTypeCode.TVMModuleHandle
) {
lib.checkCall(
Expand Down

0 comments on commit fc69f68

Please sign in to comment.