Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow booleans to be narrowed to literal types #10389

Merged
merged 8 commits into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
tuple_fallback, is_singleton_type, try_expanding_sum_type_to_union,
true_only, false_only, function_type, get_type_vars, custom_special_method,
is_literal_type_like,
)
Expand Down Expand Up @@ -4513,8 +4513,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:

# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
vartype = type_map[node]
self._check_for_truthy_type(vartype, node)
original_vartype = type_map[node]
self._check_for_truthy_type(original_vartype, node)
vartype = try_expanding_sum_type_to_union(original_vartype, "builtins.bool")

if_type = true_only(vartype) # type: Type
else_type = false_only(vartype) # type: Type
ref = node # type: Expression
Expand Down Expand Up @@ -4789,10 +4791,11 @@ def refine_identity_comparison_expression(self,
if singleton_index == -1:
singleton_index = possible_target_indices[-1]

enum_name = None
sum_type_name = None
target = get_proper_type(target)
if isinstance(target, LiteralType) and target.is_enum_literal():
enum_name = target.fallback.type.fullname
if (isinstance(target, LiteralType) and
(target.is_enum_literal() or isinstance(target.value, bool))):
sum_type_name = target.fallback.type.fullname

target_type = [TypeRange(target, is_upper_bound=False)]

Expand All @@ -4813,8 +4816,8 @@ def refine_identity_comparison_expression(self,
expr = operands[i]
expr_type = coerce_to_literal(operand_types[i])

if enum_name is not None:
expr_type = try_expanding_enum_to_union(expr_type, enum_name)
if sum_type_name is not None:
expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name)

# We intentionally use 'conditional_type_map' directly here instead of
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc
Expand Down
18 changes: 11 additions & 7 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@
FunctionContext, FunctionSigContext,
)
from mypy.typeops import (
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
function_type, callable_type, try_getting_str_literals, custom_special_method,
try_expanding_sum_type_to_union, tuple_fallback, make_simplified_union,
true_only, false_only, erase_to_union_or_bound, function_type,
callable_type, try_getting_str_literals, custom_special_method,
is_literal_type_like,
)
import mypy.errorcodes as codes
Expand Down Expand Up @@ -2787,6 +2788,9 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
# '[1] or []' are inferred correctly.
ctx = self.type_context[-1]
left_type = self.accept(e.left, ctx)
expanded_left_type = try_expanding_sum_type_to_union(
self.accept(e.left, ctx), "builtins.bool"
)

assert e.op in ('and', 'or') # Checked by visit_op_expr

Expand Down Expand Up @@ -2821,7 +2825,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
# to be unreachable and therefore any errors found in the right branch
# should be suppressed.
with (self.msg.disable_errors() if right_map is None else nullcontext()):
right_type = self.analyze_cond_branch(right_map, e.right, left_type)
right_type = self.analyze_cond_branch(right_map, e.right, expanded_left_type)

if right_map is None:
# The boolean expression is statically known to be the left value
Expand All @@ -2833,11 +2837,11 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
return right_type

if e.op == 'and':
restricted_left_type = false_only(left_type)
result_is_left = not left_type.can_be_true
restricted_left_type = false_only(expanded_left_type)
result_is_left = not expanded_left_type.can_be_true
elif e.op == 'or':
restricted_left_type = true_only(left_type)
result_is_left = not left_type.can_be_false
restricted_left_type = true_only(expanded_left_type)
result_is_left = not expanded_left_type.can_be_false

if isinstance(restricted_left_type, UninhabitedType):
# The left operand can never be the result
Expand Down
57 changes: 33 additions & 24 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def is_singleton_type(typ: Type) -> bool:
)


def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> ProperType:
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType:
"""Attempts to recursively expand any enum Instances with the given target_fullname
into a Union of all of its component LiteralTypes.

Expand All @@ -721,28 +721,34 @@ class Status(Enum):
typ = get_proper_type(typ)

if isinstance(typ, UnionType):
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
items = [try_expanding_sum_type_to_union(item, target_fullname) for item in typ.items]
return make_simplified_union(items, contract_literals=False)
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname:
new_items = []
for name, symbol in typ.type.names.items():
if not isinstance(symbol.node, Var):
continue
# Skip "_order_" and "__order__", since Enum will remove it
if name in ("_order_", "__order__"):
continue
new_items.append(LiteralType(name, typ))
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
# insertion order only starting with Python 3.7. So, we sort these for older
# versions of Python to help make tests deterministic.
#
# We could probably skip the sort for Python 3.6 since people probably run mypy
# only using CPython, but we might as well for the sake of full correctness.
if sys.version_info < (3, 7):
new_items.sort(key=lambda lit: lit.value)
return make_simplified_union(new_items, contract_literals=False)
else:
return typ
elif isinstance(typ, Instance) and typ.type.fullname == target_fullname:
if typ.type.is_enum:
new_items = []
for name, symbol in typ.type.names.items():
if not isinstance(symbol.node, Var):
continue
# Skip "_order_" and "__order__", since Enum will remove it
if name in ("_order_", "__order__"):
continue
new_items.append(LiteralType(name, typ))
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
# insertion order only starting with Python 3.7. So, we sort these for older
# versions of Python to help make tests deterministic.
#
# We could probably skip the sort for Python 3.6 since people probably run mypy
# only using CPython, but we might as well for the sake of full correctness.
if sys.version_info < (3, 7):
new_items.sort(key=lambda lit: lit.value)
return make_simplified_union(new_items, contract_literals=False)
elif typ.type.fullname == "builtins.bool":
return make_simplified_union(
[LiteralType(True, typ), LiteralType(False, typ)],
contract_literals=False
)

return typ


def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]:
Expand All @@ -760,9 +766,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]
for idx, typ in enumerate(proper_types):
if isinstance(typ, LiteralType):
fullname = typ.fallback.type.fullname
if typ.fallback.type.is_enum:
if typ.fallback.type.is_enum or isinstance(typ.value, bool):
if fullname not in sum_types:
sum_types[fullname] = (set(get_enum_values(typ.fallback)), [])
sum_types[fullname] = (set(get_enum_values(typ.fallback))
if typ.fallback.type.is_enum
else set((True, False)),
[])
literals, indexes = sum_types[fullname]
literals.discard(typ.value)
indexes.append(idx)
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-dynamic-typing.test
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ a or d
if int():
c = a in d # E: Incompatible types in assignment (expression has type "bool", variable has type "C")
if int():
c = b and d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
c = b and d # E: Incompatible types in assignment (expression has type "Union[Literal[False], Any]", variable has type "C")
if int():
c = b or d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
c = b or d # E: Incompatible types in assignment (expression has type "Union[Literal[True], Any]", variable has type "C")
if int():
b = a + d
if int():
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,11 @@ if int():
if int():
b = b or b
if int():
b = b and a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool")
b = b and a # E: Incompatible types in assignment (expression has type "Union[Literal[False], A]", variable has type "bool")
if int():
b = a and b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool")
if int():
b = b or a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool")
b = b or a # E: Incompatible types in assignment (expression has type "Union[Literal[True], A]", variable has type "bool")
if int():
b = a or b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool")
class A: pass
Expand Down
75 changes: 74 additions & 1 deletion test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1026,8 +1026,81 @@ else:
if str_or_bool_literal is not True and str_or_bool_literal is not False:
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.str"
else:
reveal_type(str_or_bool_literal) # N: Revealed type is "Union[Literal[False], Literal[True]]"
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.bool"
[builtins fixtures/primitives.pyi]

[case testNarrowingBooleanIdentityCheck]
# flags: --strict-optional
from typing import Optional
from typing_extensions import Literal

bool_val: bool

if bool_val is not False:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should if bool_val and if not bool_val also do narrowing? These can't be overridden for bool.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that seems useful, i'll look into it 👍

reveal_type(bool_val) # N: Revealed type is "Literal[True]"
else:
reveal_type(bool_val) # N: Revealed type is "Literal[False]"

opt_bool_val: Optional[bool]

if opt_bool_val is not None:
reveal_type(opt_bool_val) # N: Revealed type is "builtins.bool"

if opt_bool_val is not False:
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[True], None]"
else:
reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]"
[builtins fixtures/primitives.pyi]

[case testNarrowingBooleanTruthiness]
# flags: --strict-optional
from typing import Optional
from typing_extensions import Literal

bool_val: bool

if bool_val:
reveal_type(bool_val) # N: Revealed type is "Literal[True]"
else:
reveal_type(bool_val) # N: Revealed type is "Literal[False]"
reveal_type(bool_val) # N: Revealed type is "builtins.bool"

opt_bool_val: Optional[bool]

if opt_bool_val:
reveal_type(opt_bool_val) # N: Revealed type is "Literal[True]"
else:
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[False], None]"
reveal_type(opt_bool_val) # N: Revealed type is "Union[builtins.bool, None]"
[builtins fixtures/primitives.pyi]

[case testNarrowingBooleanBoolOp]
# flags: --strict-optional
from typing import Optional
from typing_extensions import Literal

bool_a: bool
bool_b: bool

if bool_a and bool_b:
reveal_type(bool_a) # N: Revealed type is "Literal[True]"
reveal_type(bool_b) # N: Revealed type is "Literal[True]"
else:
reveal_type(bool_a) # N: Revealed type is "builtins.bool"
reveal_type(bool_b) # N: Revealed type is "builtins.bool"

if not bool_a or bool_b:
reveal_type(bool_a) # N: Revealed type is "builtins.bool"
reveal_type(bool_b) # N: Revealed type is "builtins.bool"
else:
reveal_type(bool_a) # N: Revealed type is "Literal[True]"
reveal_type(bool_b) # N: Revealed type is "Literal[False]"

if True and bool_b:
reveal_type(bool_b) # N: Revealed type is "Literal[True]"

x = True and bool_b
reveal_type(x) # N: Revealed type is "builtins.bool"
[builtins fixtures/primitives.pyi]

[case testNarrowingTypedDictUsingEnumLiteral]
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/check-newsemanal.test
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ from a import x
def f(): pass

[targets a, b, a, a.y, b.f, __main__]
[builtins fixtures/tuple.pyi]

[case testNewAnalyzerRedefinitionAndDeferral1b]
import a
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-python38.test
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,10 @@ from typing import Optional
maybe_str: Optional[str]

if (is_str := maybe_str is not None):
reveal_type(is_str) # N: Revealed type is "builtins.bool"
reveal_type(is_str) # N: Revealed type is "Literal[True]"
reveal_type(maybe_str) # N: Revealed type is "builtins.str"
else:
reveal_type(is_str) # N: Revealed type is "builtins.bool"
reveal_type(is_str) # N: Revealed type is "Literal[False]"
reveal_type(maybe_str) # N: Revealed type is "None"

reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
Expand Down
8 changes: 4 additions & 4 deletions test-data/unit/check-unreachable-code.test
Original file line number Diff line number Diff line change
Expand Up @@ -533,11 +533,11 @@ f = (PY3 or PY2) and 's'
g = (PY2 or PY3) or 's'
h = (PY3 or PY2) or 's'
reveal_type(a) # N: Revealed type is "builtins.bool"
reveal_type(b) # N: Revealed type is "builtins.str"
reveal_type(c) # N: Revealed type is "builtins.str"
reveal_type(b) # N: Revealed type is "Literal['s']"
reveal_type(c) # N: Revealed type is "Literal['s']"
reveal_type(d) # N: Revealed type is "builtins.bool"
reveal_type(e) # N: Revealed type is "builtins.str"
reveal_type(f) # N: Revealed type is "builtins.str"
reveal_type(e) # N: Revealed type is "Literal['s']"
reveal_type(f) # N: Revealed type is "Literal['s']"
reveal_type(g) # N: Revealed type is "builtins.bool"
reveal_type(h) # N: Revealed type is "builtins.bool"
[builtins fixtures/ops.pyi]
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/typexport-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ elif not a:
[out]
NameExpr(3) : builtins.bool
IntExpr(4) : Literal[1]?
NameExpr(5) : builtins.bool
NameExpr(5) : Literal[False]
UnaryExpr(5) : builtins.bool
IntExpr(6) : Literal[1]?

Expand All @@ -259,7 +259,7 @@ while a:
[builtins fixtures/bool.pyi]
[out]
NameExpr(3) : builtins.bool
NameExpr(4) : builtins.bool
NameExpr(4) : Literal[True]


-- Simple type inference
Expand Down