Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.37】为 Paddle 代码转换工具新增 API 转换规则(第 4 组) #481

Merged
merged 23 commits into from
Oct 12, 2024

Conversation

monster1015
Copy link
Contributor

@monster1015 monster1015 commented Sep 21, 2024

PR Docs

PaddlePaddle/docs#6884

PR APIs

torch.mvlgamma
torch.Tensor.igamma
torch.Tensor.igamma_
torch.Tensor.igammac
torch.Tensor.igammac_
torch.Tensor.mvlgamma
torch.Tensor.mvlgamma_
torch.ormqr
torch.Tensor.ormqr
torch.Tensor.orgqr
torch.special.ndtr
torch.Tensor.matrix_exp
torch.linalg.inv_ex
torch.linalg.cholesky_ex
torch._assert
torch.testing.make_tensor

@monster1015
Copy link
Contributor Author

请问如何查看CI的具体结果?我点进去是一片白的 @luotao1

@paddle-bot paddle-bot bot added the contributor External developers label Sep 21, 2024
@luotao1
Copy link
Collaborator

luotao1 commented Sep 23, 2024

请问如何查看CI的具体结果?我点进去是一片白的

cc @tianshuo78520a @risemeup1

@luotao1
Copy link
Collaborator

luotao1 commented Sep 23, 2024

@monster1015 日志已打开

@monster1015
Copy link
Contributor Author

image
image
image
结果不是对得上吗?还有就是torch.Tensor.ormqr就是有至少两个输入的 @luotao1 @zhwesky2010

@monster1015
Copy link
Contributor Author

image
还有这个,self.paddleClass不是获取对应输入的?

"out"
],
"unsupport_args": [
"check_errors"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档改了这里不改?

"out"
],
"unsupport_args": [
"check_errors"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档改了这里不改?

"unsupport_args": [
"noncontiguous",
"exclude_zero",
"memory_format"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些其实没必要全部算作 无法支持,只有影响到网络计算的才算。

return code


class TensorMatrixExpMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个matcher能否复用class TensorFunc2PaddleFunc

else:
kwargs["x"] = kwargs["input"]
if "left" not in kwargs.keys():
kwargs["left"] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些参数默认不是True吗,还是说默认参数不同,必须要设置一下

建议转写代码尽可能简洁无冗余

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数是一致的,我已经修改了,docs和对应的mathcer也已经修改

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

丰富下测试case,以下四种情况的测试case必须全部包含:

  1. 传入所有参数且全部指定关键字
  2. 传入所有参数且全部不指定关键字
  3. 改变关键字顺序
  4. 默认参数均不指定

"min_input_args": 1,
"args_list": [
"A",
"check_errors",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不对吧,放在*的前后意义是不同的,是 位置参数 还是 指定关键字参数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

他是指定关键字参数

"torch.Tensor.orgqr": {},
"torch.Tensor.ormqr": {},
"torch.Tensor.orgqr": {
"Matcher": "OrgqrMatcher",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名TensorOrgqrMatcher

[0.74611458, 1.24800785, 0.88039371]])
out1 = torch.randn(3, 3)
info1 = torch.tensor([1, 2, 3], dtype=torch.int32)
torch.linalg.cholesky_ex(a, check_errors=False, upper=True, out=(out1, info1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_errors=True测一下

[0.48016702, 0.14235102, 0.42620817]])
out1 = torch.tensor([])
info1 = torch.tensor([1, 2, 3], dtype=torch.int32)
out1, info1 = torch.linalg.inv_ex(x, check_errors=False, out=(out1, info1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_errors=True的情况也测下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经完善测试case了

@monster1015
Copy link
Contributor Author

已经全都跑通了 @zhwesky2010

API_TEMPLATE = textwrap.dedent(
"""
out = paddle.uniform({}, dtype={}, min={}, max={}).to({})
out.stop_gradient = not {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没有requires_grad,这个地方可以简写为一行
有requires_grad才需要扩充为三行

参考下GenericMatcher

class LinalgCholeskyExMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "upper" not in kwargs:
kwargs["upper"] = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个一定要单独设置吗?默认值不就是False吗,直接self.kwargs_to_str不就行了吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
主要是为了配合下面的代码

]
},
"torch.Tensor.ormqr": {
"Matcher": "OrmqrMatcher",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个配置TensorFunc2PaddleFunc就可以吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以

@@ -13933,6 +14051,18 @@
"verbose"
]
},
"torch.ormqr": {
"Matcher": "OrmqrMatcher",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个配置GenericMatcher就可以吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也可以

return code


class OrmqrMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个逻辑简单 应该不用单独写Matcher了,分别配置 TensorFunc2PaddleFunc和GenericMatcher 就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
这个我觉得应该设置的,其他是已经改完了的。请review

@@ -807,6 +807,79 @@ def generate_code(self, kwargs):
return code


class TensorOrgqrMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也可以用GenericMatcher,配置一个kwargs_change

"Matcher": "MakeTMatcher",
"min_input_args": 3,
"args_list": [
"shape",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个看起来shape是 可变参数位置参数

其他的应该是 指定关键字参数

所以应该这么写:

"args_list": [
    "*shape",
    "*",
    "device",
    "dtype",
    "low",
    "high",
    "requires_grad",
    "noncontiguous",
    "exclude_zero",
    "memory_format"
]

只看torch文档可能不精确

"""
)
code = API_TEMPLATE.format(
kwargs["shape"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面的这个shape比较复杂,因为其既可以是 可变参数

torch.testing.make_tensor(3, 2, dtype=torch.float32, device=torch.device('cpu'))

也可以是 位置参数

torch.testing.make_tensor([1, 2], dtype=torch.float32, device=torch.device('cpu'))

参考下CreateMatcher的写法吧,如果比较复杂不会写,可以看看是不是复用CreateMatcher,再配置下json字段就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的好的,我试试

@monster1015
Copy link
Contributor Author

已经好了,请review @zhwesky2010

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit b6c624e into PaddlePaddle:master Oct 12, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants