From b7be43977fd6643b5e1024f79b6a4bf88eadb478 Mon Sep 17 00:00:00 2001 From: MilkeeyCat Date: Tue, 29 Oct 2024 14:11:53 +0200 Subject: [PATCH] refactor(parser): new kind of ast node `Item` closes #84 --- src/archs/amd64/amd64.rs | 29 +++--- src/codegen/codegen.rs | 35 +++++--- src/compile.rs | 17 ++-- src/parser/item.rs | 53 +++++++++++ src/parser/mod.rs | 10 ++- src/parser/parser.rs | 155 +++++++++++++++----------------- src/parser/stmt.rs | 30 +------ src/passes/macro_expansion.rs | 4 +- src/passes/mod.rs | 6 +- src/passes/pass.rs | 9 -- src/passes/symbol_resolver.rs | 161 ++++++++++++++++++++-------------- src/passes/type_checker.rs | 72 ++++++++++----- src/type_table.rs | 12 +-- 13 files changed, 333 insertions(+), 260 deletions(-) create mode 100644 src/parser/item.rs delete mode 100644 src/passes/pass.rs diff --git a/src/archs/amd64/amd64.rs b/src/archs/amd64/amd64.rs index 48e33a2..2b9af05 100644 --- a/src/archs/amd64/amd64.rs +++ b/src/archs/amd64/amd64.rs @@ -4,7 +4,7 @@ use crate::{ operands::{self, Base, EffectiveAddress, Immediate, Memory, Offset}, Argument, Destination, Source, }, - parser::{BitwiseOp, Block, CmpOp, Stmt}, + parser::{BitwiseOp, Block, CmpOp, Item, Stmt}, register::{ allocator::{AllocatorError, RegisterAllocator}, Register, @@ -671,10 +671,10 @@ impl Architecture for Amd64 { for stmt in &block.0 { match stmt { - Stmt::VarDecl(stmt2) => { - offset -= self.size(&stmt2.type_, scopes, scope_id) as isize; + Stmt::Item(Item::Variable(item)) => { + offset -= self.size(&item.type_, scopes, scope_id) as isize; - match scopes.get_symbol_mut(&stmt2.name, scope_id).unwrap() { + match scopes.get_symbol_mut(&item.name, scope_id).unwrap() { Symbol::Local(local) => { local.offset = Offset(offset); } @@ -694,23 +694,22 @@ impl Architecture for Amd64 { offset = self.populate_offsets(&stmt.block, scopes, scope_id + 1, offset)?; } Stmt::For(stmt) => { - if let Some(Stmt::VarDecl(stmt2)) = stmt.initializer.as_deref() { + if let Some(Stmt::Item(Item::Variable(stmt2))) = stmt.initializer.as_deref() { offset -= self.size(&stmt2.type_, scopes, scope_id) as isize; - //FIXME - unreachable!(); - //match stmt.block.scope.symbol_table.find_mut(&stmt2.name).unwrap() { - // Symbol::Local(local) => { - // local.offset = Offset(offset); - // } - // _ => unreachable!(), - //}; + match scopes.get_symbol_mut(&stmt2.name, scope_id).unwrap() { + Symbol::Local(local) => { + local.offset = Offset(offset); + } + _ => unreachable!(), + }; } offset = self.populate_offsets(&stmt.block, scopes, scope_id + 1, offset)?; } - Stmt::Return(_) | Stmt::Expr(_) | Stmt::Continue | Stmt::Break => (), - Stmt::Function(_) => unreachable!(), + Stmt::Return(_) | Stmt::Expr(_) | Stmt::Continue | Stmt::Break | Stmt::Item(_) => { + () + } } } diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs index 137d09c..f4f6136 100644 --- a/src/codegen/codegen.rs +++ b/src/codegen/codegen.rs @@ -7,7 +7,7 @@ use crate::{ parser::{ BinOp, BitwiseOp, CmpOp, Expr, ExprArray, ExprArrayAccess, ExprBinary, ExprFunctionCall, ExprIdent, ExprLit, ExprStruct, ExprStructAccess, ExprStructMethod, ExprUnary, Expression, - Stmt, StmtFor, StmtFunction, StmtIf, StmtReturn, StmtVarDecl, StmtWhile, UnOp, + Item, ItemFn, ItemVariable, Stmt, StmtFor, StmtIf, StmtReturn, StmtWhile, UnOp, }, register::Register, scope::Scopes, @@ -45,7 +45,7 @@ impl CodeGen { } } - fn declare(&mut self, variable: StmtVarDecl) -> Result<(), CodeGenError> { + fn declare(&mut self, variable: ItemVariable) -> Result<(), CodeGenError> { if self.scope_id == 0 { self.arch.declare( &variable.name, @@ -68,7 +68,9 @@ impl CodeGen { Ok(()) } - fn function(&mut self, func: StmtFunction) -> Result<(), CodeGenError> { + fn function(&mut self, func: ItemFn) -> Result<(), CodeGenError> { + self.scope_id += 1; + let offset = self .arch .populate_offsets( @@ -80,11 +82,10 @@ impl CodeGen { .unsigned_abs() .next_multiple_of(self.arch.stack_alignment()); - self.scope_id += 1; - self.arch.fn_preamble( - &func.name, + &func.signature.name, &func + .signature .params .iter() .map(|(_, type_)| type_.to_owned()) @@ -94,7 +95,7 @@ impl CodeGen { self.scope_id, )?; self.scope_infos.push(ScopeInfo::Function { - label: func.name.clone(), + label: func.signature.name.clone(), }); for stmt in func.block.0 { @@ -102,7 +103,7 @@ impl CodeGen { } self.scope_infos.pop(); - self.arch.fn_postamble(&func.name, offset); + self.arch.fn_postamble(&func.signature.name, offset); Ok(()) } @@ -724,9 +725,8 @@ impl CodeGen { fn stmt(&mut self, stmt: Stmt) -> Result<(), CodeGenError> { match stmt { + Stmt::Item(item) => self.item(item), Stmt::Expr(expr) => self.expr(expr, None, None).map(|_| ()), - Stmt::VarDecl(var_decl) => self.declare(var_decl), - Stmt::Function(func) => self.function(func), Stmt::Return(ret) => self.ret(ret), Stmt::If(stmt) => self.if_stmt(stmt), Stmt::While(stmt) => self.while_stmt(stmt), @@ -1127,9 +1127,18 @@ impl CodeGen { }) } - pub fn compile(&mut self, program: Vec) -> Result, CodeGenError> { - for stmt in program { - self.stmt(stmt)?; + pub fn item(&mut self, item: Item) -> Result<(), CodeGenError> { + Ok(match item { + Item::Variable(item) => self.declare(item)?, + Item::Fn(item) => self.function(item)?, + Item::Struct(_) => todo!(), + Item::ForeignFn(_) => (), + }) + } + + pub fn compile(&mut self, program: Vec) -> Result, CodeGenError> { + for item in program { + self.item(item)?; } Ok(self.arch.finish()) diff --git a/src/compile.rs b/src/compile.rs index 5f042a3..a136a49 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -3,7 +3,7 @@ use crate::{ codegen::CodeGen, lexer::Lexer, parser, - passes::{MacroExpansion, SymbolResolver, TypeChecker}, + passes::{SymbolResolver, TypeChecker}, }; use clap::Parser; use std::{ @@ -46,16 +46,17 @@ pub fn compile(args: CompileArgs) -> Result<(), Box> { file.read_to_string(&mut source_code)?; let lexer = Lexer::new(source_code); - let mut stmts = parser::Parser::new(lexer)?.parse()?; + let mut items = parser::Parser::new(lexer)?.parse()?; + //MacroExpansion::new(args.macro_libs).expand(&mut items); + let scopes = SymbolResolver::new().run_pass(&mut items)?; - MacroExpansion::new(args.macro_libs).expand(&mut stmts); - let scopes = SymbolResolver::new().run_pass(&mut stmts)?; - TypeChecker::new(&scopes).run_pass(&stmts)?; - - dbg!(&stmts); dbg!(&scopes); - let code = CodeGen::new(Box::new(Amd64::new()), scopes).compile(stmts)?; + TypeChecker::new(&scopes).run_pass(&items)?; + + dbg!(&items); + + let code = CodeGen::new(Box::new(Amd64::new()), scopes).compile(items)?; if args.assembly_only { let asm_filename = args.file.with_extension("s"); diff --git a/src/parser/item.rs b/src/parser/item.rs new file mode 100644 index 0000000..ad45f27 --- /dev/null +++ b/src/parser/item.rs @@ -0,0 +1,53 @@ +use super::{Block, Expr}; +use crate::types::Type; + +#[derive(Debug, Clone, PartialEq)] +pub enum Item { + Variable(ItemVariable), + Struct(ItemStruct), + Fn(ItemFn), + ForeignFn(ItemForeignFn), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ItemVariable { + pub type_: Type, + pub name: String, + pub value: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ItemStruct { + pub name: String, + pub fields: Vec<(String, Type)>, + pub methods: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum MethodKind { + Static, + Instance, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Method { + pub kind: MethodKind, + pub signature: Signature, + pub block: Block, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Signature { + pub name: String, + pub params: Vec<(String, Type)>, + pub return_type: Type, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ItemFn { + pub signature: Signature, + pub block: Block, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ItemForeignFn(pub Signature); diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 0ac61b7..f7d6ad5 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,13 +1,19 @@ mod error; -pub mod expr; +mod item; mod op; mod parser; mod precedence; mod stmt; +pub mod expr; + pub use error::ParserError; pub use expr::*; +pub use item::{Item, ItemFn, ItemForeignFn, ItemStruct, ItemVariable, Signature}; pub use op::{BinOp, BitwiseOp, CmpOp, OpParseError, UnOp}; pub use parser::Parser; pub use precedence::Precedence; -pub use stmt::{Block, Stmt, StmtFor, StmtFunction, StmtIf, StmtReturn, StmtVarDecl, StmtWhile}; +pub use stmt::{Stmt, StmtFor, StmtIf, StmtReturn, StmtWhile}; + +#[derive(Debug, Clone, PartialEq)] +pub struct Block(pub Vec); diff --git a/src/parser/parser.rs b/src/parser/parser.rs index 283509e..784d972 100644 --- a/src/parser/parser.rs +++ b/src/parser/parser.rs @@ -1,14 +1,14 @@ use super::{ expr::{ExprBinary, ExprLit, ExprUnary}, + item::{ItemForeignFn, ItemStruct, ItemVariable, Method, MethodKind, Signature}, precedence::Precedence, stmt::{StmtFor, StmtIf, StmtReturn, StmtWhile}, BinOp, Block, Expr, ExprArray, ExprArrayAccess, ExprCast, ExprIdent, ExprStruct, - ExprStructMethod, MacroCall, ParserError, Stmt, StmtFunction, StmtVarDecl, UIntLitRepr, UnOp, + ExprStructMethod, Item, ItemFn, MacroCall, ParserError, Stmt, UIntLitRepr, UnOp, }; use crate::{ lexer::{LexerError, Token}, parser::{ExprFunctionCall, ExprStructAccess}, - type_table::TypeStructMethod, types::{IntType, Type, TypeArray, TypeError, UintType}, }; use std::collections::HashMap; @@ -20,7 +20,6 @@ pub struct Parser>> { lexer: T, cur_token: Option, peek_token: Option, - global_stms: Vec, prefix_fns: HashMap>, infix_fns: HashMap>, } @@ -31,7 +30,6 @@ impl>> Parser { cur_token: lexer.next().transpose()?, peek_token: lexer.next().transpose()?, lexer, - global_stms: Vec::new(), prefix_fns: HashMap::from([ (Token::Ident(Default::default()), Self::ident as PrefixFn), (Token::String(Default::default()), Self::string_lit), @@ -101,24 +99,18 @@ impl>> Parser { } } - pub fn parse(&mut self) -> Result, ParserError> { + pub fn parse(&mut self) -> Result, ParserError> { let mut stmts = Vec::new(); while let Some(token) = &self.cur_token { - match token { + stmts.push(match token { Token::Struct => self.parse_struct()?, - Token::Let => stmts.push(self.var_decl()?), - Token::Fn => { - if let Some(stmt) = self.function(true)? { - stmts.push(stmt) - } - } + Token::Let => self.var_decl()?, + Token::Fn => self.function(true)?, _ => unreachable!(), - } + }); } - stmts.extend_from_slice(&self.global_stms); - Ok(stmts) } @@ -152,7 +144,7 @@ impl>> Parser { left } - fn parse_struct(&mut self) -> Result<(), ParserError> { + fn parse_struct(&mut self) -> Result { self.expect(&Token::Struct)?; let name = match self @@ -193,17 +185,15 @@ impl>> Parser { let type_ = self.parse_type()?; let block = self.compound_statement()?; - methods.push(TypeStructMethod { - return_type: type_.clone(), - name: method_name.clone(), - params: params.clone(), - }); - self.global_stms.push(Stmt::Function(StmtFunction { - return_type: type_, - name: format!("{name}__{method_name}"), - params, + methods.push(Method { + kind: MethodKind::Instance, + signature: Signature { + name: method_name, + params, + return_type: type_, + }, block, - })); + }) } else { let name = match self.next_token()? { Some(Token::Ident(ident)) => ident, @@ -226,35 +216,39 @@ impl>> Parser { self.expect(&Token::RBrace)?; - Ok(()) + Ok(Item::Struct(ItemStruct { + name, + fields, + methods, + })) } - fn stmt(&mut self) -> Result, ParserError> { + fn stmt(&mut self) -> Result { match self.cur_token.as_ref().unwrap() { - Token::Return => Ok(Some(self.parse_return()?)), - Token::If => Ok(Some(self.if_stmt()?)), - Token::While => Ok(Some(self.while_stmt()?)), - Token::For => Ok(Some(self.for_stmt()?)), - Token::Let => Ok(Some(self.var_decl()?)), + Token::Return => self.parse_return(), + Token::If => self.if_stmt(), + Token::While => self.while_stmt(), + Token::For => self.for_stmt(), + Token::Let => Ok(Stmt::Item(self.var_decl()?)), Token::Continue => { self.expect(&Token::Continue)?; self.expect(&Token::Semicolon)?; - Ok(Some(Stmt::Continue)) + Ok(Stmt::Continue) } Token::Break => { self.expect(&Token::Break)?; self.expect(&Token::Semicolon)?; - Ok(Some(Stmt::Break)) + Ok(Stmt::Break) } - Token::Fn => self.function(true), + Token::Fn => Ok(Stmt::Item(self.function(true)?)), _ => { let expr = Stmt::Expr(self.expr(Precedence::default())?); self.expect(&Token::Semicolon)?; - Ok(Some(expr)) + Ok(expr) } } } @@ -265,9 +259,7 @@ impl>> Parser { self.expect(&Token::LBrace)?; while !self.cur_token_is(&Token::RBrace) { - if let Some(stmt) = self.stmt()? { - stmts.push(stmt); - } + stmts.push(self.stmt()?); } self.expect(&Token::RBrace)?; @@ -280,9 +272,7 @@ impl>> Parser { let mut stmts = Vec::new(); while self.cur_token.is_some() { - if let Some(stmt) = self.stmt()? { - stmts.push(stmt); - } + stmts.push(self.stmt()?); } Ok(stmts) @@ -388,7 +378,7 @@ impl>> Parser { None } else { let stmt = if self.cur_token_is(&Token::Let) { - self.var_decl()? + Stmt::Item(self.var_decl()?) } else { Stmt::Expr(self.expr(Precedence::default())?) }; @@ -440,7 +430,7 @@ impl>> Parser { Ok(()) } - fn var_decl(&mut self) -> Result { + fn var_decl(&mut self) -> Result { self.expect(&Token::Let)?; let name = match self.next_token()?.unwrap() { @@ -458,7 +448,7 @@ impl>> Parser { self.array_type(&mut type_)?; - let expr = if self.cur_token_is(&Token::Assign) { + let value = if self.cur_token_is(&Token::Assign) { self.expect(&Token::Assign)?; Some(self.expr(Precedence::default())?) @@ -468,10 +458,10 @@ impl>> Parser { self.expect(&Token::Semicolon)?; - Ok(Stmt::VarDecl(StmtVarDecl::new(type_, name, expr))) + Ok(Item::Variable(ItemVariable { name, type_, value })) } - fn function(&mut self, func_definition: bool) -> Result, ParserError> { + fn function(&mut self, func_definition: bool) -> Result { self.expect(&Token::Fn)?; let name = match self.next_token()?.unwrap() { @@ -497,17 +487,18 @@ impl>> Parser { panic!("Function definition is not supported here"); } + let signature = Signature { + name, + params, + return_type: type_, + }; + if let Some(block) = block { - Ok(Some(Stmt::Function(StmtFunction { - return_type: type_, - name, - params, - block, - }))) + Ok(Item::Fn(ItemFn { signature, block })) } else { self.expect(&Token::Semicolon)?; - Ok(None) + Ok(Item::ForeignFn(ItemForeignFn(signature))) } } @@ -787,8 +778,8 @@ mod test { use crate::{ lexer::Lexer, parser::{ - BinOp, Expr, ExprBinary, ExprCast, ExprIdent, ExprLit, ExprUnary, ParserError, Stmt, - StmtVarDecl, UIntLitRepr, UnOp, + item::ItemVariable, BinOp, Expr, ExprBinary, ExprCast, ExprIdent, ExprLit, ExprUnary, + Item, ParserError, Stmt, UIntLitRepr, UnOp, }, types::{IntType, Type, UintType}, }; @@ -831,11 +822,11 @@ mod test { } ", vec![ - Stmt::VarDecl(StmtVarDecl::new( - Type::UInt(UintType::U8), - "foo".to_owned(), - None, - )), + Stmt::Item(Item::Variable(ItemVariable { + name: "foo".to_owned(), + type_: Type::UInt(UintType::U8), + value: None, + })), Stmt::Expr(Expr::Binary(ExprBinary { op: BinOp::Assign, left: Box::new(Expr::Ident(ExprIdent("foo".to_owned()))), @@ -862,16 +853,16 @@ mod test { } ", vec![ - Stmt::VarDecl(StmtVarDecl::new( - Type::UInt(UintType::U8), - "foo".to_owned(), - None, - )), - Stmt::VarDecl(StmtVarDecl::new( - Type::Int(IntType::I8), - "bar".to_owned(), - None, - )), + Stmt::Item(Item::Variable(ItemVariable { + name: "foo".to_owned(), + type_: Type::UInt(UintType::U8), + value: None, + })), + Stmt::Item(Item::Variable(ItemVariable { + name: "bar".to_owned(), + type_: Type::Int(IntType::I8), + value: None, + })), Stmt::Expr(Expr::Binary(ExprBinary { op: BinOp::Assign, left: Box::new(Expr::Ident(ExprIdent("bar".to_owned()))), @@ -919,16 +910,16 @@ mod test { } ", vec![ - Stmt::VarDecl(StmtVarDecl::new( - Type::UInt(UintType::U8), - "a".to_owned(), - None, - )), - Stmt::VarDecl(StmtVarDecl::new( - Type::UInt(UintType::U8), - "b".to_owned(), - None, - )), + Stmt::Item(Item::Variable(ItemVariable { + name: "a".to_owned(), + type_: Type::UInt(UintType::U8), + value: None, + })), + Stmt::Item(Item::Variable(ItemVariable { + name: "b".to_owned(), + type_: Type::UInt(UintType::U8), + value: None, + })), Stmt::Expr(Expr::Binary(ExprBinary { op: BinOp::Assign, left: Box::new(Expr::Ident(ExprIdent("a".to_owned()))), diff --git a/src/parser/stmt.rs b/src/parser/stmt.rs index 14e226f..4a37a6c 100644 --- a/src/parser/stmt.rs +++ b/src/parser/stmt.rs @@ -1,11 +1,9 @@ -use super::Expr; -use crate::types::Type; +use super::{Block, Expr, Item}; #[derive(Debug, Clone, PartialEq)] pub enum Stmt { - VarDecl(StmtVarDecl), + Item(Item), Expr(Expr), - Function(StmtFunction), Return(StmtReturn), If(StmtIf), While(StmtWhile), @@ -14,22 +12,6 @@ pub enum Stmt { Break, } -#[derive(Debug, Clone, PartialEq)] -pub struct Block(pub Vec); - -#[derive(Debug, Clone, PartialEq)] -pub struct StmtVarDecl { - pub type_: Type, - pub name: String, - pub value: Option, -} - -impl StmtVarDecl { - pub fn new(type_: Type, name: String, value: Option) -> Self { - Self { type_, name, value } - } -} - #[derive(Debug, Clone, PartialEq)] pub struct StmtReturn { pub expr: Option, @@ -55,11 +37,3 @@ pub struct StmtFor { pub increment: Option, pub block: Block, } - -#[derive(Debug, Clone, PartialEq)] -pub struct StmtFunction { - pub return_type: Type, - pub name: String, - pub params: Vec<(String, Type)>, - pub block: Block, -} diff --git a/src/passes/macro_expansion.rs b/src/passes/macro_expansion.rs index a92296d..49fc52e 100644 --- a/src/passes/macro_expansion.rs +++ b/src/passes/macro_expansion.rs @@ -1,7 +1,7 @@ use crate::{ lexer::{self, LexerError}, macros::{self, symbol_to_macros, Macro, MacroFn, Slice}, - parser::{Block, Expr, Parser, Precedence, Stmt}, + parser::{Block, Expr, Item, Parser, Precedence, Stmt}, }; use libloading::Library; use std::ffi::OsStr; @@ -22,7 +22,7 @@ impl MacroExpansion { } } - pub fn expand(self, stmts: &mut Vec) { + pub fn expand(self, stmts: &mut Vec) { (0..stmts.len()).for_each(|i| self.check_stmt(stmts, i)); } diff --git a/src/passes/mod.rs b/src/passes/mod.rs index b10778f..2ff10d4 100644 --- a/src/passes/mod.rs +++ b/src/passes/mod.rs @@ -1,9 +1,7 @@ -mod macro_expansion; -mod pass; +//mod macro_expansion; mod symbol_resolver; mod type_checker; -pub use macro_expansion::MacroExpansion; -pub use pass::Pass; +//pub use macro_expansion::MacroExpansion; pub use symbol_resolver::SymbolResolver; pub use type_checker::TypeChecker; diff --git a/src/passes/pass.rs b/src/passes/pass.rs deleted file mode 100644 index a118f30..0000000 --- a/src/passes/pass.rs +++ /dev/null @@ -1,9 +0,0 @@ -use crate::{parser::Stmt, scope::Scope}; - -pub trait Pass { - type Output; - type State; - - fn new(state: Self::State) -> Self; - fn run_pass(self, stmts: &mut Vec, scope: &mut Scope) -> Self::Output; -} diff --git a/src/passes/symbol_resolver.rs b/src/passes/symbol_resolver.rs index cca6761..1c33c01 100644 --- a/src/passes/symbol_resolver.rs +++ b/src/passes/symbol_resolver.rs @@ -1,9 +1,9 @@ use crate::{ codegen::Offset, - parser::{Block, Expr, Expression, ParserError, Stmt}, + parser::{Block, Expr, Expression, Item, ItemFn, ItemForeignFn, ParserError, Stmt}, scope::Scopes, symbol_table::{Symbol, SymbolFunction, SymbolGlobal, SymbolLocal, SymbolParam}, - type_table as tt, + type_table::{self as tt, TypeStruct}, types::{Type, TypeError}, }; @@ -20,62 +20,41 @@ impl SymbolResolver { } } - pub fn run_pass(mut self, stmts: &mut Vec) -> Result { + pub fn run_pass(mut self, items: &mut [Item]) -> Result { self.scopes.enter_new_global(); // Resolve globals - for stmt in &mut *stmts { - match stmt { - Stmt::Function(stmt) => { - stmt.params - .iter() - .enumerate() - .for_each(|(i, (name, type_))| { - self.scopes - .push_symbol( - name.to_owned(), - Symbol::Param(SymbolParam { - preceding: stmt.params[..i] - .iter() - .map(|(_, type_)| type_.to_owned()) - .collect(), - type_: type_.to_owned(), - offset: Offset::default(), - }), - ) - .unwrap(); - }); - - self.scopes.push_symbol( - stmt.name.clone(), - Symbol::Function(SymbolFunction { - return_type: stmt.return_type.clone(), - parameters: stmt - .params - .clone() - .into_iter() - .map(|(_, type_)| type_) - .collect(), - }), - )?; - } - Stmt::VarDecl(stmt) => { - self.scopes.push_symbol( - stmt.name.clone(), - Symbol::Global(SymbolGlobal { - label: stmt.name.clone(), - type_: stmt.type_.clone(), - }), - )?; - } - _ => unreachable!(), - } + for item in &mut *items { + self.remove_item(item)?; } - // Resolve locals - for stmt in stmts { - if let Stmt::Function(stmt) = stmt { - self.resolve_block(&mut stmt.block)?; + for item in items { + if let Item::Fn(item) = item { + let tmp = self.scope_id; + self.scope_id = self.scopes.enter_new_local(self.scope_id); + + item.signature + .params + .iter() + .enumerate() + .for_each(|(i, (name, type_))| { + self.scopes + .push_symbol( + name.to_owned(), + Symbol::Param(SymbolParam { + preceding: item.signature.params[..i] + .iter() + .map(|(_, type_)| type_.to_owned()) + .collect(), + type_: type_.to_owned(), + offset: Offset::default(), + }), + ) + .unwrap(); + }); + + self.resolve_block(&mut item.block)?; + self.scope_id = tmp; } } @@ -83,33 +62,69 @@ impl SymbolResolver { } fn resolve_block(&mut self, block: &Block) -> Result<(), ParserError> { - let tmp = self.scope_id; - self.scope_id = self.scopes.enter_new_local(self.scope_id); - for stmt in &block.0 { self.resolve_stmt(stmt)?; } - self.scope_id = tmp; - Ok(()) } - fn resolve_stmt(&mut self, stmt: &Stmt) -> Result<(), ParserError> { - Ok(match stmt { - Stmt::VarDecl(stmt) => { + fn remove_item(&mut self, item: &Item) -> Result<(), ParserError> { + Ok(match item { + Item::Fn(ItemFn { signature, .. }) | Item::ForeignFn(ItemForeignFn(signature)) => { self.scopes.push_symbol( - stmt.name.clone(), - Symbol::Local(SymbolLocal { - type_: stmt.type_.clone(), - offset: Offset::default(), + signature.name.clone(), + Symbol::Function(SymbolFunction { + return_type: signature.return_type.clone(), + parameters: signature + .params + .clone() + .into_iter() + .map(|(_, type_)| type_) + .collect(), }), )?; } + Item::Variable(item) => { + if self.scope_id > 0 { + self.scopes.push_symbol( + item.name.clone(), + Symbol::Local(SymbolLocal { + offset: Default::default(), + type_: item.type_.clone(), + }), + )?; + } else { + self.scopes.push_symbol( + item.name.clone(), + Symbol::Global(SymbolGlobal { + label: item.name.clone(), + type_: item.type_.clone(), + }), + )?; + } + } + Item::Struct(item) => { + self.scopes.push_type(tt::Type::Struct(TypeStruct { + name: item.name.clone(), + fields: item.fields.clone(), + methods: item + .methods + .clone() + .into_iter() + .map(|method| method.signature) + .collect(), + })); + } + }) + } + + fn resolve_stmt(&mut self, stmt: &Stmt) -> Result<(), ParserError> { + Ok(match stmt { + Stmt::Item(item) => self.remove_item(item)?, Stmt::Expr(expr) => { self.resolve_expr(expr)?; } - Stmt::Function(_) => unreachable!(), Stmt::Return(stmt) => { if let Some(expr) = &stmt.expr { self.resolve_expr(expr)?; @@ -117,15 +132,26 @@ impl SymbolResolver { } Stmt::If(stmt) => { self.resolve_expr(&stmt.condition)?; + + let tmp = self.scope_id; + self.scope_id = self.scopes.enter_new_local(self.scope_id); self.resolve_block(&stmt.consequence)?; + self.scope_id = tmp; if let Some(alternative) = &stmt.alternative { + let tmp = self.scope_id; + self.scope_id = self.scopes.enter_new_local(self.scope_id); self.resolve_block(alternative)?; + self.scope_id = tmp; } } Stmt::While(stmt) => { self.resolve_expr(&stmt.condition)?; + + let tmp = self.scope_id; + self.scope_id = self.scopes.enter_new_local(self.scope_id); self.resolve_block(&stmt.block)?; + self.scope_id = tmp; } Stmt::For(stmt) => { if let Some(initializer) = &stmt.initializer { @@ -140,7 +166,10 @@ impl SymbolResolver { self.resolve_expr(increment)?; } + let tmp = self.scope_id; + self.scope_id = self.scopes.enter_new_local(self.scope_id); self.resolve_block(&stmt.block)?; + self.scope_id = tmp; } Stmt::Continue | Stmt::Break => (), }) diff --git a/src/passes/type_checker.rs b/src/passes/type_checker.rs index 6f1f3b1..4f41075 100644 --- a/src/passes/type_checker.rs +++ b/src/passes/type_checker.rs @@ -1,5 +1,5 @@ use crate::{ - parser::{BinOp, Block, Expr, ExprBinary, Expression, ParserError, Stmt, UnOp}, + parser::{BinOp, Block, Expr, ExprBinary, Expression, Item, ParserError, Stmt, UnOp}, scope::Scopes, symbol_table::SymbolTableError, type_table as tt, @@ -22,16 +22,61 @@ impl<'a> TypeChecker<'a> { } } - pub fn run_pass(mut self, stmts: &[Stmt]) -> Result<(), ParserError> { + pub fn run_pass(mut self, items: &[Item]) -> Result<(), ParserError> { self.check_type_table()?; - for stmt in stmts { - self.check_stmt(stmt)?; + for item in items { + self.check_item(item)?; } Ok(()) } + fn check_item(&mut self, item: &Item) -> Result<(), ParserError> { + Ok(match item { + Item::Fn(item) => { + self.check_type(&item.signature.return_type)?; + self.return_type = Some(item.signature.return_type.clone()); + + for (_, type_) in &item.signature.params { + self.check_type(type_)?; + } + + self.check_block(&item.block)?; + self.return_type = None; + } + Item::Variable(item) => { + self.check_type(&item.type_)?; + + if let Some(expr) = &item.value { + self.check_assign(item.type_.clone(), expr)?; + } + } + Item::Struct(item) => { + for (_, type_) in &item.fields { + self.check_type(type_)?; + } + + for method in &item.methods { + self.check_type(&method.signature.return_type)?; + + for (_, type_) in &method.signature.params { + self.check_type(type_)?; + } + + self.check_block(&method.block)?; + } + } + Item::ForeignFn(item) => { + self.check_type(&item.0.return_type)?; + + for (_, type_) in &item.0.params { + self.check_type(type_)?; + } + } + }) + } + fn check_block(&mut self, block: &Block) -> Result<(), ParserError> { self.scope_id += 1; @@ -44,27 +89,10 @@ impl<'a> TypeChecker<'a> { fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), ParserError> { Ok(match stmt { - Stmt::VarDecl(stmt) => { - self.check_type(&stmt.type_)?; - - if let Some(expr) = &stmt.value { - self.check_assign(stmt.type_.clone(), expr)?; - } - } + Stmt::Item(item) => self.check_item(item)?, Stmt::Expr(expr) => { self.check_expr(expr)?; } - Stmt::Function(stmt) => { - self.check_type(&stmt.return_type)?; - self.return_type = Some(stmt.return_type.clone()); - - for (_, type_) in &stmt.params { - self.check_type(type_)?; - } - - self.return_type = None; - self.check_block(&stmt.block)?; - } Stmt::Return(stmt) => { if let Some(expr) = &stmt.expr { self.check_assign(self.return_type.clone().unwrap(), expr)?; diff --git a/src/type_table.rs b/src/type_table.rs index 505a0d1..e571250 100644 --- a/src/type_table.rs +++ b/src/type_table.rs @@ -1,6 +1,7 @@ use crate::{ archs::Arch, codegen::Offset, + parser::Signature, scope::Scopes, types::{self, TypeError}, }; @@ -14,14 +15,7 @@ pub enum Type { pub struct TypeStruct { pub name: String, pub fields: Vec<(String, types::Type)>, - pub methods: Vec, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct TypeStructMethod { - pub return_type: types::Type, - pub name: String, - pub params: Vec<(String, types::Type)>, + pub methods: Vec, } impl TypeStruct { @@ -54,7 +48,7 @@ impl TypeStruct { .map(|(_, type_)| type_) } - pub fn find_method(&self, name: &str) -> Option<&TypeStructMethod> { + pub fn find_method(&self, name: &str) -> Option<&Signature> { self.methods.iter().find(|method| method.name == name) }