Skip to content

Commit

Permalink
Add RedshiftResumeClusterOperator and RedshiftPauseClusterOperator (#…
Browse files Browse the repository at this point in the history
…19665)

These operators provide the ability to pause and resume a redshift cluster.
  • Loading branch information
dbarrundiag authored Dec 13, 2021
1 parent 08e8357 commit e77c05f
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 2 deletions.
84 changes: 83 additions & 1 deletion airflow/providers/amazon/aws/operators/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Dict, Iterable, Optional, Union

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift import RedshiftSQLHook
from airflow.providers.amazon.aws.hooks.redshift import RedshiftHook, RedshiftSQLHook


class RedshiftSQLOperator(BaseOperator):
Expand Down Expand Up @@ -71,3 +71,85 @@ def execute(self, context: dict) -> None:
self.log.info(f"Executing statement: {self.sql}")
hook = self.get_hook()
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)


class RedshiftResumeClusterOperator(BaseOperator):
"""
Resume a paused AWS Redshift Cluster
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RedshiftResumeClusterOperator`
:param cluster_identifier: id of the AWS Redshift Cluster
:type cluster_identifier: str
:param aws_conn_id: aws connection to use
:type aws_conn_id: str
"""

template_fields = ("cluster_identifier",)
ui_color = "#eeaa11"
ui_fgcolor = "#ffffff"

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id

def execute(self, context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == 'paused':
self.log.info("Starting Redshift cluster %s", self.cluster_identifier)
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
else:
self.log.warning(
"Unable to resume cluster since cluster is currently in status: %s", cluster_state
)


class RedshiftPauseClusterOperator(BaseOperator):
"""
Pause an AWS Redshift Cluster if it has status `available`.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RedshiftPauseClusterOperator`
:param cluster_identifier: id of the AWS Redshift Cluster
:type cluster_identifier: str
:param aws_conn_id: aws connection to use
:type aws_conn_id: str
"""

template_fields = ("cluster_identifier",)
ui_color = "#eeaa11"
ui_fgcolor = "#ffffff"

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id

def execute(self, context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == 'available':
self.log.info("Pausing Redshift cluster %s", self.cluster_identifier)
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
else:
self.log.warning(
"Unable to pause cluster since cluster is currently in status: %s", cluster_state
)
22 changes: 22 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/redshift.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ All together, here is our DAG:
:language: python
:start-after: [START redshift_operator_howto_guide]
:end-before: [END redshift_operator_howto_guide]


.. _howto/operator:RedshiftResumeClusterOperator:

Resume a Redshift Cluster
"""""""""""""""""""""""""""""""""""""""""""

To resume a 'paused' AWS Redshift Cluster you can use
:class:`RedshiftResumeClusterOperator <airflow.providers.amazon.aws.operators.redshift>`

This Operator leverages the AWS CLI
`resume-cluster <https://docs.aws.amazon.com/cli/latest/reference/redshift/resume-cluster.html>`__ API

.. _howto/operator:RedshiftPauseClusterOperator:

Pause a Redshift Cluster
"""""""""""""""""""""""""""""""""""""""""""

To pause an 'available' AWS Redshift Cluster you can use
:class:`RedshiftPauseClusterOperator <airflow.providers.amazon.aws.operators.redshift>`
This Operator leverages the AWS CLI
`pause-cluster <https://docs.aws.amazon.com/cli/latest/reference/redshift/pause-cluster.html>`__ API
66 changes: 65 additions & 1 deletion tests/providers/amazon/aws/operators/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

from parameterized import parameterized

from airflow.providers.amazon.aws.operators.redshift import RedshiftSQLOperator
from airflow.providers.amazon.aws.operators.redshift import (
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
RedshiftSQLOperator,
)


class TestRedshiftSQLOperator(unittest.TestCase):
Expand All @@ -42,3 +46,63 @@ def test_redshift_operator(self, test_autocommit, test_parameters, mock_get_hook
autocommit=test_autocommit,
parameters=test_parameters,
)


class TestResumeClusterOperator:
def test_init(self):
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
assert redshift_operator.task_id == "task_test"
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn")
def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = 'paused'
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier='test_cluster')

@mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn")
def test_resume_cluster_not_called_when_cluster_is_not_paused(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = 'available'
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_not_called()


class TestPauseClusterOperator:
def test_init(self):
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
assert redshift_operator.task_id == "task_test"
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn")
def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = 'available'
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.pause_cluster.assert_called_once_with(ClusterIdentifier='test_cluster')

@mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn")
def test_pause_cluster_not_called_when_cluster_is_not_available(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = 'paused'
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.pause_cluster.assert_not_called()

0 comments on commit e77c05f

Please sign in to comment.