diff --git a/src/artifacts/ast/mod.rs b/src/artifacts/ast/mod.rs index c3fc731b..ad863323 100644 --- a/src/artifacts/ast/mod.rs +++ b/src/artifacts/ast/mod.rs @@ -182,7 +182,7 @@ ast_node!( #[serde(default, deserialize_with = "serde_helpers::default_for_null")] used_events: Vec, #[serde(default, rename = "internalFunctionIDs")] - internal_function_ids: BTreeMap, + internal_function_ids: BTreeMap, } ); @@ -908,6 +908,7 @@ pub enum AssemblyReferenceSuffix { /// Inline assembly flags. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum InlineAssemblyFlag { + #[serde(rename = "memory-safe")] MemorySafe, } diff --git a/src/artifacts/ast/visitor.rs b/src/artifacts/ast/visitor.rs index 8b137891..d2405190 100644 --- a/src/artifacts/ast/visitor.rs +++ b/src/artifacts/ast/visitor.rs @@ -1 +1,596 @@ +use super::*; +pub trait Visitor { + fn visit_source_unit(&mut self, _source_unit: &SourceUnit) {} + fn visit_import_directive(&mut self, _directive: &ImportDirective) {} + fn visit_pragma_directive(&mut self, _directive: &PragmaDirective) {} + fn visit_block(&mut self, _block: &Block) {} + fn visit_statement(&mut self, _statement: &Statement) {} + fn visit_expression(&mut self, _expression: &Expression) {} + fn visit_function_call(&mut self, _function_call: &FunctionCall) {} + fn visit_user_defined_type_name(&mut self, _type_name: &UserDefinedTypeName) {} + fn visit_identifier_path(&mut self, _identifier_path: &IdentifierPath) {} + fn visit_type_name(&mut self, _type_name: &TypeName) {} + fn visit_parameter_list(&mut self, _parameter_list: &ParameterList) {} + fn visit_function_definition(&mut self, _definition: &FunctionDefinition) {} + fn visit_enum_definition(&mut self, _definition: &EnumDefinition) {} + fn visit_error_definition(&mut self, _definition: &ErrorDefinition) {} + fn visit_event_definition(&mut self, _definition: &EventDefinition) {} + fn visit_struct_definition(&mut self, _definition: &StructDefinition) {} + fn visit_modifier_definition(&mut self, _definition: &ModifierDefinition) {} + fn visit_variable_declaration(&mut self, _declaration: &VariableDeclaration) {} + fn visit_overrides(&mut self, _specifier: &OverrideSpecifier) {} + fn visit_user_defined_value_type(&mut self, _value_type: &UserDefinedValueTypeDefinition) {} + fn visit_contract_definition(&mut self, _definition: &ContractDefinition) {} + fn visit_using_for(&mut self, _directive: &UsingForDirective) {} + fn visit_unary_operation(&mut self, _unary_op: &UnaryOperation) {} + fn visit_binary_operation(&mut self, _binary_op: &BinaryOperation) {} + fn visit_conditional(&mut self, _conditional: &Conditional) {} + fn visit_tuple_expression(&mut self, _tuple_expression: &TupleExpression) {} + fn visit_new_expression(&mut self, _new_expression: &NewExpression) {} + fn visit_assignment(&mut self, _assignment: &Assignment) {} + fn visit_identifier(&mut self, _identifier: &Identifier) {} + fn visit_index_access(&mut self, _index_access: &IndexAccess) {} + fn visit_index_range_access(&mut self, _index_range_access: &IndexRangeAccess) {} + fn visit_while_statement(&mut self, _while_statement: &WhileStatement) {} + fn visit_for_statement(&mut self, _for_statement: &ForStatement) {} + fn visit_if_statement(&mut self, _if_statement: &IfStatement) {} + fn visit_do_while_statement(&mut self, _do_while_statement: &DoWhileStatement) {} + fn visit_emit_statement(&mut self, _emit_statement: &EmitStatement) {} + fn visit_unchecked_block(&mut self, _unchecked_block: &UncheckedBlock) {} + fn visit_try_statement(&mut self, _try_statement: &TryStatement) {} + fn visit_revert_statement(&mut self, _revert_statement: &RevertStatement) {} + fn visit_member_access(&mut self, _member_access: &MemberAccess) {} + fn visit_mapping(&mut self, _mapping: &Mapping) {} + fn visit_elementary_type_name(&mut self, _elementary_type_name: &ElementaryTypeName) {} + fn visit_literal(&mut self, _literal: &Literal) {} + fn visit_function_type_name(&mut self, _function_type_name: &FunctionTypeName) {} + fn visit_array_type_name(&mut self, _array_type_name: &ArrayTypeName) {} + fn visit_function_call_options(&mut self, _function_call: &FunctionCallOptions) {} + fn visit_return(&mut self, _return: &Return) {} + fn visit_inheritance_specifier(&mut self, _specifier: &InheritanceSpecifier) {} + fn visit_modifier_invocation(&mut self, _invocation: &ModifierInvocation) {} +} + +pub trait Walk { + fn walk(&self, visitor: &mut dyn Visitor); +} + +macro_rules! impl_walk { + // Implement `Walk` for a type, calling the given function. + ($ty:ty, | $val:ident, $visitor:ident | $e:expr) => { + impl Walk for $ty { + fn walk(&self, visitor: &mut dyn Visitor) { + let $val = self; + let $visitor = visitor; + $e + } + } + }; + ($ty:ty, $func:ident) => { + impl_walk!($ty, |obj, visitor| { + visitor.$func(obj); + }); + }; + ($ty:ty, $func:ident, | $val:ident, $visitor:ident | $e:expr) => { + impl_walk!($ty, |$val, $visitor| { + $visitor.$func($val); + $e + }); + }; +} + +impl_walk!(SourceUnit, visit_source_unit, |source_unit, visitor| { + source_unit.nodes.iter().for_each(|part| { + part.walk(visitor); + }); +}); + +impl_walk!(SourceUnitPart, |part, visitor| { + match part { + SourceUnitPart::ContractDefinition(contract) => { + contract.walk(visitor); + } + SourceUnitPart::UsingForDirective(directive) => { + directive.walk(visitor); + } + SourceUnitPart::ErrorDefinition(error) => { + error.walk(visitor); + } + SourceUnitPart::StructDefinition(struct_) => { + struct_.walk(visitor); + } + SourceUnitPart::VariableDeclaration(declaration) => { + declaration.walk(visitor); + } + SourceUnitPart::FunctionDefinition(function) => { + function.walk(visitor); + } + SourceUnitPart::UserDefinedValueTypeDefinition(value_type) => { + value_type.walk(visitor); + } + SourceUnitPart::ImportDirective(directive) => { + directive.walk(visitor); + } + SourceUnitPart::EnumDefinition(enum_) => { + enum_.walk(visitor); + } + SourceUnitPart::PragmaDirective(directive) => { + directive.walk(visitor); + } + } +}); + +impl_walk!(ContractDefinition, visit_contract_definition, |contract, visitor| { + contract.base_contracts.iter().for_each(|base_contract| { + base_contract.walk(visitor); + }); + + for part in &contract.nodes { + match part { + ContractDefinitionPart::FunctionDefinition(function) => { + function.walk(visitor); + } + ContractDefinitionPart::ErrorDefinition(error) => { + error.walk(visitor); + } + ContractDefinitionPart::EventDefinition(event) => { + event.walk(visitor); + } + ContractDefinitionPart::StructDefinition(struct_) => { + struct_.walk(visitor); + } + ContractDefinitionPart::VariableDeclaration(declaration) => { + declaration.walk(visitor); + } + ContractDefinitionPart::ModifierDefinition(modifier) => { + modifier.walk(visitor); + } + ContractDefinitionPart::UserDefinedValueTypeDefinition(definition) => { + definition.walk(visitor); + } + ContractDefinitionPart::UsingForDirective(directive) => { + directive.walk(visitor); + } + ContractDefinitionPart::EnumDefinition(enum_) => { + enum_.walk(visitor); + } + } + } +}); + +impl_walk!(Expression, visit_expression, |expr, visitor| { + match expr { + Expression::FunctionCall(expression) => { + expression.walk(visitor); + } + Expression::MemberAccess(member_access) => { + member_access.walk(visitor); + } + Expression::IndexAccess(index_access) => { + index_access.walk(visitor); + } + Expression::UnaryOperation(unary_op) => { + unary_op.walk(visitor); + } + Expression::BinaryOperation(expression) => { + expression.walk(visitor); + } + Expression::Conditional(expression) => { + expression.walk(visitor); + } + Expression::TupleExpression(tuple) => { + tuple.walk(visitor); + } + Expression::NewExpression(expression) => { + expression.walk(visitor); + } + Expression::Assignment(expression) => { + expression.walk(visitor); + } + Expression::Identifier(identifier) => { + identifier.walk(visitor); + } + Expression::FunctionCallOptions(function_call) => { + function_call.walk(visitor); + } + Expression::IndexRangeAccess(range_access) => { + range_access.walk(visitor); + } + Expression::Literal(literal) => { + literal.walk(visitor); + } + Expression::ElementaryTypeNameExpression(type_name) => { + type_name.walk(visitor); + } + } +}); + +impl_walk!(Statement, visit_statement, |statement, visitor| { + match statement { + Statement::Block(block) => { + block.walk(visitor); + } + Statement::WhileStatement(statement) => { + statement.walk(visitor); + } + Statement::ForStatement(statement) => { + statement.walk(visitor); + } + Statement::IfStatement(statement) => { + statement.walk(visitor); + } + Statement::DoWhileStatement(statement) => { + statement.walk(visitor); + } + Statement::EmitStatement(statement) => { + statement.walk(visitor); + } + Statement::VariableDeclarationStatement(statement) => { + statement.walk(visitor); + } + Statement::ExpressionStatement(statement) => { + statement.walk(visitor); + } + Statement::UncheckedBlock(statement) => { + statement.walk(visitor); + } + Statement::TryStatement(statement) => { + statement.walk(visitor); + } + Statement::RevertStatement(statement) => { + statement.walk(visitor); + } + Statement::Return(statement) => { + statement.walk(visitor); + } + Statement::Break(_) + | Statement::Continue(_) + | Statement::InlineAssembly(_) + | Statement::PlaceholderStatement(_) => {} + } +}); + +impl_walk!(FunctionDefinition, visit_function_definition, |function, visitor| { + function.parameters.walk(visitor); + function.return_parameters.walk(visitor); + + if let Some(overrides) = &function.overrides { + overrides.walk(visitor); + } + + if let Some(body) = &function.body { + body.walk(visitor); + } + + function.modifiers.iter().for_each(|m| m.walk(visitor)); +}); + +impl_walk!(ErrorDefinition, visit_error_definition, |error, visitor| { + error.parameters.walk(visitor); +}); + +impl_walk!(EventDefinition, visit_event_definition, |event, visitor| { + event.parameters.walk(visitor); +}); + +impl_walk!(StructDefinition, visit_struct_definition, |struct_, visitor| { + struct_.members.iter().for_each(|member| member.walk(visitor)); +}); + +impl_walk!(ModifierDefinition, visit_modifier_definition, |modifier, visitor| { + modifier.body.walk(visitor); + if let Some(override_) = &modifier.overrides { + override_.walk(visitor); + } + modifier.parameters.walk(visitor); +}); + +impl_walk!(VariableDeclaration, visit_variable_declaration, |declaration, visitor| { + if let Some(value) = &declaration.value { + value.walk(visitor); + } + + if let Some(type_name) = &declaration.type_name { + type_name.walk(visitor); + } +}); + +impl_walk!(OverrideSpecifier, visit_overrides, |override_, visitor| { + override_.overrides.iter().for_each(|type_name| { + type_name.walk(visitor); + }); +}); + +impl_walk!(UserDefinedValueTypeDefinition, visit_user_defined_value_type, |value_type, visitor| { + value_type.underlying_type.walk(visitor); +}); + +impl_walk!(FunctionCallOptions, visit_function_call_options, |function_call, visitor| { + function_call.expression.walk(visitor); + function_call.options.iter().for_each(|option| { + option.walk(visitor); + }); +}); + +impl_walk!(Return, visit_return, |return_, visitor| { + if let Some(expr) = return_.expression.as_ref() { + expr.walk(visitor); + } +}); + +impl_walk!(UsingForDirective, visit_using_for, |directive, visitor| { + if let Some(type_name) = &directive.type_name { + type_name.walk(visitor); + } + if let Some(library_name) = &directive.library_name { + library_name.walk(visitor); + } + for function in &directive.function_list { + function.function.walk(visitor); + } +}); + +impl_walk!(UnaryOperation, visit_unary_operation, |unary_op, visitor| { + unary_op.sub_expression.walk(visitor); +}); + +impl_walk!(BinaryOperation, visit_binary_operation, |binary_op, visitor| { + binary_op.lhs.walk(visitor); + binary_op.rhs.walk(visitor); +}); + +impl_walk!(Conditional, visit_conditional, |conditional, visitor| { + conditional.condition.walk(visitor); + conditional.true_expression.walk(visitor); + conditional.false_expression.walk(visitor); +}); + +impl_walk!(TupleExpression, visit_tuple_expression, |tuple_expression, visitor| { + tuple_expression.components.iter().filter_map(|component| component.as_ref()).for_each( + |component| { + component.walk(visitor); + }, + ); +}); + +impl_walk!(NewExpression, visit_new_expression, |new_expression, visitor| { + new_expression.type_name.walk(visitor); +}); + +impl_walk!(Assignment, visit_assignment, |assignment, visitor| { + assignment.lhs.walk(visitor); + assignment.rhs.walk(visitor); +}); +impl_walk!(IfStatement, visit_if_statement, |if_statement, visitor| { + if_statement.condition.walk(visitor); + if_statement.true_body.walk(visitor); + + if let Some(false_body) = &if_statement.false_body { + false_body.walk(visitor); + } +}); + +impl_walk!(IndexAccess, visit_index_access, |index_access, visitor| { + index_access.base_expression.walk(visitor); + if let Some(index_expression) = &index_access.index_expression { + index_expression.walk(visitor); + } +}); + +impl_walk!(IndexRangeAccess, visit_index_range_access, |index_range_access, visitor| { + index_range_access.base_expression.walk(visitor); + if let Some(start_expression) = &index_range_access.start_expression { + start_expression.walk(visitor); + } + if let Some(end_expression) = &index_range_access.end_expression { + end_expression.walk(visitor); + } +}); + +impl_walk!(WhileStatement, visit_while_statement, |while_statement, visitor| { + while_statement.condition.walk(visitor); + while_statement.body.walk(visitor); +}); + +impl_walk!(ForStatement, visit_for_statement, |for_statement, visitor| { + for_statement.body.walk(visitor); + if let Some(condition) = &for_statement.condition { + condition.walk(visitor); + } + + if let Some(loop_expression) = &for_statement.loop_expression { + loop_expression.walk(visitor); + } + + if let Some(initialization_expr) = &for_statement.initialization_expression { + initialization_expr.walk(visitor); + } +}); + +impl_walk!(DoWhileStatement, visit_do_while_statement, |do_while_statement, visitor| { + do_while_statement.block.walk(visitor); + do_while_statement.condition.walk(visitor); +}); + +impl_walk!(EmitStatement, visit_emit_statement, |emit_statement, visitor| { + emit_statement.event_call.walk(visitor); +}); + +impl_walk!(VariableDeclarationStatement, |stmt, visitor| { + stmt.declarations.iter().filter_map(|d| d.as_ref()).for_each(|declaration| { + declaration.walk(visitor); + }); + if let Some(initial_value) = &stmt.initial_value { + initial_value.walk(visitor); + } +}); + +impl_walk!(UncheckedBlock, visit_unchecked_block, |unchecked_block, visitor| { + unchecked_block.statements.iter().for_each(|statement| { + statement.walk(visitor); + }); +}); + +impl_walk!(TryStatement, visit_try_statement, |try_statement, visitor| { + try_statement.clauses.iter().for_each(|clause| { + clause.block.walk(visitor); + + if let Some(parameter_list) = &clause.parameters { + parameter_list.walk(visitor); + } + }); + + try_statement.external_call.walk(visitor); +}); + +impl_walk!(RevertStatement, visit_revert_statement, |revert_statement, visitor| { + revert_statement.error_call.walk(visitor); +}); + +impl_walk!(MemberAccess, visit_member_access, |member_access, visitor| { + member_access.expression.walk(visitor); +}); + +impl_walk!(FunctionCall, visit_function_call, |function_call, visitor| { + function_call.expression.walk(visitor); + function_call.arguments.iter().for_each(|argument| { + argument.walk(visitor); + }); +}); + +impl_walk!(Block, visit_block, |block, visitor| { + block.statements.iter().for_each(|statement| { + statement.walk(visitor); + }); +}); + +impl_walk!(UserDefinedTypeName, visit_user_defined_type_name, |type_name, visitor| { + if let Some(path_node) = &type_name.path_node { + path_node.walk(visitor); + } +}); + +impl_walk!(TypeName, visit_type_name, |type_name, visitor| { + match type_name { + TypeName::ElementaryTypeName(type_name) => { + type_name.walk(visitor); + } + TypeName::UserDefinedTypeName(type_name) => { + type_name.walk(visitor); + } + TypeName::Mapping(mapping) => { + mapping.walk(visitor); + } + TypeName::ArrayTypeName(array) => { + array.walk(visitor); + } + TypeName::FunctionTypeName(function) => { + function.walk(visitor); + } + } +}); + +impl_walk!(FunctionTypeName, visit_function_type_name, |function, visitor| { + function.parameter_types.walk(visitor); + function.return_parameter_types.walk(visitor); +}); + +impl_walk!(ParameterList, visit_parameter_list, |parameter_list, visitor| { + parameter_list.parameters.iter().for_each(|parameter| { + parameter.walk(visitor); + }); +}); + +impl_walk!(Mapping, visit_mapping, |mapping, visitor| { + mapping.key_type.walk(visitor); + mapping.value_type.walk(visitor); +}); + +impl_walk!(ArrayTypeName, visit_array_type_name, |array, visitor| { + array.base_type.walk(visitor); + if let Some(length) = &array.length { + length.walk(visitor); + } +}); + +impl_walk!(InheritanceSpecifier, visit_inheritance_specifier, |specifier, visitor| { + specifier.base_name.walk(visitor); + specifier.arguments.iter().for_each(|arg| { + arg.walk(visitor); + }); +}); + +impl_walk!(ModifierInvocation, visit_modifier_invocation, |invocation, visitor| { + invocation.arguments.iter().for_each(|arg| arg.walk(visitor)); + invocation.modifier_name.walk(visitor); +}); + +impl_walk!(ElementaryTypeName, visit_elementary_type_name); +impl_walk!(Literal, visit_literal); +impl_walk!(ImportDirective, visit_import_directive); +impl_walk!(PragmaDirective, visit_pragma_directive); +impl_walk!(IdentifierPath, visit_identifier_path); +impl_walk!(EnumDefinition, visit_enum_definition); +impl_walk!(Identifier, visit_identifier); + +impl_walk!(UserDefinedTypeNameOrIdentifierPath, |type_name, visitor| { + match type_name { + UserDefinedTypeNameOrIdentifierPath::UserDefinedTypeName(type_name) => { + type_name.walk(visitor); + } + UserDefinedTypeNameOrIdentifierPath::IdentifierPath(identifier_path) => { + identifier_path.walk(visitor); + } + } +}); + +impl_walk!(BlockOrStatement, |block_or_statement, visitor| { + match block_or_statement { + BlockOrStatement::Block(block) => { + block.walk(visitor); + } + BlockOrStatement::Statement(statement) => { + statement.walk(visitor); + } + } +}); + +impl_walk!(ExpressionOrVariableDeclarationStatement, |val, visitor| { + match val { + ExpressionOrVariableDeclarationStatement::ExpressionStatement(expression) => { + expression.walk(visitor); + } + ExpressionOrVariableDeclarationStatement::VariableDeclarationStatement(stmt) => { + stmt.walk(visitor); + } + } +}); + +impl_walk!(IdentifierOrIdentifierPath, |val, visitor| { + match val { + IdentifierOrIdentifierPath::Identifier(ident) => { + ident.walk(visitor); + } + IdentifierOrIdentifierPath::IdentifierPath(path) => { + path.walk(visitor); + } + } +}); + +impl_walk!(ExpressionStatement, |expression_statement, visitor| { + expression_statement.expression.walk(visitor); +}); + +impl_walk!(ElementaryTypeNameExpression, |type_name, visitor| { + type_name.type_name.walk(visitor); +}); + +impl_walk!(ElementaryOrRawTypeName, |type_name, visitor| { + match type_name { + ElementaryOrRawTypeName::ElementaryTypeName(type_name) => { + type_name.walk(visitor); + } + ElementaryOrRawTypeName::Raw(_) => {} + } +}); diff --git a/src/config.rs b/src/config.rs index fe7ffdee..8a523692 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,13 +2,14 @@ use crate::{ artifacts::{output_selection::ContractOutputSelection, Settings}, cache::SOLIDITY_FILES_CACHE_FILENAME, error::{Result, SolcError, SolcIoError}, + flatten::collect_ordered_deps, remappings::Remapping, resolver::{Graph, SolImportAlias}, utils, Source, Sources, }; use serde::{Deserialize, Serialize}; use std::{ - collections::{BTreeSet, HashSet}, + collections::BTreeSet, fmt::{self, Formatter}, fs, ops::{Deref, DerefMut}, @@ -401,110 +402,100 @@ impl ProjectPathsConfig { // part of the graph if it's not imported by any input file let flatten_target = target.to_path_buf(); if !input_files.contains(&flatten_target) { - input_files.push(flatten_target); + input_files.push(flatten_target.clone()); } let sources = Source::read_all_files(input_files)?; let graph = Graph::resolve_sources(self, sources)?; - self.flatten_node(target, &graph, &mut Default::default(), false, false, false).map(|x| { - format!("{}\n", utils::RE_THREE_OR_MORE_NEWLINES.replace_all(&x, "\n\n").trim()) - }) - } + let ordered_deps = collect_ordered_deps(&flatten_target, self, &graph)?; - /// Flattens a single node from the dependency graph - fn flatten_node( - &self, - target: &Path, - graph: &Graph, - imported: &mut HashSet, - strip_version_pragma: bool, - strip_experimental_pragma: bool, - strip_license: bool, - ) -> Result { - let target_dir = target.parent().ok_or_else(|| { - SolcError::msg(format!("failed to get parent directory for \"{:?}\"", target.display())) - })?; - let target_index = graph.files().get(target).ok_or_else(|| { - SolcError::msg(format!("cannot resolve file at {:?}", target.display())) - })?; - - if imported.contains(target_index) { - // short circuit nodes that were already imported, if both A.sol and B.sol import C.sol - return Ok(String::new()); - } - imported.insert(*target_index); + let mut sources = Vec::new(); - let target_node = graph.node(*target_index); + let mut result = String::new(); - let mut imports = target_node.imports().clone(); - imports.sort_by_key(|x| x.loc().start); + for path in ordered_deps.iter() { + let node_id = graph.files().get(path).ok_or_else(|| { + SolcError::msg(format!("cannot resolve file at {}", path.display())) + })?; + let node = graph.node(*node_id); + let content = node.content().to_owned(); - let mut content = target_node.content().to_owned(); + // Firstly we strip all licesnses, verson pragmas + // We keep target file pragma and license placing them in the beginning of the result. + let mut ranges_to_remove = Vec::new(); - for alias in imports.iter().flat_map(|i| i.data().aliases()) { - let (alias, target) = match alias { - SolImportAlias::Contract(alias, target) => (alias.clone(), target.clone()), - _ => continue, - }; - let name_regex = utils::create_contract_or_lib_name_regex(&alias); - let target_len = target.len() as isize; - let mut replace_offset = 0; - for cap in name_regex.captures_iter(&content.clone()) { - if cap.name("ignore").is_some() { - continue; + if let Some(license) = node.license() { + ranges_to_remove.push(license.loc()); + if *path == flatten_target { + result.push_str(&content[license.loc()]); + result.push('\n'); } - if let Some(name_match) = ["n1", "n2", "n3"].iter().find_map(|name| cap.name(name)) - { - let name_match_range = - utils::range_by_offset(&name_match.range(), replace_offset); - replace_offset += target_len - (name_match_range.len() as isize); - content.replace_range(name_match_range, &target); + } + if let Some(version) = node.version() { + ranges_to_remove.push(version.loc()); + if *path == flatten_target { + result.push_str(&content[version.loc()]); + result.push('\n'); } } - } + if let Some(experimental) = node.experimental() { + ranges_to_remove.push(experimental.loc()); + if *path == flatten_target { + result.push_str(&content[experimental.loc()]); + result.push('\n'); + } + } + for import in node.imports() { + ranges_to_remove.push(import.loc()); + } + ranges_to_remove.sort_by_key(|loc| loc.start); - let mut content = content.as_bytes().to_vec(); - let mut offset = 0_isize; - - let mut statements = [ - (target_node.license(), strip_license), - (target_node.version(), strip_version_pragma), - (target_node.experimental(), strip_experimental_pragma), - ] - .iter() - .filter_map(|(data, condition)| if *condition { data.to_owned().as_ref() } else { None }) - .collect::>(); - statements.sort_by_key(|x| x.loc().start); - - let (mut imports, mut statements) = - (imports.iter().peekable(), statements.iter().peekable()); - while imports.peek().is_some() || statements.peek().is_some() { - let (next_import_start, next_statement_start) = ( - imports.peek().map_or(usize::max_value(), |x| x.loc().start), - statements.peek().map_or(usize::max_value(), |x| x.loc().start), - ); - if next_statement_start < next_import_start { - let repl_range = statements.next().unwrap().loc_by_offset(offset); + let mut content = content.as_bytes().to_vec(); + let mut offset = 0_isize; + + for range in ranges_to_remove { + let repl_range = utils::range_by_offset(&range, offset); offset -= repl_range.len() as isize; content.splice(repl_range, std::iter::empty()); - } else { - let import = imports.next().unwrap(); - let import_path = self.resolve_import(target_dir, import.data().path())?; - let s = self.flatten_node(&import_path, graph, imported, true, true, true)?; - - let import_content = s.as_bytes(); - let import_content_len = import_content.len() as isize; - let import_range = import.loc_by_offset(offset); - offset += import_content_len - (import_range.len() as isize); - content.splice(import_range, import_content.iter().copied()); } + + let mut content = String::from_utf8(content).map_err(|err| { + SolcError::msg(format!("failed to convert extended bytes to string: {err}")) + })?; + + // Iterate over all aliased imports, and replace alias with real name via regexes + for alias in node.imports().iter().flat_map(|i| i.data().aliases()) { + let (alias, target) = match alias { + SolImportAlias::Contract(alias, target) => (alias.clone(), target.clone()), + _ => continue, + }; + let name_regex = utils::create_contract_or_lib_name_regex(&alias); + let target_len = target.len() as isize; + let mut replace_offset = 0; + for cap in name_regex.captures_iter(&content.clone()) { + if cap.name("ignore").is_some() { + continue; + } + if let Some(name_match) = + ["n1", "n2", "n3"].iter().find_map(|name| cap.name(name)) + { + let name_match_range = + utils::range_by_offset(&name_match.range(), replace_offset); + replace_offset += target_len - (name_match_range.len() as isize); + content.replace_range(name_match_range, &target); + } + } + } + + sources.push(content); } - let result = String::from_utf8(content).map_err(|err| { - SolcError::msg(format!("failed to convert extended bytes to string: {err}")) - })?; + for source in sources { + result.push_str(&source); + result.push_str("\n\n"); + } - Ok(result) + Ok(format!("{}\n", utils::RE_THREE_OR_MORE_NEWLINES.replace_all(&result, "\n\n").trim())) } } diff --git a/src/flatten.rs b/src/flatten.rs new file mode 100644 index 00000000..2afcd8f8 --- /dev/null +++ b/src/flatten.rs @@ -0,0 +1,628 @@ +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + path::{Path, PathBuf}, +}; + +use crate::{ + artifacts::{ + ast::SourceLocation, + visitor::{Visitor, Walk}, + ContractDefinitionPart, Identifier, IdentifierPath, MemberAccess, Source, SourceUnit, + SourceUnitPart, Sources, UserDefinedTypeName, + }, + error::SolcError, + utils, Graph, Project, ProjectCompileOutput, ProjectPathsConfig, Result, +}; + +/// Alternative of `SourceLocation` which includes path of the file. +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +struct ItemLocation { + path: PathBuf, + start: usize, + end: usize, +} + +impl ItemLocation { + fn try_from_source_loc(src: &SourceLocation, path: PathBuf) -> Option { + let start = src.start?; + let end = start + src.length?; + + Some(ItemLocation { path, start, end }) + } +} + +/// Visitor exploring AST and collecting all references to any declarations +struct ReferencesCollector { + path: PathBuf, + references: HashMap>, +} + +impl ReferencesCollector { + fn process_referenced_declaration(&mut self, id: isize, src: &SourceLocation) { + if let Some(loc) = ItemLocation::try_from_source_loc(src, self.path.clone()) { + self.references.entry(id).or_default().insert(loc); + } + } +} + +impl Visitor for ReferencesCollector { + fn visit_identifier(&mut self, identifier: &Identifier) { + if let Some(id) = identifier.referenced_declaration { + self.process_referenced_declaration(id, &identifier.src); + } + } + + fn visit_user_defined_type_name(&mut self, type_name: &UserDefinedTypeName) { + self.process_referenced_declaration(type_name.referenced_declaration, &type_name.src); + } + + fn visit_member_access(&mut self, member_access: &MemberAccess) { + if let Some(id) = member_access.referenced_declaration { + self.process_referenced_declaration(id, &member_access.src); + } + } + + fn visit_identifier_path(&mut self, path: &IdentifierPath) { + self.process_referenced_declaration(path.referenced_declaration, &path.src); + } +} + +/// Visitor exploring AST and collecting all references to any declarations found in +/// `UserDefinedTypeName` nodes +struct UserDefinedTypeNamesCollector { + path: PathBuf, + references: HashMap>, +} + +impl Visitor for UserDefinedTypeNamesCollector { + fn visit_user_defined_type_name(&mut self, type_name: &UserDefinedTypeName) { + if let Some(loc) = ItemLocation::try_from_source_loc(&type_name.src, self.path.clone()) { + self.references.entry(type_name.referenced_declaration).or_default().insert(loc); + } + } +} + +/// Updates to be applied to the sources. +/// source_path -> (start, end, new_value) +type Updates = HashMap>; + +struct FlatteningResult<'a> { + /// Updated source in the order they shoud be written to the output file. + sources: Vec, + /// Pragmas that should be present in the target file. + pragmas: Vec<&'a str>, + /// License identifier that should be present in the target file. + license: Option<&'a str>, +} + +impl<'a> FlatteningResult<'a> { + fn new( + flattener: &Flattener, + mut updates: Updates, + pragmas: Vec<&'a str>, + license: Option<&'a str>, + ) -> Self { + let mut sources = Vec::new(); + + for path in &flattener.ordered_sources { + let mut content = flattener.sources.get(path).unwrap().content.as_bytes().to_vec(); + let mut offset: isize = 0; + if let Some(updates) = updates.remove(path) { + let mut updates = updates.iter().collect::>(); + updates.sort_by_key(|(start, _, _)| *start); + for (start, end, new_value) in updates { + let start = (*start as isize + offset) as usize; + let end = (*end as isize + offset) as usize; + + content.splice(start..end, new_value.bytes()); + offset += new_value.len() as isize - (end - start) as isize; + } + } + sources.push(String::from_utf8(content).unwrap()); + } + + Self { sources, pragmas, license } + } + + fn get_flattened_target(&self) -> String { + let mut result = String::new(); + + if let Some(license) = &self.license { + result.push_str(&format!("{}\n", license)); + } + for pragma in &self.pragmas { + result.push_str(&format!("{}\n", pragma)); + } + for source in &self.sources { + result.push_str(&format!("{}\n\n", source)); + } + + format!("{}\n", utils::RE_THREE_OR_MORE_NEWLINES.replace_all(&result, "\n\n").trim()) + } +} + +/// Context for flattening. Stores all sources and ASTs that are in scope of the flattening target. +pub struct Flattener { + /// Target file to flatten. + target: PathBuf, + /// Sources including only target and it dependencies (imports of any depth). + sources: Sources, + /// Vec of (path, ast) pairs. + asts: Vec<(PathBuf, SourceUnit)>, + /// Sources in the order they should be written to the output file. + ordered_sources: Vec, +} + +impl Flattener { + /// Compilation output is expected to contain all artifacts for all sources. + /// Flattener caller is expected to resolve all imports of target file, compile them and pass + /// into this function. + pub fn new(project: &Project, output: &ProjectCompileOutput, target: &Path) -> Result { + let input_files = output + .artifacts_with_files() + .map(|(file, _, _)| PathBuf::from(file)) + .collect::>() + .into_iter() + .collect::>(); + + let sources = Source::read_all_files(input_files)?; + let graph = Graph::resolve_sources(&project.paths, sources)?; + + let ordered_deps = collect_ordered_deps(&target.to_path_buf(), &project.paths, &graph)?; + + let sources = Source::read_all(&ordered_deps)?; + + // Convert all ASTs from artifacts to strongly typed ASTs + let mut asts: Vec<(PathBuf, SourceUnit)> = Vec::new(); + for (path, ast) in output.artifacts_with_files().filter_map(|(path, _, artifact)| { + if let Some(ast) = artifact.ast.as_ref() { + if sources.contains_key(&PathBuf::from(path)) { + Some((path, ast)) + } else { + None + } + } else { + None + } + }) { + asts.push((PathBuf::from(path), serde_json::from_str(&serde_json::to_string(ast)?)?)); + } + + Ok(Flattener { target: target.into(), sources, asts, ordered_sources: ordered_deps }) + } + + /// Flattens target file and returns the result as a string + /// + /// Flattening process includes following steps: + /// 1. Find all file-level definitions and rename references to them via aliased or qualified + /// imports. + /// 2. Find all duplicates among file-level definitions and rename them to avoid conflicts. + /// 3. Remove all imports. + /// 4. Remove all pragmas except for the ones in the target file. + /// 5. Remove all license identifiers except for the one in the target file. + pub fn flatten(&self) -> String { + let mut updates = Updates::new(); + + let top_level_names = self.rename_top_level_definitions(&mut updates); + self.rename_contract_level_types_references(&top_level_names, &mut updates); + self.remove_imports(&mut updates); + let target_pragmas = self.process_pragmas(&mut updates); + let target_license = self.process_licenses(&mut updates); + + self.flatten_result(updates, target_pragmas, target_license).get_flattened_target() + } + + fn flatten_result<'a>( + &'a self, + updates: Updates, + target_pragmas: Vec<&'a str>, + target_license: Option<&'a str>, + ) -> FlatteningResult { + FlatteningResult::new(self, updates, target_pragmas, target_license) + } + + /// Finds and goes over all references to file-level definitions and updates them to match + /// definition name. This is needed for two reasons: + /// 1. We want to rename all aliased or qualified imports. + /// 2. We want to find any duplicates and rename them to avoid conflicts. + /// + /// If we find more than 1 declaration with the same name, it's name is getting changed. + /// Two Counter contracts will be renamed to Counter_0 and Counter_1 + /// + /// Returns mapping from top-level declaration id to its name (possibly updated) + fn rename_top_level_definitions(&self, updates: &mut Updates) -> HashMap { + let top_level_definitions = self.collect_top_level_definitions(); + let references = self.collect_references(); + + let mut top_level_names = HashMap::new(); + + for (name, ids) in top_level_definitions { + let mut definition_name = name.to_string(); + let needs_rename = ids.len() > 1; + + let mut ids = ids.clone().into_iter().collect::>(); + if needs_rename { + // `loc.path` is expected to be different for each id because there can't be 2 + // top-level eclarations with the same name in the same file. + // + // Sorting by loc.path to make the renaming process deterministic + ids.sort_by(|(_, loc_0), (_, loc_1)| loc_0.path.cmp(&loc_1.path)); + } + for (i, (id, loc)) in ids.iter().enumerate() { + if needs_rename { + definition_name = format!("{}_{}", name, i); + } + updates.entry(loc.path.clone()).or_default().insert(( + loc.start, + loc.end, + definition_name.clone(), + )); + if let Some(references) = references.get(&(*id as isize)) { + for loc in references { + updates.entry(loc.path.clone()).or_default().insert(( + loc.start, + loc.end, + definition_name.clone(), + )); + } + } + + top_level_names.insert(*id, definition_name.clone()); + } + } + top_level_names + } + + /// This is a workaround to be able to correctly process definitions which types + /// are present in the form of `ParentName.ChildName` where `ParentName` is a + /// contract name and `ChildName` is a struct/enum name. + /// + /// Such types are represented as `UserDefinedTypeName` in AST and don't include any + /// information about parent in which the definition of child is present. + fn rename_contract_level_types_references( + &self, + top_level_names: &HashMap, + updates: &mut Updates, + ) { + let contract_level_definitions = self.collect_contract_level_definitions(); + + for (path, ast) in &self.asts { + for node in &ast.nodes { + let current_contract_scope = match node { + SourceUnitPart::ContractDefinition(contract) => Some(contract.id), + _ => None, + }; + let mut collector = UserDefinedTypeNamesCollector { + path: self.target.clone(), + references: HashMap::new(), + }; + + node.walk(&mut collector); + + // Now this contains all definitions found in all UserDefinedTypeName nodes in the + // given source unit + let references = collector.references; + + for (id, locs) in references { + if let Some((name, contract_id)) = + contract_level_definitions.get(&(id as usize)) + { + if let Some(current_scope) = current_contract_scope { + // If this is a contract-level definition reference inside of the same + // contract it declared in, we replace it with its name + if current_scope == *contract_id { + updates.entry(path.clone()).or_default().extend( + locs.iter().map(|loc| (loc.start, loc.end, name.to_string())), + ); + continue; + } + } + // If we are in some other contract or in global scope (file-level), then we + // should replace type name with `ParentName.ChildName`` + let parent_name = top_level_names.get(contract_id).unwrap(); + updates.entry(path.clone()).or_default().extend( + locs.iter().map(|loc| { + (loc.start, loc.end, format!("{}.{}", parent_name, name)) + }), + ); + } + } + } + } + } + + /// Processes all ASTs and collects all top-level definitions in the form of + /// a mapping from name to (definition id, source location) + fn collect_top_level_definitions(&self) -> HashMap<&String, HashSet<(usize, ItemLocation)>> { + self.asts + .iter() + .flat_map(|(path, ast)| { + ast.nodes + .iter() + .filter_map(|node| match node { + SourceUnitPart::ContractDefinition(contract) => { + Some((&contract.name, contract.id, &contract.src)) + } + SourceUnitPart::EnumDefinition(enum_) => { + Some((&enum_.name, enum_.id, &enum_.src)) + } + SourceUnitPart::StructDefinition(struct_) => { + Some((&struct_.name, struct_.id, &struct_.src)) + } + SourceUnitPart::FunctionDefinition(function) => { + Some((&function.name, function.id, &function.src)) + } + SourceUnitPart::VariableDeclaration(variable) => { + Some((&variable.name, variable.id, &variable.src)) + } + SourceUnitPart::UserDefinedValueTypeDefinition(value_type) => { + Some((&value_type.name, value_type.id, &value_type.src)) + } + _ => None, + }) + .map(|(name, id, src)| { + // Find location of name in source + let content: &str = &self.sources.get(path).unwrap().content; + let start = src.start.unwrap(); + let end = start + src.length.unwrap(); + + let name_start = content[start..end].find(name).unwrap(); + let name_end = name_start + name.len(); + + let loc = ItemLocation { + path: path.clone(), + start: start + name_start, + end: start + name_end, + }; + + (name, (id, loc)) + }) + }) + .fold(HashMap::new(), |mut acc, (name, (id, item_location))| { + acc.entry(name).or_default().insert((id, item_location)); + acc + }) + } + + /// Collect all contract-level definitions in the form of a mapping from definition id to + /// (definition name, contract id) + fn collect_contract_level_definitions(&self) -> HashMap { + self.asts + .iter() + .flat_map(|(_, ast)| { + ast.nodes.iter().filter_map(|node| match node { + SourceUnitPart::ContractDefinition(contract) => { + Some((contract.id, &contract.nodes)) + } + _ => None, + }) + }) + .flat_map(|(contract_id, nodes)| { + nodes.iter().filter_map(move |node| match node { + ContractDefinitionPart::EnumDefinition(enum_) => { + Some((enum_.id, (&enum_.name, contract_id))) + } + ContractDefinitionPart::ErrorDefinition(error) => { + Some((error.id, (&error.name, contract_id))) + } + ContractDefinitionPart::EventDefinition(event) => { + Some((event.id, (&event.name, contract_id))) + } + ContractDefinitionPart::StructDefinition(struct_) => { + Some((struct_.id, (&struct_.name, contract_id))) + } + ContractDefinitionPart::FunctionDefinition(function) => { + Some((function.id, (&function.name, contract_id))) + } + ContractDefinitionPart::VariableDeclaration(variable) => { + Some((variable.id, (&variable.name, contract_id))) + } + ContractDefinitionPart::UserDefinedValueTypeDefinition(value_type) => { + Some((value_type.id, (&value_type.name, contract_id))) + } + _ => None, + }) + }) + .collect() + } + + /// Collects all references to any declaration in the form of a mapping from declaration id to + /// set of source locations it appears in + fn collect_references(&self) -> HashMap> { + self.asts + .iter() + .flat_map(|(path, ast)| { + let mut collector = + ReferencesCollector { path: path.clone(), references: HashMap::new() }; + ast.walk(&mut collector); + collector.references + }) + .fold(HashMap::new(), |mut acc, (id, locs)| { + acc.entry(id).or_default().extend(locs); + acc + }) + } + + /// Removes all imports from all sources. + fn remove_imports(&self, updates: &mut Updates) { + for loc in self.collect_imports() { + updates.entry(loc.path.clone()).or_default().insert(( + loc.start, + loc.end, + "".to_string(), + )); + } + } + + // Collects all imports locations. + fn collect_imports(&self) -> HashSet { + self.asts + .iter() + .flat_map(|(path, ast)| { + ast.nodes.iter().filter_map(|node| match node { + SourceUnitPart::ImportDirective(import) => { + ItemLocation::try_from_source_loc(&import.src, path.clone()) + } + _ => None, + }) + }) + .collect() + } + + /// Removes all pragma directives from all sources. Returns Vec of pragmas that were found in + /// target file. + fn process_pragmas(&self, updates: &mut Updates) -> Vec<&str> { + // Pragmas that will be used in the resulted file + let mut target_pragmas = Vec::new(); + + let pragmas = self.collect_pragmas(); + + let mut seen_experimental = false; + + for loc in &pragmas { + let pragma_content = self.read_location(loc); + if pragma_content.contains("experimental") { + if !seen_experimental { + seen_experimental = true; + target_pragmas.push(loc); + } + } else if loc.path == self.target { + target_pragmas.push(loc); + } + + updates.entry(loc.path.clone()).or_default().insert(( + loc.start, + loc.end, + "".to_string(), + )); + } + + target_pragmas.sort_by_key(|loc| loc.start); + target_pragmas.iter().map(|loc| self.read_location(loc)).collect::>() + } + + // Collects all pragma directives locations. + fn collect_pragmas(&self) -> HashSet { + self.asts + .iter() + .flat_map(|(path, ast)| { + ast.nodes.iter().filter_map(|node| match node { + SourceUnitPart::PragmaDirective(import) => { + ItemLocation::try_from_source_loc(&import.src, path.clone()) + } + _ => None, + }) + }) + .collect() + } + + /// Removes all license identifiers from all sources. Returns licesnse identifier from target + /// file, if any. + fn process_licenses(&self, updates: &mut Updates) -> Option<&str> { + let mut target_license = None; + + for loc in &self.collect_licenses() { + if loc.path == self.target { + target_license = Some(self.read_location(loc)); + } + updates.entry(loc.path.clone()).or_default().insert(( + loc.start, + loc.end, + "".to_string(), + )); + } + + target_license + } + + // Collects all SPDX-License-Identifier locations. + fn collect_licenses(&self) -> HashSet { + self.sources + .iter() + .flat_map(|(path, source)| { + let mut licenses = HashSet::new(); + if let Some(license_start) = source.content.find("SPDX-License-Identifier:") { + let start = + source.content[..license_start].rfind('\n').map(|i| i + 1).unwrap_or(0); + let end = start + + source.content[start..] + .find('\n') + .unwrap_or(source.content.len() - start); + licenses.insert(ItemLocation { path: path.clone(), start, end }); + } + licenses + }) + .collect() + } + + // Reads value from the given location of a source file. + fn read_location(&self, loc: &ItemLocation) -> &str { + let content: &str = &self.sources.get(&loc.path).unwrap().content; + &content[loc.start..loc.end] + } +} + +/// Performs DFS to collect all dependencies of a target +fn collect_deps( + path: &PathBuf, + paths: &ProjectPathsConfig, + graph: &Graph, + deps: &mut HashSet, +) -> Result<()> { + if deps.insert(path.clone()) { + let target_dir = path.parent().ok_or_else(|| { + SolcError::msg(format!("failed to get parent directory for \"{}\"", path.display())) + })?; + + let node_id = graph + .files() + .get(path) + .ok_or_else(|| SolcError::msg(format!("cannot resolve file at {}", path.display())))?; + + for import in graph.node(*node_id).imports() { + let path = paths.resolve_import(target_dir, import.data().path())?; + collect_deps(&path, paths, graph, deps)?; + } + } + Ok(()) +} + +/// We want to make order in which sources are written to resulted flattened file +/// deterministic. +/// +/// We can't just sort files alphabetically as it might break compilation, because Solidity +/// does not allow base class definitions to appear after derived contract +/// definitions. +/// +/// Instead, we sort files by the number of their dependencies (imports of any depth) in ascending +/// order. If files have the same number of dependencies, we sort them alphabetically. +/// Target file is always placed last. +pub fn collect_ordered_deps( + path: &PathBuf, + paths: &ProjectPathsConfig, + graph: &Graph, +) -> Result> { + let mut deps = HashSet::new(); + collect_deps(path, paths, graph, &mut deps)?; + + // Remove path prior counting dependencies + // It will be added later to the end of resulted Vec + deps.remove(path); + + let mut paths_with_deps_count = Vec::new(); + for path in deps { + let mut path_deps = HashSet::new(); + collect_deps(&path, paths, graph, &mut path_deps)?; + paths_with_deps_count.push((path_deps.len(), path)); + } + + paths_with_deps_count.sort(); + + let mut ordered_deps = + paths_with_deps_count.into_iter().map(|(_, path)| path).collect::>(); + + ordered_deps.push(path.clone()); + + Ok(ordered_deps) +} diff --git a/src/lib.rs b/src/lib.rs index 294b3312..686adad7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ use std::collections::{BTreeMap, HashSet}; mod artifact_output; pub mod buildinfo; pub mod cache; +pub mod flatten; pub mod hh; pub use artifact_output::*; diff --git a/tests/project.rs b/tests/project.rs index 1ac9cdbc..c5f08c2c 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -9,6 +9,7 @@ use foundry_compilers::{ buildinfo::BuildInfo, cache::{SolFilesCache, SOLIDITY_FILES_CACHE_FILENAME}, error::SolcError, + flatten::Flattener, info::ContractInfo, project_util::*, remappings::Remapping, @@ -432,6 +433,18 @@ fn copy_dir_all(src: impl AsRef, dst: impl AsRef) -> io::Result<()> Ok(()) } +// Runs both `flatten` implementations, asserts that their outputs match and runs additional checks +// against the output. +fn test_flatteners(project: &TempProject, target: &Path, additional_checks: fn(&str)) { + let result = project.flatten(target).unwrap(); + let solc_result = + Flattener::new(project.project(), &project.compile().unwrap(), target).unwrap().flatten(); + + assert_eq!(result, solc_result); + + additional_checks(&result); +} + #[test] fn can_flatten_file() { let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test-data/test-contract-libs"); @@ -440,15 +453,22 @@ fn can_flatten_file() { .sources(root.join("src")) .lib(root.join("lib1")) .lib(root.join("lib2")); + let project = TempProject::::new(paths).unwrap(); - let result = project.flatten(&target); - assert!(result.is_ok()); + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r#"pragma solidity 0.8.6; - let result = result.unwrap(); - assert!(!result.contains("import")); - assert!(result.contains("contract Foo")); - assert!(result.contains("contract Bar")); +contract Bar {} + +contract Baz {} + +contract Foo is Bar, Baz {} +"# + ); + }); } #[test] @@ -461,13 +481,11 @@ fn can_flatten_file_with_external_lib() { let target = root.join("contracts").join("Greeter.sol"); - let result = project.flatten(&target); - assert!(result.is_ok()); - - let result = result.unwrap(); - assert!(!result.contains("import")); - assert!(result.contains("library console")); - assert!(result.contains("contract Greeter")); + test_flatteners(&project, &target, |result| { + assert!(!result.contains("import")); + assert!(result.contains("library console")); + assert!(result.contains("contract Greeter")); + }); } #[test] @@ -478,21 +496,19 @@ fn can_flatten_file_in_dapp_sample() { let target = root.join("src/Dapp.t.sol"); - let result = project.flatten(&target); - assert!(result.is_ok()); - - let result = result.unwrap(); - assert!(!result.contains("import")); - assert!(result.contains("contract DSTest")); - assert!(result.contains("contract Dapp")); - assert!(result.contains("contract DappTest")); + test_flatteners(&project, &target, |result| { + assert!(!result.contains("import")); + assert!(result.contains("contract DSTest")); + assert!(result.contains("contract Dapp")); + assert!(result.contains("contract DappTest")); + }); } #[test] fn can_flatten_unique() { let project = TempProject::dapptools().unwrap(); - let f = project + let target = project .add_source( "A", r#" @@ -526,26 +542,26 @@ contract C { } ) .unwrap(); - let result = project.flatten(&f).unwrap(); + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r#"pragma solidity ^0.8.10; - assert_eq!( - result, - r"pragma solidity ^0.8.10; +contract B { } contract C { } -contract B { } - contract A { } -" - ); +"# + ); + }); } #[test] fn can_flatten_experimental_pragma() { let project = TempProject::dapptools().unwrap(); - let f = project + let target = project .add_source( "A", r#" @@ -582,20 +598,20 @@ contract C { } ) .unwrap(); - let result = project.flatten(&f).unwrap(); - - assert_eq!( - result, - r"pragma solidity ^0.8.10; + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r"pragma solidity ^0.8.10; pragma experimental ABIEncoderV2; -contract C { } - contract B { } +contract C { } + contract A { } " - ); + ); + }); } #[test] @@ -606,13 +622,10 @@ fn can_flatten_file_with_duplicates() { let target = root.join("contracts/FooBar.sol"); - let result = project.flatten(&target); - assert!(result.is_ok()); - - let result = result.unwrap(); - assert_eq!( - result, - r"//SPDX-License-Identifier: UNLICENSED + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r"//SPDX-License-Identifier: UNLICENSED pragma solidity >=0.6.0; contract Bar {} @@ -621,7 +634,8 @@ contract Foo {} contract FooBar {} " - ); + ); + }); } #[test] @@ -633,7 +647,7 @@ fn can_flatten_on_solang_failure() { let target = root.join("contracts/Contract.sol"); - let result = project.flatten(&target); + let result = project.flatten(target.as_path()); assert!(result.is_ok()); let result = result.unwrap(); @@ -656,7 +670,7 @@ contract Contract { fn can_flatten_multiline() { let project = TempProject::dapptools().unwrap(); - let f = project + let target = project .add_source( "A", r#" @@ -692,10 +706,10 @@ contract C { } ) .unwrap(); - let result = project.flatten(&f).unwrap(); - assert_eq!( - result, - r"pragma solidity ^0.8.10; + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r"pragma solidity ^0.8.10; contract C { } @@ -704,14 +718,15 @@ error IllegalState(); contract A { } " - ); + ); + }); } #[test] fn can_flatten_remove_extra_spacing() { let project = TempProject::dapptools().unwrap(); - let f = project + let target = project .add_source( "A", r#"pragma solidity ^0.8.10; @@ -744,10 +759,10 @@ contract C { } ) .unwrap(); - let result = project.flatten(&f).unwrap(); - assert_eq!( - result, - r"pragma solidity ^0.8.10; + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r"pragma solidity ^0.8.10; contract C { } @@ -757,14 +772,15 @@ contract B { } contract A { } " - ); + ); + }); } #[test] fn can_flatten_with_alias() { let project = TempProject::dapptools().unwrap(); - let f = project + let target = project .add_source( "Contract", r#"pragma solidity ^0.8.10; @@ -785,14 +801,10 @@ contract Contract is Parent, Peer public peer; - error Peer(); - constructor(address _peer) { peer = Peer(_peer); - } - - function Math(uint256 value) external pure returns (uint256) { - return Math.minusOne(Math.max() - value.diffMax()); + peer = new Peer(); + uint256 x = Math.minusOne(Math.max()); } } "#, @@ -856,17 +868,13 @@ library SomeLib { } ) .unwrap(); - let result = project.flatten(&f).unwrap(); - assert_eq!( - result, - r#"pragma solidity ^0.8.10; - -contract ParentContract { } + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r#"pragma solidity ^0.8.10; contract AnotherParentContract { } -contract PeerContract { } - library MathLibrary { function minusOne(uint256 val) internal returns (uint256) { return val - 1; @@ -881,6 +889,10 @@ library MathLibrary { } } +contract ParentContract { } + +contract PeerContract { } + library SomeLib { } contract Contract is ParentContract, @@ -894,25 +906,22 @@ contract Contract is ParentContract, PeerContract public peer; - error Peer(); - constructor(address _peer) { peer = PeerContract(_peer); - } - - function Math(uint256 value) external pure returns (uint256) { - return MathLibrary.minusOne(MathLibrary.max() - value.diffMax()); + peer = new PeerContract(); + uint256 x = MathLibrary.minusOne(MathLibrary.max()); } } "# - ); + ); + }); } #[test] fn can_flatten_with_version_pragma_after_imports() { let project = TempProject::dapptools().unwrap(); - let f = project + let target = project .add_source( "A", r#" @@ -929,7 +938,7 @@ contract A { } .add_source( "B", r#" -import D from "./D.sol"; +import {D} from "./D.sol"; pragma solidity ^0.8.10; import * as C from "./C.sol"; contract B { } @@ -957,18 +966,312 @@ contract D { } ) .unwrap(); - let result = project.flatten(&f).unwrap(); + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r#"pragma solidity ^0.8.10; + +contract C { } + +contract D { } + +contract B { } + +contract A { } +"# + ); + }); +} + +#[test] +fn can_flatten_with_duplicates() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "Foo.sol", + r#" +pragma solidity ^0.8.10; + +contract Foo { + function foo() public pure returns (uint256) { + return 1; + } +} + +contract Bar is Foo {} +"#, + ) + .unwrap(); + + let target = project + .add_source( + "Bar.sol", + r#" +pragma solidity ^0.8.10; +import {Foo} from "./Foo.sol"; + +contract Bar is Foo {} +"#, + ) + .unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); assert_eq!( result, r"pragma solidity ^0.8.10; -contract D { } +contract Foo { + function foo() public pure returns (uint256) { + return 1; + } +} -contract C { } +contract Bar_1 is Foo {} -contract B { } +contract Bar_0 is Foo {} +" + ); +} -contract A { } +#[test] +fn can_flatten_complex_aliases_setup_with_duplicates() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "A.sol", + r#" +pragma solidity ^0.8.10; + +contract A { + type SomeCustomValue is uint256; + + struct SomeStruct { + uint256 field; + } + + enum SomeEnum { VALUE1, VALUE2 } + + function foo() public pure returns (uint256) { + return 1; + } +} +"#, + ) + .unwrap(); + + project + .add_source( + "B.sol", + r#" +pragma solidity ^0.8.10; +import "./A.sol" as A_File; + +contract A is A_File.A {} +"#, + ) + .unwrap(); + + project + .add_source( + "C.sol", + r#" +pragma solidity ^0.8.10; +import "./B.sol" as B_File; + +contract A is B_File.A_File.A {} +"#, + ) + .unwrap(); + + let target = project + .add_source( + "D.sol", + r#" +pragma solidity ^0.8.10; +import "./C.sol" as C_File; + +C_File.B_File.A_File.A.SomeCustomValue constant fileLevelValue = C_File.B_File.A_File.A.SomeCustomValue.wrap(1); + +contract D is C_File.B_File.A_File.A { + C_File.B_File.A_File.A.SomeStruct public someStruct; + C_File.B_File.A_File.A.SomeEnum public someEnum = C_File.B_File.A_File.A.SomeEnum.VALUE1; + + constructor() C_File.B_File.A_File.A() { + someStruct = C_File.B_File.A_File.A.SomeStruct(1); + someEnum = C_File.B_File.A_File.A.SomeEnum.VALUE2; + } + + function getSelector() public pure returns (bytes4) { + return C_File.B_File.A_File.A.foo.selector; + } + + function getEnumValue1() public pure returns (C_File.B_File.A_File.A.SomeEnum) { + return C_File.B_File.A_File.A.SomeEnum.VALUE1; + } + + function getStruct() public pure returns (C_File.B_File.A_File.A.SomeStruct memory) { + return C_File.B_File.A_File.A.SomeStruct(1); + } +} +"#,).unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); + assert_eq!( + result, + r"pragma solidity ^0.8.10; + +contract A_0 { + type SomeCustomValue is uint256; + + struct SomeStruct { + uint256 field; + } + + enum SomeEnum { VALUE1, VALUE2 } + + function foo() public pure returns (uint256) { + return 1; + } +} + +contract A_1 is A_0 {} + +contract A_2 is A_0 {} + +A_0.SomeCustomValue constant fileLevelValue = A_0.SomeCustomValue.wrap(1); + +contract D is A_0 { + A_0.SomeStruct public someStruct; + A_0.SomeEnum public someEnum = A_0.SomeEnum.VALUE1; + + constructor() A_0() { + someStruct = A_0.SomeStruct(1); + someEnum = A_0.SomeEnum.VALUE2; + } + + function getSelector() public pure returns (bytes4) { + return A_0.foo.selector; + } + + function getEnumValue1() public pure returns (A_0.SomeEnum) { + return A_0.SomeEnum.VALUE1; + } + + function getStruct() public pure returns (A_0.SomeStruct memory) { + return A_0.SomeStruct(1); + } +} +" + ); +} + +// https://github.com/foundry-rs/compilers/issues/34 +#[test] +fn can_flatten_34_repro() { + let project = TempProject::dapptools().unwrap(); + let target = project + .add_source( + "FlieA.sol", + r#"pragma solidity ^0.8.10; +import {B} from "./FileB.sol"; + +interface FooBar { + function foo() external; +} +contract A { + function execute() external { + FooBar(address(0)).foo(); + } +}"#, + ) + .unwrap(); + + project + .add_source( + "FileB.sol", + r#"pragma solidity ^0.8.10; + +interface FooBar { + function bar() external; +} +contract B { + function execute() external { + FooBar(address(0)).bar(); + } +}"#, + ) + .unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); + assert_eq!( + result, + r#"pragma solidity ^0.8.10; + +interface FooBar_0 { + function bar() external; +} +contract B { + function execute() external { + FooBar_0(address(0)).bar(); + } +} + +interface FooBar_1 { + function foo() external; +} +contract A { + function execute() external { + FooBar_1(address(0)).foo(); + } +} +"# + ); +} + +#[test] +fn can_flatten_experimental_in_other_file() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "A.sol", + r#" +pragma solidity 0.6.12; +pragma experimental ABIEncoderV2; + +contract A {} +"#, + ) + .unwrap(); + + let target = project + .add_source( + "B.sol", + r#" +pragma solidity 0.6.12; + +import "./A.sol"; + +contract B is A {} +"#, + ) + .unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); + assert_eq!( + result, + r"pragma solidity 0.6.12; +pragma experimental ABIEncoderV2; + +contract A {} + +contract B is A {} " ); } @@ -996,6 +1299,41 @@ fn can_detect_type_error() { assert!(compiled.has_compiler_errors()); } +#[test] +fn can_flatten_aliases_with_pragma_and_license_after_source() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "A", + r#"pragma solidity ^0.8.10; +contract A { } +"#, + ) + .unwrap(); + + let target = project + .add_source( + "B", + r#"contract B is AContract {} +import {A as AContract} from "./A.sol"; +pragma solidity ^0.8.10;"#, + ) + .unwrap(); + + test_flatteners(&project, &target, |result| { + assert_eq!( + result, + r"pragma solidity ^0.8.10; + +contract A { } + +contract B is A {} +" + ); + }); +} + #[test] fn can_compile_single_files() { let tmp = TempProject::dapptools().unwrap();