-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add export & inference for hPINNs #902
Conversation
… into dev_model feat: add export and infer functions for hpinns
Thanks for your contribution! |
model_re = ppsci.arch.MLP(**cfg.MODEL.re_net) | ||
model_im = ppsci.arch.MLP(**cfg.MODEL.im_net) | ||
model_eps = ppsci.arch.MLP(**cfg.MODEL.eps_net) | ||
# wrap to a model_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前导出后推理会报错:
RuntimeError: (PreconditionNotMet) The variable named x is not found in the scope of the executor.
[Hint: scope->FindVar(name) should not be null.] (at ../paddle/fluid/inference/api/analysis_predictor.cc:2333)
这是由于导出时的input_keys和推理时的input_keys不匹配导致的。
这个模型中有input transform,即:("x","y")->transform->("x_cos_1","y_cons_1",...)->forward->("e_re",...),当前export函数中的input_keys是("x_cos_1","y_cons_1",...),而infer中是("x","y")。
因此,需要更改:
- 这里增加transform register代码
# register transform
model_re.register_input_transform(func_module.transform_in)
model_im.register_input_transform(func_module.transform_in)
model_eps.register_input_transform(func_module.transform_in)
model_re.register_output_transform(func_module.transform_out_real_part)
model_im.register_output_transform(func_module.transform_out_imaginary_part)
model_eps.register_output_transform(func_module.transform_out_epsilon)
- 更改下面
input_spec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,感谢纠正🌹
examples/hpinns/holography.py
Outdated
input_spec = [ | ||
{ | ||
key: InputSpec([None, 1], "float32", name=key) | ||
for key in model_list.input_keys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为for key in ["x","y"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
} | ||
|
||
ppsci.visualize.save_vtu_from_dict( | ||
"./hpinns_pred.vtu", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为:osp.join(cfg.output_dir, "hpinns_pred.vtu"),
以防文件覆盖
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为:
osp.join(cfg.output_dir, "hpinns_pred.vtu"),
以防文件覆盖
inference的产物文件就不放到output_dir下面了吧,这个就保持hpinns_pred.vtu
比较好
… into dev_model fix:register transform before export
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* feat: add export and infer functions for hpinns * fix:register transform brfore export
PR types
Others
PR changes
APIs
Describe
为hPINNs添加导出和推理代码