From 5ebde07119d929b45b87c245e0055b959866fa4e Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 4 Jul 2023 16:18:21 +0530 Subject: [PATCH 1/3] Added support for symbolic elementary functions --- src/libasr/codegen/asr_to_c_cpp.h | 71 +++++++---- src/libasr/pass/intrinsic_function_registry.h | 118 ++++++++++++------ src/lpython/semantics/python_ast_to_asr.cpp | 12 +- 3 files changed, 138 insertions(+), 63 deletions(-) diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index 25cad0a724..e635f1f654 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -2702,7 +2702,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { out += func_name; break; \ } - std::string performSymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) { + std::string performBinarySymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) { headers.insert("symengine/cwrapper.h"); std::string indent(4, ' '); LCOMPILERS_ASSERT(x.n_args == 2); @@ -2727,6 +2727,23 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { return target; } + std::string performUnarySymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) { + headers.insert("symengine/cwrapper.h"); + std::string indent(4, ' '); + LCOMPILERS_ASSERT(x.n_args == 1); + std::string target = symengine_queue.push(); + std::string target_src = symengine_src; + this->visit_expr(*x.m_args[0]); + std::string arg1 = src; + std::string arg1_src = symengine_src; + if (ASR::is_a(*x.m_args[0])) { + symengine_queue.pop(); + } + symengine_src = target_src + arg1_src; + symengine_src += indent + functionName + "(" + target + ", " + arg1 + ");\n"; + return target; + } + void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t &x) { std::string out; std::string indent(4, ' '); @@ -2745,27 +2762,51 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { SET_INTRINSIC_NAME(Exp2, "exp2"); SET_INTRINSIC_NAME(Expm1, "expm1"); case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicAdd)): { - src = performSymbolicOperation("basic_add", x); + src = performBinarySymbolicOperation("basic_add", x); return; } case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicSub)): { - src = performSymbolicOperation("basic_sub", x); + src = performBinarySymbolicOperation("basic_sub", x); return; } case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicMul)): { - src = performSymbolicOperation("basic_mul", x); + src = performBinarySymbolicOperation("basic_mul", x); return; } case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicDiv)): { - src = performSymbolicOperation("basic_div", x); + src = performBinarySymbolicOperation("basic_div", x); return; } case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicPow)): { - src = performSymbolicOperation("basic_pow", x); + src = performBinarySymbolicOperation("basic_pow", x); return; } case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicDiff)): { - src = performSymbolicOperation("basic_diff", x); + src = performBinarySymbolicOperation("basic_diff", x); + return; + } + case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicSin)): { + src = performUnarySymbolicOperation("basic_sin", x); + return; + } + case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicCos)): { + src = performUnarySymbolicOperation("basic_cos", x); + return; + } + case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicLog)): { + src = performUnarySymbolicOperation("basic_log", x); + return; + } + case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicExp)): { + src = performUnarySymbolicOperation("basic_exp", x); + return; + } + case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicAbs)): { + src = performUnarySymbolicOperation("basic_abs", x); + return; + } + case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand)): { + src = performUnarySymbolicOperation("basic_expand", x); return; } case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicPi)): { @@ -2794,22 +2835,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { src = target; return; } - case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand)): { - headers.insert("symengine/cwrapper.h"); - LCOMPILERS_ASSERT(x.n_args == 1); - std::string target = symengine_queue.push(); - std::string target_src = symengine_src; - this->visit_expr(*x.m_args[0]); - std::string arg1 = src; - std::string arg1_src = symengine_src; - if (ASR::is_a(*x.m_args[0])) { - symengine_queue.pop(); - } - symengine_src = target_src + arg1_src; - symengine_src += indent + "basic_expand(" + target + ", " + arg1 + ");\n"; - src = target; - return; - } default : { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(x.m_intrinsic_id) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index daa0ab8a45..a5e71b46ff 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -74,6 +74,11 @@ enum class IntrinsicFunctions : int64_t { SymbolicInteger, SymbolicDiff, SymbolicExpand, + SymbolicSin, + SymbolicCos, + SymbolicLog, + SymbolicExp, + SymbolicAbs, Sum, // ... }; @@ -2169,45 +2174,52 @@ namespace SymbolicInteger { } } // namespace SymbolicInteger -namespace SymbolicExpand { - - static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) { - const Location& loc = x.base.base.loc; - ASRUtils::require_impl(x.n_args == 1, - "SymbolicExpand must have exactly 1 input argument", - loc, diagnostics); - - ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); - ASRUtils::require_impl(ASR::is_a(*input_type), - "SymbolicExpand expects an argument of type SymbolicExpression", - x.base.base.loc, diagnostics); - } - - static inline ASR::expr_t *eval_SymbolicExpand(Allocator &/*al*/, - const Location &/*loc*/, Vec& /*args*/) { - // TODO - return nullptr; - } - - static inline ASR::asr_t* create_SymbolicExpand(Allocator& al, const Location& loc, - Vec& args, - const std::function err) { - if (args.size() != 1) { - err("Intrinsic expand function accepts exactly 1 argument", loc); - } - - ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); - if(!ASR::is_a(*argtype)) { - err("Argument of SymbolicExpand function must be of type SymbolicExpression", - args[0]->base.loc); - } - - ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); - return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicExpand, - static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand), 0, to_type); - } +#define create_symbolic_unary_macro(X) \ +namespace X { \ + \ + static inline void verify_args(const ASR::IntrinsicFunction_t& x, \ + diag::Diagnostics& diagnostics) { \ + const Location& loc = x.base.base.loc; \ + ASRUtils::require_impl(x.n_args == 1, \ + #X " must have exactly 1 input argument", loc, diagnostics); \ + \ + ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \ + ASRUtils::require_impl(ASR::is_a(*input_type), \ + #X " expects an argument of type SymbolicExpression", loc, diagnostics); \ + } \ + \ + static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ + Vec &/*args*/) { \ + /*TODO*/ \ + return nullptr; \ + } \ + \ + static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \ + Vec& args, \ + const std::function err) { \ + if (args.size() != 1) { \ + err("Intrinsic " #X " function accepts exactly 1 argument", loc); \ + } \ + \ + ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \ + if (!ASR::is_a(*argtype)) { \ + err("Argument of " #X " function must be of type SymbolicExpression", \ + args[0]->base.loc); \ + } \ + \ + ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); \ + return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \ + static_cast(ASRUtils::IntrinsicFunctions::X), 0, to_type); \ + } \ + \ +} // namespace X -} // namespace SymbolicExpand +create_symbolic_unary_macro(SymbolicSin) +create_symbolic_unary_macro(SymbolicCos) +create_symbolic_unary_macro(SymbolicLog) +create_symbolic_unary_macro(SymbolicExp) +create_symbolic_unary_macro(SymbolicAbs) +create_symbolic_unary_macro(SymbolicExpand) namespace IntrinsicFunctionRegistry { @@ -2275,6 +2287,16 @@ namespace IntrinsicFunctionRegistry { {nullptr, &SymbolicDiff::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand), {nullptr, &SymbolicExpand::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicSin), + {nullptr, &SymbolicSin::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicCos), + {nullptr, &SymbolicCos::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicLog), + {nullptr, &SymbolicLog::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicExp), + {nullptr, &SymbolicExp::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicAbs), + {nullptr, &SymbolicAbs::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -2333,6 +2355,16 @@ namespace IntrinsicFunctionRegistry { "SymbolicDiff"}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand), "SymbolicExpand"}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicSin), + "SymbolicSin"}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicCos), + "SymbolicCos"}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicLog), + "SymbolicLog"}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicExp), + "SymbolicExp"}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicAbs), + "SymbolicAbs"}, {static_cast(ASRUtils::IntrinsicFunctions::Any), "any"}, {static_cast(ASRUtils::IntrinsicFunctions::Sum), @@ -2372,6 +2404,11 @@ namespace IntrinsicFunctionRegistry { {"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}}, {"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}}, {"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}}, + {"SymbolicSin", {&SymbolicSin::create_SymbolicSin, &SymbolicSin::eval_SymbolicSin}}, + {"SymbolicCos", {&SymbolicCos::create_SymbolicCos, &SymbolicCos::eval_SymbolicCos}}, + {"SymbolicLog", {&SymbolicLog::create_SymbolicLog, &SymbolicLog::eval_SymbolicLog}}, + {"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}}, + {"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}}, }; static inline bool is_intrinsic_function(const std::string& name) { @@ -2488,6 +2525,11 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicInteger) INTRINSIC_NAME_CASE(SymbolicDiff) INTRINSIC_NAME_CASE(SymbolicExpand) + INTRINSIC_NAME_CASE(SymbolicSin) + INTRINSIC_NAME_CASE(SymbolicCos) + INTRINSIC_NAME_CASE(SymbolicLog) + INTRINSIC_NAME_CASE(SymbolicExp) + INTRINSIC_NAME_CASE(SymbolicAbs) INTRINSIC_NAME_CASE(Sum) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 64d7c3ebdc..5ea163e50e 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -7265,15 +7265,23 @@ class BodyVisitor : public CommonVisitor { } if (!s) { + std::string intrinsic_name = call_name; std::set not_cpython_builtin = { "sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "sum" // For sum called over lists }; - if (ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(call_name) && + std::set symbolic_functions = { + "sin", "cos", "log", "exp", "Abs" + }; + if ((symbolic_functions.find(call_name) != symbolic_functions.end()) && + imported_functions[call_name] == "sympy"){ + intrinsic_name = "Symbolic" + std::string(1, std::toupper(call_name[0])) + call_name.substr(1); + } + if (ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(intrinsic_name) && (not_cpython_builtin.find(call_name) == not_cpython_builtin.end() || imported_functions.find(call_name) != imported_functions.end() )) { ASRUtils::create_intrinsic_function create_func = - ASRUtils::IntrinsicFunctionRegistry::get_create_function(call_name); + ASRUtils::IntrinsicFunctionRegistry::get_create_function(intrinsic_name); Vec args_; args_.reserve(al, x.n_args); visit_expr_list(x.m_args, x.n_args, args_); if (ASRUtils::is_array(ASRUtils::expr_type(args_[0])) && From 38162332559009b54be824d8931f0abdbebe5029 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 4 Jul 2023 17:02:12 +0530 Subject: [PATCH 2/3] Added tests --- integration_tests/CMakeLists.txt | 1 + integration_tests/symbolics_06.py | 37 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 integration_tests/symbolics_06.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index f22eb0767f..cc9a59bdbc 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -605,6 +605,7 @@ RUN(NAME symbolics_02 LABELS cpython_sym c_sym) RUN(NAME symbolics_03 LABELS cpython_sym c_sym) RUN(NAME symbolics_04 LABELS cpython_sym c_sym) RUN(NAME symbolics_05 LABELS cpython_sym c_sym) +RUN(NAME symbolics_06 LABELS cpython_sym c_sym) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_06.py b/integration_tests/symbolics_06.py new file mode 100644 index 0000000000..1af60dff6b --- /dev/null +++ b/integration_tests/symbolics_06.py @@ -0,0 +1,37 @@ +from sympy import Symbol, sin, cos, exp, log, Abs, pi, diff +from lpython import S + +def test_elementary_functions(): + + # test sin, cos + x: S = Symbol('x') + assert(sin(pi) == S(0)) + assert(sin(pi/S(2)) == S(1)) + assert(sin(S(2)*pi) == S(0)) + assert(cos(pi) == S(-1)) + assert(cos(pi/S(2)) == S(0)) + assert(cos(S(2)*pi) == S(1)) + assert(diff(sin(x), x) == cos(x)) + assert(diff(cos(x), x) == S(-1)*sin(x)) + + # test exp, log + assert(exp(S(0)) == S(1)) + assert(log(S(1)) == S(0)) + assert(diff(exp(x), x) == exp(x)) + assert(diff(log(x), x) == S(1)/x) + + # test Abs + assert(Abs(S(-10)) == S(10)) + assert(Abs(S(10)) == S(10)) + assert(Abs(-x) == Abs(x)) + + # test composite functions + a: S = exp(x) + b: S = sin(a) + c: S = cos(b) + d: S = log(c) + e: S = Abs(d) + print(e) + assert(e == Abs(log(cos(sin(exp(x)))))) + +test_elementary_functions() \ No newline at end of file From 486db54f38d36f383dfe5f5973f51a4b68a8dba6 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 4 Jul 2023 17:08:57 +0530 Subject: [PATCH 3/3] Updated test --- integration_tests/symbolics_06.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/symbolics_06.py b/integration_tests/symbolics_06.py index 1af60dff6b..f56aa52c76 100644 --- a/integration_tests/symbolics_06.py +++ b/integration_tests/symbolics_06.py @@ -23,7 +23,7 @@ def test_elementary_functions(): # test Abs assert(Abs(S(-10)) == S(10)) assert(Abs(S(10)) == S(10)) - assert(Abs(-x) == Abs(x)) + assert(Abs(S(-1)*x) == Abs(x)) # test composite functions a: S = exp(x)