Skip to content

Commit

Permalink
Fix Newton solver code generation (#696)
Browse files Browse the repository at this point in the history
* Fixes compilation issue when the operator() was declared as const while it changed some of values outside of its scope in the sympy solver

* Use DefUseAnalyzeVisitor to find out whether the operator() is Using or Defining any of the variables outside its scope

* Update DefUseAnalyzeVisitor to be able to process LOCAL variables as well, given precedence to GLOBAL/RANGE variables, when they have the same name

* Added related unit tests

* Fix codegen issue in CONSTRUCTOR and DESTRUCTOR blocks code when compiled for CoreNEURON

Fixes #691
FIxes #692
  • Loading branch information
iomaganaris authored Aug 27, 2021
1 parent e98f281 commit 1321c7b
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 27 deletions.
74 changes: 72 additions & 2 deletions src/codegen/codegen_c_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
#include "parser/c11_driver.hpp"
#include "utils/logger.hpp"
#include "utils/string_utils.hpp"
#include "visitors/defuse_analyze_visitor.hpp"
#include "visitors/rename_visitor.hpp"
#include "visitors/symtab_visitor.hpp"
#include "visitors/var_usage_visitor.hpp"
#include "visitors/visitor_utils.hpp"

Expand All @@ -32,7 +34,11 @@ namespace codegen {

using namespace ast;

using visitor::DefUseAnalyzeVisitor;
using visitor::DUChain;
using visitor::DUState;
using visitor::RenameVisitor;
using visitor::SymtabVisitor;
using visitor::VarUsageVisitor;

using symtab::syminfo::NmodlType;
Expand Down Expand Up @@ -1717,6 +1723,54 @@ std::string CodegenCVisitor::find_var_unique_name(const std::string& original_na
return unique_name;
}

/**
* @brief Checks whether the functor_block generated by sympy solver modifies any variable outside
* its scope. If it does then return false, so that the operator() of the struct functor of the
* Eigen Newton solver doesn't have const qualifier.
*
* @param variable_block Statement Block of the variables declarations used in the functor struct of
* the solver
* @param functor_block Actual code being printed in the operator() of the functor struct of the
* solver
* @return True if operator() is const else False
*/
bool is_functor_const(const ast::StatementBlock& variable_block,
const ast::StatementBlock& functor_block) {
// Save DUChain for every variable in variable_block
std::unordered_map<std::string, DUChain> chains;

// Create complete_block with both variable declarations (done in variable_block) and solver
// part (done in functor_block) to be able to run the SymtabVisitor and DefUseAnalyzeVisitor
// then and get the proper DUChains for the variables defined in the variable_block
ast::StatementBlock complete_block(functor_block);
// Typically variable_block has only one statement, a statement containing the declaration
// of the local variables
for (const auto& statement: variable_block.get_statements()) {
complete_block.insert_statement(complete_block.get_statements().begin(), statement);
}

// Create Symbol Table for complete_block
auto model_symbol_table = std::make_shared<symtab::ModelSymbolTable>();
SymtabVisitor(model_symbol_table.get()).visit_statement_block(complete_block);
// Initialize DefUseAnalyzeVisitor to generate the DUChains for the variables defined in the
// variable_block
DefUseAnalyzeVisitor v(*complete_block.get_symbol_table());

// Check the DUChains for all the variables in the variable_block
// If variable is defined in complete_block don't add const quilifier in operator()
auto is_functor_const = true;
const auto& variables = collect_nodes(variable_block, {ast::AstNodeType::LOCAL_VAR});
for (const auto& variable: variables) {
const auto& chain = v.analyze(complete_block, variable->get_node_name());
is_functor_const = !(chain.eval() == DUState::D || chain.eval() == DUState::LD ||
chain.eval() == DUState::CD);
if (!is_functor_const)
break;
}

return is_functor_const;
}

void CodegenCVisitor::visit_eigen_newton_solver_block(const ast::EigenNewtonSolverBlock& node) {
// solution vector to store copy of state vars for Newton solver
printer->add_newline();
Expand Down Expand Up @@ -1759,13 +1813,23 @@ void CodegenCVisitor::visit_eigen_newton_solver_block(const ast::EigenNewtonSolv
instance_struct(), "{}"));

printer->add_indent();

const auto& variable_block = *node.get_variable_block();
const auto& functor_block = *node.get_functor_block();

printer->add_text(
"void operator()(const Eigen::Matrix<{0}, {1}, 1>& {2}, Eigen::Matrix<{0}, {1}, "
"1>& {3}, "
"Eigen::Matrix<{0}, {1}, {1}>& {4}) const"_format(float_type, N, X, F, Jm));
"Eigen::Matrix<{0}, {1}, {1}>& {4}) {5}"_format(
float_type,
N,
X,
F,
Jm,
is_functor_const(variable_block, functor_block) ? "const " : ""));
printer->start_block();
printer->add_line("{}* {} = {}.data();"_format(float_type, J, Jm));
print_statement_block(*node.get_functor_block(), false, false);
print_statement_block(functor_block, false, false);
printer->end_block(2);

// assign newton solver results in matrix X to state vars
Expand Down Expand Up @@ -3274,6 +3338,10 @@ void CodegenCVisitor::print_global_function_common_code(BlockType type) {
// We do not (currently) support DESTRUCTOR and CONSTRUCTOR blocks
// running anything on the GPU.
print_kernel_data_present_annotation_block_begin();
} else {
/// TODO: Remove this when the code generation is propery done
/// Related to https://github.com/BlueBrain/nmodl/issues/692
printer->add_line("#ifndef CORENEURON_BUILD");
}
printer->add_line("int nodecount = ml->nodecount;");
printer->add_line("int pnodecount = ml->_nodecount_padded;");
Expand Down Expand Up @@ -3373,6 +3441,7 @@ void CodegenCVisitor::print_nrn_constructor() {
const auto& block = info.constructor_node->get_statement_block();
print_statement_block(*block.get(), false, false);
}
printer->add_line("#endif");
printer->end_block(1);
}

Expand All @@ -3384,6 +3453,7 @@ void CodegenCVisitor::print_nrn_destructor() {
const auto& block = info.destructor_node->get_statement_block();
print_statement_block(*block.get(), false, false);
}
printer->add_line("#endif");
printer->end_block(1);
}

Expand Down
6 changes: 5 additions & 1 deletion src/language/templates/visitors/symtab_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace visitor {
*/
class SymtabVisitor: public AstVisitor {
private:
symtab::ModelSymbolTable* modsymtab;
symtab::ModelSymbolTable* modsymtab = nullptr;

std::unique_ptr<printer::JSONPrinter> printer;
std::set<std::string> block_to_solve;
Expand All @@ -48,6 +48,10 @@ class SymtabVisitor: public AstVisitor {
: printer(new printer::JSONPrinter())
, update(update) {}

SymtabVisitor(symtab::ModelSymbolTable* _modsymtab, bool update = false)
: modsymtab(_modsymtab)
, update(update) {}

SymtabVisitor(std::ostream& os, bool update = false)
: printer(new printer::JSONPrinter(os))
, update(update) {}
Expand Down
2 changes: 1 addition & 1 deletion src/symtab/symbol_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class SymbolTable {
*/
class ModelSymbolTable {
/// symbol table for mod file (always top level symbol table)
std::shared_ptr<SymbolTable> symtab = nullptr;
std::shared_ptr<SymbolTable> symtab;

/// current symbol table being constructed
SymbolTable* current_symtab = nullptr;
Expand Down
45 changes: 31 additions & 14 deletions src/visitors/defuse_analyze_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,14 @@ std::string DUChain::to_string(bool compact) const {
* As these are innermost blocks, we have to just check first use
* of variable in this block and that's the result of this block.
*/
DUState DUInstance::sub_block_eval() const {
DUState DUInstance::sub_block_eval(DUVariableType variable_type = DUVariableType::Global) const {
DUState result = DUState::NONE;
for (const auto& chain: children) {
const auto& child_state = chain.eval();
if (child_state == DUState::U || child_state == DUState::D) {
const auto& child_state = chain.eval(variable_type);
if ((variable_type == DUVariableType::Global &&
(child_state == DUState::U || child_state == DUState::D)) ||
(variable_type == DUVariableType::Local &&
(child_state == DUState::LU || child_state == DUState::LD))) {
result = child_state;
break;
}
Expand Down Expand Up @@ -127,23 +130,27 @@ DUState DUInstance::sub_block_eval() const {
* block encountered, this means every block has either "D" or "CD". In
* this case we can say that entire block effectively has "D".
*/
DUState DUInstance::conditional_block_eval() const {
DUState DUInstance::conditional_block_eval(
DUVariableType variable_type = DUVariableType::Global) const {
DUState result = DUState::NONE;
bool block_with_none = false;

for (const auto& chain: children) {
auto child_state = chain.eval();
if (child_state == DUState::U) {
auto child_state = chain.eval(variable_type);
if ((variable_type == DUVariableType::Global && child_state == DUState::U) ||
(variable_type == DUVariableType::Local && child_state == DUState::LU)) {
result = child_state;
break;
}
if (child_state == DUState::NONE) {
block_with_none = true;
}
if (child_state == DUState::D || child_state == DUState::CD) {
if ((variable_type == DUVariableType::Global && child_state == DUState::D) ||
(variable_type == DUVariableType::Local && child_state == DUState::LD) ||
child_state == DUState::CD) {
result = DUState::CD;
if (chain.state == DUState::ELSE && !block_with_none) {
result = DUState::D;
result = child_state;
break;
}
}
Expand All @@ -155,12 +162,12 @@ DUState DUInstance::conditional_block_eval() const {
* Note that we are interested in "global" variable usage
* and hence we consider only [U,D] states and not [LU, LD]
*/
DUState DUInstance::eval() const {
DUState DUInstance::eval(DUVariableType variable_type = DUVariableType::Global) const {
auto result = state;
if (state == DUState::IF || state == DUState::ELSEIF || state == DUState::ELSE) {
result = sub_block_eval();
result = sub_block_eval(variable_type);
} else if (state == DUState::CONDITIONAL_BLOCK) {
result = conditional_block_eval();
result = conditional_block_eval(variable_type);
}
return result;
}
Expand All @@ -170,8 +177,9 @@ DUState DUInstance::eval() const {
DUState DUChain::eval() const {
auto result = DUState::NONE;
for (auto& inst: chain) {
auto re = inst.eval();
if (re == DUState::U || re == DUState::D) {
auto re = inst.eval(variable_type);
if ((variable_type == DUVariableType::Global && (re == DUState::U || re == DUState::D)) ||
(variable_type == DUVariableType::Local && (re == DUState::LU || re == DUState::LD))) {
result = re;
break;
}
Expand Down Expand Up @@ -418,9 +426,18 @@ DUChain DefUseAnalyzeVisitor::analyze(const ast::Ast& node, const std::string& n
visiting_lhs = false;
current_symtab = global_symtab;
unsupported_node = false;
auto global_symbol = global_symtab->lookup_in_scope(variable_name);
// If global_symbol exists in the global_symtab then search for a global variable. Otherwise the
// variable can only be local if it exists
auto global_symbol_properties = NmodlType::global_var | NmodlType::range_var;
if (global_symbol != nullptr && global_symbol->has_any_property(global_symbol_properties)) {
variable_type = DUVariableType::Global;
} else {
variable_type = DUVariableType::Local;
}

/// new chain
DUChain usage(node.get_node_type_name());
DUChain usage(node.get_node_type_name(), variable_type);
current_chain = &usage.chain;

/// analyze given node
Expand Down
29 changes: 22 additions & 7 deletions src/visitors/defuse_analyze_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ enum class DUState {
U,
/// global variable is defined
D,
/// global variable is conditionally defined
/// global or local variable is conditionally defined
CD,
/// local variable is used
LU,
/// local variable is used
/// local variable is defined
LD,
/// state not known
UNKNOWN,
Expand All @@ -50,6 +50,14 @@ enum class DUState {
NONE
};

/**
* Variable type processed by DefUseAnalyzeVisitor
*
* DUVariableType::Local means that we are looking for LD, LU and CD DUStates, while Global means we
* are looking for U, D and CD DUStates.
*/
enum class DUVariableType { Local, Global };

std::ostream& operator<<(std::ostream& os, DUState state);

/**
Expand Down Expand Up @@ -88,13 +96,13 @@ class DUInstance {
, binary_expression(binary_expression) {}

/// analyze all children and return "effective" usage
DUState eval() const;
DUState eval(DUVariableType variable_type) const;

/// if, elseif and else evaluation
DUState sub_block_eval() const;
DUState sub_block_eval(DUVariableType variable_type) const;

/// evaluate global usage i.e. with [D,U] states of children
DUState conditional_block_eval() const;
DUState conditional_block_eval(DUVariableType variable_type) const;

void print(printer::JSONPrinter& printer) const;

Expand Down Expand Up @@ -125,12 +133,16 @@ class DUChain {
/// name of the node
std::string name;

/// type of variable
DUVariableType variable_type;

/// def-use chain for a variable
std::vector<DUInstance> chain;

DUChain() = default;
explicit DUChain(std::string name)
: name(std::move(name)) {}
DUChain(std::string name, DUVariableType type)
: name(std::move(name))
, variable_type(type) {}

/// return "effective" usage of a variable
DUState eval() const;
Expand Down Expand Up @@ -217,6 +229,9 @@ class DefUseAnalyzeVisitor: protected ConstAstVisitor {
/// variable for which to construct def-use chain
std::string variable_name;

/// variable type (Local or Global)
DUVariableType variable_type;

/// indicate that there is unsupported construct encountered
bool unsupported_node = false;

Expand Down
2 changes: 1 addition & 1 deletion test/unit/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ std::string solve_construct(const std::string& equation, std::string method) {
return solution;
}

SCENARIO("Legacy differential equation solver from NEURON solve number of ODE types") {
SCENARIO("Legacy differential equation solver") {
GIVEN("A differential equation") {
int counter = 0;
for (const auto& test_case: diff_eq_constructs) {
Expand Down
Loading

0 comments on commit 1321c7b

Please sign in to comment.