diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9be695daae43b2..a9ce160bb3169b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -190,6 +190,47 @@ repos: files: ^BREEZE.rst$|^breeze$|^breeze-complete$ pass_filenames: false require_serial: true + - id: base-operator + language: pygrep + name: Make sure BaseOperator is imported from airflow.models.baseoperator in core + entry: "from airflow.models import.* BaseOperator" + files: \.py$ + pass_filenames: true + exclude: > + (?x) + ^airflow/gcp/.*$| + ^airflow/hooks/.*$| + ^airflow/operators/.*$| + ^airflow/sensors/.*$| + ^airflow/providers/.*$| + ^airflow/contrib/.*$ + - id: base-operator + language: pygrep + name: Make sure BaseOperator is imported from airflow.models outside of core + entry: "from airflow.models.baseoperator import.* BaseOperator" + pass_filenames: true + files: > + (?x) + ^airflow/gcp/.*$| + ^airflow/hooks/.*$| + ^airflow/operators/.*$| + ^airflow/sensors/.*$| + ^airflow/providers/.*\.py$| + ^airflow/contrib/.*\.py$ + - id: airflow-exception + language: pygrep + name: Make sure AirflowException is imported using 'from airflow import AirflowException' + entry: "from airflow.exceptions import.* AirflowException" + pass_filenames: true + exclude: ^airflow/__init__\.py$ + files: \.py$ + - id: airflow-dag + language: pygrep + name: Make sure DAG is imported using 'from airflow import DAG' + entry: "from airflow.models import.* DAG|from airflow.models.dag import.* DAG" + pass_filenames: true + exclude: ^airflow/models/__init__\.py$|^airflow/__init__\.py$ + files: \.py$ - id: build name: Check if image build is needed entry: ./scripts/ci/pre_commit_ci_build.sh @@ -217,12 +258,14 @@ repos: files: \.py$ exclude: ^tests/.*\.py$|^airflow/_vendor/.*$ pass_filenames: true + require_serial: true - id: pylint name: Run pylint for tests language: system entry: "./scripts/ci/pre_commit_pylint_tests.sh" files: ^tests/.*\.py$ pass_filenames: true + require_serial: true - id: flake8 name: Run flake8 language: system diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 272e07a10c2ac2..ec66eb71af9020 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -414,7 +414,11 @@ image built locally): =================================== ================================================================ ============ **Hooks** **Description** **Breeze** =================================== ================================================================ ============ -``airflow-settings`` Check if airflow import settings are used well. +``airflow-dag`` Make sure airflow DAG is imported from correct package +----------------------------------- ---------------------------------------------------------------- ------------ +``airflow-exception`` Make sure airflow exception is imported from correct package +----------------------------------- ---------------------------------------------------------------- ------------ +``base-operator`` Checks that BaseOperator is imported properly ----------------------------------- ---------------------------------------------------------------- ------------ ``build`` Builds image for check-apache-licence, mypy, pylint, flake8. * ----------------------------------- ---------------------------------------------------------------- ------------ @@ -509,6 +513,51 @@ You can always skip running the tests by providing ``--no-verify`` flag to the To check other usage types of the pre-commit framework, see `Pre-commit website `__. +Importing Airflow core objects +============================== + +When you implement core features or DAGs you might need to import some of the core objects or modules. +Since Apache Airflow can be used both as application (by internal classes) and as library (by DAGs), there are +different ways those core objects and packages are imported. + +Airflow imports some of the core objects directly to 'airflow' package so that they can be used from there. + +Those criteria were assumed for choosing what import path to use: + +* If you work on a core feature inside Apache Airflow, you should import the objects directly from the + package where the object is defined - this minimises the risk of cyclic imports. +* If you import the objects from any of 'providers' classes, you should import the objects from + 'airflow' or 'airflow.models', It is very important for back-porting operators/hooks/sensors + to Airflow 1.10.* (AIP-21) +* If you import objects from within a DAG you write, you should import them from 'airflow' or + 'airflow.models' package where stable location of such import is important. + +Those checks enforced for the most important and repeated objects via pre-commit hooks as described below. + +BaseOperator +------------ + +The BaseOperator should be imported: +* as ``from airflow.models import BaseOperator`` in external DAG, provider's operator, or custom operator +* as ``from airflow.models.baseoperator import BaseOperator`` in Airflow core to avoid cyclic imports + +DAG +--- + +The DAG should be imported: +* as ``from airflow import DAG`` in external DAG, provider's operator, or custom operator +* as ``from airflow.models.dag import DAG`` in Airflow core to avoid cyclic imports + + +AirflowException +---------------- + +The AirflowException should be imported directly from airflow package: + +.. code-block:: python + + from airflow import AirflowException + Travis CI Testing Framework =========================== diff --git a/UPDATING.md b/UPDATING.md index b54c406a0d8a22..28a43ea6f53b02 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -41,6 +41,14 @@ assists users migrating to a new version. ## Airflow Master +### Changes to settings + +CONTEXT_MANAGER_DAG was removed from settings. It's role has been taken by `DagContext` in +'airflow.models.dag'. One of the reasons was that settings should be rather static than store +dynamic context from the DAG, but the main one is that moving the context out of settings allowed to +untangle cyclic imports between DAG, BaseOperator, SerializedDAG, SerializedBaseOperator which was +part of AIRFLOW-6010. + ### Removal of redirect_stdout, redirect_stderr Function `redirect_stderr` and `redirect_stdout` from `airflow.utils.log.logging_mixin` module has @@ -1443,7 +1451,7 @@ Type "help", "copyright", "credits" or "license" for more information. >>> from airflow.settings import * >>> >>> from datetime import datetime ->>> from airflow import DAG +>>> from airflow.models.dag import DAG >>> from airflow.operators.dummy_operator import DummyOperator >>> >>> dag = DAG('simple_dag', start_date=datetime(2017, 9, 1)) diff --git a/airflow/__init__.py b/airflow/__init__.py index f0ce14b0626b3c..f3750b053f6fe0 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -31,12 +31,15 @@ # pylint:disable=wrong-import-position from typing import Callable, Optional +# noinspection PyUnresolvedReferences +from airflow import utils from airflow import settings from airflow import version +from airflow.executors.all_executors import AllExecutors from airflow.utils.log.logging_mixin import LoggingMixin from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import DAG +from airflow.models.dag import DAG __version__ = version.version @@ -44,7 +47,6 @@ login = None # type: Optional[Callable] -from airflow import executors from airflow import hooks from airflow import macros from airflow import operators @@ -60,5 +62,5 @@ def __init__(self, namespace): operators._integrate_plugins() # pylint:disable=protected-access sensors._integrate_plugins() # pylint:disable=protected-access hooks._integrate_plugins() # pylint:disable=protected-access -executors._integrate_plugins() # pylint:disable=protected-access +AllExecutors._integrate_plugins() # pylint:disable=protected-access macros._integrate_plugins() # pylint:disable=protected-access diff --git a/airflow/api/__init__.py b/airflow/api/__init__.py index f138caaf0ae0b8..1072ffb173e084 100644 --- a/airflow/api/__init__.py +++ b/airflow/api/__init__.py @@ -20,8 +20,9 @@ from importlib import import_module +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.exceptions import AirflowConfigException from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py index b473bdd4ead03d..75777595108949 100644 --- a/airflow/api/common/experimental/delete_dag.py +++ b/airflow/api/common/experimental/delete_dag.py @@ -22,7 +22,8 @@ from airflow import models from airflow.exceptions import DagNotFound -from airflow.models import DagModel, SerializedDagModel, TaskFail +from airflow.models import DagModel, TaskFail +from airflow.models.serialized_dag import SerializedDagModel from airflow.settings import STORE_SERIALIZED_DAGS from airflow.utils.db import provide_session from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/api/common/experimental/get_code.py b/airflow/api/common/experimental/get_code.py index 4cdac7a195175f..5da70b6e3ec10d 100644 --- a/airflow/api/common/experimental/get_code.py +++ b/airflow/api/common/experimental/get_code.py @@ -17,8 +17,8 @@ # specific language governing permissions and limitations # under the License. """Get code APIs.""" +from airflow import AirflowException from airflow.api.common.experimental import check_and_get_dag -from airflow.exceptions import AirflowException from airflow.www import utils as wwwutils diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py index 1818da3b866387..2b0aa032b19639 100644 --- a/airflow/api/common/experimental/mark_tasks.py +++ b/airflow/api/common/experimental/mark_tasks.py @@ -24,7 +24,8 @@ from sqlalchemy import or_ from airflow.jobs import BackfillJob -from airflow.models import BaseOperator, DagRun, TaskInstance +from airflow.models import DagRun, TaskInstance +from airflow.models.baseoperator import BaseOperator from airflow.operators.subdag_operator import SubDagOperator from airflow.utils import timezone from airflow.utils.db import provide_session diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 0eae9be334155b..7a770d97026dd9 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -25,7 +25,7 @@ from contextlib import redirect_stderr, redirect_stdout from airflow import DAG, AirflowException, conf, jobs, settings -from airflow.executors import get_default_executor +from airflow.executors.all_executors import AllExecutors from airflow.models import DagPickle, TaskInstance from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext from airflow.utils import cli as cli_utils, db @@ -69,7 +69,7 @@ def _run(args, dag, ti): print(e) raise e - executor = get_default_executor() + executor = AllExecutors.get_default_executor() executor.start() print("Sending to executor.") executor.queue_task_instance( diff --git a/airflow/config_templates/default_celery.py b/airflow/config_templates/default_celery.py index 35a7c510ed810c..fbc6499b4b3cfd 100644 --- a/airflow/config_templates/default_celery.py +++ b/airflow/config_templates/default_celery.py @@ -19,8 +19,9 @@ """Default celery configuration.""" import ssl +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.exceptions import AirflowConfigException from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/configuration.py b/airflow/configuration.py index 1eec83ee7f3175..a8528cc7eb96b7 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -574,7 +574,7 @@ def get_airflow_test_config(airflow_home): has_option = conf.has_option remove_option = conf.remove_option as_dict = conf.as_dict -set = conf.set # noqa +set = conf.set # noqa for func in [load_test_config, get, getboolean, getfloat, getint, has_option, remove_option, as_dict, set]: diff --git a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable index 8269bc7843ead6..205e2b31a5aa71 100644 --- a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable +++ b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable @@ -19,7 +19,7 @@ -from airflow import DAG +from airflow.models.dag import DAG from airflow.contrib.operators.jenkins_job_trigger_operator import JenkinsJobTriggerOperator from airflow.operators.python_operator import PythonOperator from airflow.contrib.hooks.jenkins_hook import JenkinsHook diff --git a/airflow/contrib/example_dags/example_kubernetes_executor.py b/airflow/contrib/example_dags/example_kubernetes_executor.py index 683bef3a56aef5..eb09a50c698a06 100644 --- a/airflow/contrib/example_dags/example_kubernetes_executor.py +++ b/airflow/contrib/example_dags/example_kubernetes_executor.py @@ -22,7 +22,7 @@ import os import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.python_operator import PythonOperator args = { diff --git a/airflow/contrib/example_dags/example_kubernetes_executor_config.py b/airflow/contrib/example_dags/example_kubernetes_executor_config.py index e235777a696665..288190915ec8f7 100644 --- a/airflow/contrib/example_dags/example_kubernetes_executor_config.py +++ b/airflow/contrib/example_dags/example_kubernetes_executor_config.py @@ -22,8 +22,8 @@ import os import airflow +from airflow import DAG from airflow.contrib.example_dags.libs.helper import print_stuff -from airflow.models import DAG from airflow.operators.python_operator import PythonOperator default_args = { diff --git a/airflow/contrib/example_dags/example_kubernetes_operator.py b/airflow/contrib/example_dags/example_kubernetes_operator.py index ce3c40cf8475e7..7e824763820407 100644 --- a/airflow/contrib/example_dags/example_kubernetes_operator.py +++ b/airflow/contrib/example_dags/example_kubernetes_operator.py @@ -19,7 +19,7 @@ """ This is an example dag for using the KubernetesPodOperator. """ -from airflow.models import DAG +from airflow import DAG from airflow.utils.dates import days_ago from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/contrib/example_dags/example_papermill_operator.py b/airflow/contrib/example_dags/example_papermill_operator.py index 6840bfd2f04761..b4023c780c1a20 100644 --- a/airflow/contrib/example_dags/example_papermill_operator.py +++ b/airflow/contrib/example_dags/example_papermill_operator.py @@ -25,7 +25,7 @@ from datetime import timedelta import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.papermill_operator import PapermillOperator default_args = { diff --git a/airflow/contrib/example_dags/example_winrm_operator.py b/airflow/contrib/example_dags/example_winrm_operator.py index e3fd907fa3583a..c0dd65d045b342 100644 --- a/airflow/contrib/example_dags/example_winrm_operator.py +++ b/airflow/contrib/example_dags/example_winrm_operator.py @@ -32,9 +32,9 @@ from datetime import timedelta import airflow +from airflow import DAG from airflow.contrib.hooks.winrm_hook import WinRMHook from airflow.contrib.operators.winrm_operator import WinRMOperator -from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator default_args = { diff --git a/airflow/contrib/hooks/aws_dynamodb_hook.py b/airflow/contrib/hooks/aws_dynamodb_hook.py index abf594decdc56d..ce76661daf941b 100644 --- a/airflow/contrib/hooks/aws_dynamodb_hook.py +++ b/airflow/contrib/hooks/aws_dynamodb_hook.py @@ -21,8 +21,8 @@ """ This module contains the AWS DynamoDB hook """ +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook -from airflow.exceptions import AirflowException class AwsDynamoDBHook(AwsHook): diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index cd763277af359f..1c99e9affa43e9 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -26,7 +26,7 @@ import boto3 -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/azure_container_instance_hook.py b/airflow/contrib/hooks/azure_container_instance_hook.py index af329ccb7bd5ec..79e3169eb3655c 100644 --- a/airflow/contrib/hooks/azure_container_instance_hook.py +++ b/airflow/contrib/hooks/azure_container_instance_hook.py @@ -25,7 +25,7 @@ from azure.mgmt.containerinstance import ContainerInstanceManagementClient from zope.deprecation import deprecation -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/cloudant_hook.py b/airflow/contrib/hooks/cloudant_hook.py index a29f71dcb3a817..aa773fc630e875 100644 --- a/airflow/contrib/hooks/cloudant_hook.py +++ b/airflow/contrib/hooks/cloudant_hook.py @@ -19,7 +19,7 @@ """Hook for Cloudant""" from cloudant import cloudant -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index b284fb171a2e42..ed020140f8ad7f 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -30,8 +30,7 @@ from requests import exceptions as requests_exceptions from requests.auth import AuthBase -from airflow import __version__ -from airflow.exceptions import AirflowException +from airflow import AirflowException, __version__ from airflow.hooks.base_hook import BaseHook RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") diff --git a/airflow/contrib/hooks/datadog_hook.py b/airflow/contrib/hooks/datadog_hook.py index 21fb415c63a41f..af4e65beb8651d 100644 --- a/airflow/contrib/hooks/datadog_hook.py +++ b/airflow/contrib/hooks/datadog_hook.py @@ -21,7 +21,7 @@ from datadog import api, initialize -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/contrib/hooks/discord_webhook_hook.py b/airflow/contrib/hooks/discord_webhook_hook.py index 1b3c1aa7d92438..10313c14e96346 100644 --- a/airflow/contrib/hooks/discord_webhook_hook.py +++ b/airflow/contrib/hooks/discord_webhook_hook.py @@ -20,7 +20,7 @@ import json import re -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.http_hook import HttpHook diff --git a/airflow/contrib/hooks/emr_hook.py b/airflow/contrib/hooks/emr_hook.py index 52232f3a294b7e..e21698218e3f61 100644 --- a/airflow/contrib/hooks/emr_hook.py +++ b/airflow/contrib/hooks/emr_hook.py @@ -17,8 +17,8 @@ # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook -from airflow.exceptions import AirflowException class EmrHook(AwsHook): diff --git a/airflow/contrib/hooks/jira_hook.py b/airflow/contrib/hooks/jira_hook.py index 088e2d2a69dba8..bd62fdd5ef7927 100644 --- a/airflow/contrib/hooks/jira_hook.py +++ b/airflow/contrib/hooks/jira_hook.py @@ -20,7 +20,7 @@ from jira import JIRA from jira.exceptions import JIRAError -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/pagerduty_hook.py b/airflow/contrib/hooks/pagerduty_hook.py index 93f997af71ffda..2c9ee126c7e419 100644 --- a/airflow/contrib/hooks/pagerduty_hook.py +++ b/airflow/contrib/hooks/pagerduty_hook.py @@ -21,7 +21,7 @@ import pypd -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/pinot_hook.py b/airflow/contrib/hooks/pinot_hook.py index c4edc24177238b..e9ff7978b6f919 100644 --- a/airflow/contrib/hooks/pinot_hook.py +++ b/airflow/contrib/hooks/pinot_hook.py @@ -22,7 +22,7 @@ from pinotdb import connect -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.hooks.dbapi_hook import DbApiHook diff --git a/airflow/contrib/hooks/qubole_check_hook.py b/airflow/contrib/hooks/qubole_check_hook.py index 118fbe4a3ab7ba..dc744c60ea324c 100644 --- a/airflow/contrib/hooks/qubole_check_hook.py +++ b/airflow/contrib/hooks/qubole_check_hook.py @@ -22,8 +22,8 @@ from qds_sdk.commands import Command +from airflow import AirflowException from airflow.contrib.hooks.qubole_hook import QuboleHook -from airflow.exceptions import AirflowException from airflow.utils.log.logging_mixin import LoggingMixin COL_DELIM = '\t' diff --git a/airflow/contrib/hooks/qubole_hook.py b/airflow/contrib/hooks/qubole_hook.py index 6c25e6d6f92ee5..7edef2e76ce958 100644 --- a/airflow/contrib/hooks/qubole_hook.py +++ b/airflow/contrib/hooks/qubole_hook.py @@ -30,8 +30,8 @@ ) from qds_sdk.qubole import Qubole +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.models import TaskInstance from airflow.utils.log.logging_mixin import LoggingMixin @@ -123,7 +123,7 @@ def handle_failure_retry(context): log = LoggingMixin().log if cmd.status == 'done': log.info('Command ID: %s has been succeeded, hence marking this ' - 'TI as Success.', cmd_id) + 'TaskInstance as Success.', cmd_id) ti.state = State.SUCCESS elif cmd.status == 'running': log.info('Cancelling the Qubole Command Id: %s', cmd_id) diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py index 3480b4b5f41f9f..7c39c5b61ee8c7 100644 --- a/airflow/contrib/hooks/sagemaker_hook.py +++ b/airflow/contrib/hooks/sagemaker_hook.py @@ -25,9 +25,9 @@ from botocore.exceptions import ClientError +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils import timezone diff --git a/airflow/contrib/hooks/segment_hook.py b/airflow/contrib/hooks/segment_hook.py index 6bd6433c840447..bb91e6630d915c 100644 --- a/airflow/contrib/hooks/segment_hook.py +++ b/airflow/contrib/hooks/segment_hook.py @@ -27,7 +27,7 @@ """ import analytics -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/slack_webhook_hook.py b/airflow/contrib/hooks/slack_webhook_hook.py index 100a59b9da8f0c..bb94390f89d527 100644 --- a/airflow/contrib/hooks/slack_webhook_hook.py +++ b/airflow/contrib/hooks/slack_webhook_hook.py @@ -19,7 +19,7 @@ # import json -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.http_hook import HttpHook diff --git a/airflow/contrib/hooks/spark_jdbc_hook.py b/airflow/contrib/hooks/spark_jdbc_hook.py index 8e37816298c2ae..79192d8053aba2 100644 --- a/airflow/contrib/hooks/spark_jdbc_hook.py +++ b/airflow/contrib/hooks/spark_jdbc_hook.py @@ -19,8 +19,8 @@ # import os +from airflow import AirflowException from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook -from airflow.exceptions import AirflowException class SparkJDBCHook(SparkSubmitHook): diff --git a/airflow/contrib/hooks/spark_sql_hook.py b/airflow/contrib/hooks/spark_sql_hook.py index a8f8704d2ca878..fcba0a460e1437 100644 --- a/airflow/contrib/hooks/spark_sql_hook.py +++ b/airflow/contrib/hooks/spark_sql_hook.py @@ -19,7 +19,7 @@ # import subprocess -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py index 6354df38811ae1..0cae3511161351 100644 --- a/airflow/contrib/hooks/spark_submit_hook.py +++ b/airflow/contrib/hooks/spark_submit_hook.py @@ -22,7 +22,7 @@ import subprocess import time -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.kubernetes import kube_client from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/contrib/hooks/sqoop_hook.py b/airflow/contrib/hooks/sqoop_hook.py index ea9ad430129906..95ef3d658febd6 100644 --- a/airflow/contrib/hooks/sqoop_hook.py +++ b/airflow/contrib/hooks/sqoop_hook.py @@ -24,7 +24,7 @@ import subprocess from copy import deepcopy -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py index 1c06b34c3097b1..928934f078605e 100644 --- a/airflow/contrib/hooks/ssh_hook.py +++ b/airflow/contrib/hooks/ssh_hook.py @@ -27,7 +27,7 @@ from paramiko.config import SSH_PORT from sshtunnel import SSHTunnelForwarder -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/wasb_hook.py b/airflow/contrib/hooks/wasb_hook.py index 9f888a80f2062d..af1e6d53f56b41 100644 --- a/airflow/contrib/hooks/wasb_hook.py +++ b/airflow/contrib/hooks/wasb_hook.py @@ -27,7 +27,7 @@ """ from azure.storage.blob import BlockBlobService -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/hooks/winrm_hook.py b/airflow/contrib/hooks/winrm_hook.py index f81756c6fa8168..3cc7559d662194 100644 --- a/airflow/contrib/hooks/winrm_hook.py +++ b/airflow/contrib/hooks/winrm_hook.py @@ -22,7 +22,7 @@ from winrm.protocol import Protocol -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py index b28e85e42a948c..c52bb51b0dd7f4 100644 --- a/airflow/contrib/operators/awsbatch_operator.py +++ b/airflow/contrib/operators/awsbatch_operator.py @@ -23,8 +23,8 @@ from time import sleep from typing import Optional +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.typing_compat import Protocol from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/azure_container_instances_operator.py b/airflow/contrib/operators/azure_container_instances_operator.py index 06b126377f98ed..1f820ad74109ff 100644 --- a/airflow/contrib/operators/azure_container_instances_operator.py +++ b/airflow/contrib/operators/azure_container_instances_operator.py @@ -27,10 +27,11 @@ ) from msrestazure.azure_exceptions import CloudError +from airflow import AirflowException from airflow.contrib.hooks.azure_container_instance_hook import AzureContainerInstanceHook from airflow.contrib.hooks.azure_container_registry_hook import AzureContainerRegistryHook from airflow.contrib.hooks.azure_container_volume_hook import AzureContainerVolumeHook -from airflow.exceptions import AirflowException, AirflowTaskTimeout +from airflow.exceptions import AirflowTaskTimeout from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py index 34c706b0caf2ec..feca6cfa4e745e 100644 --- a/airflow/contrib/operators/databricks_operator.py +++ b/airflow/contrib/operators/databricks_operator.py @@ -23,8 +23,8 @@ import time +from airflow import AirflowException from airflow.contrib.hooks.databricks_hook import DatabricksHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/discord_webhook_operator.py b/airflow/contrib/operators/discord_webhook_operator.py index b0f38f0d3a721c..df7f3493414bda 100644 --- a/airflow/contrib/operators/discord_webhook_operator.py +++ b/airflow/contrib/operators/discord_webhook_operator.py @@ -17,8 +17,8 @@ # specific language governing permissions and limitations # under the License. # +from airflow import AirflowException from airflow.contrib.hooks.discord_webhook_hook import DiscordWebhookHook -from airflow.exceptions import AirflowException from airflow.operators.http_operator import SimpleHttpOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/docker_swarm_operator.py b/airflow/contrib/operators/docker_swarm_operator.py index bac8ec0d8638ee..8c6792a8a45246 100644 --- a/airflow/contrib/operators/docker_swarm_operator.py +++ b/airflow/contrib/operators/docker_swarm_operator.py @@ -19,7 +19,7 @@ from docker import types -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.operators.docker_operator import DockerOperator from airflow.utils.decorators import apply_defaults from airflow.utils.strings import get_random_string diff --git a/airflow/contrib/operators/dynamodb_to_s3.py b/airflow/contrib/operators/dynamodb_to_s3.py index 6afe2159d0e164..ab83ea636073a2 100644 --- a/airflow/contrib/operators/dynamodb_to_s3.py +++ b/airflow/contrib/operators/dynamodb_to_s3.py @@ -32,7 +32,7 @@ from boto.compat import json # type: ignore from airflow.contrib.hooks.aws_dynamodb_hook import AwsDynamoDBHook -from airflow.models.baseoperator import BaseOperator +from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py index d00b6cc43e02e9..c5ff288a9612df 100644 --- a/airflow/contrib/operators/ecs_operator.py +++ b/airflow/contrib/operators/ecs_operator.py @@ -21,9 +21,9 @@ from datetime import datetime from typing import Optional +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.typing_compat import Protocol from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/emr_add_steps_operator.py b/airflow/contrib/operators/emr_add_steps_operator.py index 9ff5488f87ba26..6c1c17d78385e4 100644 --- a/airflow/contrib/operators/emr_add_steps_operator.py +++ b/airflow/contrib/operators/emr_add_steps_operator.py @@ -16,8 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.emr_hook import EmrHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/emr_create_job_flow_operator.py b/airflow/contrib/operators/emr_create_job_flow_operator.py index a190d017cc2ecf..f44bdfb73de922 100644 --- a/airflow/contrib/operators/emr_create_job_flow_operator.py +++ b/airflow/contrib/operators/emr_create_job_flow_operator.py @@ -16,8 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.emr_hook import EmrHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/emr_terminate_job_flow_operator.py b/airflow/contrib/operators/emr_terminate_job_flow_operator.py index 6042e5ab80ec06..174dbe68130e74 100644 --- a/airflow/contrib/operators/emr_terminate_job_flow_operator.py +++ b/airflow/contrib/operators/emr_terminate_job_flow_operator.py @@ -16,8 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.emr_hook import EmrHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/gcs_to_gdrive_operator.py b/airflow/contrib/operators/gcs_to_gdrive_operator.py index 2ff7a6a487469b..9486471be37cc2 100644 --- a/airflow/contrib/operators/gcs_to_gdrive_operator.py +++ b/airflow/contrib/operators/gcs_to_gdrive_operator.py @@ -22,8 +22,8 @@ import tempfile from typing import Optional +from airflow import AirflowException from airflow.contrib.hooks.gdrive_hook import GoogleDriveHook -from airflow.exceptions import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py b/airflow/contrib/operators/jenkins_job_trigger_operator.py index 5006bf954622ba..a4ff293bae1db0 100644 --- a/airflow/contrib/operators/jenkins_job_trigger_operator.py +++ b/airflow/contrib/operators/jenkins_job_trigger_operator.py @@ -26,8 +26,8 @@ from jenkins import JenkinsException from requests import Request +from airflow import AirflowException from airflow.contrib.hooks.jenkins_hook import JenkinsHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/jira_operator.py b/airflow/contrib/operators/jira_operator.py index 583146f48d2c1e..81ae3af42bf9d6 100644 --- a/airflow/contrib/operators/jira_operator.py +++ b/airflow/contrib/operators/jira_operator.py @@ -18,8 +18,8 @@ # under the License. +from airflow import AirflowException from airflow.contrib.hooks.jira_hook import JIRAError, JiraHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py index 04888dd27ef7fd..38179fc8052958 100644 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ b/airflow/contrib/operators/kubernetes_pod_operator.py @@ -20,7 +20,7 @@ import kubernetes.client.models as k8s -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.kubernetes import kube_client, pod_generator, pod_launcher from airflow.kubernetes.k8s_model import append_to_pod from airflow.kubernetes.pod import Port, Resources diff --git a/airflow/contrib/operators/qubole_check_operator.py b/airflow/contrib/operators/qubole_check_operator.py index 73e8cc47f336c1..fbfe1c5b1faf05 100644 --- a/airflow/contrib/operators/qubole_check_operator.py +++ b/airflow/contrib/operators/qubole_check_operator.py @@ -17,9 +17,9 @@ # specific language governing permissions and limitations # under the License. # +from airflow import AirflowException from airflow.contrib.hooks.qubole_check_hook import QuboleCheckHook from airflow.contrib.operators.qubole_operator import QuboleOperator -from airflow.exceptions import AirflowException from airflow.operators.check_operator import CheckOperator, ValueCheckOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py index c39c22adc2f084..6896af220f122b 100644 --- a/airflow/contrib/operators/qubole_operator.py +++ b/airflow/contrib/operators/qubole_operator.py @@ -23,7 +23,7 @@ from airflow.contrib.hooks.qubole_hook import ( COMMAND_ARGS, HYPHEN_ARGS, POSITIONAL_ARGS, QuboleHook, flatten_list, ) -from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models import BaseOperator, BaseOperatorLink from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/s3_delete_objects_operator.py b/airflow/contrib/operators/s3_delete_objects_operator.py index 167bfad9102969..8b3b357a5cc2f1 100644 --- a/airflow/contrib/operators/s3_delete_objects_operator.py +++ b/airflow/contrib/operators/s3_delete_objects_operator.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py b/airflow/contrib/operators/s3_to_gcs_operator.py index 0fe8cb3b36fd28..eaa3acb932d194 100644 --- a/airflow/contrib/operators/s3_to_gcs_operator.py +++ b/airflow/contrib/operators/s3_to_gcs_operator.py @@ -19,8 +19,8 @@ import warnings from tempfile import NamedTemporaryFile +from airflow import AirflowException from airflow.contrib.operators.s3_list_operator import S3ListOperator -from airflow.exceptions import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook, _parse_gcs_url from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py index f0ddbd1b58c261..9dcd9ef63d9a29 100644 --- a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py +++ b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py @@ -17,8 +17,8 @@ # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator -from airflow.exceptions import AirflowException from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/sagemaker_endpoint_operator.py b/airflow/contrib/operators/sagemaker_endpoint_operator.py index aba16ce615af78..112dfe641e82e9 100644 --- a/airflow/contrib/operators/sagemaker_endpoint_operator.py +++ b/airflow/contrib/operators/sagemaker_endpoint_operator.py @@ -17,9 +17,9 @@ # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator -from airflow.exceptions import AirflowException from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/sagemaker_model_operator.py b/airflow/contrib/operators/sagemaker_model_operator.py index c6855232e9d76c..b4c6ca8d4ba298 100644 --- a/airflow/contrib/operators/sagemaker_model_operator.py +++ b/airflow/contrib/operators/sagemaker_model_operator.py @@ -17,9 +17,9 @@ # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator -from airflow.exceptions import AirflowException from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/sagemaker_training_operator.py b/airflow/contrib/operators/sagemaker_training_operator.py index 135fcf57ce6060..e50e3919c4c5cc 100644 --- a/airflow/contrib/operators/sagemaker_training_operator.py +++ b/airflow/contrib/operators/sagemaker_training_operator.py @@ -17,9 +17,9 @@ # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator -from airflow.exceptions import AirflowException from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/sagemaker_transform_operator.py b/airflow/contrib/operators/sagemaker_transform_operator.py index 71475cf8311959..4fec010c058a18 100644 --- a/airflow/contrib/operators/sagemaker_transform_operator.py +++ b/airflow/contrib/operators/sagemaker_transform_operator.py @@ -17,9 +17,9 @@ # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator -from airflow.exceptions import AirflowException from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/sagemaker_tuning_operator.py b/airflow/contrib/operators/sagemaker_tuning_operator.py index a5893519adbdee..893ebc6aa8df2d 100644 --- a/airflow/contrib/operators/sagemaker_tuning_operator.py +++ b/airflow/contrib/operators/sagemaker_tuning_operator.py @@ -17,9 +17,9 @@ # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator -from airflow.exceptions import AirflowException from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py index d71e8252823dc6..496e0cd9265d2b 100644 --- a/airflow/contrib/operators/sftp_operator.py +++ b/airflow/contrib/operators/sftp_operator.py @@ -18,8 +18,8 @@ # under the License. import os +from airflow import AirflowException from airflow.contrib.hooks.ssh_hook import SSHHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/sqoop_operator.py b/airflow/contrib/operators/sqoop_operator.py index 7058593b70d028..bf58122709842b 100644 --- a/airflow/contrib/operators/sqoop_operator.py +++ b/airflow/contrib/operators/sqoop_operator.py @@ -24,8 +24,8 @@ import os import signal +from airflow import AirflowException from airflow.contrib.hooks.sqoop_hook import SqoopHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py index 52df8d85186e45..1e889f579ac6ec 100644 --- a/airflow/contrib/operators/ssh_operator.py +++ b/airflow/contrib/operators/ssh_operator.py @@ -20,9 +20,9 @@ from base64 import b64encode from select import select +from airflow import AirflowException from airflow.configuration import conf from airflow.contrib.hooks.ssh_hook import SSHHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/operators/winrm_operator.py b/airflow/contrib/operators/winrm_operator.py index 25d7831950b952..3aac8039174496 100644 --- a/airflow/contrib/operators/winrm_operator.py +++ b/airflow/contrib/operators/winrm_operator.py @@ -22,9 +22,9 @@ from winrm.exceptions import WinRMOperationTimeoutError +from airflow import AirflowException from airflow.configuration import conf from airflow.contrib.hooks.winrm_hook import WinRMHook -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/sensors/datadog_sensor.py b/airflow/contrib/sensors/datadog_sensor.py index 622d6fffc936a5..c4354e3530bff9 100644 --- a/airflow/contrib/sensors/datadog_sensor.py +++ b/airflow/contrib/sensors/datadog_sensor.py @@ -18,8 +18,8 @@ # under the License. from datadog import api +from airflow import AirflowException from airflow.contrib.hooks.datadog_hook import DatadogHook -from airflow.exceptions import AirflowException from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/sensors/emr_base_sensor.py b/airflow/contrib/sensors/emr_base_sensor.py index b8cf5468817ec2..9217704e8b9959 100644 --- a/airflow/contrib/sensors/emr_base_sensor.py +++ b/airflow/contrib/sensors/emr_base_sensor.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/sensors/qubole_sensor.py b/airflow/contrib/sensors/qubole_sensor.py index 39432dca039c52..c0fb1f90a7f186 100644 --- a/airflow/contrib/sensors/qubole_sensor.py +++ b/airflow/contrib/sensors/qubole_sensor.py @@ -20,7 +20,7 @@ from qds_sdk.qubole import Qubole from qds_sdk.sensors import FileSensor, PartitionSensor -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/contrib/sensors/sagemaker_base_sensor.py b/airflow/contrib/sensors/sagemaker_base_sensor.py index 72183093bfb804..b09a39cf25e68b 100644 --- a/airflow/contrib/sensors/sagemaker_base_sensor.py +++ b/airflow/contrib/sensors/sagemaker_base_sensor.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/example_dags/example_bash_operator.py b/airflow/example_dags/example_bash_operator.py index 02b4a95df82dab..3e540115725041 100644 --- a/airflow/example_dags/example_bash_operator.py +++ b/airflow/example_dags/example_bash_operator.py @@ -22,7 +22,7 @@ from datetime import timedelta import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator diff --git a/airflow/example_dags/example_branch_operator.py b/airflow/example_dags/example_branch_operator.py index 307860364f77ca..7ffbe41cc7314b 100644 --- a/airflow/example_dags/example_branch_operator.py +++ b/airflow/example_dags/example_branch_operator.py @@ -22,7 +22,7 @@ import random import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import BranchPythonOperator diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py index 7455ef7ebbd23e..8c84e9cbaea0de 100644 --- a/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -23,7 +23,7 @@ """ import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import BranchPythonOperator diff --git a/airflow/example_dags/example_latest_only.py b/airflow/example_dags/example_latest_only.py index ae4339a44c4b83..488369b485ab16 100644 --- a/airflow/example_dags/example_latest_only.py +++ b/airflow/example_dags/example_latest_only.py @@ -22,7 +22,7 @@ import datetime as dt import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.latest_only_operator import LatestOnlyOperator diff --git a/airflow/example_dags/example_latest_only_with_trigger.py b/airflow/example_dags/example_latest_only_with_trigger.py index 3559afb0c85b9e..bb44940025b44d 100644 --- a/airflow/example_dags/example_latest_only_with_trigger.py +++ b/airflow/example_dags/example_latest_only_with_trigger.py @@ -22,7 +22,7 @@ import datetime as dt import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.latest_only_operator import LatestOnlyOperator from airflow.utils.trigger_rule import TriggerRule diff --git a/airflow/example_dags/example_pig_operator.py b/airflow/example_dags/example_pig_operator.py index bd2720e801583c..147d81396f8977 100644 --- a/airflow/example_dags/example_pig_operator.py +++ b/airflow/example_dags/example_pig_operator.py @@ -20,7 +20,7 @@ """Example DAG demonstrating the usage of the PigOperator.""" import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.pig_operator import PigOperator args = { diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index ecfba76cba1a42..354a5286e13717 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -23,7 +23,7 @@ from pprint import pprint import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.python_operator import PythonOperator, PythonVirtualenvOperator args = { diff --git a/airflow/example_dags/example_short_circuit_operator.py b/airflow/example_dags/example_short_circuit_operator.py index 7cbac6819a1ded..ae8b010c03a1ea 100644 --- a/airflow/example_dags/example_short_circuit_operator.py +++ b/airflow/example_dags/example_short_circuit_operator.py @@ -20,7 +20,7 @@ """Example DAG demonstrating the usage of the ShortCircuitOperator.""" import airflow.utils.helpers -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import ShortCircuitOperator diff --git a/airflow/example_dags/example_skip_dag.py b/airflow/example_dags/example_skip_dag.py index 205cd7a6404f1b..a754c4b1248c0b 100644 --- a/airflow/example_dags/example_skip_dag.py +++ b/airflow/example_dags/example_skip_dag.py @@ -20,8 +20,8 @@ """Example DAG demonstrating the DummyOperator and a custom DummySkipOperator which skips by default.""" import airflow +from airflow import DAG from airflow.exceptions import AirflowSkipException -from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator args = { diff --git a/airflow/example_dags/example_subdag_operator.py b/airflow/example_dags/example_subdag_operator.py index 97d497007d1e0a..1e00b8a4433d87 100644 --- a/airflow/example_dags/example_subdag_operator.py +++ b/airflow/example_dags/example_subdag_operator.py @@ -20,8 +20,8 @@ """Example DAG demonstrating the usage of the SubDagOperator.""" import airflow +from airflow import DAG from airflow.example_dags.subdags.subdag import subdag -from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 15e3e0c63766bf..9e24aae0d3ea0a 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -24,7 +24,7 @@ """ import airflow.utils.dates -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.operators.python_operator import PythonOperator diff --git a/airflow/example_dags/subdags/subdag.py b/airflow/example_dags/subdags/subdag.py index 46ae6b5674050d..b5d16c52597011 100644 --- a/airflow/example_dags/subdags/subdag.py +++ b/airflow/example_dags/subdags/subdag.py @@ -19,7 +19,7 @@ """Helper function to generate a DAG and operators given some arguments.""" -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator diff --git a/airflow/example_dags/test_utils.py b/airflow/example_dags/test_utils.py index 3fc8af1dada06c..0170db13d98dea 100644 --- a/airflow/example_dags/test_utils.py +++ b/airflow/example_dags/test_utils.py @@ -18,7 +18,7 @@ # under the License. """Used for unit tests""" import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator dag = DAG(dag_id='test_utils', schedule_interval=None) diff --git a/airflow/executors/__init__.py b/airflow/executors/__init__.py deleted file mode 100644 index 5e638a79575d69..00000000000000 --- a/airflow/executors/__init__.py +++ /dev/null @@ -1,99 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint:disable=missing-docstring - -import sys -from typing import Optional - -from airflow.configuration import conf -from airflow.exceptions import AirflowException -from airflow.executors.base_executor import BaseExecutor -from airflow.executors.local_executor import LocalExecutor -from airflow.executors.sequential_executor import SequentialExecutor -from airflow.utils.log.logging_mixin import LoggingMixin - -DEFAULT_EXECUTOR = None # type: Optional[BaseExecutor] - - -def _integrate_plugins(): - """Integrate plugins to the context.""" - from airflow.plugins_manager import executors_modules - for executors_module in executors_modules: - sys.modules[executors_module.__name__] = executors_module - globals()[executors_module._name] = executors_module # pylint:disable=protected-access - - -def get_default_executor(): - """Creates a new instance of the configured executor if none exists and returns it""" - global DEFAULT_EXECUTOR # pylint:disable=global-statement - - if DEFAULT_EXECUTOR is not None: - return DEFAULT_EXECUTOR - - executor_name = conf.get('core', 'EXECUTOR') - - DEFAULT_EXECUTOR = _get_executor(executor_name) - - log = LoggingMixin().log - log.info("Using executor %s", executor_name) - - return DEFAULT_EXECUTOR - - -class Executors: - LocalExecutor = "LocalExecutor" - SequentialExecutor = "SequentialExecutor" - CeleryExecutor = "CeleryExecutor" - DaskExecutor = "DaskExecutor" - KubernetesExecutor = "KubernetesExecutor" - - -def _get_executor(executor_name): - """ - Creates a new instance of the named executor. - In case the executor name is not know in airflow, - look for it in the plugins - """ - if executor_name == Executors.LocalExecutor: - return LocalExecutor() - elif executor_name == Executors.SequentialExecutor: - return SequentialExecutor() - elif executor_name == Executors.CeleryExecutor: - from airflow.executors.celery_executor import CeleryExecutor - return CeleryExecutor() - elif executor_name == Executors.DaskExecutor: - from airflow.executors.dask_executor import DaskExecutor - return DaskExecutor() - elif executor_name == Executors.KubernetesExecutor: - from airflow.executors.kubernetes_executor import KubernetesExecutor - return KubernetesExecutor() - else: - # Loading plugins - _integrate_plugins() - executor_path = executor_name.split('.') - if len(executor_path) != 2: - raise AirflowException( - "Executor {0} not supported: " - "please specify in format plugin_module.executor".format(executor_name)) - - if executor_path[0] in globals(): - return globals()[executor_path[0]].__dict__[executor_path[1]]() - else: - raise AirflowException("Executor {0} not supported.".format(executor_name)) diff --git a/airflow/executors/all_executors.py b/airflow/executors/all_executors.py new file mode 100644 index 00000000000000..3db346f5cd3f5d --- /dev/null +++ b/airflow/executors/all_executors.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Available executors""" +import sys +from typing import Optional + +from airflow import AirflowException +from airflow.configuration import conf +from airflow.executors.base_executor import BaseExecutor +from airflow.utils.log.logging_mixin import LoggingMixin + + +class AllExecutors: + """ + Keeps constants for all the currently available executors. + """ + + LOCAL_EXECUTOR = "LocalExecutor" + SEQUENTIAL_EXECUTOR = "SequentialExecutor" + CELERY_EXECUTOR = "CeleryExecutor" + DASK_EXECUTOR = "DaskExecutor" + KUBERNETES_EXECUTOR = "KubernetesExecutor" + + _default_executor: Optional[BaseExecutor] = None + + @classmethod + def get_default_executor(cls): + """Creates a new instance of the configured executor if none exists and returns it""" + if cls._default_executor is not None: + return cls._default_executor + + executor_name = conf.get('core', 'EXECUTOR') + + cls._default_executor = AllExecutors._get_executor(executor_name) + + log = LoggingMixin().log + log.info("Using executor %s", executor_name) + + return cls._default_executor + + @staticmethod + def _get_executor(executor_name): + """ + Creates a new instance of the named executor. + In case the executor name is not know in airflow, + look for it in the plugins + """ + if executor_name == AllExecutors.LOCAL_EXECUTOR: + from airflow.executors.local_executor import LocalExecutor + return LocalExecutor() + elif executor_name == AllExecutors.SEQUENTIAL_EXECUTOR: + from airflow.executors.sequential_executor import SequentialExecutor + return SequentialExecutor() + elif executor_name == AllExecutors.CELERY_EXECUTOR: + from airflow.executors.celery_executor import CeleryExecutor + return CeleryExecutor() + elif executor_name == AllExecutors.DASK_EXECUTOR: + from airflow.executors.dask_executor import DaskExecutor + return DaskExecutor() + elif executor_name == AllExecutors.KUBERNETES_EXECUTOR: + from airflow.executors.kubernetes_executor import KubernetesExecutor + return KubernetesExecutor() + else: + # Loading plugins + AllExecutors._integrate_plugins() + executor_path = executor_name.split('.') + if len(executor_path) != 2: + raise AirflowException( + "Executor {0} not supported: " + "please specify in format plugin_module.executor".format(executor_name)) + + if executor_path[0] in globals(): + return globals()[executor_path[0]].__dict__[executor_path[1]]() + else: + raise AirflowException("Executor {0} not supported.".format(executor_name)) + + @staticmethod + def _integrate_plugins(): + """Integrate plugins to the context.""" + from airflow.plugins_manager import executors_modules + for executors_module in executors_modules: + sys.modules[executors_module.__name__] = executors_module + # noinspection PyProtectedMember + globals()[executors_module._name] = executors_module # pylint:disable=protected-access diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 83fc44b59f3f90..3994ce98370243 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -26,9 +26,9 @@ from celery import Celery, states as celery_states +from airflow import AirflowException from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py index b355c4b6e7e86e..454b51b023f009 100644 --- a/airflow/executors/dask_executor.py +++ b/airflow/executors/dask_executor.py @@ -45,8 +45,7 @@ def __init__(self, cluster_address=None): def start(self): if self.tls_ca or self.tls_key or self.tls_cert: - from distributed.security import Security - security = Security( + security = distributed.Security( tls_client_key=self.tls_key, tls_client_cert=self.tls_cert, tls_ca_file=self.tls_ca, @@ -90,6 +89,7 @@ def sync(self): self._process_future(future) def end(self): + import distributed for future in distributed.as_completed(self.futures.copy()): self._process_future(future) diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index f6d61472eeaf11..dea36ad976ec12 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -29,9 +29,9 @@ from kubernetes import client, watch from kubernetes.client.rest import ApiException -from airflow import settings +from airflow import AirflowException, settings from airflow.configuration import conf -from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.exceptions import AirflowConfigException from airflow.executors.base_executor import BaseExecutor from airflow.kubernetes.kube_client import get_kube_client from airflow.kubernetes.pod_generator import PodGenerator diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 0086dcbf12a017..621e96a6f327b7 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -73,7 +73,7 @@ def execute_work(self, key, command): """ Executes command received and stores result state in queue. - :param key: the key to identify the TI + :param key: the key to identify the TaskInstance :type key: tuple(dag_id, task_id, execution_date) :param command: the command to execute :type command: str @@ -139,7 +139,7 @@ def start(self): def execute_async(self, key, command): """ - :param key: the key to identify the TI + :param key: the key to identify the TaskInstance :type key: tuple(dag_id, task_id, execution_date) :param command: the command to execute :type command: str @@ -182,7 +182,7 @@ def start(self): def execute_async(self, key, command): """ - :param key: the key to identify the TI + :param key: the key to identify the TaskInstance :type key: tuple(dag_id, task_id, execution_date) :param command: the command to execute :type command: str diff --git a/airflow/gcp/example_dags/example_automl_nl_text_classification.py b/airflow/gcp/example_dags/example_automl_nl_text_classification.py index 6179af12253f67..7be439bbd5a3c2 100644 --- a/airflow/gcp/example_dags/example_automl_nl_text_classification.py +++ b/airflow/gcp/example_dags/example_automl_nl_text_classification.py @@ -58,7 +58,7 @@ extract_object_id = CloudAutoMLHook.extract_object_id # Example DAG for AutoML Natural Language Text Classification -with models.DAG( +with models.dag.DAG( "example_automl_text_cls", default_args=default_args, schedule_interval=None, # Override to match your needs diff --git a/airflow/gcp/example_dags/example_automl_tables.py b/airflow/gcp/example_dags/example_automl_tables.py index cf3d1438b24da7..2a7064eaa4bd6a 100644 --- a/airflow/gcp/example_dags/example_automl_tables.py +++ b/airflow/gcp/example_dags/example_automl_tables.py @@ -75,7 +75,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: # Example DAG to create dataset, train model_id and deploy it. -with models.DAG( +with models.dag.DAG( "example_create_and_deploy", default_args=default_args, schedule_interval=None, # Override to match your needs @@ -183,7 +183,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: # Example DAG for AutoML datasets operations -with models.DAG( +with models.dag.DAG( "example_automl_dataset", default_args=default_args, schedule_interval=None, # Override to match your needs @@ -248,7 +248,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: >> delete_datasets_task # noqa ) -with models.DAG( +with models.dag.DAG( "example_gcp_get_deploy", default_args=default_args, schedule_interval=None, # Override to match your needs @@ -272,7 +272,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: # [END howto_operator_deploy_model] -with models.DAG( +with models.dag.DAG( "example_gcp_predict", default_args=default_args, schedule_interval=None, # Override to match your needs diff --git a/airflow/gcp/example_dags/example_automl_translation.py b/airflow/gcp/example_dags/example_automl_translation.py index e52857a68ba5ec..9fbfb081a64e0f 100644 --- a/airflow/gcp/example_dags/example_automl_translation.py +++ b/airflow/gcp/example_dags/example_automl_translation.py @@ -62,7 +62,7 @@ # Example DAG for AutoML Translation -with models.DAG( +with models.dag.DAG( "example_automl_translation", default_args=default_args, schedule_interval=None, # Override to match your needs diff --git a/airflow/gcp/example_dags/example_bigtable.py b/airflow/gcp/example_dags/example_bigtable.py index 3eaf79ae976a16..04074b476bde64 100644 --- a/airflow/gcp/example_dags/example_bigtable.py +++ b/airflow/gcp/example_dags/example_bigtable.py @@ -77,7 +77,7 @@ 'start_date': airflow.utils.dates.days_ago(1) } -with models.DAG( +with models.dag.DAG( 'example_gcp_bigtable_operators', default_args=default_args, schedule_interval=None # Override to match your needs diff --git a/airflow/gcp/example_dags/example_cloud_sql.py b/airflow/gcp/example_dags/example_cloud_sql.py index 53806b0fa0a71c..481d5f8f13419a 100644 --- a/airflow/gcp/example_dags/example_cloud_sql.py +++ b/airflow/gcp/example_dags/example_cloud_sql.py @@ -180,7 +180,7 @@ 'start_date': airflow.utils.dates.days_ago(1) } -with models.DAG( +with models.dag.DAG( 'example_gcp_sql', default_args=default_args, schedule_interval=None # Override to match your needs diff --git a/airflow/gcp/example_dags/example_dlp_operator.py b/airflow/gcp/example_dags/example_dlp_operator.py index 2067130cdcadf4..f28c3a6f8f66fb 100644 --- a/airflow/gcp/example_dags/example_dlp_operator.py +++ b/airflow/gcp/example_dags/example_dlp_operator.py @@ -30,11 +30,11 @@ from google.cloud.dlp_v2.types import ContentItem, InspectConfig, InspectTemplate import airflow +from airflow import DAG from airflow.gcp.operators.dlp import ( CloudDLPCreateInspectTemplateOperator, CloudDLPDeleteInspectTemplateOperator, CloudDLPInspectContentOperator, ) -from airflow.models import DAG default_args = {"start_date": airflow.utils.dates.days_ago(1)} diff --git a/airflow/gcp/example_dags/example_kubernetes_engine.py b/airflow/gcp/example_dags/example_kubernetes_engine.py index 73fdfb18155c19..2c25a7485ee1ec 100644 --- a/airflow/gcp/example_dags/example_kubernetes_engine.py +++ b/airflow/gcp/example_dags/example_kubernetes_engine.py @@ -36,7 +36,7 @@ default_args = {"start_date": airflow.utils.dates.days_ago(1)} -with models.DAG( +with models.dag.DAG( "example_gcp_gke", default_args=default_args, schedule_interval=None, # Override to match your needs diff --git a/airflow/gcp/example_dags/example_tasks.py b/airflow/gcp/example_dags/example_tasks.py index 1e36262c1fd9fe..ce483b8de62b84 100644 --- a/airflow/gcp/example_dags/example_tasks.py +++ b/airflow/gcp/example_dags/example_tasks.py @@ -31,10 +31,10 @@ from google.protobuf import timestamp_pb2 import airflow +from airflow import DAG from airflow.gcp.operators.tasks import ( CloudTasksQueueCreateOperator, CloudTasksTaskCreateOperator, CloudTasksTaskRunOperator, ) -from airflow.models import DAG default_args = {"start_date": airflow.utils.dates.days_ago(1)} timestamp = timestamp_pb2.Timestamp() diff --git a/airflow/gcp/hooks/base.py b/airflow/gcp/hooks/base.py index bdac2e049b3b59..95dc9f6714b2e5 100644 --- a/airflow/gcp/hooks/base.py +++ b/airflow/gcp/hooks/base.py @@ -41,8 +41,7 @@ from google.auth.environment_vars import CREDENTIALS from googleapiclient.errors import HttpError -from airflow import LoggingMixin, version -from airflow.exceptions import AirflowException +from airflow import AirflowException, LoggingMixin, version from airflow.hooks.base_hook import BaseHook logger = LoggingMixin().log diff --git a/airflow/gcp/hooks/cloud_storage_transfer_service.py b/airflow/gcp/hooks/cloud_storage_transfer_service.py index 9f075f27181dd7..bfe0a11c6ce756 100644 --- a/airflow/gcp/hooks/cloud_storage_transfer_service.py +++ b/airflow/gcp/hooks/cloud_storage_transfer_service.py @@ -29,7 +29,7 @@ from googleapiclient.discovery import build -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.base import GoogleCloudBaseHook # Time to sleep between active checks of the operation results diff --git a/airflow/gcp/hooks/gcs.py b/airflow/gcp/hooks/gcs.py index efb206069f58de..2c9c877444c57d 100644 --- a/airflow/gcp/hooks/gcs.py +++ b/airflow/gcp/hooks/gcs.py @@ -31,7 +31,7 @@ from google.cloud import storage -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.base import GoogleCloudBaseHook from airflow.version import version diff --git a/airflow/gcp/hooks/gsheets.py b/airflow/gcp/hooks/gsheets.py index b6cab3d87eb462..000c18d2d500ac 100644 --- a/airflow/gcp/hooks/gsheets.py +++ b/airflow/gcp/hooks/gsheets.py @@ -25,7 +25,7 @@ from googleapiclient.discovery import build -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.base import GoogleCloudBaseHook diff --git a/airflow/gcp/operators/bigquery.py b/airflow/gcp/operators/bigquery.py index 6ec11288d9f6d5..1021a60992c18c 100644 --- a/airflow/gcp/operators/bigquery.py +++ b/airflow/gcp/operators/bigquery.py @@ -28,10 +28,10 @@ from googleapiclient.errors import HttpError -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.bigquery import BigQueryHook from airflow.gcp.hooks.gcs import GoogleCloudStorageHook, _parse_gcs_url -from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance from airflow.operators.check_operator import CheckOperator, IntervalCheckOperator, ValueCheckOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/gcp/operators/datastore.py b/airflow/gcp/operators/datastore.py index 313e0b087e4f39..31610fa744b6e1 100644 --- a/airflow/gcp/operators/datastore.py +++ b/airflow/gcp/operators/datastore.py @@ -22,7 +22,7 @@ """ from typing import Optional -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.datastore import DatastoreHook from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator diff --git a/airflow/gcp/operators/mlengine.py b/airflow/gcp/operators/mlengine.py index fe00a51d7a1f20..72d6926a98e7d2 100644 --- a/airflow/gcp/operators/mlengine.py +++ b/airflow/gcp/operators/mlengine.py @@ -22,7 +22,7 @@ import warnings from typing import List, Optional -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.mlengine import MLEngineHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/gcp/utils/credentials_provider.py b/airflow/gcp/utils/credentials_provider.py index e759dfcd90288a..db8188cd7e9267 100644 --- a/airflow/gcp/utils/credentials_provider.py +++ b/airflow/gcp/utils/credentials_provider.py @@ -29,7 +29,7 @@ from google.auth.environment_vars import CREDENTIALS -from airflow.exceptions import AirflowException +from airflow import AirflowException AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT = "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT" diff --git a/airflow/gcp/utils/mlengine_operator_utils.py b/airflow/gcp/utils/mlengine_operator_utils.py index f838879ad6de5b..5fcb6a0d6888ff 100644 --- a/airflow/gcp/utils/mlengine_operator_utils.py +++ b/airflow/gcp/utils/mlengine_operator_utils.py @@ -28,7 +28,7 @@ import dill -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.gcp.operators.dataflow import DataFlowPythonOperator from airflow.gcp.operators.mlengine import MLEngineBatchPredictionOperator @@ -174,7 +174,7 @@ def validate_err_and_count(summary): :type version_name: str :param dag: The `DAG` to use for all Operators. - :type dag: airflow.models.DAG + :type dag: airflow.models.dag.DAG :param py_interpreter: Python version of the beam pipeline. If None, this defaults to the python2. diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py index 96a9faed9fd6ca..e3193497826165 100644 --- a/airflow/hooks/base_hook.py +++ b/airflow/hooks/base_hook.py @@ -21,7 +21,7 @@ import random from typing import Iterable -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models import Connection from airflow.utils.db import provide_session from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index cf02385f61c8db..c0cc75c6534f1d 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -23,7 +23,7 @@ from sqlalchemy import create_engine -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.typing_compat import Protocol diff --git a/airflow/hooks/docker_hook.py b/airflow/hooks/docker_hook.py index 0d8d881d678c24..66e3bce9cf80ef 100644 --- a/airflow/hooks/docker_hook.py +++ b/airflow/hooks/docker_hook.py @@ -20,7 +20,7 @@ from docker import APIClient from docker.errors import APIError -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py index 848fc0366d8d12..741796d0b3fdb5 100644 --- a/airflow/hooks/druid_hook.py +++ b/airflow/hooks/druid_hook.py @@ -22,7 +22,7 @@ import requests from pydruid.db import connect -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.hooks.dbapi_hook import DbApiHook diff --git a/airflow/hooks/hdfs_hook.py b/airflow/hooks/hdfs_hook.py index f9b3bebcc05f6c..4b5f0d7b6f53f3 100644 --- a/airflow/hooks/hdfs_hook.py +++ b/airflow/hooks/hdfs_hook.py @@ -17,8 +17,8 @@ # specific language governing permissions and limitations # under the License. """Hook for HDFS operations""" +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook try: diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py index 021c1c07fdc034..7cd33e6299b3c7 100644 --- a/airflow/hooks/hive_hooks.py +++ b/airflow/hooks/hive_hooks.py @@ -28,8 +28,8 @@ import unicodecsv as csv +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.security import utils from airflow.utils.file import TemporaryDirectory diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py index 316fe3fa4048f0..1ce86717c98e0a 100644 --- a/airflow/hooks/http_hook.py +++ b/airflow/hooks/http_hook.py @@ -20,7 +20,7 @@ import requests import tenacity -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/hooks/pig_hook.py b/airflow/hooks/pig_hook.py index 2b4d015bac0272..38d5cba1240ccc 100644 --- a/airflow/hooks/pig_hook.py +++ b/airflow/hooks/pig_hook.py @@ -20,7 +20,7 @@ import subprocess from tempfile import NamedTemporaryFile -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.utils.file import TemporaryDirectory diff --git a/airflow/hooks/slack_hook.py b/airflow/hooks/slack_hook.py index 907595ed8630a7..06e09784cace67 100644 --- a/airflow/hooks/slack_hook.py +++ b/airflow/hooks/slack_hook.py @@ -21,7 +21,7 @@ from slackclient import SlackClient -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/hooks/webhdfs_hook.py b/airflow/hooks/webhdfs_hook.py index ec19132ed8b53e..19b48dcf2c40b3 100644 --- a/airflow/hooks/webhdfs_hook.py +++ b/airflow/hooks/webhdfs_hook.py @@ -19,8 +19,8 @@ """Hook for Web HDFS""" from hdfs import HdfsError, InsecureClient +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index 2b12dcfca5e721..391b2b10797a90 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -24,13 +24,15 @@ from sqlalchemy.orm.session import Session, make_transient -from airflow import executors, models +from airflow import DAG, models from airflow.exceptions import ( AirflowException, DagConcurrencyLimitReached, NoAvailablePoolSlot, PoolNotFound, TaskConcurrencyLimitReached, ) +from airflow.executors.local_executor import LocalExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.base_job import BaseJob -from airflow.models import DAG, DagPickle, DagRun +from airflow.models import DagPickle, DagRun from airflow.ti_deps.dep_context import BACKFILL_QUEUED_DEPS, DepContext from airflow.utils import timezone from airflow.utils.configuration import tmp_configuration_copy @@ -487,8 +489,7 @@ def _per_task_process(task, key, ti, session=None): session.merge(ti) cfg_path = None - if executor.__class__ in (executors.LocalExecutor, - executors.SequentialExecutor): + if executor.__class__ in (LocalExecutor, SequentialExecutor): cfg_path = tmp_configuration_copy() executor.queue_task_instance( @@ -740,8 +741,7 @@ def _execute(self, session=None): # picklin' pickle_id = None - if not self.donot_pickle and self.executor.__class__ not in ( - executors.LocalExecutor, executors.SequentialExecutor): + if not self.donot_pickle and self.executor.__class__ not in (LocalExecutor, SequentialExecutor): pickle = DagPickle(self.dag) session.add(pickle) session.commit() diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index 874c0c6f058cec..20b89d2d8c29ea 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -26,9 +26,9 @@ from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import make_transient -from airflow import executors, models +from airflow import AirflowException, models from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.executors.all_executors import AllExecutors from airflow.models.base import ID_LEN, Base from airflow.stats import Stats from airflow.utils import helpers, timezone @@ -79,7 +79,7 @@ def __init__( heartrate=None, *args, **kwargs): self.hostname = get_hostname() - self.executor = executor or executors.get_default_executor() + self.executor = executor or AllExecutors.get_default_executor() self.executor_class = executor.__class__.__name__ self.start_date = timezone.utcnow() self.latest_heartbeat = timezone.utcnow() @@ -255,21 +255,21 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None): running_tis = self.executor.running resettable_states = [State.SCHEDULED, State.QUEUED] - TI = models.TaskInstance + TaskInstance = models.TaskInstance DR = models.DagRun if filter_by_dag_run is None: resettable_tis = ( session - .query(TI) + .query(TaskInstance) .join( DR, and_( - TI.dag_id == DR.dag_id, - TI.execution_date == DR.execution_date)) + TaskInstance.dag_id == DR.dag_id, + TaskInstance.execution_date == DR.execution_date)) .filter( DR.state == State.RUNNING, DR.run_id.notlike(BackfillJob.ID_PREFIX + '%'), - TI.state.in_(resettable_states))).all() + TaskInstance.state.in_(resettable_states))).all() else: resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states, session=session) @@ -283,14 +283,14 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None): return [] def query(result, items): - filter_for_tis = ([and_(TI.dag_id == ti.dag_id, - TI.task_id == ti.task_id, - TI.execution_date == ti.execution_date) + filter_for_tis = ([and_(TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.execution_date == ti.execution_date) for ti in items]) reset_tis = ( session - .query(TI) - .filter(or_(*filter_for_tis), TI.state.in_(resettable_states)) + .query(TaskInstance) + .filter(or_(*filter_for_tis), TaskInstance.state.in_(resettable_states)) .with_for_update() .all()) for ti in reset_tis: diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 3f57d99a1a81cc..fd1a5a242d16b6 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -22,8 +22,8 @@ import signal import time +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.jobs.base_job import BaseJob from airflow.stats import Stats from airflow.task.task_runner import get_task_runner diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 2640f1232efe51..19b5bc4af65f95 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -35,11 +35,12 @@ from sqlalchemy import and_, func, not_, or_ from sqlalchemy.orm.session import make_transient -from airflow import executors, models, settings +from airflow import DAG, AirflowException, models, settings from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.executors.local_executor import LocalExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.base_job import BaseJob -from airflow.models import DAG, DagRun, SlaMiss, errors +from airflow.models import DagRun, SlaMiss, errors from airflow.stats import Stats from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, SCHEDULED_DEPS, DepContext from airflow.ti_deps.deps.pool_slots_available_dep import STATES_TO_COUNT_AS_RUNNING @@ -403,25 +404,25 @@ def manage_slas(self, dag, session=None): self.log.info("Skipping SLA check for %s because no tasks in DAG have SLAs", dag) return - TI = models.TaskInstance + TaskInstance = models.TaskInstance sq = ( session .query( - TI.task_id, - func.max(TI.execution_date).label('max_ti')) - .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql') - .filter(TI.dag_id == dag.dag_id) + TaskInstance.task_id, + func.max(TaskInstance.execution_date).label('max_ti')) + .with_hint(TaskInstance, 'USE INDEX (PRIMARY)', dialect_name='mysql') + .filter(TaskInstance.dag_id == dag.dag_id) .filter(or_( - TI.state == State.SUCCESS, - TI.state == State.SKIPPED)) - .filter(TI.task_id.in_(dag.task_ids)) - .group_by(TI.task_id).subquery('sq') + TaskInstance.state == State.SUCCESS, + TaskInstance.state == State.SKIPPED)) + .filter(TaskInstance.task_id.in_(dag.task_ids)) + .group_by(TaskInstance.task_id).subquery('sq') ) - max_tis = session.query(TI).filter( - TI.dag_id == dag.dag_id, - TI.task_id == sq.c.task_id, - TI.execution_date == sq.c.max_ti, + max_tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.task_id == sq.c.task_id, + TaskInstance.execution_date == sq.c.max_ti, ).all() ts = timezone.utcnow() @@ -452,11 +453,11 @@ def manage_slas(self, dag, session=None): sla_dates = [sla.execution_date for sla in slas] qry = ( session - .query(TI) + .query(TaskInstance) .filter( - TI.state != State.SUCCESS, - TI.execution_date.in_(sla_dates), - TI.dag_id == dag.dag_id + TaskInstance.state != State.SUCCESS, + TaskInstance.execution_date.in_(sla_dates), + TaskInstance.dag_id == dag.dag_id ).all() ) blocking_tis = [] @@ -808,12 +809,12 @@ def __get_concurrency_maps(self, states, session=None): :rtype: dict[tuple[str, str], int] """ - TI = models.TaskInstance + TaskInstance = models.TaskInstance ti_concurrency_query = ( session - .query(TI.task_id, TI.dag_id, func.count('*')) - .filter(TI.state.in_(states)) - .group_by(TI.task_id, TI.dag_id) + .query(TaskInstance.task_id, TaskInstance.dag_id, func.count('*')) + .filter(TaskInstance.state.in_(states)) + .group_by(TaskInstance.task_id, TaskInstance.dag_id) ).all() dag_map = defaultdict(int) task_map = defaultdict(int) @@ -844,31 +845,32 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None): # Get all task instances associated with scheduled # DagRuns which are not backfilled, in the given states, # and the dag is not paused - TI = models.TaskInstance + TaskInstance = models.TaskInstance DR = models.DagRun DM = models.DagModel ti_query = ( session - .query(TI) - .filter(TI.dag_id.in_(simple_dag_bag.dag_ids)) + .query(TaskInstance) + .filter(TaskInstance.dag_id.in_(simple_dag_bag.dag_ids)) .outerjoin( DR, - and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date) + and_(DR.dag_id == TaskInstance.dag_id, DR.execution_date == TaskInstance.execution_date) ) - .filter(or_(DR.run_id == None, # noqa: E711 pylint: disable=singleton-comparison + .filter(or_(DR.run_id is None, # noqa: E711 pylint: disable=singleton-comparison not_(DR.run_id.like(BackfillJob.ID_PREFIX + '%')))) - .outerjoin(DM, DM.dag_id == TI.dag_id) - .filter(or_(DM.dag_id == None, # noqa: E711 pylint: disable=singleton-comparison + .outerjoin(DM, DM.dag_id == TaskInstance.dag_id) + .filter(or_(DM.dag_id is None, # noqa: E711 pylint: disable=singleton-comparison not_(DM.is_paused))) ) # Additional filters on task instance state if None in states: ti_query = ti_query.filter( - or_(TI.state == None, TI.state.in_(states)) # noqa: E711 pylint: disable=singleton-comparison + or_(TaskInstance.state is None, + TaskInstance.state.in_(states)) # noqa: E711 pylint: disable=singleton-comparison ) else: - ti_query = ti_query.filter(TI.state.in_(states)) + ti_query = ti_query.filter(TaskInstance.state.in_(states)) task_instances_to_examine = ti_query.all() @@ -1019,24 +1021,24 @@ def _change_state_for_executable_task_instances(self, task_instances, session.commit() return [] - TI = models.TaskInstance + TaskInstance = models.TaskInstance filter_for_ti_state_change = ( [and_( - TI.dag_id == ti.dag_id, - TI.task_id == ti.task_id, - TI.execution_date == ti.execution_date) + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.execution_date == ti.execution_date) for ti in task_instances]) ti_query = ( session - .query(TI) + .query(TaskInstance) .filter(or_(*filter_for_ti_state_change))) if None in acceptable_states: ti_query = ti_query.filter( - or_(TI.state == None, TI.state.in_(acceptable_states)) # noqa pylint: disable=singleton-comparison + or_(TaskInstance.state == None, TaskInstance.state.in_(acceptable_states)) # noqa pylint: disable=singleton-comparison ) else: - ti_query = ti_query.filter(TI.state.in_(acceptable_states)) + ti_query = ti_query.filter(TaskInstance.state.in_(acceptable_states)) tis_to_set_to_queued = ( ti_query @@ -1080,11 +1082,11 @@ def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, :param simple_dag_bag: Should contains all of the task_instances' dags :type simple_dag_bag: airflow.utils.dag_processing.SimpleDagBag """ - TI = models.TaskInstance + TaskInstance = models.TaskInstance # actually enqueue them for simple_task_instance in simple_task_instances: simple_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id) - command = TI.generate_command( + command = TaskInstance.generate_command( simple_task_instance.dag_id, simple_task_instance.task_id, simple_task_instance.execution_date, @@ -1157,19 +1159,19 @@ def _change_state_for_tasks_failed_to_execute(self, session): :param session: session for ORM operations """ if self.executor.queued_tasks: - TI = models.TaskInstance + TaskInstance = models.TaskInstance filter_for_ti_state_change = ( [and_( - TI.dag_id == dag_id, - TI.task_id == task_id, - TI.execution_date == execution_date, - # The TI.try_number will return raw try_number+1 since the + TaskInstance.dag_id == dag_id, + TaskInstance.task_id == task_id, + TaskInstance.execution_date == execution_date, + # The TaskInstance.try_number will return raw try_number+1 since the # ti is not running. And we need to -1 to match the DB record. - TI._try_number == try_number - 1, - TI.state == State.QUEUED) + TaskInstance._try_number == try_number - 1, + TaskInstance.state == State.QUEUED) for dag_id, task_id, execution_date, try_number in self.executor.queued_tasks.keys()]) - ti_query = (session.query(TI) + ti_query = (session.query(TaskInstance) .filter(or_(*filter_for_ti_state_change))) tis_to_set_to_scheduled = (ti_query .with_for_update() @@ -1238,7 +1240,7 @@ def _process_executor_events(self, simple_dag_bag, session=None): """ # TODO: this shares quite a lot of code with _manage_executor_state - TI = models.TaskInstance + TaskInstance = models.TaskInstance for key, state in list(self.executor.get_event_buffer(simple_dag_bag.dag_ids) .items()): dag_id, task_id, execution_date, try_number = key @@ -1248,9 +1250,9 @@ def _process_executor_events(self, simple_dag_bag, session=None): dag_id, task_id, execution_date, state, try_number ) if state == State.FAILED or state == State.SUCCESS: - qry = session.query(TI).filter(TI.dag_id == dag_id, - TI.task_id == task_id, - TI.execution_date == execution_date) + qry = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id, + TaskInstance.task_id == task_id, + TaskInstance.execution_date == execution_date) ti = qry.first() if not ti: self.log.warning("TaskInstance %s went missing from the database", ti) @@ -1282,8 +1284,7 @@ def _execute(self): # DAGs can be pickled for easier remote execution by some executors pickle_dags = False - if self.do_pickle and self.executor.__class__ not in \ - (executors.LocalExecutor, executors.SequentialExecutor): + if self.do_pickle and self.executor.__class__ not in (LocalExecutor, SequentialExecutor): pickle_dags = True self.log.info("Processing each file at most %s times", self.num_runs) @@ -1550,8 +1551,8 @@ def process_file(self, file_path, zombies, pickle_dags=False, session=None): ti = models.TaskInstance(task, ti_key[2]) ti.refresh_from_db(session=session, lock_for_update=True) - # We check only deps needed to set TI to SCHEDULED state here. - # Deps needed to set TI to QUEUED state will be batch checked later + # We check only deps needed to set TaskInstance to SCHEDULED state here. + # Deps needed to set TaskInstance to QUEUED state will be batch checked later # by the scheduler for better performance. dep_context = DepContext(deps=SCHEDULED_DEPS, ignore_task_deps=True) diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index 380db2a798cf1f..d25a88c89c46e7 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -26,7 +26,7 @@ import kubernetes.client.models as k8s -from airflow.executors import Executors +from airflow.executors.all_executors import AllExecutors class PodDefaults: @@ -230,7 +230,7 @@ def from_obj(obj) -> k8s.V1Pod: 'Cannot convert a non-dictionary or non-PodGenerator ' 'object into a KubernetesExecutorConfig') - namespaced = obj.get(Executors.KubernetesExecutor, {}) + namespaced = obj.get(AllExecutors.KUBERNETES_EXECUTOR, {}) resources = namespaced.get('resources') diff --git a/airflow/migrations/versions/211e584da130_add_ti_state_index.py b/airflow/migrations/versions/211e584da130_add_ti_state_index.py index b17f390e0b65bf..82ad6ae97c9545 100644 --- a/airflow/migrations/versions/211e584da130_add_ti_state_index.py +++ b/airflow/migrations/versions/211e584da130_add_ti_state_index.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -"""add TI state index +"""add TaskInstance state index Revision ID: 211e584da130 Revises: 2e82aab8ef20 diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index 6f78f72f1f4084..4624022efc7cb2 100644 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -18,7 +18,7 @@ # under the License. """Airflow models""" from airflow.models.base import ID_LEN, Base # noqa: F401 -from airflow.models.baseoperator import BaseOperator # noqa: F401 +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink # noqa: F401 from airflow.models.connection import Connection # noqa: F401 from airflow.models.dag import DAG, DagModel # noqa: F401 from airflow.models.dagbag import DagBag # noqa: F401 @@ -28,7 +28,6 @@ from airflow.models.kubernetes import KubeResourceVersion, KubeWorkerIdentifier # noqa: F401 from airflow.models.log import Log # noqa: F401 from airflow.models.pool import Pool # noqa: F401 -from airflow.models.serialized_dag import SerializedDagModel # noqa: F401 from airflow.models.skipmixin import SkipMixin # noqa: F401 from airflow.models.slamiss import SlaMiss # noqa: F401 from airflow.models.taskfail import TaskFail # noqa: F401 @@ -39,4 +38,5 @@ # Load SQLAlchemy models during package initialization # Must be loaded after loading DAG model. +# noinspection PyUnresolvedReferences import airflow.jobs # noqa: F401 isort # isort:skip diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 513708ca1de43b..743a05dfdcf492 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -26,18 +26,18 @@ import warnings from abc import ABCMeta, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union +from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union import jinja2 from cached_property import cached_property from dateutil.relativedelta import relativedelta -from airflow import settings +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException, DuplicateTaskIdFound +from airflow.exceptions import DuplicateTaskIdFound from airflow.lineage import DataSet, apply_lineage, prepare_lineage -from airflow.models.dag import DAG from airflow.models.pool import Pool +# noinspection PyPep8Naming from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.xcom import XCOM_RETURN_KEY from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep @@ -233,11 +233,10 @@ class derived from this one results in the creation of a task object, result :type do_xcom_push: bool """ - # For derived classes to define which fields will get jinjaified - template_fields = [] # type: Iterable[str] + template_fields: Iterable[str] = [] # Defines which files extensions to look for in the templated fields - template_ext = [] # type: Iterable[str] + template_ext: Iterable[str] = [] # Defines the color in the UI ui_color = '#fff' # type: str ui_fgcolor = '#000' # type: str @@ -245,19 +244,17 @@ class derived from this one results in the creation of a task object, pool = "" # type: str # base list which includes all the attrs that don't need deep copy. - _base_operator_shallow_copy_attrs = ('user_defined_macros', - 'user_defined_filters', - 'params', - '_log',) # type: Iterable[str] + BASE_OPERATOR_SHALLOW_COPY_ATTRS: Tuple[str, ...] = \ + ('user_defined_macros', 'user_defined_filters', 'params', '_log',) # each operator should override this class attr for shallow copy attrs. - shallow_copy_attrs = () # type: Iterable[str] + shallow_copy_attrs: Tuple[str, ...] = () # Defines the operator level extra links - operator_extra_links = () # type: Iterable[BaseOperatorLink] + operator_extra_links: Iterable['BaseOperatorLink'] = () - # Set at end of file - _serialized_fields = frozenset() # type: FrozenSet[str] + # The _serialized_fields are lazily loaded when get_serialized_fields() method is called + _serialized_fields: Optional[FrozenSet[str]] = None _comps = { 'task_id', @@ -281,7 +278,7 @@ class derived from this one results in the creation of a task object, } # noinspection PyUnusedLocal - # pylint: disable=too-many-arguments,too-many-locals + # pylint: disable=too-many-arguments,too-many-locals, too-many-statements @apply_defaults def __init__( self, @@ -298,7 +295,7 @@ def __init__( end_date: Optional[datetime] = None, depends_on_past: bool = False, wait_for_downstream: bool = False, - dag: Optional[DAG] = None, + dag=None, params: Optional[Dict] = None, default_args: Optional[Dict] = None, # pylint: disable=unused-argument priority_weight: int = 1, @@ -321,7 +318,8 @@ def __init__( *args, **kwargs ): - + from airflow.models.dag import DagContext + super().__init__() if args or kwargs: if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): raise AirflowException( @@ -381,7 +379,7 @@ def __init__( self.retry_delay = retry_delay else: self.log.debug("Retry_delay isn't timedelta object, assuming secs") - self.retry_delay = timedelta(seconds=retry_delay) + self.retry_delay = timedelta(seconds=retry_delay.seconds) self.retry_exponential_backoff = retry_exponential_backoff self.max_retry_delay = max_retry_delay self.params = params or {} # Available in templates! @@ -393,27 +391,24 @@ def __init__( .format(all_weight_rules=WeightRule.all_weight_rules, d=dag.dag_id if dag else "", t=task_id, tr=weight_rule)) self.weight_rule = weight_rule - - self.resources = Resources(**resources) if resources is not None else None + self.resources: Optional[Resources] = Resources(**resources) if resources else None self.run_as_user = run_as_user self.task_concurrency = task_concurrency self.executor_config = executor_config or {} self.do_xcom_push = do_xcom_push # Private attributes - self._upstream_task_ids = set() # type: Set[str] - self._downstream_task_ids = set() # type: Set[str] + self._upstream_task_ids: Set[str] = set() + self._downstream_task_ids: Set[str] = set() + self._dag = None - if not dag and settings.CONTEXT_MANAGER_DAG: - dag = settings.CONTEXT_MANAGER_DAG - if dag: - self.dag = dag + self.dag = dag or DagContext.get_current_dag() self._log = logging.getLogger("airflow.task.operators") # lineage - self.inlets = [] # type: List[DataSet] - self.outlets = [] # type: List[DataSet] + self.inlets: List[DataSet] = [] + self.outlets: List[DataSet] = [] self.lineage_data = None self._inlets = { @@ -422,9 +417,9 @@ def __init__( "datasets": [], } - self._outlets = { + self._outlets: Dict[str, Iterable] = { "datasets": [], - } # type: Dict + } if inlets: self._inlets.update(inlets) @@ -433,8 +428,7 @@ def __init__( self._outlets.update(outlets) def __eq__(self, other): - if (type(self) == type(other) and # pylint: disable=unidiomatic-typecheck - self.task_id == other.task_id): + if type(self) is type(other) and self.task_id == other.task_id: return all(self.__dict__.get(c, None) == other.__dict__.get(c, None) for c in self._comps) return False @@ -463,6 +457,7 @@ def __rshift__(self, other): If "Other" is a DAG, the DAG is assigned to the Operator. """ + from airflow import DAG if isinstance(other, DAG): # if this dag is already assigned, do nothing # otherwise, do normal dag assignment @@ -478,6 +473,7 @@ def __lshift__(self, other): If "Other" is a DAG, the DAG is assigned to the Operator. """ + from airflow import DAG if isinstance(other, DAG): # if this dag is already assigned, do nothing # otherwise, do normal dag assignment @@ -522,6 +518,10 @@ def dag(self, dag): Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok. """ + from airflow import DAG + if dag is None: + self._dag = None + return if not isinstance(dag, DAG): raise TypeError( 'Expected DAG; received {}'.format(dag.__class__.__name__)) @@ -654,8 +654,7 @@ def __deepcopy__(self, memo): memo[id(self)] = result # noinspection PyProtectedMember - shallow_copy = cls.shallow_copy_attrs + \ - cls._base_operator_shallow_copy_attrs # pylint: disable=protected-access + shallow_copy = cls.shallow_copy_attrs + cls.BASE_OPERATOR_SHALLOW_COPY_ATTRS for k, v in self.__dict__.items(): if k not in shallow_copy: @@ -834,13 +833,12 @@ def clear(self, Clears the state of task instances associated with the task, following the parameters specified. """ - TI = TaskInstance - qry = session.query(TI).filter(TI.dag_id == self.dag_id) + qry = session.query(TaskInstance).filter(TaskInstance.dag_id == self.dag_id) if start_date: - qry = qry.filter(TI.execution_date >= start_date) + qry = qry.filter(TaskInstance.execution_date >= start_date) if end_date: - qry = qry.filter(TI.execution_date <= end_date) + qry = qry.filter(TaskInstance.execution_date <= end_date) tasks = [self.task_id] @@ -852,7 +850,7 @@ def clear(self, tasks += [ t.task_id for t in self.get_flat_relatives(upstream=False)] - qry = qry.filter(TI.task_id.in_(tasks)) + qry = qry.filter(TaskInstance.task_id.in_(tasks)) count = qry.count() @@ -1026,8 +1024,8 @@ def set_upstream(self, task_or_task_list): """ self._set_relatives(task_or_task_list, upstream=True) + @staticmethod def xcom_push( - self, context, key, value, @@ -1040,8 +1038,8 @@ def xcom_push( value=value, execution_date=execution_date) + @staticmethod def xcom_pull( - self, context, task_ids=None, dag_id=None, @@ -1081,13 +1079,16 @@ def get_extra_links(self, dttm, link_name): else: return None - -# pylint: disable=protected-access -BaseOperator._serialized_fields = frozenset( - vars(BaseOperator(task_id='test')).keys() - { - 'inlets', 'outlets', '_upstream_task_ids', 'default_args' - } | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'} -) + @classmethod + def get_serialized_fields(cls): + """Stringified DAGs and operators contain exactly these fields.""" + if not cls._serialized_fields: + cls._serialized_fields = frozenset( + vars(BaseOperator(task_id='test')).keys() - { + 'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag' + } | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'} + ) + return cls._serialized_fields class BaseOperatorLink(metaclass=ABCMeta): diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 1d28fa8f6099fa..92883994806b24 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -24,8 +24,7 @@ from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import synonym -from airflow import LoggingMixin -from airflow.exceptions import AirflowException +from airflow import AirflowException, LoggingMixin from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet diff --git a/airflow/models/crypto.py b/airflow/models/crypto.py index ef88e2a016128d..955e2718604be6 100644 --- a/airflow/models/crypto.py +++ b/airflow/models/crypto.py @@ -21,8 +21,8 @@ from cryptography.fernet import Fernet, MultiFernet +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.typing_compat import Protocol from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/models/dag.py b/airflow/models/dag.py index f9da5dbe171d87..5fee17ca658694 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -26,7 +26,7 @@ import traceback from collections import OrderedDict, defaultdict from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Callable, Dict, FrozenSet, Iterable, List, Optional, Type, Union +from typing import Callable, Dict, FrozenSet, Iterable, List, Optional, Type, Union import jinja2 import pendulum @@ -34,16 +34,17 @@ from dateutil.relativedelta import relativedelta from sqlalchemy import Boolean, Column, Index, Integer, String, Text, func, or_ -from airflow import settings, utils +from airflow import AirflowException, settings, utils from airflow.configuration import conf from airflow.dag.base_dag import BaseDag -from airflow.exceptions import AirflowDagCycleException, AirflowException, DagNotFound, DuplicateTaskIdFound -from airflow.executors import LocalExecutor, get_default_executor +from airflow.exceptions import AirflowDagCycleException, DagNotFound, DuplicateTaskIdFound +from airflow.executors.all_executors import AllExecutors +from airflow.executors.local_executor import LocalExecutor from airflow.models.base import ID_LEN, Base +from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import DagRun -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.settings import MIN_SERIALIZED_DAG_UPDATE_INTERVAL, STORE_SERIALIZED_DAGS from airflow.utils import timezone @@ -55,9 +56,6 @@ from airflow.utils.sqlalchemy import Interval, UtcDateTime from airflow.utils.state import State -if TYPE_CHECKING: - from airflow.models.baseoperator import BaseOperator # Avoid circular dependency - ScheduleInterval = Union[str, timedelta, relativedelta] @@ -186,7 +184,6 @@ class DAG(BaseDag, LoggingMixin): :type jinja_environment_kwargs: dict """ - _comps = { 'dag_id', 'task_ids', @@ -198,7 +195,7 @@ class DAG(BaseDag, LoggingMixin): 'last_loaded', } - _serialized_fields = frozenset() # type: FrozenSet[str] + _serialized_fields: Optional[FrozenSet[str]] = None def __init__( self, @@ -311,7 +308,6 @@ def __init__( self.on_failure_callback = on_failure_callback self.doc_md = doc_md - self._old_context_manager_dags = [] # type: Iterable[DAG] self._access_control = access_control self.is_paused_upon_creation = is_paused_upon_creation @@ -351,14 +347,12 @@ def __hash__(self): return hash(tuple(hash_components)) # Context Manager ----------------------------------------------- - def __enter__(self): - self._old_context_manager_dags.append(settings.CONTEXT_MANAGER_DAG) - settings.CONTEXT_MANAGER_DAG = self + DagContext.push_context_managed_dag(self) return self def __exit__(self, _type, _value, _tb): - settings.CONTEXT_MANAGER_DAG = self._old_context_manager_dags.pop() + DagContext.pop_context_managed_dag() # /Context Manager ---------------------------------------------- @@ -583,10 +577,10 @@ def owner(self): @provide_session def _get_concurrency_reached(self, session=None): - TI = TaskInstance - qry = session.query(func.count(TI.task_id)).filter( - TI.dag_id == self.dag_id, - TI.state == State.RUNNING, + TaskInstance + qry = session.query(func.count(TaskInstance.task_id)).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.state == State.RUNNING, ) return qry.scalar() >= self.concurrency @@ -812,12 +806,12 @@ def get_task_instances( return tis @property - def roots(self) -> List["BaseOperator"]: + def roots(self) -> List[BaseOperator]: """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" return [task for task in self.tasks if not task.upstream_list] @property - def leaves(self) -> List["BaseOperator"]: + def leaves(self) -> List[BaseOperator]: """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" return [task for task in self.tasks if not task.downstream_list] @@ -913,20 +907,20 @@ def clear( Clears a set of task instances associated with the current dag for a specified date range. """ - TI = TaskInstance - tis = session.query(TI) + TaskInstance + tis = session.query(TaskInstance) if include_subdags: # Crafting the right filter for dag_id and task_ids combo conditions = [] for dag in self.subdags + [self]: conditions.append( - TI.dag_id.like(dag.dag_id) & - TI.task_id.in_(dag.task_ids) + TaskInstance.dag_id.like(dag.dag_id) & + TaskInstance.task_id.in_(dag.task_ids) ) tis = tis.filter(or_(*conditions)) else: - tis = session.query(TI).filter(TI.dag_id == self.dag_id) - tis = tis.filter(TI.task_id.in_(self.task_ids)) + tis = session.query(TaskInstance).filter(TaskInstance.dag_id == self.dag_id) + tis = tis.filter(TaskInstance.task_id.in_(self.task_ids)) if include_parentdag and self.is_subdag: @@ -948,15 +942,15 @@ def clear( )) if start_date: - tis = tis.filter(TI.execution_date >= start_date) + tis = tis.filter(TaskInstance.execution_date >= start_date) if end_date: - tis = tis.filter(TI.execution_date <= end_date) + tis = tis.filter(TaskInstance.execution_date <= end_date) if only_failed: tis = tis.filter(or_( - TI.state == State.FAILED, - TI.state == State.UPSTREAM_FAILED)) + TaskInstance.state == State.FAILED, + TaskInstance.state == State.UPSTREAM_FAILED)) if only_running: - tis = tis.filter(TI.state == State.RUNNING) + tis = tis.filter(TaskInstance.state == State.RUNNING) if get_tis: return tis @@ -1100,8 +1094,8 @@ def sub_dag(self, task_regex, include_downstream=False, for t in dag.tasks: # Removing upstream/downstream references to tasks that did not # made the cut - t._upstream_task_ids = t._upstream_task_ids.intersection(dag.task_dict.keys()) - t._downstream_task_ids = t._downstream_task_ids.intersection( + t._upstream_task_ids = t.upstream_task_ids.intersection(dag.task_dict.keys()) + t._downstream_task_ids = t.downstream_task_ids.intersection( dag.task_dict.keys()) if len(dag.tasks) < len(self.tasks): @@ -1263,7 +1257,7 @@ def run( if not executor and local: executor = LocalExecutor() elif not executor: - executor = get_default_executor() + executor = AllExecutors.get_default_executor() job = BackfillJob( self, start_date=start_date, @@ -1353,6 +1347,7 @@ def sync_to_db(self, owner=None, sync_time=None, session=None): :type sync_time: datetime :return: None """ + from airflow.models.serialized_dag import SerializedDagModel if owner is None: owner = self.owner @@ -1510,6 +1505,18 @@ def _test_cycle_helper(self, visit_map, task_id): visit_map[task_id] = DagBag.CYCLE_DONE + @classmethod + def get_serialized_fields(cls): + """Stringified DAGs and operators contain exactly these fields.""" + if not cls._serialized_fields: + cls._serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - { + 'parent_dag', '_old_context_manager_dags', 'safe_dag_id', 'last_loaded', + '_full_filepath', 'user_defined_filters', 'user_defined_macros', + '_schedule_interval', 'partial', '_old_context_manager_dags', + '_pickle_id', '_log', 'is_subdag', 'task_dict' + } + return cls._serialized_fields + class DagModel(Base): @@ -1702,12 +1709,26 @@ def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): raise -# Stringified DAGs and operators contain exactly these fields. +class DagContext: -# pylint: disable=protected-access -DAG._serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - { - 'parent_dag', '_old_context_manager_dags', 'safe_dag_id', 'last_loaded', - '_full_filepath', 'user_defined_filters', 'user_defined_macros', - '_schedule_interval', 'partial', '_old_context_manager_dags', - '_pickle_id', '_log', 'is_subdag', 'task_dict' -} + _context_managed_dag: Optional[DAG] = None + _previous_context_managed_dags: List[DAG] = [] + + @classmethod + def push_context_managed_dag(cls, dag: DAG): + if cls._context_managed_dag: + cls._previous_context_managed_dags.append(cls._context_managed_dag) + cls._context_managed_dag = dag + + @classmethod + def pop_context_managed_dag(cls) -> Optional[DAG]: + old_dag = cls._context_managed_dag + if len(cls._previous_context_managed_dags): + cls._context_managed_dag = cls._previous_context_managed_dags.pop() + else: + cls._context_managed_dag = None + return old_dag + + @classmethod + def get_current_dag(cls) -> Optional[DAG]: + return cls._context_managed_dag diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 265b4b941d6297..168549bbb72e72 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -33,8 +33,7 @@ from airflow.configuration import conf from airflow.dag.base_dag import BaseDagBag from airflow.exceptions import AirflowDagCycleException -from airflow.executors import get_default_executor -from airflow.models.serialized_dag import SerializedDagModel +from airflow.executors.all_executors import AllExecutors from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.dag_processing import correct_maybe_zipped, list_py_file_paths @@ -89,7 +88,7 @@ def __init__( # do not use default arg in signature, to fix import cycle on plugin load if executor is None: - executor = get_default_executor() + executor = AllExecutors.get_default_executor() dag_folder = dag_folder or settings.DAGS_FOLDER self.dag_folder = dag_folder self.dags = {} @@ -124,7 +123,8 @@ def get_dag(self, dag_id, from_file_only=False): :param from_file_only: returns a DAG loaded from file. :type from_file_only: bool """ - from airflow.models.dag import DagModel # Avoid circular import + # Avoid circular import + from airflow.models.dag import DagModel # Only read DAGs from DB if this dagbag is store_serialized_dags. # from_file_only is an exception, currently it is for renderring templates @@ -133,6 +133,7 @@ def get_dag(self, dag_id, from_file_only=False): # FIXME: this exception should be removed in future, then webserver can be # decoupled from DAG files. if self.store_serialized_dags and not from_file_only: + from airflow.models.serialized_dag import SerializedDagModel if dag_id not in self.dags: # Load from DB if not (yet) in the bag row = SerializedDagModel.get(dag_id) @@ -185,7 +186,7 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): Given a path to a python module or zip file, this method imports the module and look for dag objects within it. """ - from airflow.models.dag import DAG # Avoid circular import + from airflow import DAG # Avoid circular import found_dags = [] @@ -444,6 +445,7 @@ def collect_dags( def collect_dags_from_db(self): """Collects DAGs from database.""" + from airflow.models.serialized_dag import SerializedDagModel start_dttm = timezone.utcnow() self.log.info("Filling up the DagBag from database") diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 6d199d3a74a1f8..e6779ae01207ea 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -25,7 +25,7 @@ from sqlalchemy.orm import synonym from sqlalchemy.orm.session import Session -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models.base import ID_LEN, Base from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext @@ -208,11 +208,11 @@ def get_task_instance(self, task_id, session=None): """ from airflow.models.taskinstance import TaskInstance # Avoid circular import - TI = TaskInstance - ti = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.execution_date == self.execution_date, - TI.task_id == task_id + TaskInstance + ti = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.execution_date == self.execution_date, + TaskInstance.task_id == task_id ).first() return ti diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index fbc4285cf07e26..81f0a97762fe0d 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -21,23 +21,19 @@ import hashlib from datetime import timedelta -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import Any, Dict, List, Optional import sqlalchemy_jsonfield from sqlalchemy import Column, Index, Integer, String, and_ from sqlalchemy.sql import exists from airflow.models.base import ID_LEN, Base +from airflow.serialization.serialized_dag import SerializedDAG from airflow.settings import json from airflow.utils import db, timezone from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.sqlalchemy import UtcDateTime -if TYPE_CHECKING: - from airflow.models import DAG # noqa: F401; # pylint: disable=cyclic-import - from airflow.serialization import SerializedDAG # noqa: F401 - - log = LoggingMixin().log @@ -59,6 +55,7 @@ class SerializedDagModel(Base): Because reading from database is lightweight compared to importing from files, it solves the webserver scalability issue. """ + from airflow import DAG __tablename__ = 'serialized_dag' dag_id = Column(String(ID_LEN), primary_key=True) @@ -72,9 +69,7 @@ class SerializedDagModel(Base): Index('idx_fileloc_hash', fileloc_hash, unique=False), ) - def __init__(self, dag: 'DAG'): - from airflow.serialization import SerializedDAG # noqa # pylint: disable=redefined-outer-name - + def __init__(self, dag: DAG): self.dag_id = dag.dag_id self.fileloc = dag.full_filepath self.fileloc_hash = self.dag_fileloc_hash(self.fileloc) @@ -96,7 +91,7 @@ def dag_fileloc_hash(full_filepath: str) -> int: @classmethod @db.provide_session - def write_dag(cls, dag: 'DAG', min_update_interval: Optional[int] = None, session=None): + def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session=None): """Serializes a DAG and writes it into database. :param dag: a DAG to be written into database @@ -143,8 +138,6 @@ def read_all_dags(cls, session=None) -> Dict[str, 'SerializedDAG']: @property def dag(self): """The DAG deserialized from the ``data`` column""" - from airflow.serialization import SerializedDAG # noqa # pylint: disable=redefined-outer-name - if isinstance(self.data, dict): dag = SerializedDAG.from_dict(self.data) # type: Any else: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index d97b2e710c2b24..7c212b84852560 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -213,7 +213,7 @@ def try_number(self): Return the try number that this task number will be when it is actually run. - If the TI is currently running, this will match the column in the + If the TaskInstance is currently running, this will match the column in the database, in all other cases this will be incremented. """ # This is designed so that task logs end up in the right file. @@ -395,11 +395,11 @@ def current_state(self, session=None): we use and looking up the state becomes part of the session, otherwise a new session is used. """ - TI = TaskInstance - ti = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.execution_date == self.execution_date, + TaskInstance + ti = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.execution_date == self.execution_date, ).all() if ti: state = ti[0].state @@ -429,12 +429,12 @@ def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_ lock the TaskInstance (issuing a FOR UPDATE clause) until the session is committed. """ - TI = TaskInstance + TaskInstance - qry = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.execution_date == self.execution_date) + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.execution_date == self.execution_date) if lock_for_update: ti = qry.with_for_update().first() @@ -528,7 +528,7 @@ def _get_previous_ti( # LEGACY: most likely running from unit tests if not dr: - # Means that this TI is NOT being run from a DR, but from a catchup + # Means that this TaskInstance is NOT being run from a DR, but from a catchup previous_scheduled_date = dag.previous_schedule(self.execution_date) if not previous_scheduled_date: return None @@ -742,7 +742,7 @@ def _check_and_change_state_before_execution( :type ignore_all_deps: bool :param ignore_depends_on_past: Ignore depends_on_past DAG attribute :type ignore_depends_on_past: bool - :param ignore_task_deps: Don't check the dependencies of this TI's task + :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task :type ignore_task_deps: bool :param ignore_ti_state: Disregards previous task instance state :type ignore_ti_state: bool @@ -1056,7 +1056,7 @@ def handle_failure(self, error, test_mode=False, context=None, session=None): # Let's go deeper try: - # Since this function is called only when the TI state is running, + # Since this function is called only when the TaskInstance state is running, # try_number contains the current try_number (not the next). We # only mark task instance as FAILED if the next task instance # try_number exceeds the max_tries. @@ -1379,12 +1379,12 @@ def xcom_pull( @provide_session def get_num_running_task_instances(self, session): - TI = TaskInstance + TaskInstance # .count() is inefficient return session.query(func.count()).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.state == State.RUNNING + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.state == State.RUNNING ).scalar() def init_run_context(self, raw=False): diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py index 9a0781f76163fb..9fc01354492a0d 100644 --- a/airflow/operators/bash_operator.py +++ b/airflow/operators/bash_operator.py @@ -24,7 +24,7 @@ from tempfile import gettempdir from typing import Dict, Optional -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.utils.file import TemporaryDirectory diff --git a/airflow/operators/bigquery_to_bigquery.py b/airflow/operators/bigquery_to_bigquery.py index 3595926ada90d2..19057668b3de7c 100644 --- a/airflow/operators/bigquery_to_bigquery.py +++ b/airflow/operators/bigquery_to_bigquery.py @@ -23,7 +23,7 @@ from typing import Dict, List, Optional, Union from airflow.gcp.hooks.bigquery import BigQueryHook -from airflow.models.baseoperator import BaseOperator +from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/bigquery_to_gcs.py b/airflow/operators/bigquery_to_gcs.py index e4293d7f2625da..c0f8e8c469d214 100644 --- a/airflow/operators/bigquery_to_gcs.py +++ b/airflow/operators/bigquery_to_gcs.py @@ -23,7 +23,7 @@ from typing import Dict, List, Optional from airflow.gcp.hooks.bigquery import BigQueryHook -from airflow.models.baseoperator import BaseOperator +from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/bigquery_to_mysql.py b/airflow/operators/bigquery_to_mysql.py index 967830b2a03311..d13ade37716a33 100644 --- a/airflow/operators/bigquery_to_mysql.py +++ b/airflow/operators/bigquery_to_mysql.py @@ -23,7 +23,7 @@ from airflow.gcp.hooks.bigquery import BigQueryHook from airflow.hooks.mysql_hook import MySqlHook -from airflow.models.baseoperator import BaseOperator +from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/cassandra_to_gcs.py b/airflow/operators/cassandra_to_gcs.py index 40b197363d2d41..ada549a52c5333 100644 --- a/airflow/operators/cassandra_to_gcs.py +++ b/airflow/operators/cassandra_to_gcs.py @@ -31,7 +31,7 @@ from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py index 0da9480c07d33b..71bae213ad08ec 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/check_operator.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Iterable, Optional, SupportsAbs -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py index 7e530385c3fe77..96577ecd02495a 100644 --- a/airflow/operators/docker_operator.py +++ b/airflow/operators/docker_operator.py @@ -25,7 +25,7 @@ from docker import APIClient, tls -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.docker_hook import DockerHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/druid_check_operator.py b/airflow/operators/druid_check_operator.py index 92ad337d45d4a3..abc12701765558 100644 --- a/airflow/operators/druid_check_operator.py +++ b/airflow/operators/druid_check_operator.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.druid_hook import DruidDbApiHook from airflow.operators.check_operator import CheckOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/gcs_to_gcs.py b/airflow/operators/gcs_to_gcs.py index fcd4e60f49a8ef..0dc09269c70b33 100644 --- a/airflow/operators/gcs_to_gcs.py +++ b/airflow/operators/gcs_to_gcs.py @@ -22,7 +22,7 @@ import warnings from typing import Optional -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py index 62528d5dd796f5..0bb63b9e531b99 100644 --- a/airflow/operators/hive_stats_operator.py +++ b/airflow/operators/hive_stats_operator.py @@ -21,7 +21,7 @@ from collections import OrderedDict from typing import Callable, Dict, List, Optional -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.hive_hooks import HiveMetastoreHook from airflow.hooks.mysql_hook import MySqlHook from airflow.hooks.presto_hook import PrestoHook diff --git a/airflow/operators/http_operator.py b/airflow/operators/http_operator.py index 150b3b7d9330c2..dd3a49482179d4 100644 --- a/airflow/operators/http_operator.py +++ b/airflow/operators/http_operator.py @@ -18,7 +18,7 @@ # under the License. from typing import Any, Callable, Dict, Optional -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.http_hook import HttpHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 8efbce80d751ec..6380f2a4cda822 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -30,7 +30,7 @@ import dill -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models import BaseOperator, SkipMixin from airflow.utils.decorators import apply_defaults from airflow.utils.file import TemporaryDirectory diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py index ccaa2eac216af9..14d5c25b58cdc5 100644 --- a/airflow/operators/s3_file_transform_operator.py +++ b/airflow/operators/s3_file_transform_operator.py @@ -22,7 +22,7 @@ from tempfile import NamedTemporaryFile from typing import Optional, Union -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py index 52ef128d20b08d..0a51fd4ada0656 100644 --- a/airflow/operators/s3_to_hive_operator.py +++ b/airflow/operators/s3_to_hive_operator.py @@ -24,7 +24,7 @@ from tempfile import NamedTemporaryFile from typing import Dict, Optional, Union -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.hive_hooks import HiveCliHook from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook diff --git a/airflow/operators/slack_operator.py b/airflow/operators/slack_operator.py index f9df8537f1501e..b2e642cc0b100b 100644 --- a/airflow/operators/slack_operator.py +++ b/airflow/operators/slack_operator.py @@ -20,7 +20,7 @@ import json from typing import Dict, List, Optional -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.slack_hook import SlackHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py index e2f120632689a1..a84003c5dccbb1 100644 --- a/airflow/operators/subdag_operator.py +++ b/airflow/operators/subdag_operator.py @@ -24,11 +24,11 @@ from sqlalchemy.orm.session import Session -from airflow import settings +from airflow import DAG, AirflowException from airflow.api.common.experimental.get_task_instance import get_task_instance -from airflow.exceptions import AirflowException, TaskInstanceNotFound +from airflow.exceptions import TaskInstanceNotFound from airflow.models import DagRun -from airflow.models.dag import DAG +from airflow.models.dag import DagContext from airflow.models.pool import Pool from airflow.models.taskinstance import TaskInstance from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -78,7 +78,7 @@ def __init__(self, self._validate_pool(session) def _validate_dag(self, kwargs): - dag = kwargs.get('dag') or settings.CONTEXT_MANAGER_DAG + dag = kwargs.get('dag') or DagContext.get_current_dag() if not dag: raise AirflowException('Please pass in the `dag` param or call within a DAG context manager') diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 6dfa9dbe7b6c99..f5d3fc826099ef 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -130,6 +130,12 @@ def is_valid_plugin(plugin_obj, existing_plugins): return False +def converts_camel_case_to_snake_case(name: str) -> str: + """Converts SomeCase name to some_case.""" + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + plugins = [] # type: List[AirflowPlugin] norm_pattern = re.compile(r'[/|.]') @@ -205,7 +211,9 @@ def make_module(name, objects): ) hooks_modules.append(make_module('airflow.hooks.' + p.name, p.hooks)) executors_modules.append( - make_module('airflow.executors.' + p.name, p.executors)) + make_module('airflow.executors.' + + converts_camel_case_to_snake_case(p.name) + "." + + p.name, p.executors)) macros_modules.append(make_module('airflow.macros.' + p.name, p.macros)) admin_views.extend(p.admin_views) diff --git a/airflow/providers/amazon/aws/hooks/datasync.py b/airflow/providers/amazon/aws/hooks/datasync.py index 76c3eea53dfd32..284bedbc08e618 100644 --- a/airflow/providers/amazon/aws/hooks/datasync.py +++ b/airflow/providers/amazon/aws/hooks/datasync.py @@ -23,8 +23,9 @@ import time +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook -from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout +from airflow.exceptions import AirflowBadRequest, AirflowTaskTimeout class AWSDataSyncHook(AwsHook): diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 7a1fc2daf457e4..b61c7026e49291 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -29,8 +29,8 @@ from botocore.exceptions import ClientError +from airflow import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook -from airflow.exceptions import AirflowException def provide_bucket_name(func): diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py index 70ecdb09638787..2018617b08382a 100644 --- a/airflow/providers/amazon/aws/operators/datasync.py +++ b/airflow/providers/amazon/aws/operators/datasync.py @@ -21,7 +21,7 @@ Get, Create, Update, Delete and execute an AWS DataSync Task. """ -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook from airflow.utils.decorators import apply_defaults diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py index d900fc5076bafd..49e707847b8e6a 100644 --- a/airflow/providers/amazon/aws/sensors/athena.py +++ b/airflow/providers/amazon/aws/sensors/athena.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py index 120a05d9445b52..b740cad4007307 100644 --- a/airflow/providers/amazon/aws/sensors/sqs.py +++ b/airflow/providers/amazon/aws/sensors/sqs.py @@ -20,7 +20,7 @@ Reads and then deletes the message from SQS queue """ -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.sqs import SQSHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc.py b/airflow/providers/google/cloud/example_dags/example_dataproc.py index baf0ebe37db848..cf9a67ab712eae 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataproc.py +++ b/airflow/providers/google/cloud/example_dags/example_dataproc.py @@ -120,7 +120,7 @@ }, } -with models.DAG( +with models.dag.DAG( "example_gcp_dataproc", default_args={"start_date": airflow.utils.dates.days_ago(1)}, schedule_interval=None, diff --git a/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py index 08d621d2b2814d..04078ff6ffee37 100644 --- a/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py @@ -39,7 +39,7 @@ OBJECT_SRC_3 = "parent-3.txt" -with models.DAG( +with models.dag.DAG( "example_sftp_to_gcs", default_args=default_args, schedule_interval=None ) as dag: # [START howto_operator_sftp_to_gcs_copy_single_file] diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 84d45736810dc8..f8642d7b107833 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -39,7 +39,7 @@ ) from google.protobuf.json_format import MessageToDict -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder diff --git a/airflow/providers/google/marketing_platform/operators/campaign_manager.py b/airflow/providers/google/marketing_platform/operators/campaign_manager.py index e27b1c54bdcd21..82949c7c04d4e5 100644 --- a/airflow/providers/google/marketing_platform/operators/campaign_manager.py +++ b/airflow/providers/google/marketing_platform/operators/campaign_manager.py @@ -27,7 +27,7 @@ from airflow import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook -from airflow.models.baseoperator import BaseOperator +from airflow.models import BaseOperator from airflow.providers.google.marketing_platform.hooks.campaign_manager import GoogleCampaignManagerHook from airflow.utils.decorators import apply_defaults diff --git a/airflow/providers/google/marketing_platform/operators/display_video.py b/airflow/providers/google/marketing_platform/operators/display_video.py index 3bc3dbb0754ce8..2ffd401cc70029 100644 --- a/airflow/providers/google/marketing_platform/operators/display_video.py +++ b/airflow/providers/google/marketing_platform/operators/display_video.py @@ -27,7 +27,7 @@ from airflow import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook -from airflow.models.baseoperator import BaseOperator +from airflow.models import BaseOperator from airflow.providers.google.marketing_platform.hooks.display_video import GoogleDisplayVideo360Hook from airflow.utils.decorators import apply_defaults diff --git a/airflow/providers/google/marketing_platform/operators/search_ads.py b/airflow/providers/google/marketing_platform/operators/search_ads.py index dcd15f4efeddb4..54a0cd1eeb8ecb 100644 --- a/airflow/providers/google/marketing_platform/operators/search_ads.py +++ b/airflow/providers/google/marketing_platform/operators/search_ads.py @@ -24,7 +24,7 @@ from airflow import AirflowException from airflow.gcp.hooks.gcs import GoogleCloudStorageHook -from airflow.models.baseoperator import BaseOperator +from airflow.models import BaseOperator from airflow.providers.google.marketing_platform.hooks.search_ads import GoogleSearchAdsHook from airflow.utils.decorators import apply_defaults diff --git a/airflow/sensors/external_task_sensor.py b/airflow/sensors/external_task_sensor.py index 90360fb742ccdc..089fe49253cbe7 100644 --- a/airflow/sensors/external_task_sensor.py +++ b/airflow/sensors/external_task_sensor.py @@ -21,7 +21,7 @@ from sqlalchemy import func -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models import DagBag, DagModel, DagRun, TaskInstance from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.db import provide_session @@ -118,7 +118,6 @@ def poke(self, context, session=None): ) DM = DagModel - TI = TaskInstance DR = DagRun # we only do the check for 1st time, no need for subsequent poke @@ -146,10 +145,10 @@ def poke(self, context, session=None): if self.external_task_id: # .count() is inefficient count = session.query(func.count()).filter( - TI.dag_id == self.external_dag_id, - TI.task_id == self.external_task_id, - TI.state.in_(self.allowed_states), - TI.execution_date.in_(dttm_filter), + TaskInstance.dag_id == self.external_dag_id, + TaskInstance.task_id == self.external_task_id, + TaskInstance.state.in_(self.allowed_states), + TaskInstance.execution_date.in_(dttm_filter), ).scalar() else: # .count() is inefficient diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py index ee036095ca090c..1260f758c7d14c 100644 --- a/airflow/sensors/http_sensor.py +++ b/airflow/sensors/http_sensor.py @@ -18,7 +18,7 @@ # under the License. from typing import Callable, Dict, Optional -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.http_hook import HttpHook from airflow.operators.python_operator import PythonOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator diff --git a/airflow/sensors/s3_key_sensor.py b/airflow/sensors/s3_key_sensor.py index bc24c7fbbed800..3882f0eddece66 100644 --- a/airflow/sensors/s3_key_sensor.py +++ b/airflow/sensors/s3_key_sensor.py @@ -20,7 +20,7 @@ from urllib.parse import urlparse -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py index 8941bd8e98499d..ad2221973176c4 100644 --- a/airflow/sensors/sql_sensor.py +++ b/airflow/sensors/sql_sensor.py @@ -19,7 +19,7 @@ from typing import Iterable -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/serialization/__init__.py b/airflow/serialization/__init__.py deleted file mode 100644 index 3d378c1260d8b9..00000000000000 --- a/airflow/serialization/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""DAG serialization.""" -# pylint: disable=cyclic-import -from airflow.serialization.serialized_baseoperator import SerializedBaseOperator -# pylint: disable=cyclic-import -from airflow.serialization.serialized_dag import SerializedDAG - -__ALL__ = [SerializedBaseOperator, SerializedDAG] diff --git a/airflow/serialization/serialization.py b/airflow/serialization/base_serialization.py similarity index 82% rename from airflow/serialization/serialization.py rename to airflow/serialization/base_serialization.py index 4c971db12aacba..a36807f75b0b5d 100644 --- a/airflow/serialization/serialization.py +++ b/airflow/serialization/base_serialization.py @@ -22,17 +22,17 @@ import datetime import enum import logging -from typing import TYPE_CHECKING, Dict, Optional, Union +from inspect import Parameter +from typing import Dict, Optional, Set, Union import pendulum from dateutil import relativedelta -import airflow -from airflow.exceptions import AirflowException +from airflow import DAG, AirflowException from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection -from airflow.models.dag import DAG from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding +from airflow.serialization.json_schema import Validator from airflow.settings import json from airflow.utils.log.logging_mixin import LoggingMixin from airflow.www.utils import get_python_source @@ -43,8 +43,8 @@ FAILED = 'serialization_failed' -class Serialization: - """Serialization provides utils for serialization.""" +class BaseSerialization: + """BaseSerialization provides utils for serialization.""" # JSON primitive types. _primitive_types = (int, bool, float, str) @@ -78,14 +78,12 @@ def to_dict(cls, var: Union[DAG, BaseOperator, dict, list, set, tuple]) -> dict: raise NotImplementedError() @classmethod - def from_json(cls, serialized_obj: str) -> Union[ - 'SerializedDAG', 'SerializedBaseOperator', dict, list, set, tuple]: + def from_json(cls, serialized_obj: str) -> Union['BaseSerialization', dict, list, set, tuple]: """Deserializes json_str and reconstructs all DAGs and operators it contains.""" return cls.from_dict(json.loads(serialized_obj)) @classmethod - def from_dict(cls, serialized_obj: dict) -> Union[ - 'SerializedDAG', 'SerializedBaseOperator', dict, list, set, tuple]: + def from_dict(cls, serialized_obj: dict) -> Union['BaseSerialization', dict, list, set, tuple]: """Deserializes a python dict stored with type decorators and reconstructs all DAGs and operators it contains.""" return cls._deserialize(serialized_obj) @@ -123,6 +121,27 @@ def _is_excluded(cls, var, attrname, instance): cls._value_is_hardcoded_default(attrname, var) ) + @classmethod + def serialize_to_json(cls, object_to_serialize: Union[BaseOperator, DAG], decorated_fields: Set): + """Serializes an object to json""" + serialized_object = {} + keys_to_serialize = object_to_serialize.get_serialized_fields() + for key in keys_to_serialize: + # None is ignored in serialized form and is added back in deserialization. + value = getattr(object_to_serialize, key, None) + if cls._is_excluded(value, key, object_to_serialize): + continue + + if key in decorated_fields: + serialized_object[key] = cls._serialize(value) + else: + value = cls._serialize(value) + # TODO: Why? + if isinstance(value, dict) and "__type" in value: + value = value["__var"] + serialized_object[key] = value + return serialized_object + @classmethod def _serialize(cls, var): # pylint: disable=too-many-return-statements """Helper function of depth first search for serialization. @@ -135,6 +154,8 @@ def _serialize(cls, var): # pylint: disable=too-many-return-statements (3) Operator has a special field CLASS to record the original class name for displaying in UI. """ + from airflow.serialization.serialized_dag import SerializedDAG + from airflow.serialization.serialized_baseoperator import SerializedBaseOperator try: if cls._is_primitive(var): # enum.IntEnum is an int instance, it causes json dumps error so we use its value. @@ -149,9 +170,9 @@ def _serialize(cls, var): # pylint: disable=too-many-return-statements elif isinstance(var, list): return [cls._serialize(v) for v in var] elif isinstance(var, DAG): - return airflow.serialization.SerializedDAG.serialize_dag(var) + return SerializedDAG.serialize_dag(var) elif isinstance(var, BaseOperator): - return airflow.serialization.SerializedBaseOperator.serialize_operator(var) + return SerializedBaseOperator.serialize_operator(var) elif isinstance(var, cls._datetime_types): return cls._encode(var.timestamp(), type_=DAT.DATETIME) elif isinstance(var, datetime.timedelta): @@ -186,6 +207,8 @@ def _serialize(cls, var): # pylint: disable=too-many-return-statements @classmethod def _deserialize(cls, encoded_var): # pylint: disable=too-many-return-statements """Helper function of depth first search for deserialization.""" + from airflow.serialization.serialized_dag import SerializedDAG + from airflow.serialization.serialized_baseoperator import SerializedBaseOperator # JSON primitives (except for dict) are not encoded. if cls._is_primitive(encoded_var): return encoded_var @@ -199,9 +222,9 @@ def _deserialize(cls, encoded_var): # pylint: disable=too-many-return-statement if type_ == DAT.DICT: return {k: cls._deserialize(v) for k, v in var.items()} elif type_ == DAT.DAG: - return airflow.serialization.SerializedDAG.deserialize_dag(var) + return SerializedDAG.deserialize_dag(var) elif type_ == DAT.OP: - return airflow.serialization.SerializedBaseOperator.deserialize_operator(var) + return SerializedBaseOperator.deserialize_operator(var) elif type_ == DAT.DATETIME: return pendulum.from_timestamp(var) elif type_ == DAT.TIMEDELTA: @@ -242,10 +265,3 @@ def _value_is_hardcoded_default(cls, attrname, value): if attrname in cls._CONSTRUCTOR_PARAMS and cls._CONSTRUCTOR_PARAMS[attrname].default is value: return True return False - - -if TYPE_CHECKING: - from airflow.serialization.json_schema import Validator - from airflow.serialization.serialized_baseoperator import SerializedBaseOperator # noqa: F401, E501; # pylint: disable=cyclic-import - from airflow.serialization.serialized_dag import SerializedDAG # noqa: F401, E501; # pylint: disable=cyclic-import - from inspect import Parameter diff --git a/airflow/serialization/json_schema.py b/airflow/serialization/json_schema.py index 7a09b5b8161754..07b81b4c6d81d1 100644 --- a/airflow/serialization/json_schema.py +++ b/airflow/serialization/json_schema.py @@ -25,7 +25,7 @@ import jsonschema from typing_extensions import Protocol -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.settings import json diff --git a/airflow/serialization/serialized_baseoperator.py b/airflow/serialization/serialized_baseoperator.py index c44989e932a73d..a0464483382e97 100644 --- a/airflow/serialization/serialized_baseoperator.py +++ b/airflow/serialization/serialized_baseoperator.py @@ -20,11 +20,11 @@ """Operator serialization with JSON.""" from inspect import signature -from airflow.models import BaseOperator -from airflow.serialization.serialization import Serialization # pylint: disable=cyclic-import +from airflow.models.baseoperator import BaseOperator +from airflow.serialization.base_serialization import BaseSerialization -class SerializedBaseOperator(BaseOperator, Serialization): +class SerializedBaseOperator(BaseOperator, BaseSerialization): """A JSON serializable representation of operator. All operators are casted to SerializedBaseOperator after deserialization. @@ -66,34 +66,16 @@ def task_type(self, task_type: str): def serialize_operator(cls, op: BaseOperator) -> dict: """Serializes operator into a JSON object. """ - serialize_op = {} - - # pylint: disable=protected-access - for k in op._serialized_fields: - # None is ignored in serialized form and is added back in deserialization. - v = getattr(op, k, None) - if cls._is_excluded(v, k, op): - continue - - if k in cls._decorated_fields: - serialize_op[k] = cls._serialize(v) - else: - v = cls._serialize(v) - if isinstance(v, dict) and "__type" in v: - v = v["__var"] - serialize_op[k] = v - - # Adds a new task_type field to record the original operator class. + serialize_op = cls.serialize_to_json(op, cls._decorated_fields) serialize_op['_task_type'] = op.__class__.__name__ serialize_op['_task_module'] = op.__class__.__module__ - return serialize_op @classmethod def deserialize_operator(cls, encoded_op: dict) -> BaseOperator: """Deserializes an operator from a JSON object. """ - from airflow.serialization import SerializedDAG + from airflow.serialization.serialized_dag import SerializedDAG from airflow.plugins_manager import operator_extra_links op = SerializedBaseOperator(task_id=encoded_op['task_id']) @@ -119,14 +101,13 @@ def deserialize_operator(cls, encoded_op: dict) -> BaseOperator: v = cls._deserialize_timedelta(v) elif k.endswith("_date"): v = cls._deserialize_datetime(v) - elif k in cls._decorated_fields or k not in op._serialized_fields: # noqa: E501; # pylint: disable=protected-access + elif k in cls._decorated_fields or k not in op.get_serialized_fields(): v = cls._deserialize(v) # else use v as it is setattr(op, k, v) - # pylint: disable=protected-access - for k in op._serialized_fields - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): + for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): setattr(op, k, None) return op diff --git a/airflow/serialization/serialized_dag.py b/airflow/serialization/serialized_dag.py index 7845a9db0899b3..45daf778e121ca 100644 --- a/airflow/serialization/serialized_dag.py +++ b/airflow/serialization/serialized_dag.py @@ -21,12 +21,12 @@ from inspect import signature from typing import cast -from airflow.models import DAG +from airflow import DAG +from airflow.serialization.base_serialization import BaseSerialization from airflow.serialization.json_schema import load_dag_schema -from airflow.serialization.serialization import Serialization # pylint: disable=cyclic-import -class SerializedDAG(DAG, Serialization): +class SerializedDAG(DAG, BaseSerialization): """ A JSON serializable representation of DAG. @@ -62,23 +62,7 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument def serialize_dag(cls, dag: DAG) -> dict: """Serializes a DAG into a JSON object. """ - serialize_dag = {} - - # pylint: disable=protected-access - for k in dag._serialized_fields: - # None is ignored in serialized form and is added back in deserialization. - v = getattr(dag, k, None) - if cls._is_excluded(v, k, dag): - continue - - if k in cls._decorated_fields: - serialize_dag[k] = cls._serialize(v) - else: - v = cls._serialize(v) - # TODO: Why? - if isinstance(v, dict) and "__type" in v: - v = v["__var"] - serialize_dag[k] = v + serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields) serialize_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()] return serialize_dag @@ -87,7 +71,7 @@ def serialize_dag(cls, dag: DAG) -> dict: def deserialize_dag(cls, encoded_dag: dict) -> "SerializedDAG": """Deserializes a DAG from a JSON object. """ - from airflow.serialization import SerializedBaseOperator + from airflow.serialization.serialized_baseoperator import SerializedBaseOperator dag = SerializedDAG(dag_id=encoded_dag['_dag_id']) @@ -111,26 +95,26 @@ def deserialize_dag(cls, encoded_dag: dict) -> "SerializedDAG": setattr(dag, k, v) - # pylint: disable=protected-access - keys_to_set_none = dag._serialized_fields - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys() + keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys() for k in keys_to_set_none: setattr(dag, k, None) setattr(dag, 'full_filepath', dag.fileloc) for task in dag.task_dict.values(): task.dag = dag - task = cast(SerializedBaseOperator, task) + serializable_task: SerializedBaseOperator = cast(SerializedBaseOperator, task) for date_attr in ["start_date", "end_date"]: - if getattr(task, date_attr) is None: - setattr(task, date_attr, getattr(dag, date_attr)) + if getattr(serializable_task, date_attr) is None: + setattr(serializable_task, date_attr, getattr(dag, date_attr)) - if task.subdag is not None: - setattr(task.subdag, 'parent_dag', dag) - task.subdag.is_subdag = True + if serializable_task.subdag is not None: + setattr(serializable_task.subdag, 'parent_dag', dag) + serializable_task.subdag.is_subdag = True - for task_id in task.downstream_task_ids: + for task_id in serializable_task.downstream_task_ids: # Bypass set_upstream etc here - it does more than we want + # noinspection PyProtectedMember dag.task_dict[task_id]._upstream_task_ids.add(task_id) # pylint: disable=protected-access return dag diff --git a/airflow/settings.py b/airflow/settings.py index bfae1464ddbfa9..c97a83d344fe52 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -31,7 +31,7 @@ from sqlalchemy.orm.session import Session as SASession from sqlalchemy.pool import NullPool -import airflow +# noinspection PyUnresolvedReferences from airflow.configuration import AIRFLOW_HOME, WEBSERVER_CONFIG, conf # NOQA F401 from airflow.logging_config import configure_logging from airflow.utils.module_loading import import_string @@ -322,9 +322,6 @@ def initialize(): WEB_COLORS = {'LIGHTBLUE': '#4d9de0', 'LIGHTORANGE': '#FF9933'} -# Used by DAG context_managers -CONTEXT_MANAGER_DAG = None # type: Optional[airflow.models.dag.DAG] - # If store_serialized_dags is True, scheduler writes serialized DAGs to DB, and webserver # reads DAGs from DB instead of importing from files. STORE_SERIALIZED_DAGS = conf.getboolean('core', 'store_serialized_dags', fallback=False) diff --git a/airflow/task/task_runner/__init__.py b/airflow/task/task_runner/__init__.py index 945f1b656495e4..eaf96721855fe2 100644 --- a/airflow/task/task_runner/__init__.py +++ b/airflow/task/task_runner/__init__.py @@ -19,8 +19,8 @@ # pylint:disable=missing-docstring +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.task.task_runner.standard_task_runner import StandardTaskRunner _TASK_RUNNER = conf.get('core', 'TASK_RUNNER') diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index f004378d5908a1..7fa974947d6fa5 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -20,6 +20,7 @@ from sqlalchemy import case, func import airflow +from airflow.models import TaskInstance from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.db import provide_session from airflow.utils.state import State @@ -36,7 +37,6 @@ class TriggerRuleDep(BaseTIDep): @provide_session def _get_dep_statuses(self, ti, session, dep_context): - TI = airflow.models.TaskInstance TR = airflow.utils.trigger_rule.TriggerRule # Checking that all upstream dependencies have succeeded @@ -56,20 +56,20 @@ def _get_dep_statuses(self, ti, session, dep_context): session .query( func.coalesce(func.sum( - case([(TI.state == State.SUCCESS, 1)], else_=0)), 0), + case([(TaskInstance.state == State.SUCCESS, 1)], else_=0)), 0), func.coalesce(func.sum( - case([(TI.state == State.SKIPPED, 1)], else_=0)), 0), + case([(TaskInstance.state == State.SKIPPED, 1)], else_=0)), 0), func.coalesce(func.sum( - case([(TI.state == State.FAILED, 1)], else_=0)), 0), + case([(TaskInstance.state == State.FAILED, 1)], else_=0)), 0), func.coalesce(func.sum( - case([(TI.state == State.UPSTREAM_FAILED, 1)], else_=0)), 0), - func.count(TI.task_id), + case([(TaskInstance.state == State.UPSTREAM_FAILED, 1)], else_=0)), 0), + func.count(TaskInstance.task_id), ) .filter( - TI.dag_id == ti.dag_id, - TI.task_id.in_(ti.task.upstream_task_ids), - TI.execution_date == ti.execution_date, - TI.state.in_([ + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id.in_(ti.task.upstream_task_ids), + TaskInstance.execution_date == ti.execution_date, + TaskInstance.state.in_([ State.SUCCESS, State.FAILED, State.UPSTREAM_FAILED, State.SKIPPED]), ) diff --git a/airflow/ti_deps/deps/valid_state_dep.py b/airflow/ti_deps/deps/valid_state_dep.py index 38b1c0446b6241..a7b9b30d6e2400 100644 --- a/airflow/ti_deps/deps/valid_state_dep.py +++ b/airflow/ti_deps/deps/valid_state_dep.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.db import provide_session diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 153d675f276f4a..9eaf2158b2edb5 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -39,10 +39,10 @@ # To avoid circular imports import airflow.models +from airflow import AirflowException from airflow.configuration import conf from airflow.dag.base_dag import BaseDag, BaseDagBag -from airflow.exceptions import AirflowException -from airflow.models import errors +from airflow.models import TaskInstance, errors from airflow.settings import STORE_SERIALIZED_DAGS from airflow.stats import Stats from airflow.utils import timezone @@ -220,12 +220,10 @@ def construct_task_instance(self, session=None, lock_for_update=False): lock the TaskInstance (issuing a FOR UPDATE clause) until the session is committed. """ - TI = airflow.models.TaskInstance - - qry = session.query(TI).filter( - TI.dag_id == self._dag_id, - TI.task_id == self._task_id, - TI.execution_date == self._execution_date) + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == self._dag_id, + TaskInstance.task_id == self._task_id, + TaskInstance.execution_date == self._execution_date) if lock_for_update: ti = qry.with_for_update().first() @@ -904,7 +902,7 @@ def _refresh_dag_dir(self): self.log.exception("Error removing old import errors") if STORE_SERIALIZED_DAGS: - from airflow.models import SerializedDagModel + from airflow.models.serialized_dag import SerializedDagModel from airflow.models.dag import DagModel SerializedDagModel.remove_deleted_dags(self._file_paths) DagModel.deactivate_deleted_dags(self._file_paths) @@ -1274,15 +1272,15 @@ def _find_zombies(self, session): # to avoid circular imports from airflow.jobs import LocalTaskJob as LJ self.log.info("Finding 'running' jobs without a recent heartbeat") - TI = airflow.models.TaskInstance + TaskInstance = airflow.models.TaskInstance limit_dttm = timezone.utcnow() - timedelta( seconds=self._zombie_threshold_secs) self.log.info("Failing jobs without heartbeat after %s", limit_dttm) tis = ( - session.query(TI) - .join(LJ, TI.job_id == LJ.id) - .filter(TI.state == State.RUNNING) + session.query(TaskInstance) + .join(LJ, TaskInstance.job_id == LJ.id) + .filter(TaskInstance.state == State.RUNNING) .filter( or_( LJ.state != State.RUNNING, diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 26ee66f09fcf8c..ace3074b11fc36 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -336,8 +336,12 @@ def resetdb(): Clear out the database """ from airflow import models + # We need to add this model manually to get reset working well + # noinspection PyUnresolvedReferences + from airflow.models.serialized_dag import SerializedDagModel # noqa: F401 # alembic adds significant import time, so we import it lazily + # noinspection PyUnresolvedReferences from alembic.migration import MigrationContext log.info("Dropping tables that exist") diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index 2e77aaee8a783c..46523fbd895efd 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -23,8 +23,7 @@ from copy import copy from functools import wraps -from airflow import settings -from airflow.exceptions import AirflowException +from airflow import AirflowException signature = inspect.signature @@ -53,13 +52,14 @@ def apply_defaults(func): @wraps(func) def wrapper(*args, **kwargs): + from airflow.models.dag import DagContext if len(args) > 1: raise AirflowException( "Use keyword arguments when initializing operators") dag_args = {} dag_params = {} - dag = kwargs.get('dag', None) or settings.CONTEXT_MANAGER_DAG + dag = kwargs.get('dag', None) or DagContext.get_current_dag() if dag: dag_args = copy(dag.default_args) or {} dag_params = copy(dag.params) or {} @@ -93,6 +93,7 @@ def wrapper(*args, **kwargs): return result return wrapper + if 'BUILDING_AIRFLOW_DOCS' in os.environ: # flake8: noqa: F811 # Monkey patch hook to get good function headers while building docs diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index a41ac9d6751bea..b5b669c13284c1 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -28,8 +28,8 @@ import psutil from jinja2 import Template +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException try: # Fix Python > 3.7 deprecation @@ -167,8 +167,7 @@ def chain(*tasks): :param tasks: List of tasks or List[airflow.models.BaseOperator] to set dependencies :type tasks: List[airflow.models.BaseOperator] or airflow.models.BaseOperator """ - from airflow.models import BaseOperator - + from airflow.models.baseoperator import BaseOperator for index, up_task in enumerate(tasks[:-1]): down_task = tasks[index + 1] if isinstance(up_task, BaseOperator): diff --git a/airflow/utils/log/gcs_task_handler.py b/airflow/utils/log/gcs_task_handler.py index e1e987382834ba..ca3644f128e439 100644 --- a/airflow/utils/log/gcs_task_handler.py +++ b/airflow/utils/log/gcs_task_handler.py @@ -21,8 +21,8 @@ from cached_property import cached_property +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/utils/operator_resources.py b/airflow/utils/operator_resources.py index 08220e1f30578b..5c509f05795aae 100644 --- a/airflow/utils/operator_resources.py +++ b/airflow/utils/operator_resources.py @@ -17,8 +17,8 @@ # specific language governing permissions and limitations # under the License. +from airflow import AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException # Constants for resources (megabytes are the base unit) MB = 1 diff --git a/airflow/utils/tests.py b/airflow/utils/tests.py index 2835a4ecd1a243..01bfc8c5bb5c35 100644 --- a/airflow/utils/tests.py +++ b/airflow/utils/tests.py @@ -20,8 +20,7 @@ import re import unittest -from airflow.models import BaseOperator -from airflow.models.baseoperator import BaseOperatorLink +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.utils.decorators import apply_defaults diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index 67204f72d2c344..2810922e2f75c0 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -19,14 +19,13 @@ from flask import Blueprint, g, jsonify, request, url_for import airflow.api -from airflow import models +from airflow import AirflowException, models from airflow.api.common.experimental import delete_dag as delete, pool as pool_api, trigger_dag as trigger from airflow.api.common.experimental.get_code import get_code from airflow.api.common.experimental.get_dag_run_state import get_dag_run_state from airflow.api.common.experimental.get_dag_runs import get_dag_runs from airflow.api.common.experimental.get_task import get_task from airflow.api.common.experimental.get_task_instance import get_task_instance -from airflow.exceptions import AirflowException from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.strings import to_boolean diff --git a/airflow/www/security.py b/airflow/www/security.py index 4bb804b42c6ab5..44753eb7e12e94 100644 --- a/airflow/www/security.py +++ b/airflow/www/security.py @@ -23,8 +23,7 @@ from flask_appbuilder.security.sqla.manager import SecurityManager from sqlalchemy import and_, or_ -from airflow import models -from airflow.exceptions import AirflowException +from airflow import AirflowException, models from airflow.utils.db import provide_session from airflow.utils.log.logging_mixin import LoggingMixin from airflow.www.app import appbuilder diff --git a/airflow/www/utils.py b/airflow/www/utils.py index c8b6685e9fc78b..b359bc994035c5 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -37,7 +37,7 @@ from pygments.formatters import HtmlFormatter from airflow.configuration import conf -from airflow.models import BaseOperator +from airflow.models.baseoperator import BaseOperator from airflow.operators.subdag_operator import SubDagOperator from airflow.utils import timezone from airflow.utils.json import AirflowJsonEncoder diff --git a/airflow/www/views.py b/airflow/www/views.py index 49877058d4d7a9..65d305b31c145b 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -51,6 +51,7 @@ set_dag_run_state_to_failed, set_dag_run_state_to_success, ) from airflow.configuration import AIRFLOW_CONFIG, conf +from airflow.executors.all_executors import AllExecutors from airflow.models import Connection, DagModel, DagRun, Log, SlaMiss, TaskFail, XCom, errors from airflow.settings import STORE_SERIALIZED_DAGS from airflow.ti_deps.dep_context import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS, DepContext @@ -331,7 +332,7 @@ def dag_stats(self, session=None): @has_access @provide_session def task_stats(self, session=None): - TI = models.TaskInstance + TaskInstance = models.TaskInstance DagRun = models.DagRun Dag = models.DagModel @@ -361,16 +362,16 @@ def task_stats(self, session=None): # Select all task_instances from active dag_runs. # If no dag_run is active, return task instances from most recent dag_run. LastTI = ( - session.query(TI.dag_id.label('dag_id'), TI.state.label('state')) + session.query(TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state')) .join(LastDagRun, - and_(LastDagRun.c.dag_id == TI.dag_id, - LastDagRun.c.execution_date == TI.execution_date)) + and_(LastDagRun.c.dag_id == TaskInstance.dag_id, + LastDagRun.c.execution_date == TaskInstance.execution_date)) ) RunningTI = ( - session.query(TI.dag_id.label('dag_id'), TI.state.label('state')) + session.query(TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state')) .join(RunningDagRun, - and_(RunningDagRun.c.dag_id == TI.dag_id, - RunningDagRun.c.execution_date == TI.execution_date)) + and_(RunningDagRun.c.dag_id == TaskInstance.dag_id, + RunningDagRun.c.execution_date == TaskInstance.execution_date)) ) UnionTI = union_all(LastTI, RunningTI).alias('union_ti') @@ -462,11 +463,11 @@ def dag_details(self, session=None): title = "DAG details" root = request.args.get('root', '') - TI = models.TaskInstance + TaskInstance = models.TaskInstance states = ( - session.query(TI.state, sqla.func.count(TI.dag_id)) - .filter(TI.dag_id == dag_id) - .group_by(TI.state) + session.query(TaskInstance.state, sqla.func.count(TaskInstance.dag_id)) + .filter(TaskInstance.dag_id == dag_id) + .group_by(TaskInstance.state) .all() ) @@ -670,7 +671,7 @@ def elasticsearch(self, session=None): @has_access @action_logging def task(self): - TI = models.TaskInstance + TaskInstance = models.TaskInstance dag_id = request.args.get('dag_id') task_id = request.args.get('task_id') @@ -690,7 +691,7 @@ def task(self): return redirect(url_for('Airflow.index')) task = copy.copy(dag.get_task(task_id)) task.resolve_template_files() - ti = TI(task=task, execution_date=dttm) + ti = TaskInstance(task=task, execution_date=dttm) ti.refresh_from_db() ti_attrs = [] @@ -807,8 +808,7 @@ def run(self): ignore_task_deps = request.form.get('ignore_task_deps') == "true" ignore_ti_state = request.form.get('ignore_ti_state') == "true" - from airflow.executors import get_default_executor - executor = get_default_executor() + executor = AllExecutors.get_default_executor() valid_celery_config = False valid_kubernetes_config = False diff --git a/dags/test_dag.py b/dags/test_dag.py index 358a5de1a1ca5c..412f33fd9c9912 100644 --- a/dags/test_dag.py +++ b/dags/test_dag.py @@ -21,8 +21,9 @@ """ from datetime import datetime, timedelta -from airflow import DAG, utils +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator +from airflow.utils.dates import days_ago now = datetime.now() now_to_the_hour = ( @@ -34,7 +35,7 @@ default_args = { 'owner': 'airflow', 'depends_on_past': True, - 'start_date': utils.dates.days_ago(2) + 'start_date': days_ago(2) } dag = DAG(DAG_NAME, schedule_interval='*/10 * * * *', default_args=default_args) diff --git a/docs/autoapi_templates/index.rst b/docs/autoapi_templates/index.rst index 114ee880713eff..48675b5c7b6820 100644 --- a/docs/autoapi_templates/index.rst +++ b/docs/autoapi_templates/index.rst @@ -146,6 +146,13 @@ persisted in the database. airflow/models/index +Exceptions +---------- +Exceptions are the mechanism to raise exceptions by Airflow operators, hooks and sensors. + +* :mod:`airflow.exceptions` + + Core and community package -------------------------- Formerly the core code was maintained by the original creators - Airbnb. The code diff --git a/docs/conf.py b/docs/conf.py index 44a534357971c6..b0fc15b817caae 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,7 +36,7 @@ import sys from typing import Dict -import airflow +import airflow.settings autodoc_mock_imports = [ 'MySQLdb', @@ -227,6 +227,11 @@ '_api/airflow/providers/amazon/aws/example_dags', '_api/airflow/providers/apache/index.rst', '_api/airflow/providers/apache/cassandra/index.rst', + '_api/enums/index.rst', + '_api/json_schema/index.rst', + '_api/base_serialization/index.rst', + '_api/serialized_baseoperator/index.rst', + '_api/serialized_dag/index.rst', 'autoapi_templates', 'howto/operator/gcp/_partials', ] @@ -471,7 +476,6 @@ '*/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py', '*/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py', '*/airflow/kubernetes/kubernetes_request_factory/*', - '*/node_modules/*', '*/migrations/*', ] diff --git a/docs/scheduler.rst b/docs/scheduler.rst index c958ec77f514f8..9fd9df72be2086 100644 --- a/docs/scheduler.rst +++ b/docs/scheduler.rst @@ -117,7 +117,7 @@ interval series. Code that goes along with the Airflow tutorial located at: https://github.com/apache/airflow/blob/master/airflow/example_dags/tutorial.py """ - from airflow import DAG + from airflow.models.dag import DAG from airflow.operators.bash_operator import BashOperator from datetime import datetime, timedelta diff --git a/docs/tutorial.rst b/docs/tutorial.rst index d89ce509adaef5..82b713d06e0a05 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -36,7 +36,7 @@ complicated, a line by line explanation follows below. Code that goes along with the Airflow tutorial located at: https://github.com/apache/airflow/blob/master/airflow/example_dags/tutorial.py """ - from airflow import DAG + from airflow.models.dag import DAG from airflow.operators.bash_operator import BashOperator from datetime import datetime, timedelta @@ -116,7 +116,7 @@ Airflow DAG object. Let's start by importing the libraries we will need. .. code:: python # The DAG object; we'll need this to instantiate a DAG - from airflow import DAG + from airflow.models.dag import DAG # Operators; we need this to operate! from airflow.operators.bash_operator import BashOperator @@ -312,7 +312,7 @@ something like this: Code that goes along with the Airflow tutorial located at: https://github.com/apache/airflow/blob/master/airflow/example_dags/tutorial.py """ - from airflow import DAG + from airflow.models.dag import DAG from airflow.operators.bash_operator import BashOperator from datetime import datetime, timedelta diff --git a/scripts/perf/dags/perf_dag_1.py b/scripts/perf/dags/perf_dag_1.py index 4ac25c226aa307..64977c9c55aedc 100644 --- a/scripts/perf/dags/perf_dag_1.py +++ b/scripts/perf/dags/perf_dag_1.py @@ -22,7 +22,7 @@ from datetime import timedelta import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator args = { diff --git a/scripts/perf/dags/perf_dag_2.py b/scripts/perf/dags/perf_dag_2.py index b4f4370fa02220..511c1cfe698690 100644 --- a/scripts/perf/dags/perf_dag_2.py +++ b/scripts/perf/dags/perf_dag_2.py @@ -22,7 +22,7 @@ from datetime import timedelta import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator args = { diff --git a/scripts/perf/scheduler_ops_metrics.py b/scripts/perf/scheduler_ops_metrics.py index 26b07e8481503d..23d19fbffe0cf1 100644 --- a/scripts/perf/scheduler_ops_metrics.py +++ b/scripts/perf/scheduler_ops_metrics.py @@ -69,11 +69,10 @@ def print_stats(self): Print operational metrics for the scheduler test. """ session = settings.Session() - TI = TaskInstance tis = ( session - .query(TI) - .filter(TI.dag_id.in_(DAG_IDS)) + .query(TaskInstance) + .filter(TaskInstance.dag_id.in_(DAG_IDS)) .all() ) successful_tis = [x for x in tis if x.state == State.SUCCESS] @@ -109,12 +108,11 @@ def heartbeat(self): super().heartbeat() session = settings.Session() # Get all the relevant task instances - TI = TaskInstance successful_tis = ( session - .query(TI) - .filter(TI.dag_id.in_(DAG_IDS)) - .filter(TI.state.in_([State.SUCCESS])) + .query(TaskInstance) + .filter(TaskInstance.dag_id.in_(DAG_IDS)) + .filter(TaskInstance.state.in_([State.SUCCESS])) .all() ) session.commit() @@ -155,11 +153,10 @@ def clear_dag_task_instances(): Remove any existing task instances for the perf test DAGs. """ session = settings.Session() - TI = TaskInstance tis = ( session - .query(TI) - .filter(TI.dag_id.in_(DAG_IDS)) + .query(TaskInstance) + .filter(TaskInstance.dag_id.in_(DAG_IDS)) .all() ) for ti in tis: diff --git a/tests/api/common/experimental/test_delete_dag.py b/tests/api/common/experimental/test_delete_dag.py index 8d45ca7e094fd5..b81a1dd4fdf42d 100644 --- a/tests/api/common/experimental/test_delete_dag.py +++ b/tests/api/common/experimental/test_delete_dag.py @@ -29,7 +29,7 @@ DM = models.DagModel DR = models.DagRun -TI = models.TaskInstance +TaskInstance = models.TaskInstance LOG = models.log.Log TF = models.taskfail.TaskFail TR = models.taskreschedule.TaskReschedule @@ -66,9 +66,9 @@ def setUp(self): with create_session() as session: session.add(DM(dag_id=self.key, fileloc=self.dag_file_path)) session.add(DR(dag_id=self.key)) - session.add(TI(task=task, - execution_date=test_date, - state=State.SUCCESS)) + session.add(TaskInstance(task=task, + execution_date=test_date, + state=State.SUCCESS)) # flush to ensure task instance if written before # task reschedule because of FK constraint session.flush() @@ -86,7 +86,7 @@ def tearDown(self): with create_session() as session: session.query(TR).filter(TR.dag_id == self.key).delete() session.query(TF).filter(TF.dag_id == self.key).delete() - session.query(TI).filter(TI.dag_id == self.key).delete() + session.query(TaskInstance).filter(TaskInstance.dag_id == self.key).delete() session.query(DR).filter(DR.dag_id == self.key).delete() session.query(DM).filter(DM.dag_id == self.key).delete() session.query(LOG).filter(LOG.dag_id == self.key).delete() @@ -96,7 +96,7 @@ def test_delete_dag_successful_delete(self): with create_session() as session: self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 1) self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 1) - self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 1) + self.assertEqual(session.query(TaskInstance).filter(TaskInstance.dag_id == self.key).count(), 1) self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1) @@ -108,7 +108,7 @@ def test_delete_dag_successful_delete(self): with create_session() as session: self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 0) self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 0) - self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 0) + self.assertEqual(session.query(TaskInstance).filter(TaskInstance.dag_id == self.key).count(), 0) self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1) @@ -120,7 +120,7 @@ def test_delete_dag_successful_delete_not_keeping_records_in_log(self): with create_session() as session: self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 1) self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 1) - self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 1) + self.assertEqual(session.query(TaskInstance).filter(TaskInstance.dag_id == self.key).count(), 1) self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1) @@ -132,7 +132,7 @@ def test_delete_dag_successful_delete_not_keeping_records_in_log(self): with create_session() as session: self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 0) self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 0) - self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 0) + self.assertEqual(session.query(TaskInstance).filter(TaskInstance.dag_id == self.key).count(), 0) self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 0) diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py index e234d34245a28b..072b27fda24b09 100644 --- a/tests/api/common/experimental/test_mark_tasks.py +++ b/tests/api/common/experimental/test_mark_tasks.py @@ -84,20 +84,20 @@ def tearDown(self): @staticmethod def snapshot_state(dag, execution_dates): - TI = models.TaskInstance + TaskInstance = models.TaskInstance with create_session() as session: - return session.query(TI).filter( - TI.dag_id == dag.dag_id, - TI.execution_date.in_(execution_dates) + return session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.execution_date.in_(execution_dates) ).all() @provide_session def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=None): - TI = models.TaskInstance + TaskInstance = models.TaskInstance - tis = session.query(TI).filter( - TI.dag_id == dag.dag_id, - TI.execution_date.in_(execution_dates) + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.execution_date.in_(execution_dates) ).all() self.assertTrue(len(tis) > 0) @@ -311,9 +311,9 @@ def _verify_task_instance_states_remain_default(self, dr): @provide_session def _verify_task_instance_states(self, dag, date, state, session=None): - TI = models.TaskInstance - tis = session.query(TI)\ - .filter(TI.dag_id == dag.dag_id, TI.execution_date == date) + TaskInstance = models.TaskInstance + tis = session.query(TaskInstance)\ + .filter(TaskInstance.dag_id == dag.dag_id, TaskInstance.execution_date == date) for ti in tis: self.assertEqual(ti.state, state) diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/experimental/test_trigger_dag.py index b787f840ffbc39..6bbd1e4a1f762c 100644 --- a/tests/api/common/experimental/test_trigger_dag.py +++ b/tests/api/common/experimental/test_trigger_dag.py @@ -21,9 +21,9 @@ import unittest from unittest import mock +from airflow import DAG, AirflowException from airflow.api.common.experimental.trigger_dag import _trigger_dag -from airflow.exceptions import AirflowException -from airflow.models import DAG, DagRun +from airflow.models import DagRun class TestTriggerDag(unittest.TestCase): diff --git a/tests/contrib/hooks/test_azure_cosmos_hook.py b/tests/contrib/hooks/test_azure_cosmos_hook.py index ee0c39a6274648..e6fa8a3e9ac5e4 100644 --- a/tests/contrib/hooks/test_azure_cosmos_hook.py +++ b/tests/contrib/hooks/test_azure_cosmos_hook.py @@ -24,8 +24,8 @@ import unittest import uuid +from airflow import AirflowException from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook -from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.utils import db from tests.compat import mock diff --git a/tests/contrib/hooks/test_cloudant_hook.py b/tests/contrib/hooks/test_cloudant_hook.py index dc17a97d9f8c9e..d3ecc3b9482024 100644 --- a/tests/contrib/hooks/test_cloudant_hook.py +++ b/tests/contrib/hooks/test_cloudant_hook.py @@ -18,8 +18,8 @@ # under the License. import unittest +from airflow import AirflowException from airflow.contrib.hooks.cloudant_hook import CloudantHook -from airflow.exceptions import AirflowException from airflow.models import Connection from tests.compat import patch diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index 8b96622ca21f12..63f68fe1a75f59 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -24,9 +24,8 @@ from requests import exceptions as requests_exceptions -from airflow import __version__ +from airflow import AirflowException, __version__ from airflow.contrib.hooks.databricks_hook import SUBMIT_RUN_ENDPOINT, DatabricksHook, RunState -from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.utils import db from tests.compat import mock diff --git a/tests/contrib/hooks/test_datadog_hook.py b/tests/contrib/hooks/test_datadog_hook.py index df6d1bcdc59620..d0eec20eb6b9ed 100644 --- a/tests/contrib/hooks/test_datadog_hook.py +++ b/tests/contrib/hooks/test_datadog_hook.py @@ -21,8 +21,8 @@ import unittest from unittest import mock +from airflow import AirflowException from airflow.contrib.hooks.datadog_hook import DatadogHook -from airflow.exceptions import AirflowException from airflow.models import Connection APP_KEY = 'app_key' diff --git a/tests/contrib/hooks/test_pinot_hook.py b/tests/contrib/hooks/test_pinot_hook.py index 3765f7d46d1a48..ca5d720927e63e 100644 --- a/tests/contrib/hooks/test_pinot_hook.py +++ b/tests/contrib/hooks/test_pinot_hook.py @@ -24,8 +24,8 @@ import unittest from unittest import mock +from airflow import AirflowException from airflow.contrib.hooks.pinot_hook import PinotAdminHook, PinotDbApiHook -from airflow.exceptions import AirflowException class TestPinotAdminHook(unittest.TestCase): diff --git a/tests/contrib/hooks/test_sagemaker_hook.py b/tests/contrib/hooks/test_sagemaker_hook.py index 499c5c18f8a75d..e9b58f9e2f00d4 100644 --- a/tests/contrib/hooks/test_sagemaker_hook.py +++ b/tests/contrib/hooks/test_sagemaker_hook.py @@ -24,11 +24,11 @@ from tzlocal import get_localzone +from airflow import AirflowException from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook from airflow.contrib.hooks.sagemaker_hook import ( LogState, SageMakerHook, secondary_training_status_changed, secondary_training_status_message, ) -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook from tests.compat import mock diff --git a/tests/contrib/hooks/test_sqoop_hook.py b/tests/contrib/hooks/test_sqoop_hook.py index 05e8c47688a3ae..a2cc80bc602abb 100644 --- a/tests/contrib/hooks/test_sqoop_hook.py +++ b/tests/contrib/hooks/test_sqoop_hook.py @@ -24,8 +24,8 @@ from io import StringIO from unittest.mock import call, patch +from airflow import AirflowException from airflow.contrib.hooks.sqoop_hook import SqoopHook -from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.utils import db diff --git a/tests/contrib/operators/test_awsbatch_operator.py b/tests/contrib/operators/test_awsbatch_operator.py index c1723e504b9ece..869b5b68bed5c2 100644 --- a/tests/contrib/operators/test_awsbatch_operator.py +++ b/tests/contrib/operators/test_awsbatch_operator.py @@ -21,8 +21,8 @@ import sys import unittest +from airflow import AirflowException from airflow.contrib.operators.awsbatch_operator import AWSBatchOperator -from airflow.exceptions import AirflowException from tests.compat import mock RESPONSE_WITHOUT_FAILURES = { diff --git a/tests/contrib/operators/test_azure_container_instances_operator.py b/tests/contrib/operators/test_azure_container_instances_operator.py index 56f4f6ff77764d..7d0eadb3c35dc7 100644 --- a/tests/contrib/operators/test_azure_container_instances_operator.py +++ b/tests/contrib/operators/test_azure_container_instances_operator.py @@ -23,8 +23,8 @@ from azure.mgmt.containerinstance.models import ContainerState, Event +from airflow import AirflowException from airflow.contrib.operators.azure_container_instances_operator import AzureContainerInstancesOperator -from airflow.exceptions import AirflowException from tests.compat import mock diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py index 78d140684293af..9acb8da6324dd7 100644 --- a/tests/contrib/operators/test_databricks_operator.py +++ b/tests/contrib/operators/test_databricks_operator.py @@ -22,12 +22,11 @@ from datetime import datetime import airflow.contrib.operators.databricks_operator as databricks_operator +from airflow import DAG, AirflowException from airflow.contrib.hooks.databricks_hook import RunState from airflow.contrib.operators.databricks_operator import ( DatabricksRunNowOperator, DatabricksSubmitRunOperator, ) -from airflow.exceptions import AirflowException -from airflow.models import DAG from tests.compat import mock DATE = '2017-04-20' diff --git a/tests/contrib/operators/test_ecs_operator.py b/tests/contrib/operators/test_ecs_operator.py index 8fbe58912f41ae..5450c91620d3e2 100644 --- a/tests/contrib/operators/test_ecs_operator.py +++ b/tests/contrib/operators/test_ecs_operator.py @@ -24,8 +24,8 @@ from parameterized import parameterized +from airflow import AirflowException from airflow.contrib.operators.ecs_operator import ECSOperator -from airflow.exceptions import AirflowException from tests.compat import mock RESPONSE_WITHOUT_FAILURES = { diff --git a/tests/contrib/operators/test_jenkins_operator.py b/tests/contrib/operators/test_jenkins_operator.py index e637d649ffcf65..38b2352e9e1e09 100644 --- a/tests/contrib/operators/test_jenkins_operator.py +++ b/tests/contrib/operators/test_jenkins_operator.py @@ -21,9 +21,9 @@ import jenkins +from airflow import AirflowException from airflow.contrib.hooks.jenkins_hook import JenkinsHook from airflow.contrib.operators.jenkins_job_trigger_operator import JenkinsJobTriggerOperator -from airflow.exceptions import AirflowException from tests.compat import mock diff --git a/tests/contrib/operators/test_qubole_check_operator.py b/tests/contrib/operators/test_qubole_check_operator.py index d64879d822c948..b53ccba0a9a4ff 100644 --- a/tests/contrib/operators/test_qubole_check_operator.py +++ b/tests/contrib/operators/test_qubole_check_operator.py @@ -22,11 +22,10 @@ from qds_sdk.commands import HiveCommand +from airflow import DAG, AirflowException from airflow.contrib.hooks.qubole_check_hook import QuboleCheckHook from airflow.contrib.hooks.qubole_hook import QuboleHook from airflow.contrib.operators.qubole_check_operator import QuboleValueCheckOperator -from airflow.exceptions import AirflowException -from airflow.models import DAG from tests.compat import mock diff --git a/tests/contrib/operators/test_qubole_operator.py b/tests/contrib/operators/test_qubole_operator.py index 92554932b781d8..94513423c88b35 100644 --- a/tests/contrib/operators/test_qubole_operator.py +++ b/tests/contrib/operators/test_qubole_operator.py @@ -20,10 +20,10 @@ import unittest -from airflow import settings +from airflow import DAG, settings from airflow.contrib.hooks.qubole_hook import QuboleHook from airflow.contrib.operators.qubole_operator import QuboleOperator -from airflow.models import DAG, Connection +from airflow.models import Connection from airflow.models.taskinstance import TaskInstance from airflow.utils import db from airflow.utils.timezone import datetime diff --git a/tests/contrib/operators/test_s3_to_sftp_operator.py b/tests/contrib/operators/test_s3_to_sftp_operator.py index ec688b53372e2b..70b4612578c9c4 100644 --- a/tests/contrib/operators/test_s3_to_sftp_operator.py +++ b/tests/contrib/operators/test_s3_to_sftp_operator.py @@ -22,11 +22,11 @@ import boto3 from moto import mock_s3 -from airflow import models +from airflow import DAG, models from airflow.configuration import conf from airflow.contrib.operators.s3_to_sftp_operator import S3ToSFTPOperator from airflow.contrib.operators.ssh_operator import SSHOperator -from airflow.models import DAG, TaskInstance +from airflow.models import TaskInstance from airflow.settings import Session from airflow.utils import timezone from airflow.utils.timezone import datetime diff --git a/tests/contrib/operators/test_sagemaker_endpoint_config_operator.py b/tests/contrib/operators/test_sagemaker_endpoint_config_operator.py index cd1bcbf76fb99a..72b24b69297a79 100644 --- a/tests/contrib/operators/test_sagemaker_endpoint_config_operator.py +++ b/tests/contrib/operators/test_sagemaker_endpoint_config_operator.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.operators.sagemaker_endpoint_config_operator import SageMakerEndpointConfigOperator -from airflow.exceptions import AirflowException from tests.compat import mock model_name = 'test-model-name' diff --git a/tests/contrib/operators/test_sagemaker_endpoint_operator.py b/tests/contrib/operators/test_sagemaker_endpoint_operator.py index cc895941738a50..1a210e989c0874 100644 --- a/tests/contrib/operators/test_sagemaker_endpoint_operator.py +++ b/tests/contrib/operators/test_sagemaker_endpoint_operator.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.operators.sagemaker_endpoint_operator import SageMakerEndpointOperator -from airflow.exceptions import AirflowException from tests.compat import mock role = 'arn:aws:iam:role/test-role' diff --git a/tests/contrib/operators/test_sagemaker_model_operator.py b/tests/contrib/operators/test_sagemaker_model_operator.py index 252120dda4b2ad..5c234fa397646f 100644 --- a/tests/contrib/operators/test_sagemaker_model_operator.py +++ b/tests/contrib/operators/test_sagemaker_model_operator.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.operators.sagemaker_model_operator import SageMakerModelOperator -from airflow.exceptions import AirflowException from tests.compat import mock role = 'arn:aws:iam:role/test-role' diff --git a/tests/contrib/operators/test_sagemaker_training_operator.py b/tests/contrib/operators/test_sagemaker_training_operator.py index b385af9c141e0e..befcddc637eb69 100644 --- a/tests/contrib/operators/test_sagemaker_training_operator.py +++ b/tests/contrib/operators/test_sagemaker_training_operator.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator -from airflow.exceptions import AirflowException from tests.compat import mock role = 'arn:aws:iam:role/test-role' diff --git a/tests/contrib/operators/test_sagemaker_transform_operator.py b/tests/contrib/operators/test_sagemaker_transform_operator.py index 01adcf12b4cc31..c8f76e555f6d96 100644 --- a/tests/contrib/operators/test_sagemaker_transform_operator.py +++ b/tests/contrib/operators/test_sagemaker_transform_operator.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.operators.sagemaker_transform_operator import SageMakerTransformOperator -from airflow.exceptions import AirflowException from tests.compat import mock role = 'arn:aws:iam:role/test-role' diff --git a/tests/contrib/operators/test_sagemaker_tuning_operator.py b/tests/contrib/operators/test_sagemaker_tuning_operator.py index e5dfe4955fd99c..c55de09a9cc34c 100644 --- a/tests/contrib/operators/test_sagemaker_tuning_operator.py +++ b/tests/contrib/operators/test_sagemaker_tuning_operator.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.operators.sagemaker_tuning_operator import SageMakerTuningOperator -from airflow.exceptions import AirflowException from tests.compat import mock role = 'arn:aws:iam:role/test-role' diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 3d4c6a7e70e2a6..eef023e5918e63 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -22,10 +22,10 @@ from base64 import b64encode from unittest import mock -from airflow import AirflowException, models +from airflow import DAG, AirflowException from airflow.contrib.operators.sftp_operator import SFTPOperation, SFTPOperator from airflow.contrib.operators.ssh_operator import SSHOperator -from airflow.models import DAG, TaskInstance +from airflow.models import TaskInstance from airflow.settings import Session from airflow.utils import timezone from airflow.utils.timezone import datetime @@ -38,7 +38,7 @@ def reset(dag_id=TEST_DAG_ID): session = Session() - tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) + tis = session.query(TaskInstance).filter_by(dag_id=dag_id) tis.delete() session.commit() session.close() diff --git a/tests/contrib/operators/test_sftp_to_s3_operator.py b/tests/contrib/operators/test_sftp_to_s3_operator.py index 533ae4277f29f1..a0acc2e5a9a92c 100644 --- a/tests/contrib/operators/test_sftp_to_s3_operator.py +++ b/tests/contrib/operators/test_sftp_to_s3_operator.py @@ -22,11 +22,11 @@ import boto3 from moto import mock_s3 -from airflow import models +from airflow import DAG, models from airflow.contrib.hooks.ssh_hook import SSHHook from airflow.contrib.operators.sftp_to_s3_operator import SFTPToS3Operator from airflow.contrib.operators.ssh_operator import SSHOperator -from airflow.models import DAG, TaskInstance +from airflow.models import TaskInstance from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.settings import Session from airflow.utils import timezone diff --git a/tests/contrib/operators/test_sqoop_operator.py b/tests/contrib/operators/test_sqoop_operator.py index 0a673834c4d778..21984ee56c2934 100644 --- a/tests/contrib/operators/test_sqoop_operator.py +++ b/tests/contrib/operators/test_sqoop_operator.py @@ -21,9 +21,8 @@ import datetime import unittest -from airflow import DAG +from airflow import DAG, AirflowException from airflow.contrib.operators.sqoop_operator import SqoopOperator -from airflow.exceptions import AirflowException class TestSqoopOperator(unittest.TestCase): diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index a9033496e068ef..24758f246f7cda 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -17,15 +17,14 @@ # specific language governing permissions and limitations # under the License. -import unittest import unittest.mock from base64 import b64encode from parameterized import parameterized -from airflow import AirflowException, models +from airflow import DAG, AirflowException, models from airflow.contrib.operators.ssh_operator import SSHOperator -from airflow.models import DAG, TaskInstance +from airflow.models import TaskInstance from airflow.settings import Session from airflow.utils import timezone from airflow.utils.timezone import datetime diff --git a/tests/contrib/operators/test_winrm_operator.py b/tests/contrib/operators/test_winrm_operator.py index f9d15fb590d82c..1611ec3e26c98d 100644 --- a/tests/contrib/operators/test_winrm_operator.py +++ b/tests/contrib/operators/test_winrm_operator.py @@ -20,8 +20,8 @@ import unittest from unittest import mock +from airflow import AirflowException from airflow.contrib.operators.winrm_operator import WinRMOperator -from airflow.exceptions import AirflowException class TestWinRMOperator(unittest.TestCase): diff --git a/tests/contrib/sensors/test_emr_base_sensor.py b/tests/contrib/sensors/test_emr_base_sensor.py index b4b2d124a5293a..87547d4182ad58 100644 --- a/tests/contrib/sensors/test_emr_base_sensor.py +++ b/tests/contrib/sensors/test_emr_base_sensor.py @@ -19,8 +19,8 @@ import unittest +from airflow import AirflowException from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor -from airflow.exceptions import AirflowException class TestEmrBaseSensor(unittest.TestCase): diff --git a/tests/contrib/sensors/test_qubole_sensor.py b/tests/contrib/sensors/test_qubole_sensor.py index 17828d56422d90..a5ced303f0964e 100644 --- a/tests/contrib/sensors/test_qubole_sensor.py +++ b/tests/contrib/sensors/test_qubole_sensor.py @@ -22,9 +22,9 @@ from datetime import datetime from unittest.mock import patch +from airflow import DAG, AirflowException from airflow.contrib.sensors.qubole_sensor import QuboleFileSensor, QubolePartitionSensor -from airflow.exceptions import AirflowException -from airflow.models import DAG, Connection +from airflow.models import Connection from airflow.utils import db DAG_ID = "qubole_test_dag" diff --git a/tests/contrib/sensors/test_sagemaker_base_sensor.py b/tests/contrib/sensors/test_sagemaker_base_sensor.py index 2892dc80711564..94b6a2bda5b1ca 100644 --- a/tests/contrib/sensors/test_sagemaker_base_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_base_sensor.py @@ -19,8 +19,8 @@ import unittest +from airflow import AirflowException from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor -from airflow.exceptions import AirflowException class TestSagemakerBaseSensor(unittest.TestCase): diff --git a/tests/contrib/sensors/test_sagemaker_endpoint_sensor.py b/tests/contrib/sensors/test_sagemaker_endpoint_sensor.py index 65800da219da56..7b4a717d2f3524 100644 --- a/tests/contrib/sensors/test_sagemaker_endpoint_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_endpoint_sensor.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.sensors.sagemaker_endpoint_sensor import SageMakerEndpointSensor -from airflow.exceptions import AirflowException from tests.compat import mock DESCRIBE_ENDPOINT_CREATING_RESPONSE = { diff --git a/tests/contrib/sensors/test_sagemaker_training_sensor.py b/tests/contrib/sensors/test_sagemaker_training_sensor.py index c380e42bc406b4..5e6c89668e9832 100644 --- a/tests/contrib/sensors/test_sagemaker_training_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_training_sensor.py @@ -20,10 +20,10 @@ import unittest from datetime import datetime +from airflow import AirflowException from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook from airflow.contrib.hooks.sagemaker_hook import LogState, SageMakerHook from airflow.contrib.sensors.sagemaker_training_sensor import SageMakerTrainingSensor -from airflow.exceptions import AirflowException from tests.compat import mock DESCRIBE_TRAINING_COMPELETED_RESPONSE = { diff --git a/tests/contrib/sensors/test_sagemaker_transform_sensor.py b/tests/contrib/sensors/test_sagemaker_transform_sensor.py index e26bf708e31195..d33bbce3ba75c4 100644 --- a/tests/contrib/sensors/test_sagemaker_transform_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_transform_sensor.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.sensors.sagemaker_transform_sensor import SageMakerTransformSensor -from airflow.exceptions import AirflowException from tests.compat import mock DESCRIBE_TRANSFORM_INPROGRESS_RESPONSE = { diff --git a/tests/contrib/sensors/test_sagemaker_tuning_sensor.py b/tests/contrib/sensors/test_sagemaker_tuning_sensor.py index be6cbc55331670..ad360058f344fa 100644 --- a/tests/contrib/sensors/test_sagemaker_tuning_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_tuning_sensor.py @@ -19,9 +19,9 @@ import unittest +from airflow import AirflowException from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.contrib.sensors.sagemaker_tuning_sensor import SageMakerTuningSensor -from airflow.exceptions import AirflowException from tests.compat import mock DESCRIBE_TUNING_INPROGRESS_RESPONSE = { diff --git a/tests/contrib/utils/test_task_handler_with_custom_formatter.py b/tests/contrib/utils/test_task_handler_with_custom_formatter.py index 95d6c857ae387f..94c4ff173239c2 100644 --- a/tests/contrib/utils/test_task_handler_with_custom_formatter.py +++ b/tests/contrib/utils/test_task_handler_with_custom_formatter.py @@ -20,9 +20,10 @@ import logging import unittest +from airflow import DAG from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.configuration import conf -from airflow.models import DAG, TaskInstance +from airflow.models import TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.utils.log.logging_mixin import set_context from airflow.utils.timezone import datetime diff --git a/tests/core.py b/tests/core.py index 5ff43dc9fe3be4..9a432d36506d22 100644 --- a/tests/core.py +++ b/tests/core.py @@ -31,14 +31,13 @@ from numpy.testing import assert_array_almost_equal from pendulum import utcnow -from airflow import DAG, configuration, exceptions, jobs, settings, utils +from airflow import DAG, AirflowException, configuration, exceptions, jobs, settings, utils from airflow.bin import cli from airflow.configuration import AirflowConfigException, conf, run_command -from airflow.exceptions import AirflowException -from airflow.executors import SequentialExecutor from airflow.hooks.base_hook import BaseHook from airflow.hooks.sqlite_hook import SqliteHook -from airflow.models import BaseOperator, Connection, DagBag, DagRun, Pool, TaskFail, TaskInstance, Variable +from airflow.models import Connection, DagBag, DagRun, Pool, TaskFail, TaskInstance, Variable +from airflow.models.baseoperator import BaseOperator from airflow.operators.bash_operator import BashOperator from airflow.operators.check_operator import CheckOperator, ValueCheckOperator from airflow.operators.dummy_operator import DummyOperator @@ -643,8 +642,7 @@ def __bool__(self): t.resolve_template_files() def test_task_get_template(self): - TI = TaskInstance - ti = TI( + ti = TaskInstance( task=self.runme_0, execution_date=DEFAULT_DATE) ti.dag = self.dag_bash ti.run(ignore_ti_state=True) @@ -673,15 +671,13 @@ def test_task_get_template(self): self.assertEqual(context['tomorrow_ds_nodash'], '20150102') def test_local_task_job(self): - TI = TaskInstance - ti = TI( + ti = TaskInstance( task=self.runme_0, execution_date=DEFAULT_DATE) job = jobs.LocalTaskJob(task_instance=ti, ignore_ti_state=True) job.run() def test_raw_job(self): - TI = TaskInstance - ti = TI( + ti = TaskInstance( task=self.runme_0, execution_date=DEFAULT_DATE) ti.dag = self.dag_bash ti.run(ignore_ti_state=True) @@ -870,11 +866,11 @@ def test_bad_trigger_rule(self): def test_terminate_task(self): """If a task instance's db state get deleted, it should fail""" - TI = TaskInstance + from airflow.executors.sequential_executor import SequentialExecutor dag = self.dagbag.dags.get('test_utils') task = dag.task_dict.get('sleeps_forever') - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) job = jobs.LocalTaskJob( task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) @@ -887,7 +883,7 @@ def test_terminate_task(self): ti.refresh_from_db(session=session) # making sure it's actually running self.assertEqual(State.RUNNING, ti.state) - ti = session.query(TI).filter_by( + ti = session.query(TaskInstance).filter_by( dag_id=task.dag_id, task_id=task.task_id, execution_date=DEFAULT_DATE @@ -950,7 +946,6 @@ def test_run_command(self): self.assertRaises(AirflowConfigException, run_command, 'bash -c "exit 1"') def test_externally_triggered_dagrun(self): - TI = TaskInstance # Create the dagrun between two "scheduled" execution dates of the DAG EXECUTION_DATE = DEFAULT_DATE + timedelta(days=2) @@ -971,7 +966,7 @@ def test_externally_triggered_dagrun(self): task.run( start_date=EXECUTION_DATE, end_date=EXECUTION_DATE) - ti = TI(task=task, execution_date=EXECUTION_DATE) + ti = TaskInstance(task=task, execution_date=EXECUTION_DATE) context = ti.get_template_context() # next_ds/prev_ds should be the execution date for manually triggered runs diff --git a/tests/dags/test_backfill_pooled_tasks.py b/tests/dags/test_backfill_pooled_tasks.py index 8320f51c86e761..0f8f3992b164b8 100644 --- a/tests/dags/test_backfill_pooled_tasks.py +++ b/tests/dags/test_backfill_pooled_tasks.py @@ -25,7 +25,7 @@ """ from datetime import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator dag = DAG(dag_id='test_backfill_pooled_task_dag') diff --git a/tests/dags/test_clear_subdag.py b/tests/dags/test_clear_subdag.py index 7385b686f676b1..3bd44269edf16c 100644 --- a/tests/dags/test_clear_subdag.py +++ b/tests/dags/test_clear_subdag.py @@ -20,7 +20,7 @@ import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.operators.subdag_operator import SubDagOperator diff --git a/tests/dags/test_cli_triggered_dags.py b/tests/dags/test_cli_triggered_dags.py index 64d827dc9cf83b..a67fd226a1fbbf 100644 --- a/tests/dags/test_cli_triggered_dags.py +++ b/tests/dags/test_cli_triggered_dags.py @@ -20,7 +20,7 @@ from datetime import timedelta -from airflow.models import DAG +from airflow import DAG from airflow.operators.python_operator import PythonOperator from airflow.utils.timezone import datetime diff --git a/tests/dags/test_default_impersonation.py b/tests/dags/test_default_impersonation.py index a9b83a1a2d9feb..465bba0d9bdd5a 100644 --- a/tests/dags/test_default_impersonation.py +++ b/tests/dags/test_default_impersonation.py @@ -20,7 +20,7 @@ from datetime import datetime from textwrap import dedent -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags/test_double_trigger.py b/tests/dags/test_double_trigger.py index a53d7638dfefc2..c00e01e9bcd84a 100644 --- a/tests/dags/test_double_trigger.py +++ b/tests/dags/test_double_trigger.py @@ -18,7 +18,7 @@ # under the License. from datetime import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags/test_example_bash_operator.py b/tests/dags/test_example_bash_operator.py index ae965c9260606d..b8f6f8283a4244 100644 --- a/tests/dags/test_example_bash_operator.py +++ b/tests/dags/test_example_bash_operator.py @@ -19,7 +19,7 @@ from datetime import timedelta import airflow -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator diff --git a/tests/dags/test_heartbeat_failed_fast.py b/tests/dags/test_heartbeat_failed_fast.py index 12c4cca7bfd051..5ac3dad3c8df93 100644 --- a/tests/dags/test_heartbeat_failed_fast.py +++ b/tests/dags/test_heartbeat_failed_fast.py @@ -18,7 +18,7 @@ # under the License. from datetime import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags/test_impersonation.py b/tests/dags/test_impersonation.py index 95887f36e7a897..a034f5aa9f29f3 100644 --- a/tests/dags/test_impersonation.py +++ b/tests/dags/test_impersonation.py @@ -20,7 +20,7 @@ from datetime import datetime from textwrap import dedent -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags/test_impersonation_subdag.py b/tests/dags/test_impersonation_subdag.py index e7da2da01b3cce..95fdfc01c17f26 100644 --- a/tests/dags/test_impersonation_subdag.py +++ b/tests/dags/test_impersonation_subdag.py @@ -19,7 +19,7 @@ from datetime import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.operators.python_operator import PythonOperator from airflow.operators.subdag_operator import SubDagOperator diff --git a/tests/dags/test_invalid_cron.py b/tests/dags/test_invalid_cron.py index 51a0e43cb500a4..1f181c3ca88083 100644 --- a/tests/dags/test_invalid_cron.py +++ b/tests/dags/test_invalid_cron.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.utils.timezone import datetime diff --git a/tests/dags/test_issue_1225.py b/tests/dags/test_issue_1225.py index 24be0222a10205..7d6d9e6fedb9f0 100644 --- a/tests/dags/test_issue_1225.py +++ b/tests/dags/test_issue_1225.py @@ -25,7 +25,7 @@ """ from datetime import datetime, timedelta -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator from airflow.utils.trigger_rule import TriggerRule diff --git a/tests/dags/test_latest_runs.py b/tests/dags/test_latest_runs.py index e8cd99adc907e0..271df94c1d2143 100644 --- a/tests/dags/test_latest_runs.py +++ b/tests/dags/test_latest_runs.py @@ -20,7 +20,7 @@ from datetime import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator for i in range(1, 2): diff --git a/tests/dags/test_mark_success.py b/tests/dags/test_mark_success.py index dfbcbc1b5d16b0..34d8b938dbd68e 100644 --- a/tests/dags/test_mark_success.py +++ b/tests/dags/test_mark_success.py @@ -18,7 +18,7 @@ # under the License. from datetime import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags/test_no_impersonation.py b/tests/dags/test_no_impersonation.py index 9382733737f2e7..8e6c4e75ae76a9 100644 --- a/tests/dags/test_no_impersonation.py +++ b/tests/dags/test_no_impersonation.py @@ -20,7 +20,7 @@ from datetime import datetime from textwrap import dedent -from airflow.models import DAG +from airflow import DAG from airflow.operators.bash_operator import BashOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags/test_on_kill.py b/tests/dags/test_on_kill.py index c5ae29852b8519..f2642d3d5e6190 100644 --- a/tests/dags/test_on_kill.py +++ b/tests/dags/test_on_kill.py @@ -18,7 +18,7 @@ # under the License. import time -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.utils.timezone import datetime diff --git a/tests/dags/test_scheduler_dags.py b/tests/dags/test_scheduler_dags.py index ec42a335608b99..b2cbefc2161e08 100644 --- a/tests/dags/test_scheduler_dags.py +++ b/tests/dags/test_scheduler_dags.py @@ -19,7 +19,7 @@ from datetime import datetime, timedelta -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags/test_subdag.py b/tests/dags/test_subdag.py index 52bc9ec81c2b4d..e50cf9dec7d7c0 100644 --- a/tests/dags/test_subdag.py +++ b/tests/dags/test_subdag.py @@ -24,7 +24,7 @@ from datetime import datetime, timedelta -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator diff --git a/tests/dags/test_task_view_type_check.py b/tests/dags/test_task_view_type_check.py index 7bad77ea96d30e..282be6cb6af68a 100644 --- a/tests/dags/test_task_view_type_check.py +++ b/tests/dags/test_task_view_type_check.py @@ -25,7 +25,7 @@ import logging from datetime import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.python_operator import PythonOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags_corrupted/test_impersonation_custom.py b/tests/dags_corrupted/test_impersonation_custom.py index 2d73046f6f5dde..fb1a3bb5d6aee3 100644 --- a/tests/dags_corrupted/test_impersonation_custom.py +++ b/tests/dags_corrupted/test_impersonation_custom.py @@ -28,7 +28,7 @@ # variable correctly. from fake_datetime import FakeDatetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.python_operator import PythonOperator DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/dags_with_system_exit/a_system_exit.py b/tests/dags_with_system_exit/a_system_exit.py index 67aa5ec42170f2..eff8c1e5ed1aef 100644 --- a/tests/dags_with_system_exit/a_system_exit.py +++ b/tests/dags_with_system_exit/a_system_exit.py @@ -23,7 +23,7 @@ import sys from datetime import datetime -from airflow.models import DAG +from airflow import DAG DEFAULT_DATE = datetime(2100, 1, 1) diff --git a/tests/dags_with_system_exit/b_test_scheduler_dags.py b/tests/dags_with_system_exit/b_test_scheduler_dags.py index bde36ba5dfecf0..9a47ca64f40bad 100644 --- a/tests/dags_with_system_exit/b_test_scheduler_dags.py +++ b/tests/dags_with_system_exit/b_test_scheduler_dags.py @@ -19,7 +19,7 @@ from datetime import datetime -from airflow.models import DAG +from airflow import DAG from airflow.operators.dummy_operator import DummyOperator DEFAULT_DATE = datetime(2000, 1, 1) diff --git a/tests/dags_with_system_exit/c_system_exit.py b/tests/dags_with_system_exit/c_system_exit.py index 39721c186c7879..b1aa6ef1294be7 100644 --- a/tests/dags_with_system_exit/c_system_exit.py +++ b/tests/dags_with_system_exit/c_system_exit.py @@ -23,7 +23,7 @@ import sys from datetime import datetime -from airflow.models import DAG +from airflow import DAG DEFAULT_DATE = datetime(2100, 1, 1) diff --git a/tests/gcp/hooks/test_cloud_sql.py b/tests/gcp/hooks/test_cloud_sql.py index 49963faba6789c..9134a134cf87ff 100644 --- a/tests/gcp/hooks/test_cloud_sql.py +++ b/tests/gcp/hooks/test_cloud_sql.py @@ -25,7 +25,7 @@ from googleapiclient.errors import HttpError from parameterized import parameterized -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook, CloudSqlHook from airflow.models import Connection from tests.compat import PropertyMock, mock diff --git a/tests/gcp/hooks/test_gcs.py b/tests/gcp/hooks/test_gcs.py index bdb7cacd18a343..9665088b63fc86 100644 --- a/tests/gcp/hooks/test_gcs.py +++ b/tests/gcp/hooks/test_gcs.py @@ -28,7 +28,7 @@ import dateutil from google.cloud import exceptions, storage -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.gcp.hooks import gcs from airflow.version import version from tests.compat import mock diff --git a/tests/gcp/operators/test_bigquery.py b/tests/gcp/operators/test_bigquery.py index 0a7ebfccd21f5f..5451ff4cf7c04a 100644 --- a/tests/gcp/operators/test_bigquery.py +++ b/tests/gcp/operators/test_bigquery.py @@ -21,15 +21,14 @@ from datetime import datetime from unittest.mock import MagicMock -from airflow import models -from airflow.exceptions import AirflowException +from airflow import DAG, AirflowException, models from airflow.gcp.operators.bigquery import ( BigQueryConsoleLink, BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryCreateExternalTableOperator, BigQueryDeleteDatasetOperator, BigQueryGetDataOperator, BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator, BigQueryOperator, BigQueryPatchDatasetOperator, BigQueryTableDeleteOperator, BigQueryUpdateDatasetOperator, ) -from airflow.models import DAG, TaskFail, TaskInstance, XCom +from airflow.models import TaskFail, TaskInstance, XCom from airflow.settings import Session from airflow.utils.db import provide_session from tests.compat import mock diff --git a/tests/gcp/operators/test_cloud_storage_transfer_service.py b/tests/gcp/operators/test_cloud_storage_transfer_service.py index 2eb30114e08eaf..6a147ad9365d6f 100644 --- a/tests/gcp/operators/test_cloud_storage_transfer_service.py +++ b/tests/gcp/operators/test_cloud_storage_transfer_service.py @@ -26,7 +26,7 @@ from freezegun import freeze_time from parameterized import parameterized -from airflow import AirflowException +from airflow import DAG, AirflowException from airflow.gcp.hooks.cloud_storage_transfer_service import ( ACCESS_KEY_ID, AWS_ACCESS_KEY, AWS_S3_DATA_SOURCE, BUCKET_NAME, FILTER_JOB_NAMES, GCS_DATA_SINK, GCS_DATA_SOURCE, HTTP_DATA_SOURCE, LIST_URL, NAME, SCHEDULE, SCHEDULE_END_DATE, SCHEDULE_START_DATE, @@ -40,7 +40,7 @@ GoogleCloudStorageToGoogleCloudStorageTransferOperator, S3ToGoogleCloudStorageTransferOperator, TransferJobPreprocessor, TransferJobValidator, ) -from airflow.models import DAG, TaskInstance +from airflow.models import TaskInstance from airflow.utils import timezone from tests.compat import mock diff --git a/tests/gcp/operators/test_compute.py b/tests/gcp/operators/test_compute.py index dac0188d0f7bfc..b14e8d7d28ddbf 100644 --- a/tests/gcp/operators/test_compute.py +++ b/tests/gcp/operators/test_compute.py @@ -26,12 +26,12 @@ import httplib2 from googleapiclient.errors import HttpError -from airflow import AirflowException +from airflow import DAG, AirflowException from airflow.gcp.operators.compute import ( GceInstanceGroupManagerUpdateTemplateOperator, GceInstanceStartOperator, GceInstanceStopOperator, GceInstanceTemplateCopyOperator, GceSetMachineTypeOperator, ) -from airflow.models import DAG, TaskInstance +from airflow.models import TaskInstance from airflow.utils import timezone from tests.compat import mock diff --git a/tests/gcp/operators/test_mlengine.py b/tests/gcp/operators/test_mlengine.py index 9485ea913e1b04..28c209c4a16747 100644 --- a/tests/gcp/operators/test_mlengine.py +++ b/tests/gcp/operators/test_mlengine.py @@ -23,8 +23,7 @@ import httplib2 from googleapiclient.errors import HttpError -from airflow import DAG -from airflow.exceptions import AirflowException +from airflow import DAG, AirflowException from airflow.gcp.operators.mlengine import ( MLEngineBatchPredictionOperator, MLEngineCreateModelOperator, MLEngineCreateVersionOperator, MLEngineDeleteModelOperator, MLEngineDeleteVersionOperator, MLEngineGetModelOperator, diff --git a/tests/gcp/operators/test_mlengine_utils.py b/tests/gcp/operators/test_mlengine_utils.py index 68ea9047a68975..bc955c4584f511 100644 --- a/tests/gcp/operators/test_mlengine_utils.py +++ b/tests/gcp/operators/test_mlengine_utils.py @@ -19,8 +19,7 @@ import unittest from unittest.mock import ANY, patch -from airflow import DAG -from airflow.exceptions import AirflowException +from airflow import DAG, AirflowException from airflow.gcp.utils import mlengine_operator_utils from airflow.version import version diff --git a/tests/hooks/test_docker_hook.py b/tests/hooks/test_docker_hook.py index e83a649822bf63..f9ca42db5f1586 100644 --- a/tests/hooks/test_docker_hook.py +++ b/tests/hooks/test_docker_hook.py @@ -19,7 +19,7 @@ import unittest -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models import Connection from airflow.utils import db from tests.compat import mock diff --git a/tests/hooks/test_druid_hook.py b/tests/hooks/test_druid_hook.py index 215f8f45b1a826..6647ab17860d5c 100644 --- a/tests/hooks/test_druid_hook.py +++ b/tests/hooks/test_druid_hook.py @@ -24,7 +24,7 @@ import requests import requests_mock -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.druid_hook import DruidDbApiHook, DruidHook diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py index 59cad448ef4360..a9d7109ad72ff5 100644 --- a/tests/hooks/test_hive_hook.py +++ b/tests/hooks/test_hive_hook.py @@ -29,8 +29,7 @@ import pandas as pd from hmsclient import HMSClient -from airflow import DAG -from airflow.exceptions import AirflowException +from airflow import DAG, AirflowException from airflow.hooks.hive_hooks import HiveCliHook, HiveMetastoreHook, HiveServer2Hook from airflow.models.connection import Connection from airflow.operators.hive_operator import HiveOperator diff --git a/tests/hooks/test_http_hook.py b/tests/hooks/test_http_hook.py index e9722d299de0e0..49b9e792172282 100644 --- a/tests/hooks/test_http_hook.py +++ b/tests/hooks/test_http_hook.py @@ -23,7 +23,7 @@ import requests_mock import tenacity -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.http_hook import HttpHook from airflow.models import Connection from tests.compat import mock diff --git a/tests/hooks/test_pig_hook.py b/tests/hooks/test_pig_hook.py index 529da68e6f93db..8b069bd6f75fbb 100644 --- a/tests/hooks/test_pig_hook.py +++ b/tests/hooks/test_pig_hook.py @@ -65,7 +65,7 @@ def test_run_cli_fail(self, popen_mock): hook = self.pig_hook() - from airflow.exceptions import AirflowException + from airflow import AirflowException self.assertRaises(AirflowException, hook.run_cli, "") @mock.patch('subprocess.Popen') diff --git a/tests/hooks/test_slack_hook.py b/tests/hooks/test_slack_hook.py index 7f45f996a2d9aa..d8f43d85f8a0af 100644 --- a/tests/hooks/test_slack_hook.py +++ b/tests/hooks/test_slack_hook.py @@ -19,7 +19,7 @@ import unittest -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.hooks.slack_hook import SlackHook from tests.compat import mock diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 739f4c78c24d7e..79cd08a43332f4 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -27,14 +27,15 @@ import sqlalchemy from parameterized import parameterized -from airflow import AirflowException, settings +from airflow import DAG, settings from airflow.bin import cli from airflow.configuration import conf from airflow.exceptions import ( - AirflowTaskTimeout, DagConcurrencyLimitReached, NoAvailablePoolSlot, TaskConcurrencyLimitReached, + AirflowException, AirflowTaskTimeout, DagConcurrencyLimitReached, NoAvailablePoolSlot, + TaskConcurrencyLimitReached, ) from airflow.jobs import BackfillJob, SchedulerJob -from airflow.models import DAG, DagBag, DagRun, Pool, TaskInstance as TI +from airflow.models import DagBag, DagRun, Pool, TaskInstance as TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from airflow.utils.db import create_session @@ -562,8 +563,8 @@ def test_backfill_run_rescheduled(self): ) job.run() - ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), - execution_date=DEFAULT_DATE) + ti = TaskInstance(task=dag.get_task('test_backfill_run_rescheduled_task-1'), + execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.UP_FOR_RESCHEDULE) @@ -574,8 +575,8 @@ def test_backfill_run_rescheduled(self): rerun_failed_tasks=True ) job.run() - ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), - execution_date=DEFAULT_DATE) + ti = TaskInstance(task=dag.get_task('test_backfill_run_rescheduled_task-1'), + execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) @@ -601,8 +602,8 @@ def test_backfill_rerun_failed_tasks(self): ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TaskInstance(task=dag.get_task('test_backfill_rerun_failed_task-1'), + execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.FAILED) @@ -613,8 +614,8 @@ def test_backfill_rerun_failed_tasks(self): rerun_failed_tasks=True ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TaskInstance(task=dag.get_task('test_backfill_rerun_failed_task-1'), + execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) @@ -641,8 +642,8 @@ def test_backfill_rerun_upstream_failed_tasks(self): ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TaskInstance(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), + execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.UPSTREAM_FAILED) @@ -653,8 +654,8 @@ def test_backfill_rerun_upstream_failed_tasks(self): rerun_failed_tasks=True ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TaskInstance(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), + execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) @@ -680,8 +681,8 @@ def test_backfill_rerun_failed_tasks_without_flag(self): ) job.run() - ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), - execution_date=DEFAULT_DATE) + ti = TaskInstance(task=dag.get_task('test_backfill_rerun_failed_task-1'), + execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.FAILED) @@ -779,7 +780,7 @@ def test_backfill_pooled_tasks(self): job.run() except AirflowTaskTimeout: pass - ti = TI( + ti = TaskInstance( task=dag.get_task('test_backfill_pooled_task'), execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -807,7 +808,7 @@ def test_backfill_depends_on_past(self): ignore_first_depends_on_past=True).run() # ti should have succeeded - ti = TI(dag.tasks[0], run_date) + ti = TaskInstance(dag.tasks[0], run_date) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) @@ -832,7 +833,7 @@ def test_backfill_depends_on_past_backwards(self): **kwargs) job.run() - ti = TI(dag.get_task('test_dop_task'), end_date) + ti = TaskInstance(dag.get_task('test_dop_task'), end_date) ti.refresh_from_db() # runs fine forwards self.assertEqual(ti.state, State.SUCCESS) @@ -1200,19 +1201,19 @@ def test_subdag_clear_parentdag_downstream_clear(self): with timeout(seconds=30): job.run() - ti_subdag = TI( + ti_subdag = TaskInstance( task=dag.get_task('daily_job'), execution_date=DEFAULT_DATE) ti_subdag.refresh_from_db() self.assertEqual(ti_subdag.state, State.SUCCESS) - ti_irrelevant = TI( + ti_irrelevant = TaskInstance( task=dag.get_task('daily_job_irrelevant'), execution_date=DEFAULT_DATE) ti_irrelevant.refresh_from_db() self.assertEqual(ti_irrelevant.state, State.SUCCESS) - ti_downstream = TI( + ti_downstream = TaskInstance( task=dag.get_task('daily_job_downstream'), execution_date=DEFAULT_DATE) ti_downstream.refresh_from_db() @@ -1256,7 +1257,7 @@ def test_backfill_execute_subdag_with_removed_task(self): executor=executor, donot_pickle=True) - removed_task_ti = TI( + removed_task_ti = TaskInstance( task=DummyOperator(task_id='removed_task'), execution_date=DEFAULT_DATE, state=State.REMOVED) @@ -1269,10 +1270,10 @@ def test_backfill_execute_subdag_with_removed_task(self): job.run() for task in subdag.tasks: - instance = session.query(TI).filter( - TI.dag_id == subdag.dag_id, - TI.task_id == task.task_id, - TI.execution_date == DEFAULT_DATE).first() + instance = session.query(TaskInstance).filter( + TaskInstance.dag_id == subdag.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.execution_date == DEFAULT_DATE).first() self.assertIsNotNone(instance) self.assertEqual(instance.state, State.SUCCESS) @@ -1301,7 +1302,7 @@ def test_update_counters(self): execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session) - ti = TI(task1, dr.execution_date) + ti = TaskInstance(task1, dr.execution_date) ti.refresh_from_db() ti_status = BackfillJob._DagRunTaskStatus() @@ -1421,9 +1422,9 @@ def test_backfill_run_backwards(self): job.run() session = settings.Session() - tis = session.query(TI).filter( - TI.dag_id == 'test_start_date_scheduling' and TI.task_id == 'dummy' - ).order_by(TI.execution_date).all() + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == 'test_start_date_scheduling' and TaskInstance.task_id == 'dummy' + ).order_by(TaskInstance.execution_date).all() queued_times = [ti.queued_dttm for ti in tis] self.assertTrue(queued_times == sorted(queued_times, reverse=True)) diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index ce881bcf58cda1..6395c5d6bb8c95 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -22,11 +22,12 @@ import time import unittest -from airflow import AirflowException, models, settings +from ariflow.models import TaskInstance + +from airflow import DAG, AirflowException, models, settings from airflow.configuration import conf -from airflow.executors import SequentialExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs import LocalTaskJob -from airflow.models import DAG, TaskInstance as TI from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from airflow.utils.db import create_session @@ -143,7 +144,7 @@ def heartbeat_recorder(**kwargs): execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session) - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.state = State.RUNNING ti.hostname = get_hostname() @@ -184,7 +185,7 @@ def test_mark_success_no_kill(self): execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session) - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) process = multiprocessing.Process(target=job1.run) @@ -228,7 +229,7 @@ def test_localtaskjob_double_trigger(self): session.merge(ti) session.commit() - ti_run = TI(task=task, execution_date=DEFAULT_DATE) + ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index fa2789ebb508c3..cf6abb9e3869bc 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -29,11 +29,11 @@ from parameterized import parameterized import airflow.example_dags -from airflow import AirflowException, models, settings +from airflow import DAG, AirflowException, models, settings from airflow.configuration import conf from airflow.executors import BaseExecutor from airflow.jobs import BackfillJob, SchedulerJob -from airflow.models import DAG, DagBag, DagModel, DagRun, Pool, SlaMiss, TaskInstance as TI, errors +from airflow.models import DagBag, DagModel, DagRun, Pool, SlaMiss, TaskInstance as TaskInstance, errors from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone @@ -157,7 +157,7 @@ def test_process_executor_events(self, mock_stats_incr): scheduler = SchedulerJob() session = settings.Session() - ti1 = TI(task1, DEFAULT_DATE) + ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.QUEUED session.merge(ti1) session.commit() @@ -200,7 +200,7 @@ def test_execute_task_instances_is_paused_wont_execute(self): session = settings.Session() dr1 = scheduler.create_dag_run(dag) - ti1 = TI(task1, DEFAULT_DATE) + ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.SCHEDULED dr1.state = State.RUNNING dagmodel = models.DagModel() @@ -230,7 +230,7 @@ def test_execute_task_instances_no_dagrun_task_will_execute(self): session = settings.Session() scheduler.create_dag_run(dag) - ti1 = TI(task1, DEFAULT_DATE) + ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.SCHEDULED ti1.execution_date = ti1.execution_date + datetime.timedelta(days=1) session.merge(ti1) @@ -256,7 +256,7 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self): dr1 = scheduler.create_dag_run(dag) dr1.run_id = BackfillJob.ID_PREFIX + '_blah' - ti1 = TI(task1, dr1.execution_date) + ti1 = TaskInstance(task1, dr1.execution_date) ti1.refresh_from_db() ti1.state = State.SCHEDULED session.merge(ti1) @@ -283,9 +283,9 @@ def test_find_executable_task_instances_backfill_nodagrun(self): dr2 = scheduler.create_dag_run(dag) dr2.run_id = BackfillJob.ID_PREFIX + 'asdf' - ti_no_dagrun = TI(task1, DEFAULT_DATE - datetime.timedelta(days=1)) - ti_backfill = TI(task1, dr2.execution_date) - ti_with_dagrun = TI(task1, dr1.execution_date) + ti_no_dagrun = TaskInstance(task1, DEFAULT_DATE - datetime.timedelta(days=1)) + ti_backfill = TaskInstance(task1, dr2.execution_date) + ti_with_dagrun = TaskInstance(task1, dr1.execution_date) # ti_with_paused ti_no_dagrun.state = State.SCHEDULED ti_backfill.state = State.SCHEDULED @@ -323,10 +323,10 @@ def test_find_executable_task_instances_pool(self): dr2 = scheduler.create_dag_run(dag) tis = ([ - TI(task1, dr1.execution_date), - TI(task2, dr1.execution_date), - TI(task1, dr2.execution_date), - TI(task2, dr2.execution_date) + TaskInstance(task1, dr1.execution_date), + TaskInstance(task2, dr1.execution_date), + TaskInstance(task1, dr2.execution_date), + TaskInstance(task2, dr2.execution_date) ]) for ti in tis: ti.state = State.SCHEDULED @@ -364,8 +364,8 @@ def test_find_executable_task_instances_in_default_pool(self): dr1 = scheduler.create_dag_run(dag) dr2 = scheduler.create_dag_run(dag) - ti1 = TI(task=t1, execution_date=dr1.execution_date) - ti2 = TI(task=t2, execution_date=dr2.execution_date) + ti1 = TaskInstance(task=t1, execution_date=dr1.execution_date) + ti2 = TaskInstance(task=t2, execution_date=dr2.execution_date) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED @@ -406,7 +406,7 @@ def test_nonexistent_pool(self): dr = scheduler.create_dag_run(dag) - ti = TI(task, dr.execution_date) + ti = TaskInstance(task, dr.execution_date) ti.state = State.SCHEDULED session.merge(ti) session.commit() @@ -450,9 +450,9 @@ def test_find_executable_task_instances_concurrency(self): dr2 = scheduler.create_dag_run(dag) dr3 = scheduler.create_dag_run(dag) - ti1 = TI(task1, dr1.execution_date) - ti2 = TI(task1, dr2.execution_date) - ti3 = TI(task1, dr3.execution_date) + ti1 = TaskInstance(task1, dr1.execution_date) + ti2 = TaskInstance(task1, dr2.execution_date) + ti3 = TaskInstance(task1, dr3.execution_date) ti1.state = State.RUNNING ti2.state = State.SCHEDULED ti3.state = State.SCHEDULED @@ -494,9 +494,9 @@ def test_find_executable_task_instances_concurrency_queued(self): session = settings.Session() dag_run = scheduler.create_dag_run(dag) - ti1 = TI(task1, dag_run.execution_date) - ti2 = TI(task2, dag_run.execution_date) - ti3 = TI(task3, dag_run.execution_date) + ti1 = TaskInstance(task1, dag_run.execution_date) + ti2 = TaskInstance(task2, dag_run.execution_date) + ti3 = TaskInstance(task3, dag_run.execution_date) ti1.state = State.RUNNING ti2.state = State.QUEUED ti3.state = State.SCHEDULED @@ -532,8 +532,8 @@ def test_find_executable_task_instances_task_concurrency(self): dr2 = scheduler.create_dag_run(dag) dr3 = scheduler.create_dag_run(dag) - ti1_1 = TI(task1, dr1.execution_date) - ti2 = TI(task2, dr1.execution_date) + ti1_1 = TaskInstance(task1, dr1.execution_date) + ti2 = TaskInstance(task2, dr1.execution_date) ti1_1.state = State.SCHEDULED ti2.state = State.SCHEDULED @@ -550,7 +550,7 @@ def test_find_executable_task_instances_task_concurrency(self): ti1_1.state = State.RUNNING ti2.state = State.RUNNING - ti1_2 = TI(task1, dr2.execution_date) + ti1_2 = TaskInstance(task1, dr2.execution_date) ti1_2.state = State.SCHEDULED session.merge(ti1_1) session.merge(ti2) @@ -565,7 +565,7 @@ def test_find_executable_task_instances_task_concurrency(self): self.assertEqual(1, len(res)) ti1_2.state = State.RUNNING - ti1_3 = TI(task1, dr3.execution_date) + ti1_3 = TaskInstance(task1, dr3.execution_date) ti1_3.state = State.SCHEDULED session.merge(ti1_2) session.merge(ti1_3) @@ -646,9 +646,9 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self): dr2 = scheduler.create_dag_run(dag) dr3 = scheduler.create_dag_run(dag) - ti1 = TI(task1, dr1.execution_date) - ti2 = TI(task1, dr2.execution_date) - ti3 = TI(task1, dr3.execution_date) + ti1 = TaskInstance(task1, dr1.execution_date) + ti2 = TaskInstance(task1, dr2.execution_date) + ti3 = TaskInstance(task1, dr3.execution_date) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED ti3.state = State.SCHEDULED @@ -678,9 +678,9 @@ def test_change_state_for_executable_task_instances_none_state(self): dr2 = scheduler.create_dag_run(dag) dr3 = scheduler.create_dag_run(dag) - ti1 = TI(task1, dr1.execution_date) - ti2 = TI(task1, dr2.execution_date) - ti3 = TI(task1, dr3.execution_date) + ti1 = TaskInstance(task1, dr1.execution_date) + ti2 = TaskInstance(task1, dr2.execution_date) + ti3 = TaskInstance(task1, dr3.execution_date) ti1.state = State.SCHEDULED ti2.state = State.QUEUED ti3.state = State.NONE @@ -712,7 +712,7 @@ def test_enqueue_task_instances_with_queued_state(self): dr1 = scheduler.create_dag_run(dag) - ti1 = TI(task1, dr1.execution_date) + ti1 = TaskInstance(task1, dr1.execution_date) session.merge(ti1) session.commit() @@ -732,7 +732,7 @@ def test_execute_task_instances_nothing(self): session = settings.Session() dr1 = scheduler.create_dag_run(dag) - ti1 = TI(task1, dr1.execution_date) + ti1 = TaskInstance(task1, dr1.execution_date) ti1.state = State.SCHEDULED session.merge(ti1) session.commit() @@ -757,8 +757,8 @@ def test_execute_task_instances(self): # create first dag run with 1 running and 1 queued dr1 = scheduler.create_dag_run(dag) - ti1 = TI(task1, dr1.execution_date) - ti2 = TI(task2, dr1.execution_date) + ti1 = TaskInstance(task1, dr1.execution_date) + ti2 = TaskInstance(task2, dr1.execution_date) ti1.refresh_from_db() ti2.refresh_from_db() ti1.state = State.RUNNING @@ -777,8 +777,8 @@ def test_execute_task_instances(self): # create second dag run dr2 = scheduler.create_dag_run(dag) - ti3 = TI(task1, dr2.execution_date) - ti4 = TI(task2, dr2.execution_date) + ti3 = TaskInstance(task1, dr2.execution_date) + ti4 = TaskInstance(task2, dr2.execution_date) ti3.refresh_from_db() ti4.refresh_from_db() # manually set to scheduled so we can pick them up @@ -828,8 +828,8 @@ def test_execute_task_instances_limit(self): tis = [] for _ in range(0, 4): dr = scheduler.create_dag_run(dag) - ti1 = TI(task1, dr.execution_date) - ti2 = TI(task2, dr.execution_date) + ti1 = TaskInstance(task1, dr.execution_date) + ti2 = TaskInstance(task2, dr.execution_date) tis.append(ti1) tis.append(ti2) ti1.refresh_from_db() @@ -886,7 +886,7 @@ def test_change_state_for_tis_without_dagrun(self): ti2.state = State.SCHEDULED session.commit() - ti3 = TI(dag3.get_task('dummy'), DEFAULT_DATE) + ti3 = TaskInstance(dag3.get_task('dummy'), DEFAULT_DATE) ti3.state = State.SCHEDULED session.merge(ti3) session.commit() @@ -958,11 +958,11 @@ def test_change_state_for_tasks_failed_to_execute(self): mock_logger.info.assert_not_called() # Tasks failed to execute with QUEUED state will be set to SCHEDULED state. - session.query(TI).delete() + session.query(TaskInstance).delete() session.commit() key = 'dag_id', 'task_id', DEFAULT_DATE, 1 test_executor.queued_tasks[key] = 'value' - ti = TI(task, DEFAULT_DATE) + ti = TaskInstance(task, DEFAULT_DATE) ti.state = State.QUEUED session.merge(ti) session.commit() @@ -973,7 +973,7 @@ def test_change_state_for_tasks_failed_to_execute(self): self.assertEqual(State.SCHEDULED, ti.state) # Tasks failed to execute with RUNNING state will not be set to SCHEDULED state. - session.query(TI).delete() + session.query(TaskInstance).delete() session.commit() ti.state = State.RUNNING @@ -1111,7 +1111,7 @@ def evaluate_dagrun( # test tasks for task_id, expected_state in expected_task_states.items(): task = dag.get_task(task_id) - ti = TI(task, ex_date) + ti = TaskInstance(task, ex_date) ti.refresh_from_db() self.assertEqual(ti.state, expected_state) @@ -1255,7 +1255,7 @@ def test_scheduler_start_date(self): # zero tasks ran self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 0) + len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0) session.commit() self.assertListEqual([], self.null_exec.sorted_tasks) @@ -1273,7 +1273,7 @@ def test_scheduler_start_date(self): # one task ran self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1) self.assertListEqual( [ ((dag.dag_id, 'dummy', DEFAULT_DATE, 1), State.SUCCESS), @@ -1290,7 +1290,7 @@ def test_scheduler_start_date(self): # still one task self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1) session.commit() self.assertListEqual([], self.null_exec.sorted_tasks) @@ -1309,9 +1309,9 @@ def test_scheduler_task_start_date(self): scheduler.run() session = settings.Session() - tiq = session.query(TI).filter(TI.dag_id == dag_id) - ti1s = tiq.filter(TI.task_id == 'dummy1').all() - ti2s = tiq.filter(TI.task_id == 'dummy2').all() + tiq = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id) + ti1s = tiq.filter(TaskInstance.task_id == 'dummy1').all() + ti2s = tiq.filter(TaskInstance.task_id == 'dummy2').all() self.assertEqual(len(ti1s), 0) self.assertEqual(len(ti2s), 2) for t in ti2s: @@ -1336,7 +1336,7 @@ def test_scheduler_multiprocessing(self): dag_id = 'test_start_date_scheduling' session = settings.Session() self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 0) + len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0) def test_scheduler_dagrun_once(self): """ @@ -1683,7 +1683,7 @@ def test_scheduler_max_active_runs_respected_after_clear(self): (dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER) ) - @patch.object(TI, 'pool_full') + @patch.object(TaskInstance, 'pool_full') def test_scheduler_verify_pool_full(self, mock_pool_full): """ Test task instances not queued when pool is full @@ -1724,7 +1724,7 @@ def test_scheduler_verify_pool_full(self, mock_pool_full): # Recreated part of the scheduler here, to kick off tasks -> executor for ti_key in task_instances_list: task = dag.get_task(ti_key[1]) - ti = TI(task, ti_key[2]) + ti = TaskInstance(task, ti_key[2]) # Task starts out in the scheduled state. All tasks in the # scheduled state will be sent to the executor ti.state = State.SCHEDULED @@ -2135,7 +2135,7 @@ def run_with_error(task, ignore_ti_state=False): # At this point, scheduler has tried to schedule the task once and # heartbeated the executor once, which moved the state of the task from # SCHEDULED to QUEUED and then to SCHEDULED, to fail the task execution - # we need to ignore the TI state as SCHEDULED is not a valid state to start + # we need to ignore the TaskInstance state as SCHEDULED is not a valid state to start # executing task. run_with_error(ti, ignore_ti_state=True) self.assertEqual(ti.state, State.UP_FOR_RETRY) @@ -2185,8 +2185,8 @@ def test_retry_handling_job(self): scheduler.run() session = settings.Session() - ti = session.query(TI).filter(TI.dag_id == dag.dag_id, - TI.task_id == dag_task1.task_id).first() + ti = session.query(TaskInstance).filter(TaskInstance.dag_id == dag.dag_id, + TaskInstance.task_id == dag_task1.task_id).first() # make sure the counter has increased self.assertEqual(ti.try_number, 2) @@ -2215,7 +2215,7 @@ def test_dag_with_system_exit(self): scheduler.run() with create_session() as session: self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1) def test_dag_get_active_runs(self): """ @@ -2690,8 +2690,8 @@ def test_reset_orphaned_tasks_with_orphans(self): dr1_tis = [] dr2_tis = [] for i, (task, state) in enumerate(zip(tasks, states)): - ti1 = TI(task, dr1.execution_date) - ti2 = TI(task, dr2.execution_date) + ti1 = TaskInstance(task, dr1.execution_date) + ti2 = TaskInstance(task, dr2.execution_date) ti1.refresh_from_db() ti2.refresh_from_db() ti1.state = state diff --git a/tests/lineage/backend/test_atlas.py b/tests/lineage/backend/test_atlas.py index a174c6586db998..af0a1c4c7da872 100644 --- a/tests/lineage/backend/test_atlas.py +++ b/tests/lineage/backend/test_atlas.py @@ -20,10 +20,11 @@ import unittest from configparser import DuplicateSectionError +from airflow import DAG from airflow.configuration import AirflowConfigException, conf from airflow.lineage.backend.atlas import AtlasBackend from airflow.lineage.datasets import File -from airflow.models import DAG, TaskInstance as TI +from airflow.models import TaskInstance as TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from tests.compat import mock @@ -69,7 +70,7 @@ def test_lineage_send(self, atlas_mock): inlets={"datasets": inlets_d}, outlets={"datasets": outlets_d}) - ctx = {"ti": TI(task=op1, execution_date=DEFAULT_DATE)} + ctx = {"ti": TaskInstance(task=op1, execution_date=DEFAULT_DATE)} self.atlas.send_lineage(operator=op1, inlets=inlets_d, outlets=outlets_d, context=ctx) diff --git a/tests/lineage/test_lineage.py b/tests/lineage/test_lineage.py index 451d1f32eb35c9..b27f80ddc88f90 100644 --- a/tests/lineage/test_lineage.py +++ b/tests/lineage/test_lineage.py @@ -18,9 +18,10 @@ # under the License. import unittest +from airflow import DAG from airflow.lineage import apply_lineage, prepare_lineage from airflow.lineage.datasets import File -from airflow.models import DAG, TaskInstance as TI +from airflow.models import TaskInstance as TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from tests.compat import mock @@ -64,10 +65,10 @@ def test_lineage(self, _get_backend): op3.set_downstream(op4) op4.set_downstream(op5) - ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE)} - ctx2 = {"ti": TI(task=op2, execution_date=DEFAULT_DATE)} - ctx3 = {"ti": TI(task=op3, execution_date=DEFAULT_DATE)} - ctx5 = {"ti": TI(task=op5, execution_date=DEFAULT_DATE)} + ctx1 = {"ti": TaskInstance(task=op1, execution_date=DEFAULT_DATE)} + ctx2 = {"ti": TaskInstance(task=op2, execution_date=DEFAULT_DATE)} + ctx3 = {"ti": TaskInstance(task=op3, execution_date=DEFAULT_DATE)} + ctx5 = {"ti": TaskInstance(task=op5, execution_date=DEFAULT_DATE)} func = mock.Mock() func.__name__ = 'foo' diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 9d2b99b33b3709..af0a373676ad35 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -26,7 +26,8 @@ import jinja2 from parameterized import parameterized -from airflow.models import DAG, BaseOperator +from airflow import DAG +from airflow.models.baseoperator import BaseOperator from airflow.operators.dummy_operator import DummyOperator from airflow.utils.decorators import apply_defaults from tests.models import DEFAULT_DATE diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index d4acc0c8b5969e..aba8f743f7d3de 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -21,9 +21,9 @@ import os import unittest -from airflow import settings +from airflow import DAG, settings from airflow.configuration import conf -from airflow.models import DAG, TaskInstance as TI, XCom, clear_task_instances +from airflow.models import TaskInstance as TaskInstance, XCom, clear_task_instances from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from airflow.utils.db import create_session @@ -35,21 +35,21 @@ class TestClearTasks(unittest.TestCase): def tearDown(self): with create_session() as session: - session.query(TI).delete() + session.query(TaskInstance).delete() def test_clear_task_instances(self): dag = DAG('test_clear_task_instances', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='0', owner='test', dag=dag) task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ti0 = TaskInstance(task=task0, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=task1, execution_date=DEFAULT_DATE) ti0.run() ti1.run() with create_session() as session: - qry = session.query(TI).filter( - TI.dag_id == dag.dag_id).all() + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id).all() clear_task_instances(qry, session, dag=dag) ti0.refresh_from_db() @@ -65,8 +65,8 @@ def test_clear_task_instances_without_task(self): end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='task0', owner='test', dag=dag) task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ti0 = TaskInstance(task=task0, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=task1, execution_date=DEFAULT_DATE) ti0.run() ti1.run() @@ -76,8 +76,8 @@ def test_clear_task_instances_without_task(self): self.assertFalse(dag.has_task(task1.task_id)) with create_session() as session: - qry = session.query(TI).filter( - TI.dag_id == dag.dag_id).all() + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id).all() clear_task_instances(qry, session) # When dag is None, max_tries will be maximum of original max_tries or try_number. @@ -94,14 +94,14 @@ def test_clear_task_instances_without_dag(self): end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='task_0', owner='test', dag=dag) task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ti0 = TaskInstance(task=task0, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=task1, execution_date=DEFAULT_DATE) ti0.run() ti1.run() with create_session() as session: - qry = session.query(TI).filter( - TI.dag_id == dag.dag_id).all() + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id).all() clear_task_instances(qry, session) # When dag is None, max_tries will be maximum of original max_tries or try_number. @@ -117,7 +117,7 @@ def test_dag_clear(self): dag = DAG('test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag) - ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + ti0 = TaskInstance(task=task0, execution_date=DEFAULT_DATE) # Next try to run will be try 1 self.assertEqual(ti0.try_number, 1) ti0.run() @@ -130,7 +130,7 @@ def test_dag_clear(self): task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2) - ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=task1, execution_date=DEFAULT_DATE) self.assertEqual(ti1.max_tries, 2) ti1.try_number = 1 # Next try will be 2 @@ -156,9 +156,8 @@ def test_dags_clear(self): for i in range(num_of_dags): dag = DAG('test_dag_clear_' + str(i), start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) - ti = TI(task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', - dag=dag), - execution_date=DEFAULT_DATE) + ti = TaskInstance(task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag), + execution_date=DEFAULT_DATE) dags.append(dag) tis.append(ti) @@ -220,8 +219,8 @@ def test_operator_clear(self): t2.set_upstream(t1) - ti1 = TI(task=t1, execution_date=DEFAULT_DATE) - ti2 = TI(task=t2, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=t1, execution_date=DEFAULT_DATE) + ti2 = TaskInstance(task=t2, execution_date=DEFAULT_DATE) ti2.run() # Dependency not met self.assertEqual(ti2.try_number, 1) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 8cfdbc5e0066b1..33f2ac3a7e20f0 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -29,10 +29,10 @@ import pendulum -from airflow import models, settings +from airflow import DAG, AirflowException, models, settings from airflow.configuration import conf -from airflow.exceptions import AirflowDagCycleException, AirflowException, DuplicateTaskIdFound -from airflow.models import DAG, DagModel, TaskInstance as TI +from airflow.exceptions import AirflowDagCycleException, DuplicateTaskIdFound +from airflow.models import DagModel, TaskInstance from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator @@ -383,13 +383,13 @@ def test_get_num_task_instances(self): test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE) test_task = DummyOperator(task_id=test_task_id, dag=test_dag) - ti1 = TI(task=test_task, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=test_task, execution_date=DEFAULT_DATE) ti1.state = None - ti2 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) + ti2 = TaskInstance(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) ti2.state = State.RUNNING - ti3 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=2)) + ti3 = TaskInstance(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=2)) ti3.state = State.QUEUED - ti4 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=3)) + ti4 = TaskInstance(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=3)) ti4.state = State.RUNNING session = settings.Session() session.merge(ti1) diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 67dd9f16955a9c..f5cc073ca4bd1f 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -29,7 +29,7 @@ import airflow.example_dags from airflow import models from airflow.configuration import conf -from airflow.models import DagBag, DagModel, TaskInstance as TI +from airflow.models import DagBag, DagModel, TaskInstance as TaskInstance from airflow.utils.dag_processing import SimpleTaskInstance from airflow.utils.db import create_session from airflow.utils.state import State @@ -315,7 +315,7 @@ def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, def test_load_subdags(self): # Define Dag to load def standard_subdag(): - from airflow.models import DAG + from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator import datetime @@ -370,7 +370,7 @@ def subdag_1(): # Define Dag to load def nested_subdags(): - from airflow.models import DAG + from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator import datetime @@ -468,7 +468,7 @@ def test_skip_cycle_dags(self): # Define Dag to load def basic_cycle(): - from airflow.models import DAG + from airflow import DAG from airflow.operators.dummy_operator import DummyOperator import datetime DAG_NAME = 'cycle_dag' @@ -501,7 +501,7 @@ def basic_cycle(): # Define Dag to load def nested_subdag_cycle(): - from airflow.models import DAG + from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator import datetime @@ -603,18 +603,18 @@ def test_process_file_with_none(self): self.assertEqual([], dagbag.process_file(None)) - @patch.object(TI, 'handle_failure') + @patch.object(TaskInstance, 'handle_failure') def test_kill_zombies(self, mock_ti_handle_failure): """ Test that kill zombies call TIs failure handler with proper context """ dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True) with create_session() as session: - session.query(TI).delete() + session.query(TaskInstance).delete() dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') - ti = TI(task, DEFAULT_DATE, State.RUNNING) + ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) session.add(ti) session.commit() diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 0d9e920d791fd3..10b9e900254667 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -20,9 +20,9 @@ import datetime import unittest -from airflow import models, settings +from airflow import DAG, models, settings from airflow.jobs import BackfillJob -from airflow.models import DAG, DagRun, TaskInstance as TI, clear_task_instances +from airflow.models import DagRun, TaskInstance as TaskInstance, clear_task_instances from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import ShortCircuitOperator from airflow.utils import timezone @@ -71,11 +71,11 @@ def test_clear_task_instances_for_backfill_dagrun(self): self.create_dag_run(dag, execution_date=now, is_backfill=True) task0 = DummyOperator(task_id='backfill_task_0', owner='test', dag=dag) - ti0 = TI(task=task0, execution_date=now) + ti0 = TaskInstance(task=task0, execution_date=now) ti0.run() - qry = session.query(TI).filter( - TI.dag_id == dag.dag_id).all() + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id).all() clear_task_instances(qry, session) session.commit() ti0.refresh_from_db() diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py index 22db8f6eb84527..37ebecb76b0b7f 100644 --- a/tests/models/test_pool.py +++ b/tests/models/test_pool.py @@ -19,10 +19,9 @@ import unittest -from airflow import settings -from airflow.models import DAG +from airflow import DAG, settings from airflow.models.pool import Pool -from airflow.models.taskinstance import TaskInstance as TI +from airflow.models.taskinstance import TaskInstance as TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from airflow.utils.state import State @@ -48,8 +47,8 @@ def test_open_slots(self): start_date=DEFAULT_DATE, ) t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') - ti1 = TI(task=t1, execution_date=DEFAULT_DATE) - ti2 = TI(task=t2, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=t1, execution_date=DEFAULT_DATE) + ti2 = TaskInstance(task=t2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED @@ -72,8 +71,8 @@ def test_infinite_slots(self): start_date=DEFAULT_DATE, ) t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') - ti1 = TI(task=t1, execution_date=DEFAULT_DATE) - ti2 = TI(task=t2, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=t1, execution_date=DEFAULT_DATE) + ti2 = TaskInstance(task=t2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED @@ -98,8 +97,8 @@ def test_default_pool_open_slots(self): start_date=DEFAULT_DATE, ) t1 = DummyOperator(task_id='dummy1', dag=dag) t2 = DummyOperator(task_id='dummy2', dag=dag) - ti1 = TI(task=t1, execution_date=DEFAULT_DATE) - ti2 = TI(task=t2, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=t1, execution_date=DEFAULT_DATE) + ti2 = TaskInstance(task=t2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 05d61b65a1072e..009e91960c922f 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -22,7 +22,8 @@ import unittest from airflow import example_dags as example_dags_module -from airflow.models import DagBag, SerializedDagModel as SDM +from airflow.models import DagBag +from airflow.models.serialized_dag import SerializedDagModel as SDM from airflow.serialization.serialized_dag import SerializedDAG from airflow.utils import db diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 148476112a63d7..a5b21f3fa3b5a4 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -23,8 +23,8 @@ import pendulum -from airflow import settings -from airflow.models import DAG, SkipMixin, TaskInstance as TI +from airflow import DAG, settings +from airflow.models import SkipMixin, TaskInstance as TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from airflow.utils.state import State @@ -55,12 +55,12 @@ def test_skip(self, mock_now): tasks=tasks, session=session) - session.query(TI).filter( - TI.dag_id == 'dag', - TI.task_id == 'task', - TI.state == State.SKIPPED, - TI.start_date == now, - TI.end_date == now, + session.query(TaskInstance).filter( + TaskInstance.dag_id == 'dag', + TaskInstance.task_id == 'task', + TaskInstance.state == State.SKIPPED, + TaskInstance.start_date == now, + TaskInstance.end_date == now, ).one() @patch('airflow.utils.timezone.utcnow') @@ -80,12 +80,12 @@ def test_skip_none_dagrun(self, mock_now): tasks=tasks, session=session) - session.query(TI).filter( - TI.dag_id == 'dag', - TI.task_id == 'task', - TI.state == State.SKIPPED, - TI.start_date == now, - TI.end_date == now, + session.query(TaskInstance).filter( + TaskInstance.dag_id == 'dag', + TaskInstance.task_id == 'task', + TaskInstance.state == State.SKIPPED, + TaskInstance.start_date == now, + TaskInstance.end_date == now, ).one() def test_skip_none_tasks(self): diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 2ad2e52925d79d..45d5ba2b183645 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -29,11 +29,11 @@ from parameterized import param, parameterized from sqlalchemy.orm.session import Session -from airflow import models, settings +from airflow import DAG, AirflowException, models, settings from airflow.configuration import conf from airflow.contrib.sensors.python_sensor import PythonSensor -from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.models import DAG, DagRun, Pool, TaskFail, TaskInstance as TI, TaskReschedule +from airflow.exceptions import AirflowSkipException +from airflow.models import DagRun, Pool, TaskFail, TaskInstance as TaskInstance, TaskReschedule from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator @@ -109,7 +109,7 @@ def test_timezone_awareness(self): # check ti without dag (just for bw compat) op_no_dag = DummyOperator(task_id='op_no_dag') - ti = TI(task=op_no_dag, execution_date=NAIVE_DATETIME) + ti = TaskInstance(task=op_no_dag, execution_date=NAIVE_DATETIME) self.assertEqual(ti.execution_date, DEFAULT_DATE) @@ -117,7 +117,7 @@ def test_timezone_awareness(self): dag = DAG('dag', start_date=DEFAULT_DATE) op1 = DummyOperator(task_id='op_1') dag.add_task(op1) - ti = TI(task=op1, execution_date=NAIVE_DATETIME) + ti = TaskInstance(task=op1, execution_date=NAIVE_DATETIME) self.assertEqual(ti.execution_date, DEFAULT_DATE) @@ -125,7 +125,7 @@ def test_timezone_awareness(self): tz = pendulum.timezone("Europe/Amsterdam") execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tz) utc_date = timezone.convert_to_utc(execution_date) - ti = TI(task=op1, execution_date=execution_date) + ti = TaskInstance(task=op1, execution_date=execution_date) self.assertEqual(ti.execution_date, utc_date) def test_task_naive_datetime(self): @@ -232,8 +232,8 @@ def test_requeue_over_dag_concurrency(self, mock_concurrency_reached): max_active_runs=1, concurrency=2) task = DummyOperator(task_id='test_requeue_over_dag_concurrency_op', dag=dag) - ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) - # TI.run() will sync from DB before validating deps. + ti = TaskInstance(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) + # TaskInstance.run() will sync from DB before validating deps. with create_session() as session: session.add(ti) session.commit() @@ -246,8 +246,8 @@ def test_requeue_over_task_concurrency(self): task = DummyOperator(task_id='test_requeue_over_task_concurrency_op', dag=dag, task_concurrency=0) - ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) - # TI.run() will sync from DB before validating deps. + ti = TaskInstance(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) + # TaskInstance.run() will sync from DB before validating deps. with create_session() as session: session.add(ti) session.commit() @@ -260,8 +260,8 @@ def test_requeue_over_pool_concurrency(self): task = DummyOperator(task_id='test_requeue_over_pool_concurrency_op', dag=dag, task_concurrency=0) - ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) - # TI.run() will sync from DB before validating deps. + ti = TaskInstance(task=task, execution_date=timezone.utcnow(), state=State.QUEUED) + # TaskInstance.run() will sync from DB before validating deps. with create_session() as session: pool = session.query(Pool).filter(Pool.pool == 'test_pool').one() pool.slots = 0 @@ -280,7 +280,7 @@ def test_not_requeue_non_requeueable_task_instance(self): pool='test_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow(), state=State.QUEUED) with create_session() as session: session.add(ti) @@ -322,9 +322,9 @@ def test_mark_non_runnable_task_as_success(self): pool='test_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow(), state=non_runnable_state) - # TI.run() will sync from DB before validating deps. + # TaskInstance.run() will sync from DB before validating deps. with create_session() as session: session.add(ti) session.commit() @@ -339,7 +339,7 @@ def test_run_pooling_task(self): task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, pool='test_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow()) ti.run() @@ -355,7 +355,7 @@ def test_ti_updates_with_task(self, session=None): task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow', executor_config={'foo': 'bar'}, start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow()) ti.run(session=session) @@ -366,7 +366,7 @@ def test_ti_updates_with_task(self, session=None): executor_config={'bar': 'baz'}, start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task2, execution_date=timezone.utcnow()) ti.run(session=session) tis = dag.get_task_instances() @@ -385,7 +385,7 @@ def test_run_pooling_task_with_mark_success(self): pool='test_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow()) ti.run(mark_success=True) self.assertEqual(ti.state, State.SUCCESS) @@ -406,7 +406,7 @@ def raise_skip_exception(): python_callable=raise_skip_exception, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow()) ti.run() self.assertEqual(State.SKIPPED, ti.state) @@ -431,7 +431,7 @@ def run_with_error(ti): except AirflowException: pass - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow()) self.assertEqual(ti.try_number, 1) @@ -469,7 +469,7 @@ def run_with_error(ti): except AirflowException: pass - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow()) self.assertEqual(ti.try_number, 1) @@ -485,7 +485,7 @@ def run_with_error(ti): self.assertEqual(ti._try_number, 2) self.assertEqual(ti.try_number, 3) - # Clear the TI state since you can't run a task with a FAILED state without + # Clear the TaskInstance state since you can't run a task with a FAILED state without # clearing it first dag.clear() @@ -517,7 +517,7 @@ def test_next_retry_datetime(self): dag=dag, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task, execution_date=DEFAULT_DATE) ti.end_date = pendulum.instance(timezone.utcnow()) @@ -561,7 +561,7 @@ def test_next_retry_datetime_short_intervals(self): dag=dag, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( + ti = TaskInstance( task=task, execution_date=DEFAULT_DATE) ti.end_date = pendulum.instance(timezone.utcnow()) @@ -570,7 +570,7 @@ def test_next_retry_datetime_short_intervals(self): period = ti.end_date.add(seconds=1) - ti.end_date.add(seconds=15) self.assertTrue(dt in period) - @patch.object(TI, 'pool_full') + @patch.object(TaskInstance, 'pool_full') def test_reschedule_handling(self, mock_pool_full): """ Test that task reschedules are handled properly @@ -597,7 +597,7 @@ def callable(): pool='test_pool', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertEqual(ti.try_number, 1) @@ -666,7 +666,7 @@ def run_ti_and_assert(run_date, expected_start_date, expected_end_date, done, fail = True, False run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) - @patch.object(TI, 'pool_full') + @patch.object(TaskInstance, 'pool_full') def test_reschedule_handling_clear_reschedules(self, mock_pool_full): """ Test that task reschedules clearing are handled properly @@ -693,7 +693,7 @@ def callable(): pool='test_pool', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertEqual(ti.try_number, 1) @@ -746,7 +746,7 @@ def test_depends_on_past(self): run_date = task.start_date + datetime.timedelta(days=5) - ti = TI(task, run_date) + ti = TaskInstance(task, run_date) # depends_on_past prevents the run task.run(start_date=run_date, end_date=run_date) @@ -820,7 +820,7 @@ def test_check_task_dependencies(self, trigger_rule, successes, skipped, task.set_downstream(downstream) run_date = task.start_date + datetime.timedelta(days=5) - ti = TI(downstream, run_date) + ti = TaskInstance(downstream, run_date) dep_results = TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=successes, @@ -846,12 +846,12 @@ def test_xcom_pull(self): # Push a value task1 = DummyOperator(task_id='test_xcom_1', dag=dag, owner='airflow') - ti1 = TI(task=task1, execution_date=exec_date) + ti1 = TaskInstance(task=task1, execution_date=exec_date) ti1.xcom_push(key='foo', value='bar') # Push another value with the same key (but by a different task) task2 = DummyOperator(task_id='test_xcom_2', dag=dag, owner='airflow') - ti2 = TI(task=task2, execution_date=exec_date) + ti2 = TaskInstance(task=task2, execution_date=exec_date) ti2.xcom_push(key='foo', value='baz') # Pull with no arguments @@ -886,7 +886,7 @@ def test_xcom_pull_after_success(self): owner='airflow', start_date=timezone.datetime(2016, 6, 2, 0, 0, 0)) exec_date = timezone.utcnow() - ti = TI( + ti = TaskInstance( task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) @@ -920,14 +920,14 @@ def test_xcom_pull_different_execution_date(self): owner='airflow', start_date=timezone.datetime(2016, 6, 2, 0, 0, 0)) exec_date = timezone.utcnow() - ti = TI( + ti = TaskInstance( task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() exec_date += datetime.timedelta(days=1) - ti = TI( + ti = TaskInstance( task=task, execution_date=exec_date) ti.run() # We have set a new execution date (and did not pass in @@ -957,7 +957,7 @@ def test_xcom_push_flag(self): owner='airflow', start_date=datetime.datetime(2017, 1, 1) ) - ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1)) + ti = TaskInstance(task=task, execution_date=datetime.datetime(2017, 1, 1)) ti.run() self.assertEqual( ti.xcom_pull( @@ -987,7 +987,7 @@ def post_execute(self, context, result): python_callable=lambda: 'error', owner='airflow', start_date=timezone.datetime(2017, 2, 1)) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) with self.assertRaises(TestError): ti.run() @@ -995,7 +995,7 @@ def post_execute(self, context, result): def test_check_and_change_state_before_execution(self): dag = models.DAG(dag_id='test_check_and_change_state_before_execution') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) - ti = TI( + ti = TaskInstance( task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertTrue(ti._check_and_change_state_before_execution()) @@ -1008,7 +1008,7 @@ def test_check_and_change_state_before_execution_dep_not_met(self): task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE) task >> task2 - ti = TI( + ti = TaskInstance( task=task2, execution_date=timezone.utcnow()) self.assertFalse(ti._check_and_change_state_before_execution()) @@ -1018,7 +1018,7 @@ def test_try_number(self): """ dag = models.DAG(dag_id='test_check_and_change_state_before_execution') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) self.assertEqual(1, ti.try_number) ti.try_number = 2 ti.state = State.RUNNING @@ -1034,9 +1034,9 @@ def test_get_num_running_task_instances(self): task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) task2 = DummyOperator(task_id='task', dag=dag2, start_date=DEFAULT_DATE) - ti1 = TI(task=task, execution_date=DEFAULT_DATE) - ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) - ti3 = TI(task=task2, execution_date=DEFAULT_DATE) + ti1 = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti2 = TaskInstance(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) + ti3 = TaskInstance(task=task2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED ti3.state = State.RUNNING @@ -1053,7 +1053,7 @@ def test_get_num_running_task_instances(self): # now = pendulum.now('Europe/Brussels') # dag = DAG('dag', start_date=DEFAULT_DATE) # task = DummyOperator(task_id='op', dag=dag) - # ti = TI(task=task, execution_date=now) + # ti = TaskInstance(task=task, execution_date=now) # d = urllib.parse.parse_qs( # urllib.parse.urlparse(ti.log_url).query, # keep_blank_values=True, strict_parsing=True) @@ -1064,7 +1064,7 @@ def test_get_num_running_task_instances(self): def test_log_url(self): dag = DAG('dag', start_date=DEFAULT_DATE) task = DummyOperator(task_id='op', dag=dag) - ti = TI(task=task, execution_date=datetime.datetime(2018, 1, 1)) + ti = TaskInstance(task=task, execution_date=datetime.datetime(2018, 1, 1)) expected_url = ( 'http://localhost:8080/log?' @@ -1078,7 +1078,7 @@ def test_mark_success_url(self): now = pendulum.now('Europe/Brussels') dag = DAG('dag', start_date=DEFAULT_DATE) task = DummyOperator(task_id='op', dag=dag) - ti = TI(task=task, execution_date=now) + ti = TaskInstance(task=task, execution_date=now) d = urllib.parse.parse_qs( urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True) @@ -1088,7 +1088,7 @@ def test_mark_success_url(self): def test_overwrite_params_with_dag_run_conf(self): task = DummyOperator(task_id='op') - ti = TI(task=task, execution_date=datetime.datetime.now()) + ti = TaskInstance(task=task, execution_date=datetime.datetime.now()) dag_run = DagRun() dag_run.conf = {"override": True} params = {"override": False} @@ -1099,7 +1099,7 @@ def test_overwrite_params_with_dag_run_conf(self): def test_overwrite_params_with_dag_run_none(self): task = DummyOperator(task_id='op') - ti = TI(task=task, execution_date=datetime.datetime.now()) + ti = TaskInstance(task=task, execution_date=datetime.datetime.now()) params = {"override": False} ti.overwrite_params_with_dag_run_conf(params, None) @@ -1108,7 +1108,7 @@ def test_overwrite_params_with_dag_run_none(self): def test_overwrite_params_with_dag_run_conf_none(self): task = DummyOperator(task_id='op') - ti = TI(task=task, execution_date=datetime.datetime.now()) + ti = TaskInstance(task=task, execution_date=datetime.datetime.now()) params = {"override": False} dag_run = DagRun() @@ -1126,7 +1126,7 @@ def test_email_alert(self, mock_send_email): start_date=DEFAULT_DATE, email='to') - ti = TI(task=task, execution_date=datetime.datetime.now()) + ti = TaskInstance(task=task, execution_date=datetime.datetime.now()) try: ti.run() @@ -1149,7 +1149,7 @@ def test_email_alert_with_config(self, mock_send_email): start_date=DEFAULT_DATE, email='to') - ti = TI( + ti = TaskInstance( task=task, execution_date=datetime.datetime.now()) conf.set('email', 'subject_template', '/subject/path') @@ -1169,7 +1169,7 @@ def test_email_alert_with_config(self, mock_send_email): def test_set_duration(self): task = DummyOperator(task_id='op', email='test@test.test') - ti = TI( + ti = TaskInstance( task=task, execution_date=datetime.datetime.now(), ) @@ -1180,7 +1180,7 @@ def test_set_duration(self): def test_set_duration_empty_dates(self): task = DummyOperator(task_id='op', email='test@test.test') - ti = TI(task=task, execution_date=datetime.datetime.now()) + ti = TaskInstance(task=task, execution_date=datetime.datetime.now()) ti.set_duration() self.assertIsNone(ti.duration) @@ -1196,10 +1196,10 @@ def wrap_task_instance(self, ti): def success_handler(self, context): # pylint: disable=unused-argument self.callback_ran = True session = settings.Session() - temp_instance = session.query(TI).filter( - TI.task_id == self.task_id).filter( - TI.dag_id == self.dag_id).filter( - TI.execution_date == self.execution_date).one() + temp_instance = session.query(TaskInstance).filter( + TaskInstance.task_id == self.task_id).filter( + TaskInstance.dag_id == self.dag_id).filter( + TaskInstance.execution_date == self.execution_date).one() self.task_state_in_callback = temp_instance.state cw = CallbackWrapper() @@ -1207,7 +1207,7 @@ def success_handler(self, context): # pylint: disable=unused-argument end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task = DummyOperator(task_id='op', email='test@test.test', on_success_callback=cw.success_handler, dag=dag) - ti = TI(task=task, execution_date=datetime.datetime.now()) + ti = TaskInstance(task=task, execution_date=datetime.datetime.now()) ti.state = State.RUNNING session = settings.Session() session.merge(ti) @@ -1226,7 +1226,7 @@ def _test_previous_dates_setup(schedule_interval: Union[str, datetime.timedelta, dag = models.DAG(dag_id=dag_id, schedule_interval=schedule_interval, catchup=catchup) task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) - def get_test_ti(session, execution_date: pendulum.datetime, state: str) -> TI: + def get_test_ti(session, execution_date: pendulum.datetime, state: str) -> TaskInstance: dag.create_dagrun( run_id='scheduled__{}'.format(execution_date.to_iso8601_string()), state=state, @@ -1234,7 +1234,7 @@ def get_test_ti(session, execution_date: pendulum.datetime, state: str) -> TI: start_date=pendulum.utcnow(), session=session ) - ti = TI(task=task, execution_date=execution_date) + ti = TaskInstance(task=task, execution_date=execution_date) ti.set_state(state=State.SUCCESS, session=session) return ti @@ -1341,7 +1341,7 @@ def test_pendulum_template_dates(self): start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) task = DummyOperator(task_id='test_pendulum_template_dates_task', dag=dag) - ti = TI(task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) template_context = ti.get_template_context() diff --git a/tests/operators/test_branch_operator.py b/tests/operators/test_branch_operator.py index bde83d24cbcab9..489c5d2292d8f4 100644 --- a/tests/operators/test_branch_operator.py +++ b/tests/operators/test_branch_operator.py @@ -20,7 +20,8 @@ import datetime import unittest -from airflow.models import DAG, DagRun, TaskInstance as TI +from airflow import DAG +from airflow.models import DagRun, TaskInstance as TaskInstance from airflow.operators.branch_operator import BaseBranchOperator from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone @@ -48,7 +49,7 @@ def setUpClass(cls): with create_session() as session: session.query(DagRun).delete() - session.query(TI).delete() + session.query(TaskInstance).delete() def setUp(self): self.dag = DAG('branch_operator_test', @@ -67,7 +68,7 @@ def tearDown(self): with create_session() as session: session.query(DagRun).delete() - session.query(TI).delete() + session.query(TaskInstance).delete() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" @@ -79,9 +80,9 @@ def test_without_dag_run(self): self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag.dag_id, + TaskInstance.execution_date == DEFAULT_DATE ) for ti in tis: @@ -107,9 +108,9 @@ def test_branch_list_without_dag_run(self): self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag.dag_id, + TaskInstance.execution_date == DEFAULT_DATE ) expected = { diff --git a/tests/operators/test_check_operator.py b/tests/operators/test_check_operator.py index 5ff5af6fce297d..803098aa1da7d0 100644 --- a/tests/operators/test_check_operator.py +++ b/tests/operators/test_check_operator.py @@ -20,8 +20,7 @@ import unittest from datetime import datetime -from airflow.exceptions import AirflowException -from airflow.models import DAG +from airflow import DAG, AirflowException from airflow.operators.check_operator import CheckOperator, IntervalCheckOperator, ValueCheckOperator from tests.compat import mock diff --git a/tests/operators/test_dagrun_operator.py b/tests/operators/test_dagrun_operator.py index 586701dc63c75b..11bef2347f72db 100644 --- a/tests/operators/test_dagrun_operator.py +++ b/tests/operators/test_dagrun_operator.py @@ -22,7 +22,8 @@ from datetime import datetime from unittest import TestCase -from airflow.models import DAG, DagModel, DagRun, Log, TaskInstance +from airflow import DAG +from airflow.models import DagModel, DagRun, Log, TaskInstance from airflow.operators.dagrun_operator import TriggerDagRunOperator from airflow.utils import timezone from airflow.utils.db import create_session @@ -32,7 +33,7 @@ TRIGGERED_DAG_ID = "triggerdag" DAG_SCRIPT = ( "from datetime import datetime\n\n" - "from airflow.models import DAG\n" + "from airflow import DAG\n" "from airflow.operators.dummy_operator import DummyOperator\n\n" "dag = DAG(\n" 'dag_id="{dag_id}", \n' diff --git a/tests/operators/test_docker_operator.py b/tests/operators/test_docker_operator.py index b3fb522efb0c76..0af3163d4bf073 100644 --- a/tests/operators/test_docker_operator.py +++ b/tests/operators/test_docker_operator.py @@ -20,7 +20,7 @@ import logging import unittest -from airflow.exceptions import AirflowException +from airflow import AirflowException from tests.compat import mock try: diff --git a/tests/operators/test_docker_swarm_operator.py b/tests/operators/test_docker_swarm_operator.py index 3ce8a02a214220..4cbd6cbab8836c 100644 --- a/tests/operators/test_docker_swarm_operator.py +++ b/tests/operators/test_docker_swarm_operator.py @@ -21,8 +21,8 @@ from docker import APIClient +from airflow import AirflowException from airflow.contrib.operators.docker_swarm_operator import DockerSwarmOperator -from airflow.exceptions import AirflowException from tests.compat import mock diff --git a/tests/operators/test_druid_check_operator.py b/tests/operators/test_druid_check_operator.py index 43d13cc74889e8..51ca3e2c67ce77 100644 --- a/tests/operators/test_druid_check_operator.py +++ b/tests/operators/test_druid_check_operator.py @@ -21,8 +21,7 @@ import unittest from datetime import datetime -from airflow.exceptions import AirflowException -from airflow.models import DAG +from airflow import DAG, AirflowException from airflow.operators.druid_check_operator import DruidCheckOperator from tests.compat import mock diff --git a/tests/operators/test_gcs_to_gcs.py b/tests/operators/test_gcs_to_gcs.py index ee7d1c07f3488b..177f609296aae5 100644 --- a/tests/operators/test_gcs_to_gcs.py +++ b/tests/operators/test_gcs_to_gcs.py @@ -20,7 +20,7 @@ import unittest from datetime import datetime -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.operators.gcs_to_gcs import ( WILDCARD, GoogleCloudStorageSynchronizeBuckets, GoogleCloudStorageToGoogleCloudStorageOperator, ) diff --git a/tests/operators/test_gcs_to_sftp.py b/tests/operators/test_gcs_to_sftp.py index 526d20a563881f..2ffa8927e2814b 100644 --- a/tests/operators/test_gcs_to_sftp.py +++ b/tests/operators/test_gcs_to_sftp.py @@ -21,7 +21,7 @@ import os import unittest -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.operators.gcs_to_sftp import GoogleCloudStorageToSFTPOperator from tests.compat import mock diff --git a/tests/operators/test_http_operator.py b/tests/operators/test_http_operator.py index 36c7b4020ed5d5..f341b34c4cbfda 100644 --- a/tests/operators/test_http_operator.py +++ b/tests/operators/test_http_operator.py @@ -21,7 +21,7 @@ import requests_mock -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.operators.http_operator import SimpleHttpOperator from tests.compat import mock diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 8d30492c5751c8..bd889930fd57a2 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -25,8 +25,8 @@ from collections import namedtuple from datetime import date, timedelta -from airflow.exceptions import AirflowException -from airflow.models import DAG, DagRun, TaskInstance as TI +from airflow import DAG, AirflowException +from airflow.models import DagRun, TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import BranchPythonOperator, PythonOperator, ShortCircuitOperator from airflow.utils import timezone @@ -75,7 +75,7 @@ def setUpClass(cls): with create_session() as session: session.query(DagRun).delete() - session.query(TI).delete() + session.query(TaskInstance).delete() def setUp(self): super().setUp() @@ -94,7 +94,7 @@ def tearDown(self): with create_session() as session: session.query(DagRun).delete() - session.query(TI).delete() + session.query(TaskInstance).delete() def do_run(self): self.run = True @@ -328,7 +328,7 @@ def setUpClass(cls): with create_session() as session: session.query(DagRun).delete() - session.query(TI).delete() + session.query(TaskInstance).delete() def setUp(self): self.dag = DAG('branch_operator_test', @@ -345,7 +345,7 @@ def tearDown(self): with create_session() as session: session.query(DagRun).delete() - session.query(TI).delete() + session.query(TaskInstance).delete() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" @@ -359,9 +359,9 @@ def test_without_dag_run(self): self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag.dag_id, + TaskInstance.execution_date == DEFAULT_DATE ) for ti in tis: @@ -389,9 +389,9 @@ def test_branch_list_without_dag_run(self): self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag.dag_id, + TaskInstance.execution_date == DEFAULT_DATE ) expected = { @@ -502,14 +502,14 @@ def setUpClass(cls): with create_session() as session: session.query(DagRun).delete() - session.query(TI).delete() + session.query(TaskInstance).delete() def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() - session.query(TI).delete() + session.query(TaskInstance).delete() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" @@ -534,9 +534,9 @@ def test_without_dag_run(self): short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == dag.dag_id, - TI.execution_date == DEFAULT_DATE + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.execution_date == DEFAULT_DATE ) for ti in tis: diff --git a/tests/operators/test_s3_file_transform_operator.py b/tests/operators/test_s3_file_transform_operator.py index 2f16dafac548a3..bd2678aec14848 100644 --- a/tests/operators/test_s3_file_transform_operator.py +++ b/tests/operators/test_s3_file_transform_operator.py @@ -30,7 +30,7 @@ import boto3 from moto import mock_s3 -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.operators.s3_file_transform_operator import S3FileTransformOperator diff --git a/tests/operators/test_s3_to_hive_operator.py b/tests/operators/test_s3_to_hive_operator.py index 2f63712642057f..7363ed820f13f8 100644 --- a/tests/operators/test_s3_to_hive_operator.py +++ b/tests/operators/test_s3_to_hive_operator.py @@ -28,7 +28,7 @@ from itertools import product from tempfile import NamedTemporaryFile, mkdtemp -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.operators.s3_to_hive_operator import S3ToHiveTransfer from tests.compat import mock diff --git a/tests/operators/test_slack_operator.py b/tests/operators/test_slack_operator.py index 653369cf49ca66..3b488a69f1323c 100644 --- a/tests/operators/test_slack_operator.py +++ b/tests/operators/test_slack_operator.py @@ -20,7 +20,7 @@ import json import unittest -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.operators.slack_operator import SlackAPIPostOperator from tests.compat import mock diff --git a/tests/operators/test_subdag_operator.py b/tests/operators/test_subdag_operator.py index 0cd92b565e910b..697abd16e18f0c 100644 --- a/tests/operators/test_subdag_operator.py +++ b/tests/operators/test_subdag_operator.py @@ -24,8 +24,8 @@ from parameterized import parameterized import airflow -from airflow.exceptions import AirflowException -from airflow.models import DAG, DagRun, TaskInstance +from airflow import DAG, AirflowException +from airflow.models import DagRun, TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SkippedStatePropagationOptions, SubDagOperator from airflow.utils.db import create_session diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py index d59804dcd91e6d..63bd4b58631b2d 100644 --- a/tests/operators/test_virtualenv_operator.py +++ b/tests/operators/test_virtualenv_operator.py @@ -24,8 +24,7 @@ import funcsigs -from airflow import DAG -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.operators.python_operator import PythonVirtualenvOperator from airflow.utils import timezone diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index e2b1b6773d140a..6c80624be5a907 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -18,7 +18,8 @@ import unittest -from airflow.models import DAG, TaskInstance +from airflow import DAG +from airflow.models import TaskInstance from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator from airflow.utils import timezone diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py index 660bd30174fc6a..06a502fd5c4ca0 100644 --- a/tests/providers/amazon/aws/operators/test_datasync.py +++ b/tests/providers/amazon/aws/operators/test_datasync.py @@ -21,8 +21,8 @@ import boto3 -from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance +from airflow import DAG, AirflowException +from airflow.models import TaskInstance from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook from airflow.providers.amazon.aws.operators.datasync import ( AWSDataSyncCreateTaskOperator, AWSDataSyncDeleteTaskOperator, AWSDataSyncGetTasksOperator, diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py index 022661c06698ee..bbee9bf8ee523f 100644 --- a/tests/providers/amazon/aws/sensors/test_sqs.py +++ b/tests/providers/amazon/aws/sensors/test_sqs.py @@ -23,8 +23,7 @@ from moto import mock_sqs -from airflow import DAG -from airflow.exceptions import AirflowException +from airflow import DAG, AirflowException from airflow.providers.amazon.aws.hooks.sqs import SQSHook from airflow.providers.amazon.aws.sensors.sqs import SQSSensor from airflow.utils import timezone diff --git a/tests/providers/google/cloud/operators/test_sftp_to_gcs.py b/tests/providers/google/cloud/operators/test_sftp_to_gcs.py index e79187460eb88b..daf16fbc812e22 100644 --- a/tests/providers/google/cloud/operators/test_sftp_to_gcs.py +++ b/tests/providers/google/cloud/operators/test_sftp_to_gcs.py @@ -21,7 +21,7 @@ import os import unittest -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.providers.google.cloud.operators.sftp_to_gcs import SFTPToGoogleCloudStorageOperator from tests.compat import mock diff --git a/tests/sensors/test_base_sensor.py b/tests/sensors/test_base_sensor.py index d5b92ee8b09904..84ce287dda52ef 100644 --- a/tests/sensors/test_base_sensor.py +++ b/tests/sensors/test_base_sensor.py @@ -24,8 +24,8 @@ from freezegun import freeze_time -from airflow import DAG, settings -from airflow.exceptions import AirflowException, AirflowRescheduleException, AirflowSensorTimeout +from airflow import DAG, AirflowException, settings +from airflow.exceptions import AirflowRescheduleException, AirflowSensorTimeout from airflow.models import DagRun, TaskInstance, TaskReschedule from airflow.operators.dummy_operator import DummyOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 4a84b97fc5e2fa..45198776ae6881 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -19,8 +19,8 @@ import unittest from datetime import time, timedelta -from airflow import DAG, exceptions, settings -from airflow.exceptions import AirflowException, AirflowSensorTimeout +from airflow import DAG, AirflowException, settings +from airflow.exceptions import AirflowSensorTimeout from airflow.models import DagBag, TaskInstance from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator @@ -144,7 +144,6 @@ def test_external_task_sensor_fn_multiple_execution_dates(self): ignore_ti_state=True) session = settings.Session() - TI = TaskInstance try: task_external_with_failure.run( start_date=DEFAULT_DATE, @@ -154,10 +153,10 @@ def test_external_task_sensor_fn_multiple_execution_dates(self): # once per minute (the run on the first second of # each minute). except Exception as e: - failed_tis = session.query(TI).filter( - TI.dag_id == dag_external_id, - TI.state == State.FAILED, - TI.execution_date == DEFAULT_DATE + timedelta(seconds=1)).all() + failed_tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag_external_id, + TaskInstance.state == State.FAILED, + TaskInstance.execution_date == DEFAULT_DATE + timedelta(seconds=1)).all() if len(failed_tis) == 1 and \ failed_tis[0].task_id == 'task_external_with_failure': pass @@ -246,7 +245,7 @@ def test_external_task_sensor_fn(self): poke_interval=1, dag=self.dag ) - with self.assertRaises(exceptions.AirflowSensorTimeout): + with self.assertRaises(AirflowSensorTimeout): t2.run( start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, diff --git a/tests/sensors/test_http_sensor.py b/tests/sensors/test_http_sensor.py index d249247dc9cb1b..82c08c262e23c8 100644 --- a/tests/sensors/test_http_sensor.py +++ b/tests/sensors/test_http_sensor.py @@ -21,8 +21,8 @@ import requests -from airflow import DAG -from airflow.exceptions import AirflowException, AirflowSensorTimeout +from airflow import DAG, AirflowException +from airflow.exceptions import AirflowSensorTimeout from airflow.models import TaskInstance from airflow.operators.http_operator import SimpleHttpOperator from airflow.sensors.http_sensor import HttpSensor diff --git a/tests/sensors/test_s3_key_sensor.py b/tests/sensors/test_s3_key_sensor.py index 4a5418aea34c79..5297f9d575fd82 100644 --- a/tests/sensors/test_s3_key_sensor.py +++ b/tests/sensors/test_s3_key_sensor.py @@ -22,7 +22,7 @@ from parameterized import parameterized -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.sensors.s3_key_sensor import S3KeySensor diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py index 336833dd32feb8..1dc5813d1fdf58 100644 --- a/tests/sensors/test_sql_sensor.py +++ b/tests/sensors/test_sql_sensor.py @@ -19,9 +19,8 @@ import unittest from unittest import mock -from airflow import DAG +from airflow import DAG, AirflowException from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.sensors.sql_sensor import SqlSensor from airflow.utils.timezone import datetime diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 3c3b91c532f8ad..f1e41bfe1a20a3 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -31,10 +31,12 @@ from airflow.contrib import example_dags as contrib_example_dags from airflow.gcp import example_dags as gcp_example_dags from airflow.hooks.base_hook import BaseHook -from airflow.models import DAG, BaseOperator, Connection, DagBag +from airflow.models import DAG, Connection, DagBag +from airflow.models.baseoperator import BaseOperator from airflow.operators.bash_operator import BashOperator from airflow.operators.subdag_operator import SubDagOperator -from airflow.serialization import SerializedBaseOperator, SerializedDAG +from airflow.serialization.serialized_baseoperator import SerializedBaseOperator +from airflow.serialization.serialized_dag import SerializedDAG from airflow.utils.tests import CustomBaseOperator, GoogleLink serialized_simple_dag_ground_truth = { diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index de7416c5be50c9..745f192fcb254e 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -26,7 +26,7 @@ from airflow import models, settings from airflow.jobs import LocalTaskJob -from airflow.models import TaskInstance as TI +from airflow.models import TaskInstance as TaskInstance from airflow.task.task_runner import StandardTaskRunner from airflow.utils import timezone from airflow.utils.state import State @@ -115,7 +115,7 @@ def test_on_kill(self): execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session) - ti = TI(task=task, execution_date=DEFAULT_DATE) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) runner = StandardTaskRunner(job1) diff --git a/tests/ti_deps/deps/test_dagrun_exists_dep.py b/tests/ti_deps/deps/test_dagrun_exists_dep.py index 23dfae1868bfcc..44dcb0bfd2c4c0 100644 --- a/tests/ti_deps/deps/test_dagrun_exists_dep.py +++ b/tests/ti_deps/deps/test_dagrun_exists_dep.py @@ -20,7 +20,8 @@ import unittest from unittest.mock import Mock, patch -from airflow.models import DAG, DagRun +from airflow import DAG +from airflow.models import DagRun from airflow.ti_deps.deps.dagrun_exists_dep import DagrunRunningDep from airflow.utils.state import State diff --git a/tests/ti_deps/deps/test_prev_dagrun_dep.py b/tests/ti_deps/deps/test_prev_dagrun_dep.py index 53cc822a58e633..1425da6e1c2536 100644 --- a/tests/ti_deps/deps/test_prev_dagrun_dep.py +++ b/tests/ti_deps/deps/test_prev_dagrun_dep.py @@ -21,7 +21,8 @@ from datetime import datetime from unittest.mock import Mock -from airflow.models import DAG, BaseOperator +from airflow import DAG +from airflow.models.baseoperator import BaseOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.utils.state import State @@ -68,7 +69,7 @@ def test_context_ignore_depends_on_past(self): def test_first_task_run(self): """ - The first task run for a TI should pass since it has no previous dagrun. + The first task run for a TaskInstance should pass since it has no previous dagrun. """ task = self._get_task(depends_on_past=True, start_date=datetime(2016, 1, 1), @@ -82,7 +83,7 @@ def test_first_task_run(self): def test_prev_ti_bad_state(self): """ - If the previous TI did not complete execution this dep should fail. + If the previous TaskInstance did not complete execution this dep should fail. """ task = self._get_task(depends_on_past=True, start_date=datetime(2016, 1, 1), @@ -97,9 +98,9 @@ def test_prev_ti_bad_state(self): def test_failed_wait_for_downstream(self): """ - If the previous TI specified to wait for the downstream tasks of the + If the previous TaskInstance specified to wait for the downstream tasks of the previous dagrun then it should fail this dep if the downstream TIs of - the previous TI are not done. + the previous TaskInstance are not done. """ task = self._get_task(depends_on_past=True, start_date=datetime(2016, 1, 1), diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index 0408a05ed265f4..dd0e8ee9c94492 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -21,7 +21,8 @@ from datetime import timedelta from unittest.mock import Mock, patch -from airflow.models import DAG, TaskInstance, TaskReschedule +from airflow import DAG +from airflow.models import TaskInstance, TaskReschedule from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.utils.state import State diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py index 76987ca2f72320..2232252b581495 100644 --- a/tests/ti_deps/deps/test_task_concurrency.py +++ b/tests/ti_deps/deps/test_task_concurrency.py @@ -21,7 +21,8 @@ from datetime import datetime from unittest.mock import Mock -from airflow.models import DAG, BaseOperator +from airflow.models import DAG +from airflow.models.baseoperator import BaseOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index e86fb3e8cbea4b..1c09c73a1ac514 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -20,7 +20,8 @@ import unittest from datetime import datetime -from airflow.models import BaseOperator, TaskInstance +from airflow.models import TaskInstance +from airflow.models.baseoperator import BaseOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils.db import create_session from airflow.utils.state import State @@ -39,7 +40,7 @@ def _get_task_instance(self, trigger_rule=TriggerRule.ALL_SUCCESS, def test_no_upstream_tasks(self): """ - If the TI has no upstream TIs then there is nothing to check and the dep is passed + If the TaskInstance has no upstream TIs then there is nothing to check and the dep is passed """ ti = self._get_task_instance(TriggerRule.ALL_DONE, State.UP_FOR_RETRY) self.assertTrue(TriggerRuleDep().is_met(ti=ti)) diff --git a/tests/utils/log/test_es_task_handler.py b/tests/utils/log/test_es_task_handler.py index 64126b2096b644..1df21a396aa900 100644 --- a/tests/utils/log/test_es_task_handler.py +++ b/tests/utils/log/test_es_task_handler.py @@ -26,8 +26,9 @@ import elasticsearch import pendulum +from airflow import DAG from airflow.configuration import conf -from airflow.models import DAG, TaskInstance +from airflow.modesl import TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from airflow.utils.log.es_task_handler import ElasticsearchTaskHandler diff --git a/tests/utils/log/test_s3_task_handler.py b/tests/utils/log/test_s3_task_handler.py index 1e5955bebf0ae0..c5f9f5802c4fcd 100644 --- a/tests/utils/log/test_s3_task_handler.py +++ b/tests/utils/log/test_s3_task_handler.py @@ -21,7 +21,8 @@ import unittest from unittest import mock -from airflow.models import DAG, TaskInstance +from airflow import DAG +from airflow.models import TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils.log.s3_task_handler import S3TaskHandler diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index c34810925d461e..7fa8c33119591b 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -28,7 +28,7 @@ from airflow.configuration import conf from airflow.jobs import DagFileProcessor, LocalTaskJob as LJ -from airflow.models import DagBag, TaskInstance as TI +from airflow.models import DagBag, TaskInstance as TaskInstance from airflow.utils import timezone from airflow.utils.dag_processing import ( DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, SimpleTaskInstance, correct_maybe_zipped, @@ -195,7 +195,7 @@ def test_find_zombies(self): dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') - ti = TI(task, DEFAULT_DATE, State.RUNNING) + ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) lj = LJ(ti) lj.state = State.SHUTDOWN lj.id = 1 @@ -215,7 +215,7 @@ def test_find_zombies(self): self.assertEqual(ti.task_id, zombies[0].task_id) self.assertEqual(ti.execution_date, zombies[0].execution_date) - session.query(TI).delete() + session.query(TaskInstance).delete() session.query(LJ).delete() def test_zombies_are_correctly_passed_to_dag_file_processor(self): @@ -231,7 +231,7 @@ def test_zombies_are_correctly_passed_to_dag_file_processor(self): dag = dagbag.get_dag('test_example_bash_operator') task = dag.get_task(task_id='run_this_last') - ti = TI(task, DEFAULT_DATE, State.RUNNING) + ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) lj = LJ(ti) lj.state = State.SHUTDOWN lj.id = 1 diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 6623ac3cf5c176..15e722288a147f 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -19,7 +19,7 @@ import unittest -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.utils.decorators import apply_defaults diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 33850833026b3f..4ae11cfc4dec4b 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -27,8 +27,7 @@ import psutil -from airflow import DAG -from airflow.exceptions import AirflowException +from airflow import DAG, AirflowException from airflow.models import TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.utils import helpers diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index cb4ca31f38dcf1..ea22d1bc91c870 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -22,8 +22,9 @@ import os import unittest +from airflow import DAG from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.models import DAG, DagRun, TaskInstance +from airflow.models import DagRun, TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator from airflow.utils.db import create_session diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index c893a24a7d279e..d916d08a58cd4a 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -22,16 +22,14 @@ from sqlalchemy.exc import StatementError -from airflow import settings -from airflow.models import DAG -from airflow.settings import Session +from airflow import DAG, settings from airflow.utils.state import State from airflow.utils.timezone import utcnow class TestSqlAlchemyUtils(unittest.TestCase): def setUp(self): - session = Session() + session = settings.Session() # make sure NOT to run in UTC. Only postgres supports storing # timezone information in the datetime field diff --git a/tests/www/test_security.py b/tests/www/test_security.py index 55cb95bc6a0402..db36c35be3aa6e 100644 --- a/tests/www/test_security.py +++ b/tests/www/test_security.py @@ -28,7 +28,7 @@ from flask_appbuilder.views import BaseView, ModelView from sqlalchemy import Column, Date, Float, Integer, String -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.www.security import AirflowSecurityManager READ_WRITE = {'can_dag_read', 'can_dag_edit'} diff --git a/tests/www/test_views.py b/tests/www/test_views.py index 5213cf89a10bc8..60cc559f7a6b95 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -38,13 +38,13 @@ from werkzeug.test import Client from werkzeug.wrappers import BaseResponse -from airflow import models, settings +from airflow import DAG, models, settings from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.configuration import conf from airflow.executors.celery_executor import CeleryExecutor from airflow.jobs import BaseJob -from airflow.models import DAG, BaseOperator, Connection, DagRun, TaskInstance -from airflow.models.baseoperator import BaseOperatorLink +from airflow.models import Connection, DagRun, TaskInstance +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.operators.dummy_operator import DummyOperator from airflow.settings import Session from airflow.ti_deps.dep_context import QUEUEABLE_STATES, RUNNABLE_STATES @@ -583,7 +583,7 @@ def test_run(self): resp = self.client.post('run', data=form) self.check_content_in_response('', resp, resp_code=302) - @mock.patch('airflow.executors.get_default_executor') + @mock.patch('airflow.executors.all_executors.AllExecutors.get_default_executor') def test_run_with_runnable_states(self, get_default_executor_function): executor = CeleryExecutor() executor.heartbeat = lambda: True @@ -613,7 +613,7 @@ def test_run_with_runnable_states(self, get_default_executor_function): .format(state) + "The task must be cleared in order to be run" self.assertFalse(re.search(msg, resp.get_data(as_text=True))) - @mock.patch('airflow.executors.get_default_executor') + @mock.patch('airflow.executors.all_excutors.AllExecutors.get_default_executor') def test_run_with_not_runnable_states(self, get_default_executor_function): get_default_executor_function.return_value = CeleryExecutor()