diff --git a/test/test_ansible_inventory.py b/test/test_ansible_inventory.py new file mode 100644 index 00000000..70404b55 --- /dev/null +++ b/test/test_ansible_inventory.py @@ -0,0 +1,95 @@ +import pytest + +from testinfra.backend import parse_hostspec +from testinfra.utils.ansible_runner import expand_pattern, get_hosts, Inventory + + +@pytest.fixture +def inventory() -> Inventory: + """Hosts are always under a group, the default is "ungrouped" if using the + ini file format. The "all" meta-group always contains all hosts when + expanded.""" + return { + "_meta": { + "hostvars": { + "a": None, + "b": None, + "c": None, + } + }, + "all": { + "children": ["nested"], + }, + "left": { + "hosts": ["a", "b"], + }, + "right": { + "hosts": ["b", "c"], + }, + "combined": { + "children": ["left", "right"], + }, + "nested": { + "children": ["combined"], + } + } + + +def test_expand_pattern_simple(inventory: Inventory): + """Simple names are matched, recurring into groups if needed.""" + # direct hostname + assert expand_pattern("a", inventory) == {"a"} + # group + assert expand_pattern("left", inventory) == {"a", "b"} + # meta-group + assert expand_pattern("combined", inventory) == {"a", "b", "c"} + # meta-meta-group + assert expand_pattern("nested", inventory) == {"a", "b", "c"} + + +def test_expand_pattern_fnmatch(inventory: Inventory): + """Simple names are matched, recurring into groups if needed.""" + # l->left + assert expand_pattern("l*", inventory) == {"a", "b"} + # any single letter name + assert expand_pattern("?", inventory) == {"a", "b", "c"} + + +def test_expand_pattern_regex(inventory: Inventory): + """Simple names are matched, recurring into groups if needed.""" + # simple character matching - "l" matches "left" but not "all" + assert expand_pattern("~l", inventory) == {"a", "b"} + # "b" matches an exact host, not any group + assert expand_pattern("~b", inventory) == {"b"} + # "a" will match all + assert expand_pattern("~a", inventory) == {"a", "b", "c"} + + +def test_get_hosts(inventory: Inventory): + """Multiple names/patterns can be combined.""" + assert get_hosts("a", inventory) == ["a"] + # the two pattern separators are handled + assert get_hosts("a:b", inventory) == ["a", "b"] + assert get_hosts("a,b", inventory) == ["a", "b"] + # difference works + assert get_hosts("left:!right", inventory) == ["a"] + # intersection works + assert get_hosts("left:&right", inventory) == ["b"] + # intersection is taken with the intersection of the intersection groups + assert get_hosts("all:&left:&right", inventory) == ["b"] + # when the intersections ends up empty, so does the result + assert get_hosts("all:&a:&c", inventory) == [] + # negation is taken with the union of negation groups + assert get_hosts("all:!a:!c", inventory) == ["b"] + + +@pytest.mark.parametrize("left", ["h1", "!h1", "&h1", "~h1", "*h1"]) +@pytest.mark.parametrize("sep", [":", ","]) +@pytest.mark.parametrize("right", ["h2", "!h2", "&h2", "~h2", "*h2", ""]) +def test_parse_hostspec(left: str, sep: str, right: str): + """Ansible's host patterns are parsed without issue.""" + if right: + pattern = f"{left}{sep}{right}" + else: + pattern = left + assert parse_hostspec(pattern) == (pattern, {}) diff --git a/testinfra/utils/ansible_runner.py b/testinfra/utils/ansible_runner.py index 084c03b6..54e845e9 100644 --- a/testinfra/utils/ansible_runner.py +++ b/testinfra/utils/ansible_runner.py @@ -16,8 +16,9 @@ import ipaddress import json import os +import re import tempfile -from typing import Any, Callable, Iterator, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Union import testinfra import testinfra.host @@ -26,6 +27,86 @@ local = testinfra.get_host("local://") +Inventory = Dict[str, Any] + + +def expand_group(name: str, inventory: Inventory) -> Iterator[str]: + """Return all the underlying hostnames for the given group name/pattern.""" + group = inventory.get(name) + if group is None: + return + + # this is a meta-group so recurse + children = group.get("children") + if children is not None: + for child in children: + yield from expand_group(child, inventory) + + # this is a regular group + hosts = group.get("hosts") + if hosts is not None: + yield from iter(hosts) + + +def expand_pattern(pattern: str, inventory: Inventory) -> Set[str]: + """Return all underlying hostnames for the given name/pattern.""" + if pattern.startswith("~"): + # this is a regex, so cut off the indicating character + pattern = re.compile(pattern[1:]) + # match is used, not search or fullmatch + filter_ = lambda l: [i for i in l if pattern.match(i)] + else: + filter_ = lambda l: fnmatch.filter(l, pattern) + + # hosts in the inventory directly matched by the pattern + matching_hosts = set(filter_(expand_group('all', inventory))) + + # look for matches in the groups + for group in filter_(inventory.keys()): + if group == "_meta": + continue + matching_hosts.update(expand_group(group, inventory)) + + return matching_hosts + + +def get_hosts(pattern: str, inventory: Inventory) -> List[str]: + """Return hostnames with a name/group that matches the given name/pattern. + + Reference: + https://docs.ansible.com/ansible/latest/inventory_guide/intro_patterns.html + + This is but a shadow of Ansible's full InventoryManager. The source of the + `inventory_hostnames` module would be a good starting point for a more + faithful reproduction if this turns out to be insufficient. + """ + from ansible.inventory.manager import split_host_pattern + + patterns = split_host_pattern(pattern) + + positive = set() + intersect = None + negative = set() + + for requirement in patterns: + if requirement.startswith('&'): + expanded = expand_pattern(requirement[1:], inventory) + if intersect is None: + intersect = expanded + else: + intersect &= expanded + elif requirement.startswith('!'): + negative.update(expand_pattern(requirement[1:], inventory)) + else: + positive.update(expand_pattern(requirement, inventory)) + + result = positive + if intersect is not None: + result &= intersect + if negative: + result -= negative + return sorted(result) + def get_ansible_config() -> configparser.ConfigParser: fname = os.environ.get("ANSIBLE_CONFIG") @@ -45,9 +126,6 @@ def get_ansible_config() -> configparser.ConfigParser: return config -Inventory = dict[str, Any] - - def get_ansible_inventory( config: configparser.ConfigParser, inventory_file: Optional[str] ) -> Inventory: @@ -216,16 +294,8 @@ def get_config( return testinfra.get_host(spec, **kwargs) -def itergroup(inventory: Inventory, group: str) -> Iterator[str]: - for host in inventory.get(group, {}).get("hosts", []): - yield host - for g in inventory.get(group, {}).get("children", []): - for host in itergroup(inventory, g): - yield host - - def is_empty_inventory(inventory: Inventory) -> bool: - return not any(True for _ in itergroup(inventory, "all")) + return next(expand_group("all", inventory), None) is None class AnsibleRunner: @@ -275,25 +345,15 @@ def __init__(self, inventory_file: Optional[str] = None): def get_hosts(self, pattern: str = "all") -> list[str]: inventory = self.inventory - result = set() if is_empty_inventory(inventory): # empty inventory should not return any hosts except for localhost if pattern == "localhost": - result.add("localhost") - else: - raise RuntimeError( - "No inventory was parsed (missing file ?), " - "only implicit localhost is available" - ) - else: - for group in inventory: - groupmatch = fnmatch.fnmatch(group, pattern) - if groupmatch: - result |= set(itergroup(inventory, group)) - for host in inventory[group].get("hosts", []): - if fnmatch.fnmatch(host, pattern): - result.add(host) - return sorted(result) + return ["localhost"] + raise RuntimeError( + "No inventory was parsed (missing file ?), " + "only implicit localhost is available" + ) + return get_hosts(pattern, inventory) @functools.cached_property def inventory(self) -> Inventory: @@ -315,7 +375,7 @@ def get_variables(self, host: str) -> dict[str, Any]: for group in sorted(inventory): if group == "_meta": continue - groups[group] = sorted(itergroup(inventory, group)) + groups[group] = sorted(expand_group(group, inventory)) if host in groups[group]: group_names.append(group)