diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py index 5e7910b621175c..9d421932c8675e 100644 --- a/airflow/sensors/base_sensor_operator.py +++ b/airflow/sensors/base_sensor_operator.py @@ -20,6 +20,7 @@ from time import sleep from datetime import timedelta +from typing import Dict, Iterable from airflow.exceptions import AirflowException, AirflowSensorTimeout, \ AirflowSkipException, AirflowRescheduleException @@ -57,17 +58,17 @@ class BaseSensorOperator(BaseOperator, SkipMixin): prevent too much load on the scheduler. :type mode: str """ - ui_color = '#e6f1f2' - valid_modes = ['poke', 'reschedule'] + ui_color = '#e6f1f2' # type: str + valid_modes = ['poke', 'reschedule'] # type: Iterable[str] @apply_defaults def __init__(self, - poke_interval=60, - timeout=60 * 60 * 24 * 7, - soft_fail=False, - mode='poke', + poke_interval: float = 60, + timeout: float = 60 * 60 * 24 * 7, + soft_fail: bool = False, + mode: str = 'poke', *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.poke_interval = poke_interval self.soft_fail = soft_fail @@ -75,7 +76,7 @@ def __init__(self, self.mode = mode self._validate_input_values() - def _validate_input_values(self): + def _validate_input_values(self) -> None: if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: raise AirflowException( "The poke_interval must be a non-negative number") @@ -90,14 +91,14 @@ def _validate_input_values(self): d=self.dag.dag_id if self.dag else "", t=self.task_id, m=self.mode)) - def poke(self, context): + def poke(self, context: Dict) -> bool: """ Function that the sensors defined while deriving this class should override. """ raise AirflowException('Override me.') - def execute(self, context): + def execute(self, context: Dict) -> None: started_at = timezone.utcnow() if self.reschedule: # If reschedule, use first start date of current try @@ -122,7 +123,7 @@ def execute(self, context): sleep(self.poke_interval) self.log.info("Success criteria met. Exiting.") - def _do_skip_downstream_tasks(self, context): + def _do_skip_downstream_tasks(self, context: Dict) -> None: downstream_tasks = context['task'].get_flat_relatives(upstream=False) self.log.debug("Downstream task_ids %s", downstream_tasks) if downstream_tasks: