Skip to content

Commit

Permalink
feat: update SnowflakeSqlApiHook to support OAuth
Browse files Browse the repository at this point in the history
  • Loading branch information
andyguwc committed Mar 5, 2024
1 parent 2867951 commit 0d14c7e
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 20 deletions.
9 changes: 9 additions & 0 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ def _get_conn_params(self) -> dict[str, str | None]:
conn_config["private_key"] = pkb
conn_config.pop("password", None)

refresh_token = self._get_field(extra_dict, "refresh_token") or ""
if refresh_token:
conn_config["refresh_token"] = refresh_token
conn_config["authenticator"] = "oauth"
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
conn_config.pop("login", None)
conn_config.pop("password", None)

return conn_config

def get_uri(self) -> str:
Expand Down
74 changes: 60 additions & 14 deletions airflow/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from requests.auth import HTTPBasicAuth

from airflow.exceptions import AirflowException
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
Expand All @@ -39,8 +40,9 @@ class SnowflakeSqlApiHook(SnowflakeHook):
poll to check the status of the execution of a statement. Fetch query results asynchronously.
This hook requires the snowflake_conn_id connection. This hooks mainly uses account, schema, database,
warehouse, private_key_file or private_key_content field must be setup in the connection. Other inputs
can be defined in the connection or hook instantiation.
warehouse, and an authentication mechanism from one of below:
1. JWT Token generated from private_key_file or private_key_content. Other inputs can be defined in the connection or hook instantiation.
2. OAuth Token generated from the refresh_token, client_id and client_secret specified in the connection
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
Expand Down Expand Up @@ -81,6 +83,17 @@ def __init__(
super().__init__(snowflake_conn_id, *args, **kwargs)
self.private_key: Any = None

@property
def account_identifier(self) -> str:
"""Returns snowflake account identifier."""
conn_config = self._get_conn_params()
account_identifier = f"https://{conn_config['account']}"

if conn_config["region"]:
account_identifier += f".{conn_config['region']}"

return account_identifier

def get_private_key(self) -> None:
"""Get the private key from snowflake connection."""
conn = self.get_connection(self.snowflake_conn_id)
Expand Down Expand Up @@ -137,10 +150,7 @@ def execute_query(
conn_config = self._get_conn_params()

req_id = uuid.uuid4()
url = (
f"https://{conn_config['account']}.{conn_config['region']}"
f".snowflakecomputing.com/api/v2/statements"
)
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements"
params: dict[str, Any] | None = {"requestId": str(req_id), "async": True, "pageSize": 10}
headers = self.get_headers()
if bindings is None:
Expand Down Expand Up @@ -175,12 +185,27 @@ def execute_query(
return self.query_ids

def get_headers(self) -> dict[str, Any]:
"""Form JWT Token and header based on the private key, and connection details."""
"""Form auth headers based on either OAuth token or JWT token from private key."""
conn_config = self._get_conn_params()

# Use OAuth if refresh_token and client_id and client_secret are provided
if all(
[conn_config.get("refresh_token"), conn_config.get("client_id"), conn_config.get("client_secret")]
):
oauth_token = self.get_oauth_token()
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {oauth_token}",
"Accept": "application/json",
"User-Agent": "snowflakeSQLAPI/1.0",
"X-Snowflake-Authorization-Token-Type": "OAUTH",
}
return headers

# Alternatively, get the JWT token from the connection details and the private key
if not self.private_key:
self.get_private_key()
conn_config = self._get_conn_params()

# Get the JWT token from the connection details and the private key
token = JWTGenerator(
conn_config["account"], # type: ignore[arg-type]
conn_config["user"], # type: ignore[arg-type]
Expand All @@ -198,20 +223,41 @@ def get_headers(self) -> dict[str, Any]:
}
return headers

def get_oauth_token(self) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
conn_config = self._get_conn_params()
url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request"
data = {
"grant_type": "refresh_token",
"refresh_token": conn_config["refresh_token"],
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
}
response = requests.post(
url,
data=data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type]
)

try:
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
raise AirflowException(msg)
return response.json()["access_token"]

def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
"""
Build the request header Url with account name identifier and query id from the connection params.
:param query_id: statement handles query ids for the individual statements.
"""
conn_config = self._get_conn_params()
req_id = uuid.uuid4()
header = self.get_headers()
params = {"requestId": str(req_id)}
url = (
f"https://{conn_config['account']}.{conn_config['region']}"
f".snowflakecomputing.com/api/v2/statements/{query_id}"
)
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements/{query_id}"
return header, params, url

def check_query_output(self, query_ids: list[str]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ Extra (optional)
* ``region``: Warehouse region.
* ``warehouse``: Snowflake warehouse name.
* ``role``: Snowflake role.
* ``authenticator``: To connect using OAuth set this parameter ``oath``.
* ``authenticator``: To connect using OAuth set this parameter ``oauth``.
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``private_key_file``: Specify the path to the private key file.
* ``private_key_content``: Specify the content of the private key file.
* ``session_parameters``: Specify `session level parameters <https://docs.snowflake.com/en/user-guide/python-connector-example.html#setting-session-parameters>`_.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ the connection metadata is structured as follows:
* - Parameter
- Input
* - Login: string
- Snowflake user name
- Snowflake user name. If using `OAuth connection <https://docs.snowflake.com/en/developer-guide/sql-api/authenticating#using-oauth>`__ this is the ``client_id``
* - Password: string
- Password for Snowflake user
- Password for Snowflake user. If using OAuth this is the ``client_secret``
* - Schema: string
- Set schema to execute SQL operations on by default
* - Extra: dictionary
- ``warehouse``, ``account``, ``database``, ``region``, ``role``, ``authenticator``
- ``warehouse``, ``account``, ``database``, ``region``, ``role``, ``authenticator``, ``refresh_token``. If using OAuth must specify ``refresh_token`` (`obtained here <https://community.snowflake.com/s/article/HOW-TO-OAUTH-TOKEN-GENERATION-USING-SNOWFLAKE-CUSTOM-OAUTH>`__)

An example usage of the SnowflakeSqlApiHook is as follows:

Expand Down
62 changes: 60 additions & 2 deletions tests/providers/snowflake/hooks/test_snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"role": "af_role",
},
}

CONN_PARAMS = {
"account": "airflow",
"application": "AIRFLOW",
Expand All @@ -71,6 +72,22 @@
"user": "user",
"warehouse": "af_wh",
}

CONN_PARAMS_OAUTH = {
"account": "airflow",
"application": "AIRFLOW",
"authenticator": "oauth",
"database": "db",
"client_id": "test_client_id",
"client_secret": "test_client_pw",
"refresh_token": "secrettoken",
"region": "af_region",
"role": "af_role",
"schema": "public",
"session_parameters": None,
"warehouse": "af_wh",
}

HEADERS = {
"Content-Type": "application/json",
"Authorization": "Bearer newT0k3n",
Expand All @@ -79,6 +96,15 @@
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
}

HEADERS_OAUTH = {
"Content-Type": "application/json",
"Authorization": "Bearer newT0k3n",
"Accept": "application/json",
"User-Agent": "snowflakeSQLAPI/1.0",
"X-Snowflake-Authorization-Token-Type": "OAUTH",
}


GET_RESPONSE = {
"resultSetMetaData": {
"numRows": 10000,
Expand Down Expand Up @@ -173,7 +199,7 @@ def test_execute_query(
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
def test_execute_query_exception_without_statement_handel(
def test_execute_query_exception_without_statement_handle(
self,
mock_get_header,
mock_conn_param,
Expand Down Expand Up @@ -250,14 +276,46 @@ def test_get_request_url_header_params(self, mock_get_header, mock_conn_param):
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_private_key")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
@mock.patch("airflow.providers.snowflake.utils.sql_api_generate_jwt.JWTGenerator.get_token")
def test_get_headers(self, mock_get_token, mock_conn_param, mock_private_key):
def test_get_headers_should_support_private_key(self, mock_get_token, mock_conn_param, mock_private_key):
"""Test get_headers method by mocking get_private_key and _get_conn_params method"""
mock_get_token.return_value = "newT0k3n"
mock_conn_param.return_value = CONN_PARAMS
hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
result = hook.get_headers()
assert result == HEADERS

@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_oauth_token")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_token):
"""Test get_headers method by mocking get_oauth_token and _get_conn_params method"""
mock_conn_param.return_value = CONN_PARAMS_OAUTH
mock_oauth_token.return_value = "newT0k3n"
hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
result = hook.get_headers()
assert result == HEADERS_OAUTH

@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.HTTPBasicAuth")
@mock.patch("requests.post")
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
"""Test get_oauth_token method makes the right http request"""
BASIC_AUTH = {"Authorization": "Basic usernamepassword"}
mock_conn_param.return_value = CONN_PARAMS_OAUTH
requests_post.return_value.status_code = 200
mock_auth.return_value = BASIC_AUTH
hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
hook.get_oauth_token()
requests_post.assert_called_once_with(
f"https://{CONN_PARAMS_OAUTH['account']}.{CONN_PARAMS_OAUTH['region']}.snowflakecomputing.com/oauth/token-request",
data={
"grant_type": "refresh_token",
"refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
"redirect_uri": "https://localhost.com",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=BASIC_AUTH,
)

@pytest.fixture
def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
"""Encrypt the pem file from the path"""
Expand Down

0 comments on commit 0d14c7e

Please sign in to comment.