Skip to content

Commit

Permalink
Respond to PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jsignell committed Mar 30, 2023
1 parent ab28740 commit 2f9e57f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 19 deletions.
32 changes: 23 additions & 9 deletions pystac/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from copy import deepcopy
from functools import partial
from html import escape
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -335,7 +336,7 @@ def add_items(
self.add_item(item, strategy=strategy)

def get_child(
self, id: str, recursive: bool = False
self, id: str, recursive: bool = False, sort_links_by_id: bool = True
) -> Optional[Union["Catalog", Collection]]:
"""Gets the child of this catalog with the given ID, if it exists.
Expand All @@ -344,20 +345,33 @@ def get_child(
recursive : If True, search this catalog and all children for the
item; otherwise, only search the children of this catalog. Defaults
to False.
sort_links_by_id : If True, links containing the ID will be checked
first. If links doe not contin the ID then setting this to False
will improve performance. Defaults to True.
Return:
Optional Catalog or Collection: The child with the given ID,
or None if not found.
"""
if not recursive:
return next(
(
cast(Union[pystac.Catalog, pystac.Collection], c)
for c in self.get_stac_objects(pystac.RelType.CHILD, _prefer=id)
if c.id == id
),
None,
)
children: Iterable[Union["Catalog", pystac.Collection]]
if not sort_links_by_id:
children = self.get_children()
else:

def sort_function(id: str, links: List[Link]) -> List[Link]:
return sorted(
links,
key=lambda x: (href := x.get_href()) is None or id not in href,
)

children = map(
lambda x: cast(Union[pystac.Catalog, pystac.Collection], x),
self.get_stac_objects(
pystac.RelType.CHILD, modify_links=partial(sort_function, id)
),
)
return next((c for c in children if c.id == id), None)
else:
for root, _, _ in self.walk():
child = root.get_child(id, recursive=False)
Expand Down
15 changes: 7 additions & 8 deletions pystac/stac_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -349,7 +350,7 @@ def get_stac_objects(
self,
rel: Union[str, pystac.RelType],
typ: Optional[Type[STACObject]] = None,
_prefer: Optional[str] = None,
modify_links: Optional[Callable[[List[Link]], List[Link]]] = None,
) -> Iterable[STACObject]:
"""Gets the :class:`~pystac.STACObject` instances that are linked to
by links with their ``rel`` property matching the passed in argument.
Expand All @@ -359,20 +360,18 @@ def get_stac_objects(
``rel`` property against.
typ : If not ``None``, objects will only be yielded if they are instances of
``typ``.
modify_links : A function that modifies the list of links before they are
iterated over. For instance this option can be used to sort the list
so that links matching a particular pattern are earlier in the iterator.
Returns:
Iterable[STACObjects]: A possibly empty iterable of STACObjects that are
connected to this object through links with the given ``rel`` and are of
type ``typ`` (if given).
"""
links = self.links[:]
if _prefer is not None:
# if _prefer is set, put links that contain it in their hrefs first
links = sorted(
links,
key=lambda x: (href := x.get_href()) is None
or str(_prefer) not in href,
)
if modify_links:
links = modify_links(links)

for i in range(0, len(links)):
link = links[i]
Expand Down
28 changes: 26 additions & 2 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def test_clone_cant_mutate_original(self) -> None:

def test_multiple_extents(self) -> None:
cat1 = TestCases.case_1()
country = cat1.get_child("country-2")
country = cat1.get_child("country-1")
assert country is not None
col1 = country.get_child("area-2-2")
col1 = country.get_child("area-1-1")
assert col1 is not None
col1.validate()
self.assertIsInstance(col1, Collection)
Expand Down Expand Up @@ -541,3 +541,27 @@ def test_remove_hierarchical_links(
for link in collection.links:
assert not link.is_hierarchical()
assert bool(collection.get_single_link("canonical")) == add_canonical


@pytest.mark.parametrize("child", ["country-1", "country-2"])
def test_get_child_checks_links_where_hrefs_contains_id_first(
test_case_1_catalog: Catalog, child: str
) -> None:
cat1 = test_case_1_catalog
country = cat1.get_child(child)
assert country is not None
child_links = [link for link in cat1.links if link.rel == pystac.RelType.CHILD]
for link in child_links:
if country.id not in link.href:
assert not link.is_resolved()


def test_get_child_sort_links_by_id_is_configurable(
test_case_1_catalog: Catalog,
) -> None:
cat1 = test_case_1_catalog
country = cat1.get_child("country-2", sort_links_by_id=False)
assert country is not None
child_links = [link for link in cat1.links if link.rel == pystac.RelType.CHILD]
for link in child_links:
assert link.is_resolved()

0 comments on commit 2f9e57f

Please sign in to comment.