diff --git a/crates/red_knot/src/db.rs b/crates/red_knot/src/db.rs index 7ddb426336c81..cb50307253121 100644 --- a/crates/red_knot/src/db.rs +++ b/crates/red_knot/src/db.rs @@ -1,4 +1,3 @@ -use std::path::Path; use std::sync::Arc; pub use jars::{HasJar, HasJars}; @@ -7,12 +6,12 @@ pub use runtime::DbRuntime; pub use storage::JarsStorage; use crate::files::FileId; -use crate::lint::{Diagnostics, LintSemanticStorage, LintSyntaxStorage}; -use crate::module::{Module, ModuleData, ModuleName, ModuleResolver, ModuleSearchPath}; -use crate::parse::{Parsed, ParsedStorage}; -use crate::source::{Source, SourceStorage}; -use crate::symbols::{SymbolId, SymbolTable, SymbolTablesStorage}; -use crate::types::{Type, TypeStore}; +use crate::lint::{LintSemanticStorage, LintSyntaxStorage}; +use crate::module::ModuleResolver; +use crate::parse::ParsedStorage; +use crate::source::SourceStorage; +use crate::symbols::SymbolTablesStorage; +use crate::types::TypeStore; mod jars; mod query; @@ -61,6 +60,8 @@ pub trait ParallelDatabase: Database + Send { fn snapshot(&self) -> Snapshot; } +pub trait DbWithJar: Database + HasJar {} + /// Readonly snapshot of a database. /// /// ## Dead locks @@ -96,45 +97,24 @@ where } } +pub trait Upcast { + fn upcast(&self) -> &T; +} + // Red knot specific databases code. -pub trait SourceDb: Database { +pub trait SourceDb: DbWithJar { // queries fn file_id(&self, path: &std::path::Path) -> FileId; fn file_path(&self, file_id: FileId) -> Arc; - - fn source(&self, file_id: FileId) -> QueryResult; - - fn parse(&self, file_id: FileId) -> QueryResult; -} - -pub trait SemanticDb: SourceDb { - // queries - fn resolve_module(&self, name: ModuleName) -> QueryResult>; - - fn file_to_module(&self, file_id: FileId) -> QueryResult>; - - fn path_to_module(&self, path: &Path) -> QueryResult>; - - fn symbol_table(&self, file_id: FileId) -> QueryResult>; - - fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult; - - // mutations - - fn add_module(&mut self, path: &Path) -> Option<(Module, Vec>)>; - - fn set_module_search_paths(&mut self, paths: Vec); } -pub trait LintDb: SemanticDb { - fn lint_syntax(&self, file_id: FileId) -> QueryResult; +pub trait SemanticDb: SourceDb + DbWithJar + Upcast {} - fn lint_semantic(&self, file_id: FileId) -> QueryResult; -} +pub trait LintDb: SemanticDb + DbWithJar + Upcast {} -pub trait Db: LintDb {} +pub trait Db: LintDb + Upcast {} #[derive(Debug, Default)] pub struct SourceJar { @@ -161,19 +141,10 @@ pub(crate) mod tests { use std::sync::Arc; use crate::db::{ - Database, DbRuntime, HasJar, HasJars, JarsStorage, LintDb, LintJar, QueryResult, SourceDb, - SourceJar, + Database, DbRuntime, DbWithJar, HasJar, HasJars, JarsStorage, LintDb, LintJar, QueryResult, + SourceDb, SourceJar, Upcast, }; use crate::files::{FileId, Files}; - use crate::lint::{lint_semantic, lint_syntax, Diagnostics}; - use crate::module::{ - add_module, file_to_module, path_to_module, resolve_module, set_module_search_paths, - Module, ModuleData, ModuleName, ModuleSearchPath, - }; - use crate::parse::{parse, Parsed}; - use crate::source::{source_text, Source}; - use crate::symbols::{symbol_table, SymbolId, SymbolTable}; - use crate::types::{infer_symbol_type, Type}; use super::{SemanticDb, SemanticJar}; @@ -223,56 +194,36 @@ pub(crate) mod tests { fn file_path(&self, file_id: FileId) -> Arc { self.files.path(file_id) } - - fn source(&self, file_id: FileId) -> QueryResult { - source_text(self, file_id) - } - - fn parse(&self, file_id: FileId) -> QueryResult { - parse(self, file_id) - } } - impl SemanticDb for TestDb { - fn resolve_module(&self, name: ModuleName) -> QueryResult> { - resolve_module(self, name) - } - - fn file_to_module(&self, file_id: FileId) -> QueryResult> { - file_to_module(self, file_id) - } + impl DbWithJar for TestDb {} - fn path_to_module(&self, path: &Path) -> QueryResult> { - path_to_module(self, path) + impl Upcast for TestDb { + fn upcast(&self) -> &(dyn SourceDb + 'static) { + self } + } - fn symbol_table(&self, file_id: FileId) -> QueryResult> { - symbol_table(self, file_id) - } + impl SemanticDb for TestDb {} - fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult { - infer_symbol_type(self, file_id, symbol_id) - } + impl DbWithJar for TestDb {} - fn add_module(&mut self, path: &Path) -> Option<(Module, Vec>)> { - add_module(self, path) - } - - fn set_module_search_paths(&mut self, paths: Vec) { - set_module_search_paths(self, paths); + impl Upcast for TestDb { + fn upcast(&self) -> &(dyn SemanticDb + 'static) { + self } } - impl LintDb for TestDb { - fn lint_syntax(&self, file_id: FileId) -> QueryResult { - lint_syntax(self, file_id) - } + impl LintDb for TestDb {} - fn lint_semantic(&self, file_id: FileId) -> QueryResult { - lint_semantic(self, file_id) + impl Upcast for TestDb { + fn upcast(&self) -> &(dyn LintDb + 'static) { + self } } + impl DbWithJar for TestDb {} + impl HasJars for TestDb { type Jars = (SourceJar, SemanticJar, LintJar); diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index c9a55ac18ab84..800287478997e 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -7,19 +7,17 @@ use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{ModModule, StringLiteral}; use crate::cache::KeyValueCache; -use crate::db::{HasJar, LintDb, LintJar, QueryResult, SemanticDb}; +use crate::db::{LintDb, LintJar, QueryResult}; use crate::files::FileId; -use crate::parse::Parsed; -use crate::source::Source; -use crate::symbols::{Definition, SymbolId, SymbolTable}; -use crate::types::Type; +use crate::parse::{parse, Parsed}; +use crate::source::{source_text, Source}; +use crate::symbols::{symbol_table, Definition, SymbolId, SymbolTable}; +use crate::types::{infer_symbol_type, Type}; #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn lint_syntax(db: &Db, file_id: FileId) -> QueryResult -where - Db: LintDb + HasJar, -{ - let storage = &db.jar()?.lint_syntax; +pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult { + let lint_jar: &LintJar = db.jar()?; + let storage = &lint_jar.lint_syntax; #[allow(clippy::print_stdout)] if std::env::var("RED_KNOT_SLOW_LINT").is_ok() { @@ -33,10 +31,10 @@ where storage.get(&file_id, |file_id| { let mut diagnostics = Vec::new(); - let source = db.source(*file_id)?; + let source = source_text(db.upcast(), *file_id)?; lint_lines(source.text(), &mut diagnostics); - let parsed = db.parse(*file_id)?; + let parsed = parse(db.upcast(), *file_id)?; if parsed.errors().is_empty() { let ast = parsed.ast(); @@ -73,16 +71,14 @@ fn lint_lines(source: &str, diagnostics: &mut Vec) { } #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn lint_semantic(db: &Db, file_id: FileId) -> QueryResult -where - Db: LintDb + HasJar, -{ - let storage = &db.jar()?.lint_semantic; +pub(crate) fn lint_semantic(db: &dyn LintDb, file_id: FileId) -> QueryResult { + let lint_jar: &LintJar = db.jar()?; + let storage = &lint_jar.lint_semantic; storage.get(&file_id, |file_id| { - let source = db.source(*file_id)?; - let parsed = db.parse(*file_id)?; - let symbols = db.symbol_table(*file_id)?; + let source = source_text(db.upcast(), *file_id)?; + let parsed = parse(db.upcast(), *file_id)?; + let symbols = symbol_table(db.upcast(), *file_id)?; let context = SemanticLintContext { file_id: *file_id, @@ -145,7 +141,7 @@ pub struct SemanticLintContext<'a> { source: Source, parsed: Parsed, symbols: Arc, - db: &'a dyn SemanticDb, + db: &'a dyn LintDb, diagnostics: RefCell>, } @@ -167,7 +163,7 @@ impl<'a> SemanticLintContext<'a> { } pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult { - self.db.infer_symbol_type(self.file_id, symbol_id) + infer_symbol_type(self.db.upcast(), self.file_id, symbol_id) } pub fn push_diagnostic(&self, diagnostic: String) { diff --git a/crates/red_knot/src/main.rs b/crates/red_knot/src/main.rs index 53007bd088d62..5d2c5ff459c95 100644 --- a/crates/red_knot/src/main.rs +++ b/crates/red_knot/src/main.rs @@ -11,8 +11,8 @@ use tracing_subscriber::layer::{Context, Filter, SubscriberExt}; use tracing_subscriber::{Layer, Registry}; use tracing_tree::time::Uptime; -use red_knot::db::{HasJar, ParallelDatabase, QueryError, SemanticDb, SourceDb, SourceJar}; -use red_knot::module::{ModuleSearchPath, ModuleSearchPathKind}; +use red_knot::db::{HasJar, ParallelDatabase, QueryError, SourceDb, SourceJar}; +use red_knot::module::{set_module_search_paths, ModuleSearchPath, ModuleSearchPathKind}; use red_knot::program::check::ExecutionMode; use red_knot::program::{FileWatcherChange, Program}; use red_knot::watch::FileWatcher; @@ -49,7 +49,7 @@ fn main() -> anyhow::Result<()> { ModuleSearchPathKind::FirstParty, ); let mut program = Program::new(workspace); - program.set_module_search_paths(vec![workspace_search_path]); + set_module_search_paths(&mut program, vec![workspace_search_path]); let entry_id = program.file_id(entry_point); program.workspace_mut().open_file(entry_id); diff --git a/crates/red_knot/src/module.rs b/crates/red_knot/src/module.rs index 60cf17e40dd62..f98405055b4d7 100644 --- a/crates/red_knot/src/module.rs +++ b/crates/red_knot/src/module.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use dashmap::mapref::entry::Entry; use smol_str::SmolStr; -use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; +use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; use crate::symbols::Dependency; use crate::FxDashMap; @@ -17,41 +17,32 @@ use crate::FxDashMap; pub struct Module(u32); impl Module { - pub fn name(&self, db: &Db) -> QueryResult - where - Db: HasJar, - { - let modules = &db.jar()?.module_resolver; + pub fn name(&self, db: &dyn SemanticDb) -> QueryResult { + let jar: &SemanticJar = db.jar()?; + let modules = &jar.module_resolver; Ok(modules.modules.get(self).unwrap().name.clone()) } - pub fn path(&self, db: &Db) -> QueryResult - where - Db: HasJar, - { - let modules = &db.jar()?.module_resolver; + pub fn path(&self, db: &dyn SemanticDb) -> QueryResult { + let jar: &SemanticJar = db.jar()?; + let modules = &jar.module_resolver; Ok(modules.modules.get(self).unwrap().path.clone()) } - pub fn kind(&self, db: &Db) -> QueryResult - where - Db: HasJar, - { - let modules = &db.jar()?.module_resolver; + pub fn kind(&self, db: &dyn SemanticDb) -> QueryResult { + let jar: &SemanticJar = db.jar()?; + let modules = &jar.module_resolver; Ok(modules.modules.get(self).unwrap().kind) } - pub fn resolve_dependency( + pub fn resolve_dependency( &self, - db: &Db, + db: &dyn SemanticDb, dependency: &Dependency, - ) -> QueryResult> - where - Db: HasJar, - { + ) -> QueryResult> { let (level, module) = match dependency { Dependency::Module(module) => return Ok(Some(module.clone())), Dependency::Relative { level, module } => (*level, module.as_deref()), @@ -244,12 +235,9 @@ pub struct ModuleData { /// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient and, therefore, cannot be used as part of a query. /// For this to work with salsa, it would be necessary to intern all `ModuleName`s. #[tracing::instrument(level = "debug", skip(db))] -pub fn resolve_module(db: &Db, name: ModuleName) -> QueryResult> -where - Db: SemanticDb + HasJar, -{ - let jar = db.jar(); - let modules = &jar?.module_resolver; +pub fn resolve_module(db: &dyn SemanticDb, name: ModuleName) -> QueryResult> { + let jar: &SemanticJar = db.jar()?; + let modules = &jar.module_resolver; let entry = modules.by_name.entry(name.clone()); @@ -303,10 +291,7 @@ where /// /// Returns `None` if the path is not a module in `sys.path`. #[tracing::instrument(level = "debug", skip(db))] -pub fn path_to_module(db: &Db, path: &Path) -> QueryResult> -where - Db: SemanticDb + HasJar, -{ +pub fn path_to_module(db: &dyn SemanticDb, path: &Path) -> QueryResult> { let file = db.file_id(path); file_to_module(db, file) } @@ -315,11 +300,8 @@ where /// /// Returns `None` if the file is not a module in `sys.path`. #[tracing::instrument(level = "debug", skip(db))] -pub fn file_to_module(db: &Db, file: FileId) -> QueryResult> -where - Db: SemanticDb + HasJar, -{ - let jar = db.jar()?; +pub fn file_to_module(db: &dyn SemanticDb, file: FileId) -> QueryResult> { + let jar: &SemanticJar = db.jar()?; let modules = &jar.module_resolver; if let Some(existing) = modules.by_file.get(&file) { @@ -381,11 +363,8 @@ where ////////////////////////////////////////////////////// /// Changes the module search paths to `search_paths`. -pub fn set_module_search_paths(db: &mut Db, search_paths: Vec) -where - Db: SemanticDb + HasJar, -{ - let jar = db.jar_mut(); +pub fn set_module_search_paths(db: &mut dyn SemanticDb, search_paths: Vec) { + let jar: &mut SemanticJar = db.jar_mut(); jar.module_resolver = ModuleResolver::new(search_paths); } @@ -397,10 +376,7 @@ where /// Returns `Some` with the id of the module and the ids of the modules that need re-resolving /// because they were part of a namespace package and might now resolve differently. /// Note: This won't work with salsa because `Path` is not an ingredient. -pub fn add_module(db: &mut Db, path: &Path) -> Option<(Module, Vec>)> -where - Db: SemanticDb + HasJar, -{ +pub fn add_module(db: &mut dyn SemanticDb, path: &Path) -> Option<(Module, Vec>)> { // No locking is required because we're holding a mutable reference to `modules`. // TODO This needs tests @@ -426,7 +402,7 @@ where let mut to_remove = Vec::new(); - let jar = db.jar_mut(); + let jar: &mut SemanticJar = db.jar_mut(); let modules = &mut jar.module_resolver; modules.by_file.retain(|_, id| { @@ -676,11 +652,15 @@ impl PackageKind { #[cfg(test)] mod tests { + use std::num::NonZeroU32; + use crate::db::tests::TestDb; - use crate::db::{SemanticDb, SourceDb}; - use crate::module::{ModuleKind, ModuleName, ModuleSearchPath, ModuleSearchPathKind}; + use crate::db::SourceDb; + use crate::module::{ + path_to_module, resolve_module, set_module_search_paths, ModuleKind, ModuleName, + ModuleSearchPath, ModuleSearchPathKind, + }; use crate::symbols::Dependency; - use std::num::NonZeroU32; struct TestCase { temp_dir: tempfile::TempDir, @@ -708,7 +688,7 @@ mod tests { let roots = vec![src.clone(), site_packages.clone()]; let mut db = TestDb::default(); - db.set_module_search_paths(roots); + set_module_search_paths(&mut db, roots); Ok(TestCase { temp_dir, @@ -730,16 +710,19 @@ mod tests { let foo_path = src.path().join("foo.py"); std::fs::write(&foo_path, "print('Hello, world!')")?; - let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); + let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); - assert_eq!(Some(foo_module), db.resolve_module(ModuleName::new("foo"))?); + assert_eq!( + Some(foo_module), + resolve_module(&db, ModuleName::new("foo"))? + ); assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?); assert_eq!(&src, foo_module.path(&db)?.root()); assert_eq!(ModuleKind::Module, foo_module.kind(&db)?); assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file())); - assert_eq!(Some(foo_module), db.path_to_module(&foo_path)?); + assert_eq!(Some(foo_module), path_to_module(&db, &foo_path)?); Ok(()) } @@ -758,16 +741,16 @@ mod tests { std::fs::create_dir(&foo_dir)?; std::fs::write(&foo_path, "print('Hello, world!')")?; - let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); + let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?); assert_eq!(&src, foo_module.path(&db)?.root()); assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file())); - assert_eq!(Some(foo_module), db.path_to_module(&foo_path)?); + assert_eq!(Some(foo_module), path_to_module(&db, &foo_path)?); // Resolving by directory doesn't resolve to the init file. - assert_eq!(None, db.path_to_module(&foo_dir)?); + assert_eq!(None, path_to_module(&db, &foo_dir)?); Ok(()) } @@ -789,14 +772,14 @@ mod tests { let foo_py = src.path().join("foo.py"); std::fs::write(&foo_py, "print('Hello, world!')")?; - let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); + let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); assert_eq!(&src, foo_module.path(&db)?.root()); assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db)?.file())); assert_eq!(ModuleKind::Package, foo_module.kind(&db)?); - assert_eq!(Some(foo_module), db.path_to_module(&foo_init)?); - assert_eq!(None, db.path_to_module(&foo_py)?); + assert_eq!(Some(foo_module), path_to_module(&db, &foo_init)?); + assert_eq!(None, path_to_module(&db, &foo_py)?); Ok(()) } @@ -815,13 +798,13 @@ mod tests { std::fs::write(&foo_stub, "x: int")?; std::fs::write(&foo_py, "print('Hello, world!')")?; - let foo = db.resolve_module(ModuleName::new("foo"))?.unwrap(); + let foo = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); assert_eq!(&src, foo.path(&db)?.root()); assert_eq!(&foo_stub, &*db.file_path(foo.path(&db)?.file())); - assert_eq!(Some(foo), db.path_to_module(&foo_stub)?); - assert_eq!(None, db.path_to_module(&foo_py)?); + assert_eq!(Some(foo), path_to_module(&db, &foo_stub)?); + assert_eq!(None, path_to_module(&db, &foo_py)?); Ok(()) } @@ -844,12 +827,12 @@ mod tests { std::fs::write(bar.join("__init__.py"), "")?; std::fs::write(&baz, "print('Hello, world!')")?; - let baz_module = db.resolve_module(ModuleName::new("foo.bar.baz"))?.unwrap(); + let baz_module = resolve_module(&db, ModuleName::new("foo.bar.baz"))?.unwrap(); assert_eq!(&src, baz_module.path(&db)?.root()); assert_eq!(&baz, &*db.file_path(baz_module.path(&db)?.file())); - assert_eq!(Some(baz_module), db.path_to_module(&baz)?); + assert_eq!(Some(baz_module), path_to_module(&db, &baz)?); Ok(()) } @@ -890,16 +873,12 @@ mod tests { std::fs::create_dir_all(&child2)?; std::fs::write(&two, "print('Hello, world!')")?; - let one_module = db - .resolve_module(ModuleName::new("parent.child.one"))? - .unwrap(); + let one_module = resolve_module(&db, ModuleName::new("parent.child.one"))?.unwrap(); - assert_eq!(Some(one_module), db.path_to_module(&one)?); + assert_eq!(Some(one_module), path_to_module(&db, &one)?); - let two_module = db - .resolve_module(ModuleName::new("parent.child.two"))? - .unwrap(); - assert_eq!(Some(two_module), db.path_to_module(&two)?); + let two_module = resolve_module(&db, ModuleName::new("parent.child.two"))?.unwrap(); + assert_eq!(Some(two_module), path_to_module(&db, &two)?); Ok(()) } @@ -941,15 +920,13 @@ mod tests { std::fs::create_dir_all(&child2)?; std::fs::write(two, "print('Hello, world!')")?; - let one_module = db - .resolve_module(ModuleName::new("parent.child.one"))? - .unwrap(); + let one_module = resolve_module(&db, ModuleName::new("parent.child.one"))?.unwrap(); - assert_eq!(Some(one_module), db.path_to_module(&one)?); + assert_eq!(Some(one_module), path_to_module(&db, &one)?); assert_eq!( None, - db.resolve_module(ModuleName::new("parent.child.two"))? + resolve_module(&db, ModuleName::new("parent.child.two"))? ); Ok(()) } @@ -969,13 +946,13 @@ mod tests { std::fs::write(&foo_src, "")?; std::fs::write(&foo_site_packages, "")?; - let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); + let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); assert_eq!(&src, foo_module.path(&db)?.root()); assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db)?.file())); - assert_eq!(Some(foo_module), db.path_to_module(&foo_src)?); - assert_eq!(None, db.path_to_module(&foo_site_packages)?); + assert_eq!(Some(foo_module), path_to_module(&db, &foo_src)?); + assert_eq!(None, path_to_module(&db, &foo_site_packages)?); Ok(()) } @@ -996,8 +973,8 @@ mod tests { std::fs::write(&foo, "")?; std::os::unix::fs::symlink(&foo, &bar)?; - let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); - let bar_module = db.resolve_module(ModuleName::new("bar"))?.unwrap(); + let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); + let bar_module = resolve_module(&db, ModuleName::new("bar"))?.unwrap(); assert_ne!(foo_module, bar_module); @@ -1010,8 +987,8 @@ mod tests { assert_eq!(foo_module.path(&db)?.file(), bar_module.path(&db)?.file()); assert_eq!(&foo, &*db.file_path(bar_module.path(&db)?.file())); - assert_eq!(Some(foo_module), db.path_to_module(&foo)?); - assert_eq!(Some(bar_module), db.path_to_module(&bar)?); + assert_eq!(Some(foo_module), path_to_module(&db, &foo)?); + assert_eq!(Some(bar_module), path_to_module(&db, &bar)?); Ok(()) } @@ -1033,8 +1010,8 @@ mod tests { std::fs::write(foo_path, "from .bar import test")?; std::fs::write(bar_path, "test = 'Hello world'")?; - let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); - let bar_module = db.resolve_module(ModuleName::new("foo.bar"))?.unwrap(); + let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); + let bar_module = resolve_module(&db, ModuleName::new("foo.bar"))?.unwrap(); // `from . import bar` in `foo/__init__.py` resolves to `foo` assert_eq!( diff --git a/crates/red_knot/src/parse.rs b/crates/red_knot/src/parse.rs index e76cd06706c13..6856315dcb494 100644 --- a/crates/red_knot/src/parse.rs +++ b/crates/red_knot/src/parse.rs @@ -6,8 +6,9 @@ use ruff_python_parser::{Mode, ParseError}; use ruff_text_size::{Ranged, TextRange}; use crate::cache::KeyValueCache; -use crate::db::{HasJar, QueryResult, SourceDb, SourceJar}; +use crate::db::{QueryResult, SourceDb}; use crate::files::FileId; +use crate::source::source_text; #[derive(Debug, Clone, PartialEq)] pub struct Parsed { @@ -64,14 +65,11 @@ impl Parsed { } #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn parse(db: &Db, file_id: FileId) -> QueryResult -where - Db: SourceDb + HasJar, -{ - let parsed = db.jar()?; - - parsed.parsed.get(&file_id, |file_id| { - let source = db.source(*file_id)?; +pub(crate) fn parse(db: &dyn SourceDb, file_id: FileId) -> QueryResult { + let jar = db.jar()?; + + jar.parsed.get(&file_id, |file_id| { + let source = source_text(db, *file_id)?; Ok(Parsed::from_text(source.text())) }) diff --git a/crates/red_knot/src/program/check.rs b/crates/red_knot/src/program/check.rs index 1b60893c619cb..8e116bec5f32d 100644 --- a/crates/red_knot/src/program/check.rs +++ b/crates/red_knot/src/program/check.rs @@ -1,11 +1,12 @@ use rayon::{current_num_threads, yield_local}; use rustc_hash::FxHashSet; -use crate::db::{Database, LintDb, QueryError, QueryResult, SemanticDb}; +use crate::db::{Database, QueryError, QueryResult}; use crate::files::FileId; -use crate::lint::Diagnostics; +use crate::lint::{lint_semantic, lint_syntax, Diagnostics}; +use crate::module::{file_to_module, resolve_module}; use crate::program::Program; -use crate::symbols::Dependency; +use crate::symbols::{symbol_table, Dependency}; impl Program { /// Checks all open files in the workspace and its dependencies. @@ -27,11 +28,11 @@ impl Program { fn check_file(&self, file: FileId, context: &CheckFileContext) -> QueryResult { self.cancelled()?; - let symbol_table = self.symbol_table(file)?; + let symbol_table = symbol_table(self, file)?; let dependencies = symbol_table.dependencies(); if !dependencies.is_empty() { - let module = self.file_to_module(file)?; + let module = file_to_module(self, file)?; // TODO scheduling all dependencies here is wasteful if we don't infer any types on them // but I think that's unlikely, so it is okay? @@ -50,7 +51,7 @@ impl Program { // TODO We may want to have a different check functions for non-first-party // files because we only need to index them and not check them. // Supporting non-first-party code also requires supporting typing stubs. - if let Some(dependency) = self.resolve_module(dependency_name)? { + if let Some(dependency) = resolve_module(self, dependency_name)? { if dependency.path(self)?.root().kind().is_first_party() { context.schedule_dependency(dependency.path(self)?.file()); } @@ -62,8 +63,8 @@ impl Program { let mut diagnostics = Vec::new(); if self.workspace().is_file_open(file) { - diagnostics.extend_from_slice(&self.lint_syntax(file)?); - diagnostics.extend_from_slice(&self.lint_semantic(file)?); + diagnostics.extend_from_slice(&lint_syntax(self, file)?); + diagnostics.extend_from_slice(&lint_semantic(self, file)?); } Ok(Diagnostics::from(diagnostics)) diff --git a/crates/red_knot/src/program/mod.rs b/crates/red_knot/src/program/mod.rs index 4f7c538ea639f..4650b65967d06 100644 --- a/crates/red_knot/src/program/mod.rs +++ b/crates/red_knot/src/program/mod.rs @@ -5,19 +5,10 @@ use std::sync::Arc; use rustc_hash::FxHashMap; use crate::db::{ - Database, Db, DbRuntime, HasJar, HasJars, JarsStorage, LintDb, LintJar, ParallelDatabase, - QueryResult, SemanticDb, SemanticJar, Snapshot, SourceDb, SourceJar, + Database, Db, DbRuntime, DbWithJar, HasJar, HasJars, JarsStorage, LintDb, LintJar, + ParallelDatabase, QueryResult, SemanticDb, SemanticJar, Snapshot, SourceDb, SourceJar, Upcast, }; use crate::files::{FileId, Files}; -use crate::lint::{lint_semantic, lint_syntax, Diagnostics}; -use crate::module::{ - add_module, file_to_module, path_to_module, resolve_module, set_module_search_paths, Module, - ModuleData, ModuleName, ModuleSearchPath, -}; -use crate::parse::{parse, Parsed}; -use crate::source::{source_text, Source}; -use crate::symbols::{symbol_table, SymbolId, SymbolTable}; -use crate::types::{infer_symbol_type, Type}; use crate::Workspace; pub mod check; @@ -83,54 +74,33 @@ impl SourceDb for Program { fn file_path(&self, file_id: FileId) -> Arc { self.files.path(file_id) } - - fn source(&self, file_id: FileId) -> QueryResult { - source_text(self, file_id) - } - - fn parse(&self, file_id: FileId) -> QueryResult { - parse(self, file_id) - } } -impl SemanticDb for Program { - fn resolve_module(&self, name: ModuleName) -> QueryResult> { - resolve_module(self, name) - } +impl DbWithJar for Program {} - fn file_to_module(&self, file_id: FileId) -> QueryResult> { - file_to_module(self, file_id) - } +impl SemanticDb for Program {} - fn path_to_module(&self, path: &Path) -> QueryResult> { - path_to_module(self, path) - } - - fn symbol_table(&self, file_id: FileId) -> QueryResult> { - symbol_table(self, file_id) - } +impl DbWithJar for Program {} - fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult { - infer_symbol_type(self, file_id, symbol_id) - } +impl LintDb for Program {} - // Mutations - fn add_module(&mut self, path: &Path) -> Option<(Module, Vec>)> { - add_module(self, path) - } +impl DbWithJar for Program {} - fn set_module_search_paths(&mut self, paths: Vec) { - set_module_search_paths(self, paths); +impl Upcast for Program { + fn upcast(&self) -> &(dyn SemanticDb + 'static) { + self } } -impl LintDb for Program { - fn lint_syntax(&self, file_id: FileId) -> QueryResult { - lint_syntax(self, file_id) +impl Upcast for Program { + fn upcast(&self) -> &(dyn SourceDb + 'static) { + self } +} - fn lint_semantic(&self, file_id: FileId) -> QueryResult { - lint_semantic(self, file_id) +impl Upcast for Program { + fn upcast(&self) -> &(dyn LintDb + 'static) { + self } } diff --git a/crates/red_knot/src/source.rs b/crates/red_knot/src/source.rs index 69092d684453f..6746b90f81b7c 100644 --- a/crates/red_knot/src/source.rs +++ b/crates/red_knot/src/source.rs @@ -1,18 +1,17 @@ -use crate::cache::KeyValueCache; -use crate::db::{HasJar, QueryResult, SourceDb, SourceJar}; -use ruff_notebook::Notebook; -use ruff_python_ast::PySourceType; use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use ruff_notebook::Notebook; +use ruff_python_ast::PySourceType; + +use crate::cache::KeyValueCache; +use crate::db::{QueryResult, SourceDb}; use crate::files::FileId; #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn source_text(db: &Db, file_id: FileId) -> QueryResult -where - Db: SourceDb + HasJar, -{ - let sources = &db.jar()?.sources; +pub(crate) fn source_text(db: &dyn SourceDb, file_id: FileId) -> QueryResult { + let jar = db.jar()?; + let sources = &jar.sources; sources.get(&file_id, |file_id| { let path = db.file_path(*file_id); diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index be2ae264cae89..5f00f0d45db94 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -16,21 +16,19 @@ use ruff_python_ast::visitor::preorder::PreorderVisitor; use crate::ast_ids::TypedNodeKey; use crate::cache::KeyValueCache; -use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; +use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; use crate::module::ModuleName; +use crate::parse::parse; use crate::Name; #[allow(unreachable_pub)] #[tracing::instrument(level = "debug", skip(db))] -pub fn symbol_table(db: &Db, file_id: FileId) -> QueryResult> -where - Db: SemanticDb + HasJar, -{ - let jar = db.jar()?; +pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult> { + let jar: &SemanticJar = db.jar()?; jar.symbol_tables.get(&file_id, |_| { - let parsed = db.parse(file_id)?; + let parsed = parse(db.upcast(), file_id)?; Ok(Arc::from(SymbolTable::from_ast(parsed.ast()))) }) } diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index 871c43092bb4c..c10583757b959 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -1,16 +1,18 @@ #![allow(dead_code)] + +use rustc_hash::FxHashMap; + +pub(crate) use infer::infer_symbol_type; +use ruff_index::{newtype_index, IndexVec}; + use crate::ast_ids::NodeKey; -use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; +use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; -use crate::symbols::{ScopeId, SymbolId}; +use crate::symbols::{symbol_table, ScopeId, SymbolId}; use crate::{FxDashMap, FxIndexSet, Name}; -use ruff_index::{newtype_index, IndexVec}; -use rustc_hash::FxHashMap; pub(crate) mod infer; -pub(crate) use infer::infer_symbol_type; - /// unique ID for a type #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub enum Type { @@ -262,15 +264,14 @@ pub struct ClassTypeId { } impl ClassTypeId { - fn get_own_class_member(self, db: &Db, name: &Name) -> QueryResult> - where - Db: SemanticDb + HasJar, - { + fn get_own_class_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { + let jar: &SemanticJar = db.jar()?; + // TODO: this should distinguish instance-only members (e.g. `x: int`) and not return them - let ClassType { scope_id, .. } = *db.jar()?.type_store.get_class(self); - let table = db.symbol_table(self.file_id)?; + let ClassType { scope_id, .. } = *jar.type_store.get_class(self); + let table = symbol_table(db, self.file_id)?; if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) { - Ok(Some(db.infer_symbol_type(self.file_id, symbol_id)?)) + Ok(Some(infer_symbol_type(db, self.file_id, symbol_id)?)) } else { Ok(None) } @@ -526,11 +527,12 @@ impl IntersectionType { #[cfg(test)] mod tests { + use std::path::Path; + use crate::files::Files; use crate::symbols::SymbolTable; use crate::types::{Type, TypeStore}; use crate::FxIndexSet; - use std::path::Path; #[test] fn add_class() { diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index efdd3a484cb6a..43ab0b131f70a 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -1,34 +1,35 @@ #![allow(dead_code)] +use ruff_python_ast as ast; use ruff_python_ast::AstNode; -use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; -use crate::module::ModuleName; -use crate::symbols::{ClassDefinition, Definition, ImportFromDefinition, SymbolId}; +use crate::db::{QueryResult, SemanticDb, SemanticJar}; + +use crate::module::{resolve_module, ModuleName}; +use crate::parse::parse; +use crate::symbols::{symbol_table, ClassDefinition, Definition, ImportFromDefinition, SymbolId}; use crate::types::Type; use crate::FileId; -use ruff_python_ast as ast; // FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. #[tracing::instrument(level = "trace", skip(db))] -pub fn infer_symbol_type(db: &Db, file_id: FileId, symbol_id: SymbolId) -> QueryResult -where - Db: SemanticDb + HasJar, -{ - let symbols = db.symbol_table(file_id)?; +pub fn infer_symbol_type( + db: &dyn SemanticDb, + file_id: FileId, + symbol_id: SymbolId, +) -> QueryResult { + let symbols = symbol_table(db, file_id)?; let defs = symbols.definitions(symbol_id); - if let Some(ty) = db - .jar()? - .type_store - .get_cached_symbol_type(file_id, symbol_id) - { + let jar: &SemanticJar = db.jar()?; + let type_store = &jar.type_store; + + if let Some(ty) = type_store.get_cached_symbol_type(file_id, symbol_id) { return Ok(ty); } // TODO handle multiple defs, conditional defs... assert_eq!(defs.len(), 1); - let type_store = &db.jar()?.type_store; let ty = match &defs[0] { Definition::ImportFrom(ImportFromDefinition { @@ -39,11 +40,11 @@ where // TODO relative imports assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); - if let Some(module) = db.resolve_module(module_name)? { + if let Some(module) = resolve_module(db, module_name)? { let remote_file_id = module.path(db)?.file(); - let remote_symbols = db.symbol_table(remote_file_id)?; + let remote_symbols = symbol_table(db, remote_file_id)?; if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) { - db.infer_symbol_type(remote_file_id, remote_symbol_id)? + infer_symbol_type(db, remote_file_id, remote_symbol_id)? } else { Type::Unknown } @@ -55,7 +56,7 @@ where if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { ty } else { - let parsed = db.parse(file_id)?; + let parsed = parse(db.upcast(), file_id)?; let ast = parsed.ast(); let node = node_key.resolve_unwrap(ast.as_any_node_ref()); @@ -75,7 +76,7 @@ where if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { ty } else { - let parsed = db.parse(file_id)?; + let parsed = parse(db.upcast(), file_id)?; let ast = parsed.ast(); let node = node_key .resolve(ast.as_any_node_ref()) @@ -95,7 +96,7 @@ where } } Definition::Assignment(node_key) => { - let parsed = db.parse(file_id)?; + let parsed = parse(db.upcast(), file_id)?; let ast = parsed.ast(); let node = node_key.resolve_unwrap(ast.as_any_node_ref()); // TODO handle unpacking assignment correctly @@ -110,16 +111,13 @@ where Ok(ty) } -fn infer_expr_type(db: &Db, file_id: FileId, expr: &ast::Expr) -> QueryResult -where - Db: SemanticDb + HasJar, -{ +fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> QueryResult { // TODO cache the resolution of the type on the node - let symbols = db.symbol_table(file_id)?; + let symbols = symbol_table(db, file_id)?; match expr { ast::Expr::Name(name) => { if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) { - db.infer_symbol_type(file_id, symbol_id) + infer_symbol_type(db, file_id, symbol_id) } else { Ok(Type::Unknown) } @@ -131,9 +129,12 @@ where #[cfg(test)] mod tests { use crate::db::tests::TestDb; - use crate::db::{HasJar, SemanticDb, SemanticJar}; - use crate::module::{ModuleName, ModuleSearchPath, ModuleSearchPathKind}; - use crate::types::Type; + use crate::db::{HasJar, SemanticJar}; + use crate::module::{ + resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind, + }; + use crate::symbols::symbol_table; + use crate::types::{infer_symbol_type, Type}; use crate::Name; // TODO with virtual filesystem we shouldn't have to write files to disk for these @@ -156,7 +157,7 @@ mod tests { let roots = vec![src.clone()]; let mut db = TestDb::default(); - db.set_module_search_paths(roots); + set_module_search_paths(&mut db, roots); Ok(TestCase { temp_dir, db, src }) } @@ -170,17 +171,16 @@ mod tests { let b_path = case.src.path().join("b.py"); std::fs::write(a_path, "from b import C as D; E = D")?; std::fs::write(b_path, "class C: pass")?; - let a_file = db - .resolve_module(ModuleName::new("a"))? + let a_file = resolve_module(db, ModuleName::new("a"))? .expect("module should be found") .path(db)? .file(); - let a_syms = db.symbol_table(a_file)?; + let a_syms = symbol_table(db, a_file)?; let e_sym = a_syms .root_symbol_id_by_name("E") .expect("E symbol should be found"); - let ty = db.infer_symbol_type(a_file, e_sym)?; + let ty = infer_symbol_type(db, a_file, e_sym)?; let jar = HasJar::::jar(db)?; assert!(matches!(ty, Type::Class(_))); @@ -196,17 +196,16 @@ mod tests { let path = case.src.path().join("mod.py"); std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?; - let file = db - .resolve_module(ModuleName::new("mod"))? + let file = resolve_module(db, ModuleName::new("mod"))? .expect("module should be found") .path(db)? .file(); - let syms = db.symbol_table(file)?; + let syms = symbol_table(db, file)?; let sym = syms .root_symbol_id_by_name("Sub") .expect("Sub symbol should be found"); - let ty = db.infer_symbol_type(file, sym)?; + let ty = infer_symbol_type(db, file, sym)?; let Type::Class(class_id) = ty else { panic!("Sub is not a Class") @@ -232,17 +231,16 @@ mod tests { let path = case.src.path().join("mod.py"); std::fs::write(path, "class C:\n def f(self): pass")?; - let file = db - .resolve_module(ModuleName::new("mod"))? + let file = resolve_module(db, ModuleName::new("mod"))? .expect("module should be found") .path(db)? .file(); - let syms = db.symbol_table(file)?; + let syms = symbol_table(db, file)?; let sym = syms .root_symbol_id_by_name("C") .expect("C symbol should be found"); - let ty = db.infer_symbol_type(file, sym)?; + let ty = infer_symbol_type(db, file, sym)?; let Type::Class(class_id) = ty else { panic!("C is not a Class");