Skip to content

Commit

Permalink
[lang] Argpacks stores scalar values only in cpp implementation, and …
Browse files Browse the repository at this point in the history
…passes ndarray values by 'virtual parameters'

ghstack-source-id: d59cb7110531f92dfeaf5d5e4a8710d4f0b1852e
Pull Request resolved: #8266
  • Loading branch information
listerily committed Jul 7, 2023
1 parent bd30240 commit a0b3d31
Show file tree
Hide file tree
Showing 18 changed files with 482 additions and 286 deletions.
37 changes: 15 additions & 22 deletions python/taichi/lang/argpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, annotations, dtype, *args, **kwargs):
self.__dtype = dtype
self.__argpack = impl.get_runtime().prog.create_argpack(self.__dtype)
for i, (k, v) in enumerate(self.__entries.items()):
self._write_to_device(self.__annotations[k], type(v), v, i)
self._write_to_device(self.__annotations[k], type(v), v, self._calc_element_true_index(i))

def __del__(self):
if impl is not None and impl.get_runtime() is not None and impl.get_runtime().prog is not None:
Expand Down Expand Up @@ -119,7 +119,7 @@ def __getitem__(self, key):

def __setitem__(self, key, value):
self.__entries[key] = value
index = list(self.__annotations).index(key)
index = self._calc_element_true_index(list(self.__annotations).index(key))
self._write_to_device(self.__annotations[key], type(value), value, index)

def _set_entries(self, value):
Expand Down Expand Up @@ -181,6 +181,15 @@ def to_dict(self):
}
return res_dict

def _calc_element_true_index(self, old_index):
for i in range(old_index):
anno = list(self.__annotations.values())[i]
if isinstance(anno, sparse_matrix_builder) or isinstance(anno, ndarray_type.NdarrayType)\
or isinstance(anno, texture_type.TextureType) or isinstance(anno, texture_type.RWTextureType)\
or isinstance(anno, ndarray_type.NdarrayType):
old_index -= 1
return old_index

def _write_to_device(self, needed, provided, v, index):
if isinstance(needed, ArgPackType):
if not isinstance(v, ArgPack):
Expand Down Expand Up @@ -280,8 +289,7 @@ def __init__(self, **kwargs):
elements.append([dtype.dtype, k])
elif isinstance(dtype, ArgPackType):
self.members[k] = dtype
# Use i32 as a placeholder for nested argpacks
elements.append([primitive_types.i32, k])
raise TaichiSyntaxError("ArgPack nesting is not supported currently.")
elif isinstance(dtype, MatrixType):
# Convert MatrixType to StructType
if dtype.ndim == 1:
Expand All @@ -292,34 +300,19 @@ def __init__(self, **kwargs):
elements.append([_ti_core.get_type_factory_instance().get_struct_type(elements_), k])
elif isinstance(dtype, sparse_matrix_builder):
self.members[k] = dtype
elements.append([cook_dtype(primitive_types.u64), k])
elif isinstance(dtype, ndarray_type.NdarrayType):
self.members[k] = dtype
root_dtype = dtype.dtype
while isinstance(root_dtype, MatrixType):
root_dtype = root_dtype.dtype
elements.append(
[
_ti_core.DataType(
_ti_core.get_type_factory_instance().get_ndarray_struct_type(root_dtype, dtype.ndim, False)
),
k,
]
)
elif isinstance(dtype, texture_type.RWTextureType):
self.members[k] = dtype
elements.append(
[_ti_core.DataType(_ti_core.get_type_factory_instance().get_rwtexture_struct_type()), k]
)
elif isinstance(dtype, texture_type.TextureType):
self.members[k] = dtype
elements.append(
[_ti_core.DataType(_ti_core.get_type_factory_instance().get_rwtexture_struct_type()), k]
)
else:
dtype = cook_dtype(dtype)
self.members[k] = dtype
elements.append([dtype, k])
if len(elements) == 0:
# Use i32 as a placeholder for empty argpacks
elements.append([primitive_types.i32, k])
self.dtype = _ti_core.get_type_factory_instance().get_argpack_type(elements)

def __call__(self, *args, **kwargs):
Expand Down
102 changes: 66 additions & 36 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,68 +602,98 @@ def build_FunctionDef(ctx, node):
assert args.kw_defaults == []
assert args.kwarg is None

def decl_and_create_variable(annotation, name, arg_features, arg_depth):
def decl_and_create_variable(annotation, name, arg_features, invoke_later_dict, prefix_name, arg_depth):
full_name = prefix_name + "_" + name
if not isinstance(annotation, primitive_types.RefType):
ctx.kernel_args.append(name)
if isinstance(annotation, ArgPackType):
kernel_arguments.push_argpack_arg(name)
d = {}
items_to_put_in_dict = []
for j, (_name, anno) in enumerate(annotation.members.items()):
d[_name] = decl_and_create_variable(anno, _name, arg_features[j], arg_depth + 1)
return kernel_arguments.decl_argpack_arg(annotation, d)
result, obj = decl_and_create_variable(anno, _name, arg_features[j], invoke_later_dict,
full_name, arg_depth + 1)
if not result:
d[_name] = None
items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
else:
d[_name] = obj
argpack = kernel_arguments.decl_argpack_arg(annotation, d)
for item in items_to_put_in_dict:
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
return True, argpack
if isinstance(annotation, annotations.template):
return ctx.global_vars[name]
return True, ctx.global_vars[name]
if isinstance(annotation, annotations.sparse_matrix_builder):
return kernel_arguments.decl_sparse_matrix(
to_taichi_type(arg_features),
name,
)
return False, (kernel_arguments.decl_sparse_matrix,
(
to_taichi_type(arg_features),
full_name,
))
if isinstance(annotation, ndarray_type.NdarrayType):
return kernel_arguments.decl_ndarray_arg(
to_taichi_type(arg_features[0]),
arg_features[1],
name,
arg_features[2],
arg_features[3],
)
return False, (kernel_arguments.decl_ndarray_arg,
(
to_taichi_type(arg_features[0]),
arg_features[1],
full_name,
arg_features[2],
arg_features[3],
))
if isinstance(annotation, texture_type.TextureType):
return kernel_arguments.decl_texture_arg(arg_features[0], name)
return False, (kernel_arguments.decl_texture_arg,
(arg_features[0], full_name))
if isinstance(annotation, texture_type.RWTextureType):
return kernel_arguments.decl_rw_texture_arg(
arg_features[0],
arg_features[1],
arg_features[2],
name,
)
return False, (kernel_arguments.decl_rw_texture_arg,
(arg_features[0],
arg_features[1],
arg_features[2],
full_name))
if isinstance(annotation, MatrixType):
return kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
return True, kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
if isinstance(annotation, StructType):
return kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
return kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)
return True, kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
return True, kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)

def transform_as_kernel():
# Treat return type
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type, ctx.is_real_function)
impl.get_runtime().compiling_callable.finalize_rets()

invoke_later_dict = dict()
create_variable_later = dict()
for i, arg in enumerate(args.args):
if isinstance(ctx.func.arguments[i].annotation, ArgPackType):
d = {}
kernel_arguments.push_argpack_arg(ctx.func.arguments[i].name)
d = {}
items_to_put_in_dict = []
for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()):
d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j], 1)
ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d))
result, obj = decl_and_create_variable(anno, name, ctx.arg_features[i][j], invoke_later_dict,
"__argpack_" + name, 1)
if not result:
d[name] = None
items_to_put_in_dict.append(("__argpack_" + name, name, obj))
else:
d[name] = obj
argpack = kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d)
for item in items_to_put_in_dict:
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
create_variable_later[arg.arg] = argpack
else:
ctx.create_variable(
arg.arg,
decl_and_create_variable(
ctx.func.arguments[i].annotation,
ctx.func.arguments[i].name,
ctx.arg_features[i] if ctx.arg_features is not None else None,
0,
),
result, obj = decl_and_create_variable(
ctx.func.arguments[i].annotation,
ctx.func.arguments[i].name,
ctx.arg_features[i] if ctx.arg_features is not None else None,
invoke_later_dict,
"",
0,
)
ctx.create_variable(arg.arg, obj if result else obj[0](*obj[1]))
for k, v in invoke_later_dict.items():
argpack, name, func, params = v
argpack[name] = func(*params)
for k, v in create_variable_later.items():
ctx.create_variable(k, v)

impl.get_runtime().compiling_callable.finalize_params()
# remove original args
Expand Down
Loading

0 comments on commit a0b3d31

Please sign in to comment.