Skip to content

Commit

Permalink
feat: Add support for apple metal attributes and code_before_func_def
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Jul 27, 2024
1 parent f2c86b2 commit d89b9ab
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Changelog
## [0.4.13] - 2024-07-27
### Added
- add attributes and `code_before_func_def` support for apple metal

## [0.4.12] - 2024-01-03
### Added
- add support for param class reload by check id in `sys.module`
Expand Down
8 changes: 5 additions & 3 deletions pccm/builder/inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,9 @@ def inline(self,

container_fcode.arg(arg_name,
mapped_cpp_type_str,
array=mapped_cpp_type.count)
inner_fcode.arg(cap.replaced_name, inner_cpp_type)
array=mapped_cpp_type.count,
userdata=obj)
inner_fcode.arg(cap.replaced_name, inner_cpp_type, userdata=obj)
for rr in cap.replace_range_pairs:
replace = Replace(cap.replaced_name, *rr)
replaces.append(replace)
Expand All @@ -727,7 +728,8 @@ def inline(self,
v, self.plugins, user_arg=user_arg)
container_fcode.arg(k,
str(mapped_cpp_type),
array=mapped_cpp_type.count)
array=mapped_cpp_type.count,
userdata=v)
args.append(v)

inner_code_str = execute_modifiers(code_str, replaces)
Expand Down
24 changes: 20 additions & 4 deletions pccm/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,13 @@ def __init__(self,
self.ret_doc = None # type: Optional[str]
self.ret_pyanno = None # type: Optional[str]
self.func_doc = None # type: Optional[str]
self.code_before_include = None # type: Optional[str]
self.code_after_include = None # type: Optional[str]
self.code_in_ns = None # type: Optional[str]

# for metal kernel global constants
self.code_before_func_def = None # type: Optional[str]

self._additional_pre_attrs = [] # type: List[str]

self._impl_only_deps: List[Dependency] = []
Expand Down Expand Up @@ -721,6 +725,8 @@ def get_sig(self,
arg_strs = [] # type: List[str]
for arg in self.arguments:
arg_fmt = "{type} {name}".format(type=arg.type_str, name=arg.name)
if arg.attributes is not None:
arg_fmt += "[[" + " ".join(arg.attributes) + "]]"
if arg.default:
arg_fmt += " = {}".format(arg.default)
arg_strs.append(arg_fmt)
Expand All @@ -740,6 +746,7 @@ def get_sig(self,

def get_impl(self, name: str, meta: FunctionMeta, class_name: str = ""):
"""
code_before_func_def
template <args>
pre_attrs ret_type BoundClass::name(args) post_attrs {body};
"""
Expand Down Expand Up @@ -781,6 +788,8 @@ def get_impl(self, name: str, meta: FunctionMeta, class_name: str = ""):
else:
array_str = arg.array
arg_fmt += array_str
if arg.attributes is not None:
arg_fmt += "[[" + " ".join(arg.attributes) + "]]"
if arg.default and header_only:
arg_fmt += " = {}".format(arg.default)
arg_strs.append(arg_fmt)
Expand All @@ -795,6 +804,8 @@ def get_impl(self, name: str, meta: FunctionMeta, class_name: str = ""):
prefix_fmt = pre_attrs_str + " " + prefix_fmt
blocks = [] # List[Union[Block, str]]
blocks.extend(self._blocks)
if self.code_before_func_def is not None:
template_fmt = self.code_before_func_def + "\n" + template_fmt
block = Block(template_fmt + prefix_fmt, blocks, "}")
if meta.macro_guard is not None:
block = Block("#if {}".format(meta.macro_guard), [block], "#endif")
Expand All @@ -815,7 +826,9 @@ def arg(self,
default: Optional[str] = None,
pyanno: Optional[str] = None,
array: Optional[Union[int, str]] = None,
doc: Optional[str] = None):
doc: Optional[str] = None,
attributes: Optional[List[str]] = None,
userdata: Any = None):
"""add a argument.
"""
type = str(type).strip()
Expand All @@ -826,7 +839,8 @@ def arg(self,
if not part.strip():
raise ValueError("you provide a empty name in", name)
args.append(
Argument(part.strip(), type, default, pyanno=pyanno, array=array, doc=doc))
Argument(part.strip(), type, default, pyanno=pyanno, array=array, doc=doc,
attributes=attributes, userdata=userdata))
else:
arg_attrs = arg_parser(name)
for arg_with_attr in arg_attrs:
Expand All @@ -837,8 +851,10 @@ def arg(self,
type,
default,
pyanno=pyanno,
attrs=arg_with_attr.attrs, array=array,
doc=doc))
user_attrs=arg_with_attr.attrs, array=array,
doc=doc,
attributes=attributes,
userdata=userdata))
if type not in self._type_to_hook:
hook = _get_attr_hook(type)
if hook is not None:
Expand Down
10 changes: 7 additions & 3 deletions pccm/core/funccode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@ def __init__(self,
array: Optional[Union[int, str]] = None,
pyanno: Optional[str] = None,
doc: Optional[str] = None,
attrs: Optional[List[Attr]] = None):
user_attrs: Optional[List[Attr]] = None,
attributes: Optional[List[str]] = None,
userdata: Any = None):
self.name = name.strip()
self.type_str = str(type).strip() # type: str
self.default = default
self.array = array
self.pyanno = pyanno
self.doc = doc
if attrs is None:
if user_attrs is None:
self.attrs: List[Attr] = []
else:
self.attrs: List[Attr] = attrs
self.attrs: List[Attr] = user_attrs
self.attributes = attributes
self.userdata = userdata
if pyanno is not None:
self.pyanno = pyanno.strip()
assert len(pyanno) != 0
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.4.12
0.4.13

0 comments on commit d89b9ab

Please sign in to comment.