diff --git a/crates/ruff_python_codegen/src/generator.rs b/crates/ruff_python_codegen/src/generator.rs index b488956d9549a..b1110f440f6b0 100644 --- a/crates/ruff_python_codegen/src/generator.rs +++ b/crates/ruff_python_codegen/src/generator.rs @@ -34,6 +34,7 @@ mod precedence { pub(crate) const COMMA: u8 = 21; pub(crate) const NAMED_EXPR: u8 = 23; pub(crate) const ASSERT: u8 = 23; + pub(crate) const COMPREHENSION_ELEMENT: u8 = 27; pub(crate) const LAMBDA: u8 = 27; pub(crate) const IF_EXP: u8 = 27; pub(crate) const COMPREHENSION: u8 = 29; @@ -1052,7 +1053,7 @@ impl<'a> Generator<'a> { range: _range, }) => { self.p("["); - self.unparse_expr(elt, precedence::MAX); + self.unparse_expr(elt, precedence::COMPREHENSION_ELEMENT); self.unparse_comp(generators); self.p("]"); } @@ -1062,7 +1063,7 @@ impl<'a> Generator<'a> { range: _range, }) => { self.p("{"); - self.unparse_expr(elt, precedence::MAX); + self.unparse_expr(elt, precedence::COMPREHENSION_ELEMENT); self.unparse_comp(generators); self.p("}"); } @@ -1073,9 +1074,9 @@ impl<'a> Generator<'a> { range: _range, }) => { self.p("{"); - self.unparse_expr(key, precedence::MAX); + self.unparse_expr(key, precedence::COMPREHENSION_ELEMENT); self.p(": "); - self.unparse_expr(value, precedence::MAX); + self.unparse_expr(value, precedence::COMPREHENSION_ELEMENT); self.unparse_comp(generators); self.p("}"); } @@ -1085,7 +1086,7 @@ impl<'a> Generator<'a> { range: _range, }) => { self.p("("); - self.unparse_expr(elt, precedence::COMMA); + self.unparse_expr(elt, precedence::COMPREHENSION_ELEMENT); self.unparse_comp(generators); self.p(")"); } @@ -1570,6 +1571,8 @@ mod tests { assert_round_trip!("foo(1)"); assert_round_trip!("foo(1, 2)"); assert_round_trip!("foo(x for x in y)"); + assert_round_trip!("foo([x for x in y])"); + assert_round_trip!("foo([(x := 2) for x in y])"); assert_round_trip!("x = yield 1"); assert_round_trip!("return (yield 1)"); assert_round_trip!("lambda: (1, 2, 3)"); @@ -1622,8 +1625,8 @@ mod tests { r#"def f() -> (int, str): pass"# ); - assert_round_trip!("[(await x) async for x in y]"); - assert_round_trip!("[(await i) for i in b if await c]"); + assert_round_trip!("[await x async for x in y]"); + assert_round_trip!("[await i for i in b if await c]"); assert_round_trip!("(await x async for x in y)"); assert_round_trip!( r#"async def read_data(db): @@ -1719,6 +1722,18 @@ class Foo: pass"# ); + assert_round_trip!(r#"[lambda n: n for n in range(10)]"#); + assert_round_trip!(r#"[n[0:2] for n in range(10)]"#); + assert_round_trip!(r#"[n[0] for n in range(10)]"#); + assert_round_trip!(r#"[(n, n * 2) for n in range(10)]"#); + assert_round_trip!(r#"[1 if n % 2 == 0 else 0 for n in range(10)]"#); + assert_round_trip!(r#"[n % 2 == 0 or 0 for n in range(10)]"#); + assert_round_trip!(r#"[(n := 2) for n in range(10)]"#); + assert_round_trip!(r#"((n := 2) for n in range(10))"#); + assert_round_trip!(r#"[n * 2 for n in range(10)]"#); + assert_round_trip!(r#"{n * 2 for n in range(10)}"#); + assert_round_trip!(r#"{i: n * 2 for i, n in enumerate(range(10))}"#); + // Type aliases assert_round_trip!(r#"type Foo = int | str"#); assert_round_trip!(r#"type Foo[T] = list[T]"#);