Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First version of a rdm api interaction module #125 #161

Merged
merged 15 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"pydantic==2.8.0",
"rioxarray>=0.13.0",
"scipy",
"duckdb>=1.1.0",
"tqdm",
"xarray>=2022.3.0"
]
Expand Down
7 changes: 7 additions & 0 deletions src/worldcereal/rdm_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""This sub-module contains utilitary function and tools for worldcereal-classification"""

from worldcereal.rdm_api.rdm_interaction import RdmInteraction

__all__ = [
"RdmInteraction",
]
268 changes: 268 additions & 0 deletions src/worldcereal/rdm_api/rdm_interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
"""Interaction with the WorldCereal RDM API. Used to generate the reference data in geoparquet format for the point extractions."""

from typing import Dict, List, Optional

import duckdb
import geopandas as gpd
import requests
from openeo.rest.auth.oidc import (
OidcClientInfo,
OidcDeviceAuthenticator,
OidcProviderInfo,
)
from requests.adapters import HTTPAdapter
from shapely import wkb
from shapely.geometry.base import BaseGeometry
from urllib3.util.retry import Retry


class NoIntersectingCollections(Exception):
"""Raised when no spatiotemporally intersecting collection IDs are found in the RDM."""


class RdmInteraction:
"""Class to interact with the WorldCereal RDM API."""

# Define the default columns to be extracted from the RDM API
DEFAULT_COLUMNS = [
"sample_id",
"ewoc_code",
"valid_time",
"quality_score_lc",
"quality_score_ct",
]

# RDM API Endpoint
RDM_ENDPOINT = "https://ewoc-rdm-api.iiasa.ac.at"

MAX_RETRIES = 5

def __init__(self, resilient: bool = True):
self.headers = None
self.session = requests.Session()
if resilient:
self._make_resilient()

def _make_resilient(self):
"""Make the session resilient to connection errors."""
retries = Retry(
total=self.MAX_RETRIES,
backoff_factor=0.1,
status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET"],
)
self.session.mount("https://", HTTPAdapter(max_retries=retries))
self.session.mount("http://", HTTPAdapter(max_retries=retries))

def authenticate(self):
"""Authenticate the user with the RDM API via device code flow."""
self.headers = self._get_api_bearer_token()
return self

def _get_api_bearer_token(self) -> dict[str, str]:
"""Get API bearer access token via device code flow.

Returns
-------
dict[str, str]
A Dictionary containing the headers.
"""
provider_info = OidcProviderInfo(
issuer="https://sso.terrascope.be/auth/realms/terrascope"
)

client_info = OidcClientInfo(
client_id="worldcereal-rdm",
provider=provider_info,
)

authenticator = OidcDeviceAuthenticator(client_info=client_info)

tokens = authenticator.get_tokens()

return {"Authorization": f"Bearer {tokens.access_token}"}

def _get_headers(self) -> Dict[str, str]:
"""
Get the headers for the API requests.
Returns
-------
Dict[str, str]
A dictionary containing the headers.
"""
headers = {"accept": "*/*"}
if self.headers:
headers.update(self.headers)
return headers

def _collections_from_rdm(
self,
geometry: BaseGeometry,
temporal_extent: Optional[List[str]] = None,
) -> List[str]:
"""Queries the RDM API and finds all intersection collection IDs for a given geometry and temporal extent.

Parameters
----------
geometry : BaseGeometry
A user-defined geometry for which all intersecting collection IDs need to be found.
temporal_extent : Optional[List[str]], optional
A list of two strings representing the temporal extent, by default None. If None, all available data will be queried.
Returns
-------
List[str]
A List containing the URLs of all intersection collection IDs.
"""

bbox = geometry.bounds
bbox_str = f"Bbox={bbox[0]}&Bbox={bbox[1]}&Bbox={bbox[2]}&Bbox={bbox[3]}"

val_time = (
f"&ValidityTime.Start={temporal_extent[0]}T00%3A00%3A00Z&ValidityTime.End={temporal_extent[1]}T00%3A00%3A00Z"
if temporal_extent
else ""
)

url = f"{self.RDM_ENDPOINT}/collections/search?{bbox_str}{val_time}"

response = self.session.get(url=url, headers=self._get_headers(), timeout=10)

if response.status_code != 200:
raise Exception(f"Error fetching collections: {response.text}")

response_json = response.json()

col_ids = [col["collectionId"] for col in response_json]

if not col_ids:
raise NoIntersectingCollections(
f"No spatiotemporally intersecting collection IDs found in the RDM for the given geometry: {bbox} and temporal extent: {temporal_extent}."
)

return col_ids

def _get_download_urls(
self, collection_ids: List[str], user_id: Optional[str] = None
) -> List[str]:
"""Queries the RDM API and finds all HTTP URLs for the GeoParquet files for each collection ID.

Parameters
----------
collection_ids : List[str]
A list of collection IDs.

Returns
-------
List[str]
A List containing the HTTPs URLs of the GeoParquet files for each collection ID.
"""
urls = []

for id in collection_ids:
url = f"{self.RDM_ENDPOINT}/collections/{id}/download"
response = self.session.get(url, headers=self._get_headers(), timeout=10)
if response.status_code != 200:
raise Exception(
f"Failed to get download URL for collection {id}: {response.text}"
)
urls.append(response.text)

return urls

def _setup_sql_query(
self,
urls: List[str],
geometry: BaseGeometry,
columns: List[str],
temporal_extent: Optional[List[str]] = None,
) -> str:
"""Sets up the SQL query for the GeoParquet files.

Parameters
----------
urls : List[str]
A list of URLs of the GeoParquet files.
geometry : BaseGeometry
A user-defined geometry.
columns :
A list of column names to extract.
temporal_extent : Optional[List[str]], optional
A list of two strings representing the temporal extent, by default None. If None, all available data will be queried.

Returns
-------
str
A SQL query for the GeoParquet files.
"""

combined_query = ""
columns_str = ", ".join(columns)

optional_temporal = (
f"AND valid_time BETWEEN '{temporal_extent[0]}' AND '{temporal_extent[1]}'"
if temporal_extent
else ""
)

for i, url in enumerate(urls):
query = f"""
SELECT {columns_str}, ST_AsWKB(ST_Intersection(ST_MakeValid(geometry), ST_GeomFromText('{str(geometry)}'))) AS wkb_geometry
FROM read_parquet('{url}')
WHERE ST_Intersects(ST_MakeValid(geometry), ST_GeomFromText('{str(geometry)}'))
{optional_temporal}

"""
if i == 0:
combined_query = query
else:
combined_query += f" UNION ALL {query}"

return combined_query

def query_rdm(
self,
geometry: BaseGeometry,
temporal_extent: Optional[List[str]] = None,
columns: List[str] = DEFAULT_COLUMNS,
):
"""Queries the RDM API and generates a GeoParquet file of all intersecting sample IDs.

Parameters
----------
geometry : BaseGeometry
A user-defined polygon. CRS should be EPSG:4326.
temporal_extent : List[str], optional
A list of two strings representing the temporal extent, by default None. If None, all available data will be queried.
Dates should be in the format "YYYY-MM-DD".
columns : List[str], optional
A list of column names to extract., by default DEFAULT_COLUMNS

Returns
-------
gpd.GeoDataFrame
A GeoDataFrame containing the extracted columns and the geometry.
"""
collection_ids = self._collections_from_rdm(
geometry=geometry, temporal_extent=temporal_extent
)
urls = self._get_download_urls(collection_ids)

query = self._setup_sql_query(
urls=urls,
geometry=geometry,
columns=columns,
temporal_extent=temporal_extent,
)

con = duckdb.connect()
con.execute("INSTALL spatial;")
con.execute("LOAD spatial;")

df = con.execute(query).fetch_df()

df["geometry"] = df["wkb_geometry"].apply(lambda x: wkb.loads(bytes(x)))
df.drop(columns=["wkb_geometry"], inplace=True)

gdf = gpd.GeoDataFrame(df, geometry="geometry", crs="EPSG:4326")

return gdf
92 changes: 92 additions & 0 deletions tests/worldcerealtests/test_rdm_interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from unittest.mock import patch

import geopandas as gpd
import pytest
from shapely import Point, Polygon

from worldcereal.rdm_api.rdm_interaction import RdmInteraction


@pytest.fixture
def sample_polygon():
return Polygon([(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)])


@pytest.fixture
def sample_temporal_extent():
return ["2021-01-01", "2021-12-31"]


class TestRdmInteraction:
@patch("requests.Session.get")
def test_collections_from_rdm(
self, mock_requests_get, sample_polygon, sample_temporal_extent
):

mock_requests_get.return_value.status_code = 200
mock_requests_get.return_value.json.return_value = [
{"collectionId": "Foo"},
{"collectionId": "Bar"},
]
interaction = RdmInteraction()
collection_ids = interaction._collections_from_rdm(
geometry=sample_polygon, temporal_extent=sample_temporal_extent
)

assert collection_ids == ["Foo", "Bar"]

bbox = sample_polygon.bounds
geom = f"Bbox={bbox[0]}&Bbox={bbox[1]}&Bbox={bbox[2]}&Bbox={bbox[3]}"
temporal = f"&ValidityTime.Start={sample_temporal_extent[0]}T00%3A00%3A00Z&ValidityTime.End={sample_temporal_extent[1]}T00%3A00%3A00Z"
expected_url = f"{interaction.RDM_ENDPOINT}/collections/search?{geom}{temporal}"
mock_requests_get.assert_called_with(
url=expected_url, headers={"accept": "*/*"}, timeout=10
)

@patch("worldcereal.rdm_api.rdm_interaction.RdmInteraction._get_download_urls")
@patch("worldcereal.rdm_api.rdm_interaction.RdmInteraction._collections_from_rdm")
def test_query_rdm(
self,
mock_get_download_urls,
mock_collections_from_rdm,
sample_polygon,
sample_temporal_extent,
tmp_path,
):

data = {
"col1": ["must", "include", "this", "column"],
"col2": ["and", "this", "One", "Too"],
"col3": ["but", "not", "This", "One"],
"valid_time": [
"2021-01-01",
"2021-12-31",
"2021-06-01",
"2025-05-22",
], # Last date not within sample_temporal_extent
"geometry": [
Point(0.5, 0.5),
Point(0.25, 0.25),
Point(2, 3),
Point(0.75, 0.75),
], # Third point not within sample_polygon
}
gdf = gpd.GeoDataFrame(data, crs="EPSG:4326")
file_path = tmp_path / "sample.parquet"
gdf.to_parquet(file_path)

mock_collections_from_rdm.return_value = [file_path]
mock_get_download_urls.return_value = [file_path]

interaction = RdmInteraction()
result_gdf = interaction.query_rdm(
geometry=sample_polygon,
temporal_extent=sample_temporal_extent,
columns=["col1", "col2"],
)

# Check that col3 and valid_time indeed not included
assert result_gdf.columns.tolist() == ["col1", "col2", "geometry"]

# Check that the third and fourth geometry are not included, as they are outside the spatiotemporal extent
assert len(result_gdf) == 2
Loading