Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function for dumping paddle thirdparty apis #102

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

## 最近更新

- 支持添加Paddle自定义算子
- 支持单模型运行并dump相关数据
- 提供离线对齐工具

Expand Down
8 changes: 8 additions & 0 deletions docs/Tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,14 @@ assign_weight(layer, module)

设置环境变量可以打开该功能: `export PADIFF_API_CHECK=ON`

### 添加自定义算子

当`PADIFF_API_CHECK`开启时,可以添加PD_BUILD_OP方式注册的自定义算子,该功能通过`PADDLE_THIRDPARTY_API`环境变量开启

设置时,需将自定义算子的module、api名写全,多个算子之间用逗号分隔,如:

`export PADDLE_THIRDPARTY_API=paddle3d.ops.iou3d_nms,paddle3d.ops.hard_voxelize,paddle_xpu_nn.xpu_rms_norm`

### 略过 wrap_layer

`export PADIFF_SIKP_WRAP_LAYER=TRUE`
Expand Down
21 changes: 21 additions & 0 deletions padiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,33 @@ def create_module(self, spec):
return None


def add_thirdparty_apis(thirdparty_apis):
# Exp. 1:
# thirdparty_apis = "paddle3d.ops.iou3d_nms,paddle3d.ops.hard_voxelize,paddle_xpu_nn.xpu_rms_norm"
# json.THIRD_PARTY = {"paddle3d.ops": {"iou3d_nms", "hard_voxelize"}, "paddle_xpu_nn": {"xpu_rms_norm"}}
thirdparty_apis = thirdparty_api.replace(" ", "").split(",")
json.THIRD_PARTY = {}
for fullname in thirdparty_apis:
module = fullname.rpartition(".")[0]
api = fullname.rpartition(".")[2]
if not module in json.THIRD_PARTY:
json.THIRD_PARTY[module] = {api}
else:
json.THIRD_PARTY[module].add(api)

self.paddle_apis.update(self.THIRD_PARTY)


if os.getenv("PADIFF_API_CHECK") == "ON":
for name in jsons.TORCH_PATH:
if name in sys.modules.keys():
module = sys.modules[name]
wrap_api_method(module)

thirdparty_apis = os.getenv("PADDLE_THIRDPARTY_API")
if thirdparty_apis is not None:
add_thirdparty_apis(thirdparty_apis)

for name in jsons.PADDLE_PATH:
if name in sys.modules.keys():
module = sys.modules[name]
Expand Down
Loading