diff --git a/crates/ruff_python_ast/src/source_code/generator.rs b/crates/ruff_python_ast/src/source_code/generator.rs index 412b897eedd767..5efc6fe0c07a14 100644 --- a/crates/ruff_python_ast/src/source_code/generator.rs +++ b/crates/ruff_python_ast/src/source_code/generator.rs @@ -6,7 +6,8 @@ use std::ops::Deref; use rustpython_literal::escape::{AsciiEscape, Escape, UnicodeEscape}; use rustpython_parser::ast::{ self, Alias, Arg, Arguments, BoolOp, CmpOp, Comprehension, Constant, ConversionFlag, - ExceptHandler, Expr, Identifier, MatchCase, Operator, Pattern, Stmt, Suite, WithItem, + ExceptHandler, Expr, Identifier, MatchCase, Operator, Pattern, Stmt, Suite, TypeParam, + TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, WithItem, }; use ruff_python_whitespace::LineEnding; @@ -207,6 +208,7 @@ impl<'a> Generator<'a> { body, returns, decorator_list, + type_params, .. }) => { self.newlines(if self.indent_depth == 0 { 2 } else { 1 }); @@ -219,6 +221,7 @@ impl<'a> Generator<'a> { statement!({ self.p("def "); self.p_id(name); + self.unparse_type_params(type_params); self.p("("); self.unparse_args(args); self.p(")"); @@ -239,6 +242,7 @@ impl<'a> Generator<'a> { body, returns, decorator_list, + type_params, .. }) => { self.newlines(if self.indent_depth == 0 { 2 } else { 1 }); @@ -251,6 +255,7 @@ impl<'a> Generator<'a> { statement!({ self.p("async def "); self.p_id(name); + self.unparse_type_params(type_params); self.p("("); self.unparse_args(args); self.p(")"); @@ -271,8 +276,8 @@ impl<'a> Generator<'a> { keywords, body, decorator_list, + type_params, range: _, - type_params: _, }) => { self.newlines(if self.indent_depth == 0 { 2 } else { 1 }); for decorator in decorator_list { @@ -284,6 +289,7 @@ impl<'a> Generator<'a> { statement!({ self.p("class "); self.p_id(name); + self.unparse_type_params(type_params); let mut first = true; for base in bases { self.p_if(first, "("); @@ -525,6 +531,18 @@ impl<'a> Generator<'a> { self.indent_depth = self.indent_depth.saturating_sub(1); } } + Stmt::TypeAlias(ast::StmtTypeAlias { + name, + range: _range, + type_params, + value, + }) => { + self.p("type "); + self.unparse_expr(name, precedence::MAX); + self.unparse_type_params(type_params); + self.p(" = "); + self.unparse_expr(value, precedence::ASSIGN); + } Stmt::Raise(ast::StmtRaise { exc, cause, @@ -702,7 +720,6 @@ impl<'a> Generator<'a> { self.p("continue"); }); } - Stmt::TypeAlias(_) => todo!(), } } @@ -830,6 +847,38 @@ impl<'a> Generator<'a> { self.body(&ast.body); } + fn unparse_type_params(&mut self, type_params: &Vec) { + if !type_params.is_empty() { + self.p("["); + let mut first = true; + for type_param in type_params { + self.p_delim(&mut first, ", "); + self.unparse_type_param(type_param); + } + self.p("]"); + } + } + + pub(crate) fn unparse_type_param(&mut self, ast: &TypeParam) { + match ast { + TypeParam::TypeVar(TypeParamTypeVar { name, bound, .. }) => { + self.p_id(name); + if let Some(expr) = bound { + self.p(": "); + self.unparse_expr(expr, precedence::MAX); + } + } + TypeParam::TypeVarTuple(TypeParamTypeVarTuple { name, .. }) => { + self.p("*"); + self.p_id(name); + } + TypeParam::ParamSpec(TypeParamParamSpec { name, .. }) => { + self.p("**"); + self.p_id(name); + } + } + } + pub(crate) fn unparse_expr(&mut self, ast: &Expr, level: u8) { macro_rules! opprec { ($opty:ident, $x:expr, $enu:path, $($var:ident($op:literal, $prec:ident)),*$(,)?) => { @@ -1510,6 +1559,26 @@ mod tests { ); assert_round_trip!( r#"class Foo(Bar, object): + pass"# + ); + assert_round_trip!( + r#"class Foo[T]: + pass"# + ); + assert_round_trip!( + r#"class Foo[T](Bar): + pass"# + ); + assert_round_trip!( + r#"class Foo[*Ts]: + pass"# + ); + assert_round_trip!( + r#"class Foo[**P]: + pass"# + ); + assert_round_trip!( + r#"class Foo[T, U, *Ts, **P]: pass"# ); assert_round_trip!( @@ -1541,6 +1610,22 @@ mod tests { ); assert_round_trip!( r#"def test(a, b=4, /, c=8, d=9): + pass"# + ); + assert_round_trip!( + r#"def test[T](): + pass"# + ); + assert_round_trip!( + r#"def test[*Ts](): + pass"# + ); + assert_round_trip!( + r#"def test[**P](): + pass"# + ); + assert_round_trip!( + r#"def test[T, U, *Ts, **P](): pass"# ); assert_round_trip!( @@ -1596,6 +1681,13 @@ class Foo: def f(): pass"# ); + + // Type aliases + assert_round_trip!(r#"type Foo = int | str"#); + assert_round_trip!(r#"type Foo[T] = list[T]"#); + assert_round_trip!(r#"type Foo[*Ts] = ..."#); + assert_round_trip!(r#"type Foo[**P] = ..."#); + assert_round_trip!(r#"type Foo[T, U, *Ts, **P] = ..."#); } #[test]