Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any expression in comprehensions' evaluated expression #936

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions libcst/_nodes/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,25 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Parameters":
star_kwarg=visit_optional(self, "star_kwarg", self.star_kwarg, visitor),
)

def _safe_to_join_with_lambda(self) -> bool:
"""
Determine if Parameters need a space after the `lambda` keyword. Returns True
iff it's safe to omit the space between `lambda` and these Parameters.

See also `BaseExpression._safe_to_use_with_word_operator`.

For example: `lambda*_: pass`
"""
if len(self.posonly_params) != 0:
return False

# posonly_ind can't appear if above condition is false

if len(self.params) > 0 and self.params[0].star not in {"*", "**"}:
return False

return True

def _codegen_impl(self, state: CodegenState) -> None: # noqa: C901
# Compute the star existence first so we can ask about whether
# each element is the last in the list or not.
Expand Down Expand Up @@ -2115,6 +2134,7 @@ def _validate(self) -> None:
if (
isinstance(whitespace_after_lambda, BaseParenthesizableWhitespace)
and whitespace_after_lambda.empty
and not self.params._safe_to_join_with_lambda()
):
raise CSTValidationError(
"Must have at least one space after lambda when specifying params"
Expand Down Expand Up @@ -3495,7 +3515,7 @@ class BaseSimpleComp(BaseComp, ABC):
#: The expression evaluated during each iteration of the comprehension. This
#: lexically comes before the ``for_in`` clause, but it is semantically the
#: inner-most element, evaluated inside the ``for_in`` clause.
elt: BaseAssignTargetExpression
elt: BaseExpression

#: The ``for ... in ... if ...`` clause that lexically comes after ``elt``. This may
#: be a nested structure for nested comprehensions. See :class:`CompFor` for
Expand Down Expand Up @@ -3528,7 +3548,7 @@ class GeneratorExp(BaseSimpleComp):
"""

#: The expression evaluated and yielded during each iteration of the generator.
elt: BaseAssignTargetExpression
elt: BaseExpression

#: The ``for ... in ... if ...`` clause that comes after ``elt``. This may be a
#: nested structure for nested comprehensions. See :class:`CompFor` for details.
Expand Down Expand Up @@ -3579,7 +3599,7 @@ class ListComp(BaseList, BaseSimpleComp):
"""

#: The expression evaluated and stored during each iteration of the comprehension.
elt: BaseAssignTargetExpression
elt: BaseExpression

#: The ``for ... in ... if ...`` clause that comes after ``elt``. This may be a
#: nested structure for nested comprehensions. See :class:`CompFor` for details.
Expand Down Expand Up @@ -3621,7 +3641,7 @@ class SetComp(BaseSet, BaseSimpleComp):
"""

#: The expression evaluated and stored during each iteration of the comprehension.
elt: BaseAssignTargetExpression
elt: BaseExpression

#: The ``for ... in ... if ...`` clause that comes after ``elt``. This may be a
#: nested structure for nested comprehensions. See :class:`CompFor` for details.
Expand Down Expand Up @@ -3663,10 +3683,10 @@ class DictComp(BaseDict, BaseComp):
"""

#: The key inserted into the dictionary during each iteration of the comprehension.
key: BaseAssignTargetExpression
key: BaseExpression
#: The value associated with the ``key`` inserted into the dictionary during each
#: iteration of the comprehension.
value: BaseAssignTargetExpression
value: BaseExpression

#: The ``for ... in ... if ...`` clause that lexically comes after ``key`` and
#: ``value``. This may be a nested structure for nested comprehensions. See
Expand Down
11 changes: 11 additions & 0 deletions libcst/_nodes/tests/test_dict_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ class DictCompTest(CSTNodeTest):
"parser": parse_expression,
"expected_position": CodeRange((1, 0), (1, 17)),
},
# non-trivial keys & values in DictComp
{
"node": cst.DictComp(
cst.BinaryOperation(cst.Name("k1"), cst.Add(), cst.Name("k2")),
cst.BinaryOperation(cst.Name("v1"), cst.Add(), cst.Name("v2")),
cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")),
),
"code": "{k1 + k2: v1 + v2 for a in b}",
"parser": parse_expression,
"expected_position": CodeRange((1, 0), (1, 29)),
},
# custom whitespace around colon
{
"node": cst.DictComp(
Expand Down
56 changes: 32 additions & 24 deletions libcst/_nodes/tests/test_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,30 +303,6 @@ def test_valid(
),
"at least one space after lambda",
),
(
lambda: cst.Lambda(
cst.Parameters(star_arg=cst.Param(cst.Name("arg"))),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"at least one space after lambda",
),
(
lambda: cst.Lambda(
cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")),)),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"at least one space after lambda",
),
(
lambda: cst.Lambda(
cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"at least one space after lambda",
),
(
lambda: cst.Lambda(
cst.Parameters(
Expand Down Expand Up @@ -944,6 +920,38 @@ class LambdaParserTest(CSTNodeTest):
),
"( lambda : 5 )",
),
# No space between lambda and params
(
cst.Lambda(
cst.Parameters(star_arg=cst.Param(cst.Name("args"), star="*")),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"lambda*args: 5",
),
(
cst.Lambda(
cst.Parameters(star_kwarg=cst.Param(cst.Name("kwargs"), star="**")),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"lambda**kwargs: 5",
),
(
cst.Lambda(
cst.Parameters(
star_arg=cst.ParamStar(
comma=cst.Comma(
cst.SimpleWhitespace(""), cst.SimpleWhitespace("")
)
),
kwonly_params=[cst.Param(cst.Name("args"), star="")],
),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"lambda*,args: 5",
),
)
)
def test_valid(
Expand Down
27 changes: 27 additions & 0 deletions libcst/_nodes/tests/test_simple_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ class SimpleCompTest(CSTNodeTest):
"code": "{a for b in c}",
"parser": parse_expression,
},
# non-trivial elt in GeneratorExp
{
"node": cst.GeneratorExp(
cst.BinaryOperation(cst.Name("a1"), cst.Add(), cst.Name("a2")),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
),
"code": "(a1 + a2 for b in c)",
"parser": parse_expression,
},
# non-trivial elt in ListComp
{
"node": cst.ListComp(
cst.BinaryOperation(cst.Name("a1"), cst.Add(), cst.Name("a2")),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
),
"code": "[a1 + a2 for b in c]",
"parser": parse_expression,
},
# non-trivial elt in SetComp
{
"node": cst.SetComp(
cst.BinaryOperation(cst.Name("a1"), cst.Add(), cst.Name("a2")),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
),
"code": "{a1 + a2 for b in c}",
"parser": parse_expression,
},
# async GeneratorExp
{
"node": cst.GeneratorExp(
Expand Down
Loading