Skip to content

Commit

Permalink
[refactor] Simplify callback logic for torch & paddle args (taichi-de…
Browse files Browse the repository at this point in the history
…v#6626)

related: taichi-dev#5662

We need runtime arg handling to be more modular to prepare for struct
args. Will send a few refactor PRs on top of this one.
  • Loading branch information
ailzhang authored and quadpixels committed May 13, 2023
1 parent d97ac48 commit f86861d
Showing 1 changed file with 46 additions and 62 deletions.
108 changes: 46 additions & 62 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,64 +593,6 @@ def taichi_ast_generator(kernel_cxx):
taichi_kernel)
self.compiled_kernels[key] = taichi_kernel

def get_torch_callbacks(self, v):
callbacks = []

def get_call_back(u, v):
def call_back():
u.copy_(v)

return call_back

if not v.is_contiguous():
raise ValueError(
"Non contiguous tensors are not supported, please call tensor.contiguous() before passing it into taichi kernel."
)
tmp = v
taichi_arch = self.runtime.prog.config().arch

if str(v.device).startswith('cuda'):
# External tensor on cuda
if taichi_arch != _ti_core.Arch.cuda:
# copy data back to cpu
host_v = v.to(device='cpu', copy=True)
tmp = host_v
callbacks.append(get_call_back(v, host_v))
return tmp, callbacks

def get_paddle_callbacks(self, v):
callbacks = []

def get_call_back(u, v):
def call_back():
u.copy_(v, False)

return call_back

tmp = v.value().get_tensor()
taichi_arch = self.runtime.prog.config().arch

if v.place.is_gpu_place():
# External tensor on cuda
if taichi_arch != _ti_core.Arch.cuda:
# copy data back to cpu
host_v = v.cpu()
tmp = host_v.value().get_tensor()
callbacks.append(get_call_back(v, host_v))
elif v.place.is_cpu_place():
# External tensor on cpu
if taichi_arch == _ti_core.Arch.cuda:
gpu_v = v.cuda()
tmp = gpu_v.value().get_tensor()
callbacks.append(get_call_back(v, gpu_v))
else:
# Paddle do support many other backends like XPU, NPU, MLU, IPU.
raise TaichiRuntimeError(
f"Taichi do not support backend {v.place} that Paddle support."
)

return tmp, callbacks

def get_function_body(self, t_kernel):
# The actual function body
def func__(*args):
Expand Down Expand Up @@ -733,15 +675,57 @@ def callback(original, updated):
"Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) before passing it into taichi kernel."
)
elif has_pytorch() and isinstance(v, torch.Tensor):
tmp, torch_callbacks = self.get_torch_callbacks(v)
callbacks += torch_callbacks
if not v.is_contiguous():
raise ValueError(
"Non contiguous tensors are not supported, please call tensor.contiguous() before passing it into taichi kernel."
)
taichi_arch = self.runtime.prog.config().arch

def get_call_back(u, v):
def call_back():
u.copy_(v)

return call_back

tmp = v
if str(v.device).startswith(
'cuda') and taichi_arch != _ti_core.Arch.cuda:
# Getting a torch CUDA tensor on Taichi non-cuda arch:
# We just replace it with a CPU tensor and by the end of kernel execution we'll use the callback to copy the values back to the original CUDA tensor.
host_v = v.to(device='cpu', copy=True)
tmp = host_v
callbacks.append(get_call_back(v, host_v))

launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp.data_ptr()),
tmp.element_size() * tmp.nelement(), array_shape)
elif has_paddle() and isinstance(v, paddle.Tensor):
# For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch
tmp, paddle_callbacks = self.get_paddle_callbacks(v)
callbacks += paddle_callbacks
def get_call_back(u, v):
def call_back():
u.copy_(v, False)

return call_back

tmp = v.value().get_tensor()
taichi_arch = self.runtime.prog.config().arch
if v.place.is_gpu_place():
if taichi_arch != _ti_core.Arch.cuda:
# Paddle cuda tensor on Taichi non-cuda arch
host_v = v.cpu()
tmp = host_v.value().get_tensor()
callbacks.append(get_call_back(v, host_v))
elif v.place.is_cpu_place():
if taichi_arch == _ti_core.Arch.cuda:
# Paddle cpu tensor on Taichi cuda arch
gpu_v = v.cuda()
tmp = gpu_v.value().get_tensor()
callbacks.append(get_call_back(v, gpu_v))
else:
# Paddle do support many other backends like XPU, NPU, MLU, IPU
raise TaichiRuntimeTypeError(
f"Taichi do not support backend {v.place} that Paddle support"
)
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp._ptr()),
v.element_size() * v.size, array_shape)
Expand Down

0 comments on commit f86861d

Please sign in to comment.