Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Aug 26, 2023
1 parent 62e9855 commit 3be132b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 36 deletions.
69 changes: 61 additions & 8 deletions ppsci/utils/sym_to_func.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Sympy to python function conversion module
"""

from __future__ import annotations

import functools
from typing import TYPE_CHECKING
from typing import Dict
Expand All @@ -17,7 +33,6 @@

from ppsci.autodiff import hessian
from ppsci.autodiff import jacobian
from ppsci.utils import logger

if TYPE_CHECKING:
from ppsci import arch
Expand Down Expand Up @@ -235,7 +250,7 @@ def __init__(self, expr: Union[sp.Number, sp.NumberSymbol]):
self.expr = float(self.expr)
else:
raise TypeError(
f"expr({expr}) should be float/int/bool, but got {type(self.expr)}"
f"expr({expr}) should be Float/Integer/Boolean/Rational, but got {type(self.expr)}"
)
self.expr = paddle.to_tensor(self.expr)

Expand All @@ -253,10 +268,9 @@ class ComposedNode(nn.Layer):
Compose list of several callable objects together.
"""

def __init__(self, target: str, funcs: List[Node]):
def __init__(self, funcs: List[Node]):
super().__init__()
self.funcs = funcs
self.target = target

def forward(self, data_dict: Dict):
# call all funcs in order
Expand Down Expand Up @@ -299,19 +313,57 @@ def _post_traverse(cur_node: sp.Basic, nodes: List[sp.Basic]) -> List[sp.Basic]:


def sympy_to_function(
target: str,
expr: sp.Expr,
models: Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]] = None,
) -> ComposedNode:
"""Convert sympy expression to callable function.
Args:
target (str): Alias of `expr`, such as "z" for expression: "z = a + b * c".
expr (sp.Expr): Sympy expression to be converted.
models (Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]]): Model(s) for computing forward result in `LayerNode`.
Returns:
ComposedNode: Callable object for computing expr with necessary input(s) data in dict given.
Examples:
>>> import paddle
>>> import sympy as sp
>>> from ppsci import arch
>>> from ppsci.utils import sym_to_func
>>> a, b, c, x, y = sp.symbols("a b c x y")
>>> u = sp.Function("u")(x, y)
>>> v = sp.Function("v")(x, y)
>>> z = -a + b * (c ** 2) + u * v + 2.3
>>> model = arch.MLP(("x", "y"), ("u", "v"), 4, 16)
>>> batch_size = 13
>>> a_tensor = paddle.randn([batch_size, 1])
>>> b_tensor = paddle.randn([batch_size, 1])
>>> c_tensor = paddle.randn([batch_size, 1])
>>> x_tensor = paddle.randn([batch_size, 1])
>>> y_tensor = paddle.randn([batch_size, 1])
>>> model_output_dict = model({"x": x_tensor, "y": y_tensor})
>>> u_tensor, v_tensor = model_output_dict["u"], model_output_dict["v"]
>>> z_tensor_manually = (
... -a_tensor + b_tensor * (c_tensor ** 2)
... + u_tensor * v_tensor + 2.3
... )
>>> z_tensor_sympy = sym_to_func.sympy_to_function(z, model)(
... {
... "a": a_tensor,
... "b": b_tensor,
... "c": c_tensor,
... "x": x_tensor,
... "y": y_tensor,
... }
... )
>>> paddle.allclose(z_tensor_manually, z_tensor_sympy).item()
True
"""

# simplify expression to reduce nodes in tree
Expand All @@ -330,9 +382,10 @@ def sympy_to_function(
sympy_nodes = list(dict.fromkeys(sympy_nodes))

# convert sympy node to callable node
if not isinstance(models, (tuple, list)):
models = (models,)
callable_nodes = []
for i, node in enumerate(sympy_nodes):
logger.debug(f"tree node [{i + 1}/{len(sympy_nodes)}]: {node}")
if isinstance(node.func, sp.core.function.UndefinedFunction):
match = False
for model in models:
Expand All @@ -359,4 +412,4 @@ def sympy_to_function(
)

# Compose callable nodes into one callable object
return ComposedNode(target, callable_nodes)
return ComposedNode(callable_nodes)
6 changes: 3 additions & 3 deletions test/utils/test_linear_elasticity_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import ppsci
from ppsci import equation
from ppsci.autodiff import clear
from ppsci.utils import expression
from ppsci.utils import sym_to_func

__all__ = []

Expand Down Expand Up @@ -227,8 +227,8 @@ def test_linearelasticity(E, nu, lambda_, mu, rho, dim, time):
E, nu, lambda_, mu, rho, dim, time
).equations
for target, expr in sympy_expr_dict.items():
sympy_expr_dict[target] = expression.sympy_to_function(
target, expr, [disp_net, stress_net]
sympy_expr_dict[target] = sym_to_func.sympy_to_function(
expr, [disp_net, stress_net]
)

# compute equation with python function
Expand Down
25 changes: 0 additions & 25 deletions test/utils/test_navier_stokes_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,29 +127,6 @@ def __init__(self, nu, rho=1, dim=3, time=True):


class ZeroEquation_sympy:
"""
Zero Equation Turbulence model
Parameters
==========
nu : float
The kinematic viscosity of the fluid.
max_distance : float
The maximum wall distance in the flow field.
rho : float, Sympy sp.Symbol/Expr, str
The density. If `rho` is a str then it is
converted to Sympy sp.Function of form 'rho(x,y,z,t)'.
If 'rho' is a Sympy sp.Symbol or Expression then this
is substituted into the equation. Default is 1.
dim : int
Dimension of the Zero Equation Turbulence model (2 or 3).
Default is 3.
time : bool
If time-dependent equations or not. Default is True.
Example
"""

def __init__(
self, nu, max_distance, rho=1, dim=3, time=True
): # TODO add density into model
Expand Down Expand Up @@ -454,7 +431,6 @@ def momentum_y_f(out):
sympy_expr_dict = NavierStokes_sympy(nu_sympy, rho, dim, time).equations
for target, expr in sympy_expr_dict.items():
sympy_expr_dict[target] = sym_to_func.sympy_to_function(
target,
expr,
[
model,
Expand Down Expand Up @@ -528,7 +504,6 @@ def test_nu_constant(self, nu, rho, dim, time):
sympy_expr_dict = NavierStokes_sympy(nu_sympy, rho, dim, time).equations
for target, expr in sympy_expr_dict.items():
sympy_expr_dict[target] = sym_to_func.sympy_to_function(
target,
expr,
[
model,
Expand Down

0 comments on commit 3be132b

Please sign in to comment.