Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored waiting function for Tableau Jobs #17034

Merged
merged 11 commits into from
Jul 21, 2021
41 changes: 41 additions & 0 deletions airflow/providers/tableau/hooks/tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time
import warnings
from distutils.util import strtobool
from enum import Enum
Expand All @@ -22,9 +23,14 @@
from tableauserverclient import Pager, PersonalAccessTokenAuth, Server, TableauAuth
from tableauserverclient.server import Auth

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook


class TableauJobFailedException(AirflowException):
"""An exception that indicates that a Job failed to complete."""


class TableauJobFinishCode(Enum):
"""
The finish code indicates the status of the job.
Expand Down Expand Up @@ -133,3 +139,38 @@ def get_all(self, resource_name: str) -> Pager:
except AttributeError:
raise ValueError(f"Resource name {resource_name} is not found.")
return Pager(resource.get)

def get_job_status(self, job_id: str) -> TableauJobFinishCode:
"""
Get the current state of a defined Tableau Job.
.. see also:: https://tableau.github.io/server-client-python/docs/api-ref#jobs

:param job_id: The id of the job to check.
:type job_id: str
:return: An Enum that describe the Tableau job’s return code
:rtype: TableauJobFinishCode
"""
return TableauJobFinishCode(int(self.server.jobs.get_by_id(job_id).finish_code))

def wait_for_state(self, job_id: str, target_state: TableauJobFinishCode, check_interval: float) -> bool:
"""
Wait until the current state of a defined Tableau Job is equal
to target_state or different from PENDING.

:param job_id: The id of the job to check.
:type job_id: str
:param target_state: Enum that describe the Tableau job’s target state
:type target_state: TableauJobFinishCode
:param check_interval: time in seconds that the job should wait in
between each instance state checks until operation is completed
:type check_interval: float
:return: return True if the job is equal to the target_status, False otherwise.
:rtype: bool
"""
finish_code = self.get_job_status(job_id=job_id)
while finish_code == TableauJobFinishCode.PENDING and finish_code != target_state:
self.log.info("job state: %s", finish_code)
time.sleep(check_interval)
finish_code = self.get_job_status(job_id=job_id)

return finish_code == target_state
28 changes: 18 additions & 10 deletions airflow/providers/tableau/operators/tableau_refresh_workbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.tableau.hooks.tableau import TableauHook, TableauJobFinishCode
from airflow.providers.tableau.sensors.tableau_job_status import TableauJobFailedException
from airflow.providers.tableau.hooks.tableau import (
TableauHook,
TableauJobFailedException,
TableauJobFinishCode,
)


class TableauRefreshWorkbookOperator(BaseOperator):
Expand All @@ -41,6 +44,9 @@ class TableauRefreshWorkbookOperator(BaseOperator):
containing the credentials to authenticate to the Tableau Server. Default:
'tableau_default'.
:type tableau_conn_id: str
:param check_interval: time in seconds that the job should wait in
between each instance state checks until operation is completed
:type check_interval: float
"""

def __init__(
Expand All @@ -50,13 +56,15 @@ def __init__(
site_id: Optional[str] = None,
blocking: bool = True,
tableau_conn_id: str = 'tableau_default',
check_interval: float = 20,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.workbook_name = workbook_name
self.site_id = site_id
self.blocking = blocking
self.tableau_conn_id = tableau_conn_id
self.check_interval = check_interval

def execute(self, context: dict) -> str:
"""
Expand All @@ -72,14 +80,14 @@ def execute(self, context: dict) -> str:

job_id = self._refresh_workbook(tableau_hook, workbook.id)
if self.blocking:
finish_code = TableauJobFinishCode.PENDING
negative_codes = (TableauJobFinishCode.ERROR, TableauJobFinishCode.CANCELED)
while not finish_code == TableauJobFinishCode.SUCCESS:
return_code = int(tableau_hook.server.jobs.get_by_id(job_id).finish_code)
finish_code = TableauJobFinishCode(return_code)
if finish_code in negative_codes:
raise TableauJobFailedException('The Tableau Refresh Workbook Job failed!')
self.log.info('Workbook %s has been successfully refreshed.', self.workbook_name)
if not tableau_hook.wait_for_state(
job_id=job_id,
check_interval=self.check_interval,
target_state=TableauJobFinishCode.SUCCESS,
):
raise TableauJobFailedException('The Tableau Refresh Workbook Job failed!')

self.log.info('Workbook %s has been successfully refreshed.', self.workbook_name)
return job_id

def _get_workbook_by_name(self, tableau_hook: TableauHook) -> WorkbookItem:
Expand Down
17 changes: 8 additions & 9 deletions airflow/providers/tableau/sensors/tableau_job_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
# under the License.
from typing import Optional

from airflow.exceptions import AirflowException
from airflow.providers.tableau.hooks.tableau import TableauHook, TableauJobFinishCode
from airflow.providers.tableau.hooks.tableau import (
TableauHook,
TableauJobFailedException,
TableauJobFinishCode,
)
from airflow.sensors.base import BaseSensorOperator


class TableauJobFailedException(AirflowException):
"""An exception that indicates that a Job failed to complete."""


class TableauJobStatusSensor(BaseSensorOperator):
"""
Watches the status of a Tableau Server Job.
Expand Down Expand Up @@ -65,10 +64,10 @@ def poke(self, context: dict) -> bool:
:rtype: bool
"""
with TableauHook(self.site_id, self.tableau_conn_id) as tableau_hook:
finish_code = TableauJobFinishCode(
int(tableau_hook.server.jobs.get_by_id(self.job_id).finish_code)
)
finish_code = tableau_hook.get_job_status(job_id=self.job_id)
self.log.info('Current finishCode is %s (%s)', finish_code.name, finish_code.value)

if finish_code in (TableauJobFinishCode.ERROR, TableauJobFinishCode.CANCELED):
raise TableauJobFailedException('The Tableau Refresh Workbook Job failed!')

return finish_code == TableauJobFinishCode.SUCCESS
94 changes: 92 additions & 2 deletions tests/providers/tableau/hooks/test_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.

import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from parameterized import parameterized

from airflow import configuration, models
from airflow.providers.tableau.hooks.tableau import TableauHook
from airflow.providers.tableau.hooks.tableau import TableauHook, TableauJobFinishCode
from airflow.utils import db


Expand Down Expand Up @@ -189,3 +191,91 @@ def test_get_all(self, mock_pager, mock_server, mock_tableau_auth):
assert jobs == mock_pager.return_value

mock_pager.assert_called_once_with(mock_server.return_value.jobs.get)

@parameterized.expand(
[
(0, TableauJobFinishCode.SUCCESS),
(1, TableauJobFinishCode.ERROR),
(2, TableauJobFinishCode.CANCELED),
]
)
@patch('airflow.providers.tableau.hooks.tableau.Server')
def test_get_job_status(self, finish_code, expected_status, mock_tableau_server):
"""
Test get job status
"""
mock_tableau_server.jobs.get_by_id.return_value.finish_code = finish_code
with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
tableau_hook.server = mock_tableau_server
jobs_status = tableau_hook.get_job_status(job_id='j1')
assert jobs_status == expected_status

@patch('airflow.providers.tableau.hooks.tableau.Server')
def test_wait_for_state(self, mock_tableau_server):
"""
Test wait_for_state
"""
# Test SUCCESS Positive
with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
tableau_hook.get_job_status = MagicMock(
name='get_job_status',
side_effect=[TableauJobFinishCode.PENDING, TableauJobFinishCode.SUCCESS],
)
assert tableau_hook.wait_for_state(
job_id='j1', target_state=TableauJobFinishCode.SUCCESS, check_interval=1
)

# Test SUCCESS Negative
with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
tableau_hook.get_job_status = MagicMock(
name='get_job_status',
side_effect=[
TableauJobFinishCode.PENDING,
TableauJobFinishCode.PENDING,
eladkal marked this conversation as resolved.
Show resolved Hide resolved
TableauJobFinishCode.ERROR,
],
)
assert not tableau_hook.wait_for_state(
job_id='j1', target_state=TableauJobFinishCode.SUCCESS, check_interval=1
)

# Test ERROR Positive
with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
tableau_hook.get_job_status = MagicMock(
name='get_job_status',
side_effect=[
TableauJobFinishCode.PENDING,
TableauJobFinishCode.PENDING,
eladkal marked this conversation as resolved.
Show resolved Hide resolved
TableauJobFinishCode.ERROR,
],
)
assert tableau_hook.wait_for_state(
job_id='j1', target_state=TableauJobFinishCode.ERROR, check_interval=1
)

# Test CANCELLED Positive
with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
tableau_hook.get_job_status = MagicMock(
name='get_job_status',
side_effect=[
TableauJobFinishCode.PENDING,
TableauJobFinishCode.PENDING,
eladkal marked this conversation as resolved.
Show resolved Hide resolved
TableauJobFinishCode.CANCELED,
],
)
assert tableau_hook.wait_for_state(
job_id='j1', target_state=TableauJobFinishCode.CANCELED, check_interval=1
)

# Test PENDING Positive
with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
tableau_hook.get_job_status = MagicMock(
name='get_job_status',
side_effect=[
TableauJobFinishCode.PENDING,
TableauJobFinishCode.ERROR,
],
)
assert tableau_hook.wait_for_state(
job_id='j1', target_state=TableauJobFinishCode.PENDING, check_interval=1
)
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setUp(self):
mock_workbook.id = i
mock_workbook.name = f'wb_{i}'
self.mocked_workbooks.append(mock_workbook)
self.kwargs = {'site_id': 'test_site', 'task_id': 'task', 'dag': None}
self.kwargs = {'site_id': 'test_site', 'task_id': 'task', 'dag': None, 'check_interval': 1}

@patch('airflow.providers.tableau.operators.tableau_refresh_workbook.TableauHook')
def test_execute(self, mock_tableau_hook):
Expand Down Expand Up @@ -72,7 +72,9 @@ def test_execute_blocking(self, mock_tableau_hook):

mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2)
assert mock_tableau_hook.server.workbooks.refresh.return_value.id == job_id
mock_tableau_hook.server.jobs.get_by_id.assert_called_once_with(job_id)
mock_tableau_hook.wait_for_state.assert_called_once_with(
job_id=job_id, check_interval=1, target_state=TableauJobFinishCode.SUCCESS
)

@patch('airflow.providers.tableau.operators.tableau_refresh_workbook.TableauHook')
def test_execute_missing_workbook(self, mock_tableau_hook):
Expand Down
13 changes: 6 additions & 7 deletions tests/providers/tableau/sensors/test_tableau_job_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.providers.tableau.sensors.tableau_job_status import (
TableauJobFailedException,
TableauJobFinishCode,
TableauJobStatusSensor,
)

Expand All @@ -41,26 +42,24 @@ def test_poke(self, mock_tableau_hook):
Test poke
"""
mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
mock_get = mock_tableau_hook.server.jobs.get_by_id
mock_get.return_value.finish_code = '0'
mock_tableau_hook.get_job_status.return_value = TableauJobFinishCode.SUCCESS
sensor = TableauJobStatusSensor(**self.kwargs)

job_finished = sensor.poke(context={})

assert job_finished
mock_get.assert_called_once_with(sensor.job_id)
mock_tableau_hook.get_job_status.assert_called_once_with(job_id=sensor.job_id)

@parameterized.expand([('1',), ('2',)])
@parameterized.expand([(TableauJobFinishCode.ERROR,), (TableauJobFinishCode.CANCELED,)])
@patch('airflow.providers.tableau.sensors.tableau_job_status.TableauHook')
def test_poke_failed(self, finish_code, mock_tableau_hook):
"""
Test poke failed
"""
mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
mock_get = mock_tableau_hook.server.jobs.get_by_id
mock_get.return_value.finish_code = finish_code
mock_tableau_hook.get_job_status.return_value = finish_code
sensor = TableauJobStatusSensor(**self.kwargs)

with pytest.raises(TableauJobFailedException):
sensor.poke({})
mock_get.assert_called_once_with(sensor.job_id)
mock_tableau_hook.get_job_status.assert_called_once_with(job_id=sensor.job_id)