Skip to content

Commit

Permalink
Allow booleans to be narrowed to literal types (#10389)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethan-leba committed Nov 7, 2021
1 parent 871ec6b commit d41e34a
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 52 deletions.
19 changes: 11 additions & 8 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
tuple_fallback, is_singleton_type, try_expanding_sum_type_to_union,
true_only, false_only, function_type, get_type_vars, custom_special_method,
is_literal_type_like,
)
Expand Down Expand Up @@ -4583,8 +4583,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:

# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
vartype = type_map[node]
self._check_for_truthy_type(vartype, node)
original_vartype = type_map[node]
self._check_for_truthy_type(original_vartype, node)
vartype = try_expanding_sum_type_to_union(original_vartype, "builtins.bool")

if_type = true_only(vartype) # type: Type
else_type = false_only(vartype) # type: Type
ref = node # type: Expression
Expand Down Expand Up @@ -4857,10 +4859,11 @@ def refine_identity_comparison_expression(self,
if singleton_index == -1:
singleton_index = possible_target_indices[-1]

enum_name = None
sum_type_name = None
target = get_proper_type(target)
if isinstance(target, LiteralType) and target.is_enum_literal():
enum_name = target.fallback.type.fullname
if (isinstance(target, LiteralType) and
(target.is_enum_literal() or isinstance(target.value, bool))):
sum_type_name = target.fallback.type.fullname

target_type = [TypeRange(target, is_upper_bound=False)]

Expand All @@ -4881,8 +4884,8 @@ def refine_identity_comparison_expression(self,
expr = operands[i]
expr_type = coerce_to_literal(operand_types[i])

if enum_name is not None:
expr_type = try_expanding_enum_to_union(expr_type, enum_name)
if sum_type_name is not None:
expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name)

# We intentionally use 'conditional_type_map' directly here instead of
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc
Expand Down
18 changes: 11 additions & 7 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@
FunctionContext, FunctionSigContext,
)
from mypy.typeops import (
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
function_type, callable_type, try_getting_str_literals, custom_special_method,
try_expanding_sum_type_to_union, tuple_fallback, make_simplified_union,
true_only, false_only, erase_to_union_or_bound, function_type,
callable_type, try_getting_str_literals, custom_special_method,
is_literal_type_like,
)
import mypy.errorcodes as codes
Expand Down Expand Up @@ -2800,6 +2801,9 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
# '[1] or []' are inferred correctly.
ctx = self.type_context[-1]
left_type = self.accept(e.left, ctx)
expanded_left_type = try_expanding_sum_type_to_union(
self.accept(e.left, ctx), "builtins.bool"
)

assert e.op in ('and', 'or') # Checked by visit_op_expr

Expand Down Expand Up @@ -2834,7 +2838,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
# to be unreachable and therefore any errors found in the right branch
# should be suppressed.
with (self.msg.disable_errors() if right_map is None else nullcontext()):
right_type = self.analyze_cond_branch(right_map, e.right, left_type)
right_type = self.analyze_cond_branch(right_map, e.right, expanded_left_type)

if right_map is None:
# The boolean expression is statically known to be the left value
Expand All @@ -2846,11 +2850,11 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
return right_type

if e.op == 'and':
restricted_left_type = false_only(left_type)
result_is_left = not left_type.can_be_true
restricted_left_type = false_only(expanded_left_type)
result_is_left = not expanded_left_type.can_be_true
elif e.op == 'or':
restricted_left_type = true_only(left_type)
result_is_left = not left_type.can_be_false
restricted_left_type = true_only(expanded_left_type)
result_is_left = not expanded_left_type.can_be_false

if isinstance(restricted_left_type, UninhabitedType):
# The left operand can never be the result
Expand Down
57 changes: 33 additions & 24 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def is_singleton_type(typ: Type) -> bool:
)


def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> ProperType:
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType:
"""Attempts to recursively expand any enum Instances with the given target_fullname
into a Union of all of its component LiteralTypes.
Expand All @@ -723,28 +723,34 @@ class Status(Enum):
typ = get_proper_type(typ)

if isinstance(typ, UnionType):
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
items = [try_expanding_sum_type_to_union(item, target_fullname) for item in typ.items]
return make_simplified_union(items, contract_literals=False)
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname:
new_items = []
for name, symbol in typ.type.names.items():
if not isinstance(symbol.node, Var):
continue
# Skip "_order_" and "__order__", since Enum will remove it
if name in ("_order_", "__order__"):
continue
new_items.append(LiteralType(name, typ))
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
# insertion order only starting with Python 3.7. So, we sort these for older
# versions of Python to help make tests deterministic.
#
# We could probably skip the sort for Python 3.6 since people probably run mypy
# only using CPython, but we might as well for the sake of full correctness.
if sys.version_info < (3, 7):
new_items.sort(key=lambda lit: lit.value)
return make_simplified_union(new_items, contract_literals=False)
else:
return typ
elif isinstance(typ, Instance) and typ.type.fullname == target_fullname:
if typ.type.is_enum:
new_items = []
for name, symbol in typ.type.names.items():
if not isinstance(symbol.node, Var):
continue
# Skip "_order_" and "__order__", since Enum will remove it
if name in ("_order_", "__order__"):
continue
new_items.append(LiteralType(name, typ))
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
# insertion order only starting with Python 3.7. So, we sort these for older
# versions of Python to help make tests deterministic.
#
# We could probably skip the sort for Python 3.6 since people probably run mypy
# only using CPython, but we might as well for the sake of full correctness.
if sys.version_info < (3, 7):
new_items.sort(key=lambda lit: lit.value)
return make_simplified_union(new_items, contract_literals=False)
elif typ.type.fullname == "builtins.bool":
return make_simplified_union(
[LiteralType(True, typ), LiteralType(False, typ)],
contract_literals=False
)

return typ


def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]:
Expand All @@ -762,9 +768,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]
for idx, typ in enumerate(proper_types):
if isinstance(typ, LiteralType):
fullname = typ.fallback.type.fullname
if typ.fallback.type.is_enum:
if typ.fallback.type.is_enum or isinstance(typ.value, bool):
if fullname not in sum_types:
sum_types[fullname] = (set(get_enum_values(typ.fallback)), [])
sum_types[fullname] = (set(get_enum_values(typ.fallback))
if typ.fallback.type.is_enum
else set((True, False)),
[])
literals, indexes = sum_types[fullname]
literals.discard(typ.value)
indexes.append(idx)
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-dynamic-typing.test
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ a or d
if int():
c = a in d # E: Incompatible types in assignment (expression has type "bool", variable has type "C")
if int():
c = b and d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
c = b and d # E: Incompatible types in assignment (expression has type "Union[Literal[False], Any]", variable has type "C")
if int():
c = b or d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
c = b or d # E: Incompatible types in assignment (expression has type "Union[Literal[True], Any]", variable has type "C")
if int():
b = a + d
if int():
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,11 @@ if int():
if int():
b = b or b
if int():
b = b and a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool")
b = b and a # E: Incompatible types in assignment (expression has type "Union[Literal[False], A]", variable has type "bool")
if int():
b = a and b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool")
if int():
b = b or a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool")
b = b or a # E: Incompatible types in assignment (expression has type "Union[Literal[True], A]", variable has type "bool")
if int():
b = a or b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool")
class A: pass
Expand Down
75 changes: 74 additions & 1 deletion test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1047,8 +1047,81 @@ else:
if str_or_bool_literal is not True and str_or_bool_literal is not False:
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.str"
else:
reveal_type(str_or_bool_literal) # N: Revealed type is "Union[Literal[False], Literal[True]]"
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.bool"
[builtins fixtures/primitives.pyi]

[case testNarrowingBooleanIdentityCheck]
# flags: --strict-optional
from typing import Optional
from typing_extensions import Literal

bool_val: bool

if bool_val is not False:
reveal_type(bool_val) # N: Revealed type is "Literal[True]"
else:
reveal_type(bool_val) # N: Revealed type is "Literal[False]"

opt_bool_val: Optional[bool]

if opt_bool_val is not None:
reveal_type(opt_bool_val) # N: Revealed type is "builtins.bool"

if opt_bool_val is not False:
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[True], None]"
else:
reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]"
[builtins fixtures/primitives.pyi]

[case testNarrowingBooleanTruthiness]
# flags: --strict-optional
from typing import Optional
from typing_extensions import Literal

bool_val: bool

if bool_val:
reveal_type(bool_val) # N: Revealed type is "Literal[True]"
else:
reveal_type(bool_val) # N: Revealed type is "Literal[False]"
reveal_type(bool_val) # N: Revealed type is "builtins.bool"

opt_bool_val: Optional[bool]

if opt_bool_val:
reveal_type(opt_bool_val) # N: Revealed type is "Literal[True]"
else:
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[False], None]"
reveal_type(opt_bool_val) # N: Revealed type is "Union[builtins.bool, None]"
[builtins fixtures/primitives.pyi]

[case testNarrowingBooleanBoolOp]
# flags: --strict-optional
from typing import Optional
from typing_extensions import Literal

bool_a: bool
bool_b: bool

if bool_a and bool_b:
reveal_type(bool_a) # N: Revealed type is "Literal[True]"
reveal_type(bool_b) # N: Revealed type is "Literal[True]"
else:
reveal_type(bool_a) # N: Revealed type is "builtins.bool"
reveal_type(bool_b) # N: Revealed type is "builtins.bool"

if not bool_a or bool_b:
reveal_type(bool_a) # N: Revealed type is "builtins.bool"
reveal_type(bool_b) # N: Revealed type is "builtins.bool"
else:
reveal_type(bool_a) # N: Revealed type is "Literal[True]"
reveal_type(bool_b) # N: Revealed type is "Literal[False]"

if True and bool_b:
reveal_type(bool_b) # N: Revealed type is "Literal[True]"

x = True and bool_b
reveal_type(x) # N: Revealed type is "builtins.bool"
[builtins fixtures/primitives.pyi]

[case testNarrowingTypedDictUsingEnumLiteral]
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/check-newsemanal.test
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ from a import x
def f(): pass

[targets a, b, a, a.y, b.f, __main__]
[builtins fixtures/tuple.pyi]

[case testNewAnalyzerRedefinitionAndDeferral1b]
import a
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-python38.test
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,10 @@ from typing import Optional
maybe_str: Optional[str]

if (is_str := maybe_str is not None):
reveal_type(is_str) # N: Revealed type is "builtins.bool"
reveal_type(is_str) # N: Revealed type is "Literal[True]"
reveal_type(maybe_str) # N: Revealed type is "builtins.str"
else:
reveal_type(is_str) # N: Revealed type is "builtins.bool"
reveal_type(is_str) # N: Revealed type is "Literal[False]"
reveal_type(maybe_str) # N: Revealed type is "None"

reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
Expand Down
8 changes: 4 additions & 4 deletions test-data/unit/check-unreachable-code.test
Original file line number Diff line number Diff line change
Expand Up @@ -533,11 +533,11 @@ f = (PY3 or PY2) and 's'
g = (PY2 or PY3) or 's'
h = (PY3 or PY2) or 's'
reveal_type(a) # N: Revealed type is "builtins.bool"
reveal_type(b) # N: Revealed type is "builtins.str"
reveal_type(c) # N: Revealed type is "builtins.str"
reveal_type(b) # N: Revealed type is "Literal['s']"
reveal_type(c) # N: Revealed type is "Literal['s']"
reveal_type(d) # N: Revealed type is "builtins.bool"
reveal_type(e) # N: Revealed type is "builtins.str"
reveal_type(f) # N: Revealed type is "builtins.str"
reveal_type(e) # N: Revealed type is "Literal['s']"
reveal_type(f) # N: Revealed type is "Literal['s']"
reveal_type(g) # N: Revealed type is "builtins.bool"
reveal_type(h) # N: Revealed type is "builtins.bool"
[builtins fixtures/ops.pyi]
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/typexport-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ elif not a:
[out]
NameExpr(3) : builtins.bool
IntExpr(4) : Literal[1]?
NameExpr(5) : builtins.bool
NameExpr(5) : Literal[False]
UnaryExpr(5) : builtins.bool
IntExpr(6) : Literal[1]?

Expand All @@ -259,7 +259,7 @@ while a:
[builtins fixtures/bool.pyi]
[out]
NameExpr(3) : builtins.bool
NameExpr(4) : builtins.bool
NameExpr(4) : Literal[True]


-- Simple type inference
Expand Down

0 comments on commit d41e34a

Please sign in to comment.