Skip to content

Commit

Permalink
feat: OpenSearchQueryOperator using an endpoint with a self-signed ce…
Browse files Browse the repository at this point in the history
…rtificate (#39788)

* feat: added connection options

* feat: opensearch hook unit tests

* feat: fallback to RequestsHttpConnection

* fix: static checks

* fix: static checks

* fix: static checks

* feat: opensearch static module loading

---------

Co-authored-by: Lukas Verret <[email protected]>
  • Loading branch information
Lukas1v and Lukas Verret authored Jun 8, 2024
1 parent 85be186 commit 1a61eb3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
21 changes: 16 additions & 5 deletions airflow/providers/opensearch/hooks/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@

import json
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any

from opensearchpy import OpenSearch, RequestsHttpConnection

if TYPE_CHECKING:
from opensearchpy import Connection as OpenSearchConnectionClass

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.utils.strings import to_boolean


class OpenSearchHook(BaseHook):
Expand All @@ -40,13 +44,20 @@ class OpenSearchHook(BaseHook):
conn_type = "opensearch"
hook_name = "OpenSearch Hook"

def __init__(self, open_search_conn_id: str, log_query: bool, **kwargs: Any):
def __init__(
self,
open_search_conn_id: str,
log_query: bool,
open_search_conn_class: type[OpenSearchConnectionClass] | None = RequestsHttpConnection,
**kwargs: Any,
):
super().__init__(**kwargs)
self.conn_id = open_search_conn_id
self.log_query = log_query

self.use_ssl = self.conn.extra_dejson.get("use_ssl", False)
self.verify_certs = self.conn.extra_dejson.get("verify_certs", False)
self.use_ssl = to_boolean(str(self.conn.extra_dejson.get("use_ssl", False)))
self.verify_certs = to_boolean(str(self.conn.extra_dejson.get("verify_certs", False)))
self.connection_class = open_search_conn_class
self.__SERVICE = "es"

@cached_property
Expand All @@ -62,7 +73,7 @@ def client(self) -> OpenSearch:
http_auth=auth,
use_ssl=self.use_ssl,
verify_certs=self.verify_certs,
connection_class=RequestsHttpConnection,
connection_class=self.connection_class,
)
return client

Expand Down
12 changes: 11 additions & 1 deletion airflow/providers/opensearch/operators/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from opensearchpy import RequestsHttpConnection
from opensearchpy.exceptions import OpenSearchException

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook

if TYPE_CHECKING:
from opensearchpy import Connection as OpenSearchConnectionClass

from airflow.utils.context import Context


Expand All @@ -42,6 +45,7 @@ class OpenSearchQueryOperator(BaseOperator):
:param search_object: A Search object from opensearch-dsl.
:param index_name: The name of the index to search for documents.
:param opensearch_conn_id: opensearch connection to use
:param opensearch_conn_class: opensearch connection class to use
:param log_query: Whether to log the query used. Defaults to True and logs query used.
"""

Expand All @@ -54,20 +58,26 @@ def __init__(
search_object: Any | None = None,
index_name: str | None = None,
opensearch_conn_id: str = "opensearch_default",
opensearch_conn_class: type[OpenSearchConnectionClass] | None = RequestsHttpConnection,
log_query: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.query = query
self.index_name = index_name
self.opensearch_conn_id = opensearch_conn_id
self.opensearch_conn_class = opensearch_conn_class
self.log_query = log_query
self.search_object = search_object

@cached_property
def hook(self) -> OpenSearchHook:
"""Get an instance of an OpenSearchHook."""
return OpenSearchHook(open_search_conn_id=self.opensearch_conn_id, log_query=self.log_query)
return OpenSearchHook(
open_search_conn_id=self.opensearch_conn_id,
open_search_conn_class=self.opensearch_conn_class,
log_query=self.log_query,
)

def execute(self, context: Context) -> Any:
"""Execute a search against a given index or a Search object on an OpenSearch Cluster."""
Expand Down
28 changes: 28 additions & 0 deletions tests/providers/opensearch/hooks/test_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
# under the License.
from __future__ import annotations

from unittest import mock

import opensearchpy
import pytest
from opensearchpy import Urllib3HttpConnection

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook

pytestmark = pytest.mark.db_test


MOCK_SEARCH_RETURN = {"status": "test"}
DEFAULT_CONN = opensearchpy.connection.http_requests.RequestsHttpConnection


class TestOpenSearchHook:
Expand All @@ -46,3 +52,25 @@ def test_delete_check_parameters(self):
hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
with pytest.raises(AirflowException, match="must include one of either a query or a document id"):
hook.delete(index_name="test_index")

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_hook_param_bool(self, mock_get_connection):
mock_conn = Connection(
conn_id="opensearch_default", extra={"use_ssl": "True", "verify_certs": "True"}
)
mock_get_connection.return_value = mock_conn
hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)

assert isinstance(hook.use_ssl, bool)
assert isinstance(hook.verify_certs, bool)

def test_load_conn_param(self, mock_hook):
hook_default = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
assert hook_default.connection_class == DEFAULT_CONN

hook_Urllib3 = OpenSearchHook(
open_search_conn_id="opensearch_default",
log_query=True,
open_search_conn_class=Urllib3HttpConnection,
)
assert hook_Urllib3.connection_class == Urllib3HttpConnection

0 comments on commit 1a61eb3

Please sign in to comment.