Skip to content

Commit

Permalink
WIP: Add support for TypeAlias and TypeParam
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb committed Jul 17, 2023
1 parent bfaa1f9 commit e34cfeb
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 2 deletions.
73 changes: 73 additions & 0 deletions crates/ruff_python_ast/src/comparable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ pub struct StmtFunctionDef<'a> {
args: ComparableArguments<'a>,
body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
returns: Option<ComparableExpr<'a>>,
type_comment: Option<&'a str>,
}
Expand All @@ -925,6 +926,7 @@ pub struct StmtAsyncFunctionDef<'a> {
args: ComparableArguments<'a>,
body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
returns: Option<ComparableExpr<'a>>,
type_comment: Option<&'a str>,
}
Expand All @@ -936,6 +938,7 @@ pub struct StmtClassDef<'a> {
keywords: Vec<ComparableKeyword<'a>>,
body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
}

#[derive(Debug, PartialEq, Eq, Hash)]
Expand All @@ -948,6 +951,59 @@ pub struct StmtDelete<'a> {
targets: Vec<ComparableExpr<'a>>,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct StmtTypeAlias<'a> {
pub name: Box<ComparableExpr<'a>>,
pub type_params: Vec<ComparableTypeParam<'a>>,
pub value: Box<ComparableExpr<'a>>,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub enum ComparableTypeParam<'a> {
TypeVar(TypeParamTypeVar<'a>),
ParamSpec(TypeParamParamSpec<'a>),
TypeVarTuple(TypeParamTypeVarTuple<'a>),
}

impl<'a> From<&'a ast::TypeParam> for ComparableTypeParam<'a> {
fn from(type_param: &'a ast::TypeParam) -> Self {
match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, bound, .. }) => {
Self::TypeVar(TypeParamTypeVar {
name: name.as_str(),
bound: bound.as_ref().map(Into::into),
})
}
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => {
Self::TypeVarTuple(TypeParamTypeVarTuple {
name: name.as_str(),
})
}
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => {
Self::ParamSpec(TypeParamParamSpec {
name: name.as_str(),
})
}
}
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamTypeVar<'a> {
pub name: &'a str,
pub bound: Option<Box<ComparableExpr<'a>>>,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamParamSpec<'a> {
pub name: &'a str,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamTypeVarTuple<'a> {
pub name: &'a str,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct StmtAssign<'a> {
targets: Vec<ComparableExpr<'a>>,
Expand Down Expand Up @@ -1097,6 +1153,7 @@ pub enum ComparableStmt<'a> {
Raise(StmtRaise<'a>),
Try(StmtTry<'a>),
TryStar(StmtTryStar<'a>),
TypeAlias(StmtTypeAlias<'a>),
Assert(StmtAssert<'a>),
Import(StmtImport<'a>),
ImportFrom(StmtImportFrom<'a>),
Expand All @@ -1118,6 +1175,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list,
returns,
type_comment,
type_params,
range: _range,
}) => Self::FunctionDef(StmtFunctionDef {
name: name.as_str(),
Expand All @@ -1126,6 +1184,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list: decorator_list.iter().map(Into::into).collect(),
returns: returns.as_ref().map(Into::into),
type_comment: type_comment.as_ref().map(String::as_str),
type_params: type_params.iter().map(Into::into).collect(),
}),
ast::Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef {
name,
Expand All @@ -1134,6 +1193,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list,
returns,
type_comment,
type_params,
range: _range,
}) => Self::AsyncFunctionDef(StmtAsyncFunctionDef {
name: name.as_str(),
Expand All @@ -1142,20 +1202,23 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list: decorator_list.iter().map(Into::into).collect(),
returns: returns.as_ref().map(Into::into),
type_comment: type_comment.as_ref().map(String::as_str),
type_params: type_params.iter().map(Into::into).collect(),
}),
ast::Stmt::ClassDef(ast::StmtClassDef {
name,
bases,
keywords,
body,
decorator_list,
type_params,
range: _range,
}) => Self::ClassDef(StmtClassDef {
name: name.as_str(),
bases: bases.iter().map(Into::into).collect(),
keywords: keywords.iter().map(Into::into).collect(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),
type_params: type_params.iter().map(Into::into).collect(),
}),
ast::Stmt::Return(ast::StmtReturn {
value,
Expand All @@ -1169,6 +1232,16 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
}) => Self::Delete(StmtDelete {
targets: targets.iter().map(Into::into).collect(),
}),
ast::Stmt::TypeAlias(ast::StmtTypeAlias {
range: _range,
name,
type_params,
value,
}) => Self::TypeAlias(StmtTypeAlias {
name: name.into(),
type_params: type_params.iter().map(Into::into).collect(),
value: value.into(),
}),
ast::Stmt::Assign(ast::StmtAssign {
targets,
value,
Expand Down
25 changes: 24 additions & 1 deletion crates/ruff_python_ast/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use ruff_text_size::{TextRange, TextSize};
use rustc_hash::FxHashMap;
use rustpython_ast::CmpOp;
use rustpython_parser::ast::{
self, Arguments, Constant, ExceptHandler, Expr, Keyword, MatchCase, Pattern, Ranged, Stmt,
self, Arguments, Constant, ExceptHandler, Expr, Keyword, MatchCase, Pattern, Ranged, Stmt, TypeParam
};
use rustpython_parser::{lexer, Mode, Tok};
use smallvec::SmallVec;
Expand Down Expand Up @@ -265,6 +265,24 @@ where
}
}


pub fn any_over_type_param<F>(type_param: &TypeParam, func: &F) -> bool
where
F: Fn(&Expr) -> bool,
{
match type_param {
TypeParam::TypeVar(ast::TypeParamTypeVar { bound, .. }) => {
bound.as_ref().map_or(false, |value| any_over_expr(value, func))
}
TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { .. }) => {
false
}
TypeParam::ParamSpec(ast::TypeParamParamSpec { .. }) => {
false
}
}
}

pub fn any_over_pattern<F>(pattern: &Pattern, func: &F) -> bool
where
F: Fn(&Expr) -> bool,
Expand Down Expand Up @@ -391,6 +409,11 @@ where
targets,
range: _range,
}) => targets.iter().any(|expr| any_over_expr(expr, func)),
Stmt::TypeAlias(ast::StmtTypeAlias { name, type_params, value, .. }) => {
any_over_expr(name, func)
|| type_params.iter().any(|type_param| any_over_type_param(type_param, func))
|| any_over_expr(value, func)
}
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
targets.iter().any(|expr| any_over_expr(expr, func)) || any_over_expr(value, func)
}
Expand Down
Loading

0 comments on commit e34cfeb

Please sign in to comment.