Skip to content

Commit

Permalink
Switch colcon_core.extension_point to importlib.metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
cottsay committed Jul 11, 2023
1 parent 088bd12 commit f8ee3a4
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 79 deletions.
76 changes: 38 additions & 38 deletions colcon_core/extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
import os
import traceback

try:
from importlib.metadata import distributions
from importlib.metadata import EntryPoint
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import distributions
from importlib_metadata import EntryPoint
from importlib_metadata import entry_points

from colcon_core.environment_variable import EnvironmentVariable
from colcon_core.logging import colcon_logger
from pkg_resources import EntryPoint
from pkg_resources import iter_entry_points
from pkg_resources import WorkingSet

"""Environment variable to block extensions"""
EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE = EnvironmentVariable(
Expand Down Expand Up @@ -44,27 +50,25 @@ def get_all_extension_points():
colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None)

entry_points = defaultdict(dict)
working_set = WorkingSet()
for dist in sorted(working_set):
entry_map = dist.get_entry_map()
for group_name in entry_map.keys():
seen = set()
for dist in distributions():
if dist.name in seen:
continue
seen.add(dist.name)
for entry_point in dist.entry_points:
# skip groups which are not registered as extension points
if group_name not in colcon_extension_points:
if entry_point.group not in colcon_extension_points:
continue

group = entry_map[group_name]
for entry_point_name, entry_point in group.items():
if entry_point_name in entry_points[group_name]:
previous = entry_points[group_name][entry_point_name]
logger.error(
f"Entry point '{group_name}.{entry_point_name}' is "
f"declared multiple times, '{entry_point}' "
f"overwriting '{previous}'")
value = entry_point.module_name
if entry_point.attrs:
value += f":{'.'.join(entry_point.attrs)}"
entry_points[group_name][entry_point_name] = (
value, dist.project_name, getattr(dist, 'version', None))
if entry_point.name in entry_points[entry_point.group]:
previous = entry_points[entry_point.group][entry_point.name]
logger.error(
f"Entry point '{entry_point.group}.{entry_point.name}' is "
f"declared multiple times, '{entry_point.value}' "
f"from '{dist._path}' "
f"overwriting '{previous}'")
entry_points[entry_point.group][entry_point.name] = \
(entry_point.value, dist.name, dist.version)
return entry_points


Expand All @@ -76,19 +80,21 @@ def get_extension_points(group):
:returns: mapping of extension point names to extension point values
:rtype: dict
"""
entry_points = {}
for entry_point in iter_entry_points(group=group):
if entry_point.name in entry_points:
previous_entry_point = entry_points[entry_point.name]
extension_points = {}
try:
# Python 3.10 and newer
query = entry_points(group=group)
except TypeError:
query = entry_points().get(group, ())
for entry_point in query:
if entry_point.name in extension_points:
previous_entry_point = extension_points[entry_point.name]
logger.error(
f"Entry point '{group}.{entry_point.name}' is declared "
f"multiple times, '{entry_point}' overwriting "
f"multiple times, '{entry_point.value}' overwriting "
f"'{previous_entry_point}'")
value = entry_point.module_name
if entry_point.attrs:
value += f":{'.'.join(entry_point.attrs)}"
entry_points[entry_point.name] = value
return entry_points
extension_points[entry_point.name] = entry_point.value
return extension_points


def load_extension_points(group, *, excludes=None):
Expand Down Expand Up @@ -146,10 +152,4 @@ def load_extension_point(name, value, group):
raise RuntimeError(
'The entry point name is listed in the environment variable '
f"'{EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE.name}'")
if ':' in value:
module_name, attr = value.split(':', 1)
attrs = attr.split('.')
else:
module_name = value
attrs = ()
return EntryPoint(name, module_name, attrs).resolve()
return EntryPoint(name, value, group).load()
2 changes: 1 addition & 1 deletion stdeb.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[colcon-core]
No-Python2:
Depends3: python3-distlib, python3-empy, python3-pytest, python3-setuptools
Depends3: python3-distlib, python3-empy, python3-pytest, python3-setuptools, python3 (>= 3.7) | python3-importlib-metadata
Recommends3: python3-pytest-cov
Suggests3: python3-pytest-repeat, python3-pytest-rerunfailures
Suite: bionic focal jammy stretch buster bullseye
Expand Down
1 change: 0 additions & 1 deletion test/spell_check.words
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ apache
argcomplete
argparse
asyncio
attrs
autouse
basepath
bazqux
Expand Down
78 changes: 39 additions & 39 deletions test/test_extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,96 +17,96 @@
from .environment_context import EnvironmentContext


Group1 = EntryPoint('group1', 'g1')
Group2 = EntryPoint('group2', 'g2')
Group1 = EntryPoint('group1', 'g1', EXTENSION_POINT_GROUP_NAME)
Group2 = EntryPoint('group2', 'g2', EXTENSION_POINT_GROUP_NAME)
ExtA = EntryPoint('extA', 'eA', Group1.name)
ExtB = EntryPoint('extB', 'eB', Group1.name)


class Dist():

project_name = 'dist'
version = '0.0.0'

def __init__(self, group_name, group):
self._group_name = group_name
self._group = group
def __init__(self, entry_points):
self.name = f'dist-{id(self)}'
self._entry_points = entry_points

def __lt__(self, other):
return self._group_name < other._group_name
@property
def entry_points(self):
return list(self._entry_points)

def get_entry_map(self):
return self._group


def iter_entry_points(*, group):
def iter_entry_points(*, group=None):
if group == EXTENSION_POINT_GROUP_NAME:
return [Group1, Group2]
assert group == Group1.name
ep1 = EntryPoint('extA', 'eA')
ep2 = EntryPoint('extB', 'eB')
return [ep1, ep2]
elif group == Group1.name:
return [ExtA, ExtB]
assert not group
return {
EXTENSION_POINT_GROUP_NAME: [Group1, Group2],
Group1.name: [ExtA, ExtB],
}


def working_set():
def distributions():
return [
Dist('group1', {
'group1': {ep.name: ep for ep in iter_entry_points(group='group1')}
}),
Dist('group2', {'group2': {'extC': EntryPoint('extC', 'eC')}}),
Dist('groupX', {'groupX': {'extD': EntryPoint('extD', 'eD')}}),
Dist(iter_entry_points(group='group1')),
Dist([EntryPoint('extC', 'eC', Group2.name)]),
Dist([EntryPoint('extD', 'eD', 'groupX')]),
]


def test_all_extension_points():
with patch(
'colcon_core.extension_point.iter_entry_points',
'colcon_core.extension_point.entry_points',
side_effect=iter_entry_points
):
with patch(
'colcon_core.extension_point.WorkingSet',
side_effect=working_set
'colcon_core.extension_point.distributions',
side_effect=distributions
):
# successfully load a known entry point
extension_points = get_all_extension_points()
assert set(extension_points.keys()) == {'group1', 'group2'}
assert set(extension_points['group1'].keys()) == {'extA', 'extB'}
assert extension_points['group1']['extA'] == (
'eA', Dist.project_name, None)
assert extension_points['group1']['extA'][0] == 'eA'


def test_extension_point_blocklist():
# successful loading of extension point without a blocklist
with patch(
'colcon_core.extension_point.iter_entry_points',
'colcon_core.extension_point.entry_points',
side_effect=iter_entry_points
):
with patch(
'colcon_core.extension_point.WorkingSet',
side_effect=working_set
'colcon_core.extension_point.distributions',
side_effect=distributions
):
extension_points = get_extension_points('group1')
assert 'extA' in extension_points.keys()
extension_point = extension_points['extA']
assert extension_point == 'eA'

with patch.object(EntryPoint, 'resolve', return_value=None) as resolve:
with patch.object(EntryPoint, 'load', return_value=None) as load:
load_extension_point('extA', 'eA', 'group1')
assert resolve.call_count == 1
assert load.call_count == 1

# successful loading of entry point not in blocklist
resolve.reset_mock()
load.reset_mock()
with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([
'group1.extB', 'group2.extC'])
):
load_extension_point('extA', 'eA', 'group1')
assert resolve.call_count == 1
assert load.call_count == 1

# entry point in a blocked group can't be loaded
resolve.reset_mock()
load.reset_mock()
with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST='group1'):
with pytest.raises(RuntimeError) as e:
load_extension_point('extA', 'eA', 'group1')
assert 'The entry point group name is listed in the environment ' \
'variable' in str(e.value)
assert resolve.call_count == 0
assert load.call_count == 0

# entry point listed in the blocklist can't be loaded
with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([
Expand All @@ -116,10 +116,10 @@ def test_extension_point_blocklist():
load_extension_point('extA', 'eA', 'group1')
assert 'The entry point name is listed in the environment ' \
'variable' in str(e.value)
assert resolve.call_count == 0
assert load.call_count == 0


def entry_point_resolve(self, *args, **kwargs):
def entry_point_load(self, *args, **kwargs):
if self.name == 'exception':
raise Exception('entry point raising exception')
if self.name == 'runtime_error':
Expand All @@ -129,7 +129,7 @@ def entry_point_resolve(self, *args, **kwargs):
return DEFAULT


@patch.object(EntryPoint, 'resolve', entry_point_resolve)
@patch.object(EntryPoint, 'load', entry_point_load)
@patch(
'colcon_core.extension_point.get_extension_points',
return_value={'exception': 'a', 'runtime_error': 'b', 'success': 'c'}
Expand Down

0 comments on commit f8ee3a4

Please sign in to comment.