From d41e34a0c518c181b341390b5d9097c0e5b696b7 Mon Sep 17 00:00:00 2001 From: Ethan Leba Date: Sun, 7 Nov 2021 15:35:06 -0500 Subject: [PATCH] Allow booleans to be narrowed to literal types (#10389) --- mypy/checker.py | 19 +++--- mypy/checkexpr.py | 18 ++++-- mypy/typeops.py | 57 +++++++++------- test-data/unit/check-dynamic-typing.test | 4 +- test-data/unit/check-expressions.test | 4 +- test-data/unit/check-narrowing.test | 75 +++++++++++++++++++++- test-data/unit/check-newsemanal.test | 1 + test-data/unit/check-python38.test | 4 +- test-data/unit/check-unreachable-code.test | 8 +-- test-data/unit/typexport-basic.test | 4 +- 10 files changed, 142 insertions(+), 52 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 976eee749653..d4548f581a28 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -51,7 +51,7 @@ map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, try_getting_str_literals_from_type, try_getting_int_literals_from_type, - tuple_fallback, is_singleton_type, try_expanding_enum_to_union, + tuple_fallback, is_singleton_type, try_expanding_sum_type_to_union, true_only, false_only, function_type, get_type_vars, custom_special_method, is_literal_type_like, ) @@ -4583,8 +4583,10 @@ def has_no_custom_eq_checks(t: Type) -> bool: # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively - vartype = type_map[node] - self._check_for_truthy_type(vartype, node) + original_vartype = type_map[node] + self._check_for_truthy_type(original_vartype, node) + vartype = try_expanding_sum_type_to_union(original_vartype, "builtins.bool") + if_type = true_only(vartype) # type: Type else_type = false_only(vartype) # type: Type ref = node # type: Expression @@ -4857,10 +4859,11 @@ def refine_identity_comparison_expression(self, if singleton_index == -1: singleton_index = possible_target_indices[-1] - enum_name = None + sum_type_name = None target = get_proper_type(target) - if isinstance(target, LiteralType) and target.is_enum_literal(): - enum_name = target.fallback.type.fullname + if (isinstance(target, LiteralType) and + (target.is_enum_literal() or isinstance(target.value, bool))): + sum_type_name = target.fallback.type.fullname target_type = [TypeRange(target, is_upper_bound=False)] @@ -4881,8 +4884,8 @@ def refine_identity_comparison_expression(self, expr = operands[i] expr_type = coerce_to_literal(operand_types[i]) - if enum_name is not None: - expr_type = try_expanding_enum_to_union(expr_type, enum_name) + if sum_type_name is not None: + expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name) # We intentionally use 'conditional_type_map' directly here instead of # 'self.conditional_type_map_with_intersection': we only compute ad-hoc diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 51ab7bcdf1f3..7625ee0b64c2 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -66,8 +66,9 @@ FunctionContext, FunctionSigContext, ) from mypy.typeops import ( - tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound, - function_type, callable_type, try_getting_str_literals, custom_special_method, + try_expanding_sum_type_to_union, tuple_fallback, make_simplified_union, + true_only, false_only, erase_to_union_or_bound, function_type, + callable_type, try_getting_str_literals, custom_special_method, is_literal_type_like, ) import mypy.errorcodes as codes @@ -2800,6 +2801,9 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: # '[1] or []' are inferred correctly. ctx = self.type_context[-1] left_type = self.accept(e.left, ctx) + expanded_left_type = try_expanding_sum_type_to_union( + self.accept(e.left, ctx), "builtins.bool" + ) assert e.op in ('and', 'or') # Checked by visit_op_expr @@ -2834,7 +2838,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: # to be unreachable and therefore any errors found in the right branch # should be suppressed. with (self.msg.disable_errors() if right_map is None else nullcontext()): - right_type = self.analyze_cond_branch(right_map, e.right, left_type) + right_type = self.analyze_cond_branch(right_map, e.right, expanded_left_type) if right_map is None: # The boolean expression is statically known to be the left value @@ -2846,11 +2850,11 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: return right_type if e.op == 'and': - restricted_left_type = false_only(left_type) - result_is_left = not left_type.can_be_true + restricted_left_type = false_only(expanded_left_type) + result_is_left = not expanded_left_type.can_be_true elif e.op == 'or': - restricted_left_type = true_only(left_type) - result_is_left = not left_type.can_be_false + restricted_left_type = true_only(expanded_left_type) + result_is_left = not expanded_left_type.can_be_false if isinstance(restricted_left_type, UninhabitedType): # The left operand can never be the result diff --git a/mypy/typeops.py b/mypy/typeops.py index 6adf21fd84e9..af758852a929 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -701,7 +701,7 @@ def is_singleton_type(typ: Type) -> bool: ) -def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> ProperType: +def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType: """Attempts to recursively expand any enum Instances with the given target_fullname into a Union of all of its component LiteralTypes. @@ -723,28 +723,34 @@ class Status(Enum): typ = get_proper_type(typ) if isinstance(typ, UnionType): - items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items] + items = [try_expanding_sum_type_to_union(item, target_fullname) for item in typ.items] return make_simplified_union(items, contract_literals=False) - elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname: - new_items = [] - for name, symbol in typ.type.names.items(): - if not isinstance(symbol.node, Var): - continue - # Skip "_order_" and "__order__", since Enum will remove it - if name in ("_order_", "__order__"): - continue - new_items.append(LiteralType(name, typ)) - # SymbolTables are really just dicts, and dicts are guaranteed to preserve - # insertion order only starting with Python 3.7. So, we sort these for older - # versions of Python to help make tests deterministic. - # - # We could probably skip the sort for Python 3.6 since people probably run mypy - # only using CPython, but we might as well for the sake of full correctness. - if sys.version_info < (3, 7): - new_items.sort(key=lambda lit: lit.value) - return make_simplified_union(new_items, contract_literals=False) - else: - return typ + elif isinstance(typ, Instance) and typ.type.fullname == target_fullname: + if typ.type.is_enum: + new_items = [] + for name, symbol in typ.type.names.items(): + if not isinstance(symbol.node, Var): + continue + # Skip "_order_" and "__order__", since Enum will remove it + if name in ("_order_", "__order__"): + continue + new_items.append(LiteralType(name, typ)) + # SymbolTables are really just dicts, and dicts are guaranteed to preserve + # insertion order only starting with Python 3.7. So, we sort these for older + # versions of Python to help make tests deterministic. + # + # We could probably skip the sort for Python 3.6 since people probably run mypy + # only using CPython, but we might as well for the sake of full correctness. + if sys.version_info < (3, 7): + new_items.sort(key=lambda lit: lit.value) + return make_simplified_union(new_items, contract_literals=False) + elif typ.type.fullname == "builtins.bool": + return make_simplified_union( + [LiteralType(True, typ), LiteralType(False, typ)], + contract_literals=False + ) + + return typ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]: @@ -762,9 +768,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType] for idx, typ in enumerate(proper_types): if isinstance(typ, LiteralType): fullname = typ.fallback.type.fullname - if typ.fallback.type.is_enum: + if typ.fallback.type.is_enum or isinstance(typ.value, bool): if fullname not in sum_types: - sum_types[fullname] = (set(get_enum_values(typ.fallback)), []) + sum_types[fullname] = (set(get_enum_values(typ.fallback)) + if typ.fallback.type.is_enum + else set((True, False)), + []) literals, indexes = sum_types[fullname] literals.discard(typ.value) indexes.append(idx) diff --git a/test-data/unit/check-dynamic-typing.test b/test-data/unit/check-dynamic-typing.test index 8ff82819c104..7b016c342e95 100644 --- a/test-data/unit/check-dynamic-typing.test +++ b/test-data/unit/check-dynamic-typing.test @@ -159,9 +159,9 @@ a or d if int(): c = a in d # E: Incompatible types in assignment (expression has type "bool", variable has type "C") if int(): - c = b and d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C") + c = b and d # E: Incompatible types in assignment (expression has type "Union[Literal[False], Any]", variable has type "C") if int(): - c = b or d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C") + c = b or d # E: Incompatible types in assignment (expression has type "Union[Literal[True], Any]", variable has type "C") if int(): b = a + d if int(): diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 818ecc6628b7..70fda50d617e 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -323,11 +323,11 @@ if int(): if int(): b = b or b if int(): - b = b and a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool") + b = b and a # E: Incompatible types in assignment (expression has type "Union[Literal[False], A]", variable has type "bool") if int(): b = a and b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool") if int(): - b = b or a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool") + b = b or a # E: Incompatible types in assignment (expression has type "Union[Literal[True], A]", variable has type "bool") if int(): b = a or b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool") class A: pass diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 6b030ae49a86..d6b25ef456d9 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1047,8 +1047,81 @@ else: if str_or_bool_literal is not True and str_or_bool_literal is not False: reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.str" else: - reveal_type(str_or_bool_literal) # N: Revealed type is "Union[Literal[False], Literal[True]]" + reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.bool" +[builtins fixtures/primitives.pyi] + +[case testNarrowingBooleanIdentityCheck] +# flags: --strict-optional +from typing import Optional +from typing_extensions import Literal + +bool_val: bool + +if bool_val is not False: + reveal_type(bool_val) # N: Revealed type is "Literal[True]" +else: + reveal_type(bool_val) # N: Revealed type is "Literal[False]" + +opt_bool_val: Optional[bool] + +if opt_bool_val is not None: + reveal_type(opt_bool_val) # N: Revealed type is "builtins.bool" + +if opt_bool_val is not False: + reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[True], None]" +else: + reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]" +[builtins fixtures/primitives.pyi] + +[case testNarrowingBooleanTruthiness] +# flags: --strict-optional +from typing import Optional +from typing_extensions import Literal + +bool_val: bool + +if bool_val: + reveal_type(bool_val) # N: Revealed type is "Literal[True]" +else: + reveal_type(bool_val) # N: Revealed type is "Literal[False]" +reveal_type(bool_val) # N: Revealed type is "builtins.bool" + +opt_bool_val: Optional[bool] + +if opt_bool_val: + reveal_type(opt_bool_val) # N: Revealed type is "Literal[True]" +else: + reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[False], None]" +reveal_type(opt_bool_val) # N: Revealed type is "Union[builtins.bool, None]" +[builtins fixtures/primitives.pyi] + +[case testNarrowingBooleanBoolOp] +# flags: --strict-optional +from typing import Optional +from typing_extensions import Literal + +bool_a: bool +bool_b: bool + +if bool_a and bool_b: + reveal_type(bool_a) # N: Revealed type is "Literal[True]" + reveal_type(bool_b) # N: Revealed type is "Literal[True]" +else: + reveal_type(bool_a) # N: Revealed type is "builtins.bool" + reveal_type(bool_b) # N: Revealed type is "builtins.bool" + +if not bool_a or bool_b: + reveal_type(bool_a) # N: Revealed type is "builtins.bool" + reveal_type(bool_b) # N: Revealed type is "builtins.bool" +else: + reveal_type(bool_a) # N: Revealed type is "Literal[True]" + reveal_type(bool_b) # N: Revealed type is "Literal[False]" + +if True and bool_b: + reveal_type(bool_b) # N: Revealed type is "Literal[True]" +x = True and bool_b +reveal_type(x) # N: Revealed type is "builtins.bool" [builtins fixtures/primitives.pyi] [case testNarrowingTypedDictUsingEnumLiteral] diff --git a/test-data/unit/check-newsemanal.test b/test-data/unit/check-newsemanal.test index ed999d1f46b6..d44d8ad0348e 100644 --- a/test-data/unit/check-newsemanal.test +++ b/test-data/unit/check-newsemanal.test @@ -304,6 +304,7 @@ from a import x def f(): pass [targets a, b, a, a.y, b.f, __main__] +[builtins fixtures/tuple.pyi] [case testNewAnalyzerRedefinitionAndDeferral1b] import a diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index c7289593e810..b72beb400f35 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -411,10 +411,10 @@ from typing import Optional maybe_str: Optional[str] if (is_str := maybe_str is not None): - reveal_type(is_str) # N: Revealed type is "builtins.bool" + reveal_type(is_str) # N: Revealed type is "Literal[True]" reveal_type(maybe_str) # N: Revealed type is "builtins.str" else: - reveal_type(is_str) # N: Revealed type is "builtins.bool" + reveal_type(is_str) # N: Revealed type is "Literal[False]" reveal_type(maybe_str) # N: Revealed type is "None" reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]" diff --git a/test-data/unit/check-unreachable-code.test b/test-data/unit/check-unreachable-code.test index 2a55db1a33c6..832bb5737a37 100644 --- a/test-data/unit/check-unreachable-code.test +++ b/test-data/unit/check-unreachable-code.test @@ -533,11 +533,11 @@ f = (PY3 or PY2) and 's' g = (PY2 or PY3) or 's' h = (PY3 or PY2) or 's' reveal_type(a) # N: Revealed type is "builtins.bool" -reveal_type(b) # N: Revealed type is "builtins.str" -reveal_type(c) # N: Revealed type is "builtins.str" +reveal_type(b) # N: Revealed type is "Literal['s']" +reveal_type(c) # N: Revealed type is "Literal['s']" reveal_type(d) # N: Revealed type is "builtins.bool" -reveal_type(e) # N: Revealed type is "builtins.str" -reveal_type(f) # N: Revealed type is "builtins.str" +reveal_type(e) # N: Revealed type is "Literal['s']" +reveal_type(f) # N: Revealed type is "Literal['s']" reveal_type(g) # N: Revealed type is "builtins.bool" reveal_type(h) # N: Revealed type is "builtins.bool" [builtins fixtures/ops.pyi] diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index deb43f6d316f..4f40117d18d2 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -247,7 +247,7 @@ elif not a: [out] NameExpr(3) : builtins.bool IntExpr(4) : Literal[1]? -NameExpr(5) : builtins.bool +NameExpr(5) : Literal[False] UnaryExpr(5) : builtins.bool IntExpr(6) : Literal[1]? @@ -259,7 +259,7 @@ while a: [builtins fixtures/bool.pyi] [out] NameExpr(3) : builtins.bool -NameExpr(4) : builtins.bool +NameExpr(4) : Literal[True] -- Simple type inference