Skip to content

Commit

Permalink
# This is a combination of 2 commits.
Browse files Browse the repository at this point in the history
# This is the 1st commit message:

adds password field for aws secret

# This is the commit message apache#2:

adds password field for aws secret
  • Loading branch information
sunkickr committed Mar 11, 2021
1 parent 46ceb92 commit 5e8c40b
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 22 deletions.
2 changes: 1 addition & 1 deletion airflow/operators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class BashOperator(BaseOperator):
r"""
"""
Execute a Bash script, command or set of commands.
.. seealso::
Expand Down
43 changes: 32 additions & 11 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@

class SnowflakeHook(DbApiHook):
"""
A client to interact with Snowflake
A client to interact with Snowflake.
This hook requires the snowflake_conn_id connection. The snowflake host, login,
and, password field must be setup in the connection. Other inputs can be defined
in the connection or hook instantiation. If used with the S3ToSnowflakeOperator
add 'aws_access_key_id' and 'aws_secret_access_key' to extra field in the connection.
:param account: snowflake account name
:type account: Optional[str]
Expand Down Expand Up @@ -72,9 +77,9 @@ class SnowflakeHook(DbApiHook):
@staticmethod
def get_connection_form_widgets() -> Dict[str, Any]:
"""Returns connection widgets to add to connection form"""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField
from wtforms import PasswordField, StringField

return {
"extra__snowflake__account": StringField(lazy_gettext('Account'), widget=BS3TextFieldWidget()),
Expand All @@ -83,6 +88,12 @@ def get_connection_form_widgets() -> Dict[str, Any]:
),
"extra__snowflake__database": StringField(lazy_gettext('Database'), widget=BS3TextFieldWidget()),
"extra__snowflake__region": StringField(lazy_gettext('Region'), widget=BS3TextFieldWidget()),
"extra__snowflake__aws_access_key_id": StringField(
lazy_gettext('AWS Access Key'), widget=BS3TextFieldWidget()
),
"extra__snowflake__aws_secret_access_key": PasswordField(
lazy_gettext('AWS Secret Key'), widget=BS3PasswordFieldWidget()
),
}

@staticmethod
Expand All @@ -100,8 +111,6 @@ def get_ui_field_behaviour() -> Dict:
"authenticator": "snowflake oauth",
"private_key_file": "private key",
"session_parameters": "session parameters",
"aws_access_key_id": "aws access key",
"aws_secret_access_key": "aws secret key",
},
indent=1,
),
Expand All @@ -113,6 +122,8 @@ def get_ui_field_behaviour() -> Dict:
'extra__snowflake__warehouse': 'snowflake warehouse name',
'extra__snowflake__database': 'snowflake db name',
'extra__snowflake__region': 'snowflake hosted region',
'extra__snowflake__aws_access_key_id': 'aws access key id (S3ToSnowflakeOperator)',
'extra__snowflake__aws_secret_access_key': 'aws secret access key (S3ToSnowflakeOperator)',
},
}

Expand All @@ -135,10 +146,16 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]:
conn = self.get_connection(
self.snowflake_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
)
account = conn.extra_dejson.get('extra__snowflake__account', '')
warehouse = conn.extra_dejson.get('extra__snowflake__warehouse', '')
database = conn.extra_dejson.get('extra__snowflake__database', '')
region = conn.extra_dejson.get('extra__snowflake__region', '')
account = conn.extra_dejson.get('extra__snowflake__account', '') or conn.extra_dejson.get(
'account', ''
)
warehouse = conn.extra_dejson.get('extra__snowflake__warehouse', '') or conn.extra_dejson.get(
'warehouse', ''
)
database = conn.extra_dejson.get('extra__snowflake__database', '') or conn.extra_dejson.get(
'database', ''
)
region = conn.extra_dejson.get('extra__snowflake__region', '') or conn.extra_dejson.get('region', '')
role = conn.extra_dejson.get('role', '')
schema = conn.schema or ''
authenticator = conn.extra_dejson.get('authenticator', 'snowflake')
Expand Down Expand Up @@ -211,8 +228,12 @@ def _get_aws_credentials(self) -> Tuple[Optional[Any], Optional[Any]]:
self.snowflake_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
)
if 'aws_secret_access_key' in connection_object.extra_dejson:
aws_access_key_id = connection_object.extra_dejson.get('aws_access_key_id')
aws_secret_access_key = connection_object.extra_dejson.get('aws_secret_access_key')
aws_access_key_id = connection_object.extra_dejson.get(
'aws_access_key_id'
) or connection_object.extra_dejson.get('aws_access_key_id')
aws_secret_access_key = connection_object.extra_dejson.get(
'aws_secret_access_key'
) or connection_object.extra_dejson.get('aws_secret_access_key')
return aws_access_key_id, aws_secret_access_key

def set_autocommit(self, conn, autocommit: Any) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ function discover_all_connection_form_widgets() {

COLUMNS=180 airflow providers widgets

local expected_number_of_widgets=19
local expected_number_of_widgets=25
local actual_number_of_widgets
actual_number_of_widgets=$(airflow providers widgets --output table | grep -c ^extra)
if [[ ${actual_number_of_widgets} != "${expected_number_of_widgets}" ]]; then
Expand All @@ -176,7 +176,7 @@ function discover_all_field_behaviours() {
group_start "Listing connections with custom behaviours via 'airflow providers behaviours'"
COLUMNS=180 airflow providers behaviours

local expected_number_of_connections_with_behaviours=11
local expected_number_of_connections_with_behaviours=12
local actual_number_of_connections_with_behaviours
actual_number_of_connections_with_behaviours=$(airflow providers behaviours --output table | grep -v "===" | \
grep -v field_behaviours | grep -cv "^ " | xargs)
Expand Down
14 changes: 10 additions & 4 deletions tests/core/test_providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@
'apache-airflow-providers-sftp',
'apache-airflow-providers-singularity',
'apache-airflow-providers-slack',
# Uncomment when https://github.com/apache/airflow/issues/12881 is fixed
# 'apache-airflow-providers-snowflake',
'apache-airflow-providers-snowflake',
'apache-airflow-providers-sqlite',
'apache-airflow-providers-ssh',
'apache-airflow-providers-tableau',
Expand Down Expand Up @@ -139,8 +138,7 @@
'samba',
'segment',
'sftp',
# Uncomment when https://github.com/apache/airflow/issues/12881 is fixed
# 'snowflake',
'snowflake',
'spark',
'spark_jdbc',
'spark_sql',
Expand Down Expand Up @@ -174,6 +172,13 @@
'extra__yandexcloud__public_ssh_key',
'extra__yandexcloud__service_account_json',
'extra__yandexcloud__service_account_json_path',
'extra__snowflake__account',
'extra__snowflake__warehouse',
'extra__snowflake__database',
'extra__snowflake__region',
'extra__snowflake__aws_access_key_id',
'extra__snowflake__aws_secret_access_key',

]

CONNECTIONS_WITH_FIELD_BEHAVIOURS = [
Expand All @@ -188,6 +193,7 @@
'spark',
'ssh',
'yandexcloud',
'snowflake',
]

EXTRA_LINKS = [
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def setUp(self):
self.conn.password = 'pw'
self.conn.schema = 'public'
self.conn.extra_dejson = {
'extra__snowflake__database': 'db',
'extra__snowflake__account': 'airflow',
'extra__snowflake__warehouse': 'af_wh',
'extra__snowflake__region': 'af_region',
'database': 'db',
'account': 'airflow',
'warehouse': 'af_wh',
'region': 'af_region',
'role': 'af_role',
}

Expand Down

0 comments on commit 5e8c40b

Please sign in to comment.