Skip to content

Commit

Permalink
Update example_redshift and example_redshift_s3_transfers to use …
Browse files Browse the repository at this point in the history
…`RedshiftDataHook` instead of `RedshiftSQLHook` (#40970)
  • Loading branch information
vincbeck authored Jul 24, 2024
1 parent 152d03e commit b4e82cf
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 88 deletions.
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class RedshiftToS3Operator(BaseOperator):
"unload_options",
"select_query",
"redshift_conn_id",
"redshift_data_api_kwargs",
)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"select_query": "sql"}
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class S3ToRedshiftOperator(BaseOperator):
"copy_options",
"redshift_conn_id",
"method",
"redshift_data_api_kwargs",
"aws_conn_id",
)
template_ext: Sequence[str] = ()
Expand Down
11 changes: 0 additions & 11 deletions tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,6 @@ def test_table_unloading_role_arn(
assert extra["role_arn"] in unload_query
assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query)

def test_template_fields_overrides(self):
assert RedshiftToS3Operator.template_fields == (
"s3_bucket",
"s3_key",
"schema",
"table",
"unload_options",
"select_query",
"redshift_conn_id",
)

@pytest.mark.parametrize("param", ["sql", "parameters"])
def test_invalid_param_in_redshift_data_api_kwargs(self, param):
"""
Expand Down
13 changes: 0 additions & 13 deletions tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,19 +381,6 @@ def test_different_region(self, mock_run, mock_session, mock_connection, mock_ho
assert mock_run.call_count == 1
assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query)

def test_template_fields_overrides(self):
assert S3ToRedshiftOperator.template_fields == (
"s3_bucket",
"s3_key",
"schema",
"table",
"column_list",
"copy_options",
"redshift_conn_id",
"method",
"aws_conn_id",
)

def test_execute_unavailable_method(self):
"""
Test execute unavailable method
Expand Down
35 changes: 0 additions & 35 deletions tests/system/providers/amazon/aws/example_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,8 @@

from datetime import datetime

from airflow import settings
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
RedshiftCreateClusterSnapshotOperator,
Expand All @@ -36,7 +32,6 @@
)
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator
from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder

Expand All @@ -56,24 +51,6 @@
POLL_INTERVAL = 10


@task
def create_connection(conn_id_name: str, cluster_id: str):
redshift_hook = RedshiftHook()
cluster_endpoint = redshift_hook.get_conn().describe_clusters(ClusterIdentifier=cluster_id)["Clusters"][0]
conn = Connection(
conn_id=conn_id_name,
conn_type="redshift",
host=cluster_endpoint["Endpoint"]["Address"],
login=DB_LOGIN,
password=DB_PASS,
port=cluster_endpoint["Endpoint"]["Port"],
schema=cluster_endpoint["DBName"],
)
session = settings.Session()
session.add(conn)
session.commit()


with DAG(
dag_id=DAG_ID,
start_date=datetime(2021, 1, 1),
Expand All @@ -87,7 +64,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
cluster_subnet_group_name = test_context[CLUSTER_SUBNET_GROUP_KEY]
redshift_cluster_identifier = f"{env_id}-redshift-cluster"
redshift_cluster_snapshot_identifier = f"{env_id}-snapshot"
conn_id_name = f"{env_id}-conn-id"
sg_name = f"{env_id}-sg"

# [START howto_operator_redshift_cluster]
Expand Down Expand Up @@ -164,8 +140,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
timeout=60 * 30,
)

set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier)

# [START howto_operator_redshift_data]
create_table_redshift_data = RedshiftDataOperator(
task_id="create_table_redshift_data",
Expand Down Expand Up @@ -201,13 +175,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
wait_for_completion=True,
)

drop_table = SQLExecuteQueryOperator(
task_id="drop_table",
conn_id=conn_id_name,
sql="DROP TABLE IF EXISTS fruit",
trigger_rule=TriggerRule.ALL_DONE,
)

# [START howto_operator_redshift_delete_cluster]
delete_cluster = RedshiftDeleteClusterOperator(
task_id="delete_cluster",
Expand Down Expand Up @@ -236,10 +203,8 @@ def create_connection(conn_id_name: str, cluster_id: str):
wait_cluster_paused,
resume_cluster,
wait_cluster_available_after_resume,
set_up_connection,
create_table_redshift_data,
insert_data,
drop_table,
delete_cluster_snapshot,
delete_cluster,
)
Expand Down
47 changes: 18 additions & 29 deletions tests/system/providers/amazon/aws/example_redshift_s3_transfers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@

from datetime import datetime

from airflow import settings
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
RedshiftDeleteClusterOperator,
Expand Down Expand Up @@ -75,24 +71,6 @@
DATA = "0, 'Airflow', 'testing'"


@task
def create_connection(conn_id_name: str, cluster_id: str):
redshift_hook = RedshiftHook()
cluster_endpoint = redshift_hook.get_conn().describe_clusters(ClusterIdentifier=cluster_id)["Clusters"][0]
conn = Connection(
conn_id=conn_id_name,
conn_type="redshift",
host=cluster_endpoint["Endpoint"]["Address"],
login=DB_LOGIN,
password=DB_PASS,
port=cluster_endpoint["Endpoint"]["Port"],
schema=cluster_endpoint["DBName"],
)
session = settings.Session()
session.add(conn)
session.commit()


with DAG(
dag_id=DAG_ID,
start_date=datetime(2021, 1, 1),
Expand All @@ -105,7 +83,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
security_group_id = test_context[SECURITY_GROUP_KEY]
cluster_subnet_group_name = test_context[CLUSTER_SUBNET_GROUP_KEY]
redshift_cluster_identifier = f"{env_id}-redshift-cluster"
conn_id_name = f"{env_id}-conn-id"
sg_name = f"{env_id}-sg"
bucket_name = f"{env_id}-bucket"

Expand Down Expand Up @@ -134,8 +111,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
timeout=60 * 30,
)

set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier)

create_object = S3CreateObjectOperator(
task_id="create_object",
s3_bucket=bucket_name,
Expand Down Expand Up @@ -165,7 +140,12 @@ def create_connection(conn_id_name: str, cluster_id: str):
# [START howto_transfer_redshift_to_s3]
transfer_redshift_to_s3 = RedshiftToS3Operator(
task_id="transfer_redshift_to_s3",
redshift_conn_id=conn_id_name,
redshift_data_api_kwargs={
"database": DB_NAME,
"cluster_identifier": redshift_cluster_identifier,
"db_user": DB_LOGIN,
"wait_for_completion": True,
},
s3_bucket=bucket_name,
s3_key=S3_KEY,
schema="PUBLIC",
Expand All @@ -182,7 +162,12 @@ def create_connection(conn_id_name: str, cluster_id: str):
# [START howto_transfer_s3_to_redshift]
transfer_s3_to_redshift = S3ToRedshiftOperator(
task_id="transfer_s3_to_redshift",
redshift_conn_id=conn_id_name,
redshift_data_api_kwargs={
"database": DB_NAME,
"cluster_identifier": redshift_cluster_identifier,
"db_user": DB_LOGIN,
"wait_for_completion": True,
},
s3_bucket=bucket_name,
s3_key=S3_KEY_2,
schema="PUBLIC",
Expand All @@ -194,7 +179,12 @@ def create_connection(conn_id_name: str, cluster_id: str):
# [START howto_transfer_s3_to_redshift_multiple_keys]
transfer_s3_to_redshift_multiple = S3ToRedshiftOperator(
task_id="transfer_s3_to_redshift_multiple",
redshift_conn_id=conn_id_name,
redshift_data_api_kwargs={
"database": DB_NAME,
"cluster_identifier": redshift_cluster_identifier,
"db_user": DB_LOGIN,
"wait_for_completion": True,
},
s3_bucket=bucket_name,
s3_key=S3_KEY_PREFIX,
schema="PUBLIC",
Expand Down Expand Up @@ -231,7 +221,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
create_bucket,
create_cluster,
wait_cluster_available,
set_up_connection,
create_object,
create_table_redshift_data,
insert_data,
Expand Down

0 comments on commit b4e82cf

Please sign in to comment.