Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reentrancy-2 detector #288

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 144 additions & 155 deletions detectors/reentrancy-2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
#![feature(rustc_private)]

extern crate rustc_abi;
extern crate rustc_ast;
extern crate rustc_hir;
extern crate rustc_middle;
extern crate rustc_span;
extern crate rustc_target;
extern crate rustc_type_ir;

use std::collections::{HashMap, HashSet};

use clippy_wrappers::span_lint_and_help;
use if_chain::if_chain;
use rustc_ast::ast::LitKind;
use rustc_hir::def::Res;
use rustc_hir::intravisit::{walk_expr, FnKind};
use rustc_hir::intravisit::{walk_local, Visitor};
use rustc_hir::{Body, FnDecl, HirId, Local, PatKind, QPath};
use rustc_hir::{Expr, ExprKind};
use rustc_hir::{
def::Res,
intravisit::{walk_expr, walk_local, FnKind, Visitor},
Body, Expr, ExprKind, FnDecl, HirId, Local, PatKind, QPath,
};
use rustc_lint::{LateContext, LateLintPass};
use rustc_middle::ty::TyKind;
use rustc_span::def_id::LocalDefId;
use rustc_span::{Span, Symbol};
use rustc_span::{def_id::LocalDefId, Span, Symbol};
use rustc_target::abi::VariantIdx;

const LINT_MESSAGE:&str = "External calls could open the opportunity for a malicious contract to execute any arbitrary code";
Expand Down Expand Up @@ -105,190 +103,181 @@ const INSERT: &str = "insert";
const MAPPING: &str = "Mapping";
const ACCOUNT_ID: &str = "AccountId";
const U128: &str = "u128";
const CALL_FLAGS: &str = "call_flags";
const ALLOW_REENTRY: &str = "ALLOW_REENTRY";

impl<'tcx> LateLintPass<'tcx> for Reentrancy2 {
fn check_fn(
&mut self,
cx: &LateContext<'tcx>,
_: FnKind<'tcx>,
_: &'tcx FnDecl<'_>,
body: &'tcx Body<'_>,
_: Span,
_: LocalDefId,
) {
struct ReentrancyVisitor<'a, 'tcx: 'a> {
cx: &'a LateContext<'tcx>,
contracts_tainted_for_reentrancy: HashSet<Symbol>,
current_method_call: Option<Symbol>,
bool_var_values: HashMap<HirId, bool>,
reentrancy_spans: Vec<Span>,
should_look_for_insert: bool,
has_insert_operation: bool,
}
struct ReentrancyVisitor<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
tainted_contracts: HashSet<Symbol>,
current_method: Option<Symbol>,
bool_values: HashMap<HirId, bool>,
reentrancy_spans: Vec<Span>,
looking_for_insert: bool,
found_insert: bool,
}

// This function is called whenever a contract is identified as potentially susceptible to reentrancy.
fn set_tainted_contract(visitor: &mut ReentrancyVisitor) {
if let Some(method_calls) = &visitor.current_method_call {
visitor
.contracts_tainted_for_reentrancy
.insert(*method_calls);
visitor.current_method_call = None;
}
impl<'a, 'tcx> ReentrancyVisitor<'a, 'tcx> {
fn mark_current_as_tainted(&mut self) {
if let Some(method) = self.current_method.take() {
self.tainted_contracts.insert(method);
}
}

fn handle_set_allow_reentry(visitor: &mut ReentrancyVisitor, args: &&[Expr<'_>]) {
match &args[0].kind {
ExprKind::Lit(lit) => {
// If the argument is a boolean literal and it's true, call set_tainted_contract
if let LitKind::Bool(value) = lit.node {
if value {
set_tainted_contract(visitor);
}
}
}
ExprKind::Path(qpath) => {
// If the argument is a local variable, check if it's a boolean and if it's true
if_chain! {
if let res = visitor.cx.qpath_res(qpath, args[0].hir_id);
if let Res::Local(_) = res;
if let QPath::Resolved(_, path) = qpath;
then {
for path_segment in path.segments {
// If the argument is a known boolean variable, check if it's true
if let Res::Local(hir_id) = path_segment.res {
if visitor.bool_var_values.get(&hir_id).map_or(true, |v| *v) {
set_tainted_contract(visitor);
}
}
fn handle_set_allow_reentry(&mut self, args: &[Expr<'_>]) {
let is_reentry_enabled = match &args[0].kind {
ExprKind::Lit(lit) => matches!(lit.node, LitKind::Bool(true)),
ExprKind::Path(qpath) => {
if_chain! {
if let res = self.cx.qpath_res(qpath, args[0].hir_id);
if let Res::Local(_) = res;
if let QPath::Resolved(_, path) = qpath;
then {
path.segments.iter().any(|segment| {
if let Res::Local(hir_id) = segment.res {
self.bool_values.get(&hir_id).copied().unwrap_or(true)
} else {
false
}
}
})
} else {
false
}
}
_ => (),
}
_ => false,
};

if is_reentry_enabled {
self.mark_current_as_tainted();
}
}

fn handle_invoke_contract(
visitor: &mut ReentrancyVisitor,
args: &&[Expr<'_>],
expr: &Expr<'_>,
) {
if_chain! {
if let ExprKind::AddrOf(_, _, invoke_expr) = &args[0].kind;
if let ExprKind::Path(qpath) = &invoke_expr.kind;
if let QPath::Resolved(_, path) = qpath;
then{
for path_segment in path.segments {
// If the argument is a tainted contract, add the span of this expression to the span vector
if visitor.contracts_tainted_for_reentrancy.contains(&path_segment.ident.name) {
visitor.should_look_for_insert = true;
visitor.reentrancy_spans.push(expr.span);
}
fn handle_invoke_contract(&mut self, args: &[Expr<'_>], expr: &Expr<'_>) {
if_chain! {
if let ExprKind::AddrOf(_, _, invoke_expr) = &args[0].kind;
if let ExprKind::Path(qpath) = &invoke_expr.kind;
if let QPath::Resolved(_, path) = qpath;
then {
for segment in path.segments.iter() {
if self.tainted_contracts.contains(&segment.ident.name) {
self.looking_for_insert = true;
self.reentrancy_spans.push(expr.span);
}
}
}
}
}

fn handle_insert(visitor: &mut ReentrancyVisitor, expr: &Expr<'_>) {
if_chain! {
if let ExprKind::MethodCall(_, expr1, _, _) = &expr.kind;
if let object_type = visitor.cx.typeck_results().expr_ty(expr1);
if let TyKind::Adt(adt_def, substs) = object_type.kind();
if let Some(variant) = adt_def.variants().get(VariantIdx::from_u32(0));
if variant.name.as_str() == MAPPING;
if let mut has_account_id = false;
if let mut has_u128 = false;
then{
substs.types().for_each(|inner_type| {
let str_inner_type = inner_type.to_string();
if str_inner_type.contains(ACCOUNT_ID) {
has_account_id = true;
} else if str_inner_type.contains(U128) {
has_u128 = true;
}
});
visitor.has_insert_operation = has_account_id && has_u128;
}
fn handle_call_flags(&mut self, args: &[Expr<'_>]) {
if_chain! {
if let ExprKind::Path(qpath) = &args[0].kind;
if let QPath::TypeRelative(_, segment) = qpath;
if segment.ident.name.as_str() == ALLOW_REENTRY;
then {
self.mark_current_as_tainted();
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for ReentrancyVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx Local<'tcx>) {
if let Some(init) = &local.init {
if let PatKind::Binding(_, _, ident, _) = &local.pat.kind {
match &init.kind {
// Check if the variable being declared is a boolean, if so, add it to the bool_declarations hashmap
ExprKind::Lit(lit) => {
if let LitKind::Bool(value) = lit.node {
self.bool_var_values.insert(local.pat.hir_id, value);
}
}
ExprKind::MethodCall(_, _, _, _) => {
self.current_method_call = Some(ident.name);
}
// Check if the variable being declared is a boolean, if so, add it to the bool_declarations hashmap
ExprKind::Path(QPath::Resolved(_, path)) => {
if let Some(segment) = path.segments.last() {
if let Res::Local(hir_id) = segment.res {
if let Some(value) = self.bool_var_values.get(&hir_id) {
self.bool_var_values.insert(local.pat.hir_id, *value);
}
}
fn handle_insert(&mut self, expr: &Expr<'_>) {
if_chain! {
if let ExprKind::MethodCall(_, receiver, _, _) = &expr.kind;
if let object_type = self.cx.typeck_results().expr_ty(receiver);
if let TyKind::Adt(adt_def, substs) = object_type.kind();
if let Some(variant) = adt_def.variants().get(VariantIdx::from_u32(0));
if variant.name.as_str() == MAPPING;
then {
let mut has_account_id = false;
let mut has_u128 = false;

substs.types().for_each(|ty| {
let type_str = ty.to_string();
has_account_id |= type_str.contains(ACCOUNT_ID);
has_u128 |= type_str.contains(U128);
});

self.found_insert = has_account_id && has_u128;
}
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for ReentrancyVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx Local<'tcx>) {
if let Some(init) = &local.init {
if let PatKind::Binding(_, _, ident, _) = &local.pat.kind {
match &init.kind {
ExprKind::Lit(lit) => {
if let LitKind::Bool(value) = lit.node {
self.bool_values.insert(local.pat.hir_id, value);
}
}
ExprKind::MethodCall(_, _, _, _) => {
self.current_method = Some(ident.name);
}
ExprKind::Path(QPath::Resolved(_, path)) => {
if let Some(segment) = path.segments.last() {
if let Res::Local(hir_id) = segment.res {
if let Some(&value) = self.bool_values.get(&hir_id) {
self.bool_values.insert(local.pat.hir_id, value);
}
}
_ => (),
}
}
walk_local(self, local);
_ => (),
}
}
}
walk_local(self, local);
}

// This method is called for every expression.
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
if let ExprKind::MethodCall(func, _, args, _) = &expr.kind {
let function_name = func.ident.name.as_str();
match function_name {
// The function "set_allow_reentry" is being called
SET_ALLOW_REENTRY => handle_set_allow_reentry(self, args),
// The function "invoke_contract" is being called
INVOKE_CONTRACT => handle_invoke_contract(self, args, expr),
// The function "insert" is being called
INSERT => {
if self.should_look_for_insert {
handle_insert(self, expr)
}
}
_ => (),
}
}
walk_expr(self, expr)
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
if let ExprKind::MethodCall(func, _, args, _) = &expr.kind {
match func.ident.name.as_str() {
SET_ALLOW_REENTRY => self.handle_set_allow_reentry(args),
CALL_FLAGS => self.handle_call_flags(args),
INVOKE_CONTRACT => self.handle_invoke_contract(args, expr),
INSERT if self.looking_for_insert => self.handle_insert(expr),
_ => (),
}
}
walk_expr(self, expr);
}
}

// The main function where we start the visitor to traverse the AST.
let mut reentrancy_visitor = ReentrancyVisitor {
impl<'tcx> LateLintPass<'tcx> for Reentrancy2 {
fn check_fn(
&mut self,
cx: &LateContext<'tcx>,
_: FnKind<'tcx>,
_: &'tcx FnDecl<'_>,
body: &'tcx Body<'_>,
_: Span,
_: LocalDefId,
) {
let mut visitor = ReentrancyVisitor {
cx,
contracts_tainted_for_reentrancy: HashSet::new(),
current_method_call: None,
bool_var_values: HashMap::new(),
tainted_contracts: HashSet::new(),
current_method: None,
bool_values: HashMap::new(),
reentrancy_spans: Vec::new(),
has_insert_operation: false,
should_look_for_insert: false,
looking_for_insert: false,
found_insert: false,
};
walk_expr(&mut reentrancy_visitor, body.value);
walk_expr(&mut visitor, body.value);

// Iterate over all potential reentrancy spans and emit a warning for each.
if reentrancy_visitor.has_insert_operation {
reentrancy_visitor.reentrancy_spans.into_iter().for_each(|span| {
clippy_wrappers::span_lint_and_help(
if visitor.found_insert {
for span in visitor.reentrancy_spans {
span_lint_and_help(
cx,
REENTRANCY_2,
span,
LINT_MESSAGE,
None,
"This statement seems to call another contract after the flag set_allow_reentry was enabled [todo: check state changes after this statement]"
"This statement seems to call another contract after the flag \
set_allow_reentry was enabled [todo: check state changes after this statement]",
);
})
}
}
}
}
8 changes: 5 additions & 3 deletions test-cases/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ members = ["*/*/*-example"]
resolver = "2"

[workspace.dependencies]
getrandom = { version = "0.2" }
ink = { version = "5.0.0", default-features = false }
scale = { package = "parity-scale-codec", version = "3", default-features = false, features = ["derive"] }
scale-info = { version = "2.6", default-features = false, features = ["derive"] }
ink_e2e = { version = "=5.0.0" }
getrandom = { version = "0.2" }
scale = { package = "parity-scale-codec", version = "3", default-features = false, features = [
"derive",
] }
scale-info = { version = "2.6", default-features = false, features = ["derive"] }

[profile.release]
codegen-units = 1
Expand Down
Loading
Loading