Skip to content

Commit

Permalink
[refactor] Simplify logic in get_function_body.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ailing Zhang authored and taichi-gardener committed Nov 2, 2022
1 parent 8e5e002 commit c99938f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 35 deletions.
52 changes: 17 additions & 35 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,11 @@ def extract_arg(arg, anno):
arg.m), Layout.AOS
# external arrays
element_dim = 0 if anno.element_dim is None else anno.element_dim
shape = tuple(arg.shape)
shape = getattr(arg, 'shape', None)
if shape is None:
raise TaichiRuntimeTypeError(
f"Invalid argument into ti.types.ndarray(), got {arg}")
shape = tuple(shape)
if len(shape) < element_dim:
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required element_dim={element_dim}, "
Expand Down Expand Up @@ -573,7 +577,7 @@ def taichi_ast_generator(kernel_cxx):
taichi_kernel)
self.compiled_kernels[key] = taichi_kernel

def get_torch_callbacks(self, v, has_torch, is_ndarray=True):
def get_torch_callbacks(self, v):
callbacks = []

def get_call_back(u, v):
Expand All @@ -582,8 +586,6 @@ def call_back():

return call_back

assert has_torch
assert isinstance(v, torch.Tensor)
if not v.is_contiguous():
raise ValueError(
"Non contiguous tensors are not supported, please call tensor.contiguous() before passing it into taichi kernel."
Expand All @@ -600,7 +602,7 @@ def call_back():
callbacks.append(get_call_back(v, host_v))
return tmp, callbacks

def get_paddle_callbacks(self, v, has_pp):
def get_paddle_callbacks(self, v):
callbacks = []

def get_call_back(u, v):
Expand All @@ -609,9 +611,6 @@ def call_back():

return call_back

assert has_pp
assert isinstance(v, paddle.Tensor)

tmp = v.value().get_tensor()
taichi_arch = self.runtime.prog.config().arch

Expand Down Expand Up @@ -645,8 +644,6 @@ def func__(*args):

tmps = []
callbacks = []
has_torch = has_pytorch()
has_pp = has_paddle()

actual_argument_slot = 0
launch_ctx = t_kernel.make_launch_context()
Expand Down Expand Up @@ -685,14 +682,8 @@ def func__(*args):
texture_type.RWTextureType) and isinstance(
v, taichi.lang._texture.Texture):
launch_ctx.set_arg_rw_texture(actual_argument_slot, v.tex)
elif isinstance(
needed,
ndarray_type.NdarrayType) and (self.match_ext_arr(v)):
is_numpy = isinstance(v, np.ndarray)
is_torch = isinstance(v,
torch.Tensor) if has_torch else False

# Element shapes are already spcialized in Taichi codegen.
elif isinstance(needed, ndarray_type.NdarrayType):
# Element shapes are already specialized in Taichi codegen.
# The shape information for element dims are no longer needed.
# Therefore we strip the element shapes from the shape vector,
# so that it only holds "real" array shapes.
Expand All @@ -702,7 +693,7 @@ def func__(*args):
if element_dim:
array_shape = v.shape[
element_dim:] if is_soa else v.shape[:-element_dim]
if is_numpy:
if isinstance(v, np.ndarray):
if v.flags.c_contiguous:
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(v.ctypes.data),
Expand All @@ -725,22 +716,22 @@ def callback(original, updated):
raise ValueError(
"Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) before passing it into taichi kernel."
)
elif is_torch:
is_ndarray = False
tmp, torch_callbacks = self.get_torch_callbacks(
v, has_torch, is_ndarray)
elif has_pytorch() and isinstance(v, torch.Tensor):
tmp, torch_callbacks = self.get_torch_callbacks(v)
callbacks += torch_callbacks
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp.data_ptr()),
tmp.element_size() * tmp.nelement(), array_shape)
else:
elif has_paddle() and isinstance(v, paddle.Tensor):
# For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch
tmp, paddle_callbacks = self.get_paddle_callbacks(
v, has_pp)
tmp, paddle_callbacks = self.get_paddle_callbacks(v)
callbacks += paddle_callbacks
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp._ptr()),
v.element_size() * v.size, array_shape)
else:
raise TaichiRuntimeTypeError.get(
i, needed.to_string(), v)

elif isinstance(needed, MatrixType):
if needed.dtype in primitive_types.real_types:
Expand Down Expand Up @@ -832,15 +823,6 @@ def callback(original, updated):

return func__

@staticmethod
def match_ext_arr(v):
has_array = isinstance(v, np.ndarray)
if not has_array and has_pytorch():
has_array = isinstance(v, torch.Tensor)
if not has_array and has_paddle():
has_array = isinstance(v, paddle.Tensor)
return has_array

def ensure_compiled(self, *args):
instance_id, arg_features = self.mapper.lookup(args)
key = (self.func, instance_id, self.autodiff_mode)
Expand Down
28 changes: 28 additions & 0 deletions tests/python/test_kernel_arg_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@ def foo(a: ti.i32):
foo(1.2)


@test_utils.test(arch=ti.cpu)
def test_pass_float_as_ndarray():
@ti.kernel
def foo(a: ti.types.ndarray()):
pass

with pytest.raises(
ti.TaichiRuntimeTypeError,
match=r"Invalid argument into ti.types.ndarray\(\), got 1.2"):
foo(1.2)


@test_utils.test(arch=ti.cpu)
def test_random_python_class_as_ndarray():
@ti.kernel
def foo(a: ti.types.ndarray()):
pass

class Bla:
pass

with pytest.raises(
ti.TaichiRuntimeTypeError,
match=r"Invalid argument into ti.types.ndarray\(\), got"):
b = Bla()
foo(b)


@test_utils.test(exclude=[ti.metal])
def test_pass_u64():
if ti.lang.impl.current_cfg().arch == ti.vulkan and platform.system(
Expand Down

0 comments on commit c99938f

Please sign in to comment.