Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update SnowflakeSqlApiHook to support OAuth #37922

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ def _get_conn_params(self) -> dict[str, str | None]:
conn_config["private_key"] = pkb
conn_config.pop("password", None)

refresh_token = self._get_field(extra_dict, "refresh_token") or ""
if refresh_token:
conn_config["refresh_token"] = refresh_token
conn_config["authenticator"] = "oauth"
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
conn_config.pop("login", None)
conn_config.pop("password", None)

return conn_config

def get_uri(self) -> str:
Expand Down
74 changes: 60 additions & 14 deletions airflow/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from requests.auth import HTTPBasicAuth

from airflow.exceptions import AirflowException
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
Expand All @@ -39,8 +40,9 @@ class SnowflakeSqlApiHook(SnowflakeHook):
poll to check the status of the execution of a statement. Fetch query results asynchronously.

This hook requires the snowflake_conn_id connection. This hooks mainly uses account, schema, database,
warehouse, private_key_file or private_key_content field must be setup in the connection. Other inputs
can be defined in the connection or hook instantiation.
warehouse, and an authentication mechanism from one of below:
1. JWT Token generated from private_key_file or private_key_content. Other inputs can be defined in the connection or hook instantiation.
2. OAuth Token generated from the refresh_token, client_id and client_secret specified in the connection

:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
Expand Down Expand Up @@ -81,6 +83,17 @@ def __init__(
super().__init__(snowflake_conn_id, *args, **kwargs)
self.private_key: Any = None

@property
def account_identifier(self) -> str:
"""Returns snowflake account identifier."""
conn_config = self._get_conn_params()
account_identifier = f"https://{conn_config['account']}"

if conn_config["region"]:
account_identifier += f".{conn_config['region']}"

return account_identifier

def get_private_key(self) -> None:
"""Get the private key from snowflake connection."""
conn = self.get_connection(self.snowflake_conn_id)
Expand Down Expand Up @@ -137,10 +150,7 @@ def execute_query(
conn_config = self._get_conn_params()

req_id = uuid.uuid4()
url = (
f"https://{conn_config['account']}.{conn_config['region']}"
f".snowflakecomputing.com/api/v2/statements"
)
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements"
params: dict[str, Any] | None = {"requestId": str(req_id), "async": True, "pageSize": 10}
headers = self.get_headers()
if bindings is None:
Expand Down Expand Up @@ -175,12 +185,27 @@ def execute_query(
return self.query_ids

def get_headers(self) -> dict[str, Any]:
"""Form JWT Token and header based on the private key, and connection details."""
"""Form auth headers based on either OAuth token or JWT token from private key."""
conn_config = self._get_conn_params()

# Use OAuth if refresh_token and client_id and client_secret are provided
if all(
[conn_config.get("refresh_token"), conn_config.get("client_id"), conn_config.get("client_secret")]
):
oauth_token = self.get_oauth_token()
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {oauth_token}",
"Accept": "application/json",
"User-Agent": "snowflakeSQLAPI/1.0",
"X-Snowflake-Authorization-Token-Type": "OAUTH",
}
return headers

# Alternatively, get the JWT token from the connection details and the private key
if not self.private_key:
self.get_private_key()
conn_config = self._get_conn_params()

# Get the JWT token from the connection details and the private key
token = JWTGenerator(
conn_config["account"], # type: ignore[arg-type]
conn_config["user"], # type: ignore[arg-type]
Expand All @@ -198,20 +223,41 @@ def get_headers(self) -> dict[str, Any]:
}
return headers

def get_oauth_token(self) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
conn_config = self._get_conn_params()
url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request"
data = {
"grant_type": "refresh_token",
"refresh_token": conn_config["refresh_token"],
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
}
response = requests.post(
url,
data=data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type]
)

try:
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
raise AirflowException(msg)
return response.json()["access_token"]

def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
"""
Build the request header Url with account name identifier and query id from the connection params.

:param query_id: statement handles query ids for the individual statements.
"""
conn_config = self._get_conn_params()
req_id = uuid.uuid4()
header = self.get_headers()
params = {"requestId": str(req_id)}
url = (
f"https://{conn_config['account']}.{conn_config['region']}"
f".snowflakecomputing.com/api/v2/statements/{query_id}"
)
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements/{query_id}"
return header, params, url

def check_query_output(self, query_ids: list[str]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ Extra (optional)
* ``region``: Warehouse region.
* ``warehouse``: Snowflake warehouse name.
* ``role``: Snowflake role.
* ``authenticator``: To connect using OAuth set this parameter ``oath``.
* ``authenticator``: To connect using OAuth set this parameter ``oauth``.
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``private_key_file``: Specify the path to the private key file.
* ``private_key_content``: Specify the content of the private key file.
* ``session_parameters``: Specify `session level parameters <https://docs.snowflake.com/en/user-guide/python-connector-example.html#setting-session-parameters>`_.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ the connection metadata is structured as follows:
* - Parameter
- Input
* - Login: string
- Snowflake user name
- Snowflake user name. If using `OAuth connection <https://docs.snowflake.com/en/developer-guide/sql-api/authenticating#using-oauth>`__ this is the ``client_id``
* - Password: string
- Password for Snowflake user
- Password for Snowflake user. If using OAuth this is the ``client_secret``
* - Schema: string
- Set schema to execute SQL operations on by default
* - Extra: dictionary
- ``warehouse``, ``account``, ``database``, ``region``, ``role``, ``authenticator``
- ``warehouse``, ``account``, ``database``, ``region``, ``role``, ``authenticator``, ``refresh_token``. If using OAuth must specify ``refresh_token`` (`obtained here <https://community.snowflake.com/s/article/HOW-TO-OAUTH-TOKEN-GENERATION-USING-SNOWFLAKE-CUSTOM-OAUTH>`__)

An example usage of the SnowflakeSqlApiHook is as follows:

Expand Down
62 changes: 60 additions & 2 deletions tests/providers/snowflake/hooks/test_snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"role": "af_role",
},
}

CONN_PARAMS = {
"account": "airflow",
"application": "AIRFLOW",
Expand All @@ -71,6 +72,22 @@
"user": "user",
"warehouse": "af_wh",
}

CONN_PARAMS_OAUTH = {
"account": "airflow",
"application": "AIRFLOW",
"authenticator": "oauth",
"database": "db",
"client_id": "test_client_id",
"client_secret": "test_client_pw",
"refresh_token": "secrettoken",
"region": "af_region",
"role": "af_role",
"schema": "public",
"session_parameters": None,
"warehouse": "af_wh",
}

HEADERS = {
"Content-Type": "application/json",
"Authorization": "Bearer newT0k3n",
Expand All @@ -79,6 +96,15 @@
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
}

HEADERS_OAUTH = {
"Content-Type": "application/json",
"Authorization": "Bearer newT0k3n",
"Accept": "application/json",
"User-Agent": "snowflakeSQLAPI/1.0",
"X-Snowflake-Authorization-Token-Type": "OAUTH",
}


GET_RESPONSE = {
"resultSetMetaData": {
"numRows": 10000,
Expand Down Expand Up @@ -173,7 +199,7 @@ def test_execute_query(
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
def test_execute_query_exception_without_statement_handel(
def test_execute_query_exception_without_statement_handle(
self,
mock_get_header,
mock_conn_param,
Expand Down Expand Up @@ -250,14 +276,46 @@ def test_get_request_url_header_params(self, mock_get_header, mock_conn_param):
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_private_key")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
@mock.patch("airflow.providers.snowflake.utils.sql_api_generate_jwt.JWTGenerator.get_token")
def test_get_headers(self, mock_get_token, mock_conn_param, mock_private_key):
def test_get_headers_should_support_private_key(self, mock_get_token, mock_conn_param, mock_private_key):
"""Test get_headers method by mocking get_private_key and _get_conn_params method"""
mock_get_token.return_value = "newT0k3n"
mock_conn_param.return_value = CONN_PARAMS
hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
result = hook.get_headers()
assert result == HEADERS

@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_oauth_token")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_token):
"""Test get_headers method by mocking get_oauth_token and _get_conn_params method"""
mock_conn_param.return_value = CONN_PARAMS_OAUTH
mock_oauth_token.return_value = "newT0k3n"
hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
result = hook.get_headers()
assert result == HEADERS_OAUTH

@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.HTTPBasicAuth")
@mock.patch("requests.post")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
"""Test get_oauth_token method makes the right http request"""
BASIC_AUTH = {"Authorization": "Basic usernamepassword"}
mock_conn_param.return_value = CONN_PARAMS_OAUTH
requests_post.return_value.status_code = 200
mock_auth.return_value = BASIC_AUTH
hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
hook.get_oauth_token()
requests_post.assert_called_once_with(
f"https://{CONN_PARAMS_OAUTH['account']}.{CONN_PARAMS_OAUTH['region']}.snowflakecomputing.com/oauth/token-request",
data={
"grant_type": "refresh_token",
"refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
"redirect_uri": "https://localhost.com",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=BASIC_AUTH,
)

@pytest.fixture
def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
"""Encrypt the pem file from the path"""
Expand Down