Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify ssl verification in druid operator #37673

Merged
43 changes: 28 additions & 15 deletions airflow/providers/apache/druid/hooks/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import time
from enum import Enum
from typing import Any, Iterable
from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterable

import requests
from pydruid.db import connect
Expand All @@ -28,6 +29,9 @@
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from airflow.models import Connection


class IngestionType(Enum):
"""
Expand All @@ -53,17 +57,16 @@ class DruidHook(BaseHook):
the Druid job for the status of the ingestion job.
Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the job failed
:param verify_ssl: Either a boolean, in which case it controls whether we verify the server's TLS
certificate, or a string, in which case it must be a path to a CA bundle to use.
Defaults to True
:param verify_ssl: Whether to use SSL encryption to submit indexing job. If set to False then checks
connection information for path to a CA bundle to use. Defaults to True
"""

def __init__(
self,
druid_ingest_conn_id: str = "druid_ingest_default",
timeout: int = 1,
max_ingestion_time: int | None = None,
verify_ssl: bool | str = True,
verify_ssl: bool = True,
) -> None:
super().__init__()
self.druid_ingest_conn_id = druid_ingest_conn_id
Expand All @@ -75,16 +78,19 @@ def __init__(
if self.timeout < 1:
raise ValueError("Druid timeout should be equal or greater than 1")

@cached_property
def conn(self) -> Connection:
return self.get_connection(self.druid_ingest_conn_id)

def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH) -> str:
"""Get Druid connection url."""
conn = self.get_connection(self.druid_ingest_conn_id)
host = conn.host
port = conn.port
conn_type = conn.conn_type or "http"
host = self.conn.host
port = self.conn.port
conn_type = self.conn.conn_type or "http"
if ingestion_type == IngestionType.BATCH:
endpoint = conn.extra_dejson.get("endpoint", "")
endpoint = self.conn.extra_dejson.get("endpoint", "")
else:
endpoint = conn.extra_dejson.get("msq_endpoint", "")
endpoint = self.conn.extra_dejson.get("msq_endpoint", "")
return f"{conn_type}://{host}:{port}/{endpoint}"

def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
Expand All @@ -93,14 +99,21 @@ def get_auth(self) -> requests.auth.HTTPBasicAuth | None:

If these details have not been set then returns None.
"""
conn = self.get_connection(self.druid_ingest_conn_id)
user = conn.login
password = conn.password
user = self.conn.login
password = self.conn.password
if user is not None and password is not None:
return requests.auth.HTTPBasicAuth(user, password)
else:
return None

def get_verify(self) -> bool | str:
ca_bundle_path: str | None = self.conn.extra_dejson.get("ca_bundle_path", None)
if not self.verify_ssl and ca_bundle_path:
self.log.info("Using CA bundle to verify connection")
return ca_bundle_path

return self.verify_ssl

def submit_indexing_job(
self, json_index_spec: dict[str, Any] | str, ingestion_type: IngestionType = IngestionType.BATCH
) -> None:
Expand All @@ -109,7 +122,7 @@ def submit_indexing_job(

self.log.info("Druid ingestion spec: %s", json_index_spec)
req_index = requests.post(
url, data=json_index_spec, headers=self.header, auth=self.get_auth(), verify=self.verify_ssl
url, data=json_index_spec, headers=self.header, auth=self.get_auth(), verify=self.get_verify()
)

code = req_index.status_code
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/apache/druid/operators/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ class DruidOperator(BaseOperator):
of the ingestion job. Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the job failed
:param ingestion_type: The ingestion type of the job. Could be IngestionType.Batch or IngestionType.MSQ
:param verify_ssl: Either a boolean, in which case it controls whether we verify the server's TLS
certificate, or a string, in which case it must be a path to a CA bundle to use.
Defaults to True.
:param verify_ssl: Whether to use SSL encryption to submit indexing job. If set to False then checks
connection information for path to a CA bundle to use. Defaults to True
"""

template_fields: Sequence[str] = ("json_index_file",)
Expand All @@ -54,7 +53,7 @@ def __init__(
timeout: int = 1,
max_ingestion_time: int | None = None,
ingestion_type: IngestionType = IngestionType.BATCH,
verify_ssl: bool | str = True,
verify_ssl: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down
32 changes: 30 additions & 2 deletions tests/providers/apache/druid/hooks/test_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_submit_sql_based_ingestion_ok(self, requests_mock):
assert status_check.called_once

def test_submit_with_correct_ssl_arg(self, requests_mock):
self.db_hook.verify_ssl = "/path/to/ca.crt"
self.db_hook.verify_ssl = False
task_post = requests_mock.post(
"http://druid-overlord:8081/druid/indexer/v1/task",
text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}',
Expand All @@ -113,7 +113,7 @@ def test_submit_with_correct_ssl_arg(self, requests_mock):
assert status_check.called_once
if task_post.called_once:
verify_ssl = task_post.request_history[0].verify
assert "/path/to/ca.crt" == verify_ssl
assert False is verify_ssl

def test_submit_correct_json_body(self, requests_mock):
task_post = requests_mock.post(
Expand Down Expand Up @@ -199,6 +199,17 @@ def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH):

self.db_hook = TestDRuidhook()

@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_conn_property(self, mock_get_connection):
get_conn_value = MagicMock()
get_conn_value.host = "test_host"
get_conn_value.conn_type = "https"
get_conn_value.port = "1"
get_conn_value.extra_dejson = {"endpoint": "ingest"}
mock_get_connection.return_value = get_conn_value
hook = DruidHook()
assert hook.conn == get_conn_value

@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_get_conn_url(self, mock_get_connection):
get_conn_value = MagicMock()
Expand Down Expand Up @@ -254,6 +265,23 @@ def test_get_auth_with_no_user_and_password(self, mock_get_connection):
mock_get_connection.return_value = get_conn_value
assert self.db_hook.get_auth() is None

@pytest.mark.parametrize(
"verify_ssl_arg, ca_bundle_path, expected_return_value",
[
(False, None, False),
(True, None, True),
(False, "path/to/ca_bundle", "path/to/ca_bundle"),
(True, "path/to/ca_bundle", True),
],
)
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_get_verify(self, mock_get_connection, verify_ssl_arg, ca_bundle_path, expected_return_value):
get_conn_value = MagicMock()
get_conn_value.extra_dejson = {"ca_bundle_path": ca_bundle_path}
mock_get_connection.return_value = get_conn_value
hook = DruidHook(verify_ssl=verify_ssl_arg)
assert hook.get_verify() == expected_return_value


class TestDruidDbApiHook:
def setup_method(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/apache/druid/operators/test_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
druid_ingest_conn_id = "druid_ingest_default"
max_ingestion_time = 5
timeout = 5
verify_ssl = "/path/to/ca.crt"
verify_ssl = False
operator = DruidOperator(
task_id="spark_submit_job",
json_index_file=json_index_file,
Expand Down