Skip to content

Commit

Permalink
Implement unparse for type aliases and parameters (astral-sh#5869)
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb authored and evanrittenhouse committed Jul 19, 2023
1 parent babc25d commit 4511a59
Showing 1 changed file with 95 additions and 3 deletions.
98 changes: 95 additions & 3 deletions crates/ruff_python_ast/src/source_code/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use std::ops::Deref;
use rustpython_literal::escape::{AsciiEscape, Escape, UnicodeEscape};
use rustpython_parser::ast::{
self, Alias, Arg, Arguments, BoolOp, CmpOp, Comprehension, Constant, ConversionFlag,
ExceptHandler, Expr, Identifier, MatchCase, Operator, Pattern, Stmt, Suite, WithItem,
ExceptHandler, Expr, Identifier, MatchCase, Operator, Pattern, Stmt, Suite, TypeParam,
TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, WithItem,
};

use ruff_python_whitespace::LineEnding;
Expand Down Expand Up @@ -207,6 +208,7 @@ impl<'a> Generator<'a> {
body,
returns,
decorator_list,
type_params,
..
}) => {
self.newlines(if self.indent_depth == 0 { 2 } else { 1 });
Expand All @@ -219,6 +221,7 @@ impl<'a> Generator<'a> {
statement!({
self.p("def ");
self.p_id(name);
self.unparse_type_params(type_params);
self.p("(");
self.unparse_args(args);
self.p(")");
Expand All @@ -239,6 +242,7 @@ impl<'a> Generator<'a> {
body,
returns,
decorator_list,
type_params,
..
}) => {
self.newlines(if self.indent_depth == 0 { 2 } else { 1 });
Expand All @@ -251,6 +255,7 @@ impl<'a> Generator<'a> {
statement!({
self.p("async def ");
self.p_id(name);
self.unparse_type_params(type_params);
self.p("(");
self.unparse_args(args);
self.p(")");
Expand All @@ -271,8 +276,8 @@ impl<'a> Generator<'a> {
keywords,
body,
decorator_list,
type_params,
range: _,
type_params: _,
}) => {
self.newlines(if self.indent_depth == 0 { 2 } else { 1 });
for decorator in decorator_list {
Expand All @@ -284,6 +289,7 @@ impl<'a> Generator<'a> {
statement!({
self.p("class ");
self.p_id(name);
self.unparse_type_params(type_params);
let mut first = true;
for base in bases {
self.p_if(first, "(");
Expand Down Expand Up @@ -525,6 +531,18 @@ impl<'a> Generator<'a> {
self.indent_depth = self.indent_depth.saturating_sub(1);
}
}
Stmt::TypeAlias(ast::StmtTypeAlias {
name,
range: _range,
type_params,
value,
}) => {
self.p("type ");
self.unparse_expr(name, precedence::MAX);
self.unparse_type_params(type_params);
self.p(" = ");
self.unparse_expr(value, precedence::ASSIGN);
}
Stmt::Raise(ast::StmtRaise {
exc,
cause,
Expand Down Expand Up @@ -702,7 +720,6 @@ impl<'a> Generator<'a> {
self.p("continue");
});
}
Stmt::TypeAlias(_) => todo!(),
}
}

Expand Down Expand Up @@ -830,6 +847,38 @@ impl<'a> Generator<'a> {
self.body(&ast.body);
}

fn unparse_type_params(&mut self, type_params: &Vec<TypeParam>) {
if !type_params.is_empty() {
self.p("[");
let mut first = true;
for type_param in type_params {
self.p_delim(&mut first, ", ");
self.unparse_type_param(type_param);
}
self.p("]");
}
}

pub(crate) fn unparse_type_param(&mut self, ast: &TypeParam) {
match ast {
TypeParam::TypeVar(TypeParamTypeVar { name, bound, .. }) => {
self.p_id(name);
if let Some(expr) = bound {
self.p(": ");
self.unparse_expr(expr, precedence::MAX);
}
}
TypeParam::TypeVarTuple(TypeParamTypeVarTuple { name, .. }) => {
self.p("*");
self.p_id(name);
}
TypeParam::ParamSpec(TypeParamParamSpec { name, .. }) => {
self.p("**");
self.p_id(name);
}
}
}

pub(crate) fn unparse_expr(&mut self, ast: &Expr, level: u8) {
macro_rules! opprec {
($opty:ident, $x:expr, $enu:path, $($var:ident($op:literal, $prec:ident)),*$(,)?) => {
Expand Down Expand Up @@ -1510,6 +1559,26 @@ mod tests {
);
assert_round_trip!(
r#"class Foo(Bar, object):
pass"#
);
assert_round_trip!(
r#"class Foo[T]:
pass"#
);
assert_round_trip!(
r#"class Foo[T](Bar):
pass"#
);
assert_round_trip!(
r#"class Foo[*Ts]:
pass"#
);
assert_round_trip!(
r#"class Foo[**P]:
pass"#
);
assert_round_trip!(
r#"class Foo[T, U, *Ts, **P]:
pass"#
);
assert_round_trip!(
Expand Down Expand Up @@ -1541,6 +1610,22 @@ mod tests {
);
assert_round_trip!(
r#"def test(a, b=4, /, c=8, d=9):
pass"#
);
assert_round_trip!(
r#"def test[T]():
pass"#
);
assert_round_trip!(
r#"def test[*Ts]():
pass"#
);
assert_round_trip!(
r#"def test[**P]():
pass"#
);
assert_round_trip!(
r#"def test[T, U, *Ts, **P]():
pass"#
);
assert_round_trip!(
Expand Down Expand Up @@ -1596,6 +1681,13 @@ class Foo:
def f():
pass"#
);

// Type aliases
assert_round_trip!(r#"type Foo = int | str"#);
assert_round_trip!(r#"type Foo[T] = list[T]"#);
assert_round_trip!(r#"type Foo[*Ts] = ..."#);
assert_round_trip!(r#"type Foo[**P] = ..."#);
assert_round_trip!(r#"type Foo[T, U, *Ts, **P] = ..."#);
}

#[test]
Expand Down

0 comments on commit 4511a59

Please sign in to comment.