Skip to content

Commit

Permalink
Implement any_over_expr for type alias and type params (#5866)
Browse files Browse the repository at this point in the history
Part of #5062
  • Loading branch information
zanieb authored Jul 19, 2023
1 parent a459d8f commit b27f0fa
Showing 1 changed file with 141 additions and 4 deletions.
145 changes: 141 additions & 4 deletions crates/ruff_python_ast/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use rustc_hash::FxHashMap;
use rustpython_ast::CmpOp;
use rustpython_parser::ast::{
self, Arguments, Constant, ExceptHandler, Expr, Keyword, MatchCase, Pattern, Ranged, Stmt,
TypeParam,
};
use rustpython_parser::{lexer, Mode, Tok};
use smallvec::SmallVec;
Expand Down Expand Up @@ -265,6 +266,19 @@ where
}
}

pub fn any_over_type_param<F>(type_param: &TypeParam, func: &F) -> bool
where
F: Fn(&Expr) -> bool,
{
match type_param {
TypeParam::TypeVar(ast::TypeParamTypeVar { bound, .. }) => bound
.as_ref()
.map_or(false, |value| any_over_expr(value, func)),
TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { .. }) => false,
TypeParam::ParamSpec(ast::TypeParamParamSpec { .. }) => false,
}
}

pub fn any_over_pattern<F>(pattern: &Pattern, func: &F) -> bool
where
F: Fn(&Expr) -> bool,
Expand Down Expand Up @@ -391,6 +405,18 @@ where
targets,
range: _range,
}) => targets.iter().any(|expr| any_over_expr(expr, func)),
Stmt::TypeAlias(ast::StmtTypeAlias {
name,
type_params,
value,
..
}) => {
any_over_expr(name, func)
|| type_params
.iter()
.any(|type_param| any_over_type_param(type_param, func))
|| any_over_expr(value, func)
}
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
targets.iter().any(|expr| any_over_expr(expr, func)) || any_over_expr(value, func)
}
Expand Down Expand Up @@ -539,7 +565,6 @@ where
range: _range,
}) => any_over_expr(value, func),
Stmt::Pass(_) | Stmt::Break(_) | Stmt::Continue(_) => false,
Stmt::TypeAlias(_) => todo!(),
}
}

Expand Down Expand Up @@ -1564,15 +1589,22 @@ pub fn locate_cmp_ops(expr: &Expr, locator: &Locator) -> Vec<LocatedCmpOp> {
mod tests {
use std::borrow::Cow;

use std::cell::RefCell;

use std::vec;

use anyhow::Result;
use ruff_text_size::{TextLen, TextRange, TextSize};
use rustpython_ast::{CmpOp, Expr, Ranged};
use rustpython_ast::{
self, CmpOp, Constant, Expr, ExprConstant, ExprContext, ExprName, Identifier, Ranged, Stmt,
StmtTypeAlias, TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple,
};
use rustpython_parser::ast::Suite;
use rustpython_parser::Parse;

use crate::helpers::{
first_colon_range, has_trailing_content, locate_cmp_ops, resolve_imported_module_path,
LocatedCmpOp,
any_over_stmt, any_over_type_param, first_colon_range, has_trailing_content,
locate_cmp_ops, resolve_imported_module_path, LocatedCmpOp,
};
use crate::source_code::Locator;

Expand Down Expand Up @@ -1746,4 +1778,109 @@ y = 2

Ok(())
}

#[test]
fn any_over_stmt_type_alias() {
let seen = RefCell::new(Vec::new());
let name = Expr::Name(ExprName {
id: "x".to_string(),
range: TextRange::default(),
ctx: ExprContext::Load,
});
let constant_one = Expr::Constant(ExprConstant {
value: Constant::Int(1.into()),
kind: Some("x".to_string()),
range: TextRange::default(),
});
let constant_two = Expr::Constant(ExprConstant {
value: Constant::Int(2.into()),
kind: Some("y".to_string()),
range: TextRange::default(),
});
let constant_three = Expr::Constant(ExprConstant {
value: Constant::Int(3.into()),
kind: Some("z".to_string()),
range: TextRange::default(),
});
let type_var_one = TypeParam::TypeVar(TypeParamTypeVar {
range: TextRange::default(),
bound: Some(Box::new(constant_one.clone())),
name: Identifier::new("x", TextRange::default()),
});
let type_var_two = TypeParam::TypeVar(TypeParamTypeVar {
range: TextRange::default(),
bound: Some(Box::new(constant_two.clone())),
name: Identifier::new("x", TextRange::default()),
});
let type_alias = Stmt::TypeAlias(StmtTypeAlias {
name: Box::new(name.clone()),
type_params: vec![type_var_one, type_var_two],
value: Box::new(constant_three.clone()),
range: TextRange::default(),
});
assert!(!any_over_stmt(&type_alias, &|expr| {
seen.borrow_mut().push(expr.clone());
false
}));
assert_eq!(
seen.take(),
vec![name, constant_one, constant_two, constant_three]
);
}

#[test]
fn any_over_type_param_type_var() {
let type_var_no_bound = TypeParam::TypeVar(TypeParamTypeVar {
range: TextRange::default(),
bound: None,
name: Identifier::new("x", TextRange::default()),
});
assert!(!any_over_type_param(&type_var_no_bound, &|_expr| true));

let bound = Expr::Constant(ExprConstant {
value: Constant::Int(1.into()),
kind: Some("x".to_string()),
range: TextRange::default(),
});

let type_var_with_bound = TypeParam::TypeVar(TypeParamTypeVar {
range: TextRange::default(),
bound: Some(Box::new(bound.clone())),
name: Identifier::new("x", TextRange::default()),
});
assert!(
any_over_type_param(&type_var_with_bound, &|expr| {
assert_eq!(
*expr, bound,
"the received expression should be the unwrapped bound"
);
true
}),
"if true is returned from `func` it should be respected"
);
}

#[test]
fn any_over_type_param_type_var_tuple() {
let type_var_tuple = TypeParam::TypeVarTuple(TypeParamTypeVarTuple {
range: TextRange::default(),
name: Identifier::new("x", TextRange::default()),
});
assert!(
!any_over_type_param(&type_var_tuple, &|_expr| true),
"type var tuples have no expressions to visit"
);
}

#[test]
fn any_over_type_param_param_spec() {
let type_param_spec = TypeParam::ParamSpec(TypeParamParamSpec {
range: TextRange::default(),
name: Identifier::new("x", TextRange::default()),
});
assert!(
!any_over_type_param(&type_param_spec, &|_expr| true),
"param specs have no expressions to visit"
);
}
}

0 comments on commit b27f0fa

Please sign in to comment.