From c99938f98b3c3e87b4c0f250c0fee573a6c45ab7 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Wed, 2 Nov 2022 11:51:55 +0800 Subject: [PATCH] [refactor] Simplify logic in get_function_body. related: #5662 --- python/taichi/lang/kernel_impl.py | 52 +++++++++----------------- tests/python/test_kernel_arg_errors.py | 28 ++++++++++++++ 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 357655caa8e2d..563f2237dabbd 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -383,7 +383,11 @@ def extract_arg(arg, anno): arg.m), Layout.AOS # external arrays element_dim = 0 if anno.element_dim is None else anno.element_dim - shape = tuple(arg.shape) + shape = getattr(arg, 'shape', None) + if shape is None: + raise TaichiRuntimeTypeError( + f"Invalid argument into ti.types.ndarray(), got {arg}") + shape = tuple(shape) if len(shape) < element_dim: raise ValueError( f"Invalid argument into ti.types.ndarray() - required element_dim={element_dim}, " @@ -573,7 +577,7 @@ def taichi_ast_generator(kernel_cxx): taichi_kernel) self.compiled_kernels[key] = taichi_kernel - def get_torch_callbacks(self, v, has_torch, is_ndarray=True): + def get_torch_callbacks(self, v): callbacks = [] def get_call_back(u, v): @@ -582,8 +586,6 @@ def call_back(): return call_back - assert has_torch - assert isinstance(v, torch.Tensor) if not v.is_contiguous(): raise ValueError( "Non contiguous tensors are not supported, please call tensor.contiguous() before passing it into taichi kernel." @@ -600,7 +602,7 @@ def call_back(): callbacks.append(get_call_back(v, host_v)) return tmp, callbacks - def get_paddle_callbacks(self, v, has_pp): + def get_paddle_callbacks(self, v): callbacks = [] def get_call_back(u, v): @@ -609,9 +611,6 @@ def call_back(): return call_back - assert has_pp - assert isinstance(v, paddle.Tensor) - tmp = v.value().get_tensor() taichi_arch = self.runtime.prog.config().arch @@ -645,8 +644,6 @@ def func__(*args): tmps = [] callbacks = [] - has_torch = has_pytorch() - has_pp = has_paddle() actual_argument_slot = 0 launch_ctx = t_kernel.make_launch_context() @@ -685,14 +682,8 @@ def func__(*args): texture_type.RWTextureType) and isinstance( v, taichi.lang._texture.Texture): launch_ctx.set_arg_rw_texture(actual_argument_slot, v.tex) - elif isinstance( - needed, - ndarray_type.NdarrayType) and (self.match_ext_arr(v)): - is_numpy = isinstance(v, np.ndarray) - is_torch = isinstance(v, - torch.Tensor) if has_torch else False - - # Element shapes are already spcialized in Taichi codegen. + elif isinstance(needed, ndarray_type.NdarrayType): + # Element shapes are already specialized in Taichi codegen. # The shape information for element dims are no longer needed. # Therefore we strip the element shapes from the shape vector, # so that it only holds "real" array shapes. @@ -702,7 +693,7 @@ def func__(*args): if element_dim: array_shape = v.shape[ element_dim:] if is_soa else v.shape[:-element_dim] - if is_numpy: + if isinstance(v, np.ndarray): if v.flags.c_contiguous: launch_ctx.set_arg_external_array_with_shape( actual_argument_slot, int(v.ctypes.data), @@ -725,22 +716,22 @@ def callback(original, updated): raise ValueError( "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) before passing it into taichi kernel." ) - elif is_torch: - is_ndarray = False - tmp, torch_callbacks = self.get_torch_callbacks( - v, has_torch, is_ndarray) + elif has_pytorch() and isinstance(v, torch.Tensor): + tmp, torch_callbacks = self.get_torch_callbacks(v) callbacks += torch_callbacks launch_ctx.set_arg_external_array_with_shape( actual_argument_slot, int(tmp.data_ptr()), tmp.element_size() * tmp.nelement(), array_shape) - else: + elif has_paddle() and isinstance(v, paddle.Tensor): # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch - tmp, paddle_callbacks = self.get_paddle_callbacks( - v, has_pp) + tmp, paddle_callbacks = self.get_paddle_callbacks(v) callbacks += paddle_callbacks launch_ctx.set_arg_external_array_with_shape( actual_argument_slot, int(tmp._ptr()), v.element_size() * v.size, array_shape) + else: + raise TaichiRuntimeTypeError.get( + i, needed.to_string(), v) elif isinstance(needed, MatrixType): if needed.dtype in primitive_types.real_types: @@ -832,15 +823,6 @@ def callback(original, updated): return func__ - @staticmethod - def match_ext_arr(v): - has_array = isinstance(v, np.ndarray) - if not has_array and has_pytorch(): - has_array = isinstance(v, torch.Tensor) - if not has_array and has_paddle(): - has_array = isinstance(v, paddle.Tensor) - return has_array - def ensure_compiled(self, *args): instance_id, arg_features = self.mapper.lookup(args) key = (self.func, instance_id, self.autodiff_mode) diff --git a/tests/python/test_kernel_arg_errors.py b/tests/python/test_kernel_arg_errors.py index 23a495a371d04..5c6bace947a01 100644 --- a/tests/python/test_kernel_arg_errors.py +++ b/tests/python/test_kernel_arg_errors.py @@ -20,6 +20,34 @@ def foo(a: ti.i32): foo(1.2) +@test_utils.test(arch=ti.cpu) +def test_pass_float_as_ndarray(): + @ti.kernel + def foo(a: ti.types.ndarray()): + pass + + with pytest.raises( + ti.TaichiRuntimeTypeError, + match=r"Invalid argument into ti.types.ndarray\(\), got 1.2"): + foo(1.2) + + +@test_utils.test(arch=ti.cpu) +def test_random_python_class_as_ndarray(): + @ti.kernel + def foo(a: ti.types.ndarray()): + pass + + class Bla: + pass + + with pytest.raises( + ti.TaichiRuntimeTypeError, + match=r"Invalid argument into ti.types.ndarray\(\), got"): + b = Bla() + foo(b) + + @test_utils.test(exclude=[ti.metal]) def test_pass_u64(): if ti.lang.impl.current_cfg().arch == ti.vulkan and platform.system(