Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add introspection helper functions for third-party libraries #203

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from typing_extensions import clear_overloads, get_overloads, overload
from typing_extensions import NamedTuple
from typing_extensions import override, deprecated, Buffer, TypeAliasType, TypeVar
from typing_extensions import get_typing_objects_by_name_of, is_typing_name
from _typed_dict_test_helper import Foo, FooGeneric

# Flags used to mark tests that only apply after a specific
Expand Down Expand Up @@ -4988,5 +4989,91 @@ class MyAlias(TypeAliasType):
pass


class IntrospectionHelperTests(BaseTestCase):
def test_typing_objects_by_name_of(self):
for name in typing_extensions.__all__:
with self.subTest(name=name):
objs = get_typing_objects_by_name_of(name)
self.assertIsInstance(objs, tuple)
self.assertIn(len(objs), (1, 2))
te_obj = getattr(typing_extensions, name)
if len(objs) == 1:
self.assertIs(te_obj, getattr(typing, name, te_obj))
else:
self.assertTrue(hasattr(typing, name))
self.assertIsNot(te_obj, getattr(typing, name))

with self.assertRaisesRegex(
ValueError,
"Neither typing nor typing_extensions has an object called 'foo'"
):
get_typing_objects_by_name_of("foo")

def test_typing_objects_by_name_not_in_typing_extensions(self):
objs = get_typing_objects_by_name_of("ByteString")
self.assertIsInstance(objs, tuple)
self.assertEqual(len(objs), 1)
bytestring = objs[0]
self.assertIs(bytestring, typing.ByteString)
self.assertEqual(bytestring.__module__, "typing")

def test_typing_objects_by_name_of_2(self):
classvar_objs = get_typing_objects_by_name_of("ClassVar")
self.assertEqual(len(classvar_objs), 1)
classvar_obj = classvar_objs[0]
self.assertIs(classvar_obj, typing.ClassVar)
self.assertIs(classvar_obj, typing_extensions.ClassVar)
self.assertEqual(classvar_obj.__module__, "typing")

@skipIf(TYPING_3_12_0, "We reexport TypeAliasType from typing on 3.12+")
def test_typing_objects_by_name_of_2(self):
name = "TypeAliasType"
# Sanity check; the test won't work correctly if this doesn't hold true:
self.assertFalse(hasattr(typing, name))
typealiastype_objs = get_typing_objects_by_name_of(name)
self.assertEqual(len(typealiastype_objs), 1)
typealiastype_obj = typealiastype_objs[0]
self.assertIs(typealiastype_obj, typing_extensions.TypeAliasType)
self.assertEqual(typealiastype_obj.__module__, "typing_extensions")

@skipUnless(
(3, 8) <= sys.version_info < (3, 12),
(
"Needs a Python version where typing.Protocol "
"and typing_extensions.Protocol are different objects"
)
)
def test_typing_objects_by_name_of_3(self):
name = "Protocol"
# Sanity check; the test won't work correctly if this doesn't hold true:
self.assertTrue(hasattr(typing, name))
protocol_objs = get_typing_objects_by_name_of(name)
self.assertEqual(len(protocol_objs), 2)
modules = {obj.__module__ for obj in protocol_objs}
self.assertEqual(modules, {"typing", "typing_extensions"})

def test_is_typing_name(self):
for name in typing_extensions.__all__:
te_obj = getattr(typing_extensions, name)
self.assertTrue(is_typing_name(te_obj, name))
if hasattr(typing, name):
typing_obj = getattr(typing, name)
self.assertTrue(is_typing_name(typing_obj, name))

def test_is_typing_name_fails_appropriately(self):
self.assertFalse(is_typing_name(typing_extensions.NoReturn, "ClassVar"))
self.assertFalse(is_typing_name(typing.NoReturn, "ClassVar"))
error_msg = "Neither typing nor typing_extensions has an object called 'foo'"
with self.assertRaisesRegex(ValueError, error_msg):
is_typing_name(typing_extensions.NoReturn, "foo")
with self.assertRaisesRegex(ValueError, error_msg):
is_typing_name(typing_extensions.NoReturn, "foo")

def test_is_typing_name_not_in_typing_extensions(self):
# Sanity check -- this is a useless test otherwise:
self.assertFalse(hasattr(typing_extensions, "ByteString"))
self.assertTrue(is_typing_name(typing.ByteString, "ByteString"))


if __name__ == '__main__':
main()
40 changes: 40 additions & 0 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@
'NoReturn',
'Required',
'NotRequired',

# Introspection helpers unique to typing_extensions
# These will never be added to typing.py in CPython
'get_typing_objects_by_name_of',
'is_typing_name',
]

# for backward compatibility
Expand Down Expand Up @@ -2873,3 +2878,38 @@ def __ror__(self, left):
if not _is_unionable(left):
return NotImplemented
return typing.Union[left, self]


#############################################################
# Introspection helpers for third-party libraries
#
# These are not part of the typing-module API,
# and nor will they ever become part of the typing-module API.
#
# They are specific to typing-extensions
##############################################################


@functools.lru_cache(maxsize=None)
def get_typing_objects_by_name_of(name: str) -> typing.Tuple[Any, ...]:
try:
te_obj = globals()[name]
except KeyError:
try:
typing_obj = getattr(typing, name)
except AttributeError:
raise ValueError(
f"Neither typing nor typing_extensions has an object called {name!r}!"
) from None
else:
return (typing_obj,)
else:
if hasattr(typing, name):
typing_obj = getattr(typing, name)
if typing_obj is not te_obj:
return (te_obj, typing_obj)
return (te_obj,)


def is_typing_name(obj: object, name: str) -> bool:
return any(obj is thing for thing in get_typing_objects_by_name_of(name))