Skip to content

Commit

Permalink
[AIRFLOW-5906] Add authenticator parameter to snowflake_hook (#8642)
Browse files Browse the repository at this point in the history
(cherry picked from commit cd635dd)
  • Loading branch information
koszti authored and kaxil committed Jun 30, 2020
1 parent 98443d2 commit 8d23325
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
8 changes: 5 additions & 3 deletions airflow/contrib/hooks/snowflake_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, *args, **kwargs):
self.region = kwargs.pop("region", None)
self.role = kwargs.pop("role", None)
self.schema = kwargs.pop("schema", None)
self.authenticator = kwargs.pop("authenticator", None)

def _get_conn_params(self):
"""
Expand All @@ -56,6 +57,7 @@ def _get_conn_params(self):
database = conn.extra_dejson.get('database', None)
region = conn.extra_dejson.get("region", None)
role = conn.extra_dejson.get('role', None)
authenticator = conn.extra_dejson.get('authenticator', 'snowflake')

conn_config = {
"user": conn.login,
Expand All @@ -65,8 +67,8 @@ def _get_conn_params(self):
"account": self.account or account or '',
"warehouse": self.warehouse or warehouse or '',
"region": self.region or region or '',
"role": self.role or role or ''

"role": self.role or role,
"authenticator": self.authenticator or authenticator
}

"""
Expand Down Expand Up @@ -103,7 +105,7 @@ def get_uri(self):
"""
conn_config = self._get_conn_params()
uri = 'snowflake://{user}:{password}@{account}/{database}/'
uri += '{schema}?warehouse={warehouse}&role={role}'
uri += '{schema}?warehouse={warehouse}&role={role}&authenticator={authenticator}'
return uri.format(**conn_config)

def get_conn(self):
Expand Down
13 changes: 11 additions & 2 deletions airflow/contrib/operators/snowflake_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ class SnowflakeOperator(BaseOperator):
:type schema: str
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:type authenticator: str
"""

template_fields = ('sql',)
Expand All @@ -52,7 +60,7 @@ class SnowflakeOperator(BaseOperator):
def __init__(
self, sql, snowflake_conn_id='snowflake_default', parameters=None,
autocommit=True, warehouse=None, database=None, role=None,
schema=None, *args, **kwargs):
schema=None, authenticator=None, *args, **kwargs):
super(SnowflakeOperator, self).__init__(*args, **kwargs)
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
Expand All @@ -62,11 +70,12 @@ def __init__(
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator

def get_hook(self):
return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id,
warehouse=self.warehouse, database=self.database,
role=self.role, schema=self.schema)
role=self.role, schema=self.schema, authenticator=self.authenticator)

def execute(self, context):
self.log.info('Executing: %s', self.sql)
Expand Down
6 changes: 4 additions & 2 deletions tests/contrib/hooks/test_snowflake_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def tearDown(self):
os.remove(self.nonEncryptedPrivateKey)

def test_get_uri(self):
uri_shouldbe = 'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role'
uri_shouldbe = 'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role' \
'&authenticator=snowflake'
self.assertEqual(uri_shouldbe, self.db_hook.get_uri())

def test_get_conn_params(self):
Expand All @@ -103,7 +104,8 @@ def test_get_conn_params(self):
'account': 'airflow',
'warehouse': 'af_wh',
'region': 'af_region',
'role': 'af_role'}
'role': 'af_role',
'authenticator': 'snowflake'}
self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params())

def test_get_conn(self):
Expand Down

0 comments on commit 8d23325

Please sign in to comment.