From dbb001f6209b8a91f1850455f54423c9c369943e Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 10:51:55 +0530 Subject: [PATCH 01/21] [ASR Pass] Symbolic: Use a function to create `basic_new_stack` BindC Function --- src/libasr/pass/replace_symbolic.cpp | 88 ++++++++++++++-------------- 1 file changed, 45 insertions(+), 43 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 7abf80f8fd..ddceec95f7 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -49,6 +49,48 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor symbolic_vars_to_omit; SymEngine_Stack symengine_stack; + /********************************** Utils *********************************/ + ASR::stmt_t *basic_new_stack(const Location &loc, ASR::expr_t *x) { + std::string fn_name = "basic_new_stack"; + symbolic_dependencies.push_back(fn_name); + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)); + ASR::symbol_t* basic_new_stack_sym = current_scope->resolve_symbol(fn_name); + if ( !basic_new_stack_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable *fn_symtab = al.make_new(current_scope->parent); + + Vec args; + { + args.reserve(al, 1); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg))); + } + + Vec body; body.reserve(al, 1); + Vec dependencies; dependencies.reserve(al, 1); + basic_new_stack_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dependencies.p, dependencies.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, + false, false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(fn_name, basic_new_stack_sym); + } + + Vec call_args; call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = x; + call_args.push_back(al, call_arg); + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_new_stack_sym, + basic_new_stack_sym, call_args.p, call_args.n, nullptr)); + } + /********************************** Utils *********************************/ + void visit_Function(const ASR::Function_t &x) { // FIXME: this is a hack, we need to pass in a non-const `x`, // which requires to generate a TransformVisitor. @@ -143,39 +185,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, placeholder), sym2); - std::string new_name = "basic_new_stack"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable *fn_symtab = al.make_new(module_scope); - Vec args; - { - args.reserve(al, 1); - ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( - al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr, - ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg))); - } - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t *new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(new_name, new_symbol); - } - - new_name = "basic_free_stack"; + std::string new_name = "basic_free_stack"; symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; @@ -228,21 +239,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol("basic_new_stack"); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = xx.base.base.loc; - call_arg.m_value = target2; - call_args.push_back(al, call_arg); - // defining the assignment statement ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr)); ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr)); ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr)); - ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym, - basic_new_stack_sym, call_args.p, call_args.n, nullptr)); + // statement 4 + ASR::stmt_t* stmt4 = basic_new_stack(x.base.base.loc, target2); pass_result.push_back(al, stmt1); pass_result.push_back(al, stmt2); From ed2997d2fbe00d2df095d4731e28e8826600cd30 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 10:53:11 +0530 Subject: [PATCH 02/21] [ASR Pass] Symbolic: Use a function to create `basic_free_stack` BindC Function --- src/libasr/pass/replace_symbolic.cpp | 106 ++++++++++++--------------- 1 file changed, 45 insertions(+), 61 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index ddceec95f7..303f6693ee 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -89,6 +89,46 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_free_stack_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable *fn_symtab = al.make_new(current_scope->parent); + + Vec args; + { + args.reserve(al, 1); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg))); + } + + Vec body; body.reserve(al, 1); + Vec dependencies; dependencies.reserve(al, 1); + basic_free_stack_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dependencies.p, dependencies.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(fn_name, basic_free_stack_sym); + } + + Vec call_args; call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = x; + call_args.push_back(al, call_arg); + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_free_stack_sym, + basic_free_stack_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -97,7 +137,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); SymbolTable* current_scope_copy = this->current_scope; this->current_scope = xx.m_symtab; - SymbolTable* module_scope = this->current_scope->parent; ASR::ttype_t* f_signature= xx.m_function_signature; ASR::FunctionType_t *f_type = ASR::down_cast(f_signature); @@ -132,22 +171,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name); Vec func_body; func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body); for (ASR::symbol_t* symbol : symbolic_vars_to_free) { if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue; - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = xx.base.base.loc; - call_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, symbol)); - call_args.push_back(al, call_arg); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_free_stack_sym, - basic_free_stack_sym, call_args.p, call_args.n, nullptr)); - func_body.push_back(al, stmt); + func_body.push_back(al, basic_free_stack(x.base.base.loc, + ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol)))); } xx.n_body = func_body.size(); @@ -159,7 +189,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); if (xx.m_type->type == ASR::ttypeType::SymbolicExpression) { - SymbolTable* module_scope = current_scope->parent; std::string var_name = xx.m_name; std::string placeholder = "_" + std::string(var_name); @@ -184,40 +213,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, placeholder), sym2); - - - std::string new_name = "basic_free_stack"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable *fn_symtab = al.make_new(module_scope); - - Vec args; - { - args.reserve(al, 1); - ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( - al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr, - ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg))); - } - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t *new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(new_name, new_symbol); - } - ASR::symbol_t* var_sym = current_scope->get_symbol(var_name); ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder); ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym)); @@ -1771,22 +1766,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; - // freeing out variables - std::string new_name = "basic_free_stack"; - ASR::symbol_t* basic_free_stack_sym = module_scope->get_symbol(new_name); - for (ASR::symbol_t* symbol : symbolic_vars_to_free) { if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue; - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol)); - call_args.push_back(al, call_arg); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_free_stack_sym, - basic_free_stack_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); + // freeing out variables + pass_result.push_back(al, basic_free_stack(x.base.base.loc, + ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol)))); } symbolic_vars_to_free.clear(); pass_result.push_back(al, ASRUtils::STMT(ASR::make_Return_t(al, x.base.base.loc))); From 2fe8807e53b4289667c46c9258868999c1537501 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 10:53:47 +0530 Subject: [PATCH 03/21] [ASR Pass] Symbolic: Add `basic_free_stack` to function dependencies --- src/libasr/pass/replace_symbolic.cpp | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 303f6693ee..98f8750e9d 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -155,20 +155,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorcurrent_scope = current_scope_copy; - // freeing out variables if (!symbolic_vars_to_free.empty()) { Vec func_body; @@ -184,6 +170,16 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorcurrent_scope = current_scope_copy; } void visit_Variable(const ASR::Variable_t& x) { From 0c3d6f20ce1b5f546ac6f91a8296af2381ad1c57 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 11:16:45 +0530 Subject: [PATCH 04/21] [ASR Pass] Symbolic: Simplify `basic_get_args` to return `SubroutineCall` --- src/libasr/pass/replace_symbolic.cpp | 100 +++++++++++++-------------- 1 file changed, 48 insertions(+), 52 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 98f8750e9d..9aeed38a9a 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -129,6 +129,53 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_get_args_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; + args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, type, + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, type, + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; body.reserve(al, 1); + Vec dependencies; dependencies.reserve(al, 1); + basic_get_args_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dependencies.p, dependencies.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_get_args_sym); + } + + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = x; + call_args.push_back(al, call_arg); + call_arg.m_value = y; + call_args.push_back(al, call_arg); + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_get_args_sym, + basic_get_args_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -559,7 +606,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - ASR::symbol_t* basic_get_args_sym = declare_basic_get_args_function(al, loc, module_scope); ASR::symbol_t* vecbasic_new_sym = declare_vecbasic_new_function(al, loc, module_scope); ASR::symbol_t* vecbasic_get_sym = declare_vecbasic_get_function(al, loc, module_scope); ASR::symbol_t* vecbasic_size_sym = declare_vecbasic_size_function(al, loc, module_scope); @@ -584,18 +630,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args2; - call_args2.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = loc; - call_arg1.m_value = value1; - call_arg2.loc = loc; - call_arg2.m_value = args; - call_args2.push_back(al, call_arg1); - call_args2.push_back(al, call_arg2); - ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_get_args_sym, - basic_get_args_sym, call_args2.p, call_args2.n, nullptr)); - pass_result.push_back(al, stmt2); + pass_result.push_back(al, basic_get_args(loc, value1, args)); // Statement 3 Vec call_args3; @@ -798,45 +833,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } - ASR::symbol_t* declare_basic_get_args_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "basic_get_args"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 2); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::symbol_t* declare_vecbasic_new_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "vecbasic_new"; symbolic_dependencies.push_back(name); From 93546df23165441435e94d15be920b45d4f82351 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 12:55:27 +0530 Subject: [PATCH 05/21] [ASR Pass] Symbolic: Simplify `vecbasic_new` to return `FunctionCall` --- src/libasr/pass/replace_symbolic.cpp | 75 +++++++++++++--------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 9aeed38a9a..8808170aab 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -176,6 +176,41 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !vecbasic_new_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; + args.reserve(al, 1); + char *arg_name = s2c(al, "_lpython_return_variable"); + ASR::symbol_t* arg1 = ASR::down_cast( + ASR::make_Variable_t(al, loc, fn_symtab, arg_name, nullptr, 0, + ASR::intentType::ReturnVar, nullptr, nullptr, + ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_CPtr_t(al, loc))), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(arg_name, arg1); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)); + vecbasic_new_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, return_var, + ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), vecbasic_new_sym); + } + Vec call_args; call_args.reserve(al, 1); + return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + vecbasic_new_sym, vecbasic_new_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -606,7 +641,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - ASR::symbol_t* vecbasic_new_sym = declare_vecbasic_new_function(al, loc, module_scope); ASR::symbol_t* vecbasic_get_sym = declare_vecbasic_get_function(al, loc, module_scope); ASR::symbol_t* vecbasic_size_sym = declare_vecbasic_size_function(al, loc, module_scope); @@ -621,11 +655,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args1; - call_args1.reserve(al, 1); - ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - vecbasic_new_sym, vecbasic_new_sym, call_args1.p, call_args1.n, - ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, nullptr)); + ASR::expr_t* function_call1 = vecbasic_new(loc); ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, args, function_call1, nullptr)); pass_result.push_back(al, stmt1); @@ -833,39 +863,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } - ASR::symbol_t* declare_vecbasic_new_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "vecbasic_new"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_CPtr_t(al, loc))), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - return_var, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::symbol_t* declare_vecbasic_get_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "vecbasic_get"; symbolic_dependencies.push_back(name); From 799932e60cab2a97b1c5b597590c19bdeb5543d2 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 13:03:55 +0530 Subject: [PATCH 06/21] [ASR Pass] Symbolic: Simplify `vecbasic_get` to return `SubroutineCall` --- src/libasr/pass/replace_symbolic.cpp | 115 +++++++++++++-------------- 1 file changed, 54 insertions(+), 61 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 8808170aab..839b9fcd23 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -211,6 +211,59 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(name); + if ( !vecbasic_get_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 3); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, cptr_type, + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_Integer_t(al, loc, 4))), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, cptr_type, + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "z"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + vecbasic_get_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, name), + dep.p, dep.n, args.p, args.n, body.p, body.n, nullptr, + ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, name), vecbasic_get_sym); + } + Vec call_args; + call_args.reserve(al, 3); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = x; + call_args.push_back(al, call_arg); + call_arg.m_value = y; + call_args.push_back(al, call_arg); + call_arg.m_value = z; + call_args.push_back(al, call_arg); + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, vecbasic_get_sym, + vecbasic_get_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -641,7 +694,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - ASR::symbol_t* vecbasic_get_sym = declare_vecbasic_get_function(al, loc, module_scope); ASR::symbol_t* vecbasic_size_sym = declare_vecbasic_size_function(al, loc, module_scope); // Define necessary variables @@ -682,21 +734,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args4; - call_args4.reserve(al, 3); - ASR::call_arg_t call_arg4, call_arg5, call_arg6; - call_arg4.loc = loc; - call_arg4.m_value = args; - call_arg5.loc = loc; - call_arg5.m_value = x->m_args[1]; - call_arg6.loc = loc; - call_arg6.m_value = target; - call_args4.push_back(al, call_arg4); - call_args4.push_back(al, call_arg5); - call_args4.push_back(al, call_arg6); - ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, vecbasic_get_sym, - vecbasic_get_sym, call_args4.p, call_args4.n, nullptr)); - pass_result.push_back(al, stmt4); + pass_result.push_back(al, vecbasic_get(loc, args, x->m_args[1], target)); break; } default: { @@ -863,51 +901,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } - ASR::symbol_t* declare_vecbasic_get_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "vecbasic_get"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 3); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_Integer_t(al, loc, 4))), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "z"), arg3); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::symbol_t* declare_vecbasic_size_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "vecbasic_size"; symbolic_dependencies.push_back(name); From a7eae7b7fe81155a29ed59be7918ef749532b524 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 13:09:46 +0530 Subject: [PATCH 07/21] [ASR Pass] Symbolic: Simplify `vecbasic_size` to return `FunctionCall` --- src/libasr/pass/replace_symbolic.cpp | 94 +++++++++++++--------------- 1 file changed, 45 insertions(+), 49 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 839b9fcd23..531173e3c9 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -264,6 +264,50 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !vecbasic_size_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 1); + char *return_var_name = s2c(al, "_lpython_return_variable"); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, return_var_name, nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(return_var_name, arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)); + vecbasic_size_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), vecbasic_size_sym); + } + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = x; + call_args.push_back(al, call_arg); + return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + vecbasic_size_sym, vecbasic_size_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -694,7 +738,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - ASR::symbol_t* vecbasic_size_sym = declare_vecbasic_size_function(al, loc, module_scope); // Define necessary variables ASR::ttype_t* CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)); @@ -715,15 +758,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args3; - call_args3.reserve(al, 1); - ASR::call_arg_t call_arg3; - call_arg3.loc = loc; - call_arg3.m_value = args; - call_args3.push_back(al, call_arg3); - ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - vecbasic_size_sym, vecbasic_size_sym, call_args3.p, call_args3.n, - ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + ASR::expr_t* function_call2 = vecbasic_size(loc, args); ASR::expr_t* test = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call2, ASR::cmpopType::Gt, x->m_args[1], ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); std::string error_str = "tuple index out of range"; @@ -901,45 +936,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } - ASR::symbol_t* declare_vecbasic_size_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "vecbasic_size"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - return_var, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "basic_eq"; symbolic_dependencies.push_back(name); From f497d07128198aa313d3b76e39fba15143dc2ee5 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 13:18:40 +0530 Subject: [PATCH 08/21] [ASR Pass] Symbolic: Simplify `basic_assign` to return `SubroutineCall` --- src/libasr/pass/replace_symbolic.cpp | 122 ++++++++++++--------------- 1 file changed, 52 insertions(+), 70 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 531173e3c9..04cba31bd3 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -308,6 +308,52 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_assign_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, cptr_type, + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, cptr_type, + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + basic_assign_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, nullptr, + ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_assign_sym); + } + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + call_arg.m_value = value; + call_args.push_back(al, call_arg); + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_assign_sym, + basic_assign_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -780,45 +826,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 2); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::symbol_t* declare_basic_str_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "basic_str"; symbolic_dependencies.push_back(name); @@ -1197,22 +1204,9 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value) && ASR::is_a(*ASRUtils::expr_type(x.m_value))) { ASR::symbol_t *v = ASR::down_cast(x.m_value)->m_v; if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return; - ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope); ASR::symbol_t* var_sym = ASR::down_cast(x.m_value)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = x.m_target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = target; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym, - basic_assign_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); + pass_result.push_back(al, basic_assign(x.base.base.loc, x.m_target, + ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)))); } else if (ASR::is_a(*x.m_value)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(x.m_value); if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { @@ -1305,22 +1299,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { ASR::ListItem_t* list_item = ASR::down_cast(x.m_value); if (list_item->m_type->type == ASR::ttypeType::SymbolicExpression) { - ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); - ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope); - - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = x.m_target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = ASRUtils::EXPR(ASR::make_ListItem_t(al, x.base.base.loc, list_item->m_a, - list_item->m_pos, CPtr_type, nullptr)); - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym, - basic_assign_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); + ASR::expr_t *value = ASRUtils::EXPR(ASR::make_ListItem_t(al, + x.base.base.loc, list_item->m_a, list_item->m_pos, + ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), nullptr)); + pass_result.push_back(al, basic_assign(x.base.base.loc, x.m_target, value)); } } else if (ASR::is_a(*x.m_value)) { ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_value); From 0aa44351807ea71f73d9145afffd943c0f2a6405 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 13:41:38 +0530 Subject: [PATCH 09/21] [ASR Pass] Symbolic: Simplify `basic_str` to return `FunctionCall` --- src/libasr/pass/replace_symbolic.cpp | 171 ++++++++++----------------- 1 file changed, 61 insertions(+), 110 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 04cba31bd3..f0778bce66 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -354,6 +354,51 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_str_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 1); + char *return_var_name = s2c(al, "_lpython_return_variable"); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, return_var_name, nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, + ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(return_var_name, arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, + ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)); + basic_str_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, return_var, + ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_str_sym); + } + Vec call_args; call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = x; + call_args.push_back(al, call_arg); + return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_str_sym, basic_str_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -826,45 +871,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - return_var, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::symbol_t* declare_integer_set_si_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "integer_set_si"; symbolic_dependencies.push_back(name); @@ -1406,23 +1412,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val) && ASR::is_a(*ASRUtils::expr_type(val))) { ASR::symbol_t *v = ASR::down_cast(val)->m_v; if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return; - ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); - - // Extract the symbol from value (Var) - ASR::symbol_t* var_sym = ASR::down_cast(val)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - - // Now create the FunctionCall node for basic_str - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - print_tmp.push_back(function_call); + print_tmp.push_back(basic_str(x.base.base.loc, val)); } else if (ASR::is_a(*val)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(val); if (ASR::is_a(*ASRUtils::expr_type(val))) { @@ -1444,17 +1434,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - print_tmp.push_back(function_call); + print_tmp.push_back(basic_str(x.base.base.loc, target)); } else if (ASR::is_a(*ASRUtils::expr_type(val))) { ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, val, module_scope); print_tmp.push_back(function_call); @@ -1467,17 +1447,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - print_tmp.push_back(function_call); + print_tmp.push_back(basic_str(x.base.base.loc, target)); } else if (ASR::is_a(*val)) { ASR::SymbolicCompare_t *s = ASR::down_cast(val); if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { @@ -1507,20 +1477,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val)) { ASR::ListItem_t* list_item = ASR::down_cast(val); if (list_item->m_type->type == ASR::ttypeType::SymbolicExpression) { - ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); - ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); - - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = ASRUtils::EXPR(ASR::make_ListItem_t(al, x.base.base.loc, list_item->m_a, - list_item->m_pos, CPtr_type, nullptr)); - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - print_tmp.push_back(function_call); + ASR::expr_t *value = ASRUtils::EXPR(ASR::make_ListItem_t(al, + x.base.base.loc, list_item->m_a, list_item->m_pos, + ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), nullptr)); + print_tmp.push_back(basic_str(x.base.base.loc, value)); } } else { print_tmp.push_back(x.m_values[i]); @@ -1623,8 +1583,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*expr)) { var_sym = ASR::down_cast(expr)->m_v; @@ -1636,20 +1595,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(expr); this->visit_Cast(*cast_t); var_sym = current_scope->get_symbol(symengine_stack.pop()); + } else { + LCOMPILERS_ASSERT(false); } ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym)); - // Now create the FunctionCall node for basic_str - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_str_sym, basic_str_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr)); - return function_call; + // Now create the FunctionCall node for basic_str and return + return basic_str(loc, target); } void visit_Assert(const ASR::Assert_t &x) { @@ -1703,16 +1655,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { ASR::LogicalBinOp_t* binop = ASR::down_cast(x.m_test); if (ASR::is_a(*binop->m_left) && ASR::is_a(*binop->m_right)) { - ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); ASR::SymbolicCompare_t *s1 = ASR::down_cast(binop->m_left); - left_tmp = process_with_basic_str(al, x.base.base.loc, s1->m_left, basic_str_sym); - right_tmp = process_with_basic_str(al, x.base.base.loc, s1->m_right, basic_str_sym); + left_tmp = process_with_basic_str(x.base.base.loc, s1->m_left); + right_tmp = process_with_basic_str(x.base.base.loc, s1->m_right); ASR::expr_t* test1 = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, s1->m_op, right_tmp, s1->m_type, s1->m_value)); ASR::SymbolicCompare_t *s2 = ASR::down_cast(binop->m_right); - left_tmp = process_with_basic_str(al, x.base.base.loc, s2->m_left, basic_str_sym); - right_tmp = process_with_basic_str(al, x.base.base.loc, s2->m_right, basic_str_sym); + left_tmp = process_with_basic_str(x.base.base.loc, s2->m_left); + right_tmp = process_with_basic_str(x.base.base.loc, s2->m_right); ASR::expr_t* test2 = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, s2->m_op, right_tmp, s2->m_type, s2->m_value)); From 3096c288bb32fef4cdd0afd042e065280db0cfa9 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 14:22:10 +0530 Subject: [PATCH 10/21] [ASR Pass] Symbolic: Simplify `basic_get_type` to return `FunctionCall` --- src/libasr/pass/replace_symbolic.cpp | 137 ++++++++++----------------- 1 file changed, 48 insertions(+), 89 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index f0778bce66..3410ce4cf9 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -399,6 +399,49 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_get_type_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 1); + char *return_var_name =s2c(al, "_lpython_return_variable"); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, return_var_name, nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(return_var_name, arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)); + basic_get_type_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, return_var, + ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_get_type_sym); + } + Vec call_args; call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value; + call_args.push_back(al, call_arg); + return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -910,45 +953,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } - ASR::symbol_t* declare_basic_get_type_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "basic_get_type"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - return_var, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "basic_eq"; symbolic_dependencies.push_back(name); @@ -1106,17 +1110,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = value1; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + ASR::expr_t* function_call = basic_get_type(loc, value1); // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), @@ -1124,17 +1119,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = value1; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + ASR::expr_t* function_call = basic_get_type(loc, value1); // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), @@ -1142,17 +1128,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = value1; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + ASR::expr_t* function_call = basic_get_type(loc, value1); // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), @@ -1160,17 +1137,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = value1; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + ASR::expr_t* function_call = basic_get_type(loc, value1); // Using 29 as the right value of the IntegerCompare node as it represents SYMENGINE_LOG through SYMENGINE_ENUM return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 29, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), @@ -1178,17 +1146,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = value1; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + ASR::expr_t* function_call = basic_get_type(loc, value1); // Using 35 as the right value of the IntegerCompare node as it represents SYMENGINE_SIN through SYMENGINE_ENUM return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 35, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), From bb48bdbd410ead3fd6edb36b3d01e7b2866649a1 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 14:38:35 +0530 Subject: [PATCH 11/21] [ASR Pass] Symbolic: Simplify `basic_eq` & `basic_neq` into `basic_compare` to return `FunctionCall` --- src/libasr/pass/replace_symbolic.cpp | 204 ++++++++------------------- 1 file changed, 60 insertions(+), 144 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 3410ce4cf9..bb3dbbaa3e 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -442,6 +442,57 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_compare_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)); + basic_compare_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_compare_sym); + } + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = handle_argument(al, loc, left); + call_args.push_back(al, call_arg); + call_arg.m_value = handle_argument(al, loc, right); + call_args.push_back(al, call_arg); + return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_compare_sym, basic_compare_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -953,96 +1004,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } - ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "basic_eq"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg3); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - return_var, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - - ASR::symbol_t* declare_basic_neq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "basic_neq"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg3); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - return_var, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, SymbolTable* module_scope) { if (ASR::is_a(*expr)) { @@ -1272,27 +1233,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_value); if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { - ASR::symbol_t* sym = nullptr; + ASR::expr_t* function_call = nullptr; if (s->m_op == ASR::cmpopType::Eq) { - sym = declare_basic_eq_function(al, x.base.base.loc, module_scope); + function_call = basic_compare(x.base.base.loc, "basic_eq", s->m_left, s->m_right); } else { - sym = declare_basic_neq_function(al, x.base.base.loc, module_scope); + function_call = basic_compare(x.base.base.loc, "basic_neq", s->m_left, s->m_right); } - ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left); - ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right); - - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = value1; - call_args.push_back(al, call_arg1); - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = value2; - call_args.push_back(al, call_arg2); - - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr)); ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); pass_result.push_back(al, stmt); } @@ -1410,27 +1356,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val)) { ASR::SymbolicCompare_t *s = ASR::down_cast(val); if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { - ASR::symbol_t* sym = nullptr; + ASR::expr_t* function_call = nullptr; if (s->m_op == ASR::cmpopType::Eq) { - sym = declare_basic_eq_function(al, x.base.base.loc, module_scope); + function_call = basic_compare(x.base.base.loc, "basic_eq", s->m_left, s->m_right); } else { - sym = declare_basic_neq_function(al, x.base.base.loc, module_scope); + function_call = basic_compare(x.base.base.loc, "basic_neq", s->m_left, s->m_right); } - ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left); - ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right); - - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = value1; - call_args.push_back(al, call_arg1); - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = value2; - call_args.push_back(al, call_arg2); - - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr)); print_tmp.push_back(function_call); } } else if (ASR::is_a(*val)) { @@ -1580,27 +1511,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { ASR::SymbolicCompare_t* s = ASR::down_cast(x.m_test); if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { - ASR::symbol_t* sym = nullptr; + ASR::expr_t* function_call = nullptr; if (s->m_op == ASR::cmpopType::Eq) { - sym = declare_basic_eq_function(al, x.base.base.loc, module_scope); + function_call = basic_compare(x.base.base.loc, "basic_eq", s->m_left, s->m_right); } else { - sym = declare_basic_neq_function(al, x.base.base.loc, module_scope); + function_call = basic_compare(x.base.base.loc, "basic_neq", s->m_left, s->m_right); } - ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left); - ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = value1; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = value2; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - sym, sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr)); - ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, function_call, x.m_msg)); pass_result.push_back(al, assert_stmt); } From e8724c15c21a711849359aba42e9457119be3717 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 14:50:38 +0530 Subject: [PATCH 12/21] [ASR Pass] Symbolic: Simplify `integer_set_si` to return `SubroutineCall` --- src/libasr/pass/replace_symbolic.cpp | 131 ++++++++++----------------- 1 file changed, 50 insertions(+), 81 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index bb3dbbaa3e..ba0e84040f 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -493,6 +493,53 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !integer_set_si_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 8)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + integer_set_si_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), integer_set_si_sym); + } + + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + call_arg.m_value = value; + call_args.push_back(al, call_arg); + + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, integer_set_si_sym, + integer_set_si_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -965,45 +1012,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 2); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 8)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, SymbolTable* module_scope) { if (ASR::is_a(*expr)) { @@ -1148,22 +1156,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_arg; ASR::expr_t* cast_value = cast_t->m_value; if (ASR::is_a(*cast_arg)) { - ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope); ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8)); ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg, (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, nullptr)); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = x.m_target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = value; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym, - integer_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); + pass_result.push_back(al, integer_set_si(x.base.base.loc, x.m_target, value)); } else if (ASR::is_a(*cast_value)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(cast_value); int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; @@ -1180,24 +1176,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_n; } - ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope); ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8)); ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg, (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, const_value, cast_type)))); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = x.m_target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = value; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym, - integer_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); + pass_result.push_back(al, integer_set_si(x.base.base.loc, x.m_target, value)); } } } @@ -1416,7 +1399,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); std::string symengine_var = symengine_stack.push(); @@ -1451,24 +1433,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_n; } - ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope); ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8)); ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg, (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, const_value, cast_type)))); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = value; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym, - integer_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); + pass_result.push_back(al, integer_set_si(x.base.base.loc, target, value)); } } } From 274545101a857e9e3015cde4eb5671079c3141af Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 16:25:04 +0530 Subject: [PATCH 13/21] [ASR Pass] Symbolic: Simplify `symbol_set` to return `SubroutineCall` --- src/libasr/pass/replace_symbolic.cpp | 97 ++++++++++++++-------------- 1 file changed, 47 insertions(+), 50 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index ba0e84040f..09725d7e0d 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -540,6 +540,52 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !symbol_set_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, + ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "s"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + symbol_set_sym = ASR::down_cast(ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, fn_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), symbol_set_sym); + } + + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + call_arg.m_value = value; + call_args.push_back(al, call_arg); + + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, symbol_set_sym, + symbol_set_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -859,56 +905,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_intrinsic_id; switch (static_cast(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSymbol: { - std::string new_name = "symbol_set"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "s"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = loc; - call_arg1.m_value = target; - call_arg2.loc = loc; - call_arg2.m_value = x->m_args[0]; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, symbol_set_sym, - symbol_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); + pass_result.push_back(al, symbol_set(loc, target, x->m_args[0])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPi: { From 9440150ce9b1e16343e8bf668fdeac3cbd5248a9 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 16:39:45 +0530 Subject: [PATCH 14/21] [ASR Pass] Symbolic: Simplify `basic_const` to return `SubroutineCall` --- src/libasr/pass/replace_symbolic.cpp | 91 ++++++++++++---------------- 1 file changed, 40 insertions(+), 51 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 09725d7e0d..7b3af33587 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -586,6 +586,44 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_const_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 1); + ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + basic_const_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_const_sym); + } + + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value; + call_args.push_back(al, call_arg); + + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, + basic_const_sym, basic_const_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -824,50 +862,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - ASR::symbol_t* func_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = value; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, func_sym, - func_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - } - ASR::expr_t* handle_argument(Allocator &al, const Location &loc, ASR::expr_t* arg) { if (ASR::is_a(*arg)) { return arg; @@ -895,11 +889,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_intrinsic_id; @@ -909,11 +898,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor Date: Sat, 25 Nov 2023 16:51:31 +0530 Subject: [PATCH 15/21] [ASR Pass] Symbolic: Simplify `basic_binop` to return `SubroutineCall` --- src/libasr/pass/replace_symbolic.cpp | 141 +++++++++++++-------------- 1 file changed, 66 insertions(+), 75 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 7b3af33587..4e5e094e6e 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -624,6 +624,60 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_binop_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 3); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "z"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + basic_binop_sym = ASR::down_cast( + ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name), + dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_binop_sym); + } + + Vec call_args; + call_args.reserve(al, 3); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + call_arg.m_value = handle_argument(al, loc, op_01); + call_args.push_back(al, call_arg); + call_arg.m_value = handle_argument(al, loc, op_02); + call_args.push_back(al, call_arg); + + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_binop_sym, + basic_binop_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -747,68 +801,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 3); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "z"), arg3); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - ASR::symbol_t* func_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 3); - ASR::call_arg_t call_arg1, call_arg2, call_arg3; - call_arg1.loc = loc; - call_arg1.m_value = value1; - call_arg2.loc = loc; - call_arg2.m_value = value2; - call_arg3.loc = loc; - call_arg3.m_value = value3; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - call_args.push_back(al, call_arg3); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, func_sym, - func_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - } - void perform_symbolic_unary_operation(Allocator &al, const Location &loc, SymbolTable* module_scope, const std::string& new_name, ASR::expr_t* value1, ASR::expr_t* value2) { symbolic_dependencies.push_back(new_name); @@ -876,13 +868,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - ASR::expr_t* value2 = handle_argument(al, loc, x->m_args[1]); - perform_symbolic_binary_operation(al, loc, module_scope, new_name, target, value1, value2); - } - void process_unary_operator(Allocator &al, const Location &loc, ASR::IntrinsicScalarFunction_t* x, SymbolTable* module_scope, const std::string& new_name, ASR::expr_t* target) { ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]); @@ -906,27 +891,33 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0], x->m_args[1])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSub: { - process_binary_operator(al, loc, x, module_scope, "basic_sub", target); + pass_result.push_back(al, basic_binop(loc, "basic_sub", target, + x->m_args[0], x->m_args[1])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMul: { - process_binary_operator(al, loc, x, module_scope, "basic_mul", target); + pass_result.push_back(al, basic_binop(loc, "basic_mul", target, + x->m_args[0], x->m_args[1])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicDiv: { - process_binary_operator(al, loc, x, module_scope, "basic_div", target); + pass_result.push_back(al, basic_binop(loc, "basic_div", target, + x->m_args[0], x->m_args[1])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPow: { - process_binary_operator(al, loc, x, module_scope, "basic_pow", target); + pass_result.push_back(al, basic_binop(loc, "basic_pow", target, + x->m_args[0], x->m_args[1])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicDiff: { - process_binary_operator(al, loc, x, module_scope, "basic_diff", target); + pass_result.push_back(al, basic_binop(loc, "basic_diff", target, + x->m_args[0], x->m_args[1])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSin: { From 79066eef1ae2f7fc0ffbe305e5363893e2a7ea06 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 17:00:05 +0530 Subject: [PATCH 16/21] [ASR Pass] Symbolic: Simplify `basic_unaryop` to return `SubroutineCall` --- src/libasr/pass/replace_symbolic.cpp | 122 +++++++++++++-------------- 1 file changed, 57 insertions(+), 65 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 4e5e094e6e..98d78ed6d0 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -678,6 +678,51 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_unaryop_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + basic_unaryop_sym = ASR::down_cast(ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, fn_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_unaryop_sym); + } + + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + call_arg.m_value = handle_argument(al, loc, op_01); + call_args.push_back(al, call_arg); + + return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_unaryop_sym, + basic_unaryop_sym, call_args.p, call_args.n, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -801,59 +846,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 2); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - ASR::symbol_t* func_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = loc; - call_arg1.m_value = value1; - call_arg2.loc = loc; - call_arg2.m_value = value2; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, func_sym, - func_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - } - ASR::expr_t* handle_argument(Allocator &al, const Location &loc, ASR::expr_t* arg) { if (ASR::is_a(*arg)) { return arg; @@ -868,12 +860,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - perform_symbolic_unary_operation(al, loc, module_scope, new_name, target, value1); - } - void process_intrinsic_function(Allocator &al, const Location &loc, ASR::IntrinsicScalarFunction_t* x, SymbolTable* module_scope, ASR::expr_t* target){ int64_t intrinsic_id = x->m_intrinsic_id; @@ -921,27 +907,33 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicCos: { - process_unary_operator(al, loc, x, module_scope, "basic_cos", target); + pass_result.push_back(al, basic_unaryop(loc, "basic_cos", target, + x->m_args[0])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLog: { - process_unary_operator(al, loc, x, module_scope, "basic_log", target); + pass_result.push_back(al, basic_unaryop(loc, "basic_log", target, + x->m_args[0])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicExp: { - process_unary_operator(al, loc, x, module_scope, "basic_exp", target); + pass_result.push_back(al, basic_unaryop(loc, "basic_exp", target, + x->m_args[0])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAbs: { - process_unary_operator(al, loc, x, module_scope, "basic_abs", target); + pass_result.push_back(al, basic_unaryop(loc, "basic_abs", target, + x->m_args[0])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicExpand: { - process_unary_operator(al, loc, x, module_scope, "basic_expand", target); + pass_result.push_back(al, basic_unaryop(loc, "basic_expand", target, + x->m_args[0])); break; } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicGetArgument: { From d331f27b640f921748480363ae6a9e0190bc44f4 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 17:00:35 +0530 Subject: [PATCH 17/21] [ASR Pass] Symbolic: Simplify `process_intrinsic_function` arguments --- src/libasr/pass/replace_symbolic.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 98d78ed6d0..55a35fae33 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -860,8 +860,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_intrinsic_id; switch (static_cast(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSymbol: { @@ -1113,7 +1113,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(x.m_value); if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { - process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target); + process_intrinsic_function(x.base.base.loc, intrinsic_func, x.m_target); } else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, x.m_value, module_scope); ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); @@ -1212,7 +1212,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; Vec call_args; call_args.reserve(al, 1); @@ -1235,7 +1234,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype == ASR::ttypeType::SymbolicExpression) { - SymbolTable* module_scope = current_scope->parent; - ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); std::string symengine_var = symengine_stack.push(); ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( @@ -1362,7 +1359,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); - process_intrinsic_function(al, x.base.base.loc, &xx, module_scope, target); + process_intrinsic_function(x.base.base.loc, &xx, target); } } From f18ae18ec29e327b80afe37dc3e9e79b7014801b Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 17:31:19 +0530 Subject: [PATCH 18/21] [ASR Pass] Symbolic: Simplify `process_intrinsic_function` to use macros --- src/libasr/pass/replace_symbolic.cpp | 100 +++++++++------------------ 1 file changed, 32 insertions(+), 68 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 55a35fae33..64e5c0cfda 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -50,6 +50,24 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0], x->m_args[1])); \ + break; } + + #define BASIC_UNARYOP(SYM, name) \ + case LCompilers::ASRUtils::IntrinsicScalarFunctions::Symbolic##SYM: { \ + pass_result.push_back(al, basic_unaryop(loc, "basic_"#name, \ + target, x->m_args[0])); \ + break; } + ASR::stmt_t *basic_new_stack(const Location &loc, ASR::expr_t *x) { std::string fn_name = "basic_new_stack"; symbolic_dependencies.push_back(fn_name); @@ -868,74 +886,20 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0])); break; } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPi: { - pass_result.push_back(al, basic_const(loc, "basic_const_pi", target)); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicE: { - pass_result.push_back(al, basic_const(loc, "basic_const_E", target)); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAdd: { - pass_result.push_back(al, basic_binop(loc, "basic_add", target, - x->m_args[0], x->m_args[1])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSub: { - pass_result.push_back(al, basic_binop(loc, "basic_sub", target, - x->m_args[0], x->m_args[1])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMul: { - pass_result.push_back(al, basic_binop(loc, "basic_mul", target, - x->m_args[0], x->m_args[1])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicDiv: { - pass_result.push_back(al, basic_binop(loc, "basic_div", target, - x->m_args[0], x->m_args[1])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPow: { - pass_result.push_back(al, basic_binop(loc, "basic_pow", target, - x->m_args[0], x->m_args[1])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicDiff: { - pass_result.push_back(al, basic_binop(loc, "basic_diff", target, - x->m_args[0], x->m_args[1])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSin: { - pass_result.push_back(al, basic_unaryop(loc, "basic_sin", target, - x->m_args[0])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicCos: { - pass_result.push_back(al, basic_unaryop(loc, "basic_cos", target, - x->m_args[0])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLog: { - pass_result.push_back(al, basic_unaryop(loc, "basic_log", target, - x->m_args[0])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicExp: { - pass_result.push_back(al, basic_unaryop(loc, "basic_exp", target, - x->m_args[0])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAbs: { - pass_result.push_back(al, basic_unaryop(loc, "basic_abs", target, - x->m_args[0])); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicExpand: { - pass_result.push_back(al, basic_unaryop(loc, "basic_expand", target, - x->m_args[0])); - break; - } + BASIC_CONST(Pi, pi) + BASIC_CONST(E, E) + BASIC_BINOP(Add, add) + BASIC_BINOP(Sub, sub) + BASIC_BINOP(Mul, mul) + BASIC_BINOP(Div, div) + BASIC_BINOP(Pow, pow) + BASIC_BINOP(Diff, diff) + BASIC_UNARYOP(Sin, sin) + BASIC_UNARYOP(Cos, cos) + BASIC_UNARYOP(Log, log) + BASIC_UNARYOP(Exp, exp) + BASIC_UNARYOP(Abs, abs) + BASIC_UNARYOP(Expand, expand) case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicGetArgument: { // Define necessary function symbols ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]); From 760380a3916746ed3edb63e02dca29789aa82caf Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 17:50:21 +0530 Subject: [PATCH 19/21] [ASR Pass] Symbolic: Simplify `process_attributes` to use macros --- src/libasr/pass/intrinsic_function_registry.h | 8 +-- src/libasr/pass/replace_symbolic.cpp | 60 +++++-------------- 2 files changed, 18 insertions(+), 50 deletions(-) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index cfb4f5cdfb..2fca01651d 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -251,11 +251,11 @@ class ASRBuilder { false, nullptr, 0, false, false, false)); // Types ------------------------------------------------------------------- - #define int32 TYPE(ASR::make_Integer_t(al, loc, 4)) + #define int32 ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)) #define int64 TYPE(ASR::make_Integer_t(al, loc, 8)) #define real32 TYPE(ASR::make_Real_t(al, loc, 4)) #define real64 TYPE(ASR::make_Real_t(al, loc, 8)) - #define logical TYPE(ASR::make_Logical_t(al, loc, 4)) + #define logical ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)) #define character(x) TYPE(ASR::make_Character_t(al, loc, 1, x, nullptr)) #define List(x) TYPE(ASR::make_List_t(al, loc, x)) @@ -285,7 +285,7 @@ class ASRBuilder { // Expressions ------------------------------------------------------------- #define i(x, t) EXPR(ASR::make_IntegerConstant_t(al, loc, x, t)) - #define i32(x) EXPR(ASR::make_IntegerConstant_t(al, loc, x, int32)) + #define i32(x) ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, x, int32)) #define i32_n(x) EXPR(ASR::make_IntegerUnaryMinus_t(al, loc, i32(abs(x)), \ int32, i32(x))) #define i32_neg(x, t) EXPR(ASR::make_IntegerUnaryMinus_t(al, loc, x, t, nullptr)) @@ -414,7 +414,7 @@ class ASRBuilder { } // Compare ----------------------------------------------------------------- - #define iEq(x, y) EXPR(ASR::make_IntegerCompare_t(al, loc, x, \ + #define iEq(x, y) ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, x, \ ASR::cmpopType::Eq, y, logical, nullptr)) #define iNotEq(x, y) EXPR(ASR::make_IntegerCompare_t(al, loc, x, \ ASR::cmpopType::NotEq, y, logical, nullptr)) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 64e5c0cfda..273a84328a 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -68,6 +68,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0])); \ break; } + #define BASIC_ATTR(SYM, N) \ + case LCompilers::ASRUtils::IntrinsicScalarFunctions::Symbolic##SYM: { \ + ASR::expr_t* function_call = basic_get_type(loc, \ + intrinsic_func->m_args[0]); \ + return iEq(function_call, i32(N)); } + ASR::stmt_t *basic_new_stack(const Location &loc, ASR::expr_t *x) { std::string fn_name = "basic_new_stack"; symbolic_dependencies.push_back(fn_name); @@ -454,7 +460,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args; call_args.reserve(al, 1); ASR::call_arg_t call_arg; call_arg.loc = loc; - call_arg.m_value = value; + call_arg.m_value = handle_argument(al, loc, value); call_args.push_back(al, call_arg); return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, @@ -1011,51 +1017,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - ASR::expr_t* function_call = basic_get_type(loc, value1); - // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM - return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, - ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: { - ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); - ASR::expr_t* function_call = basic_get_type(loc, value1); - // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM - return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, - ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: { - ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); - ASR::expr_t* function_call = basic_get_type(loc, value1); - // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM - return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, - ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLogQ: { - ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); - ASR::expr_t* function_call = basic_get_type(loc, value1); - // Using 29 as the right value of the IntegerCompare node as it represents SYMENGINE_LOG through SYMENGINE_ENUM - return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, - ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 29, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); - break; - } - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSinQ: { - ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); - ASR::expr_t* function_call = basic_get_type(loc, value1); - // Using 35 as the right value of the IntegerCompare node as it represents SYMENGINE_SIN through SYMENGINE_ENUM - return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, - ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 35, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); - break; - } + // (sym_name, n) where n = 16, 15, ... as the right value of the + // IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM + BASIC_ATTR(AddQ, 16) + BASIC_ATTR(MulQ, 15) + BASIC_ATTR(PowQ, 17) + BASIC_ATTR(LogQ, 29) + BASIC_ATTR(SinQ, 35) default: { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(intrinsic_id) From f6d0bd690eaf234995bb327bf5e5d0ad2d9fc640 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 17:59:08 +0530 Subject: [PATCH 20/21] [ASR Pass] Symbolic: Simplify `basic_has_symbol` to return `FunctionCall` --- src/libasr/pass/replace_symbolic.cpp | 112 +++++++++++++-------------- 1 file changed, 54 insertions(+), 58 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 273a84328a..819320a427 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -747,6 +747,58 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorresolve_symbol(fn_name); + if ( !basic_has_symbol_sym ) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(current_scope->parent); + + Vec args; args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; body.reserve(al, 1); + Vec dep; dep.reserve(al, 1); + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)); + basic_has_symbol_sym = ASR::down_cast(ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, fn_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header))); + current_scope->parent->add_symbol(s2c(al, fn_name), basic_has_symbol_sym); + } + + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = loc; + call_arg1.m_value = handle_argument(al, loc, value_01); + call_args.push_back(al, call_arg1); + call_arg2.loc = loc; + call_arg2.m_value = handle_argument(al, loc, value_02); + call_args.push_back(al, call_arg2); + return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_has_symbol_sym, basic_has_symbol_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr)); + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -958,64 +1010,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_intrinsic_id; switch (static_cast(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ: { - std::string name = "basic_has_symbol"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg3); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - return_var, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - - ASR::symbol_t* basic_has_symbol = module_scope->get_symbol(name); - ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); - ASR::expr_t* value2 = handle_argument(al, loc, intrinsic_func->m_args[1]); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = loc; - call_arg1.m_value = value1; - call_args.push_back(al, call_arg1); - call_arg2.loc = loc; - call_arg2.m_value = value2; - call_args.push_back(al, call_arg2); - return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_has_symbol, basic_has_symbol, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr)); - break; + return basic_has_symbol(loc, intrinsic_func->m_args[0], + intrinsic_func->m_args[1]); } // (sym_name, n) where n = 16, 15, ... as the right value of the // IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM From 6a7b2cdeaad635e4a8080413f1e64ae7d9b06034 Mon Sep 17 00:00:00 2001 From: Thirumalai Shaktivel Date: Sat, 25 Nov 2023 18:01:06 +0530 Subject: [PATCH 21/21] [ASR Pass] Symbolic: Simplify `process_attributes` arguments --- src/libasr/pass/replace_symbolic.cpp | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 819320a427..4acbad65fa 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -1003,8 +1003,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*expr)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(expr); int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; @@ -1031,7 +1030,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; if (ASR::is_a(*x.m_value) && ASR::is_a(*ASRUtils::expr_type(x.m_value))) { ASR::symbol_t *v = ASR::down_cast(x.m_value)->m_v; if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return; @@ -1043,7 +1041,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_type->type == ASR::ttypeType::SymbolicExpression) { process_intrinsic_function(x.base.base.loc, intrinsic_func, x.m_target); } else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { - ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, x.m_value, module_scope); + ASR::expr_t* function_call = process_attributes(x.base.base.loc, x.m_value); ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); pass_result.push_back(al, stmt); } @@ -1129,11 +1127,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); transform_stmts(xx.m_body, xx.n_body); transform_stmts(xx.m_orelse, xx.n_orelse); - SymbolTable* module_scope = current_scope->parent; if (ASR::is_a(*xx.m_test)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(xx.m_test); if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { - ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, xx.m_test, module_scope); + ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test); xx.m_test = function_call; } } @@ -1190,7 +1187,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor print_tmp; - SymbolTable* module_scope = current_scope->parent; for (size_t i=0; i(*val) && ASR::is_a(*ASRUtils::expr_type(val))) { @@ -1220,7 +1216,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*ASRUtils::expr_type(val))) { - ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, val, module_scope); + ASR::expr_t* function_call = process_attributes(x.base.base.loc, val); print_tmp.push_back(function_call); } } else if (ASR::is_a(*val)) { @@ -1358,14 +1354,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; ASR::expr_t* left_tmp = nullptr; ASR::expr_t* right_tmp = nullptr; if (ASR::is_a(*x.m_test)) { ASR::LogicalCompare_t *l = ASR::down_cast(x.m_test); - left_tmp = process_attributes(al, x.base.base.loc, l->m_left, module_scope); - right_tmp = process_attributes(al, x.base.base.loc, l->m_right, module_scope); + left_tmp = process_attributes(x.base.base.loc, l->m_left); + right_tmp = process_attributes(x.base.base.loc, l->m_right); ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp, l->m_op, right_tmp, l->m_type, l->m_value)); @@ -1386,7 +1381,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(x.m_test); if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { - ASR::expr_t* test = process_attributes(al, x.base.base.loc, x.m_test, module_scope); + ASR::expr_t* test = process_attributes(x.base.base.loc, x.m_test); ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); pass_result.push_back(al, assert_stmt); }