diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index 0dfad2d123ea9..e9da97b63d322 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -8,6 +8,7 @@ use rustc_hash::FxHashMap; use rustpython_ast::CmpOp; use rustpython_parser::ast::{ self, Arguments, Constant, ExceptHandler, Expr, Keyword, MatchCase, Pattern, Ranged, Stmt, + TypeParam, }; use rustpython_parser::{lexer, Mode, Tok}; use smallvec::SmallVec; @@ -265,6 +266,19 @@ where } } +pub fn any_over_type_param(type_param: &TypeParam, func: &F) -> bool +where + F: Fn(&Expr) -> bool, +{ + match type_param { + TypeParam::TypeVar(ast::TypeParamTypeVar { bound, .. }) => bound + .as_ref() + .map_or(false, |value| any_over_expr(value, func)), + TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { .. }) => false, + TypeParam::ParamSpec(ast::TypeParamParamSpec { .. }) => false, + } +} + pub fn any_over_pattern(pattern: &Pattern, func: &F) -> bool where F: Fn(&Expr) -> bool, @@ -391,6 +405,18 @@ where targets, range: _range, }) => targets.iter().any(|expr| any_over_expr(expr, func)), + Stmt::TypeAlias(ast::StmtTypeAlias { + name, + type_params, + value, + .. + }) => { + any_over_expr(name, func) + || type_params + .iter() + .any(|type_param| any_over_type_param(type_param, func)) + || any_over_expr(value, func) + } Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { targets.iter().any(|expr| any_over_expr(expr, func)) || any_over_expr(value, func) } @@ -539,7 +565,6 @@ where range: _range, }) => any_over_expr(value, func), Stmt::Pass(_) | Stmt::Break(_) | Stmt::Continue(_) => false, - Stmt::TypeAlias(_) => todo!(), } } @@ -1564,15 +1589,22 @@ pub fn locate_cmp_ops(expr: &Expr, locator: &Locator) -> Vec { mod tests { use std::borrow::Cow; + use std::cell::RefCell; + + use std::vec; + use anyhow::Result; use ruff_text_size::{TextLen, TextRange, TextSize}; - use rustpython_ast::{CmpOp, Expr, Ranged}; + use rustpython_ast::{ + self, CmpOp, Constant, Expr, ExprConstant, ExprContext, ExprName, Identifier, Ranged, Stmt, + StmtTypeAlias, TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, + }; use rustpython_parser::ast::Suite; use rustpython_parser::Parse; use crate::helpers::{ - first_colon_range, has_trailing_content, locate_cmp_ops, resolve_imported_module_path, - LocatedCmpOp, + any_over_stmt, any_over_type_param, first_colon_range, has_trailing_content, + locate_cmp_ops, resolve_imported_module_path, LocatedCmpOp, }; use crate::source_code::Locator; @@ -1746,4 +1778,109 @@ y = 2 Ok(()) } + + #[test] + fn any_over_stmt_type_alias() { + let seen = RefCell::new(Vec::new()); + let name = Expr::Name(ExprName { + id: "x".to_string(), + range: TextRange::default(), + ctx: ExprContext::Load, + }); + let constant_one = Expr::Constant(ExprConstant { + value: Constant::Int(1.into()), + kind: Some("x".to_string()), + range: TextRange::default(), + }); + let constant_two = Expr::Constant(ExprConstant { + value: Constant::Int(2.into()), + kind: Some("y".to_string()), + range: TextRange::default(), + }); + let constant_three = Expr::Constant(ExprConstant { + value: Constant::Int(3.into()), + kind: Some("z".to_string()), + range: TextRange::default(), + }); + let type_var_one = TypeParam::TypeVar(TypeParamTypeVar { + range: TextRange::default(), + bound: Some(Box::new(constant_one.clone())), + name: Identifier::new("x", TextRange::default()), + }); + let type_var_two = TypeParam::TypeVar(TypeParamTypeVar { + range: TextRange::default(), + bound: Some(Box::new(constant_two.clone())), + name: Identifier::new("x", TextRange::default()), + }); + let type_alias = Stmt::TypeAlias(StmtTypeAlias { + name: Box::new(name.clone()), + type_params: vec![type_var_one, type_var_two], + value: Box::new(constant_three.clone()), + range: TextRange::default(), + }); + assert!(!any_over_stmt(&type_alias, &|expr| { + seen.borrow_mut().push(expr.clone()); + false + })); + assert_eq!( + seen.take(), + vec![name, constant_one, constant_two, constant_three] + ); + } + + #[test] + fn any_over_type_param_type_var() { + let type_var_no_bound = TypeParam::TypeVar(TypeParamTypeVar { + range: TextRange::default(), + bound: None, + name: Identifier::new("x", TextRange::default()), + }); + assert!(!any_over_type_param(&type_var_no_bound, &|_expr| true)); + + let bound = Expr::Constant(ExprConstant { + value: Constant::Int(1.into()), + kind: Some("x".to_string()), + range: TextRange::default(), + }); + + let type_var_with_bound = TypeParam::TypeVar(TypeParamTypeVar { + range: TextRange::default(), + bound: Some(Box::new(bound.clone())), + name: Identifier::new("x", TextRange::default()), + }); + assert!( + any_over_type_param(&type_var_with_bound, &|expr| { + assert_eq!( + *expr, bound, + "the received expression should be the unwrapped bound" + ); + true + }), + "if true is returned from `func` it should be respected" + ); + } + + #[test] + fn any_over_type_param_type_var_tuple() { + let type_var_tuple = TypeParam::TypeVarTuple(TypeParamTypeVarTuple { + range: TextRange::default(), + name: Identifier::new("x", TextRange::default()), + }); + assert!( + !any_over_type_param(&type_var_tuple, &|_expr| true), + "type var tuples have no expressions to visit" + ); + } + + #[test] + fn any_over_type_param_param_spec() { + let type_param_spec = TypeParam::ParamSpec(TypeParamParamSpec { + range: TextRange::default(), + name: Identifier::new("x", TextRange::default()), + }); + assert!( + !any_over_type_param(&type_param_spec, &|_expr| true), + "param specs have no expressions to visit" + ); + } }