Skip to content

Commit

Permalink
Implement Comparable for type aliases and parameters (astral-sh#5865)
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb authored and evanrittenhouse committed Jul 19, 2023
1 parent 3b29822 commit 1b153ba
Showing 1 changed file with 78 additions and 7 deletions.
85 changes: 78 additions & 7 deletions crates/ruff_python_ast/src/comparable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ pub struct StmtFunctionDef<'a> {
args: ComparableArguments<'a>,
body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
returns: Option<ComparableExpr<'a>>,
type_comment: Option<&'a str>,
}
Expand All @@ -945,6 +946,7 @@ pub struct StmtAsyncFunctionDef<'a> {
args: ComparableArguments<'a>,
body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
returns: Option<ComparableExpr<'a>>,
type_comment: Option<&'a str>,
}
Expand All @@ -956,6 +958,7 @@ pub struct StmtClassDef<'a> {
keywords: Vec<ComparableKeyword<'a>>,
body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
}

#[derive(Debug, PartialEq, Eq, Hash)]
Expand All @@ -968,6 +971,61 @@ pub struct StmtDelete<'a> {
targets: Vec<ComparableExpr<'a>>,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct StmtTypeAlias<'a> {
pub name: Box<ComparableExpr<'a>>,
pub type_params: Vec<ComparableTypeParam<'a>>,
pub value: Box<ComparableExpr<'a>>,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub enum ComparableTypeParam<'a> {
TypeVar(TypeParamTypeVar<'a>),
ParamSpec(TypeParamParamSpec<'a>),
TypeVarTuple(TypeParamTypeVarTuple<'a>),
}

impl<'a> From<&'a ast::TypeParam> for ComparableTypeParam<'a> {
fn from(type_param: &'a ast::TypeParam) -> Self {
match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar {
name,
bound,
range: _,
}) => Self::TypeVar(TypeParamTypeVar {
name: name.as_str(),
bound: bound.as_ref().map(Into::into),
}),
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, range: _ }) => {
Self::TypeVarTuple(TypeParamTypeVarTuple {
name: name.as_str(),
})
}
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, range: _ }) => {
Self::ParamSpec(TypeParamParamSpec {
name: name.as_str(),
})
}
}
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamTypeVar<'a> {
pub name: &'a str,
pub bound: Option<Box<ComparableExpr<'a>>>,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamParamSpec<'a> {
pub name: &'a str,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamTypeVarTuple<'a> {
pub name: &'a str,
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct StmtAssign<'a> {
targets: Vec<ComparableExpr<'a>>,
Expand Down Expand Up @@ -1117,6 +1175,7 @@ pub enum ComparableStmt<'a> {
Raise(StmtRaise<'a>),
Try(StmtTry<'a>),
TryStar(StmtTryStar<'a>),
TypeAlias(StmtTypeAlias<'a>),
Assert(StmtAssert<'a>),
Import(StmtImport<'a>),
ImportFrom(StmtImportFrom<'a>),
Expand All @@ -1138,15 +1197,16 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list,
returns,
type_comment,
range: _,
type_params: _,
type_params,
range: _range,
}) => Self::FunctionDef(StmtFunctionDef {
name: name.as_str(),
args: args.into(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),
returns: returns.as_ref().map(Into::into),
type_comment: type_comment.as_ref().map(String::as_str),
type_params: type_params.iter().map(Into::into).collect(),
}),
ast::Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef {
name,
Expand All @@ -1155,30 +1215,32 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list,
returns,
type_comment,
range: _,
type_params: _,
type_params,
range: _range,
}) => Self::AsyncFunctionDef(StmtAsyncFunctionDef {
name: name.as_str(),
args: args.into(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),
returns: returns.as_ref().map(Into::into),
type_comment: type_comment.as_ref().map(String::as_str),
type_params: type_params.iter().map(Into::into).collect(),
}),
ast::Stmt::ClassDef(ast::StmtClassDef {
name,
bases,
keywords,
body,
decorator_list,
range: _,
type_params: _,
type_params,
range: _range,
}) => Self::ClassDef(StmtClassDef {
name: name.as_str(),
bases: bases.iter().map(Into::into).collect(),
keywords: keywords.iter().map(Into::into).collect(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),
type_params: type_params.iter().map(Into::into).collect(),
}),
ast::Stmt::Return(ast::StmtReturn {
value,
Expand All @@ -1192,6 +1254,16 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
}) => Self::Delete(StmtDelete {
targets: targets.iter().map(Into::into).collect(),
}),
ast::Stmt::TypeAlias(ast::StmtTypeAlias {
range: _range,
name,
type_params,
value,
}) => Self::TypeAlias(StmtTypeAlias {
name: name.into(),
type_params: type_params.iter().map(Into::into).collect(),
value: value.into(),
}),
ast::Stmt::Assign(ast::StmtAssign {
targets,
value,
Expand Down Expand Up @@ -1377,7 +1449,6 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
ast::Stmt::Pass(_) => Self::Pass,
ast::Stmt::Break(_) => Self::Break,
ast::Stmt::Continue(_) => Self::Continue,
ast::Stmt::TypeAlias(_) => todo!(),
}
}
}

0 comments on commit 1b153ba

Please sign in to comment.