diff --git a/.gitignore b/.gitignore index ae117927b23e3e..04b080904bb141 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ rat-results.txt *.generated *.tar.gz scripts/ci/kubernetes/kube/.generated/airflow.yaml +scripts/ci/kubernetes/docker/requirements.txt # Node & Webpack Stuff *.entry.js diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 83c51adf6a3251..09b48e3a4bbc6c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,8 +35,8 @@ repos: files: ^.*LICENSE.*$|^.*LICENCE.*$ pass_filenames: false require_serial: true - - repo: https://github.com/potiuk/pre-commit-hooks - rev: d5bee8590ea3405a299825e68dbfa0e1b951a2be + - repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.1.7 hooks: - id: forbid-tabs exclude: ^airflow/_vendor/.*$|^docs/Makefile$ @@ -92,10 +92,20 @@ repos: - license-templates/LICENSE.txt - --fuzzy-match-generates-todo - id: insert-license - name: Add licence for shell files + name: Add licence for all shell files exclude: ^\.github/.*$"|^airflow/_vendor/.*$ types: [shell] - files: ^breeze$|^breeze-complete$ + files: ^breeze$|^breeze-complete$|\.sh$ + args: + - --comment-style + - "|#|" + - --license-filepath + - license-templates/LICENSE.txt + - --fuzzy-match-generates-todo + - id: insert-license + name: Add licence for all python files + exclude: ^\.github/.*$"|^airflow/_vendor/.*$ + types: [python] args: - --comment-style - "|#|" @@ -113,7 +123,7 @@ repos: - license-templates/LICENSE.txt - --fuzzy-match-generates-todo - id: insert-license - name: Add licence for yaml files + name: Add licence for all yaml files exclude: ^\.github/.*$"|^airflow/_vendor/.*$ types: [yaml] args: @@ -167,7 +177,8 @@ repos: language: docker_image entry: koalaman/shellcheck:stable -x -a types: [shell] - files: ^breeze$|^breeze-complete$ + files: ^breeze$|^breeze-complete$|\.sh$ + exclude: ^airflow/_vendor/.*$ - id: lint-dockerfile name: Lint dockerfile language: system diff --git a/BREEZE.rst b/BREEZE.rst index faf1fd574e5ff7..2758eac1f0aa47 100644 --- a/BREEZE.rst +++ b/BREEZE.rst @@ -19,35 +19,7 @@ :align: center :alt: Airflow Breeze Logo - -Table of Contents -================= - -* `Airflow Breeze <#airflow-breeze>`_ -* `Installation <#installation>`_ -* `Resource usage <#resource-usage>`_ -* `Setting up autocomplete <#setting-up-autocomplete>`_ -* `Using the Airflow Breeze environment <#using-the-airflow-breeze-environment>`_ - - `Entering the environment <#entering-the-environment>`_ - - `Running tests in Airflow Breeze <#running-tests-in-airflow-breeze>`_ - - `Debugging with ipdb <#debugging-with-ipdb>`_ - - `Airflow directory structure in Docker <#airflow-directory-structure-inside-docker>`_ - - `Port forwarding <#port-forwarding>`_ -* `Using your host IDE <#using-your-host-ide>`_ - - `Configuring local virtualenv <#configuring-local-virtualenv>`_ - - `Running unit tests via IDE <#running-unit-tests-via-ide>`_ - - `Debugging Airflow Breeze Tests in IDE <#debugging-airflow-breeze-tests-in-ide>`_ -* `Running commands via Airflow Breeze <#running-commands-via-airflow-breeze>`_ - - `Running static code checks <#running-static-code-checks>`_ - - `Building the documentation <#building-the-documentation>`_ - - `Running tests <#running-tests>`_ - - `Running commands inside Docker <#running-commands-inside-docker>`_ - - `Running Docker Compose commands <#running-docker-compose-commands>`_ - - `Convenience scripts <#convenience-scripts>`_ -* `Keeping images up-to-date <#keeping-images-up-to-date>`_ - - `Updating dependencies <#updating-dependencies>`_ - - `Pulling the images <#pulling-the-images>`_ -* `Airflow Breeze flags <#airflow-breeze-flags>`_ +.. contents:: :local: Airflow Breeze ============== @@ -56,7 +28,7 @@ Airflow Breeze is an easy-to-use integration test environment managed via `Docker Compose `_ . The environment is easy to use locally and it is also used by Airflow's CI Travis tests. -It's called **Airflow Breeze** as in "It's a *Breeze* to develop Airflow" +It's called *Airflow Breeze* as in **It's a Breeze to develop Airflow** The advantages and disadvantages of using the environment vs. other ways of testing Airflow are described in `CONTRIBUTING.md `_. @@ -71,120 +43,125 @@ Here is the short 10 minute video about Airflow Breeze :align: center :target: http://www.youtube.com/watch?v=ffKFHV6f3PQ +Prerequisites +============= -Installation -============ +Docker +------ -Prerequisites for the installation: +You need latest stable Docker Community Edition installed and on the PATH. It should be +configured to be able to run ``docker`` commands directly and not only via root user. Your user +should be in the ``docker`` group. See `Docker installation guide `_ +When you develop on Mac OS you usually have not enough disk space for Docker if you start using it +seriously. You should increase disk space available before starting to work with the environment. +Usually you have weird problems of docker containers when you run out of Disk space. It might not be +obvious that space is an issue. At least 128 GB of Disk space is recommended. You can also get by with smaller space but you should more +often clean the docker disk space periodically. -* - If you are on MacOS you need gnu ``getopt`` and ``gstat`` to get Airflow Breeze running. Typically - you need to run ``brew install gnu-getopt coreutils`` and then follow instructions (you need - to link the gnu getopt version to become first on the PATH). Make sure to re-login after you - make the suggested changes. +If you get into weird behaviour try `Cleaning up the images <#cleaning-up-the-images>`_. -* - Latest stable Docker Community Edition installed and on the PATH. It should be - configured to be able to run ``docker`` commands directly and not only via root user. Your user - should be in the ``docker`` group. See `Docker installation guide `_ +See also `Docker for Mac - Space `_ for details of increasing +disk space available for Docker on Mac. +Docker compose +-------------- -* - When you develop on Mac OS you usually have not enough disk space for Docker if you start using it - seriously. You should increase disk space available before starting to work with the environment. - Usually you have weird problems of docker containers when you run out of Disk space. It might not be - obvious that space is an issue. +Latest stable Docker Compose installed and on the PATH. It should be +configured to be able to run ``docker-compose`` command. +See `Docker compose installation guide `_ - If you get into weird behaviour try - `Cleaning Up Docker `_ +Getopt and gstat +---------------- - See `Docker for Mac - Space `_ for details of increasing - disk space available for Docker on Mac. +* If you are on MacOS - At least 128 GB of Disk space is recommended. You can also get by with smaller space but you should more - often clean the docker disk space periodically. + * you need gnu ``getopt`` and ``gstat`` to get Airflow Breeze running. -* - On MacOS, the default 2GB of RAM available for your docker containers, but more memory is recommended - (4GB should be comfortable). For details see - `Docker for Mac - Advanced tab `_ + * Typically you need to run ``brew install gnu-getopt coreutils`` and then follow instructions (you need to link the gnu getopt + version to become first on the PATH). Make sure to re-login after yoy make the suggested changes. -* - Latest stable Docker Compose installed and on the PATH. It should be - configured to be able to run ``docker-compose`` command. - See `Docker compose installation guide `_ + * Then (with brew) link the gnu-getopt to become default as suggested by brew. + * If you use bash, you should run this command (and re-login): -Your entry point for Airflow Breeze is `./breeze <./breeze>`_ -script. You can run it with ``-h`` option to see the list of available flags. -You can add the checked out airflow repository to your PATH to run breeze -without the ./ and from any directory if you have only one airflow directory checked out. +.. code-block:: bash -See `Airflow Breeze flags <#airflow-breeze-flags>`_ for details. + echo 'export PATH="/usr/local/opt/gnu-getopt/bin:$PATH"' >> ~/.bash_profile + . ~/.bash_profile -First time you run `./breeze <./breeze>`_ script, it will pull and build local version of docker images. -It will pull latest Airflow CI images from `Apache Airflow DockerHub `_ -and use them to build your local docker images. It will use latest sources from your source code. -Further on ``breeze`` uses md5sum calculation and Docker caching mechanisms to only rebuild what is needed. -Airflow Breeze will detect if Docker images need to be rebuilt and ask you for confirmation. + * If you use zsh, you should run this command ((and re-login): -Resource usage -============== +.. code-block:: bash -You can choose environment when you run Breeze with ``--env`` flag. -Running the default ``docker`` environment takes considerable amount of resources. You can run a slimmed-down -version of the environment - just the Apache Airflow container - by choosing ``bare`` environment instead. + echo 'export PATH="/usr/local/opt/gnu-getopt/bin:$PATH"' >> ~/.zprofile + . ~/.zprofile -The following environments are available: +* If you are on Linux - * The ``docker`` environment (default): starts all dependencies required by full integration test-suite - (postgres, mysql, celery, etc.). This option is resource intensive so do not forget to - [Stop environment](#stopping-the-environment) when you are finished. This option is also RAM intensive - and can slow down your machine. - * The ``kubernetes`` environment: Runs airflow tests within a kubernetes cluster (requires - ``KUBERNETES_VERSION`` and ``KUBERNETES_MODE`` variables). - * The ``bare`` environment: runs airflow in docker without any external dependencies. - It will only work for non-dependent tests. You can only run it with sqlite backend. + * run ``apt install util-linux coreutils`` or equivalent if your system is not Debian-based. + +Memory +------ + +Minimum 4GB RAM is required to run the full ``docker`` environment. + +On MacOS, the default 2GB of RAM available for your docker containers, but more memory is recommended +(4GB should be comfortable). For details see +`Docker for Mac - Advanced tab `_ + +How Breeze works +================ + +Entering Breeze +--------------- + +Your entry point for Airflow Breeze is `./breeze <./breeze>`_ script. You can run it with ``--help`` +option to see the list of available flags. See `Airflow Breeze flags <#airflow-breeze-flags>`_ for details. -After starting up, the environment runs in the background and takes resources. You can always stop -it via +You can also `Set up autocomplete <#setting-up-autocomplete>`_ for the command and add the +checked-out airflow repository to your PATH to run breeze without the ./ and from any directory. + +First time you run Breeze, it will pull and build local version of docker images. +It will pull latest Airflow CI images from `Airflow DockerHub `_ +and use them to build your local docker images. + +Stopping Breeze +--------------- + +After starting up, the environment runs in the background and takes precious memory. +You can always stop it via: .. code-block:: bash - breeze --stop-environment + ./breeze --stop-environment + + +Using the Airflow Breeze environment for testing +================================================ Setting up autocomplete -======================= +----------------------- -The ``breeze`` command comes with built-in bash/zsh autocomplete. When you start typing -`./breeze <./breeze>`_ command you can use to show all the available switches +The ``breeze`` command comes with built-in bash/zsh autocomplete for its flags. When you start typing +the command you can use to show all the available switches nd to get autocompletion on typical values of parameters that you can use. -You can setup auto-complete automatically by running this command (-a is shortcut for --setup-autocomplete): +You can setup auto-complete automatically by running: .. code-block:: bash ./breeze --setup-autocomplete - You get autocomplete working when you re-enter the shell. Zsh autocompletion is currently limited to only autocomplete flags. Bash autocompletion also completes flag values (for example python version or static check name). - -Using the Airflow Breeze environment -==================================== - Entering the environment ------------------------ -You enter the integration test environment by running the `./breeze <./breeze>`_ script. - -You can specify python version to use, backend to use and environment for testing - so that you can -recreate the same environments as we have in matrix builds in Travis CI. The defaults when you -run the environment are reasonable (python 3.6, sqlite, docker). +You enter the integration test environment by running the ``./breeze`` script. What happens next is the appropriate docker images are pulled, local sources are used to build local version of the image and you are dropped into bash shell of the airflow container - @@ -195,7 +172,15 @@ on a fast connection to start. Subsequent runs should be much faster. ./breeze -You can choose the optional flags you need with `./breeze <./breeze>`_. +Once you enter the environment you are dropped into bash shell and you can run tests immediately. + +Choosing environment +-------------------- + +You can choose the optional flags you need with ``breeze`` + +You can specify for example python version to use, backend to use and environment +for testing - you can recreate the same environments as we have in matrix builds in Travis CI. For example you could choose to run python 3.6 tests with mysql as backend and in docker environment by: @@ -205,17 +190,21 @@ environment by: ./breeze --python 3.6 --backend mysql --env docker The choices you made are persisted in ``./.build/`` cache directory so that next time when you use the -`./breeze <./breeze>`_ script, it will use the values that were used previously. This way you do not -have to specify them when you run the script. You can delete the ``./.build/`` in case you want to +``breeze`` script, it will use the values that were used previously. This way you do not +have to specify them when you run the script. You can delete the ``.build/`` directory in case you want to restore default settings. -Relevant sources of airflow are mounted inside the ``airflow-testing`` container that you enter, +The defaults when you run the environment are reasonable (python 3.6, sqlite, docker). + +Mounting local sources to Breeze +-------------------------------- + +Important sources of airflow are mounted inside the ``airflow-testing`` container that you enter, which means that you can continue editing your changes in the host in your favourite IDE and have them visible in docker immediately and ready to test without rebuilding images. This can be disabled by specifying ``--skip-mounting-source-volume`` flag when running breeze, in which case you will have sources embedded in the container - and changes to those sources will not be persistent. -Once you enter the environment you are dropped into bash shell and you can run tests immediately. After you run Breeze for the first time you will have an empty directory ``files`` in your source code that will be mapped to ``/files`` in your docker container. You can pass any files there you need @@ -240,7 +229,6 @@ or a single test method: run-tests tests.core:TestCore.test_check_operators -- -s --logging-level=DEBUG - The tests will run ``airflow db reset`` and ``airflow db init`` the first time you run tests in running container, so you can count on database being initialized. @@ -254,6 +242,44 @@ the database. run-tests --with-db-init tests.core:TestCore.test_check_operators -- -s --logging-level=DEBUG +Adding/modifying dependencies +----------------------------- + +If you change apt dependencies in the ``Dockerfile`` or add python pacakges in ``setup.py` or +javascript dependencies in ``package.json``. You can add dependencies temporarily for one Breeze +session or permanently in ``setup.py``, ``Dockerfile``, ``package.json``. + +Installing dependencies for one Breeze session +.............................................. + +You can install dependencies inside the container using 'sudo apt install', 'pip install' or 'npm install' +(in airflow/www folder) respectively. This is useful if you want to test something quickly while in the +container. However, those changes are not persistent - they will disappear once you +exit the container (except npm dependencies in case your sources are mounted to the container). Therefore +if you want to persist a new dependency you have to follow with the second option. + +Adding dependencies permanently +............................... + +You can add the dependencies to the Dockerfile, setup.py or package.json and rebuild the image. This +should happen automatically if you modify any of setup.py, package.json or update Dockerfile itself. +After you exit the container and re-run ``breeze`` the Breeze detects changes in dependencies, +ask you to confirm rebuilding of the image and proceed to rebuilding the image if you confirm (or skip it +if you won't confirm). After rebuilding is done, it will drop you to shell. You might also provide +``--build-only`` flag to only rebuild images and not go into shell - it will then rebuild the image +and will not enter the shell. + +Optimisation for apt dependencies during development +.................................................... + +During development, changing dependencies in apt-get closer to the top of the Dockerfile +will invalidate cache for most of the image and it will take long time to rebuild the image by Breeze. +Therefore it is a recommended practice to add new dependencies initially closer to the end +of the Dockerfile. This way dependencies will be incrementally added. + +However before merge, those dependencies should be moved to the appropriate ``apt-get install`` command +which is already in the Dockerfile. + Debugging with ipdb ------------------- @@ -299,15 +325,16 @@ When you run Airflow Breeze, the following ports are automatically forwarded: You can connect to those ports/databases using: -* Webserver: (http://127.0.0.1:28080)[http://127.0.0.1:28080] +* Webserver: ``http://127.0.0.1:28080`` * Postgres: ``jdbc:postgresql://127.0.0.1:25433/airflow?user=postgres&password=airflow`` * Mysql: ``jdbc:mysql://localhost:23306/airflow?user=root`` Note that you need to start the webserver manually with ``airflow webserver`` command if you want to connect to the webserver (you can use ``tmux`` to multiply terminals). -For databases you need to run ``airflow resetdb`` at least once after you started Airflow Breeze to get -the database/tables created. You can connect to databases with IDE or any other Database client: +For databases you need to run ``airflow db reset`` at least once (or run some tests) after you started +Airflow Breeze to get the database/tables created. You can connect to databases +with IDE or any other Database client: .. image:: images/database_view.png :align: center @@ -315,106 +342,70 @@ the database/tables created. You can connect to databases with IDE or any other You can change host port numbers used by setting appropriate environment variables: -* WEBSERVER_HOST_PORT -* POSTGRES_HOST_PORT -* MYSQL_HOST_PORT +* ``WEBSERVER_HOST_PORT`` +* ``POSTGRES_HOST_PORT`` +* ``MYSQL_HOST_PORT`` When you set those variables, next time when you enter the environment the new ports should be in effect. -Using your host IDE -=================== +Cleaning up the images +---------------------- -Configuring local virtualenv ----------------------------- +You might need to cleanup your Docker environment occasionally. The images are quite big +(1.5GB for both images needed for static code analysis and CI tests). And if you often rebuild/update +images you might end up with some unused image data. -In order to use your host IDE (for example IntelliJ's PyCharm/Idea) you need to have virtual environments -setup. Ideally you should have virtualenvs for all python versions that Airflow supports (3.5, 3.6, 3.7). -You can create the virtualenv using ``virtualenvwrapper`` - that will allow you to easily switch between -virtualenvs using workon command and mange your virtual environments more easily. - -Typically creating the environment can be done by: +Cleanup can be performed with ``docker system prune`` command. +Make sure to `Stop Breeze <#stopping-breeze>`_ first with ``./breeze --stop-environment``. -.. code-block:: bash +If you run into disk space errors, we recommend you prune your docker images using the +``docker system prune --all`` command. You might need to restart the docker +engine before running this command. - mkvirtualenv --python=python +You can check if your docker is clean by running ``docker images --all`` and ``docker ps --all`` - both +should return an empty list of images and containers respectively. +If you are on Mac OS and you end up with not enough disk space for Docker you should increase disk space +available for Docker. See `Prerequsites <#prerequisites>`. -After the virtualenv is created, you must initialize it. Simply enter the environment -(using workon) and once you are in it run: - -.. code-block:: bash - - ./breeze --initialize-local-virtualenv - -Once initialization is done, you should select the virtualenv you initialized as the project's default -virtualenv in your IDE. - -Running unit tests via IDE --------------------------- - -After setting it up - you can use the usual "Run Test" option of the IDE and have all the -autocomplete and documentation support from IDE as well as you can debug and click-through -the sources of Airflow - which is very helpful during development. Usually you also can run most -of the unit tests (those that do not require prerequisites) directly from the IDE: - -Running unit tests from IDE is as simple as: - -.. image:: images/running_unittests.png - :align: center - :alt: Running unit tests - -Some of the core tests use dags defined in ``tests/dags`` folder - those tests should have -``AIRFLOW__CORE__UNIT_TEST_MODE`` set to True. You can set it up in your test configuration: - -.. image:: images/airflow_unit_test_mode.png - :align: center - :alt: Airflow Unit test mode - - -You cannot run all the tests this way - only unit tests that do not require external dependencies -such as postgres/mysql/hadoop etc. You should use -`Running tests in Airflow Breeze <#running-tests-in-airflow-breeze>`_ in order to run those tests. You can -still use your IDE to debug those tests as explained in the next chapter. - -Debugging Airflow Breeze Tests in IDE -------------------------------------- - -When you run example DAGs, even if you run them using UnitTests from within IDE, they are run in a separate -container. This makes it a little harder to use with IDE built-in debuggers. -Fortunately for IntelliJ/PyCharm it is fairly easy using remote debugging feature (note that remote -debugging is only available in paid versions of IntelliJ/PyCharm). +Troubleshooting +--------------- -You can read general description `about remote debugging -`_ +If you are having problems with the Breeze environment - try the following (after each step you +can check if your problem is fixed) -You can setup your remote debug session as follows: +1. Check if you have enough disks space in Docker if you are on MacOS. +2. Stop Breeze - use ``./breeze --stop-environment`` +3. Delete ``.build`` directory and run ``./breeze --force-pull-images`` +4. `Clean up docker images <#cleaning-up-the-images>`_ +5. Restart your docker engine and try again +6. Restart your machine and try again +7. Remove and re-install Docker CE and try again -.. image:: images/setup_remote_debugging.png - :align: center - :alt: Setup remote debugging +In case the problems are not solved, you can set VERBOSE variable to "true" (`export VERBOSE="true"`) +and rerun failing command, and copy & paste the output from your terminal, describe the problem and +post it in [Airflow Slack](https://apache-airflow-slack.herokuapp.com/) #troubleshooting channel. -Not that if you are on ``MacOS`` you have to use the real IP address of your host rather than default -localhost because on MacOS container runs in a virtual machine with different IP address. -You also have to remember about configuring source code mapping in remote debugging configuration to map -your local sources into the ``/opt/airflow`` location of the sources within the container. +Using Breeze for other tasks +============================ -.. image:: images/source_code_mapping_ide.png - :align: center - :alt: Source code mapping +Running static code checks +-------------------------- +We have a number of static code checks that are run in Travis CI but you can run them locally as well. -Running commands via Airflow Breeze -=================================== +All these tests run in python3.5 environment. Note that the first time you run the checks it might take some +time to rebuild the docker images required to run the tests, but all subsequent runs will be much faster - +the build phase will just check if your code has changed and rebuild as needed. -Running static code checks --------------------------- +The checks below are run in a docker environment, which means that if you run them locally, +they should give the same results as the tests run in TravisCI without special environment preparation. -If you wish to run static code checks inside Docker environment you can do it via -``-S``, ``--static-check`` flags or ``-F``, ``--static-check-all-files``. The former will run appropriate -checks only for files changed and staged locally, the latter will run it on all files. It can take a lot of -time to run check for all files in case of pylint on MacOS due to slow filesystem for Mac OS Docker. -You can add arguments you should pass them after -- as extra arguments. +You run the checks via ``-S``, ``--static-check`` flags or ``-F``, ``--static-check-all-files``. +The former will run appropriate checks only for files changed and staged locally, the latter will run it +on all files. It can take a lot of time to run check for all files in case of pylint on MacOS due to slow +filesystem for Mac OS Docker. You can add arguments you should pass them after -- as extra arguments. You cannot pass ``--files`` flage if you selected ``--static-check-all-files`` option. You can see the list of available static checks via --help flag or use autocomplete. Most notably ``all`` @@ -500,14 +491,13 @@ The documentation is build using ``-O``, ``--build-docs`` command: ./breeze --build-docs - Results of the build can be found in ``docs/_build`` folder. Often errors during documentation generation come from the docstrings of auto-api generated classes. During the docs building auto-api generated files are stored in ``docs/_api`` folder - so that in case of problems with documentation you can find where the problems with documentation originated from. -Running tests -------------- +Running tests directly from host +-------------------------------- If you wish to run tests only and not drop into shell, you can run them by providing -t, --test-target flag. You can add extra nosetest flags after -- in the commandline. @@ -528,6 +518,18 @@ You can also specify individual tests or group of tests: ./breeze --test-target tests.core:TestCore +Pulling the latest images +------------------------- + +Sometimes the image on DockerHub is rebuilt from the scratch. This happens for example when there is a +security update of the python version that all the images are based on. +In this case it is usually faster to pull latest images rather than rebuild them +from the scratch. + +You can do it via ``--force-pull-images`` flag to force pull latest images from DockerHub. + +In the future Breeze will warn you when you are advised to force pull images. + Running commands inside Docker ------------------------------ @@ -555,60 +557,87 @@ after -- as extra arguments. ./breeze --docker-compose pull -- --ignore-pull-failures -Convenience scripts -------------------- +Using your host IDE +=================== -Once you run ./breeze you can also execute various actions via generated convenience scripts +Configuring local virtualenv +---------------------------- -.. code-block:: +In order to use your host IDE (for example IntelliJ's PyCharm/Idea) you need to have virtual environments +setup. Ideally you should have virtualenvs for all python versions that Airflow supports (3.5, 3.6, 3.7). +You can create the virtualenv using ``virtualenvwrapper`` - that will allow you to easily switch between +virtualenvs using workon command and mange your virtual environments more easily. - Enter the environment : ./.build/cmd_run - Run command in the environment : ./.build/cmd_run "[command with args]" [bash options] - Run tests in the environment : ./.build/test_run [test-target] [nosetest options] - Run Docker compose command : ./.build/dc [help/pull/...] [docker-compose options] +Typically creating the environment can be done by: -Keeping images up-to-date -========================= +.. code-block:: bash -Updating dependencies ---------------------- + mkvirtualenv --python=python -If you change apt dependencies in the Dockerfile or change setup.py or -add new apt dependencies or npm dependencies, you have two options how to update the dependencies. +After the virtualenv is created, you must initialize it. Simply enter the environment +(using workon) and once you are in it run: +.. code-block:: bash -* - You can install dependencies inside the container using 'sudo apt install', 'pip install' or 'npm install' - (in airflow/www folder) respectively. This is useful if you want to test somthing quickly while in the - container. However, those changes are not persistent - they will disappear once you - exit the container (except npm dependencies in case your sources are mounted to the container). Therefore - if you want to persist a new dependency you have to follow with the second option. + ./breeze --initialize-local-virtualenv -* - You can add the dependencies to the Dockerfile, setup.py or package.json and rebuild the image. This - should happen automatically if you modify any of setup.py, package.json or update Dockerfile itself. - After you exit the container and re-run `./breeze <./breeze>`_ the Breeze detects changes in dependencies, - ask you to confirm rebuilding of the image and proceed to rebuilding the image if you confirm (or skip it - if you won't confirm). After rebuilding is done, it will drop you to shell. You might also provide - ``--build-only`` flag to only rebuild images and not go into shell - it will then rebuild the image - and will not enter the shell. +Once initialization is done, you should select the virtualenv you initialized as the project's default +virtualenv in your IDE. -Note that during development, changing dependencies in apt-get closer to the top of the Dockerfile -will invalidate cache for most of the image and it will take long time to rebuild the image by breeze. -Therefore it is a recommended practice to add new dependencies closer to the bottom of -Dockerfile during development (to get the new dependencies incrementally added) and only move them to the -top when you are close to finalise the PR and merge the change. It's OK for development time to add separate -``apt-get install`` commands similar to those that are already there (but remember to move newly added -dependencies to the appropriate ``apt-get install`` command which is already in the Dockerfile. +Running unit tests via IDE +-------------------------- -Pulling the images ------------------- +After setting it up - you can use the usual "Run Test" option of the IDE and have all the +autocomplete and documentation support from IDE as well as you can debug and click-through +the sources of Airflow - which is very helpful during development. Usually you also can run most +of the unit tests (those that do not require prerequisites) directly from the IDE: + +Running unit tests from IDE is as simple as: + +.. image:: images/running_unittests.png + :align: center + :alt: Running unit tests + +Some of the core tests use dags defined in ``tests/dags`` folder - those tests should have +``AIRFLOW__CORE__UNIT_TEST_MODE`` set to True. You can set it up in your test configuration: + +.. image:: images/airflow_unit_test_mode.png + :align: center + :alt: Airflow Unit test mode + + +You cannot run all the tests this way - only unit tests that do not require external dependencies +such as postgres/mysql/hadoop etc. You should use +`Running tests in Airflow Breeze <#running-tests-in-airflow-breeze>`_ in order to run those tests. You can +still use your IDE to debug those tests as explained in the next chapter. + +Debugging Airflow Breeze Tests in IDE +------------------------------------- + +When you run example DAGs, even if you run them using UnitTests from within IDE, they are run in a separate +container. This makes it a little harder to use with IDE built-in debuggers. +Fortunately for IntelliJ/PyCharm it is fairly easy using remote debugging feature (note that remote +debugging is only available in paid versions of IntelliJ/PyCharm). + +You can read general description `about remote debugging +`_ + +You can setup your remote debug session as follows: + +.. image:: images/setup_remote_debugging.png + :align: center + :alt: Setup remote debugging + +Not that if you are on ``MacOS`` you have to use the real IP address of your host rather than default +localhost because on MacOS container runs in a virtual machine with different IP address. + +You also have to remember about configuring source code mapping in remote debugging configuration to map +your local sources into the ``/opt/airflow`` location of the sources within the container. + +.. image:: images/source_code_mapping_ide.png + :align: center + :alt: Source code mapping -Sometimes the image on DockerHub is rebuilt from the scratch. This happens for example when there is a -security update of the python version that all the images are based on. In this case it is much faster to -pull latest images rather than rebuild them from the scratch. Airflow Breeze will detect such case and -will ask you to confirm to pull and build the image and if you answer OK, it will pull and build the image. -You might also provide ``--force-pull-images`` flag to force pull latest images from DockerHub. Airflow Breeze flags ==================== @@ -783,3 +812,247 @@ These are the current flags of the `./breeze <./breeze>`_ script -c, --cleanup-images Cleanup your local docker cache of the airflow docker images. This will not reclaim space in docker cache. You need to 'docker system prune' (optionally with --all) to reclaim that space. + +Internals of Airflow Breeze +=========================== + +Airflow Breeze is just a glorified bash script that is a "Swiss-Army-Knife" of Airflow testing. Under the +hood it uses other scripts that you can also run manually if you have problem with running the Breeze +environment. This chapter explains the inner details of Breeze. + +Available Airflow Breeze environments +------------------------------------- + +You can choose environment when you run Breeze with ``--env`` flag. +Running the default ``docker`` environment takes considerable amount of resources. You can run a slimmed-down +version of the environment - just the Apache Airflow container - by choosing ``bare`` environment instead. + +The following environments are available: + + * The ``docker`` environment (default): starts all dependencies required by full integration test-suite + (postgres, mysql, celery, etc.). This option is resource intensive so do not forget to + [Stop environment](#stopping-the-environment) when you are finished. This option is also RAM intensive + and can slow down your machine. + * The ``kubernetes`` environment: Runs airflow tests within a kubernetes cluster. + * The ``bare`` environment: runs airflow in docker without any external dependencies. + It will only work for non-dependent tests. You can only run it with sqlite backend. + +Running manually static code checks +----------------------------------- + +You can trigger the static checks from the host environment, without entering Docker container. You +do that by running appropriate scripts (The same is done in TravisCI) + +* ``_ - checks if all licences are OK +* ``_ - checks that documentation can be built without warnings +* ``_ - runs flake8 source code style guide enforcement tool +* ``_ - runs lint checker for the Dockerfile +* ``_ - runs mypy type annotation consistency check +* ``_ - runs pylint static code checker for main files +* '``_ - runs pylint static code checker for tests + +The scripts will ask to rebuild the images if needed. + +You can force rebuilding of the images by deleting [.build](./build) directory. This directory keeps cached +information about the images already built and you can safely delete it if you want to start from the scratch. + +After Documentation is built, the html results are available in [docs/_build/html](docs/_build/html) folder. +This folder is mounted from the host so you can access those files in your host as well. + +Running manually static code checks in Docker +--------------------------------------------- + +If you are already in the Breeze Docker (by running ``./breeze`` command) you can also run the s +ame static checks from within container: + +* Mypy: ``./scripts/ci/in_container/run_mypy.sh airflow tests`` +* Pylint for main files: ``./scripts/ci/in_container/run_pylint_main.sh`` +* Pylint for test files: ``./scripts/ci/in_container/run_pylint_tests.sh`` +* Flake8: ``./scripts/ci/in_container/run_flake8.sh`` +* Licence check: ``./scripts/ci/in_container/run_check_licence.sh`` +* Documentation: ``./scripts/ci/in_container/run_docs_build.sh`` + +Running static code analysis for selected files +----------------------------------------------- + +In all static check scripts - both in container and in the host you can also pass module/file path as +parameters of the scripts to only check selected modules or files. For example: + +In container: + +.. code-block:: + + ./scripts/ci/in_container/run_pylint.sh ./airflow/example_dags/ + +or + +.. code-block:: + + ./scripts/ci/in_container/run_pylint.sh ./airflow/example_dags/test_utils.py + +In host: + +.. code-block:: + + ./scripts/ci/ci_pylint.sh ./airflow/example_dags/ + + +.. code-block:: + + ./scripts/ci/ci_pylint.sh ./airflow/example_dags/test_utils.py + +And similarly for other scripts. + +Docker images used by Breeze +---------------------------- + +For all development tasks related integration tests and static code checks we are using Docker +images that are maintained in DockerHub under ``apache/airflow`` repository. + +There are three images that we currently manage: + +* **Slim CI** image that is used for static code checks (size around 500MB) - tag follows the pattern + of ``-python-ci-slim`` (for example ``apache/airflow:master-python3.6-ci-slim``). + The image is built using the [Dockerfile](Dockerfile) dockerfile. +* **Full CI image*** that is used for testing - containing a lot more test-related installed software + (size around 1GB) - tag follows the pattern of ``-python-ci`` + (for example ``apache/airflow:master-python3.6-ci``). The image is built using the + ``_ dockerfile. +* **Checklicence image** - an image that is used during licence check using Apache RAT tool. It does not + require any of the dependencies that the two CI images need so it is built using different Dockerfile + ``_ and only contains Java + Apache RAT tool. The image is + labeled with ``checklicence`` label - for example ``apache/airflow:checklicence``. No versioning is used for + the checklicence image. + +We also use a very small ``_ dockerfile in order to fix file permissions +for an obscure permission problem with Docker caching but it is not stored in `apache/airflow` registry. + +Before you run tests or enter environment or run local static checks, the necessary local images should be +pulled and built from DockerHub. This happens automatically for the test environment but you need to +manually trigger it for static checks as described in `Building the images <#bulding-the-images>`_ +and `Force pulling the images <#force-pulling-the-images>`_. +The static checks will fail and inform what to do if the image is not yet built. + +Note that building the image first time pulls the pre-built version of images from DockerHub might take some +of time - but this wait-time will not repeat for subsequent source code changes. +However, changes to sensitive files like setup.py or Dockerfile will trigger a rebuild +that might take more time (but it is highly optimised to only rebuild what's needed) + +In most cases re-building an image requires connectivity to network (for example to download new +dependencies). In case you work offline and do not want to rebuild the images when needed - you might set +``ASSUME_NO_TO_ALL_QUESTIONS`` variable to ``true`` as described in the +`Default behaviour for user interaction <#default-behaviour-for-user-interaction>`_ chapter. + +See `Troubleshooting section <#troubleshooting>`_ for steps you can make to clean the environment. + +Default behaviour for user interaction +-------------------------------------- + +Sometimes during the build user is asked whether to perform an action, skip it, or quit. This happens in case +of image rebuilding and image removal - they can take a lot of time and they are potentially destructive. +For automation scripts, you can export one of the three variables to control the default behaviour. + +.. code-block:: + + export ASSUME_YES_TO_ALL_QUESTIONS="true" + +If ``ASSUME_YES_TO_ALL_QUESTIONS` is set to `true`, the images will automatically rebuild when needed. +Images are deleted without asking. + +.. code-block:: + + export ASSUME_NO_TO_ALL_QUESTIONS="true" + +If ``ASSUME_NO_TO_ALL_QUESTIONS`` is set to ``true``, the old images are used even if re-building is needed. +This is useful when you work offline. Deleting images is aborted. + +.. code-block:: + + export ASSUME_QUIT_TO_ALL_QUESTIONS="true" + +If ``ASSUME_QUIT_TO_ALL_QUESTIONS`` is set to ``true``, the whole script is aborted. Deleting images is aborted. + +If more than one variable is set, YES takes precedence over NO which take precedence over QUIT. + +Running the whole suite of tests via scripts +-------------------------------------------- + +Running all tests with default settings (python 3.6, sqlite backend, docker environment): + +.. code-block:: + + ./scripts/ci/local_ci_run_airflow_testing.sh + + +Selecting python version, backend, docker environment: + +.. code-block:: + + PYTHON_VERSION=3.5 BACKEND=postgres ENV=docker ./scripts/ci/local_ci_run_airflow_testing.sh + + +Running kubernetes tests: + +.. code-block:: + + KUBERNETES_VERSION==v1.13.0 KUBERNETES_MODE=persistent_mode BACKEND=postgres ENV=kubernetes \ + ./scripts/ci/local_ci_run_airflow_testing.sh + +* PYTHON_VERSION might be one of 3.5/3.6/3.7 +* BACKEND might be one of postgres/sqlite/mysql +* ENV might be one of docker/kubernetes/bare +* KUBERNETES_VERSION - required for Kubernetes tests - currently KUBERNETES_VERSION=v1.13.0. +* KUBERNETES_MODE - mode of kubernetes, one of persistent_mode, git_mode + +The available environments are described in `` + +Fixing file/directory ownership +------------------------------- + +On Linux there is a problem with propagating ownership of created files (known Docker problem). Basically +files and directories created in container are not owned by the host user (but by the root user in our case). +This might prevent you from switching branches for example if files owned by root user are created within +your sources. In case you are on Linux host and haa some files in your sources created by the root user, +you can fix the ownership of those files by running + +.. code-block:: + + ./scripts/ci/local_ci_fix_ownership.sh + +Building the images +------------------- + +You can manually trigger building of the local images using: + +.. code-block:: + + ./scripts/ci/local_ci_build.sh + +The scripts that build the images are optimised to minimise the time needed to rebuild the image when +the source code of Airflow evolves. This means that if you already had the image locally downloaded and built, +the scripts will determine, the rebuild is needed in the first place. Then it will make sure that minimal +number of steps are executed to rebuild the parts of image (for example PIP dependencies) that will give +you an image consistent with the one used during Continuous Integration. + +Force pulling the images +------------------------ + +You can also force-pull the images before building them locally so that you are sure that you download +latest images from DockerHub repository before building. This can be done with: + +.. code-block:: + + ./scripts/ci/local_ci_pull_and_build.sh + + +Convenience scripts +------------------- + +Once you run ./breeze you can also execute various actions via generated convenience scripts + +.. code-block:: + + Enter the environment : ./.build/cmd_run + Run command in the environment : ./.build/cmd_run "[command with args]" [bash options] + Run tests in the environment : ./.build/test_run [test-target] [nosetest options] + Run Docker compose command : ./.build/dc [help/pull/...] [docker-compose options] diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 9e65636d828c53..348441e0197493 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -1,5 +1,96 @@ -Airflow 1.10.4, - 2019-08-06 ----------------------------- +Airflow 1.10.5, 2019-09-04 +-------------------------- + +New Features +"""""""""""" +- [AIRFLOW-1498] Add feature for users to add Google Analytics to Airflow UI (#5850) +- [AIRFLOW-4074] Add option to add labels to Dataproc jobs (#5606) +- [AIRFLOW-4846] Allow specification of an existing secret containing git credentials for init containers (#5475) + +Improvements +"""""""""""" +- [AIRFLOW-5335] Update GCSHook methods so they need min IAM perms (#5939) +- [AIRFLOW-2692] Allow AWS Batch Operator to use templates in job_name parameter (#3557) +- [AIRFLOW-4768] Add Timeout parameter in example_gcp_video_intelligence (#5862) +- [AIRFLOW-5165] Make Dataproc highly available (#5781) +- [AIRFLOW-5139] Allow custom ES configs (#5760) +- [AIRFLOW-5340] Fix GCP DLP example (#594) +- [AIRFLOW-5211] Add pass_value to template_fields BigQueryValueCheckOperator (#5816) +- [AIRFLOW-5113] Support icon url in slack web hook (#5724) +- [AIRFLOW-4230] bigquery schema update options should be a list (#5766) +- [AIRFLOW-1523] Clicking on Graph View should display related DAG run (#5866) +- [AIRFLOW-5027] Generalized CloudWatch log grabbing for ECS and SageMaker operators (#5645) +- [AIRFLOW-5244] Add all possible themes to default_webserver_config.py (#5849) +- [AIRFLOW-5245] Add more metrics around the scheduler (#5853) +- [AIRFLOW-5048] Improve display of Kubernetes resources (#5665) +- [AIRFLOW-5284] Replace deprecated log.warn by log.warning (#5881) +- [AIRFLOW-5276] Remove unused helpers from airflow.utils.helpers (#5878) +- [AIRFLOW-4316] Support setting kubernetes_environment_variables config section from env var (#5668) + +Bug fixes +""""""""" +- [AIRFLOW-5168] Fix Dataproc operators that failed in 1.10.4 (#5928) +- [AIRFLOW-5136] Fix Bug with Incorrect template_fields in DataProc{*} Operators (#5751) +- [AIRFLOW-5169] Pass GCP Project ID explicitly to StorageClient in GCSHook (#5783) +- [AIRFLOW-5302] Fix bug in none_skipped Trigger Rule (#5902) +- [AIRFLOW-5350] Fix bug in the num_retires field in BigQueryHook (#5955) +- [AIRFLOW-5145] Fix rbac ui presents false choice to encrypt or not encrypt variable values (#5761) +- [AIRFLOW-5104] Set default schedule for GCP Transfer operators (#5726) +- [AIRFLOW-4462] Use datetime2 column types when using MSSQL backend (#5707) +- [AIRFLOW-5282] Add default timeout on kubeclient & catch HTTPError (#5880) +- [AIRFLOW-5315] TaskInstance not updating from DB when user changes executor_config (#5926) +- [AIRFLOW-4013] Mark success/failed is picking all execution date (#5616) +- [AIRFLOW-5152] Fix autodetect default value in GoogleCloudStorageToBigQueryOperator(#5771) +- [AIRFLOW-5100] Airflow scheduler does not respect safe mode setting (#5757) +- [AIRFLOW-4763] Allow list in DockerOperator.command (#5408) +- [AIRFLOW-5260] Allow empty uri arguments in connection strings (#5855) +- [AIRFLOW-5257] Fix ElasticSearch log handler errors when attemping to close logs (#5863) +- [AIRFLOW-1772] Google Updated Sensor doesnt work with CRON expressions (#5730) +- [AIRFLOW-5085] When you run kubernetes git-sync test from TAG, it fails (#5699) +- [AIRFLOW-5258] ElasticSearch log handler, has 2 times of hours (%H and %I) in _clean_execution_dat (#5864) +- [AIRFLOW-5348] Escape Label in deprecated chart view when set via JS (#5952) +- [AIRFLOW-5357] Fix Content-Type for exported variables.json file (#5963) +- [AIRFLOW-5109] Fix process races when killing processes (#5721) +- [AIRFLOW-5240] Latest version of Kombu is breaking airflow for py2 + +Misc/Internal +""""""""""""" +- [AIRFLOW-5111] Remove apt-get upgrade from the Dockerfile (#5722) +- [AIRFLOW-5209] Fix Documentation build (#5814) +- [AIRFLOW-5083] Check licence image building can be faster and moved to before-install (#5695) +- [AIRFLOW-5119] Cron job should always rebuild everything from scratch (#5733) +- [AIRFLOW-5108] In the CI local environment long-running kerberos might fail sometimes (#5719) +- [AIRFLOW-5092] Latest python image should be pulled locally in force_pull_and_build (#5705) +- [AIRFLOW-5225] Consistent licences can be added automatically for all JS files (#5827) +- [AIRFLOW-5229] Add licence to all other file types (#5831) +- [AIRFLOW-5227] Consistent licences for all .sql files (#5829) +- [AIRFLOW-5161] Add pre-commit hooks to run static checks for only changed files (#5777) +- [AIRFLOW-5159] Optimise checklicence image build (do not build if not needed) (#5774) +- [AIRFLOW-5263] Show diff on failure of pre-commit checks (#5869) +- [AIRFLOW-5204] Shell files should be checked with shellcheck and have identical licence (#5807) +- [AIRFLOW-5233] Check for consistency in whitespace (tabs/eols) and common problems (#5835) +- [AIRFLOW-5247] Getting all dependencies from NPM can be moved up in Dockerfile (#5870) +- [AIRFLOW-5143] Corrupted rat.jar became part of the Docker image (#5759) +- [AIRFLOW-5226] Consistent licences for all html JINJA templates (#5828) +- [AIRFLOW-5051] Coverage is not properly reported in the new CI system (#5732) +- [AIRFLOW-5239] Small typo and incorrect tests in CONTRIBUTING.md (#5844) +- [AIRFLOW-5287] Checklicence base image is not pulled (#5886) +- [AIRFLOW-5301] Some not-yet-available files from breeze are committed to master (#5901) +- [AIRFLOW-5285] Pre-commit pylint runs over todo files (#5884) +- [AIRFLOW-5288] Temporary container for static checks should be auto-removed (#5887) +- [AIRFLOW-5326] Fix teething proglems for Airflow breeze (#5933) +- [AIRFLOW-5206] All .md files should have all common licence, TOC (where applicable) (#5809) +- [AIRFLOW-5329] Easy way to add local files to docker (#5933) +- [AIRFLOW-4027] Make experimental api tests more stateless (#4854) + +Doc-only changes +"""""""""""""""" +- [AIRFLOW-XXX] Fixed Azkaban link (#5865) +- [AIRFLOW-XXX] Remove duplicate lines from CONTRIBUTING.md (#5830) +- [AIRFLOW-XXX] Fix incorrect docstring parameter in SchedulerJob (#5729) + +Airflow 1.10.4, 2019-08-06 +-------------------------- New Features """""""""""" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bef22fe62ff744..93306591d7185c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,19 +32,9 @@ little bit helps, and credit will always be given. - [Improve Documentation](#improve-documentation) - [Submit Feedback](#submit-feedback) - [Documentation](#documentation) -- [Local virtualenv development environment](#local-virtualenv-development-environment) - - [Installation](#installation) - - [Running individual tests](#running-individual-tests) - - [Running tests directly from the IDE](#running-tests-directly-from-the-ide) -- [Integration test development environment](#integration-test-development-environment) - - [Prerequisites](#prerequisites) - - [Using the Docker Compose environment](#using-the-docker-compose-environment) - - [Running static code analysis](#running-static-code-analysis) - - [Docker images](#docker-images) - - [Default behaviour for user interaction](#default-behaviour-for-user-interaction) - - [Local Docker Compose scripts](#local-docker-compose-scripts) - - [Cleaning up cached Docker images/containers](#cleaning-up-cached-docker-imagescontainers) - - [Troubleshooting](#troubleshooting) +- [Development environments](#development-environments) + - [Local virtualenv development environment](#local-virtualenv-development-environment) + - [Breeze development environment](#breeze-development-environment) - [Pylint checks](#pylint-checks) - [Pre-commit hooks](#pre-commit-hooks) - [Installing pre-commit hooks](#installing-pre-commit-hooks) @@ -126,17 +116,25 @@ cd docs ./start_doc_server.sh ``` -# Local virtualenv development environment +# Development environments -When you develop Airflow you can create local virtualenv with all requirements required by Airflow. +There are two development environments you can use to develop Apache Airflow: -Advantage of local installation is that everything works locally, you do not have to enter Docker/container -environment and you can easily debug the code locally. You can also have access to python virtualenv that -contains all the necessary requirements and use it in your local IDE - this aids autocompletion, and -running tests directly from within the IDE. +The first is Local virtualenv development environment that can be used to use your IDE and to +run basic unit tests. + +## Local virtualenv development environment + +All details about using and running local virtualenv enviroment for Airflow can be found +in [LOCAL_VIRTUALENV.rst](LOCAL_VIRTUALENV.rst) It is **STRONGLY** encouraged to also install and use [Pre commit hooks](#pre-commit-hooks) for your local -development environment. They will speed up your development cycle speed a lot. +development environment. Pre-commit hooks can speed up your development cycle a lot. + +Advantage of local installation is that you have packages installed locally. Youu do not have to +enter Docker/container environment and you can easily debug the code locally. +You can also have access to python virtualenv that contains the necessary requirements +and use it in your local IDE - this aids autocompletion, and run tests directly from within the IDE. The disadvantage is that you have to maintain your dependencies and local environment consistent with other development environments that you have on your local machine. @@ -146,205 +144,29 @@ external components - mysql, postgres database, hadoop, mongo, cassandra, redis The tests in Airflow are a mixture of unit and integration tests and some of them require those components to be setup. Only real unit tests can be run by default in local environment. -If you want to run integration tests, you need to configure and install the dependencies on your own. +If you want to run integration tests, you can technically configure and install the dependencies on your own, +but it is usually complex and it's better to use +[Breeze development environment](#breeze-development-environment) instead. -It's also very difficult to make sure that your local environment is consistent with other environments. -This can often lead to "works for me" syndrome. It's better to use the Docker Compose integration test -environment in case you want reproducible environment consistent with other people. +Yet another disdvantage of using local virtualenv is that it is very difficult to make sure that your +local environment is consistent with other developer's environments. This can often lead to "works for me" +syndrome. The Breeze development environment provides reproducible environment that is +consistent with other developers. -## Installation - -Install Python (3.5 or 3.6), MySQL, and libxml by using system-level package -managers like yum, apt-get for Linux, or Homebrew for Mac OS at first. -Refer to the [Dockerfile](Dockerfile) for a comprehensive list of required packages. - -In order to use your IDE you need you can use the virtual environment. Ideally -you should setup virtualenv for all python versions that Airflow supports (3.5, 3.6). -An easy way to create the virtualenv is to use -[virtualenvwrapper](https://virtualenvwrapper.readthedocs.io/en/latest/) - it allows -you to easily switch between virtualenvs using `workon` command and mange -your virtual environments more easily. Typically creating the environment can be done by: - -``` -mkvirtualenv --python=python -``` - -Then you need to install python PIP requirements. Typically it can be done with: -`pip install -e ".[devel]"`. - -Note - if you have trouble installing mysql client on MacOS and you have an error similar to -``` -ld: library not found for -lssl -``` - -you should set LIBRARY_PATH before running `pip install`: - -``` -export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/opt/openssl/lib/ -``` - -After creating the virtualenv, run this command to create the Airflow sqlite database: -``` -airflow db init -``` - -This can be automated if you do it within a virtualenv. -The [./breeze](./breeze) script has a flag -(-e or --initialize-local-virtualenv) that automatically installs dependencies -in the virtualenv you are logged in and resets the sqlite database as described below. - -After the virtualenv is created, you must initialize it. Simply enter the environment -(using `workon`) and once you are in it run: -``` -./breeze --initialize-local-virtualenv -```` - -Once initialization is done, you should select the virtualenv you initialized as the -project's default virtualenv in your IDE and run tests efficiently. - -After setting it up - you can use the usual "Run Test" option of the IDE and have -the autocomplete and documentation support from IDE as well as you can -debug and view the sources of Airflow - which is very helpful during -development. - -## Running individual tests - -Once you activate virtualenv (or enter docker container) as described below you should be able to run -`run-tests` at will (it is in the path in Docker environment but you need to prepend it with `./` in local -virtualenv (`./run-tests`). - -Note that this script has several flags that can be useful for your testing. - -```text -Usage: run-tests [FLAGS] [TESTS_TO_RUN] -- - -Runs tests specified (or all tests if no tests are specified) - -Flags: - --h, --help - Shows this help message. - --i, --with-db-init - Forces database initialization before tests - --s, --nocapture - Don't capture stdout when running the tests. This is useful if you are - debugging with ipdb and want to drop into console with it - by adding this line to source code: - - import ipdb; ipdb.set_trace() - --v, --verbose - Verbose output showing coloured output of tests being run and summary - of the tests - in a manner similar to the tests run in the CI environment. -``` - -You can pass extra parameters to nose, by adding nose arguments after `--` - -For example, in order to just execute the "core" unit tests and add ipdb set_trace method, you can -run the following command: - -```bash -./run-tests tests.core:TestCore --nocapture --verbose -``` - -or a single test method without colors or debug logs: - -```bash -./run-tests tests.core:TestCore.test_check_operators -``` -Note that `./run_tests` script runs tests but the first time it runs, it performs database initialisation. -If you run further tests without leaving the environment, the database will not be initialized, but you -can always force database initialization with `--with-db-init` (`-i`) switch. The scripts will -inform you what you can do when they are run. - -## Running tests directly from the IDE - -Once you configure your tests to use the virtualenv you created. running tests -from IDE is as simple as: - -![Run unittests](images/run_unittests.png) - -Note that while most of the tests are typical "unit" tests that do not -require external components, there are a number of tests that are more of -"integration" or even "system" tests (depending on the convention you use). -Those tests interact with external components. For those tests -you need to run complete Docker Compose - base environment below. - -# Integration test development environment - -This is the environment that is used during CI builds on Travis CI. We have scripts to reproduce the -Travis environment and you can enter the environment and run it locally. - -The scripts used by Travis CI run also image builds which make the images contain all the sources. You can -see which scripts are used in [.travis.yml](.travis.yml) file. - -## Prerequisites - -**Docker** - -You need to have [Docker CE](https://docs.docker.com/get-started/) installed. - -IMPORTANT!!! : Mac OS Docker default Disk size settings - -When you develop on Mac OS you usually have not enough disk space for Docker if you start using it seriously. -You should increase disk space available before starting to work with the environment. Usually you have weird -problems with docker containers when you run out of Disk space. It might not be obvious that space is an -issue. If you get into weird behaviour try [Cleaning Up Docker](#cleaning-up-cached-docker-imagescontainers) - -See [Docker for Mac - Space](https://docs.docker.com/docker-for-mac/space/) for details of increasing -disk space available for Docker on Mac. - -At least 128 GB of Disk space is recommended. You can also get by with smaller space but you should more often -clean the docker disk space periodically. - -**Getopt and coreutils** - -If you are on MacOS: - -* Run `brew install gnu-getopt coreutils` (if you use brew, or use equivalent command for ports) -* Then (with brew) link the gnu-getopt to become default as suggested by brew. - -If you use bash, you should run this command: - -```bash -echo 'export PATH="/usr/local/opt/gnu-getopt/bin:$PATH"' >> ~/.bash_profile -. ~/.bash_profile -``` - -If you use zsh, you should run this command: - -```bash -echo 'export PATH="/usr/local/opt/gnu-getopt/bin:$PATH"' >> ~/.zprofile -. ~/.zprofile -``` - -if you use zsh - -* Login and logout afterwards - -If you are on Linux: - -* Run `apt install util-linux coreutils` or equivalent if your system is not Debian-based. - -## Using the Docker Compose environment - -Airflow has a super-easy-to-use integration test environment managed via -[Docker Compose](https://docs.docker.com/compose/) and used by Airflow's CI Travis tests. - -It's called **Airflow Breeze** as in "_It's a breeze to develop Airflow_" +## Breeze development environment All details about using and running Airflow Breeze can be found in [BREEZE.rst](BREEZE.rst) -The advantage of the Airflow Breeze Integration Tests environment is that it is a full environment +Using Breeze locally is easy. It's called **Airflow Breeze** as in "_It's a Breeze to develop Airflow_" + +The advantage of the Airflow Breeze environment is that it is a full environment including external components - mysql database, hadoop, mongo, cassandra, redis etc. Some of the tests in -Airflow require those external components. Integration test environment provides preconfigured environment -where all those services are running and can be used by tests automatically. +Airflow require those external components. The Breeze environment provides preconfigured docker compose +environment with all those services are running and can be used by tests automatically. Another advantage is that the Airflow Breeze environment is pretty much the same as used in [Travis CI](https://travis-ci.com/) automated builds, and if the tests run in -your local environment they will most likely work on Travis as well. +your local environment they will most likely work in CI as well. The disadvantage of Airflow Breeze is that it is fairly complex and requires time to setup. However it is all automated and easy to setup. Another disadvantage is that it takes a lot of space in your local Docker cache. @@ -353,265 +175,13 @@ around 3GB in total. Building and preparing the environment by default uses pre- (requires time to download and extract those GB of images) and less than 10 minutes per python version to build. -Note that those images are not supposed to be used in production environments. They are optimised -for repeatability of tests, maintainability and speed of building rather than performance - -### Running individual tests within the container - -Once you are inside the environment you can run individual tests as described in -[Running individual tests](#running-individual-tests). - -## Running static code analysis - -We have a number of static code checks that are run in Travis CI but you can run them locally as well. -All the scripts are available in [scripts/ci](scripts/ci) folder. - -All these tests run in python3.6 environment. Note that the first time you run the checks it might take some -time to rebuild the docker images required to run the tests, but all subsequent runs will be much faster - -the build phase will just check if your code has changed and rebuild as needed. - -The checks below are run in a docker environment, which means that if you run them locally, -they should give the same results as the tests run in TravisCI without special environment preparation. - -#### Running static code analysis from the host - -You can trigger the static checks from the host environment, without entering Docker container. You -do that by running appropriate scripts (The same is done in TravisCI) - -* [scripts/ci/ci_check_license.sh](scripts/ci/ci_check_license.sh) - checks if all licences are present in the sources -* [scripts/ci/ci_docs.sh](scripts/ci/ci_docs.sh) - checks that documentation can be built without warnings. -* [scripts/ci/ci_flake8.sh](scripts/ci/ci_flake8.sh) - runs flake8 source code style guide enforcement tool -* [scripts/ci/ci_lint_dockerfile.sh](scripts/ci/ci_lint_dockerfile.sh) - runs lint checker for the Dockerfile -* [scripts/ci/ci_mypy.sh](scripts/ci/ci_mypy.sh) - runs mypy type annotation consistency check -* [scripts/ci/ci_pylint_main.sh](scripts/ci/ci_pylint_main.sh) - runs pylint static code checker for main files -* [scripts/ci/ci_pylint_tests.sh](scripts/ci/ci_pylint_tests.sh) - runs pylint static code checker for tests - -The scripts will fail by default when image rebuild is needed (for example when dependencies change) -and provide instruction on how to rebuild the images. You can control the default behaviour as explained in -[Default behaviour for user interaction](#default-behaviour-for-user-interaction) - -You can force rebuilding of the images by deleting [.build](./build) directory. This directory keeps cached -information about the images already built and you can safely delete it if you want to start from the scratch. - -After Documentation is built, the html results are available in [docs/_build/html](docs/_build/html) folder. -This folder is mounted from the host so you can access those files in your host as well. - -#### Running static code analysis in the docker compose environment - -If you are already in the [Docker Compose Environment](#entering-bash-shell-in-docker-compose-environment) -you can also run the same static checks from within container: - -* Mypy: `./scripts/ci/in_container/run_mypy.sh airflow tests` -* Pylint for main files: `./scripts/ci/in_container/run_pylint_main.sh` -* Pylint for test files: `./scripts/ci/in_container/run_pylint_tests.sh` -* Flake8: `./scripts/ci/in_container/run_flake8.sh` -* Licence check: `./scripts/ci/in_container/run_check_licence.sh` -* Documentation: `./scripts/ci/in_container/run_docs_build.sh` - -#### Running static code analysis on selected files/modules - -In all static check scripts - both in container and in the host you can also pass module/file path as -parameters of the scripts to only check selected modules or files. For example: - -In container: - -`./scripts/ci/in_container/run_pylint.sh ./airflow/example_dags/` - -or - -`./scripts/ci/in_container/run_pylint.sh ./airflow/example_dags/test_utils.py` - -In host: - -`./scripts/ci/ci_pylint.sh ./airflow/example_dags/` - -or - -`./scripts/ci/ci_pylint.sh ./airflow/example_dags/test_utils.py` - -And similarly for other scripts. - -## Docker images - -For all development tasks related integration tests and static code checks we are using Docker -images that are maintained in Dockerhub under `apache/airflow` repository. - -There are three images that we currently manage: - -* Slim CI image that is used for static code checks (size around 500MB) - tag follows the pattern - of `-python-ci-slim` (for example `master-python3.6-ci-slim`). The image is built - using the [Dockerfile](Dockerfile) dockerfile. -* Full CI image that is used for testing - containing a lot more test-related installed software - (size around 1GB) - tag follows the pattern of `-python-ci` - (for example `master-python3.6-ci`). The image is built using the [Dockerfile](Dockerfile) dockerfile. -* Checklicence image - an image that is used during licence check using Apache RAT tool. It does not - require any of the dependencies that the two CI images need so it is built using different Dockerfile - [Dockerfile-checklicence](Dockerfile-checklicence) and only contains Java + Apache RAT tool. The image is - labeled with `checklicence` image. - -We also use a very small [Dockerfile-context](Dockerfile-context) dockerfile in order to fix file permissions -for an obscure permission problem with Docker caching but it is not stored in `apache/airflow` registry. - -Before you run tests or enter environment or run local static checks, the necessary local images should be -pulled and built from DockerHub. This happens automatically for the test environment but you need to -manually trigger it for static checks as described in -[Building the images](#building-the-images) and -[Force pulling and building the images](#force-pulling-the-images)). The static checks will fail and inform -what to do if the image is not yet built. - -Note that building image first time pulls the pre-built version of images from Dockerhub might take a bit -of time - but this wait-time will not repeat for any subsequent source code change. -However, changes to sensitive files like setup.py or Dockerfile will trigger a rebuild -that might take more time (but it is highly optimised to only rebuild what's needed) - -In most cases re-building an image requires connectivity to network (for example to download new -dependencies). In case you work offline and do not want to rebuild the images when needed - you might set -`ASSUME_NO_TO_ALL_QUESTIONS` variable to `true` as described in the -[Default behaviour for user interaction](#default-behaviour-for-user-interaction) chapter. - -See [Troubleshooting section](#troubleshooting) for steps you can make to clean the environment. - -## Default behaviour for user interaction - -Sometimes during the build user is asked whether to perform an action, skip it, or quit. This happens in case -of image rebuilding and image removal - they can take a lot of time and they are potentially destructive. -For automation scripts, you can export one of the three variables to control the default behaviour. -``` -export ASSUME_YES_TO_ALL_QUESTIONS="true" -``` -If `ASSUME_YES_TO_ALL_QUESTIONS` is set to `true`, the images will automatically rebuild when needed. -Images are deleted without asking. - -``` -export ASSUME_NO_TO_ALL_QUESTIONS="true" -``` -If `ASSUME_NO_TO_ALL_QUESTIONS` is set to `true`, the old images are used even if re-building is needed. -This is useful when you work offline. Deleting images is aborted. - -``` -export ASSUME_QUIT_TO_ALL_QUESTIONS="true" -``` -If `ASSUME_QUIT_TO_ALL_QUESTIONS` is set to `true`, the whole script is aborted. Deleting images is aborted. - -If more than one variable is set, YES takes precedence over NO which take precedence over QUIT. - -## Local Docker Compose scripts - -For your convenience, there are scripts that can be used in local development -- where local host sources are mounted to within the docker container. -Those "local" scripts starts with "local_" prefix in [scripts/ci](scripts/ci) folder and -they run Docker-Compose environment with relevant backends (mysql/postgres) -and additional services started. - -### Running the whole suite of tests - -Running all tests with default settings (python 3.6, sqlite backend, docker environment): - -```bash -./scripts/ci/local_ci_run_airflow_testing.sh -``` - -Selecting python version, backend, docker environment: - -```bash -PYTHON_VERSION=3.5 BACKEND=postgres ENV=docker ./scripts/ci/local_ci_run_airflow_testing.sh -``` - -Running kubernetes tests: -```bash -KUBERNETES_VERSION==v1.13.0 KUBERNETES_MODE=persistent_mode BACKEND=postgres ENV=kubernetes \ - ./scripts/ci/local_ci_run_airflow_testing.sh -``` - -* PYTHON_VERSION might be one of 3.5/3.6/3.7 -* BACKEND might be one of postgres/sqlite/mysql -* ENV might be one of docker/kubernetes/bare -* KUBERNETES_VERSION - required for Kubernetes tests - currently KUBERNETES_VERSION=v1.13.0. -* KUBERNETES_MODE - mode of kubernetes, one of persistent_mode, git_mode - -The following environments are possible: - - * The `docker` environment (default): starts all dependencies required by full integration test-suite - (postgres, mysql, celery, etc.). This option is resource intensive so do not forget to - [Stop environment](#stopping-the-environment) when you are finished. This option is also RAM intensive - and can slow down your machine. - * The `kubernetes` environment: Runs airflow tests within a kubernetes cluster (requires KUBERNETES_VERSION - and KUBERNETES_MODE variables). - * The `bare` environment: runs airflow in docker without any external dependencies. - It will only work for non-dependent tests. You can only run it with sqlite backend. You can only - enter the bare environment with `local_ci_enter_environment.sh` and run tests manually, you cannot execute - `local_ci_run_airflow_testing.sh` with it. - -Note: The Kubernetes environment will require setting up minikube/kubernetes so it -might require some host-network configuration. - -### Stopping the environment - -Docker-compose environment starts a number of docker containers and keep them running. -You can tear them down by running -[/scripts/ci/local_ci_stop_environment.sh](scripts/ci/local_ci_stop_environment.sh) - -### Fixing file/directory ownership - -On Linux there is a problem with propagating ownership of created files (known Docker problem). Basically -files and directories created in container are not owned by the host user (but by the root user in our case). -This might prevent you from switching branches for example if files owned by root user are created within -your sources. In case you are on Linux host and haa some files in your sources created by the root user, -you can fix the ownership of those files by running -[scripts/ci/local_ci_fix_ownership.sh](scripts/ci/local_ci_fix_ownership.sh) script. - -### Building the images - -You can manually trigger building of the local images using -[scripts/ci/local_ci_build.sh](scripts/ci/local_ci_build.sh). - -The scripts that build the images are optimised to minimise the time needed to rebuild the image when -the source code of Airflow evolves. This means that if you already had the image locally downloaded and built, -the scripts will determine, the rebuild is needed in the first place. Then it will make sure that minimal -number of steps are executed to rebuild the parts of image (for example PIP dependencies) that will give -you an image consistent with the one used during Continuous Integration. - -### Force pulling the images - -You can also force-pull the images before building them locally so that you are sure that you download -latest images from DockerHub repository before building. This can be done with -[scripts/ci/local_ci_pull_and_build.sh](scripts/ci/local_ci_pull_and_build.sh) script. - -## Cleaning up cached Docker images/containers - -Note that you might need to cleanup your Docker environment occasionally. The images are quite big -(1.5GB for both images needed for static code analysis and CI tests). And if you often rebuild/update -images you might end up with some unused image data. - -Cleanup can be performed with `docker system prune` command. - -If you run into disk space errors, we recommend you prune your docker images using the -`docker system prune --all` command. You might need to -[Stop the environment](#stopping-the-environment) or restart the docker engine before running this command. - -You can check if your docker is clean by running `docker images --all` and `docker ps --all` - both -should return an empty list of images and containers respectively. - -If you are on Mac OS and you end up with not enough disk space for Docker you should increase disk space -available for Docker. See [Docker for Mac - Space](https://docs.docker.com/docker-for-mac/space/) for details. - -## Troubleshooting - -If you are having problems with the Docker Compose environment - try the following (after each step you -can check if your problem is fixed) - -1. Check if you have [enough disk space](#prerequisites) in Docker if you are on MacOS. -2. [Stop the environment](#stopping-the-environment) -3. Delete [.build](.build) and [Force pull the images](#force-pulling-the-images) -4. [Clean Up Docker engine](#cleaning-up-cached-docker-imagescontainers) -5. [Fix file/directory ownership](#fixing-filedirectory-ownership) -6. Restart your docker engine and try again -7. Restart your machine and try again -8. Remove and re-install Docker CE, then start with [force pulling the images](#force-pulling-the-images) +The environment for Breeze runs in the background taking precious resources - disk space and CPU. You +can stop the environment manually after you use it or even use `bare` environment to decrease resource +usage. -In case the problems are not solved, you can set VERBOSE variable to "true" (`export VERBOSE="true"`) -and rerun failing command, and copy & paste the output from your terminal, describe the problem and -post it in [Airflow Slack](https://apache-airflow-slack.herokuapp.com/) #troubleshooting channel. +Note that the CI images are not supposed to be used in production environments. They are optimised +for repeatability of tests, maintainability and speed of building rather than production performance. +The production images are not yet available (but they will be). # Pylint checks @@ -629,8 +199,12 @@ as follows: 3) Fix all the issues reported by pylint 4) Re-run [scripts/ci/ci_pylint.sh](scripts/ci/ci_pylint.sh) 5) If you see "success" - submit PR following [Pull Request guidelines](#pull-request-guidelines) +6) You can refresh periodically [scripts/ci/pylint_todo.txt](scripts/ci/pylint_todo.txt) file. + You can do it by running + [scripts/ci/ci_refresh_pylint_todo.sh](scripts/ci/ci_refresh_pylint_todo.sh). + This can take quite some time (especially on MacOS)! -There are following guidelines when fixing pylint errors: +You can follow these guidelines when fixing pylint errors: * Ideally fix the errors rather than disable pylint checks - often you can easily refactor the code (IntelliJ/PyCharm might be helpful when extracting methods in complex code or moving methods around) @@ -670,7 +244,7 @@ dependencies are automatically installed). You can also install pre-commit manua The pre-commit hooks require Docker Engine to be configured as the static checks static checks are executed in docker environment. You should build the images locally before installing pre-commit checks as -described in [Building the images](#building-the-images). In case you do not have your local images built +described in [Breeze](BREEZE.rst). In case you do not have your local images built the pre-commit hooks fail and provide instructions on what needs to be done. ## Installing pre-commit hooks @@ -695,23 +269,23 @@ pre-commit install --help ## Docker images for pre-commit hooks -Before running the pre-commit hooks you must first build the docker images locally as described in -[Building the images](#building-the-images) chapter. +Before running the pre-commit hooks you must first build the docker images as described in +[BREEZE](BREEZE.rst). Sometimes your image is outdated (when dependencies change) and needs to be rebuilt because some -dependencies have been changed. In such case the docker build pre-commit will fail and inform -you that you should rebuild the image with REBUILD="true" environment variable set. +dependencies have been changed. In such case the docker build pre-commit will inform +you that you should rebuild the image. ## Prerequisites for pre-commit hooks The pre-commit hooks use several external linters that need to be installed before pre-commit are run. -Most of the linters are installed by running `pip install -e .[devel]` in the airflow sources as they -are python-only, however there are some that should be installed locally using different methods. -In Linux you typically install them with `sudo apt install` on MacOS with `brew install`. +Each of the checks install its own environment, so you do not need to install those, but there are some +checks that require locally installed binaries. In Linux you typically install them +with `sudo apt install` on MacOS with `brew install`. The current list of prerequisites: -* xmllint: Linux - install via `sudo apt install xmllint`, MacOS - install via `brew install xmllint` +* `xmllint`: Linux - install via `sudo apt install xmllint`, MacOS - install via `brew install xmllint` ## Pre-commit hooks installed diff --git a/Dockerfile-checklicence b/Dockerfile-checklicence index 771fd2372ff8c7..171006be7c6dd8 100644 --- a/Dockerfile-checklicence +++ b/Dockerfile-checklicence @@ -53,6 +53,20 @@ RUN echo "Downloading RAT from ${RAT_URL} to ${RAT_JAR}" \ && jar -tf "${RAT_JAR}" \ && md5sum -c <<<"$(cat "${RAT_JAR_MD5}") ${RAT_JAR}" +ARG DUMB_INIT_VERSION="1.2.2" +ENV DUMB_INIT_VERSION="${DUMB_INIT_VERSION}" \ + DUMB_INIT_FILE="/tmp/dumb-init_${DUMB_INIT_VERSION}_amd64.deb" \ + DUMB_INIT_FILE_SHA256SUMS="/tmp/sha256sums" \ + DUMB_INIT_URL="https://github.com/Yelp/dumb-init/releases/download/v${DUMB_INIT_VERSION}/dumb-init_${DUMB_INIT_VERSION}_amd64.deb" \ + DUMB_INIT_URL_SHA256SUMS="https://github.com/Yelp/dumb-init/releases/download/v${DUMB_INIT_VERSION}/sha256sums" + +WORKDIR /tmp +RUN echo "Downloading dumb-init from ${DUMB_INIT_URL} to ${DUMB_INIT_FILE}" \ + && curl -sL "${DUMB_INIT_URL}" > "${DUMB_INIT_FILE}" \ + && curl -sL "${DUMB_INIT_URL_SHA256SUMS}" > "${DUMB_INIT_FILE_SHA256SUMS}" \ + && sha256sum --check --ignore-missing "${DUMB_INIT_FILE_SHA256SUMS}" \ + && dpkg -i "${DUMB_INIT_FILE}" + ARG AIRFLOW_USER=airflow ENV AIRFLOW_USER=${AIRFLOW_USER} diff --git a/LOCAL_VIRTUALENV.rst b/LOCAL_VIRTUALENV.rst new file mode 100644 index 00000000000000..5dfae63a39fcbd --- /dev/null +++ b/LOCAL_VIRTUALENV.rst @@ -0,0 +1,172 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. contents:: :local: + + +Local virtualenv environment +============================ + +Installation +------------ + +Install Python (3.5 or 3.6), MySQL, and libxml by using system-level +package managers like yum, apt-get for Linux, or Homebrew for Mac OS at +first. Refer to the `Dockerfile `__ for a comprehensive list +of required packages. + +In order to use your IDE you need you can use the virtual environment. +Ideally you should setup virtualenv for all python versions that Airflow +supports (3.5, 3.6). An easy way to create the virtualenv is to use +`virtualenvwrapper `__ +- it allows you to easily switch between virtualenvs using ``workon`` +command and mange your virtual environments more easily. Typically +creating the environment can be done by: + +.. code:: bash + + mkvirtualenv --python=python + +Then you need to install python PIP requirements. Typically it can be +done with: ``pip install -e ".[devel]"``. + +After creating the virtualenv, run this command to create the Airflow +sqlite database: + +.. code:: bash + + airflow db init + + +Creating virtualenv can be automated with `Breeze environment `_ + +Once initialization is done, you should select the virtualenv you +initialized as the project's default virtualenv in your IDE. + +After setting it up - you can use the usual "Run Test" option of the IDE +and have the autocomplete and documentation support from IDE as well as +you can debug and view the sources of Airflow - which is very helpful +during development. + +Installing other extras +----------------------- + +You can also other extras (like ``[mysql]``, ``[gcp]`` etc. via +``pip install -e [EXTRA1,EXTRA2 ...]``. However some of the extras have additional +system requirements and you might need to install additional packages on your +local machine. + +For example if you have trouble installing mysql client on MacOS and you have +an error similar to + +.. code:: text + + ld: library not found for -lssl + +you should set LIBRARY\_PATH before running ``pip install``: + +.. code:: bash + + export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/opt/openssl/lib/ + +The full list of extras is available in ``_ + + +Running individual tests +------------------------ + +Once you activate virtualenv (or enter docker container) as described +below you should be able to run ``run-tests`` at will (it is in the path +in Docker environment but you need to prepend it with ``./`` in local +virtualenv (``./run-tests``). + +Note that this script has several flags that can be useful for your +testing. + +.. code:: text + + Usage: run-tests [FLAGS] [TESTS_TO_RUN] -- + + Runs tests specified (or all tests if no tests are specified) + + Flags: + + -h, --help + Shows this help message. + + -i, --with-db-init + Forces database initialization before tests + + -s, --nocapture + Don't capture stdout when running the tests. This is useful if you are + debugging with ipdb and want to drop into console with it + by adding this line to source code: + + import ipdb; ipdb.set_trace() + + -v, --verbose + Verbose output showing coloured output of tests being run and summary + of the tests - in a manner similar to the tests run in the CI environment. + +You can pass extra parameters to nose, by adding nose arguments after +``--`` + +For example, in order to just execute the "core" unit tests and add ipdb +set\_trace method, you can run the following command: + +.. code:: bash + + ./run-tests tests.core:TestCore --nocapture --verbose + +or a single test method without colors or debug logs: + +.. code:: bash + + ./run-tests tests.core:TestCore.test_check_operators + +Note that ``./run_tests`` script runs tests but the first time it runs, +it performs database initialisation. If you run further tests without +leaving the environment, the database will not be initialized, but you +can always force database initialization with ``--with-db-init`` +(``-i``) switch. The scripts will inform you what you can do when they +are run. + +Running tests directly from the IDE +----------------------------------- + +Once you configure your tests to use the virtualenv you created. running +tests from IDE is as simple as: + +.. figure:: images/run_unittests.png + :alt: Run unittests + + +Running integration tests +------------------------- + +Note that while most of the tests are typical "unit" tests that do not +require external components, there are a number of tests that are more +of "integration" or even "system" tests. You can technically use local +virtualenv to run those tests, but it requires to setup a number of +external components (databases/queues/kubernetes and the like) so it is +much easier to use the `Breeze development environment `_ +for those tests. + +Note - soon we will separate the integration and system tests out +so that you can clearly know which tests are unit tests and can be run in +the local virtualenv and which should be run using Breeze. diff --git a/README.md b/README.md index 0967dfd34896a8..c829293081171c 100644 --- a/README.md +++ b/README.md @@ -159,6 +159,7 @@ Currently **officially** using Airflow: 1. [BelugaDB](https://belugadb.com) [[@fabio-nukui](https://github.com/fabio-nukui) & [@joao-sallaberry](http://github.com/joao-sallaberry) & [@lucianoviola](https://github.com/lucianoviola) & [@tmatuki](https://github.com/tmatuki)] 1. [Betterment](https://www.betterment.com/) [[@betterment](https://github.com/Betterment)] 1. [Bexs Bank](https://www.bexs.com.br/en) [[@felipefb](https://github.com/felipefb) & [@ilarsen](https://github.com/ishvann)] +1. [BigQuant](https://bigquant.com/) [[@bigquant](https://github.com/bigquant)] 1. [BlaBlaCar](https://www.blablacar.com) [[@puckel](https://github.com/puckel) & [@wmorin](https://github.com/wmorin)] 1. [Blacklane](https://www.blacklane.com) [[@serkef](https://github.com/serkef)] 1. [Bloc](https://www.bloc.io) [[@dpaola2](https://github.com/dpaola2)] @@ -207,11 +208,13 @@ Currently **officially** using Airflow: 1. [CreditCards.com](https://www.creditcards.com/)[[@vmAggies](https://github.com/vmAggies) & [@jay-wallaby](https://github.com/jay-wallaby)] 1. [Cryptalizer.com](https://www.cryptalizer.com/) 1. [Custom Ink](https://www.customink.com/) [[@david-dalisay](https://github.com/david-dalisay), [@dmartin11](https://github.com/dmartin11) & [@mpeteuil](https://github.com/mpeteuil)] +1. [Cyscale](https://cyscale.com) [[@ocical](https://github.com/ocical)] 1. [Dailymotion](http://www.dailymotion.com/fr) [[@germaintanguy](https://github.com/germaintanguy) & [@hc](https://github.com/hc)] 1. [Danamica](https://www.danamica.dk) [[@testvinder](https://github.com/testvinder)] 1. [Data Reply](https://www.datareply.co.uk/) [[@kaxil](https://github.com/kaxil)] 1. [DataCamp](https://datacamp.com/) [[@dgrtwo](https://github.com/dgrtwo)] 1. [DataFox](https://www.datafox.com/) [[@sudowork](https://github.com/sudowork)] +1. [Dentsu Inc.](http://www.dentsu.com/) [[@bryan831](https://github.com/bryan831) & [@loozhengyuan](https://github.com/loozhengyuan)] 1. [Digital First Media](http://www.digitalfirstmedia.com/) [[@duffn](https://github.com/duffn) & [@mschmo](https://github.com/mschmo) & [@seanmuth](https://github.com/seanmuth)] 1. [DigitalOcean](https://digitalocean.com/) [[@ajbosco](https://github.com/ajbosco)] 1. [DoorDash](https://www.doordash.com/) @@ -311,6 +314,7 @@ Currently **officially** using Airflow: 1. [Modernizing Medicine](https://www.modmed.com/)[[@kehv1n](https://github.com/kehv1n), [@dalupus](https://github.com/dalupus)] 1. [Multiply](https://www.multiply.com) [[@nrhvyc](https://github.com/nrhvyc)] 1. [mytaxi](https://mytaxi.com) [[@mytaxi](https://github.com/mytaxi)] +1. [National Bank of Canada](https://nbc.ca) [[@brilhana](https://github.com/brilhana)] 1. [Neoway](https://www.neoway.com.br/) [[@neowaylabs](https://github.com/orgs/NeowayLabs/people)] 1. [Nerdwallet](https://www.nerdwallet.com) 1. [New Relic](https://www.newrelic.com) [[@marcweil](https://github.com/marcweil)] @@ -332,6 +336,7 @@ Currently **officially** using Airflow: 1. [PayFit](https://payfit.com) [[@pcorbel](https://github.com/pcorbel)] 1. [PAYMILL](https://www.paymill.com/) [[@paymill](https://github.com/paymill) & [@matthiashuschle](https://github.com/matthiashuschle)] 1. [PayPal](https://www.paypal.com/) [[@r39132](https://github.com/r39132) & [@jhsenjaliya](https://github.com/jhsenjaliya)] +1. [Pecan](https://www.pecan.ai) [[@ohadmata](https://github.com/ohadmata)] 1. [Pernod-Ricard](https://www.pernod-ricard.com/) [[@romain-nio](https://github.com/romain-nio)] 1. [Plaid](https://www.plaid.com/) [[@plaid](https://github.com/plaid), [@AustinBGibbons](https://github.com/AustinBGibbons) & [@jeeyoungk](https://github.com/jeeyoungk)] 1. [Playbuzz](https://www.playbuzz.com/) [[@clintonboys](https://github.com/clintonboys) & [@dbn](https://github.com/dbn)] @@ -408,6 +413,7 @@ Currently **officially** using Airflow: 1. [Vidio](https://www.vidio.com/) 1. [Ville de Montréal](http://ville.montreal.qc.ca/)[@VilledeMontreal](https://github.com/VilledeMontreal/)] 1. [Vnomics](https://github.com/vnomics) [[@lpalum](https://github.com/lpalum)] +1. [Walmart Labs](https://www.walmartlabs.com) [[@bharathpalaksha](https://github.com/bharathpalaksha)] 1. [Waze](https://www.waze.com) [[@waze](https://github.com/wazeHQ)] 1. [WePay](http://www.wepay.com) [[@criccomini](https://github.com/criccomini) & [@mtagle](https://github.com/mtagle)] 1. [WeTransfer](https://github.com/WeTransfer) [[@coredipper](https://github.com/coredipper) & [@higee](https://github.com/higee) & [@azclub](https://github.com/azclub)] diff --git a/UPDATING.md b/UPDATING.md index e473078b487940..80355aa817e1f0 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -40,7 +40,6 @@ assists users migrating to a new version. ## Airflow Master - ### Changes to `aws_default` Connection's default region The region of Airflow's default connection to AWS (`aws_default`) has diff --git a/airflow/api/auth/backend/kerberos_auth.py b/airflow/api/auth/backend/kerberos_auth.py index 6db971383a34cd..b1bff2eb106d54 100644 --- a/airflow/api/auth/backend/kerberos_auth.py +++ b/airflow/api/auth/backend/kerberos_auth.py @@ -1,4 +1,21 @@ # -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + # # Copyright (c) 2013, Michael Komitee # All rights reserved. diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py index 5ec32a31766e96..e6fc78c563f180 100644 --- a/airflow/api/common/experimental/delete_dag.py +++ b/airflow/api/common/experimental/delete_dag.py @@ -17,14 +17,13 @@ # specific language governing permissions and limitations # under the License. """Delete DAGs APIs.""" -import os from sqlalchemy import or_ from airflow import models from airflow.models import TaskFail, DagModel from airflow.utils.db import provide_session -from airflow.exceptions import DagFileExists, DagNotFound +from airflow.exceptions import DagNotFound @provide_session @@ -41,10 +40,6 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i if dag is None: raise DagNotFound("Dag id {} not found".format(dag_id)) - if dag.fileloc and os.path.exists(dag.fileloc): - raise DagFileExists("Dag id {} is still in DagBag. " - "Remove the DAG file first: {}".format(dag_id, dag.fileloc)) - count = 0 # noinspection PyUnresolvedReferences,PyProtectedMember diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py index be848caf36039c..c4528357493042 100644 --- a/airflow/api/common/experimental/mark_tasks.py +++ b/airflow/api/common/experimental/mark_tasks.py @@ -74,7 +74,7 @@ def set_state( for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). - :param task: the task from which to work. task.task.dag needs to be set + :param tasks: the iterable of tasks from which to work. task.task.dag needs to be set :param execution_date: the execution date from which to start looking :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index f2134c262dddef..cebc2c6b3ae501 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -17,7 +17,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import errno import importlib import logging @@ -34,6 +34,7 @@ import argparse from argparse import RawTextHelpFormatter +from airflow.utils.dot_renderer import render_dag from airflow.utils.timezone import parse as parsedate import json from tabulate import tabulate @@ -441,6 +442,33 @@ def set_is_paused(is_paused, args): print("Dag: {}, paused: {}".format(args.dag_id, str(is_paused))) +def show_dag(args): + dag = get_dag(args) + dot = render_dag(dag) + if args.save: + filename, _, fileformat = args.save.rpartition('.') + dot.render(filename=filename, format=fileformat, cleanup=True) + print("File {} saved".format(args.save)) + elif args.imgcat: + data = dot.pipe(format='png') + try: + proc = subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) + except OSError as e: + if e.errno == errno.ENOENT: + raise AirflowException( + "Failed to execute. Make sure the imgcat executables are on your systems \'PATH\'" + ) + else: + raise + out, err = proc.communicate(data) + if out: + print(out.decode('utf-8')) + if err: + print(err.decode('utf-8')) + else: + print(dot.source) + + def _run(args, dag, ti): if args.local: run_job = jobs.LocalTaskJob( @@ -1789,6 +1817,26 @@ class CLIFactory: 'dag_regex': Arg( ("-dx", "--dag_regex"), "Search dag_id as regex instead of exact string", "store_true"), + # show_dag + 'save': Arg( + ("-s", "--save"), + "Saves the result to the indicated file.\n" + "\n" + "The file format is determined by the file extension. For more information about supported " + "format, see: https://www.graphviz.org/doc/info/output.html\n" + "\n" + "If you want to create a PNG file then you should execute the following command:\n" + "airflow dags show --save output.png\n" + "\n" + "If you want to create a DOT file then you should execute the following command:\n" + "airflow dags show --save output.dot\n" + ), + 'imgcat': Arg( + ("--imgcat", ), + "Displays graph using the imgcat tool. \n" + "\n" + "For more information, see: https://www.iterm2.com/documentation-images.html", + action='store_true'), # trigger_dag 'run_id': Arg(("-r", "--run_id"), "Helps to identify this run"), 'conf': Arg( @@ -1823,7 +1871,7 @@ class CLIFactory: help="Variable key"), 'var_value': Arg( ("value",), - metavar=('VALUE'), + metavar='VALUE', help="Variable value"), 'default': Arg( ("-d", "--default"), @@ -2182,6 +2230,12 @@ class CLIFactory: 'help': "Delete all DB records related to the specified DAG", 'args': ('dag_id', 'yes'), }, + { + 'func': show_dag, + 'name': 'show', + 'help': "Displays DAG's tasks with their dependencies", + 'args': ('dag_id', 'subdir', 'save', 'imgcat',), + }, { 'func': backfill, 'name': 'backfill', diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 5e444009fb8245..14995e7a3df0a2 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -122,9 +122,10 @@ sql_alchemy_max_overflow = 10 # a lower config value will allow the system to recover faster. sql_alchemy_pool_recycle = 1800 -# How many seconds to retry re-establishing a DB connection after -# disconnects. Setting this to 0 disables retries. -sql_alchemy_reconnect_timeout = 300 +# Check connection at the start of each connection pool checkout. +# Typically, this is a simple statement like “SELECT 1”. +# More information here: https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic +sql_alchemy_pool_pre_ping = True # The schema to use for the metadata database # SqlAlchemy supports databases with the concept of multiple schemas. diff --git a/airflow/config_templates/default_celery.py b/airflow/config_templates/default_celery.py index 4a6da2a50728e2..35a7c510ed810c 100644 --- a/airflow/config_templates/default_celery.py +++ b/airflow/config_templates/default_celery.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Default celery configuration.""" import ssl from airflow.configuration import conf diff --git a/airflow/contrib/example_dags/example_gcs_acl.py b/airflow/contrib/example_dags/example_gcs_acl.py deleted file mode 100644 index 27713eb9ede807..00000000000000 --- a/airflow/contrib/example_dags/example_gcs_acl.py +++ /dev/null @@ -1,77 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that creates a new ACL entry on the specified bucket and object. - -This DAG relies on the following OS environment variables - -* GCS_ACL_BUCKET - Name of a bucket. -* GCS_ACL_OBJECT - Name of the object. For information about how to URL encode object - names to be path safe, see: - https://cloud.google.com/storage/docs/json_api/#encoding -* GCS_ACL_ENTITY - The entity holding the permission. -* GCS_ACL_BUCKET_ROLE - The access permission for the entity for the bucket. -* GCS_ACL_OBJECT_ROLE - The access permission for the entity for the object. -""" - -import os - -import airflow -from airflow import models -from airflow.contrib.operators.gcs_acl_operator import \ - GoogleCloudStorageBucketCreateAclEntryOperator, \ - GoogleCloudStorageObjectCreateAclEntryOperator - -# [START howto_operator_gcs_acl_args_common] -GCS_ACL_BUCKET = os.environ.get('GCS_ACL_BUCKET', 'example-bucket') -GCS_ACL_OBJECT = os.environ.get('GCS_ACL_OBJECT', 'example-object') -GCS_ACL_ENTITY = os.environ.get('GCS_ACL_ENTITY', 'example-entity') -GCS_ACL_BUCKET_ROLE = os.environ.get('GCS_ACL_BUCKET_ROLE', 'example-bucket-role') -GCS_ACL_OBJECT_ROLE = os.environ.get('GCS_ACL_OBJECT_ROLE', 'example-object-role') -# [END howto_operator_gcs_acl_args_common] - -default_args = { - 'start_date': airflow.utils.dates.days_ago(1) -} - -with models.DAG( - 'example_gcs_acl', - default_args=default_args, - schedule_interval=None # Change to match your use case -) as dag: - # [START howto_operator_gcs_bucket_create_acl_entry_task] - gcs_bucket_create_acl_entry_task = GoogleCloudStorageBucketCreateAclEntryOperator( - bucket=GCS_ACL_BUCKET, - entity=GCS_ACL_ENTITY, - role=GCS_ACL_BUCKET_ROLE, - task_id="gcs_bucket_create_acl_entry_task" - ) - # [END howto_operator_gcs_bucket_create_acl_entry_task] - # [START howto_operator_gcs_object_create_acl_entry_task] - gcs_object_create_acl_entry_task = GoogleCloudStorageObjectCreateAclEntryOperator( - bucket=GCS_ACL_BUCKET, - object_name=GCS_ACL_OBJECT, - entity=GCS_ACL_ENTITY, - role=GCS_ACL_OBJECT_ROLE, - task_id="gcs_object_create_acl_entry_task" - ) - # [END howto_operator_gcs_object_create_acl_entry_task] - - gcs_bucket_create_acl_entry_task >> gcs_object_create_acl_entry_task diff --git a/airflow/contrib/example_dags/example_gcs_to_gdrive.py b/airflow/contrib/example_dags/example_gcs_to_gdrive.py new file mode 100644 index 00000000000000..f6b185953c5dd9 --- /dev/null +++ b/airflow/contrib/example_dags/example_gcs_to_gdrive.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG using GoogleCloudStorageToGoogleDriveOperator. +""" +import os + +import airflow +from airflow import models + +from airflow.contrib.operators.gcs_to_gdrive_operator import GcsToGDriveOperator + +GCS_TO_GDRIVE_BUCKET = os.environ.get("GCS_TO_DRIVE_BUCKET", "example-object") + +default_args = {"start_date": airflow.utils.dates.days_ago(1)} + +with models.DAG( + "example_gcs_to_gdrive", default_args=default_args, schedule_interval=None # Override to match your needs +) as dag: + # [START howto_operator_gcs_to_gdrive_copy_single_file] + copy_single_file = GcsToGDriveOperator( + task_id="copy_single_file", + source_bucket=GCS_TO_GDRIVE_BUCKET, + source_object="sales/january.avro", + destination_object="copied_sales/january-backup.avro", + ) + # [END howto_operator_gcs_to_gdrive_copy_single_file] + # [START howto_operator_gcs_to_gdrive_copy_files] + copy_files = GcsToGDriveOperator( + task_id="copy_files", + source_bucket=GCS_TO_GDRIVE_BUCKET, + source_object="sales/*", + destination_object="copied_sales/", + ) + # [END howto_operator_gcs_to_gdrive_copy_files] + # [START howto_operator_gcs_to_gdrive_move_files] + move_files = GcsToGDriveOperator( + task_id="move_files", + source_bucket=GCS_TO_GDRIVE_BUCKET, + source_object="sales/*.avro", + move_object=True, + ) + # [END howto_operator_gcs_to_gdrive_move_files] diff --git a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable index 2d8906b0b3b9dd..c0c3df61f5c137 100644 --- a/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable +++ b/airflow/contrib/example_dags/example_jenkins_job_trigger_operator.py.notexecutable @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -75,7 +75,6 @@ def grabArtifactFromJenkins(**context): artifact_grabber = PythonOperator( task_id='artifact_grabber', - provide_context=True, python_callable=grabArtifactFromJenkins, dag=dag) diff --git a/airflow/contrib/example_dags/example_qubole_operator.py b/airflow/contrib/example_dags/example_qubole_operator.py index 1f7e2a8ce9d8f2..ef4681a85b7982 100644 --- a/airflow/contrib/example_dags/example_qubole_operator.py +++ b/airflow/contrib/example_dags/example_qubole_operator.py @@ -97,7 +97,6 @@ def compare_result(**kwargs): t3 = PythonOperator( task_id='compare_result', - provide_context=True, python_callable=compare_result, trigger_rule="all_done", dag=dag) diff --git a/airflow/contrib/hooks/aws_sqs_hook.py b/airflow/contrib/hooks/aws_sqs_hook.py index f559a422d41cb9..2c0a671018f516 100644 --- a/airflow/contrib/hooks/aws_sqs_hook.py +++ b/airflow/contrib/hooks/aws_sqs_hook.py @@ -25,13 +25,16 @@ class SQSHook(AwsHook): """ - Get the SQS client using boto3 library - - :return: SQS client - :rtype: botocore.client.SQS + Interact with Amazon Simple Queue Service. """ def get_conn(self): + """ + Get the SQS client using boto3 library + + :return: SQS client + :rtype: botocore.client.SQS + """ return self.get_client_type('sqs') def create_queue(self, queue_name, attributes=None): diff --git a/airflow/contrib/hooks/azure_cosmos_hook.py b/airflow/contrib/hooks/azure_cosmos_hook.py index 01b4007b0308fd..8a3cc41d51805d 100644 --- a/airflow/contrib/hooks/azure_cosmos_hook.py +++ b/airflow/contrib/hooks/azure_cosmos_hook.py @@ -16,6 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +This module contains integration with Azure CosmosDB. + +AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a +Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a +login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify +the default database and collection to use (see connection `azure_cosmos_default` for an example). +""" import azure.cosmos.cosmos_client as cosmos_client from azure.cosmos.errors import HTTPFailure import uuid diff --git a/airflow/contrib/hooks/azure_data_lake_hook.py b/airflow/contrib/hooks/azure_data_lake_hook.py index 9eb7af7f8a71e0..186cae5d1d4858 100644 --- a/airflow/contrib/hooks/azure_data_lake_hook.py +++ b/airflow/contrib/hooks/azure_data_lake_hook.py @@ -17,7 +17,14 @@ # specific language governing permissions and limitations # under the License. # - +""" +This module contains integration with Azure Data Lake. + +AzureDataLakeHook communicates via a REST API compatible with WebHDFS. Make sure that a +Airflow connection of type `azure_data_lake` exists. Authorization can be done by supplying a +login (=Client ID), password (=Client Secret) and extra fields tenant (Tenant) and account_name (Account Name) +(see connection `azure_data_lake_default` for an example). +""" from airflow.hooks.base_hook import BaseHook from azure.datalake.store import core, lib, multithread diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index 101670469b2a6a..3c702478d93290 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- # pylint: disable=too-many-lines +# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,2248 +16,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -""" -This module contains a BigQuery Hook, as well as a very basic PEP 249 -implementation for BigQuery. -""" - -import time -from copy import deepcopy - -from googleapiclient.discovery import build -from googleapiclient.errors import HttpError -from pandas_gbq.gbq import \ - _check_google_client_version as gbq_check_google_client_version -from pandas_gbq import read_gbq -from pandas_gbq.gbq import \ - _test_google_api_imports as gbq_test_google_api_imports -from pandas_gbq.gbq import GbqConnector - -from airflow import AirflowException -from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook -from airflow.hooks.dbapi_hook import DbApiHook -from airflow.utils.log.logging_mixin import LoggingMixin - - -class BigQueryHook(GoogleCloudBaseHook, DbApiHook): - """ - Interact with BigQuery. This hook uses the Google Cloud Platform - connection. - """ - conn_name_attr = 'bigquery_conn_id' - - def __init__(self, - bigquery_conn_id='google_cloud_default', - delegate_to=None, - use_legacy_sql=True, - location=None): - super().__init__( - gcp_conn_id=bigquery_conn_id, delegate_to=delegate_to) - self.use_legacy_sql = use_legacy_sql - self.location = location - - def get_conn(self): - """ - Returns a BigQuery PEP 249 connection object. - """ - service = self.get_service() - return BigQueryConnection( - service=service, - project_id=self.project_id, - use_legacy_sql=self.use_legacy_sql, - location=self.location, - num_retries=self.num_retries - ) - - def get_service(self): - """ - Returns a BigQuery service object. - """ - http_authorized = self._authorize() - return build( - 'bigquery', 'v2', http=http_authorized, cache_discovery=False) - - def insert_rows(self, table, rows, target_fields=None, commit_every=1000): - """ - Insertion is currently unsupported. Theoretically, you could use - BigQuery's streaming API to insert rows into a table, but this hasn't - been implemented. - """ - raise NotImplementedError() - - def get_pandas_df(self, sql, parameters=None, dialect=None): - """ - Returns a Pandas DataFrame for the results produced by a BigQuery - query. The DbApiHook method must be overridden because Pandas - doesn't support PEP 249 connections, except for SQLite. See: - - https://github.com/pydata/pandas/blob/master/pandas/io/sql.py#L447 - https://github.com/pydata/pandas/issues/6900 - - :param sql: The BigQuery SQL to execute. - :type sql: str - :param parameters: The parameters to render the SQL query with (not - used, leave to override superclass method) - :type parameters: mapping or iterable - :param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL - defaults to use `self.use_legacy_sql` if not specified - :type dialect: str in {'legacy', 'standard'} - """ - private_key = self._get_field('key_path', None) or self._get_field('keyfile_dict', None) - - if dialect is None: - dialect = 'legacy' if self.use_legacy_sql else 'standard' - - return read_gbq(sql, - project_id=self.project_id, - dialect=dialect, - verbose=False, - private_key=private_key) - - def table_exists(self, project_id, dataset_id, table_id): - """ - Checks for the existence of a table in Google BigQuery. - - :param project_id: The Google cloud project in which to look for the - table. The connection supplied to the hook must provide access to - the specified project. - :type project_id: str - :param dataset_id: The name of the dataset in which to look for the - table. - :type dataset_id: str - :param table_id: The name of the table to check the existence of. - :type table_id: str - """ - service = self.get_service() - try: - service.tables().get( # pylint: disable=no-member - projectId=project_id, datasetId=dataset_id, - tableId=table_id).execute(num_retries=self.num_retries) - return True - except HttpError as e: - if e.resp['status'] == '404': - return False - raise - - -class BigQueryPandasConnector(GbqConnector): - """ - This connector behaves identically to GbqConnector (from Pandas), except - that it allows the service to be injected, and disables a call to - self.get_credentials(). This allows Airflow to use BigQuery with Pandas - without forcing a three legged OAuth connection. Instead, we can inject - service account credentials into the binding. - """ - - def __init__(self, - project_id, - service, - reauth=False, - verbose=False, - dialect='legacy'): - super().__init__(project_id) - gbq_check_google_client_version() - gbq_test_google_api_imports() - self.project_id = project_id - self.reauth = reauth - self.service = service - self.verbose = verbose - self.dialect = dialect - - -class BigQueryConnection: - """ - BigQuery does not have a notion of a persistent connection. Thus, these - objects are small stateless factories for cursors, which do all the real - work. - """ - - def __init__(self, *args, **kwargs): - self._args = args - self._kwargs = kwargs - - def close(self): - """ BigQueryConnection does not have anything to close. """ - - def commit(self): - """ BigQueryConnection does not support transactions. """ - - def cursor(self): - """ Return a new :py:class:`Cursor` object using the connection. """ - return BigQueryCursor(*self._args, **self._kwargs) - - def rollback(self): - """ BigQueryConnection does not have transactions """ - raise NotImplementedError( - "BigQueryConnection does not have transactions") - - -class BigQueryBaseCursor(LoggingMixin): - """ - The BigQuery base cursor contains helper methods to execute queries against - BigQuery. The methods can be used directly by operators, in cases where a - PEP 249 cursor isn't needed. - """ - - def __init__(self, - service, - project_id, - use_legacy_sql=True, - api_resource_configs=None, - location=None, - num_retries=5): - - self.service = service - self.project_id = project_id - self.use_legacy_sql = use_legacy_sql - if api_resource_configs: - _validate_value("api_resource_configs", api_resource_configs, dict) - self.api_resource_configs = api_resource_configs \ - if api_resource_configs else {} - self.running_job_id = None - self.location = location - self.num_retries = num_retries - - # pylint: disable=too-many-arguments - def create_empty_table(self, - project_id, - dataset_id, - table_id, - schema_fields=None, - time_partitioning=None, - cluster_fields=None, - labels=None, - view=None, - encryption_configuration=None, - num_retries=5): - """ - Creates a new, empty table in the dataset. - To create a view, which is defined by a SQL query, parse a dictionary to 'view' kwarg - - :param project_id: The project to create the table into. - :type project_id: str - :param dataset_id: The dataset to create the table into. - :type dataset_id: str - :param table_id: The Name of the table to be created. - :type table_id: str - :param schema_fields: If set, the schema field list as defined here: - https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema - :type schema_fields: list - :param labels: a dictionary containing labels for the table, passed to BigQuery - :type labels: dict - - **Example**: :: - - schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] - - :param time_partitioning: configure optional time partitioning fields i.e. - partition by field, type and expiration as per API specifications. - - .. seealso:: - https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning - :type time_partitioning: dict - :param cluster_fields: [Optional] The fields used for clustering. - Must be specified with time_partitioning, data in the table will be first - partitioned and subsequently clustered. - https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#clustering.fields - :type cluster_fields: list - :param view: [Optional] A dictionary containing definition for the view. - If set, it will create a view instead of a table: - https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#ViewDefinition - :type view: dict - - **Example**: :: - - view = { - "query": "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 1000", - "useLegacySql": False - } - - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - :return: None - """ - - project_id = project_id if project_id is not None else self.project_id - - table_resource = { - 'tableReference': { - 'tableId': table_id - } - } - - if schema_fields: - table_resource['schema'] = {'fields': schema_fields} - - if time_partitioning: - table_resource['timePartitioning'] = time_partitioning - - if cluster_fields: - table_resource['clustering'] = { - 'fields': cluster_fields - } - - if labels: - table_resource['labels'] = labels - - if view: - table_resource['view'] = view - - if encryption_configuration: - table_resource["encryptionConfiguration"] = encryption_configuration - - num_retries = num_retries if num_retries else self.num_retries - - self.log.info('Creating Table %s:%s.%s', - project_id, dataset_id, table_id) - - try: - self.service.tables().insert( - projectId=project_id, - datasetId=dataset_id, - body=table_resource).execute(num_retries=num_retries) - - self.log.info('Table created successfully: %s:%s.%s', - project_id, dataset_id, table_id) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) - - def create_external_table(self, # pylint: disable=too-many-locals,too-many-arguments - external_project_dataset_table, - schema_fields, - source_uris, - source_format='CSV', - autodetect=False, - compression='NONE', - ignore_unknown_values=False, - max_bad_records=0, - skip_leading_rows=0, - field_delimiter=',', - quote_character=None, - allow_quoted_newlines=False, - allow_jagged_rows=False, - src_fmt_configs=None, - labels=None, - encryption_configuration=None - ): - """ - Creates a new external table in the dataset with the data in Google - Cloud Storage. See here: - - https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource - - for more details about these parameters. - - :param external_project_dataset_table: - The dotted ``(.|:).($)`` BigQuery - table name to create external table. - If ```` is not included, project will be the - project defined in the connection json. - :type external_project_dataset_table: str - :param schema_fields: The schema field list as defined here: - https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource - :type schema_fields: list - :param source_uris: The source Google Cloud - Storage URI (e.g. gs://some-bucket/some-file.txt). A single wild - per-object name can be used. - :type source_uris: list - :param source_format: File format to export. - :type source_format: str - :param autodetect: Try to detect schema and format options automatically. - Any option specified explicitly will be honored. - :type autodetect: bool - :param compression: [Optional] The compression type of the data source. - Possible values include GZIP and NONE. - The default value is NONE. - This setting is ignored for Google Cloud Bigtable, - Google Cloud Datastore backups and Avro formats. - :type compression: str - :param ignore_unknown_values: [Optional] Indicates if BigQuery should allow - extra values that are not represented in the table schema. - If true, the extra values are ignored. If false, records with extra columns - are treated as bad records, and if there are too many bad records, an - invalid error is returned in the job result. - :type ignore_unknown_values: bool - :param max_bad_records: The maximum number of bad records that BigQuery can - ignore when running the job. - :type max_bad_records: int - :param skip_leading_rows: Number of rows to skip when loading from a CSV. - :type skip_leading_rows: int - :param field_delimiter: The delimiter to use when loading from a CSV. - :type field_delimiter: str - :param quote_character: The value that is used to quote data sections in a CSV - file. - :type quote_character: str - :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not - (false). - :type allow_quoted_newlines: bool - :param allow_jagged_rows: Accept rows that are missing trailing optional columns. - The missing values are treated as nulls. If false, records with missing - trailing columns are treated as bad records, and if there are too many bad - records, an invalid error is returned in the job result. Only applicable when - soure_format is CSV. - :type allow_jagged_rows: bool - :param src_fmt_configs: configure optional fields specific to the source format - :type src_fmt_configs: dict - :param labels: a dictionary containing labels for the table, passed to BigQuery - :type labels: dict - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - """ - - if src_fmt_configs is None: - src_fmt_configs = {} - project_id, dataset_id, external_table_id = \ - _split_tablename(table_input=external_project_dataset_table, - default_project_id=self.project_id, - var_name='external_project_dataset_table') - - # bigquery only allows certain source formats - # we check to make sure the passed source format is valid - # if it's not, we raise a ValueError - # Refer to this link for more details: - # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#externalDataConfiguration.sourceFormat # noqa # pylint: disable=line-too-long - - source_format = source_format.upper() - allowed_formats = [ - "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", - "DATASTORE_BACKUP", "PARQUET" - ] - if source_format not in allowed_formats: - raise ValueError("{0} is not a valid source format. " - "Please use one of the following types: {1}" - .format(source_format, allowed_formats)) - - compression = compression.upper() - allowed_compressions = ['NONE', 'GZIP'] - if compression not in allowed_compressions: - raise ValueError("{0} is not a valid compression format. " - "Please use one of the following types: {1}" - .format(compression, allowed_compressions)) - - table_resource = { - 'externalDataConfiguration': { - 'autodetect': autodetect, - 'sourceFormat': source_format, - 'sourceUris': source_uris, - 'compression': compression, - 'ignoreUnknownValues': ignore_unknown_values - }, - 'tableReference': { - 'projectId': project_id, - 'datasetId': dataset_id, - 'tableId': external_table_id, - } - } - - if schema_fields: - table_resource['externalDataConfiguration'].update({ - 'schema': { - 'fields': schema_fields - } - }) - - self.log.info('Creating external table: %s', external_project_dataset_table) - - if max_bad_records: - table_resource['externalDataConfiguration']['maxBadRecords'] = max_bad_records - - # if following fields are not specified in src_fmt_configs, - # honor the top-level params for backward-compatibility - if 'skipLeadingRows' not in src_fmt_configs: - src_fmt_configs['skipLeadingRows'] = skip_leading_rows - if 'fieldDelimiter' not in src_fmt_configs: - src_fmt_configs['fieldDelimiter'] = field_delimiter - if 'quote_character' not in src_fmt_configs: - src_fmt_configs['quote'] = quote_character - if 'allowQuotedNewlines' not in src_fmt_configs: - src_fmt_configs['allowQuotedNewlines'] = allow_quoted_newlines - if 'allowJaggedRows' not in src_fmt_configs: - src_fmt_configs['allowJaggedRows'] = allow_jagged_rows - - src_fmt_to_param_mapping = { - 'CSV': 'csvOptions', - 'GOOGLE_SHEETS': 'googleSheetsOptions' - } - - src_fmt_to_configs_mapping = { - 'csvOptions': [ - 'allowJaggedRows', 'allowQuotedNewlines', - 'fieldDelimiter', 'skipLeadingRows', - 'quote' - ], - 'googleSheetsOptions': ['skipLeadingRows'] - } - - if source_format in src_fmt_to_param_mapping.keys(): - - valid_configs = src_fmt_to_configs_mapping[ - src_fmt_to_param_mapping[source_format] - ] - - src_fmt_configs = { - k: v - for k, v in src_fmt_configs.items() if k in valid_configs - } - - table_resource['externalDataConfiguration'][src_fmt_to_param_mapping[ - source_format]] = src_fmt_configs - - if labels: - table_resource['labels'] = labels - - if encryption_configuration: - table_resource["encryptionConfiguration"] = encryption_configuration - - try: - self.service.tables().insert( - projectId=project_id, - datasetId=dataset_id, - body=table_resource - ).execute(num_retries=self.num_retries) - - self.log.info('External table created successfully: %s', - external_project_dataset_table) - - except HttpError as err: - raise Exception( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) - - def patch_table(self, # pylint: disable=too-many-arguments - dataset_id, - table_id, - project_id=None, - description=None, - expiration_time=None, - external_data_configuration=None, - friendly_name=None, - labels=None, - schema=None, - time_partitioning=None, - view=None, - require_partition_filter=None, - encryption_configuration=None): - """ - Patch information in an existing table. - It only updates fileds that are provided in the request object. - - Reference: https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/patch - - :param dataset_id: The dataset containing the table to be patched. - :type dataset_id: str - :param table_id: The Name of the table to be patched. - :type table_id: str - :param project_id: The project containing the table to be patched. - :type project_id: str - :param description: [Optional] A user-friendly description of this table. - :type description: str - :param expiration_time: [Optional] The time when this table expires, - in milliseconds since the epoch. - :type expiration_time: int - :param external_data_configuration: [Optional] A dictionary containing - properties of a table stored outside of BigQuery. - :type external_data_configuration: dict - :param friendly_name: [Optional] A descriptive name for this table. - :type friendly_name: str - :param labels: [Optional] A dictionary containing labels associated with this table. - :type labels: dict - :param schema: [Optional] If set, the schema field list as defined here: - https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema - The supported schema modifications and unsupported schema modification are listed here: - https://cloud.google.com/bigquery/docs/managing-table-schemas - **Example**: :: - - schema=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] - - :type schema: list - :param time_partitioning: [Optional] A dictionary containing time-based partitioning - definition for the table. - :type time_partitioning: dict - :param view: [Optional] A dictionary containing definition for the view. - If set, it will patch a view instead of a table: - https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#ViewDefinition - **Example**: :: - - view = { - "query": "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 500", - "useLegacySql": False - } - - :type view: dict - :param require_partition_filter: [Optional] If true, queries over the this table require a - partition filter. If false, queries over the table - :type require_partition_filter: bool - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - - """ - - project_id = project_id if project_id is not None else self.project_id - - table_resource = {} - - if description is not None: - table_resource['description'] = description - if expiration_time is not None: - table_resource['expirationTime'] = expiration_time - if external_data_configuration: - table_resource['externalDataConfiguration'] = external_data_configuration - if friendly_name is not None: - table_resource['friendlyName'] = friendly_name - if labels: - table_resource['labels'] = labels - if schema: - table_resource['schema'] = {'fields': schema} - if time_partitioning: - table_resource['timePartitioning'] = time_partitioning - if view: - table_resource['view'] = view - if require_partition_filter is not None: - table_resource['requirePartitionFilter'] = require_partition_filter - if encryption_configuration: - table_resource["encryptionConfiguration"] = encryption_configuration - - self.log.info('Patching Table %s:%s.%s', - project_id, dataset_id, table_id) - - try: - self.service.tables().patch( - projectId=project_id, - datasetId=dataset_id, - tableId=table_id, - body=table_resource).execute(num_retries=self.num_retries) - - self.log.info('Table patched successfully: %s:%s.%s', - project_id, dataset_id, table_id) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) - - # pylint: disable=too-many-locals,too-many-arguments, too-many-branches - def run_query(self, - sql, - destination_dataset_table=None, - write_disposition='WRITE_EMPTY', - allow_large_results=False, - flatten_results=None, - udf_config=None, - use_legacy_sql=None, - maximum_billing_tier=None, - maximum_bytes_billed=None, - create_disposition='CREATE_IF_NEEDED', - query_params=None, - labels=None, - schema_update_options=None, - priority='INTERACTIVE', - time_partitioning=None, - api_resource_configs=None, - cluster_fields=None, - location=None, - encryption_configuration=None): - """ - Executes a BigQuery SQL query. Optionally persists results in a BigQuery - table. See here: - - https://cloud.google.com/bigquery/docs/reference/v2/jobs - - For more details about these parameters. - - :param sql: The BigQuery SQL to execute. - :type sql: str - :param destination_dataset_table: The dotted ``.
`` - BigQuery table to save the query results. - :type destination_dataset_table: str - :param write_disposition: What to do if the table already exists in - BigQuery. - :type write_disposition: str - :param allow_large_results: Whether to allow large results. - :type allow_large_results: bool - :param flatten_results: If true and query uses legacy SQL dialect, flattens - all nested and repeated fields in the query results. ``allowLargeResults`` - must be true if this is set to false. For standard SQL queries, this - flag is ignored and results are never flattened. - :type flatten_results: bool - :param udf_config: The User Defined Function configuration for the query. - See https://cloud.google.com/bigquery/user-defined-functions for details. - :type udf_config: list - :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). - If `None`, defaults to `self.use_legacy_sql`. - :type use_legacy_sql: bool - :param api_resource_configs: a dictionary that contain params - 'configuration' applied for Google BigQuery Jobs API: - https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs - for example, {'query': {'useQueryCache': False}}. You could use it - if you need to provide some params that are not supported by the - BigQueryHook like args. - :type api_resource_configs: dict - :param maximum_billing_tier: Positive integer that serves as a - multiplier of the basic price. - :type maximum_billing_tier: int - :param maximum_bytes_billed: Limits the bytes billed for this job. - Queries that will have bytes billed beyond this limit will fail - (without incurring a charge). If unspecified, this will be - set to your project default. - :type maximum_bytes_billed: float - :param create_disposition: Specifies whether the job is allowed to - create new tables. - :type create_disposition: str - :param query_params: a list of dictionary containing query parameter types and - values, passed to BigQuery - :type query_params: list - :param labels: a dictionary containing labels for the job/query, - passed to BigQuery - :type labels: dict - :param schema_update_options: Allows the schema of the destination - table to be updated as a side effect of the query job. - :type schema_update_options: Union[list, tuple, set] - :param priority: Specifies a priority for the query. - Possible values include INTERACTIVE and BATCH. - The default value is INTERACTIVE. - :type priority: str - :param time_partitioning: configure optional time partitioning fields i.e. - partition by field, type and expiration as per API specifications. - :type time_partitioning: dict - :param cluster_fields: Request that the result of this query be stored sorted - by one or more columns. This is only available in combination with - time_partitioning. The order of columns given determines the sort order. - :type cluster_fields: list[str] - :param location: The geographic location of the job. Required except for - US and EU. See details at - https://cloud.google.com/bigquery/docs/locations#specifying_your_location - :type location: str - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - """ - schema_update_options = list(schema_update_options or []) - - if time_partitioning is None: - time_partitioning = {} - - if location: - self.location = location - - if not api_resource_configs: - api_resource_configs = self.api_resource_configs - else: - _validate_value('api_resource_configs', - api_resource_configs, dict) - configuration = deepcopy(api_resource_configs) - if 'query' not in configuration: - configuration['query'] = {} - - else: - _validate_value("api_resource_configs['query']", - configuration['query'], dict) - - if sql is None and not configuration['query'].get('query', None): - raise TypeError('`BigQueryBaseCursor.run_query` ' - 'missing 1 required positional argument: `sql`') - - # BigQuery also allows you to define how you want a table's schema to change - # as a side effect of a query job - # for more details: - # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions # noqa # pylint: disable=line-too-long - - allowed_schema_update_options = [ - 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" - ] - - if not set(allowed_schema_update_options - ).issuperset(set(schema_update_options)): - raise ValueError("{0} contains invalid schema update options. " - "Please only use one or more of the following " - "options: {1}" - .format(schema_update_options, - allowed_schema_update_options)) - - if schema_update_options: - if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: - raise ValueError("schema_update_options is only " - "allowed if write_disposition is " - "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") - - if destination_dataset_table: - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_dataset_table, - default_project_id=self.project_id) - - destination_dataset_table = { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, - } - - if cluster_fields: - cluster_fields = {'fields': cluster_fields} - - query_param_list = [ - (sql, 'query', None, (str,)), - (priority, 'priority', 'INTERACTIVE', (str,)), - (use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool), - (query_params, 'queryParameters', None, list), - (udf_config, 'userDefinedFunctionResources', None, list), - (maximum_billing_tier, 'maximumBillingTier', None, int), - (maximum_bytes_billed, 'maximumBytesBilled', None, float), - (time_partitioning, 'timePartitioning', {}, dict), - (schema_update_options, 'schemaUpdateOptions', None, list), - (destination_dataset_table, 'destinationTable', None, dict), - (cluster_fields, 'clustering', None, dict), - ] - - for param, param_name, param_default, param_type in query_param_list: - if param_name not in configuration['query'] and param in [None, {}, ()]: - if param_name == 'timePartitioning': - param_default = _cleanse_time_partitioning( - destination_dataset_table, time_partitioning) - param = param_default - - if param in [None, {}, ()]: - continue - - _api_resource_configs_duplication_check( - param_name, param, configuration['query']) - - configuration['query'][param_name] = param - - # check valid type of provided param, - # it last step because we can get param from 2 sources, - # and first of all need to find it - - _validate_value(param_name, configuration['query'][param_name], - param_type) - - if param_name == 'schemaUpdateOptions' and param: - self.log.info("Adding experimental 'schemaUpdateOptions': " - "%s", schema_update_options) - - if param_name != 'destinationTable': - continue - - for key in ['projectId', 'datasetId', 'tableId']: - if key not in configuration['query']['destinationTable']: - raise ValueError( - "Not correct 'destinationTable' in " - "api_resource_configs. 'destinationTable' " - "must be a dict with {'projectId':'', " - "'datasetId':'', 'tableId':''}") - - configuration['query'].update({ - 'allowLargeResults': allow_large_results, - 'flattenResults': flatten_results, - 'writeDisposition': write_disposition, - 'createDisposition': create_disposition, - }) - - if 'useLegacySql' in configuration['query'] and configuration['query']['useLegacySql'] and\ - 'queryParameters' in configuration['query']: - raise ValueError("Query parameters are not allowed " - "when using legacy SQL") - - if labels: - _api_resource_configs_duplication_check( - 'labels', labels, configuration) - configuration['labels'] = labels - - if encryption_configuration: - configuration["query"][ - "destinationEncryptionConfiguration" - ] = encryption_configuration - - return self.run_with_configuration(configuration) - - def run_extract( # noqa - self, - source_project_dataset_table, - destination_cloud_storage_uris, - compression='NONE', - export_format='CSV', - field_delimiter=',', - print_header=True, - labels=None): - """ - Executes a BigQuery extract command to copy data from BigQuery to - Google Cloud Storage. See here: - - https://cloud.google.com/bigquery/docs/reference/v2/jobs - - For more details about these parameters. - - :param source_project_dataset_table: The dotted ``.
`` - BigQuery table to use as the source data. - :type source_project_dataset_table: str - :param destination_cloud_storage_uris: The destination Google Cloud - Storage URI (e.g. gs://some-bucket/some-file.txt). Follows - convention defined here: - https://cloud.google.com/bigquery/exporting-data-from-bigquery#exportingmultiple - :type destination_cloud_storage_uris: list - :param compression: Type of compression to use. - :type compression: str - :param export_format: File format to export. - :type export_format: str - :param field_delimiter: The delimiter to use when extracting to a CSV. - :type field_delimiter: str - :param print_header: Whether to print a header for a CSV file extract. - :type print_header: bool - :param labels: a dictionary containing labels for the job/query, - passed to BigQuery - :type labels: dict - """ - - source_project, source_dataset, source_table = \ - _split_tablename(table_input=source_project_dataset_table, - default_project_id=self.project_id, - var_name='source_project_dataset_table') - - configuration = { - 'extract': { - 'sourceTable': { - 'projectId': source_project, - 'datasetId': source_dataset, - 'tableId': source_table, - }, - 'compression': compression, - 'destinationUris': destination_cloud_storage_uris, - 'destinationFormat': export_format, - } - } - - if labels: - configuration['labels'] = labels - - if export_format == 'CSV': - # Only set fieldDelimiter and printHeader fields if using CSV. - # Google does not like it if you set these fields for other export - # formats. - configuration['extract']['fieldDelimiter'] = field_delimiter - configuration['extract']['printHeader'] = print_header - - return self.run_with_configuration(configuration) - - def run_copy(self, # pylint: disable=invalid-name - source_project_dataset_tables, - destination_project_dataset_table, - write_disposition='WRITE_EMPTY', - create_disposition='CREATE_IF_NEEDED', - labels=None, - encryption_configuration=None): - """ - Executes a BigQuery copy command to copy data from one BigQuery table - to another. See here: - - https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.copy - - For more details about these parameters. - - :param source_project_dataset_tables: One or more dotted - ``(project:|project.).
`` - BigQuery tables to use as the source data. Use a list if there are - multiple source tables. - If ```` is not included, project will be the project defined - in the connection json. - :type source_project_dataset_tables: list|string - :param destination_project_dataset_table: The destination BigQuery - table. Format is: ``(project:|project.).
`` - :type destination_project_dataset_table: str - :param write_disposition: The write disposition if the table already exists. - :type write_disposition: str - :param create_disposition: The create disposition if the table doesn't exist. - :type create_disposition: str - :param labels: a dictionary containing labels for the job/query, - passed to BigQuery - :type labels: dict - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - """ - source_project_dataset_tables = ([ - source_project_dataset_tables - ] if not isinstance(source_project_dataset_tables, list) else - source_project_dataset_tables) - - source_project_dataset_tables_fixup = [] - for source_project_dataset_table in source_project_dataset_tables: - source_project, source_dataset, source_table = \ - _split_tablename(table_input=source_project_dataset_table, - default_project_id=self.project_id, - var_name='source_project_dataset_table') - source_project_dataset_tables_fixup.append({ - 'projectId': - source_project, - 'datasetId': - source_dataset, - 'tableId': - source_table - }) - - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_project_dataset_table, - default_project_id=self.project_id) - configuration = { - 'copy': { - 'createDisposition': create_disposition, - 'writeDisposition': write_disposition, - 'sourceTables': source_project_dataset_tables_fixup, - 'destinationTable': { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table - } - } - } - - if labels: - configuration['labels'] = labels - - if encryption_configuration: - configuration["copy"][ - "destinationEncryptionConfiguration" - ] = encryption_configuration - - return self.run_with_configuration(configuration) - - def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid-name - destination_project_dataset_table, - source_uris, - schema_fields=None, - source_format='CSV', - create_disposition='CREATE_IF_NEEDED', - skip_leading_rows=0, - write_disposition='WRITE_EMPTY', - field_delimiter=',', - max_bad_records=0, - quote_character=None, - ignore_unknown_values=False, - allow_quoted_newlines=False, - allow_jagged_rows=False, - schema_update_options=None, - src_fmt_configs=None, - time_partitioning=None, - cluster_fields=None, - autodetect=False, - encryption_configuration=None): - """ - Executes a BigQuery load command to load data from Google Cloud Storage - to BigQuery. See here: - - https://cloud.google.com/bigquery/docs/reference/v2/jobs - - For more details about these parameters. - - :param destination_project_dataset_table: - The dotted ``(.|:).
($)`` BigQuery - table to load data into. If ```` is not included, project will be the - project defined in the connection json. If a partition is specified the - operator will automatically append the data, create a new partition or create - a new DAY partitioned table. - :type destination_project_dataset_table: str - :param schema_fields: The schema field list as defined here: - https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load - Required if autodetect=False; optional if autodetect=True. - :type schema_fields: list - :param autodetect: Attempt to autodetect the schema for CSV and JSON - source files. - :type autodetect: bool - :param source_uris: The source Google Cloud - Storage URI (e.g. gs://some-bucket/some-file.txt). A single wild - per-object name can be used. - :type source_uris: list - :param source_format: File format to export. - :type source_format: str - :param create_disposition: The create disposition if the table doesn't exist. - :type create_disposition: str - :param skip_leading_rows: Number of rows to skip when loading from a CSV. - :type skip_leading_rows: int - :param write_disposition: The write disposition if the table already exists. - :type write_disposition: str - :param field_delimiter: The delimiter to use when loading from a CSV. - :type field_delimiter: str - :param max_bad_records: The maximum number of bad records that BigQuery can - ignore when running the job. - :type max_bad_records: int - :param quote_character: The value that is used to quote data sections in a CSV - file. - :type quote_character: str - :param ignore_unknown_values: [Optional] Indicates if BigQuery should allow - extra values that are not represented in the table schema. - If true, the extra values are ignored. If false, records with extra columns - are treated as bad records, and if there are too many bad records, an - invalid error is returned in the job result. - :type ignore_unknown_values: bool - :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not - (false). - :type allow_quoted_newlines: bool - :param allow_jagged_rows: Accept rows that are missing trailing optional columns. - The missing values are treated as nulls. If false, records with missing - trailing columns are treated as bad records, and if there are too many bad - records, an invalid error is returned in the job result. Only applicable when - soure_format is CSV. - :type allow_jagged_rows: bool - :param schema_update_options: Allows the schema of the destination - table to be updated as a side effect of the load job. - :type schema_update_options: Union[list, tuple, set] - :param src_fmt_configs: configure optional fields specific to the source format - :type src_fmt_configs: dict - :param time_partitioning: configure optional time partitioning fields i.e. - partition by field, type and expiration as per API specifications. - :type time_partitioning: dict - :param cluster_fields: Request that the result of this load be stored sorted - by one or more columns. This is only available in combination with - time_partitioning. The order of columns given determines the sort order. - :type cluster_fields: list[str] - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - """ - # To provide backward compatibility - schema_update_options = list(schema_update_options or []) - - # bigquery only allows certain source formats - # we check to make sure the passed source format is valid - # if it's not, we raise a ValueError - # Refer to this link for more details: - # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat # noqa # pylint: disable=line-too-long - - if schema_fields is None and not autodetect: - raise ValueError( - 'You must either pass a schema or autodetect=True.') - - if src_fmt_configs is None: - src_fmt_configs = {} - - source_format = source_format.upper() - allowed_formats = [ - "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", - "DATASTORE_BACKUP", "PARQUET" - ] - if source_format not in allowed_formats: - raise ValueError("{0} is not a valid source format. " - "Please use one of the following types: {1}" - .format(source_format, allowed_formats)) - - # bigquery also allows you to define how you want a table's schema to change - # as a side effect of a load - # for more details: - # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions - allowed_schema_update_options = [ - 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" - ] - if not set(allowed_schema_update_options).issuperset( - set(schema_update_options)): - raise ValueError( - "{0} contains invalid schema update options." - "Please only use one or more of the following options: {1}" - .format(schema_update_options, allowed_schema_update_options)) - - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_project_dataset_table, - default_project_id=self.project_id, - var_name='destination_project_dataset_table') - - configuration = { - 'load': { - 'autodetect': autodetect, - 'createDisposition': create_disposition, - 'destinationTable': { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, - }, - 'sourceFormat': source_format, - 'sourceUris': source_uris, - 'writeDisposition': write_disposition, - 'ignoreUnknownValues': ignore_unknown_values - } - } - - time_partitioning = _cleanse_time_partitioning( - destination_project_dataset_table, - time_partitioning - ) - if time_partitioning: - configuration['load'].update({ - 'timePartitioning': time_partitioning - }) - - if cluster_fields: - configuration['load'].update({'clustering': {'fields': cluster_fields}}) - - if schema_fields: - configuration['load']['schema'] = {'fields': schema_fields} - - if schema_update_options: - if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: - raise ValueError("schema_update_options is only " - "allowed if write_disposition is " - "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") - else: - self.log.info( - "Adding experimental 'schemaUpdateOptions': %s", - schema_update_options - ) - configuration['load'][ - 'schemaUpdateOptions'] = schema_update_options - - if max_bad_records: - configuration['load']['maxBadRecords'] = max_bad_records - - if encryption_configuration: - configuration["load"][ - "destinationEncryptionConfiguration" - ] = encryption_configuration - - # if following fields are not specified in src_fmt_configs, - # honor the top-level params for backward-compatibility - if 'skipLeadingRows' not in src_fmt_configs: - src_fmt_configs['skipLeadingRows'] = skip_leading_rows - if 'fieldDelimiter' not in src_fmt_configs: - src_fmt_configs['fieldDelimiter'] = field_delimiter - if 'ignoreUnknownValues' not in src_fmt_configs: - src_fmt_configs['ignoreUnknownValues'] = ignore_unknown_values - if quote_character is not None: - src_fmt_configs['quote'] = quote_character - if allow_quoted_newlines: - src_fmt_configs['allowQuotedNewlines'] = allow_quoted_newlines - - src_fmt_to_configs_mapping = { - 'CSV': [ - 'allowJaggedRows', 'allowQuotedNewlines', 'autodetect', - 'fieldDelimiter', 'skipLeadingRows', 'ignoreUnknownValues', - 'nullMarker', 'quote' - ], - 'DATASTORE_BACKUP': ['projectionFields'], - 'NEWLINE_DELIMITED_JSON': ['autodetect', 'ignoreUnknownValues'], - 'PARQUET': ['autodetect', 'ignoreUnknownValues'], - 'AVRO': ['useAvroLogicalTypes'], - } - valid_configs = src_fmt_to_configs_mapping[source_format] - src_fmt_configs = { - k: v - for k, v in src_fmt_configs.items() if k in valid_configs - } - configuration['load'].update(src_fmt_configs) - - if allow_jagged_rows: - configuration['load']['allowJaggedRows'] = allow_jagged_rows - - return self.run_with_configuration(configuration) - - def run_with_configuration(self, configuration): - """ - Executes a BigQuery SQL query. See here: - - https://cloud.google.com/bigquery/docs/reference/v2/jobs - - For more details about the configuration parameter. - - :param configuration: The configuration parameter maps directly to - BigQuery's configuration field in the job object. See - https://cloud.google.com/bigquery/docs/reference/v2/jobs for - details. - """ - jobs = self.service.jobs() - job_data = {'configuration': configuration} - - # Send query and wait for reply. - query_reply = jobs \ - .insert(projectId=self.project_id, body=job_data) \ - .execute(num_retries=self.num_retries) - self.running_job_id = query_reply['jobReference']['jobId'] - if 'location' in query_reply['jobReference']: - location = query_reply['jobReference']['location'] - else: - location = self.location - - # Wait for query to finish. - keep_polling_job = True - while keep_polling_job: - try: - keep_polling_job = self._check_query_status(jobs, keep_polling_job, location) - - except HttpError as err: - if err.resp.status in [500, 503]: - self.log.info( - '%s: Retryable error, waiting for job to complete: %s', - err.resp.status, self.running_job_id) - time.sleep(5) - else: - raise Exception( - 'BigQuery job status check failed. Final error was: {}'. - format(err.resp.status)) - - return self.running_job_id - - def _check_query_status(self, jobs, keep_polling_job, location): - if location: - job = jobs.get( - projectId=self.project_id, - jobId=self.running_job_id, - location=location).execute(num_retries=self.num_retries) - else: - job = jobs.get( - projectId=self.project_id, - jobId=self.running_job_id).execute(num_retries=self.num_retries) - - if job['status']['state'] == 'DONE': - keep_polling_job = False - # Check if job had errors. - if 'errorResult' in job['status']: - raise Exception( - 'BigQuery job failed. Final error was: {}. The job was: {}'.format( - job['status']['errorResult'], job)) - else: - self.log.info('Waiting for job to complete : %s, %s', - self.project_id, self.running_job_id) - time.sleep(5) - return keep_polling_job - - def poll_job_complete(self, job_id): - """ - Check if jobs completed. - - :param job_id: id of the job. - :type job_id: str - :rtype: bool - """ - jobs = self.service.jobs() - try: - if self.location: - job = jobs.get(projectId=self.project_id, - jobId=job_id, - location=self.location).execute(num_retries=self.num_retries) - else: - job = jobs.get(projectId=self.project_id, - jobId=job_id).execute(num_retries=self.num_retries) - if job['status']['state'] == 'DONE': - return True - except HttpError as err: - if err.resp.status in [500, 503]: - self.log.info( - '%s: Retryable error while polling job with id %s', - err.resp.status, job_id) - else: - raise Exception( - 'BigQuery job status check failed. Final error was: {}'. - format(err.resp.status)) - return False - - def cancel_query(self): - """ - Cancel all started queries that have not yet completed - """ - jobs = self.service.jobs() - if (self.running_job_id and - not self.poll_job_complete(self.running_job_id)): - self.log.info('Attempting to cancel job : %s, %s', self.project_id, - self.running_job_id) - if self.location: - jobs.cancel( - projectId=self.project_id, - jobId=self.running_job_id, - location=self.location).execute(num_retries=self.num_retries) - else: - jobs.cancel( - projectId=self.project_id, - jobId=self.running_job_id).execute(num_retries=self.num_retries) - else: - self.log.info('No running BigQuery jobs to cancel.') - return - - # Wait for all the calls to cancel to finish - max_polling_attempts = 12 - polling_attempts = 0 - - job_complete = False - while polling_attempts < max_polling_attempts and not job_complete: - polling_attempts = polling_attempts + 1 - job_complete = self.poll_job_complete(self.running_job_id) - if job_complete: - self.log.info('Job successfully canceled: %s, %s', - self.project_id, self.running_job_id) - elif polling_attempts == max_polling_attempts: - self.log.info( - "Stopping polling due to timeout. Job with id %s " - "has not completed cancel and may or may not finish.", - self.running_job_id) - else: - self.log.info('Waiting for canceled job with id %s to finish.', - self.running_job_id) - time.sleep(5) - - def get_schema(self, dataset_id, table_id): - """ - Get the schema for a given datset.table. - see https://cloud.google.com/bigquery/docs/reference/v2/tables#resource - - :param dataset_id: the dataset ID of the requested table - :param table_id: the table ID of the requested table - :return: a table schema - """ - tables_resource = self.service.tables() \ - .get(projectId=self.project_id, datasetId=dataset_id, tableId=table_id) \ - .execute(num_retries=self.num_retries) - return tables_resource['schema'] - - def get_tabledata(self, dataset_id, table_id, - max_results=None, selected_fields=None, page_token=None, - start_index=None): - """ - Get the data of a given dataset.table and optionally with selected columns. - see https://cloud.google.com/bigquery/docs/reference/v2/tabledata/list - - :param dataset_id: the dataset ID of the requested table. - :param table_id: the table ID of the requested table. - :param max_results: the maximum results to return. - :param selected_fields: List of fields to return (comma-separated). If - unspecified, all fields are returned. - :param page_token: page token, returned from a previous call, - identifying the result set. - :param start_index: zero based index of the starting row to read. - :return: map containing the requested rows. - """ - optional_params = {} - if max_results: - optional_params['maxResults'] = max_results - if selected_fields: - optional_params['selectedFields'] = selected_fields - if page_token: - optional_params['pageToken'] = page_token - if start_index: - optional_params['startIndex'] = start_index - return (self.service.tabledata().list( - projectId=self.project_id, - datasetId=dataset_id, - tableId=table_id, - **optional_params).execute(num_retries=self.num_retries)) - - def run_table_delete(self, deletion_dataset_table, - ignore_if_missing=False): - """ - Delete an existing table from the dataset; - If the table does not exist, return an error unless ignore_if_missing - is set to True. - - :param deletion_dataset_table: A dotted - ``(.|:).
`` that indicates which table - will be deleted. - :type deletion_dataset_table: str - :param ignore_if_missing: if True, then return success even if the - requested table does not exist. - :type ignore_if_missing: bool - :return: - """ - deletion_project, deletion_dataset, deletion_table = \ - _split_tablename(table_input=deletion_dataset_table, - default_project_id=self.project_id) - - try: - self.service.tables() \ - .delete(projectId=deletion_project, - datasetId=deletion_dataset, - tableId=deletion_table) \ - .execute(num_retries=self.num_retries) - self.log.info('Deleted table %s:%s.%s.', deletion_project, - deletion_dataset, deletion_table) - except HttpError: - if not ignore_if_missing: - raise Exception('Table deletion failed. Table does not exist.') - else: - self.log.info('Table does not exist. Skipping.') - - def run_table_upsert(self, dataset_id, table_resource, project_id=None): - """ - creates a new, empty table in the dataset; - If the table already exists, update the existing table. - Since BigQuery does not natively allow table upserts, this is not an - atomic operation. - - :param dataset_id: the dataset to upsert the table into. - :type dataset_id: str - :param table_resource: a table resource. see - https://cloud.google.com/bigquery/docs/reference/v2/tables#resource - :type table_resource: dict - :param project_id: the project to upsert the table into. If None, - project will be self.project_id. - :return: - """ - # check to see if the table exists - table_id = table_resource['tableReference']['tableId'] - project_id = project_id if project_id is not None else self.project_id - tables_list_resp = self.service.tables().list( - projectId=project_id, datasetId=dataset_id).execute(num_retries=self.num_retries) - while True: - for table in tables_list_resp.get('tables', []): - if table['tableReference']['tableId'] == table_id: - # found the table, do update - self.log.info('Table %s:%s.%s exists, updating.', - project_id, dataset_id, table_id) - return self.service.tables().update( - projectId=project_id, - datasetId=dataset_id, - tableId=table_id, - body=table_resource).execute(num_retries=self.num_retries) - # If there is a next page, we need to check the next page. - if 'nextPageToken' in tables_list_resp: - tables_list_resp = self.service.tables()\ - .list(projectId=project_id, - datasetId=dataset_id, - pageToken=tables_list_resp['nextPageToken'])\ - .execute(num_retries=self.num_retries) - # If there is no next page, then the table doesn't exist. - else: - # do insert - self.log.info('Table %s:%s.%s does not exist. creating.', - project_id, dataset_id, table_id) - return self.service.tables().insert( - projectId=project_id, - datasetId=dataset_id, - body=table_resource).execute(num_retries=self.num_retries) - - def run_grant_dataset_view_access(self, - source_dataset, - view_dataset, - view_table, - source_project=None, - view_project=None): - """ - Grant authorized view access of a dataset to a view table. - If this view has already been granted access to the dataset, do nothing. - This method is not atomic. Running it may clobber a simultaneous update. - - :param source_dataset: the source dataset - :type source_dataset: str - :param view_dataset: the dataset that the view is in - :type view_dataset: str - :param view_table: the table of the view - :type view_table: str - :param source_project: the project of the source dataset. If None, - self.project_id will be used. - :type source_project: str - :param view_project: the project that the view is in. If None, - self.project_id will be used. - :type view_project: str - :return: the datasets resource of the source dataset. - """ - - # Apply default values to projects - source_project = source_project if source_project else self.project_id - view_project = view_project if view_project else self.project_id - - # we don't want to clobber any existing accesses, so we have to get - # info on the dataset before we can add view access - source_dataset_resource = self.service.datasets().get( - projectId=source_project, datasetId=source_dataset).execute(num_retries=self.num_retries) - access = source_dataset_resource[ - 'access'] if 'access' in source_dataset_resource else [] - view_access = { - 'view': { - 'projectId': view_project, - 'datasetId': view_dataset, - 'tableId': view_table - } - } - # check to see if the view we want to add already exists. - if view_access not in access: - self.log.info( - 'Granting table %s:%s.%s authorized view access to %s:%s dataset.', - view_project, view_dataset, view_table, source_project, - source_dataset) - access.append(view_access) - return self.service.datasets().patch( - projectId=source_project, - datasetId=source_dataset, - body={ - 'access': access - }).execute(num_retries=self.num_retries) - else: - # if view is already in access, do nothing. - self.log.info( - 'Table %s:%s.%s already has authorized view access to %s:%s dataset.', - view_project, view_dataset, view_table, source_project, source_dataset) - return source_dataset_resource - - def create_empty_dataset(self, dataset_id="", project_id="", - dataset_reference=None): - """ - Create a new empty dataset: - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert - - :param project_id: The name of the project where we want to create - an empty a dataset. Don't need to provide, if projectId in dataset_reference. - :type project_id: str - :param dataset_id: The id of dataset. Don't need to provide, - if datasetId in dataset_reference. - :type dataset_id: str - :param dataset_reference: Dataset reference that could be provided - with request body. More info: - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - :type dataset_reference: dict - """ - - if dataset_reference: - _validate_value('dataset_reference', dataset_reference, dict) - else: - dataset_reference = {} - - if "datasetReference" not in dataset_reference: - dataset_reference["datasetReference"] = {} - - if not dataset_reference["datasetReference"].get("datasetId") and not dataset_id: - raise ValueError( - "{} not provided datasetId. Impossible to create dataset") - - dataset_required_params = [(dataset_id, "datasetId", ""), - (project_id, "projectId", self.project_id)] - for param_tuple in dataset_required_params: - param, param_name, param_default = param_tuple - if param_name not in dataset_reference['datasetReference']: - if param_default and not param: - self.log.info( - "%s was not specified. Will be used default value %s.", - param_name, param_default - ) - param = param_default - dataset_reference['datasetReference'].update( - {param_name: param}) - elif param: - _api_resource_configs_duplication_check( - param_name, param, - dataset_reference['datasetReference'], 'dataset_reference') - - dataset_id = dataset_reference.get("datasetReference").get("datasetId") - dataset_project_id = dataset_reference.get("datasetReference").get( - "projectId") - - self.log.info('Creating Dataset: %s in project: %s ', dataset_id, - dataset_project_id) - - try: - self.service.datasets().insert( - projectId=dataset_project_id, - body=dataset_reference).execute(num_retries=self.num_retries) - self.log.info('Dataset created successfully: In project %s ' - 'Dataset %s', dataset_project_id, dataset_id) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) - - def delete_dataset(self, project_id, dataset_id, delete_contents=False): - """ - Delete a dataset of Big query in your project. - - :param project_id: The name of the project where we have the dataset . - :type project_id: str - :param dataset_id: The dataset to be delete. - :type dataset_id: str - :param delete_contents: [Optional] Whether to force the deletion even if the dataset is not empty. - Will delete all tables (if any) in the dataset if set to True. - Will raise HttpError 400: "{dataset_id} is still in use" if set to False and dataset is not empty. - The default value is False. - :type delete_contents: bool - :return: - """ - project_id = project_id if project_id is not None else self.project_id - self.log.info('Deleting from project: %s Dataset:%s', - project_id, dataset_id) - - try: - self.service.datasets().delete( - projectId=project_id, - datasetId=dataset_id, - deleteContents=delete_contents).execute(num_retries=self.num_retries) - self.log.info('Dataset deleted successfully: In project %s ' - 'Dataset %s', project_id, dataset_id) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) - - def get_dataset(self, dataset_id, project_id=None): - """ - Method returns dataset_resource if dataset exist - and raised 404 error if dataset does not exist - - :param dataset_id: The BigQuery Dataset ID - :type dataset_id: str - :param project_id: The GCP Project ID - :type project_id: str - :return: dataset_resource - - .. seealso:: - For more information, see Dataset Resource content: - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - """ - - if not dataset_id or not isinstance(dataset_id, str): - raise ValueError("dataset_id argument must be provided and has " - "a type 'str'. You provided: {}".format(dataset_id)) - - dataset_project_id = project_id if project_id else self.project_id - - try: - dataset_resource = self.service.datasets().get( - datasetId=dataset_id, projectId=dataset_project_id).execute(num_retries=self.num_retries) - self.log.info("Dataset Resource: %s", dataset_resource) - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content)) - - return dataset_resource - - def get_datasets_list(self, project_id=None): - """ - Method returns full list of BigQuery datasets in the current project - - .. seealso:: - For more information, see: - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/list - - :param project_id: Google Cloud Project for which you - try to get all datasets - :type project_id: str - :return: datasets_list - - Example of returned datasets_list: :: - - { - "kind":"bigquery#dataset", - "location":"US", - "id":"your-project:dataset_2_test", - "datasetReference":{ - "projectId":"your-project", - "datasetId":"dataset_2_test" - } - }, - { - "kind":"bigquery#dataset", - "location":"US", - "id":"your-project:dataset_1_test", - "datasetReference":{ - "projectId":"your-project", - "datasetId":"dataset_1_test" - } - } - ] - """ - dataset_project_id = project_id if project_id else self.project_id - - try: - datasets_list = self.service.datasets().list( - projectId=dataset_project_id).execute(num_retries=self.num_retries)['datasets'] - self.log.info("Datasets List: %s", datasets_list) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content)) - - return datasets_list - - def patch_dataset(self, dataset_id, dataset_resource, project_id=None): - """ - Patches information in an existing dataset. - It only replaces fields that are provided in the submitted dataset resource. - More info: - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/patch - - :param dataset_id: The BigQuery Dataset ID - :type dataset_id: str - :param dataset_resource: Dataset resource that will be provided - in request body. - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - :type dataset_resource: dict - :param project_id: The GCP Project ID - :type project_id: str - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - """ - - if not dataset_id or not isinstance(dataset_id, str): - raise ValueError( - "dataset_id argument must be provided and has " - "a type 'str'. You provided: {}".format(dataset_id) - ) - - dataset_project_id = project_id if project_id else self.project_id - - try: - dataset = ( - self.service.datasets() - .patch( - datasetId=dataset_id, - projectId=dataset_project_id, - body=dataset_resource, - ) - .execute(num_retries=self.num_retries) - ) - self.log.info("Dataset successfully patched: %s", dataset) - except HttpError as err: - raise AirflowException( - "BigQuery job failed. Error was: {}".format(err.content) - ) - - return dataset - - def update_dataset(self, dataset_id, dataset_resource, project_id=None): - """ - Updates information in an existing dataset. The update method replaces the entire - dataset resource, whereas the patch method only replaces fields that are provided - in the submitted dataset resource. - More info: - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/update - - :param dataset_id: The BigQuery Dataset ID - :type dataset_id: str - :param dataset_resource: Dataset resource that will be provided - in request body. - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - :type dataset_resource: dict - :param project_id: The GCP Project ID - :type project_id: str - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - """ - - if not dataset_id or not isinstance(dataset_id, str): - raise ValueError( - "dataset_id argument must be provided and has " - "a type 'str'. You provided: {}".format(dataset_id) - ) - - dataset_project_id = project_id if project_id else self.project_id - - try: - dataset = ( - self.service.datasets() - .update( - datasetId=dataset_id, - projectId=dataset_project_id, - body=dataset_resource, - ) - .execute(num_retries=self.num_retries) - ) - self.log.info("Dataset successfully updated: %s", dataset) - except HttpError as err: - raise AirflowException( - "BigQuery job failed. Error was: {}".format(err.content) - ) - - return dataset - - def insert_all(self, project_id, dataset_id, table_id, - rows, ignore_unknown_values=False, - skip_invalid_rows=False, fail_on_error=False): - """ - Method to stream data into BigQuery one record at a time without needing - to run a load job - - .. seealso:: - For more information, see: - https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/insertAll - - :param project_id: The name of the project where we have the table - :type project_id: str - :param dataset_id: The name of the dataset where we have the table - :type dataset_id: str - :param table_id: The name of the table - :type table_id: str - :param rows: the rows to insert - :type rows: list - - **Example or rows**: - rows=[{"json": {"a_key": "a_value_0"}}, {"json": {"a_key": "a_value_1"}}] - - :param ignore_unknown_values: [Optional] Accept rows that contain values - that do not match the schema. The unknown values are ignored. - The default value is false, which treats unknown values as errors. - :type ignore_unknown_values: bool - :param skip_invalid_rows: [Optional] Insert all valid rows of a request, - even if invalid rows exist. The default value is false, which causes - the entire request to fail if any invalid rows exist. - :type skip_invalid_rows: bool - :param fail_on_error: [Optional] Force the task to fail if any errors occur. - The default value is false, which indicates the task should not fail - even if any insertion errors occur. - :type fail_on_error: bool - """ - - dataset_project_id = project_id if project_id else self.project_id - - body = { - "rows": rows, - "ignoreUnknownValues": ignore_unknown_values, - "kind": "bigquery#tableDataInsertAllRequest", - "skipInvalidRows": skip_invalid_rows, - } - - try: - self.log.info( - 'Inserting %s row(s) into Table %s:%s.%s', - len(rows), dataset_project_id, dataset_id, table_id - ) - - resp = self.service.tabledata().insertAll( - projectId=dataset_project_id, datasetId=dataset_id, - tableId=table_id, body=body - ).execute(num_retries=self.num_retries) - - if 'insertErrors' not in resp: - self.log.info( - 'All row(s) inserted successfully: %s:%s.%s', - dataset_project_id, dataset_id, table_id - ) - else: - error_msg = '{} insert error(s) occurred: {}:{}.{}. Details: {}'.format( - len(resp['insertErrors']), - dataset_project_id, dataset_id, table_id, resp['insertErrors']) - if fail_on_error: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(error_msg) - ) - self.log.info(error_msg) - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) - - -class BigQueryCursor(BigQueryBaseCursor): - """ - A very basic BigQuery PEP 249 cursor implementation. The PyHive PEP 249 - implementation was used as a reference: - - https://github.com/dropbox/PyHive/blob/master/pyhive/presto.py - https://github.com/dropbox/PyHive/blob/master/pyhive/common.py - """ - - def __init__(self, service, project_id, use_legacy_sql=True, location=None, num_retries=5): - super().__init__( - service=service, - project_id=project_id, - use_legacy_sql=use_legacy_sql, - location=location, - num_retries=num_retries - ) - self.buffersize = None - self.page_token = None - self.job_id = None - self.buffer = [] - self.all_pages_loaded = False - - @property - def description(self): - """ The schema description method is not currently implemented. """ - raise NotImplementedError - - def close(self): - """ By default, do nothing """ - - @property - def rowcount(self): - """ By default, return -1 to indicate that this is not supported. """ - return -1 - - def execute(self, operation, parameters=None): - """ - Executes a BigQuery query, and returns the job ID. - - :param operation: The query to execute. - :type operation: str - :param parameters: Parameters to substitute into the query. - :type parameters: dict - """ - sql = _bind_parameters(operation, - parameters) if parameters else operation - self.job_id = self.run_query(sql) - - def executemany(self, operation, seq_of_parameters): - """ - Execute a BigQuery query multiple times with different parameters. - - :param operation: The query to execute. - :type operation: str - :param seq_of_parameters: List of dictionary parameters to substitute into the - query. - :type seq_of_parameters: list - """ - for parameters in seq_of_parameters: - self.execute(operation, parameters) - - def fetchone(self): - """ Fetch the next row of a query result set. """ - return self.next() - - def next(self): - """ - Helper method for fetchone, which returns the next row from a buffer. - If the buffer is empty, attempts to paginate through the result set for - the next page, and load it into the buffer. - """ - if not self.job_id: - return None - - if not self.buffer: - if self.all_pages_loaded: - return None - - query_results = (self.service.jobs().getQueryResults( - projectId=self.project_id, - jobId=self.job_id, - pageToken=self.page_token).execute(num_retries=self.num_retries)) - - if 'rows' in query_results and query_results['rows']: - self.page_token = query_results.get('pageToken') - fields = query_results['schema']['fields'] - col_types = [field['type'] for field in fields] - rows = query_results['rows'] - - for dict_row in rows: - typed_row = ([ - _bq_cast(vs['v'], col_types[idx]) - for idx, vs in enumerate(dict_row['f']) - ]) - self.buffer.append(typed_row) - - if not self.page_token: - self.all_pages_loaded = True - - else: - # Reset all state since we've exhausted the results. - self.page_token = None - self.job_id = None - self.page_token = None - return None - - return self.buffer.pop(0) - - def fetchmany(self, size=None): - """ - Fetch the next set of rows of a query result, returning a sequence of sequences - (e.g. a list of tuples). An empty sequence is returned when no more rows are - available. The number of rows to fetch per call is specified by the parameter. - If it is not given, the cursor's arraysize determines the number of rows to be - fetched. The method should try to fetch as many rows as indicated by the size - parameter. If this is not possible due to the specified number of rows not being - available, fewer rows may be returned. An :py:class:`~pyhive.exc.Error` - (or subclass) exception is raised if the previous call to - :py:meth:`execute` did not produce any result set or no call was issued yet. - """ - if size is None: - size = self.arraysize - result = [] - for _ in range(size): - one = self.fetchone() - if one is None: - break - else: - result.append(one) - return result - - def fetchall(self): - """ - Fetch all (remaining) rows of a query result, returning them as a sequence of - sequences (e.g. a list of tuples). - """ - result = [] - while True: - one = self.fetchone() - if one is None: - break - else: - result.append(one) - return result - - def get_arraysize(self): - """ Specifies the number of rows to fetch at a time with .fetchmany() """ - return self.buffersize or 1 - - def set_arraysize(self, arraysize): - """ Specifies the number of rows to fetch at a time with .fetchmany() """ - self.buffersize = arraysize - - arraysize = property(get_arraysize, set_arraysize) - - def setinputsizes(self, sizes): - """ Does nothing by default """ - - def setoutputsize(self, size, column=None): - """ Does nothing by default """ - - -def _bind_parameters(operation, parameters): - """ Helper method that binds parameters to a SQL query. """ - # inspired by MySQL Python Connector (conversion.py) - string_parameters = {} - for (name, value) in parameters.items(): - if value is None: - string_parameters[name] = 'NULL' - elif isinstance(value, str): - string_parameters[name] = "'" + _escape(value) + "'" - else: - string_parameters[name] = str(value) - return operation % string_parameters - - -def _escape(s): - """ Helper method that escapes parameters to a SQL query. """ - e = s - e = e.replace('\\', '\\\\') - e = e.replace('\n', '\\n') - e = e.replace('\r', '\\r') - e = e.replace("'", "\\'") - e = e.replace('"', '\\"') - return e - - -def _bq_cast(string_field, bq_type): - """ - Helper method that casts a BigQuery row to the appropriate data types. - This is useful because BigQuery returns all fields as strings. - """ - if string_field is None: - return None - elif bq_type == 'INTEGER': - return int(string_field) - elif bq_type in ('FLOAT', 'TIMESTAMP'): - return float(string_field) - elif bq_type == 'BOOLEAN': - if string_field not in ['true', 'false']: - raise ValueError("{} must have value 'true' or 'false'".format( - string_field)) - return string_field == 'true' - else: - return string_field - - -def _split_tablename(table_input, default_project_id, var_name=None): - - if '.' not in table_input: - raise ValueError( - 'Expected target table name in the format of ' - '.
. Got: {}'.format(table_input)) - - if not default_project_id: - raise ValueError("INTERNAL: No default project is specified") - - def var_print(var_name): - if var_name is None: - return "" - else: - return "Format exception for {var}: ".format(var=var_name) - - if table_input.count('.') + table_input.count(':') > 3: - raise Exception(('{var}Use either : or . to specify project ' - 'got {input}').format( - var=var_print(var_name), input=table_input)) - cmpt = table_input.rsplit(':', 1) - project_id = None - rest = table_input - if len(cmpt) == 1: - project_id = None - rest = cmpt[0] - elif len(cmpt) == 2 and cmpt[0].count(':') <= 1: - if cmpt[-1].count('.') != 2: - project_id = cmpt[0] - rest = cmpt[1] - else: - raise Exception(('{var}Expect format of (.
, ' - 'got {input}').format( - var=var_print(var_name), input=table_input)) - - cmpt = rest.split('.') - if len(cmpt) == 3: - if project_id: - raise ValueError( - "{var}Use either : or . to specify project".format( - var=var_print(var_name))) - project_id = cmpt[0] - dataset_id = cmpt[1] - table_id = cmpt[2] - - elif len(cmpt) == 2: - dataset_id = cmpt[0] - table_id = cmpt[1] - else: - raise Exception( - ('{var}Expect format of (.
, ' - 'got {input}').format(var=var_print(var_name), input=table_input)) - - if project_id is None: - if var_name is not None: - log = LoggingMixin().log - log.info( - 'Project not included in %s: %s; using project "%s"', - var_name, table_input, default_project_id - ) - project_id = default_project_id - - return project_id, dataset_id, table_id - - -def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in): - # if it is a partitioned table ($ is in the table name) add partition load option - - if time_partitioning_in is None: - time_partitioning_in = {} - - time_partitioning_out = {} - if destination_dataset_table and '$' in destination_dataset_table: - time_partitioning_out['type'] = 'DAY' - time_partitioning_out.update(time_partitioning_in) - return time_partitioning_out - - -def _validate_value(key, value, expected_type): - """ function to check expected type and raise - error if type is not correct """ - if not isinstance(value, expected_type): - raise TypeError("{} argument must have a type {} not {}".format( - key, expected_type, type(value))) - - -def _api_resource_configs_duplication_check(key, value, config_dict, - config_dict_name='api_resource_configs'): - if key in config_dict and value != config_dict[key]: - raise ValueError("Values of {param_name} param are duplicated. " - "{dict_name} contained {param_name} param " - "in `query` config and {param_name} was also provided " - "with arg to run_query() method. Please remove duplicates." - .format(param_name=key, dict_name=config_dict_name)) +"""This module is deprecated. Please use `airflow.gcp.hooks.bigquery`.""" + +import warnings + +# pylint: disable=unused-import +from airflow.gcp.hooks.bigquery import ( # noqa + BigQueryPandasConnector, + BigQueryCursor, + BigQueryConnection, + BigQueryHook, + BigQueryBaseCursor, + GbqConnector, +) + +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.hooks.bigquery`.", + DeprecationWarning, +) diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 7b8ac5e8b3b983..b805f6aa6bbd29 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -16,7 +16,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Databricks hook.""" +""" +Databricks hook. + +This hook enable the submitting and running of jobs to the Databricks platform. Internally the +operators talk to the ``api/2.0/jobs/runs/submit`` +`endpoint `_. +""" from urllib.parse import urlparse from time import sleep diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py index 9b046acf909725..54bffa3acd417b 100644 --- a/airflow/contrib/hooks/gcs_hook.py +++ b/airflow/contrib/hooks/gcs_hook.py @@ -16,566 +16,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ -This module contains a Google Cloud Storage hook. +This module is deprecated. Please use `airflow.gcp.hooks.gcs`. """ +import warnings -from typing import Optional -import gzip as gz -import os -import shutil - -from urllib.parse import urlparse -from google.cloud import storage - -from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook -from airflow.exceptions import AirflowException -from airflow.version import version - - -class GoogleCloudStorageHook(GoogleCloudBaseHook): - """ - Interact with Google Cloud Storage. This hook uses the Google Cloud Platform - connection. - """ - - _conn = None # type: Optional[storage.Client] - - def __init__(self, - google_cloud_storage_conn_id='google_cloud_default', - delegate_to=None): - super().__init__(google_cloud_storage_conn_id, - delegate_to) - - def get_conn(self): - """ - Returns a Google Cloud Storage service object. - """ - if not self._conn: - self._conn = storage.Client(credentials=self._get_credentials(), - client_info=self.client_info, - project=self.project_id) - - return self._conn - - def copy(self, source_bucket, source_object, destination_bucket=None, - destination_object=None): - """ - Copies an object from a bucket to another, with renaming if requested. - - destination_bucket or destination_object can be omitted, in which case - source bucket/object is used, but not both. - - :param source_bucket: The bucket of the object to copy from. - :type source_bucket: str - :param source_object: The object to copy. - :type source_object: str - :param destination_bucket: The destination of the object to copied to. - Can be omitted; then the same bucket is used. - :type destination_bucket: str - :param destination_object: The (renamed) path of the object if given. - Can be omitted; then the same name is used. - :type destination_object: str - """ - destination_bucket = destination_bucket or source_bucket - destination_object = destination_object or source_object - - if source_bucket == destination_bucket and \ - source_object == destination_object: - - raise ValueError( - 'Either source/destination bucket or source/destination object ' - 'must be different, not both the same: bucket=%s, object=%s' % - (source_bucket, source_object)) - if not source_bucket or not source_object: - raise ValueError('source_bucket and source_object cannot be empty.') - - client = self.get_conn() - source_bucket = client.bucket(source_bucket) - source_object = source_bucket.blob(source_object) - destination_bucket = client.bucket(destination_bucket) - destination_object = source_bucket.copy_blob( - blob=source_object, - destination_bucket=destination_bucket, - new_name=destination_object) - - self.log.info('Object %s in bucket %s copied to object %s in bucket %s', - source_object.name, source_bucket.name, - destination_object.name, destination_bucket.name) - - def rewrite(self, source_bucket, source_object, destination_bucket, - destination_object=None): - """ - Has the same functionality as copy, except that will work on files - over 5 TB, as well as when copying between locations and/or storage - classes. - - destination_object can be omitted, in which case source_object is used. - - :param source_bucket: The bucket of the object to copy from. - :type source_bucket: str - :param source_object: The object to copy. - :type source_object: str - :param destination_bucket: The destination of the object to copied to. - :type destination_bucket: str - :param destination_object: The (renamed) path of the object if given. - Can be omitted; then the same name is used. - :type destination_object: str - """ - destination_object = destination_object or source_object - if (source_bucket == destination_bucket and - source_object == destination_object): - raise ValueError( - 'Either source/destination bucket or source/destination object ' - 'must be different, not both the same: bucket=%s, object=%s' % - (source_bucket, source_object)) - if not source_bucket or not source_object: - raise ValueError('source_bucket and source_object cannot be empty.') - - client = self.get_conn() - source_bucket = client.bucket(source_bucket) - source_object = source_bucket.blob(blob_name=source_object) - destination_bucket = client.bucket(destination_bucket) - - token, bytes_rewritten, total_bytes = destination_bucket.blob( - blob_name=destination_object).rewrite( - source=source_object - ) - - self.log.info('Total Bytes: %s | Bytes Written: %s', - total_bytes, bytes_rewritten) - - while token is not None: - token, bytes_rewritten, total_bytes = destination_bucket.blob( - blob_name=destination_object).rewrite( - source=source_object, token=token - ) - - self.log.info('Total Bytes: %s | Bytes Written: %s', - total_bytes, bytes_rewritten) - self.log.info('Object %s in bucket %s copied to object %s in bucket %s', - source_object.name, source_bucket.name, - destination_object, destination_bucket.name) - - def download(self, bucket_name, object_name, filename=None): - """ - Get a file from Google Cloud Storage. - - :param bucket_name: The bucket to fetch from. - :type bucket_name: str - :param object_name: The object to fetch. - :type object_name: str - :param filename: If set, a local file path where the file should be written to. - :type filename: str - """ - client = self.get_conn() - bucket = client.bucket(bucket_name) - blob = bucket.blob(blob_name=object_name) - - if filename: - blob.download_to_filename(filename) - self.log.info('File downloaded to %s', filename) - - return blob.download_as_string() - - def upload(self, bucket_name, object_name, filename, - mime_type='application/octet-stream', gzip=False): - """ - Uploads a local file to Google Cloud Storage. - - :param bucket_name: The bucket to upload to. - :type bucket_name: str - :param object_name: The object name to set when uploading the local file. - :type object_name: str - :param filename: The local file path to the file to be uploaded. - :type filename: str - :param mime_type: The MIME type to set when uploading the file. - :type mime_type: str - :param gzip: Option to compress file for upload - :type gzip: bool - """ - - if gzip: - filename_gz = filename + '.gz' - - with open(filename, 'rb') as f_in: - with gz.open(filename_gz, 'wb') as f_out: - shutil.copyfileobj(f_in, f_out) - filename = filename_gz - - client = self.get_conn() - bucket = client.bucket(bucket_name) - blob = bucket.blob(blob_name=object_name) - blob.upload_from_filename(filename=filename, - content_type=mime_type) - - if gzip: - os.remove(filename) - self.log.info('File %s uploaded to %s in %s bucket', filename, object_name, bucket_name) - - def exists(self, bucket_name, object_name): - """ - Checks for the existence of a file in Google Cloud Storage. - - :param bucket_name: The Google cloud storage bucket where the object is. - :type bucket_name: str - :param object_name: The name of the blob_name to check in the Google cloud - storage bucket. - :type object_name: str - """ - client = self.get_conn() - bucket = client.bucket(bucket_name) - blob = bucket.blob(blob_name=object_name) - return blob.exists() - - def is_updated_after(self, bucket_name, object_name, ts): - """ - Checks if an blob_name is updated in Google Cloud Storage. - - :param bucket_name: The Google cloud storage bucket where the object is. - :type bucket_name: str - :param object_name: The name of the object to check in the Google cloud - storage bucket. - :type object_name: str - :param ts: The timestamp to check against. - :type ts: datetime.datetime - """ - client = self.get_conn() - bucket = client.bucket(bucket_name) - blob = bucket.get_blob(blob_name=object_name) - - if blob is None: - raise ValueError("Object ({}) not found in Bucket ({})".format( - object_name, bucket_name)) - - blob_update_time = blob.updated - - if blob_update_time is not None: - import dateutil.tz - - if not ts.tzinfo: - ts = ts.replace(tzinfo=dateutil.tz.tzutc()) - - self.log.info("Verify object date: %s > %s", blob_update_time, ts) - - if blob_update_time > ts: - return True - - return False - - def delete(self, bucket_name, object_name): - """ - Deletes an object from the bucket. - - :param bucket_name: name of the bucket, where the object resides - :type bucket_name: str - :param object_name: name of the object to delete - :type object_name: str - """ - client = self.get_conn() - bucket = client.bucket(bucket_name) - blob = bucket.blob(blob_name=object_name) - blob.delete() - - self.log.info('Blob %s deleted.', object_name) - - def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimiter=None): - """ - List all objects from the bucket with the give string prefix in name - - :param bucket_name: bucket name - :type bucket_name: str - :param versions: if true, list all versions of the objects - :type versions: bool - :param max_results: max count of items to return in a single page of responses - :type max_results: int - :param prefix: prefix string which filters objects whose name begin with - this prefix - :type prefix: str - :param delimiter: filters objects based on the delimiter (for e.g '.csv') - :type delimiter: str - :return: a stream of object names matching the filtering criteria - """ - client = self.get_conn() - bucket = client.bucket(bucket_name) - - ids = [] - page_token = None - while True: - blobs = bucket.list_blobs( - max_results=max_results, - page_token=page_token, - prefix=prefix, - delimiter=delimiter, - versions=versions - ) - - blob_names = [] - for blob in blobs: - blob_names.append(blob.name) - - prefixes = blobs.prefixes - if prefixes: - ids += list(prefixes) - else: - ids += blob_names - - page_token = blobs.next_page_token - if page_token is None: - # empty next page token - break - return ids - - def get_size(self, bucket_name, object_name): - """ - Gets the size of a file in Google Cloud Storage. - - :param bucket_name: The Google cloud storage bucket where the blob_name is. - :type bucket_name: str - :param object_name: The name of the object to check in the Google - cloud storage bucket_name. - :type object_name: str - - """ - self.log.info('Checking the file size of object: %s in bucket_name: %s', - object_name, - bucket_name) - client = self.get_conn() - bucket = client.bucket(bucket_name) - blob = bucket.get_blob(blob_name=object_name) - blob_size = blob.size - self.log.info('The file size of %s is %s bytes.', object_name, blob_size) - return blob_size - - def get_crc32c(self, bucket_name, object_name): - """ - Gets the CRC32c checksum of an object in Google Cloud Storage. - - :param bucket_name: The Google cloud storage bucket where the blob_name is. - :type bucket_name: str - :param object_name: The name of the object to check in the Google cloud - storage bucket_name. - :type object_name: str - """ - self.log.info('Retrieving the crc32c checksum of ' - 'object_name: %s in bucket_name: %s', object_name, bucket_name) - client = self.get_conn() - bucket = client.bucket(bucket_name) - blob = bucket.get_blob(blob_name=object_name) - blob_crc32c = blob.crc32c - self.log.info('The crc32c checksum of %s is %s', object_name, blob_crc32c) - return blob_crc32c - - def get_md5hash(self, bucket_name, object_name): - """ - Gets the MD5 hash of an object in Google Cloud Storage. - - :param bucket_name: The Google cloud storage bucket where the blob_name is. - :type bucket_name: str - :param object_name: The name of the object to check in the Google cloud - storage bucket_name. - :type object_name: str - """ - self.log.info('Retrieving the MD5 hash of ' - 'object: %s in bucket: %s', object_name, bucket_name) - client = self.get_conn() - bucket = client.bucket(bucket_name) - blob = bucket.get_blob(blob_name=object_name) - blob_md5hash = blob.md5_hash - self.log.info('The md5Hash of %s is %s', object_name, blob_md5hash) - return blob_md5hash - - @GoogleCloudBaseHook.catch_http_exception - @GoogleCloudBaseHook.fallback_to_default_project_id - def create_bucket(self, - bucket_name, - resource=None, - storage_class='MULTI_REGIONAL', - location='US', - project_id=None, - labels=None - ): - """ - Creates a new bucket. Google Cloud Storage uses a flat namespace, so - you can't create a bucket with a name that is already in use. - - .. seealso:: - For more information, see Bucket Naming Guidelines: - https://cloud.google.com/storage/docs/bucketnaming.html#requirements - - :param bucket_name: The name of the bucket. - :type bucket_name: str - :param resource: An optional dict with parameters for creating the bucket. - For information on available parameters, see Cloud Storage API doc: - https://cloud.google.com/storage/docs/json_api/v1/buckets/insert - :type resource: dict - :param storage_class: This defines how objects in the bucket are stored - and determines the SLA and the cost of storage. Values include - - - ``MULTI_REGIONAL`` - - ``REGIONAL`` - - ``STANDARD`` - - ``NEARLINE`` - - ``COLDLINE``. - - If this value is not specified when the bucket is - created, it will default to STANDARD. - :type storage_class: str - :param location: The location of the bucket. - Object data for objects in the bucket resides in physical storage - within this region. Defaults to US. - - .. seealso:: - https://developers.google.com/storage/docs/bucket-locations - - :type location: str - :param project_id: The ID of the GCP Project. - :type project_id: str - :param labels: User-provided labels, in key/value pairs. - :type labels: dict - :return: If successful, it returns the ``id`` of the bucket. - """ - - self.log.info('Creating Bucket: %s; Location: %s; Storage Class: %s', - bucket_name, location, storage_class) - - # Add airflow-version label to the bucket - labels = {} or labels - labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-') - - client = self.get_conn() - bucket = client.bucket(bucket_name=bucket_name) - bucket_resource = resource or {} - - for item in bucket_resource: - if item != "name": - bucket._patch_property(name=item, value=resource[item]) # pylint: disable=protected-access - - bucket.storage_class = storage_class - bucket.labels = labels - bucket.create(project=project_id, location=location) - return bucket.id - - def insert_bucket_acl(self, bucket_name, entity, role, user_project=None): - """ - Creates a new ACL entry on the specified bucket_name. - See: https://cloud.google.com/storage/docs/json_api/v1/bucketAccessControls/insert - - :param bucket_name: Name of a bucket_name. - :type bucket_name: str - :param entity: The entity holding the permission, in one of the following forms: - user-userId, user-email, group-groupId, group-email, domain-domain, - project-team-projectId, allUsers, allAuthenticatedUsers. - See: https://cloud.google.com/storage/docs/access-control/lists#scopes - :type entity: str - :param role: The access permission for the entity. - Acceptable values are: "OWNER", "READER", "WRITER". - :type role: str - :param user_project: (Optional) The project to be billed for this request. - Required for Requester Pays buckets. - :type user_project: str - """ - self.log.info('Creating a new ACL entry in bucket: %s', bucket_name) - client = self.get_conn() - bucket = client.bucket(bucket_name=bucket_name) - bucket.acl.reload() - bucket.acl.entity_from_dict(entity_dict={"entity": entity, "role": role}) - if user_project: - bucket.acl.user_project = user_project - bucket.acl.save() - - self.log.info('A new ACL entry created in bucket: %s', bucket_name) - - def insert_object_acl(self, bucket_name, object_name, entity, role, generation=None, user_project=None): - """ - Creates a new ACL entry on the specified object. - See: https://cloud.google.com/storage/docs/json_api/v1/objectAccessControls/insert - - :param bucket_name: Name of a bucket_name. - :type bucket_name: str - :param object_name: Name of the object. For information about how to URL encode - object names to be path safe, see: - https://cloud.google.com/storage/docs/json_api/#encoding - :type object_name: str - :param entity: The entity holding the permission, in one of the following forms: - user-userId, user-email, group-groupId, group-email, domain-domain, - project-team-projectId, allUsers, allAuthenticatedUsers - See: https://cloud.google.com/storage/docs/access-control/lists#scopes - :type entity: str - :param role: The access permission for the entity. - Acceptable values are: "OWNER", "READER". - :type role: str - :param generation: Optional. If present, selects a specific revision of this object. - :type generation: long - :param user_project: (Optional) The project to be billed for this request. - Required for Requester Pays buckets. - :type user_project: str - """ - self.log.info('Creating a new ACL entry for object: %s in bucket: %s', - object_name, bucket_name) - client = self.get_conn() - bucket = client.bucket(bucket_name=bucket_name) - blob = bucket.blob(blob_name=object_name, generation=generation) - # Reload fetches the current ACL from Cloud Storage. - blob.acl.reload() - blob.acl.entity_from_dict(entity_dict={"entity": entity, "role": role}) - if user_project: - blob.acl.user_project = user_project - blob.acl.save() - - self.log.info('A new ACL entry created for object: %s in bucket: %s', - object_name, bucket_name) - - def compose(self, bucket_name, source_objects, destination_object): - """ - Composes a list of existing object into a new object in the same storage bucket_name - - Currently it only supports up to 32 objects that can be concatenated - in a single operation - - https://cloud.google.com/storage/docs/json_api/v1/objects/compose - - :param bucket_name: The name of the bucket containing the source objects. - This is also the same bucket to store the composed destination object. - :type bucket_name: str - :param source_objects: The list of source objects that will be composed - into a single object. - :type source_objects: list - :param destination_object: The path of the object if given. - :type destination_object: str - """ - - if not source_objects: - raise ValueError('source_objects cannot be empty.') - - if not bucket_name or not destination_object: - raise ValueError('bucket_name and destination_object cannot be empty.') - - self.log.info("Composing %s to %s in the bucket %s", - source_objects, destination_object, bucket_name) - client = self.get_conn() - bucket = client.bucket(bucket_name) - destination_blob = bucket.blob(destination_object) - destination_blob.compose( - sources=[ - bucket.blob(blob_name=source_object) for source_object in source_objects - ]) - - self.log.info("Completed successfully.") - - -def _parse_gcs_url(gsurl): - """ - Given a Google Cloud Storage URL (gs:///), returns a - tuple containing the corresponding bucket and blob. - """ +# pylint: disable=unused-import +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook, _parse_gcs_url # noqa - parsed_url = urlparse(gsurl) - if not parsed_url.netloc: - raise AirflowException('Please provide a bucket name') - else: - bucket = parsed_url.netloc - # Remove leading '/' but NOT trailing one - blob = parsed_url.path.lstrip('/') - return bucket, blob +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.hooks.gcs`.", + DeprecationWarning +) diff --git a/airflow/contrib/hooks/gdrive_hook.py b/airflow/contrib/hooks/gdrive_hook.py new file mode 100644 index 00000000000000..2d8f144d8cd5a9 --- /dev/null +++ b/airflow/contrib/hooks/gdrive_hook.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Google Drive service""" +from typing import Any + +from googleapiclient.discovery import build +from googleapiclient.http import MediaFileUpload + +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook + + +# noinspection PyAbstractClass +class GoogleDriveHook(GoogleCloudBaseHook): + """ + Hook for the Google Drive APIs. + + :param api_version: API version used (for example v3). + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + + _conn = None + + def __init__( + self, api_version: str = "v3", gcp_conn_id: str = "google_cloud_default", delegate_to: str = None + ) -> None: + super().__init__(gcp_conn_id, delegate_to) + self.api_version = api_version + + def get_conn(self) -> Any: + """ + Retrieves the connection to Google Drive. + + :return: Google Drive services object. + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build("drive", self.api_version, http=http_authorized, cache_discovery=False) + return self._conn + + def _ensure_folders_exists(self, path: str) -> str: + service = self.get_conn() + current_parent = "root" + folders = path.split("/") + depth = 0 + # First tries to enter directories + for current_folder in folders: + self.log.debug("Looking for %s directory with %s parent", current_folder, current_parent) + conditions = [ + "mimeType = 'application/vnd.google-apps.folder'", + "name='{}'".format(current_folder), + "'{}' in parents".format(current_parent), + ] + result = ( + service.files() # pylint: disable=no-member + .list(q=" and ".join(conditions), spaces="drive", fields="files(id, name)") + .execute(num_retries=self.num_retries) + ) + files = result.get("files", []) + if not files: + self.log.info("Not found %s directory", current_folder) + # If the directory does not exist, break loops + break + depth += 1 + current_parent = files[0].get("id") + + # Check if there are directories to process + if depth != len(folders): + # Create missing directories + for current_folder in folders[depth:]: + file_metadata = { + "name": current_folder, + "mimeType": "application/vnd.google-apps.folder", + "parents": [current_parent], + } + file = ( + service.files() # pylint: disable=no-member + .create(body=file_metadata, fields="id") + .execute(num_retries=self.num_retries) + ) + self.log.info("Created %s directory", current_folder) + + current_parent = file.get("id") + # Return the ID of the last directory + return current_parent + + def upload_file(self, local_location: str, remote_location: str) -> str: + """ + Uploads a file that is available locally to a Google Drive service. + + :param local_location: The path where the file is available. + :type local_location: str + :param remote_location: The path where the file will be send + :type remote_location: str + :return: File ID + :rtype: str + """ + service = self.get_conn() + directory_path, _, filename = remote_location.rpartition("/") + if directory_path: + parent = self._ensure_folders_exists(directory_path) + else: + parent = "root" + + file_metadata = {"name": filename, "parents": [parent]} + media = MediaFileUpload(local_location) + file = ( + service.files() # pylint: disable=no-member + .create(body=file_metadata, media_body=media, fields="id") + .execute(num_retries=self.num_retries) + ) + self.log.info("File %s uploaded to gdrive://%s.", local_location, remote_location) + return file.get("id") diff --git a/airflow/contrib/hooks/grpc_hook.py b/airflow/contrib/hooks/grpc_hook.py index a6ca8a7e5ca925..664ed70b7372af 100644 --- a/airflow/contrib/hooks/grpc_hook.py +++ b/airflow/contrib/hooks/grpc_hook.py @@ -1,17 +1,23 @@ # -*- coding: utf-8 -*- + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""GRPC Hook""" import grpc from google import auth as google_auth @@ -58,7 +64,7 @@ def get_conn(self): if auth_type == "NO_AUTH": channel = grpc.insecure_channel(base_url) - elif auth_type == "SSL" or auth_type == "TLS": + elif auth_type in {"SSL", "TLS"}: credential_file_name = self._get_field("credential_pem_file") creds = grpc.ssl_channel_credentials(open(credential_file_name).read()) channel = grpc.secure_channel(base_url, creds) @@ -91,7 +97,9 @@ def get_conn(self): return channel - def run(self, stub_class, call_func, streaming=False, data={}): + def run(self, stub_class, call_func, streaming=False, data=None): + if data is None: + data = {} with self.get_conn() as channel: stub = stub_class(channel) try: @@ -102,10 +110,14 @@ def run(self, stub_class, call_func, streaming=False, data={}): else: yield from response except grpc.RpcError as ex: + # noinspection PyUnresolvedReferences self.log.exception( "Error occurred when calling the grpc service: {0}, method: {1} \ status code: {2}, error details: {3}" - .format(stub.__class__.__name__, call_func, ex.code(), ex.details())) + .format(stub.__class__.__name__, + call_func, + ex.code(), # pylint: disable=no-member + ex.details())) # pylint: disable=no-member raise ex def _get_field(self, field_name, default=None): diff --git a/airflow/contrib/hooks/qubole_hook.py b/airflow/contrib/hooks/qubole_hook.py index c2177c19743e47..8a87c04aca22a9 100644 --- a/airflow/contrib/hooks/qubole_hook.py +++ b/airflow/contrib/hooks/qubole_hook.py @@ -17,13 +17,18 @@ # specific language governing permissions and limitations # under the License. # - +"""Qubole hook""" import os import pathlib import time import datetime import re +from qds_sdk.qubole import Qubole +from qds_sdk.commands import Command, HiveCommand, PrestoCommand, HadoopCommand, \ + PigCommand, ShellCommand, SparkCommand, DbTapQueryCommand, DbExportCommand, \ + DbImportCommand, SqlCommand + from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.configuration import conf @@ -31,11 +36,6 @@ from airflow.utils.state import State from airflow.models import TaskInstance -from qds_sdk.qubole import Qubole -from qds_sdk.commands import Command, HiveCommand, PrestoCommand, HadoopCommand, \ - PigCommand, ShellCommand, SparkCommand, DbTapQueryCommand, DbExportCommand, \ - DbImportCommand, SqlCommand - COMMAND_CLASSES = { "hivecmd": HiveCommand, "prestocmd": PrestoCommand, @@ -57,20 +57,24 @@ def flatten_list(list_of_lists): + """Flatten the list""" return [element for array in list_of_lists for element in array] def filter_options(options): + """Remove options from the list""" options_to_remove = ["help", "print-logs-live", "print-logs"] return [option for option in options if option not in options_to_remove] def get_options_list(command_class): + """Get options list""" options_list = [option.get_opt_string().strip("--") for option in command_class.optparser.option_list] return filter_options(options_list) def build_command_args(): + """Build Command argument from command and options""" command_args, hyphen_args = {}, set() for cmd in COMMAND_CLASSES: @@ -95,7 +99,8 @@ def build_command_args(): class QuboleHook(BaseHook): - def __init__(self, *args, **kwargs): + """Hook for Qubole communication""" + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument conn = self.get_connection(kwargs['qubole_conn_id']) Qubole.configure(api_token=conn.password, api_url=conn.host) self.task_id = kwargs['task_id'] @@ -107,6 +112,7 @@ def __init__(self, *args, **kwargs): @staticmethod def handle_failure_retry(context): + """Handle retries in case of failures""" ti = context['ti'] cmd_id = ti.xcom_pull(key='qbol_cmd_id', task_ids=ti.task_id) @@ -123,6 +129,7 @@ def handle_failure_retry(context): cmd.cancel() def execute(self, context): + """Execute call""" args = self.cls.parse(self.create_cmd_args(context)) self.cmd = self.cls.create(**args) self.task_instance = context['task_instance'] @@ -197,7 +204,7 @@ def get_log(self, ti): """ if self.cmd is None: cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id) - Command.get_log_id(self.cls, cmd_id) + Command.get_log_id(cmd_id) def get_jobs_id(self, ti): """ @@ -207,8 +214,9 @@ def get_jobs_id(self, ti): """ if self.cmd is None: cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id) - Command.get_jobs_id(self.cls, cmd_id) + Command.get_jobs_id(cmd_id) + # noinspection PyMethodMayBeStatic def get_extra_links(self, operator, dttm): """ Get link to qubole command result page. @@ -229,28 +237,25 @@ def get_extra_links(self, operator, dttm): return url def create_cmd_args(self, context): + """Creates command arguments""" args = [] cmd_type = self.kwargs['command_type'] inplace_args = None tags = {self.dag_id, self.task_id, context['run_id']} positional_args_list = flatten_list(POSITIONAL_ARGS.values()) - for k, v in self.kwargs.items(): - if k in COMMAND_ARGS[cmd_type]: - if k in HYPHEN_ARGS: - args.append("--{0}={1}".format(k.replace('_', '-'), v)) - elif k in positional_args_list: - inplace_args = v - elif k == 'tags': - if isinstance(v, str): - tags.add(v) - elif isinstance(v, (list, tuple)): - for val in v: - tags.add(val) + for key, value in self.kwargs.items(): + if key in COMMAND_ARGS[cmd_type]: + if key in HYPHEN_ARGS: + args.append("--{0}={1}".format(key.replace('_', '-'), value)) + elif key in positional_args_list: + inplace_args = value + elif key == 'tags': + self._add_tags(tags, value) else: - args.append("--{0}={1}".format(k, v)) + args.append("--{0}={1}".format(key, value)) - if k == 'notify' and v is True: + if key == 'notify' and value is True: args.append("--notify") args.append("--tags={0}".format(','.join(filter(None, tags)))) @@ -259,3 +264,10 @@ def create_cmd_args(self, context): args += inplace_args.split(' ') return args + + @staticmethod + def _add_tags(tags, value): + if isinstance(value, str): + tags.add(value) + elif isinstance(value, (list, tuple)): + tags.extend(value) diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py index bfce52ba36297b..449f072b42f924 100644 --- a/airflow/contrib/hooks/spark_submit_hook.py +++ b/airflow/contrib/hooks/spark_submit_hook.py @@ -175,7 +175,7 @@ def _resolve_connection(self): 'deploy_mode': None, 'spark_home': None, 'spark_binary': self._spark_binary or "spark-submit", - 'namespace': 'default'} + 'namespace': None} try: # Master can be local, yarn, spark://HOST:PORT, mesos://HOST:PORT and @@ -193,7 +193,7 @@ def _resolve_connection(self): conn_data['spark_home'] = extra.get('spark-home', None) conn_data['spark_binary'] = self._spark_binary or \ extra.get('spark-binary', "spark-submit") - conn_data['namespace'] = extra.get('namespace', 'default') + conn_data['namespace'] = extra.get('namespace') except AirflowException: self.log.info( "Could not load connection string %s, defaulting to %s", @@ -246,7 +246,7 @@ def _build_spark_submit_command(self, application): elif self._env_vars and self._connection['deploy_mode'] == "cluster": raise AirflowException( "SparkSubmitHook env_vars is not supported in standalone-cluster mode.") - if self._is_kubernetes: + if self._is_kubernetes and self._connection['namespace']: connection_cmd += ["--conf", "spark.kubernetes.namespace={}".format( self._connection['namespace'])] if self._files: diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py index b431b9f703928b..ab0ab085b87d73 100644 --- a/airflow/contrib/hooks/ssh_hook.py +++ b/airflow/contrib/hooks/ssh_hook.py @@ -29,7 +29,6 @@ from airflow.hooks.base_hook import BaseHook -# noinspection PyAbstractClass class SSHHook(BaseHook): """ Hook for ssh remote execution using Paramiko. diff --git a/airflow/contrib/hooks/wasb_hook.py b/airflow/contrib/hooks/wasb_hook.py index d3a766cf69de74..077f8d92205e1f 100644 --- a/airflow/contrib/hooks/wasb_hook.py +++ b/airflow/contrib/hooks/wasb_hook.py @@ -17,7 +17,14 @@ # specific language governing permissions and limitations # under the License. # - +""" +This module contains integration with Azure Blob Storage. + +It communicate via the Window Azure Storage Blob protocol. Make sure that a +Airflow connection of type `wasb` exists. Authorization can be done by supplying a +login (=Storage account name) and password (=KEY), or login and SAS token in the extra +field (see connection `wasb_default` for an example). +""" from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook @@ -26,7 +33,7 @@ class WasbHook(BaseHook): """ - Interacts with Azure Blob Storage through the wasb:// protocol. + Interacts with Azure Blob Storage through the ``wasb://`` protocol. Additional options passed in the 'extra' field of the connection will be passed to the `BlockBlockService()` constructor. For example, authenticate diff --git a/airflow/contrib/operators/aws_sqs_publish_operator.py b/airflow/contrib/operators/aws_sqs_publish_operator.py index 072339838a68cd..0bf2f7f84c6600 100644 --- a/airflow/contrib/operators/aws_sqs_publish_operator.py +++ b/airflow/contrib/operators/aws_sqs_publish_operator.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# -# Licensed to the ApachMee Software Foundation (ASF) under one + +# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. +"""Publish message to SQS queue""" from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.contrib.hooks.aws_sqs_hook import SQSHook @@ -38,7 +39,6 @@ class SQSPublishOperator(BaseOperator): :param aws_conn_id: AWS connection id (default: aws_default) :type aws_conn_id: str """ - template_fields = ('sqs_queue', 'message_content', 'delay_seconds') ui_color = '#6ad3fa' diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py index 41fd83d491866e..cfea8523abc3f9 100644 --- a/airflow/contrib/operators/bigquery_check_operator.py +++ b/airflow/contrib/operators/bigquery_check_operator.py @@ -16,182 +16,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains Google BigQuery check operator. -""" -import warnings - -from airflow.contrib.hooks.bigquery_hook import BigQueryHook -from airflow.operators.check_operator import \ - CheckOperator, ValueCheckOperator, IntervalCheckOperator -from airflow.utils.decorators import apply_defaults - - -class BigQueryCheckOperator(CheckOperator): - """ - Performs checks against BigQuery. The ``BigQueryCheckOperator`` expects - a sql query that will return a single row. Each value on that - first row is evaluated using python ``bool`` casting. If any of the - values return ``False`` the check is failed and errors out. - - Note that Python bool casting evals the following as ``False``: - - * ``False`` - * ``0`` - * Empty string (``""``) - * Empty list (``[]``) - * Empty dictionary or set (``{}``) - - Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if - the count ``== 0``. You can craft much more complex query that could, - for instance, check that the table has the same number of rows as - the source table upstream, or that the count of today's partition is - greater than yesterday's partition, or that a set of metrics are less - than 3 standard deviation for the 7 day average. - - This operator can be used as a data quality check in your pipeline, and - depending on where you put it in your DAG, you have the choice to - stop the critical path, preventing from - publishing dubious data, or on the side and receive email alerts - without stopping the progress of the DAG. - - :param sql: the sql to be executed - :type sql: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - :param use_legacy_sql: Whether to use legacy SQL (true) - or standard SQL (false). - :type use_legacy_sql: bool - """ - - template_fields = ('sql', 'gcp_conn_id', ) - template_ext = ('.sql', ) - - @apply_defaults - def __init__(self, - sql, - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - use_legacy_sql=True, - *args, **kwargs): - super().__init__(sql=sql, *args, **kwargs) - if not bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id - - self.gcp_conn_id = gcp_conn_id - self.sql = sql - self.use_legacy_sql = use_legacy_sql - - def get_db_hook(self): - return BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - use_legacy_sql=self.use_legacy_sql) - - -class BigQueryValueCheckOperator(ValueCheckOperator): - """ - Performs a simple value check using sql code. +"""This module is deprecated. Please use `airflow.gcp.operators.bigquery`.""" - :param sql: the sql to be executed - :type sql: str - :param use_legacy_sql: Whether to use legacy SQL (true) - or standard SQL (false). - :type use_legacy_sql: bool - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - """ - - template_fields = ('sql', 'gcp_conn_id', 'pass_value', ) - template_ext = ('.sql', ) - - @apply_defaults - def __init__(self, sql, - pass_value, - tolerance=None, - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - use_legacy_sql=True, - *args, **kwargs): - super().__init__( - sql=sql, pass_value=pass_value, tolerance=tolerance, - *args, **kwargs) - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id - - self.gcp_conn_id = gcp_conn_id - self.use_legacy_sql = use_legacy_sql - - def get_db_hook(self): - return BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - use_legacy_sql=self.use_legacy_sql) - - -class BigQueryIntervalCheckOperator(IntervalCheckOperator): - """ - Checks that the values of metrics given as SQL expressions are within - a certain tolerance of the ones from days_back before. - - This method constructs a query like so :: - - SELECT {metrics_threshold_dict_key} FROM {table} - WHERE {date_filter_column}= - - :param table: the table name - :type table: str - :param days_back: number of days between ds and the ds we want to check - against. Defaults to 7 days - :type days_back: int - :param metrics_threshold: a dictionary of ratios indexed by metrics, for - example 'COUNT(*)': 1.5 would require a 50 percent or less difference - between the current day, and the prior days_back. - :type metrics_threshold: dict - :param use_legacy_sql: Whether to use legacy SQL (true) - or standard SQL (false). - :type use_legacy_sql: bool - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - """ - - template_fields = ('table', 'gcp_conn_id', ) - - @apply_defaults - def __init__(self, - table, - metrics_thresholds, - date_filter_column='ds', - days_back=-7, - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - use_legacy_sql=True, *args, **kwargs): - super().__init__( - table=table, metrics_thresholds=metrics_thresholds, - date_filter_column=date_filter_column, days_back=days_back, - *args, **kwargs) - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id - - self.gcp_conn_id = gcp_conn_id - self.use_legacy_sql = use_legacy_sql +import warnings - def get_db_hook(self): - return BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - use_legacy_sql=self.use_legacy_sql) +# pylint: disable=unused-import +from airflow.gcp.operators.bigquery import ( # noqa + BigQueryCheckOperator, + BigQueryIntervalCheckOperator, + BigQueryValueCheckOperator +) + +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.bigquery`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/bigquery_get_data.py b/airflow/contrib/operators/bigquery_get_data.py index 48a4d3095f9b1d..3b03e69c6d7f77 100644 --- a/airflow/contrib/operators/bigquery_get_data.py +++ b/airflow/contrib/operators/bigquery_get_data.py @@ -16,116 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains a Google BigQuery data operator. -""" -import warnings - -from airflow.contrib.hooks.bigquery_hook import BigQueryHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class BigQueryGetDataOperator(BaseOperator): - """ - Fetches the data from a BigQuery table (alternatively fetch data for selected columns) - and returns data in a python list. The number of elements in the returned list will - be equal to the number of rows fetched. Each element in the list will again be a list - where element would represent the columns values for that row. - - **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]`` - - .. note:: - If you pass fields to ``selected_fields`` which are in different order than the - order of columns already in - BQ table, the data will still be in the order of BQ table. - For example if the BQ table has 3 columns as - ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields`` - the data would still be of the form ``'A,B'``. - - **Example**: :: - - get_data = BigQueryGetDataOperator( - task_id='get_data_from_bq', - dataset_id='test_dataset', - table_id='Transaction_partitions', - max_results='100', - selected_fields='DATE', - gcp_conn_id='airflow-conn-id' - ) +"""This module is deprecated. Please use `airflow.gcp.operators.bigquery`.""" - :param dataset_id: The dataset ID of the requested table. (templated) - :type dataset_id: str - :param table_id: The table ID of the requested table. (templated) - :type table_id: str - :param max_results: The maximum number of records (rows) to be fetched - from the table. (templated) - :type max_results: str - :param selected_fields: List of fields to return (comma-separated). If - unspecified, all fields are returned. - :type selected_fields: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - """ - template_fields = ('dataset_id', 'table_id', 'max_results') - ui_color = '#e4f0e8' - - @apply_defaults - def __init__(self, - dataset_id, - table_id, - max_results='100', - selected_fields=None, - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - delegate_to=None, - *args, - **kwargs): - super().__init__(*args, **kwargs) - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id - - self.dataset_id = dataset_id - self.table_id = table_id - self.max_results = max_results - self.selected_fields = selected_fields - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - - def execute(self, context): - self.log.info('Fetching Data from:') - self.log.info('Dataset: %s ; Table: %s ; Max Results: %s', - self.dataset_id, self.table_id, self.max_results) - - hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - - conn = hook.get_conn() - cursor = conn.cursor() - response = cursor.get_tabledata(dataset_id=self.dataset_id, - table_id=self.table_id, - max_results=self.max_results, - selected_fields=self.selected_fields) - - self.log.info('Total Extracted rows: %s', response['totalRows']) - rows = response['rows'] +import warnings - table_data = [] - for dict_row in rows: - single_row = [] - for fields in dict_row['f']: - single_row.append(fields['v']) - table_data.append(single_row) +# pylint: disable=unused-import +from airflow.gcp.operators.bigquery import BigQueryGetDataOperator # noqa - return table_data +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.bigquery`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py index fa791a4e22ac10..26cd39fa11ed4e 100644 --- a/airflow/contrib/operators/bigquery_operator.py +++ b/airflow/contrib/operators/bigquery_operator.py @@ -16,926 +16,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains Google BigQuery operators. -""" +"""This module is deprecated. Please use `airflow.gcp.operators.bigquery`.""" -import json import warnings -from typing import Iterable, List, Optional, Union -from airflow.contrib.hooks.bigquery_hook import BigQueryHook -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook, _parse_gcs_url -from airflow.exceptions import AirflowException -from airflow.models.baseoperator import BaseOperator, BaseOperatorLink -from airflow.models.taskinstance import TaskInstance -from airflow.utils.decorators import apply_defaults - -BIGQUERY_JOB_DETAILS_LINK_FMT = 'https://console.cloud.google.com/bigquery?j={job_id}' - - -class BigQueryConsoleLink(BaseOperatorLink): - """ - Helper class for constructing BigQuery link. - """ - name = 'BigQuery Console' - - def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') - return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else '' - - -class BigQueryConsoleIndexableLink(BaseOperatorLink): - """ - Helper class for constructing BigQuery link. - """ - - def __init__(self, index) -> None: - super().__init__() - self.index = index - - @property - def name(self) -> str: - return 'BigQuery Console #{index}'.format(index=self.index + 1) - - def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - job_ids = ti.xcom_pull(task_ids=operator.task_id, key='job_id') - if not job_ids: - return None - if len(job_ids) < self.index: - return None - job_id = job_ids[self.index] - return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) - - -# pylint: disable=too-many-instance-attributes -class BigQueryOperator(BaseOperator): - """ - Executes BigQuery SQL queries in a specific BigQuery database - - :param sql: the sql code to be executed (templated) - :type sql: Can receive a str representing a sql statement, - a list of str (sql statements), or reference to a template file. - Template reference are recognized by str ending in '.sql'. - :param destination_dataset_table: A dotted - ``(.|:).
`` that, if set, will store the results - of the query. (templated) - :type destination_dataset_table: str - :param write_disposition: Specifies the action that occurs if the destination table - already exists. (default: 'WRITE_EMPTY') - :type write_disposition: str - :param create_disposition: Specifies whether the job is allowed to create new tables. - (default: 'CREATE_IF_NEEDED') - :type create_disposition: str - :param allow_large_results: Whether to allow large results. - :type allow_large_results: bool - :param flatten_results: If true and query uses legacy SQL dialect, flattens - all nested and repeated fields in the query results. ``allow_large_results`` - must be ``true`` if this is set to ``false``. For standard SQL queries, this - flag is ignored and results are never flattened. - :type flatten_results: bool - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - :param udf_config: The User Defined Function configuration for the query. - See https://cloud.google.com/bigquery/user-defined-functions for details. - :type udf_config: list - :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). - :type use_legacy_sql: bool - :param maximum_billing_tier: Positive integer that serves as a multiplier - of the basic price. - Defaults to None, in which case it uses the value set in the project. - :type maximum_billing_tier: int - :param maximum_bytes_billed: Limits the bytes billed for this job. - Queries that will have bytes billed beyond this limit will fail - (without incurring a charge). If unspecified, this will be - set to your project default. - :type maximum_bytes_billed: float - :param api_resource_configs: a dictionary that contain params - 'configuration' applied for Google BigQuery Jobs API: - https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs - for example, {'query': {'useQueryCache': False}}. You could use it - if you need to provide some params that are not supported by BigQueryOperator - like args. - :type api_resource_configs: dict - :param schema_update_options: Allows the schema of the destination - table to be updated as a side effect of the load job. - :type schema_update_options: Optional[Union[list, tuple, set]] - :param query_params: a list of dictionary containing query parameter types and - values, passed to BigQuery. The structure of dictionary should look like - 'queryParameters' in Google BigQuery Jobs API: - https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs. - For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' }, - 'parameterValue': { 'value': 'romeoandjuliet' } }]. - :type query_params: list - :param labels: a dictionary containing labels for the job/query, - passed to BigQuery - :type labels: dict - :param priority: Specifies a priority for the query. - Possible values include INTERACTIVE and BATCH. - The default value is INTERACTIVE. - :type priority: str - :param time_partitioning: configure optional time partitioning fields i.e. - partition by field, type and expiration as per API specifications. - :type time_partitioning: dict - :param cluster_fields: Request that the result of this query be stored sorted - by one or more columns. This is only available in conjunction with - time_partitioning. The order of columns given determines the sort order. - :type cluster_fields: list[str] - :param location: The geographic location of the job. Required except for - US and EU. See details at - https://cloud.google.com/bigquery/docs/locations#specifying_your_location - :type location: str - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - """ - - template_fields = ('sql', 'destination_dataset_table', 'labels') - template_ext = ('.sql', ) - ui_color = '#e4f0e8' - - @property - def operator_extra_links(self): - """ - Return operator extra links - """ - if isinstance(self.sql, str): - return ( - BigQueryConsoleLink(), - ) - return ( - BigQueryConsoleIndexableLink(i) for i, _ in enumerate(self.sql) - ) - - # pylint: disable=too-many-arguments, too-many-locals - @apply_defaults - def __init__(self, - sql: Union[str, Iterable], - destination_dataset_table: Optional[str] = None, - write_disposition: Optional[str] = 'WRITE_EMPTY', - allow_large_results: Optional[bool] = False, - flatten_results: Optional[bool] = None, - gcp_conn_id: Optional[str] = 'google_cloud_default', - bigquery_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - udf_config: Optional[list] = None, - use_legacy_sql: Optional[bool] = True, - maximum_billing_tier: Optional[int] = None, - maximum_bytes_billed: Optional[float] = None, - create_disposition: Optional[str] = 'CREATE_IF_NEEDED', - schema_update_options: Optional[Union[list, tuple, set]] = None, - query_params: Optional[list] = None, - labels: Optional[dict] = None, - priority: Optional[str] = 'INTERACTIVE', - time_partitioning: Optional[dict] = None, - api_resource_configs: Optional[dict] = None, - cluster_fields: Optional[List[str]] = None, - location: Optional[str] = None, - encryption_configuration=None, - *args, - **kwargs): - super().__init__(*args, **kwargs) - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id - - self.sql = sql - self.destination_dataset_table = destination_dataset_table - self.write_disposition = write_disposition - self.create_disposition = create_disposition - self.allow_large_results = allow_large_results - self.flatten_results = flatten_results - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - self.udf_config = udf_config - self.use_legacy_sql = use_legacy_sql - self.maximum_billing_tier = maximum_billing_tier - self.maximum_bytes_billed = maximum_bytes_billed - self.schema_update_options = schema_update_options - self.query_params = query_params - self.labels = labels - self.bq_cursor = None - self.priority = priority - self.time_partitioning = time_partitioning - self.api_resource_configs = api_resource_configs - self.cluster_fields = cluster_fields - self.location = location - self.encryption_configuration = encryption_configuration - - def execute(self, context): - if self.bq_cursor is None: - self.log.info('Executing: %s', self.sql) - hook = BigQueryHook( - bigquery_conn_id=self.gcp_conn_id, - use_legacy_sql=self.use_legacy_sql, - delegate_to=self.delegate_to, - location=self.location, - ) - conn = hook.get_conn() - self.bq_cursor = conn.cursor() - if isinstance(self.sql, str): - job_id = self.bq_cursor.run_query( - sql=self.sql, - destination_dataset_table=self.destination_dataset_table, - write_disposition=self.write_disposition, - allow_large_results=self.allow_large_results, - flatten_results=self.flatten_results, - udf_config=self.udf_config, - maximum_billing_tier=self.maximum_billing_tier, - maximum_bytes_billed=self.maximum_bytes_billed, - create_disposition=self.create_disposition, - query_params=self.query_params, - labels=self.labels, - schema_update_options=self.schema_update_options, - priority=self.priority, - time_partitioning=self.time_partitioning, - api_resource_configs=self.api_resource_configs, - cluster_fields=self.cluster_fields, - encryption_configuration=self.encryption_configuration - ) - elif isinstance(self.sql, Iterable): - job_id = [ - self.bq_cursor.run_query( - sql=s, - destination_dataset_table=self.destination_dataset_table, - write_disposition=self.write_disposition, - allow_large_results=self.allow_large_results, - flatten_results=self.flatten_results, - udf_config=self.udf_config, - maximum_billing_tier=self.maximum_billing_tier, - maximum_bytes_billed=self.maximum_bytes_billed, - create_disposition=self.create_disposition, - query_params=self.query_params, - labels=self.labels, - schema_update_options=self.schema_update_options, - priority=self.priority, - time_partitioning=self.time_partitioning, - api_resource_configs=self.api_resource_configs, - cluster_fields=self.cluster_fields, - encryption_configuration=self.encryption_configuration - ) - for s in self.sql] - else: - raise AirflowException( - "argument 'sql' of type {} is neither a string nor an iterable".format(type(str))) - context['task_instance'].xcom_push(key='job_id', value=job_id) - - def on_kill(self): - super().on_kill() - if self.bq_cursor is not None: - self.log.info('Cancelling running query') - self.bq_cursor.cancel_query() - - -class BigQueryCreateEmptyTableOperator(BaseOperator): - """ - Creates a new, empty table in the specified BigQuery dataset, - optionally with schema. - - The schema to be used for the BigQuery table may be specified in one of - two ways. You may either directly pass the schema fields in, or you may - point the operator to a Google cloud storage object name. The object in - Google cloud storage must be a JSON file with the schema fields in it. - You can also create a table without schema. - - :param project_id: The project to create the table into. (templated) - :type project_id: str - :param dataset_id: The dataset to create the table into. (templated) - :type dataset_id: str - :param table_id: The Name of the table to be created. (templated) - :type table_id: str - :param schema_fields: If set, the schema field list as defined here: - https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema - - **Example**: :: - - schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] - - :type schema_fields: list - :param gcs_schema_object: Full path to the JSON file containing - schema (templated). For - example: ``gs://test-bucket/dir1/dir2/employee_schema.json`` - :type gcs_schema_object: str - :param time_partitioning: configure optional time partitioning fields i.e. - partition by field, type and expiration as per API specifications. - - .. seealso:: - https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning - :type time_partitioning: dict - :param bigquery_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform and - interact with the Bigquery service. - :type bigquery_conn_id: str - :param google_cloud_storage_conn_id: (Optional) The connection ID used to connect to Google Cloud - Platform and interact with the Google Cloud Storage service. - :type google_cloud_storage_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to - work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - :param labels: a dictionary containing labels for the table, passed to BigQuery - - **Example (with schema JSON in GCS)**: :: - - CreateTable = BigQueryCreateEmptyTableOperator( - task_id='BigQueryCreateEmptyTableOperator_task', - dataset_id='ODS', - table_id='Employees', - project_id='internal-gcp-project', - gcs_schema_object='gs://schema-bucket/employee_schema.json', - bigquery_conn_id='airflow-conn-id', - google_cloud_storage_conn_id='airflow-conn-id' - ) - - **Corresponding Schema file** (``employee_schema.json``): :: - - [ - { - "mode": "NULLABLE", - "name": "emp_name", - "type": "STRING" - }, - { - "mode": "REQUIRED", - "name": "salary", - "type": "INTEGER" - } - ] - - **Example (with schema in the DAG)**: :: - - CreateTable = BigQueryCreateEmptyTableOperator( - task_id='BigQueryCreateEmptyTableOperator_task', - dataset_id='ODS', - table_id='Employees', - project_id='internal-gcp-project', - schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}], - bigquery_conn_id='airflow-conn-id-account', - google_cloud_storage_conn_id='airflow-conn-id' - ) - :type labels: dict - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - """ - template_fields = ('dataset_id', 'table_id', 'project_id', - 'gcs_schema_object', 'labels') - ui_color = '#f0eee4' - - # pylint: disable=too-many-arguments - @apply_defaults - def __init__(self, - dataset_id, - table_id, - project_id=None, - schema_fields=None, - gcs_schema_object=None, - time_partitioning=None, - bigquery_conn_id='google_cloud_default', - google_cloud_storage_conn_id='google_cloud_default', - delegate_to=None, - labels=None, - encryption_configuration=None, - *args, **kwargs): - - super().__init__(*args, **kwargs) - - self.project_id = project_id - self.dataset_id = dataset_id - self.table_id = table_id - self.schema_fields = schema_fields - self.gcs_schema_object = gcs_schema_object - self.bigquery_conn_id = bigquery_conn_id - self.google_cloud_storage_conn_id = google_cloud_storage_conn_id - self.delegate_to = delegate_to - self.time_partitioning = {} if time_partitioning is None else time_partitioning - self.labels = labels - self.encryption_configuration = encryption_configuration - - def execute(self, context): - bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, - delegate_to=self.delegate_to) - - if not self.schema_fields and self.gcs_schema_object: - - gcs_bucket, gcs_object = _parse_gcs_url(self.gcs_schema_object) - - gcs_hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.google_cloud_storage_conn_id, - delegate_to=self.delegate_to) - schema_fields = json.loads(gcs_hook.download( - gcs_bucket, - gcs_object).decode("utf-8")) - else: - schema_fields = self.schema_fields - - conn = bq_hook.get_conn() - cursor = conn.cursor() - - cursor.create_empty_table( - project_id=self.project_id, - dataset_id=self.dataset_id, - table_id=self.table_id, - schema_fields=schema_fields, - time_partitioning=self.time_partitioning, - labels=self.labels, - encryption_configuration=self.encryption_configuration - ) - - -# pylint: disable=too-many-instance-attributes -class BigQueryCreateExternalTableOperator(BaseOperator): - """ - Creates a new external table in the dataset with the data in Google Cloud - Storage. - - The schema to be used for the BigQuery table may be specified in one of - two ways. You may either directly pass the schema fields in, or you may - point the operator to a Google cloud storage object name. The object in - Google cloud storage must be a JSON file with the schema fields in it. - - :param bucket: The bucket to point the external table to. (templated) - :type bucket: str - :param source_objects: List of Google cloud storage URIs to point - table to. (templated) - If source_format is 'DATASTORE_BACKUP', the list must only contain a single URI. - :type source_objects: list - :param destination_project_dataset_table: The dotted ``(.).
`` - BigQuery table to load data into (templated). If ```` is not included, - project will be the project defined in the connection json. - :type destination_project_dataset_table: str - :param schema_fields: If set, the schema field list as defined here: - https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema - - **Example**: :: - - schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] - - Should not be set when source_format is 'DATASTORE_BACKUP'. - :type schema_fields: list - :param schema_object: If set, a GCS object path pointing to a .json file that - contains the schema for the table. (templated) - :type schema_object: str - :param source_format: File format of the data. - :type source_format: str - :param compression: [Optional] The compression type of the data source. - Possible values include GZIP and NONE. - The default value is NONE. - This setting is ignored for Google Cloud Bigtable, - Google Cloud Datastore backups and Avro formats. - :type compression: str - :param skip_leading_rows: Number of rows to skip when loading from a CSV. - :type skip_leading_rows: int - :param field_delimiter: The delimiter to use for the CSV. - :type field_delimiter: str - :param max_bad_records: The maximum number of bad records that BigQuery can - ignore when running the job. - :type max_bad_records: int - :param quote_character: The value that is used to quote data sections in a CSV file. - :type quote_character: str - :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not (false). - :type allow_quoted_newlines: bool - :param allow_jagged_rows: Accept rows that are missing trailing optional columns. - The missing values are treated as nulls. If false, records with missing trailing - columns are treated as bad records, and if there are too many bad records, an - invalid error is returned in the job result. Only applicable to CSV, ignored - for other formats. - :type allow_jagged_rows: bool - :param bigquery_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform and - interact with the Bigquery service. - :type bigquery_conn_id: str - :param google_cloud_storage_conn_id: (Optional) The connection ID used to connect to Google Cloud - Platform and interact with the Google Cloud Storage service. - cloud storage hook. - :type google_cloud_storage_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to - work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - :param src_fmt_configs: configure optional fields specific to the source format - :type src_fmt_configs: dict - :param labels: a dictionary containing labels for the table, passed to BigQuery - :type labels: dict - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - """ - template_fields = ('bucket', 'source_objects', - 'schema_object', 'destination_project_dataset_table', 'labels') - ui_color = '#f0eee4' - - # pylint: disable=too-many-arguments - @apply_defaults - def __init__(self, - bucket, - source_objects, - destination_project_dataset_table, - schema_fields=None, - schema_object=None, - source_format='CSV', - compression='NONE', - skip_leading_rows=0, - field_delimiter=',', - max_bad_records=0, - quote_character=None, - allow_quoted_newlines=False, - allow_jagged_rows=False, - bigquery_conn_id='google_cloud_default', - google_cloud_storage_conn_id='google_cloud_default', - delegate_to=None, - src_fmt_configs=None, - labels=None, - encryption_configuration=None, - *args, **kwargs): - - super().__init__(*args, **kwargs) - - # GCS config - self.bucket = bucket - self.source_objects = source_objects - self.schema_object = schema_object - - # BQ config - self.destination_project_dataset_table = destination_project_dataset_table - self.schema_fields = schema_fields - self.source_format = source_format - self.compression = compression - self.skip_leading_rows = skip_leading_rows - self.field_delimiter = field_delimiter - self.max_bad_records = max_bad_records - self.quote_character = quote_character - self.allow_quoted_newlines = allow_quoted_newlines - self.allow_jagged_rows = allow_jagged_rows - - self.bigquery_conn_id = bigquery_conn_id - self.google_cloud_storage_conn_id = google_cloud_storage_conn_id - self.delegate_to = delegate_to - - self.src_fmt_configs = src_fmt_configs if src_fmt_configs is not None else dict() - self.labels = labels - self.encryption_configuration = encryption_configuration - - def execute(self, context): - bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, - delegate_to=self.delegate_to) - - if not self.schema_fields and self.schema_object \ - and self.source_format != 'DATASTORE_BACKUP': - gcs_hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.google_cloud_storage_conn_id, - delegate_to=self.delegate_to) - schema_fields = json.loads(gcs_hook.download( - self.bucket, - self.schema_object).decode("utf-8")) - else: - schema_fields = self.schema_fields - - source_uris = ['gs://{}/{}'.format(self.bucket, source_object) - for source_object in self.source_objects] - conn = bq_hook.get_conn() - cursor = conn.cursor() - - cursor.create_external_table( - external_project_dataset_table=self.destination_project_dataset_table, - schema_fields=schema_fields, - source_uris=source_uris, - source_format=self.source_format, - compression=self.compression, - skip_leading_rows=self.skip_leading_rows, - field_delimiter=self.field_delimiter, - max_bad_records=self.max_bad_records, - quote_character=self.quote_character, - allow_quoted_newlines=self.allow_quoted_newlines, - allow_jagged_rows=self.allow_jagged_rows, - src_fmt_configs=self.src_fmt_configs, - labels=self.labels, - encryption_configuration=self.encryption_configuration - ) - - -class BigQueryDeleteDatasetOperator(BaseOperator): - """ - This operator deletes an existing dataset from your Project in Big query. - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/delete - - :param project_id: The project id of the dataset. - :type project_id: str - :param dataset_id: The dataset to be deleted. - :type dataset_id: str - :param delete_contents: (Optional) Whether to force the deletion even if the dataset is not empty. - Will delete all tables (if any) in the dataset if set to True. - Will raise HttpError 400: "{dataset_id} is still in use" if set to False and dataset is not empty. - The default value is False. - :type delete_contents: bool - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - - **Example**: :: - - delete_temp_data = BigQueryDeleteDatasetOperator( - dataset_id='temp-dataset', - project_id='temp-project', - delete_contents=True, # Force the deletion of the dataset as well as its tables (if any). - gcp_conn_id='_my_gcp_conn_', - task_id='Deletetemp', - dag=dag) - """ - - template_fields = ('dataset_id', 'project_id') - ui_color = '#f00004' - - @apply_defaults - def __init__(self, - dataset_id, - project_id=None, - delete_contents=False, - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - delegate_to=None, - *args, **kwargs): - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id - - self.dataset_id = dataset_id - self.project_id = project_id - self.delete_contents = delete_contents - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - - self.log.info('Dataset id: %s', self.dataset_id) - self.log.info('Project id: %s', self.project_id) - - super().__init__(*args, **kwargs) - - def execute(self, context): - bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - - conn = bq_hook.get_conn() - cursor = conn.cursor() - - cursor.delete_dataset( - project_id=self.project_id, - dataset_id=self.dataset_id, - delete_contents=self.delete_contents - ) - - -class BigQueryCreateEmptyDatasetOperator(BaseOperator): - """ - This operator is used to create new dataset for your Project in Big query. - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - - :param project_id: The name of the project where we want to create the dataset. - Don't need to provide, if projectId in dataset_reference. - :type project_id: str - :param dataset_id: The id of dataset. Don't need to provide, - if datasetId in dataset_reference. - :type dataset_id: str - :param dataset_reference: Dataset reference that could be provided with request body. - More info: - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - :type dataset_reference: dict - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - **Example**: :: - - create_new_dataset = BigQueryCreateEmptyDatasetOperator( - dataset_id='new-dataset', - project_id='my-project', - dataset_reference={"friendlyName": "New Dataset"} - gcp_conn_id='_my_gcp_conn_', - task_id='newDatasetCreator', - dag=dag) - - """ - - template_fields = ('dataset_id', 'project_id') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - dataset_id, - project_id=None, - dataset_reference=None, - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - delegate_to=None, - *args, **kwargs): - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id - - self.dataset_id = dataset_id - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.dataset_reference = dataset_reference if dataset_reference else {} - self.delegate_to = delegate_to - - self.log.info('Dataset id: %s', self.dataset_id) - self.log.info('Project id: %s', self.project_id) - - super().__init__(*args, **kwargs) - - def execute(self, context): - bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - - conn = bq_hook.get_conn() - cursor = conn.cursor() - - cursor.create_empty_dataset( - project_id=self.project_id, - dataset_id=self.dataset_id, - dataset_reference=self.dataset_reference) - - -class BigQueryGetDatasetOperator(BaseOperator): - """ - This operator is used to return the dataset specified by dataset_id. - - :param dataset_id: The id of dataset. Don't need to provide, - if datasetId in dataset_reference. - :type dataset_id: str - :param project_id: The name of the project where we want to create the dataset. - Don't need to provide, if projectId in dataset_reference. - :type project_id: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - """ - - template_fields = ('dataset_id', 'project_id') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - dataset_id, - project_id=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - *args, **kwargs): - self.dataset_id = dataset_id - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - super().__init__(*args, **kwargs) - - def execute(self, context): - bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - conn = bq_hook.get_conn() - cursor = conn.cursor() - - self.log.info('Start getting dataset: %s:%s', self.project_id, self.dataset_id) - - return cursor.get_dataset( - dataset_id=self.dataset_id, - project_id=self.project_id) - - -class BigQueryPatchDatasetOperator(BaseOperator): - """ - This operator is used to patch dataset for your Project in BigQuery. - It only replaces fields that are provided in the submitted dataset resource. - - :param dataset_id: The id of dataset. Don't need to provide, - if datasetId in dataset_reference. - :type dataset_id: str - :param dataset_resource: Dataset resource that will be provided with request body. - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - :type dataset_resource: dict - :param project_id: The name of the project where we want to create the dataset. - Don't need to provide, if projectId in dataset_reference. - :type project_id: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - """ - - template_fields = ('dataset_id', 'project_id') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - dataset_id, - dataset_resource, - project_id=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - *args, **kwargs): - self.dataset_id = dataset_id - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.dataset_resource = dataset_resource - self.delegate_to = delegate_to - super().__init__(*args, **kwargs) - - def execute(self, context): - bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - - conn = bq_hook.get_conn() - cursor = conn.cursor() - - self.log.info('Start patching dataset: %s:%s', self.project_id, self.dataset_id) - - return cursor.patch_dataset( - dataset_id=self.dataset_id, - dataset_resource=self.dataset_resource, - project_id=self.project_id) - - -class BigQueryUpdateDatasetOperator(BaseOperator): - """ - This operator is used to update dataset for your Project in BigQuery. - The update method replaces the entire dataset resource, whereas the patch - method only replaces fields that are provided in the submitted dataset resource. - - :param dataset_id: The id of dataset. Don't need to provide, - if datasetId in dataset_reference. - :type dataset_id: str - :param dataset_resource: Dataset resource that will be provided with request body. - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - :type dataset_resource: dict - :param project_id: The name of the project where we want to create the dataset. - Don't need to provide, if projectId in dataset_reference. - :type project_id: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - """ - - template_fields = ('dataset_id', 'project_id') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - dataset_id, - dataset_resource, - project_id=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - *args, **kwargs): - self.dataset_id = dataset_id - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.dataset_resource = dataset_resource - self.delegate_to = delegate_to - super().__init__(*args, **kwargs) - - def execute(self, context): - bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - - conn = bq_hook.get_conn() - cursor = conn.cursor() - - self.log.info('Start updating dataset: %s:%s', self.project_id, self.dataset_id) - - return cursor.update_dataset( - dataset_id=self.dataset_id, - dataset_resource=self.dataset_resource, - project_id=self.project_id) +# pylint: disable=unused-import +from airflow.gcp.operators.bigquery import ( # noqa + BigQueryConsoleLink, + BigQueryConsoleIndexableLink, + BigQueryOperator, + BigQueryCreateEmptyTableOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryGetDatasetOperator, + BigQueryPatchDatasetOperator, + BigQueryUpdateDatasetOperator, +) + +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.bigquery`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/bigquery_table_delete_operator.py b/airflow/contrib/operators/bigquery_table_delete_operator.py index 2c176e6f1305c2..3ed1e28e868b51 100644 --- a/airflow/contrib/operators/bigquery_table_delete_operator.py +++ b/airflow/contrib/operators/bigquery_table_delete_operator.py @@ -16,68 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains Google BigQuery table delete operator. -""" -import warnings - -from airflow.contrib.hooks.bigquery_hook import BigQueryHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class BigQueryTableDeleteOperator(BaseOperator): - """ - Deletes BigQuery tables +"""This module is deprecated. Please use `airflow.gcp.operators.bigquery`.""" - :param deletion_dataset_table: A dotted - ``(.|:).
`` that indicates which table - will be deleted. (templated) - :type deletion_dataset_table: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - :param ignore_if_missing: if True, then return success even if the - requested table does not exist. - :type ignore_if_missing: bool - """ - template_fields = ('deletion_dataset_table',) - ui_color = '#ffd1dc' - - @apply_defaults - def __init__(self, - deletion_dataset_table, - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - delegate_to=None, - ignore_if_missing=False, - *args, - **kwargs): - super().__init__(*args, **kwargs) - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id +import warnings - self.deletion_dataset_table = deletion_dataset_table - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - self.ignore_if_missing = ignore_if_missing +# pylint: disable=unused-import +from airflow.gcp.operators.bigquery import BigQueryTableDeleteOperator # noqa - def execute(self, context): - self.log.info('Deleting: %s', self.deletion_dataset_table) - hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - conn = hook.get_conn() - cursor = conn.cursor() - cursor.run_table_delete( - deletion_dataset_table=self.deletion_dataset_table, - ignore_if_missing=self.ignore_if_missing) +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.bigquery`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/bigquery_to_bigquery.py b/airflow/contrib/operators/bigquery_to_bigquery.py index 1b5c5f0606d605..c92c63b693cef1 100644 --- a/airflow/contrib/operators/bigquery_to_bigquery.py +++ b/airflow/contrib/operators/bigquery_to_bigquery.py @@ -16,105 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains a Google BigQuery to BigQuery operator. -""" -import warnings - -from airflow.contrib.hooks.bigquery_hook import BigQueryHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class BigQueryToBigQueryOperator(BaseOperator): - """ - Copies data from one BigQuery table to another. - - .. seealso:: - For more details about these parameters: - https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.copy +"""This module is deprecated. Please use `airflow.operators.bigquery_to_bigquery`.""" - :param source_project_dataset_tables: One or more - dotted ``(project:|project.).
`` BigQuery tables to use as the - source data. If ```` is not included, project will be the - project defined in the connection json. Use a list if there are multiple - source tables. (templated) - :type source_project_dataset_tables: list|string - :param destination_project_dataset_table: The destination BigQuery - table. Format is: ``(project:|project.).
`` (templated) - :type destination_project_dataset_table: str - :param write_disposition: The write disposition if the table already exists. - :type write_disposition: str - :param create_disposition: The create disposition if the table doesn't exist. - :type create_disposition: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - :param labels: a dictionary containing labels for the job/query, - passed to BigQuery - :type labels: dict - :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). - **Example**: :: - - encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" - } - :type encryption_configuration: dict - """ - template_fields = ('source_project_dataset_tables', - 'destination_project_dataset_table', 'labels') - template_ext = ('.sql',) - ui_color = '#e6f0e4' - - @apply_defaults - def __init__(self, - source_project_dataset_tables, - destination_project_dataset_table, - write_disposition='WRITE_EMPTY', - create_disposition='CREATE_IF_NEEDED', - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - delegate_to=None, - labels=None, - encryption_configuration=None, - *args, - **kwargs): - super().__init__(*args, **kwargs) - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id +import warnings - self.source_project_dataset_tables = source_project_dataset_tables - self.destination_project_dataset_table = destination_project_dataset_table - self.write_disposition = write_disposition - self.create_disposition = create_disposition - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - self.labels = labels - self.encryption_configuration = encryption_configuration +# pylint: disable=unused-import +from airflow.operators.bigquery_to_bigquery import BigQueryToBigQueryOperator # noqa - def execute(self, context): - self.log.info( - 'Executing copy of %s into: %s', - self.source_project_dataset_tables, self.destination_project_dataset_table - ) - hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - conn = hook.get_conn() - cursor = conn.cursor() - cursor.run_copy( - source_project_dataset_tables=self.source_project_dataset_tables, - destination_project_dataset_table=self.destination_project_dataset_table, - write_disposition=self.write_disposition, - create_disposition=self.create_disposition, - labels=self.labels, - encryption_configuration=self.encryption_configuration) +warnings.warn( + "This module is deprecated. Please use `airflow.operators.bigquery_to_bigquery`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/bigquery_to_gcs.py b/airflow/contrib/operators/bigquery_to_gcs.py index bc3f333b0155ff..dd352ac8de0c9c 100644 --- a/airflow/contrib/operators/bigquery_to_gcs.py +++ b/airflow/contrib/operators/bigquery_to_gcs.py @@ -16,105 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains a Google BigQuery to GCS operator. -""" -import warnings - -from airflow.contrib.hooks.bigquery_hook import BigQueryHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class BigQueryToCloudStorageOperator(BaseOperator): - """ - Transfers a BigQuery table to a Google Cloud Storage bucket. +"""This module is deprecated. Please use `airflow.operators.bigquery_to_gcs`.""" - .. seealso:: - For more details about these parameters: - https://cloud.google.com/bigquery/docs/reference/v2/jobs - - :param source_project_dataset_table: The dotted - ``(.|:).
`` BigQuery table to use as the - source data. If ```` is not included, project will be the project - defined in the connection json. (templated) - :type source_project_dataset_table: str - :param destination_cloud_storage_uris: The destination Google Cloud - Storage URI (e.g. gs://some-bucket/some-file.txt). (templated) Follows - convention defined here: - https://cloud.google.com/bigquery/exporting-data-from-bigquery#exportingmultiple - :type destination_cloud_storage_uris: list - :param compression: Type of compression to use. - :type compression: str - :param export_format: File format to export. - :type export_format: str - :param field_delimiter: The delimiter to use when extracting to a CSV. - :type field_delimiter: str - :param print_header: Whether to print a header for a CSV file extract. - :type print_header: bool - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type bigquery_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - :param labels: a dictionary containing labels for the job/query, - passed to BigQuery - :type labels: dict - """ - template_fields = ('source_project_dataset_table', - 'destination_cloud_storage_uris', 'labels') - template_ext = () - ui_color = '#e4e6f0' - - @apply_defaults - def __init__(self, # pylint: disable=too-many-arguments - source_project_dataset_table, - destination_cloud_storage_uris, - compression='NONE', - export_format='CSV', - field_delimiter=',', - print_header=True, - gcp_conn_id='google_cloud_default', - bigquery_conn_id=None, - delegate_to=None, - labels=None, - *args, - **kwargs): - super().__init__(*args, **kwargs) - - if bigquery_conn_id: - warnings.warn( - "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = bigquery_conn_id +import warnings - self.source_project_dataset_table = source_project_dataset_table - self.destination_cloud_storage_uris = destination_cloud_storage_uris - self.compression = compression - self.export_format = export_format - self.field_delimiter = field_delimiter - self.print_header = print_header - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - self.labels = labels +# pylint: disable=unused-import +from airflow.operators.bigquery_to_gcs import BigQueryToCloudStorageOperator # noqa - def execute(self, context): - self.log.info('Executing extract of %s into: %s', - self.source_project_dataset_table, - self.destination_cloud_storage_uris) - hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - conn = hook.get_conn() - cursor = conn.cursor() - cursor.run_extract( - source_project_dataset_table=self.source_project_dataset_table, - destination_cloud_storage_uris=self.destination_cloud_storage_uris, - compression=self.compression, - export_format=self.export_format, - field_delimiter=self.field_delimiter, - print_header=self.print_header, - labels=self.labels) +warnings.warn( + "This module is deprecated. Please use `airflow.operators.bigquery_to_gcs`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/bigquery_to_mysql_operator.py b/airflow/contrib/operators/bigquery_to_mysql_operator.py index b3ea23246e5d7e..705615c88085e7 100644 --- a/airflow/contrib/operators/bigquery_to_mysql_operator.py +++ b/airflow/contrib/operators/bigquery_to_mysql_operator.py @@ -16,130 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains a Google BigQuery to MySQL operator. -""" +"""This module is deprecated. Please use `airflow.operators.bigquery_to_mysql`.""" -from airflow.contrib.hooks.bigquery_hook import BigQueryHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults -from airflow.hooks.mysql_hook import MySqlHook +import warnings +# pylint: disable=unused-import +from airflow.operators.bigquery_to_mysql import BigQueryToMySqlOperator # noqa -class BigQueryToMySqlOperator(BaseOperator): - """ - Fetches the data from a BigQuery table (alternatively fetch data for selected columns) - and insert that data into a MySQL table. - - - .. note:: - If you pass fields to ``selected_fields`` which are in different order than the - order of columns already in - BQ table, the data will still be in the order of BQ table. - For example if the BQ table has 3 columns as - ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields`` - the data would still be of the form ``'A,B'`` and passed through this form - to MySQL - - **Example**: :: - - transfer_data = BigQueryToMySqlOperator( - task_id='task_id', - dataset_table='origin_bq_table', - mysql_table='dest_table_name', - replace=True, - ) - - :param dataset_table: A dotted ``.
``: the big query table of origin - :type dataset_table: str - :param max_results: The maximum number of records (rows) to be fetched - from the table. (templated) - :type max_results: str - :param selected_fields: List of fields to return (comma-separated). If - unspecified, all fields are returned. - :type selected_fields: str - :param gcp_conn_id: reference to a specific GCP hook. - :type gcp_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - :param mysql_conn_id: reference to a specific mysql hook - :type mysql_conn_id: str - :param database: name of database which overwrite defined one in connection - :type database: str - :param replace: Whether to replace instead of insert - :type replace: bool - :param batch_size: The number of rows to take in each batch - :type batch_size: int - """ - template_fields = ('dataset_id', 'table_id', 'mysql_table') - - @apply_defaults - def __init__(self, - dataset_table, - mysql_table, - selected_fields=None, - gcp_conn_id='google_cloud_default', - mysql_conn_id='mysql_default', - database=None, - delegate_to=None, - replace=False, - batch_size=1000, - *args, - **kwargs): - super(BigQueryToMySqlOperator, self).__init__(*args, **kwargs) - self.selected_fields = selected_fields - self.gcp_conn_id = gcp_conn_id - self.mysql_conn_id = mysql_conn_id - self.database = database - self.mysql_table = mysql_table - self.replace = replace - self.delegate_to = delegate_to - self.batch_size = batch_size - try: - self.dataset_id, self.table_id = dataset_table.split('.') - except ValueError: - raise ValueError('Could not parse {} as .
' - .format(dataset_table)) - - def _bq_get_data(self): - self.log.info('Fetching Data from:') - self.log.info('Dataset: %s ; Table: %s', - self.dataset_id, self.table_id) - - hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - - conn = hook.get_conn() - cursor = conn.cursor() - i = 0 - while True: - response = cursor.get_tabledata(dataset_id=self.dataset_id, - table_id=self.table_id, - max_results=self.batch_size, - selected_fields=self.selected_fields, - start_index=i * self.batch_size) - - if 'rows' in response: - rows = response['rows'] - else: - self.log.info('Job Finished') - return - - self.log.info('Total Extracted rows: %s', len(rows) + i * self.batch_size) - - table_data = [] - for dict_row in rows: - single_row = [] - for fields in dict_row['f']: - single_row.append(fields['v']) - table_data.append(single_row) - - yield table_data - i += 1 - - def execute(self, context): - mysql_hook = MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id) - for rows in self._bq_get_data(): - mysql_hook.insert_rows(self.mysql_table, rows, replace=self.replace) +warnings.warn( + "This module is deprecated. Please use `airflow.operators.bigquery_to_mysql`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/cassandra_to_gcs.py b/airflow/contrib/operators/cassandra_to_gcs.py index a941675b0ae5f7..6c3981127916e8 100644 --- a/airflow/contrib/operators/cassandra_to_gcs.py +++ b/airflow/contrib/operators/cassandra_to_gcs.py @@ -17,349 +17,15 @@ # specific language governing permissions and limitations # under the License. """ -This module contains operator for copying -data from Cassandra to Google cloud storage in JSON format. +This module is deprecated. Please use `airflow.operators.cassandra_to_gcs`. """ -import json -import warnings -from base64 import b64encode -from datetime import datetime -from decimal import Decimal -from tempfile import NamedTemporaryFile -from uuid import UUID - -from cassandra.util import Date, Time, SortedSet, OrderedMapSerializedKey - -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.contrib.hooks.cassandra_hook import CassandraHook -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class CassandraToGoogleCloudStorageOperator(BaseOperator): - """ - Copy data from Cassandra to Google cloud storage in JSON format - - Note: Arrays of arrays are not supported. - - :param cql: The CQL to execute on the Cassandra table. - :type cql: str - :param bucket: The bucket to upload to. - :type bucket: str - :param filename: The filename to use as the object name when uploading - to Google cloud storage. A {} should be specified in the filename - to allow the operator to inject file numbers in cases where the - file is split due to size. - :type filename: str - :param schema_filename: If set, the filename to use as the object name - when uploading a .json file containing the BigQuery schema fields - for the table that was dumped from MySQL. - :type schema_filename: str - :param approx_max_file_size_bytes: This operator supports the ability - to split large table dumps into multiple files (see notes in the - filename param docs above). This param allows developers to specify the - file size of the splits. Check https://cloud.google.com/storage/quotas - to see the maximum allowed file size for a single object. - :type approx_max_file_size_bytes: long - :param cassandra_conn_id: Reference to a specific Cassandra hook. - :type cassandra_conn_id: str - :param gzip: Option to compress file for upload - :type gzip: bool - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud - Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type google_cloud_storage_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to - work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - """ - template_fields = ('cql', 'bucket', 'filename', 'schema_filename',) - template_ext = ('.cql',) - ui_color = '#a0e08c' - - @apply_defaults - def __init__(self, - cql, - bucket, - filename, - schema_filename=None, - approx_max_file_size_bytes=1900000000, - gzip=False, - cassandra_conn_id='cassandra_default', - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - *args, - **kwargs): - super().__init__(*args, **kwargs) - - if google_cloud_storage_conn_id: - warnings.warn( - "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = google_cloud_storage_conn_id - - self.cql = cql - self.bucket = bucket - self.filename = filename - self.schema_filename = schema_filename - self.approx_max_file_size_bytes = approx_max_file_size_bytes - self.cassandra_conn_id = cassandra_conn_id - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - self.gzip = gzip - - self.hook = None - - # Default Cassandra to BigQuery type mapping - CQL_TYPE_MAP = { - 'BytesType': 'BYTES', - 'DecimalType': 'FLOAT', - 'UUIDType': 'BYTES', - 'BooleanType': 'BOOL', - 'ByteType': 'INTEGER', - 'AsciiType': 'STRING', - 'FloatType': 'FLOAT', - 'DoubleType': 'FLOAT', - 'LongType': 'INTEGER', - 'Int32Type': 'INTEGER', - 'IntegerType': 'INTEGER', - 'InetAddressType': 'STRING', - 'CounterColumnType': 'INTEGER', - 'DateType': 'TIMESTAMP', - 'SimpleDateType': 'DATE', - 'TimestampType': 'TIMESTAMP', - 'TimeUUIDType': 'BYTES', - 'ShortType': 'INTEGER', - 'TimeType': 'TIME', - 'DurationType': 'INTEGER', - 'UTF8Type': 'STRING', - 'VarcharType': 'STRING', - } - - def execute(self, context): - cursor = self._query_cassandra() - files_to_upload = self._write_local_data_files(cursor) - - # If a schema is set, create a BQ schema JSON file. - if self.schema_filename: - files_to_upload.update(self._write_local_schema_file(cursor)) - - # Flush all files before uploading - for file_handle in files_to_upload.values(): - file_handle.flush() - - self._upload_to_gcs(files_to_upload) - - # Close all temp file handles. - for file_handle in files_to_upload.values(): - file_handle.close() - - # Close all sessions and connection associated with this Cassandra cluster - self.hook.shutdown_cluster() - - def _query_cassandra(self): - """ - Queries cassandra and returns a cursor to the results. - """ - self.hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id) - session = self.hook.get_conn() - cursor = session.execute(self.cql) - return cursor - - def _write_local_data_files(self, cursor): - """ - Takes a cursor, and writes results to a local file. - - :return: A dictionary where keys are filenames to be used as object - names in GCS, and values are file handles to local files that - contain the data for the GCS objects. - """ - file_no = 0 - tmp_file_handle = NamedTemporaryFile(delete=True) - tmp_file_handles = {self.filename.format(file_no): tmp_file_handle} - for row in cursor: - row_dict = self.generate_data_dict(row._fields, row) - s = json.dumps(row_dict).encode('utf-8') - tmp_file_handle.write(s) - - # Append newline to make dumps BigQuery compatible. - tmp_file_handle.write(b'\n') - - if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: - file_no += 1 - tmp_file_handle = NamedTemporaryFile(delete=True) - tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle - return tmp_file_handles - - def _write_local_schema_file(self, cursor): - """ - Takes a cursor, and writes the BigQuery schema for the results to a - local file system. - - :return: A dictionary where key is a filename to be used as an object - name in GCS, and values are file handles to local files that - contains the BigQuery schema fields in .json format. - """ - schema = [] - tmp_schema_file_handle = NamedTemporaryFile(delete=True) - - for name, type in zip(cursor.column_names, cursor.column_types): - schema.append(self.generate_schema_dict(name, type)) - json_serialized_schema = json.dumps(schema).encode('utf-8') - - tmp_schema_file_handle.write(json_serialized_schema) - return {self.schema_filename: tmp_schema_file_handle} - - def _upload_to_gcs(self, files_to_upload): - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to) - for object, tmp_file_handle in files_to_upload.items(): - hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json', self.gzip) - - @classmethod - def generate_data_dict(cls, names, values): - row_dict = {} - for name, value in zip(names, values): - row_dict.update({name: cls.convert_value(name, value)}) - return row_dict - - @classmethod - def convert_value(cls, name, value): - if not value: - return value - elif isinstance(value, (str, int, float, bool, dict)): - return value - elif isinstance(value, bytes): - return b64encode(value).decode('ascii') - elif isinstance(value, UUID): - return b64encode(value.bytes).decode('ascii') - elif isinstance(value, (datetime, Date)): - return str(value) - elif isinstance(value, Decimal): - return float(value) - elif isinstance(value, Time): - return str(value).split('.')[0] - elif isinstance(value, (list, SortedSet)): - return cls.convert_array_types(name, value) - elif hasattr(value, '_fields'): - return cls.convert_user_type(name, value) - elif isinstance(value, tuple): - return cls.convert_tuple_type(name, value) - elif isinstance(value, OrderedMapSerializedKey): - return cls.convert_map_type(name, value) - else: - raise AirflowException('unexpected value: ' + str(value)) - - @classmethod - def convert_array_types(cls, name, value): - return [cls.convert_value(name, nested_value) for nested_value in value] - - @classmethod - def convert_user_type(cls, name, value): - """ - Converts a user type to RECORD that contains n fields, where n is the - number of attributes. Each element in the user type class will be converted to its - corresponding data type in BQ. - """ - names = value._fields - values = [cls.convert_value(name, getattr(value, name)) for name in names] - return cls.generate_data_dict(names, values) - - @classmethod - def convert_tuple_type(cls, name, value): - """ - Converts a tuple to RECORD that contains n fields, each will be converted - to its corresponding data type in bq and will be named 'field_', where - index is determined by the order of the tuple elements defined in cassandra. - """ - names = ['field_' + str(i) for i in range(len(value))] - values = [cls.convert_value(name, value) for name, value in zip(names, value)] - return cls.generate_data_dict(names, values) - - @classmethod - def convert_map_type(cls, name, value): - """ - Converts a map to a repeated RECORD that contains two fields: 'key' and 'value', - each will be converted to its corresponding data type in BQ. - """ - converted_map = [] - for k, v in zip(value.keys(), value.values()): - converted_map.append({ - 'key': cls.convert_value('key', k), - 'value': cls.convert_value('value', v) - }) - return converted_map - - @classmethod - def generate_schema_dict(cls, name, type): - field_schema = dict() - field_schema.update({'name': name}) - field_schema.update({'type': cls.get_bq_type(type)}) - field_schema.update({'mode': cls.get_bq_mode(type)}) - fields = cls.get_bq_fields(name, type) - if fields: - field_schema.update({'fields': fields}) - return field_schema - - @classmethod - def get_bq_fields(cls, name, type): - fields = [] - - if not cls.is_simple_type(type): - names, types = [], [] - - if cls.is_array_type(type) and cls.is_record_type(type.subtypes[0]): - names = type.subtypes[0].fieldnames - types = type.subtypes[0].subtypes - elif cls.is_record_type(type): - names = type.fieldnames - types = type.subtypes - - if types and not names and type.cassname == 'TupleType': - names = ['field_' + str(i) for i in range(len(types))] - elif types and not names and type.cassname == 'MapType': - names = ['key', 'value'] - - for name, type in zip(names, types): - field = cls.generate_schema_dict(name, type) - fields.append(field) - - return fields - - @classmethod - def is_simple_type(cls, type): - return type.cassname in CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP - - @classmethod - def is_array_type(cls, type): - return type.cassname in ['ListType', 'SetType'] - - @classmethod - def is_record_type(cls, type): - return type.cassname in ['UserType', 'TupleType', 'MapType'] +import warnings - @classmethod - def get_bq_type(cls, type): - if cls.is_simple_type(type): - return CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type.cassname] - elif cls.is_record_type(type): - return 'RECORD' - elif cls.is_array_type(type): - return cls.get_bq_type(type.subtypes[0]) - else: - raise AirflowException('Not a supported type: ' + type.cassname) +# pylint: disable=unused-import +from airflow.operators.cassandra_to_gcs import CassandraToGoogleCloudStorageOperator # noqa - @classmethod - def get_bq_mode(cls, type): - if cls.is_array_type(type) or type.cassname == 'MapType': - return 'REPEATED' - elif cls.is_record_type(type) or cls.is_simple_type(type): - return 'NULLABLE' - else: - raise AirflowException('Not a supported type: ' + type.cassname) +warnings.warn( + "This module is deprecated. Please use `airflow.operators.cassandra_to_gcs`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/docker_swarm_operator.py b/airflow/contrib/operators/docker_swarm_operator.py index 9ac3057a75faca..2f63ccd120f0b8 100644 --- a/airflow/contrib/operators/docker_swarm_operator.py +++ b/airflow/contrib/operators/docker_swarm_operator.py @@ -1,8 +1,4 @@ -''' -Run ephemeral Docker Swarm services -''' # -*- coding: utf-8 -*- -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -19,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Run ephemeral Docker Swarm services""" from docker import types diff --git a/airflow/contrib/operators/gcs_acl_operator.py b/airflow/contrib/operators/gcs_acl_operator.py index 2c35a74a95ee5c..5d30cf20cb4775 100644 --- a/airflow/contrib/operators/gcs_acl_operator.py +++ b/airflow/contrib/operators/gcs_acl_operator.py @@ -17,151 +17,18 @@ # specific language governing permissions and limitations # under the License. """ -This module contains Google Cloud Storage ACL entry operator. +This module is deprecated. Please use `airflow.gcp.operators.gcs`. """ -import warnings - -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class GoogleCloudStorageBucketCreateAclEntryOperator(BaseOperator): - """ - Creates a new ACL entry on the specified bucket. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:GoogleCloudStorageBucketCreateAclEntryOperator` - - :param bucket: Name of a bucket. - :type bucket: str - :param entity: The entity holding the permission, in one of the following forms: - user-userId, user-email, group-groupId, group-email, domain-domain, - project-team-projectId, allUsers, allAuthenticatedUsers - :type entity: str - :param role: The access permission for the entity. - Acceptable values are: "OWNER", "READER", "WRITER". - :type role: str - :param user_project: (Optional) The project to be billed for this request. - Required for Requester Pays buckets. - :type user_project: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud - Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type google_cloud_storage_conn_id: str - """ - # [START gcs_bucket_create_acl_template_fields] - template_fields = ('bucket', 'entity', 'role', 'user_project') - # [END gcs_bucket_create_acl_template_fields] - - @apply_defaults - def __init__( - self, - bucket, - entity, - role, - user_project=None, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - *args, - **kwargs - ): - super().__init__(*args, - **kwargs) - - if google_cloud_storage_conn_id: - warnings.warn( - "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = google_cloud_storage_conn_id - self.bucket = bucket - self.entity = entity - self.role = role - self.user_project = user_project - self.gcp_conn_id = gcp_conn_id - - def execute(self, context): - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.gcp_conn_id - ) - hook.insert_bucket_acl(bucket_name=self.bucket, entity=self.entity, role=self.role, - user_project=self.user_project) - - -class GoogleCloudStorageObjectCreateAclEntryOperator(BaseOperator): - """ - Creates a new ACL entry on the specified object. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:GoogleCloudStorageObjectCreateAclEntryOperator` - - :param bucket: Name of a bucket. - :type bucket: str - :param object_name: Name of the object. For information about how to URL encode object - names to be path safe, see: - https://cloud.google.com/storage/docs/json_api/#encoding - :type object_name: str - :param entity: The entity holding the permission, in one of the following forms: - user-userId, user-email, group-groupId, group-email, domain-domain, - project-team-projectId, allUsers, allAuthenticatedUsers - :type entity: str - :param role: The access permission for the entity. - Acceptable values are: "OWNER", "READER". - :type role: str - :param generation: Optional. If present, selects a specific revision of this object. - :type generation: long - :param user_project: (Optional) The project to be billed for this request. - Required for Requester Pays buckets. - :type user_project: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud - Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type google_cloud_storage_conn_id: str - """ - # [START gcs_object_create_acl_template_fields] - template_fields = ('bucket', 'object_name', 'entity', 'generation', 'role', 'user_project') - # [END gcs_object_create_acl_template_fields] - - @apply_defaults - def __init__(self, - bucket, - object_name, - entity, - role, - generation=None, - user_project=None, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - *args, **kwargs): - super().__init__(*args, - **kwargs) - - if google_cloud_storage_conn_id: - warnings.warn( - "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = google_cloud_storage_conn_id +import warnings - self.bucket = bucket - self.object_name = object_name - self.entity = entity - self.role = role - self.generation = generation - self.user_project = user_project - self.gcp_conn_id = gcp_conn_id +# pylint: disable=unused-import +from airflow.gcp.operators.gcs import ( # noqa + GoogleCloudStorageObjectCreateAclEntryOperator, + GoogleCloudStorageBucketCreateAclEntryOperator +) - def execute(self, context): - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.gcp_conn_id - ) - hook.insert_object_acl(bucket_name=self.bucket, - object_name=self.object_name, - entity=self.entity, - role=self.role, - generation=self.generation, - user_project=self.user_project) +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.gcs`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/gcs_delete_operator.py b/airflow/contrib/operators/gcs_delete_operator.py index 183350366fc81b..4af605e38c24fb 100644 --- a/airflow/contrib/operators/gcs_delete_operator.py +++ b/airflow/contrib/operators/gcs_delete_operator.py @@ -17,82 +17,15 @@ # specific language governing permissions and limitations # under the License. """ -This module contains Google Cloud Storage delete operator. +This module is deprecated. Please use `airflow.gcp.operators.gcs`. """ -import warnings -from typing import Optional, Iterable - -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class GoogleCloudStorageDeleteOperator(BaseOperator): - """ - Deletes objects from a Google Cloud Storage bucket, either - from an explicit list of object names or all objects - matching a prefix. - - :param bucket_name: The GCS bucket to delete from - :type bucket_name: str - :param objects: List of objects to delete. These should be the names - of objects in the bucket, not including gs://bucket/ - :type objects: Iterable[str] - :param prefix: Prefix of objects to delete. All objects matching this - prefix in the bucket will be deleted. - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud - Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type google_cloud_storage_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: str - """ - - template_fields = ('bucket_name', 'prefix', 'objects') - @apply_defaults - def __init__(self, - bucket_name: str, - objects: Optional[Iterable[str]] = None, - prefix: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - *args, **kwargs): - - if google_cloud_storage_conn_id: - warnings.warn( - "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = google_cloud_storage_conn_id - - self.bucket_name = bucket_name - self.objects = objects - self.prefix = prefix - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - - assert objects is not None or prefix is not None - - super().__init__(*args, **kwargs) - - def execute(self, context): - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to - ) +import warnings - if self.objects: - objects = self.objects - else: - objects = hook.list(bucket_name=self.bucket_name, - prefix=self.prefix) +# pylint: disable=unused-import +from airflow.gcp.operators.gcs import GoogleCloudStorageDeleteOperator # noqa - self.log.info("Deleting %s objects from %s", - len(objects), self.bucket_name) - for object_name in objects: - hook.delete(bucket_name=self.bucket_name, - object_name=object_name) +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.gcs`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/gcs_download_operator.py b/airflow/contrib/operators/gcs_download_operator.py index 1388a1178413b2..eff5957dd0000b 100644 --- a/airflow/contrib/operators/gcs_download_operator.py +++ b/airflow/contrib/operators/gcs_download_operator.py @@ -17,98 +17,15 @@ # specific language governing permissions and limitations # under the License. """ -This module contains Google Cloud Storage download operator. +This module is deprecated. Please use `airflow.gcp.operators.gcs`. """ -import sys import warnings -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.models import BaseOperator -from airflow.models.xcom import MAX_XCOM_SIZE -from airflow.utils.decorators import apply_defaults -from airflow import AirflowException +# pylint: disable=unused-import +from airflow.gcp.operators.gcs import GoogleCloudStorageDownloadOperator # noqa - -class GoogleCloudStorageDownloadOperator(BaseOperator): - """ - Downloads a file from Google Cloud Storage. - - :param bucket: The Google cloud storage bucket where the object is. (templated) - :type bucket: str - :param object: The name of the object to download in the Google cloud - storage bucket. (templated) - :type object: str - :param filename: The file path on the local file system (where the - operator is being executed) that the file should be downloaded to. (templated) - If no filename passed, the downloaded data will not be stored on the local file - system. - :type filename: str - :param store_to_xcom_key: If this param is set, the operator will push - the contents of the downloaded file to XCom with the key set in this - parameter. If not set, the downloaded data will not be pushed to XCom. (templated) - :type store_to_xcom_key: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud - Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type google_cloud_storage_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: str - """ - template_fields = ('bucket', 'object', 'filename', 'store_to_xcom_key',) - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - bucket, - object_name=None, - filename=None, - store_to_xcom_key=None, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - *args, - **kwargs): - # To preserve backward compatibility - # TODO: Remove one day - if object_name is None: - if 'object' in kwargs: - object_name = kwargs['object'] - DeprecationWarning("Use 'object_name' instead of 'object'.") - else: - TypeError("__init__() missing 1 required positional argument: 'object_name'") - - if google_cloud_storage_conn_id: - warnings.warn( - "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = google_cloud_storage_conn_id - - super().__init__(*args, **kwargs) - self.bucket = bucket - self.object = object_name - self.filename = filename - self.store_to_xcom_key = store_to_xcom_key - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - - def execute(self, context): - self.log.info('Executing download: %s, %s, %s', self.bucket, - self.object, self.filename) - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to - ) - file_bytes = hook.download(bucket_name=self.bucket, - object_name=self.object, - filename=self.filename) - if self.store_to_xcom_key: - if sys.getsizeof(file_bytes) < MAX_XCOM_SIZE: - context['ti'].xcom_push(key=self.store_to_xcom_key, value=file_bytes) - else: - raise AirflowException( - 'The size of the downloaded file is too large to push to XCom!' - ) +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.gcs`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/gcs_list_operator.py b/airflow/contrib/operators/gcs_list_operator.py index a143b8b3bf7029..dbefb34eb379fb 100644 --- a/airflow/contrib/operators/gcs_list_operator.py +++ b/airflow/contrib/operators/gcs_list_operator.py @@ -17,92 +17,15 @@ # specific language governing permissions and limitations # under the License. """ -This module contains a Google Cloud Storage list operator. +This module is deprecated. Please use `airflow.gcp.operators.gcs`. """ -import warnings -from typing import Iterable - -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class GoogleCloudStorageListOperator(BaseOperator): - """ - List all objects from the bucket with the give string prefix and delimiter in name. - - This operator returns a python list with the name of objects which can be used by - `xcom` in the downstream task. - - :param bucket: The Google cloud storage bucket to find the objects. (templated) - :type bucket: str - :param prefix: Prefix string which filters objects whose name begin with - this prefix. (templated) - :type prefix: str - :param delimiter: The delimiter by which you want to filter the objects. (templated) - For e.g to lists the CSV files from in a directory in GCS you would use - delimiter='.csv'. - :type delimiter: str - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud - Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type google_cloud_storage_conn_id: - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: str - - **Example**: - The following Operator would list all the Avro files from ``sales/sales-2017`` - folder in ``data`` bucket. :: - GCS_Files = GoogleCloudStorageListOperator( - task_id='GCS_Files', - bucket='data', - prefix='sales/sales-2017/', - delimiter='.avro', - gcp_conn_id=google_cloud_conn_id - ) - """ - template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str] - - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - bucket, - prefix=None, - delimiter=None, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - *args, - **kwargs): - super().__init__(*args, **kwargs) - - if google_cloud_storage_conn_id: - warnings.warn( - "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = google_cloud_storage_conn_id - - self.bucket = bucket - self.prefix = prefix - self.delimiter = delimiter - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - - def execute(self, context): - - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to - ) +import warnings - self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s', - self.bucket, self.delimiter, self.prefix) +# pylint: disable=unused-import +from airflow.gcp.operators.gcs import GoogleCloudStorageListOperator # noqa - return hook.list(bucket_name=self.bucket, - prefix=self.prefix, - delimiter=self.delimiter) +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.gcs`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/gcs_operator.py b/airflow/contrib/operators/gcs_operator.py index 2acebf34a664d2..b23dcce9ee4fb5 100644 --- a/airflow/contrib/operators/gcs_operator.py +++ b/airflow/contrib/operators/gcs_operator.py @@ -17,121 +17,15 @@ # specific language governing permissions and limitations # under the License. """ -This module contains a Google Cloud Storage Bucket operator. +This module is deprecated. Please use `airflow.gcp.operators.gcs`. """ -import warnings - -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - - -class GoogleCloudStorageCreateBucketOperator(BaseOperator): - """ - Creates a new bucket. Google Cloud Storage uses a flat namespace, - so you can't create a bucket with a name that is already in use. - - .. seealso:: - For more information, see Bucket Naming Guidelines: - https://cloud.google.com/storage/docs/bucketnaming.html#requirements - - :param bucket_name: The name of the bucket. (templated) - :type bucket_name: str - :param resource: An optional dict with parameters for creating the bucket. - For information on available parameters, see Cloud Storage API doc: - https://cloud.google.com/storage/docs/json_api/v1/buckets/insert - :type resource: dict - :param storage_class: This defines how objects in the bucket are stored - and determines the SLA and the cost of storage (templated). Values include - - - ``MULTI_REGIONAL`` - - ``REGIONAL`` - - ``STANDARD`` - - ``NEARLINE`` - - ``COLDLINE``. - - If this value is not specified when the bucket is - created, it will default to STANDARD. - :type storage_class: str - :param location: The location of the bucket. (templated) - Object data for objects in the bucket resides in physical storage - within this region. Defaults to US. - - .. seealso:: https://developers.google.com/storage/docs/bucket-locations - :type location: str - :param project_id: The ID of the GCP Project. (templated) - :type project_id: str - :param labels: User-provided labels, in key/value pairs. - :type labels: dict - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. - :type gcp_conn_id: str - :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud - Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :type google_cloud_storage_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must - have domain-wide delegation enabled. - :type delegate_to: str - - The following Operator would create a new bucket ``test-bucket`` - with ``MULTI_REGIONAL`` storage class in ``EU`` region - - .. code-block:: python - - CreateBucket = GoogleCloudStorageCreateBucketOperator( - task_id='CreateNewBucket', - bucket_name='test-bucket', - storage_class='MULTI_REGIONAL', - location='EU', - labels={'env': 'dev', 'team': 'airflow'}, - gcp_conn_id='airflow-conn-id' - ) - - """ - template_fields = ('bucket_name', 'storage_class', - 'location', 'project_id') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - bucket_name, - resource=None, - storage_class='MULTI_REGIONAL', - location='US', - project_id=None, - labels=None, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - *args, - **kwargs): - super().__init__(*args, **kwargs) - - if google_cloud_storage_conn_id: - warnings.warn( - "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) - gcp_conn_id = google_cloud_storage_conn_id - - self.bucket_name = bucket_name - self.resource = resource - self.storage_class = storage_class - self.location = location - self.project_id = project_id - self.labels = labels - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to +import warnings - def execute(self, context): - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to - ) +# pylint: disable=unused-import +from airflow.gcp.operators.gcs import GoogleCloudStorageCreateBucketOperator # noqa - hook.create_bucket(bucket_name=self.bucket_name, - resource=self.resource, - storage_class=self.storage_class, - location=self.location, - project_id=self.project_id, - labels=self.labels) +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.operators.gcs`.", + DeprecationWarning, +) diff --git a/airflow/contrib/operators/gcs_to_gdrive_operator.py b/airflow/contrib/operators/gcs_to_gdrive_operator.py new file mode 100644 index 00000000000000..71736add098212 --- /dev/null +++ b/airflow/contrib/operators/gcs_to_gdrive_operator.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains a Google Cloud Storage operator. +""" +import tempfile +from typing import Optional + +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook +from airflow.contrib.hooks.gdrive_hook import GoogleDriveHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + +WILDCARD = "*" + + +class GcsToGDriveOperator(BaseOperator): + """ + Copies objects from a Google Cloud Storage service service to Google Drive service, with renaming + if requested. + + Using this operator requires the following OAuth 2.0 scope: + + .. code-block:: none + + https://www.googleapis.com/auth/drive + + :param source_bucket: The source Google Cloud Storage bucket where the object is. (templated) + :type source_bucket: str + :param source_object: The source name of the object to copy in the Google cloud + storage bucket. (templated) + You can use only one wildcard for objects (filenames) within your bucket. The wildcard can appear + inside the object name or at the end of the object name. Appending a wildcard to the bucket name + is unsupported. + :type source_object: str + :param destination_object: The destination name of the object in the destination Google Drive + service. (templated) + If a wildcard is supplied in the source_object argument, this is the prefix that will be prepended + to the final destination objects' paths. + Note that the source path's part before the wildcard will be removed; + if it needs to be retained it should be appended to destination_object. + For example, with prefix ``foo/*`` and destination_object ``blah/``, the file ``foo/baz`` will be + copied to ``blah/baz``; to retain the prefix write the destination_object as e.g. ``blah/foo``, in + which case the copied file will be named ``blah/foo/baz``. + :type destination_object: str + :param move_object: When move object is True, the object is moved instead of copied to the new location. + This is the equivalent of a mv command as opposed to a cp command. + :type move_object: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide delegation enabled. + :type delegate_to: str + """ + + template_fields = ("source_bucket", "source_object", "destination_object") + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + source_bucket: str, + source_object: str, + destination_object: str = None, + move_object: bool = False, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str = None, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.source_bucket = source_bucket + self.source_object = source_object + self.destination_object = destination_object + self.move_object = move_object + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.gcs_hook = None # type: Optional[GoogleCloudStorageHook] + self.gdrive_hook = None # type: Optional[GoogleDriveHook] + + def execute(self, context): + + self.gcs_hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to + ) + self.gdrive_hook = GoogleDriveHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) + + if WILDCARD in self.source_object: + total_wildcards = self.source_object.count(WILDCARD) + if total_wildcards > 1: + error_msg = ( + "Only one wildcard '*' is allowed in source_object parameter. " + "Found {} in {}.".format(total_wildcards, self.source_object) + ) + + raise AirflowException(error_msg) + + prefix, delimiter = self.source_object.split(WILDCARD, 1) + objects = self.gcs_hook.list(self.source_bucket, prefix=prefix, delimiter=delimiter) + + for source_object in objects: + if self.destination_object is None: + destination_object = source_object + else: + destination_object = source_object.replace(prefix, self.destination_object, 1) + + self._copy_single_object(source_object=source_object, destination_object=destination_object) + else: + self._copy_single_object( + source_object=self.source_object, destination_object=self.destination_object + ) + + def _copy_single_object(self, source_object, destination_object): + self.log.info( + "Executing copy of gs://%s/%s to gdrive://%s", + self.source_bucket, + source_object, + destination_object, + ) + + with tempfile.NamedTemporaryFile() as file: + filename = file.name + self.gcs_hook.download( + bucket_name=self.source_bucket, object_name=source_object, filename=filename + ) + self.gdrive_hook.upload_file(local_location=filename, remote_location=destination_object) + + if self.move_object: + self.gcs_hook.delete(self.source_bucket, source_object) diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py index 21e939fcd7a15e..7616b2ca7f9aaa 100644 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ b/airflow/contrib/operators/kubernetes_pod_operator.py @@ -14,19 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Executes task in a Kubernetes POD""" from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults -from airflow.kubernetes import kube_client, pod_generator, pod_launcher -from airflow.kubernetes.pod import Resources +from airflow.kubernetes import pod_generator, kube_client, pod_launcher +from airflow.kubernetes.k8s_model import append_to_pod from airflow.utils.state import State -class KubernetesPodOperator(BaseOperator): +class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-attributes """ Execute a task in a Kubernetes Pod + .. note:: + If you use `Google Kubernetes Engine `__, use + :class:`~airflow.gcp.operators.kubernetes_engine.GKEPodOperator`, which + simplifies the authorization process. + :param image: Docker image you wish to launch. Defaults to dockerhub.io, but fully qualified URLS will point to custom repositories :type image: str @@ -45,11 +50,11 @@ class KubernetesPodOperator(BaseOperator): comma separated list: secret_a,secret_b :type image_pull_secrets: str :param ports: ports for launched pod - :type ports: list[airflow.kubernetes.pod.Port] + :type ports: list[airflow.kubernetes.models.port.Port] :param volume_mounts: volumeMounts for launched pod - :type volume_mounts: list[airflow.contrib.kubernetes.volume_mount.VolumeMount] + :type volume_mounts: list[airflow.kubernetes.models.volume_mount.VolumeMount] :param volumes: volumes for launched pod. Includes ConfigMaps and PersistentVolumes - :type volumes: list[airflow.contrib.kubernetes.volume.Volume] + :type volumes: list[airflow.kubernetes.models.volume.Volume] :param labels: labels to apply to the Pod :type labels: dict :param startup_timeout_seconds: timeout in seconds to startup the pod @@ -61,7 +66,7 @@ class KubernetesPodOperator(BaseOperator): :type env_vars: dict :param secrets: Kubernetes secrets to inject in the container, They can be exposed as environment vars or files in a volume. - :type secrets: list[airflow.contrib.kubernetes.secret.Secret] + :type secrets: list[airflow.kubernetes.models.secret.Secret] :param in_cluster: run kubernetes client with in_cluster configuration :type in_cluster: bool :param cluster_context: context that points to kubernetes cluster. @@ -99,9 +104,11 @@ class KubernetesPodOperator(BaseOperator): :type configmaps: list[str] :param pod_runtime_info_envs: environment variables about pod runtime information (ip, namespace, nodeName, podName) - :type pod_runtime_info_envs: list[PodRuntimeEnv] + :type pod_runtime_info_envs: list[airflow.kubernetes.models.pod_runtime_info_env.PodRuntimeInfoEnv] :param dnspolicy: Specify a dnspolicy for the pod :type dnspolicy: str + :param full_pod_spec: The complete podSpec + :type full_pod_spec: kubernetes.client.models.V1Pod """ template_fields = ('cmds', 'arguments', 'env_vars', 'config_file') @@ -110,42 +117,42 @@ def execute(self, context): client = kube_client.get_kube_client(in_cluster=self.in_cluster, cluster_context=self.cluster_context, config_file=self.config_file) - gen = pod_generator.PodGenerator() - - for port in self.ports: - gen.add_port(port) - for mount in self.volume_mounts: - gen.add_mount(mount) - for volume in self.volumes: - gen.add_volume(volume) - pod = gen.make_pod( - namespace=self.namespace, + pod = pod_generator.PodGenerator( image=self.image, - pod_id=self.name, + namespace=self.namespace, cmds=self.cmds, - arguments=self.arguments, + args=self.arguments, labels=self.labels, - ) + name=self.name, + envs=self.env_vars, + extract_xcom=self.do_xcom_push, + image_pull_policy=self.image_pull_policy, + node_selectors=self.node_selectors, + annotations=self.annotations, + affinity=self.affinity, + image_pull_secrets=self.image_pull_secrets, + service_account_name=self.service_account_name, + hostnetwork=self.hostnetwork, + tolerations=self.tolerations, + configmaps=self.configmaps, + security_context=self.security_context, + dnspolicy=self.dnspolicy, + resources=self.resources, + pod=self.full_pod_spec, + ).gen_pod() - pod.service_account_name = self.service_account_name - pod.secrets = self.secrets - pod.envs = self.env_vars - pod.image_pull_policy = self.image_pull_policy - pod.image_pull_secrets = self.image_pull_secrets - pod.annotations = self.annotations - pod.resources = self.resources - pod.affinity = self.affinity - pod.node_selectors = self.node_selectors - pod.hostnetwork = self.hostnetwork - pod.tolerations = self.tolerations - pod.configmaps = self.configmaps - pod.security_context = self.security_context - pod.pod_runtime_info_envs = self.pod_runtime_info_envs - pod.dnspolicy = self.dnspolicy + pod = append_to_pod(pod, self.ports) + pod = append_to_pod(pod, self.pod_runtime_info_envs) + pod = append_to_pod(pod, self.volumes) + pod = append_to_pod(pod, self.volume_mounts) + pod = append_to_pod(pod, self.secrets) + + self.pod = pod launcher = pod_launcher.PodLauncher(kube_client=client, extract_xcom=self.do_xcom_push) + try: (final_state, result) = launcher.run_pod( pod, @@ -164,15 +171,8 @@ def execute(self, context): except AirflowException as ex: raise AirflowException('Pod Launching failed: {error}'.format(error=ex)) - def _set_resources(self, resources): - inputResource = Resources() - if resources: - for item in resources.keys(): - setattr(inputResource, item, resources[item]) - return inputResource - @apply_defaults - def __init__(self, + def __init__(self, # pylint: disable=too-many-arguments,too-many-locals namespace, image, name, @@ -204,9 +204,13 @@ def __init__(self, security_context=None, pod_runtime_info_envs=None, dnspolicy=None, + full_pod_spec=None, *args, **kwargs): super().__init__(*args, **kwargs) + + self.pod = None + self.image = image self.namespace = namespace self.cmds = cmds or [] @@ -229,7 +233,7 @@ def __init__(self, self.do_xcom_push = do_xcom_push if kwargs.get('xcom_push') is not None: raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") - self.resources = self._set_resources(resources) + self.resources = resources self.config_file = config_file self.image_pull_secrets = image_pull_secrets self.service_account_name = service_account_name @@ -240,3 +244,4 @@ def __init__(self, self.security_context = security_context or {} self.pod_runtime_info_envs = pod_runtime_info_envs or [] self.dnspolicy = dnspolicy + self.full_pod_spec = full_pod_spec diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py index e591ed9d0894e4..ddd7e80542f00b 100644 --- a/airflow/contrib/operators/qubole_operator.py +++ b/airflow/contrib/operators/qubole_operator.py @@ -16,6 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Qubole operator""" from typing import Iterable from airflow.models.baseoperator import BaseOperator, BaseOperatorLink @@ -25,7 +26,7 @@ class QDSLink(BaseOperatorLink): - + """Link to QDS""" name = 'Go to QDS' def get_link(self, operator, dttm): @@ -191,16 +192,19 @@ def on_kill(self, ti=None): self.get_hook().kill(ti) def get_results(self, ti=None, fp=None, inline=True, delim=None, fetch=True): + """get_results from Qubole""" return self.get_hook().get_results(ti, fp, inline, delim, fetch) def get_log(self, ti): + """get_log from Qubole""" return self.get_hook().get_log(ti) def get_jobs_id(self, ti): + """get jobs_id from Qubole""" return self.get_hook().get_jobs_id(ti) def get_hook(self): - # Reinitiating the hook, as some template fields might have changed + """Reinitialising the hook, as some template fields might have changed""" return QuboleHook(*self.args, **self.kwargs) def __getattribute__(self, name): diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py b/airflow/contrib/operators/s3_to_gcs_operator.py index 6dd7c8bd04eb86..d71986b530b264 100644 --- a/airflow/contrib/operators/s3_to_gcs_operator.py +++ b/airflow/contrib/operators/s3_to_gcs_operator.py @@ -19,8 +19,8 @@ import warnings from tempfile import NamedTemporaryFile -from airflow.contrib.hooks.gcs_hook import (GoogleCloudStorageHook, - _parse_gcs_url) +from airflow.gcp.hooks.gcs import (GoogleCloudStorageHook, + _parse_gcs_url) from airflow.contrib.operators.s3_list_operator import S3ListOperator from airflow.exceptions import AirflowException from airflow.hooks.S3_hook import S3Hook @@ -214,7 +214,7 @@ def execute(self, context): return files # Following functionality may be better suited in - # airflow/contrib/hooks/gcs_hook.py + # airflow/contrib/hooks/gcs.py @staticmethod def _gcs_object_is_directory(object): _, blob = _parse_gcs_url(object) diff --git a/airflow/contrib/sensors/bash_sensor.py b/airflow/contrib/sensors/bash_sensor.py index e1ec0eb43c2edc..5405c302799de5 100644 --- a/airflow/contrib/sensors/bash_sensor.py +++ b/airflow/contrib/sensors/bash_sensor.py @@ -80,7 +80,6 @@ def poke(self, context): self.sp = sp self.log.info("Output:") - line = '' for line in iter(sp.stdout.readline, b''): line = line.decode(self.output_encoding).strip() self.log.info(line) diff --git a/airflow/contrib/sensors/bigquery_sensor.py b/airflow/contrib/sensors/bigquery_sensor.py index b3c2262f8887c6..1ef5d3e4d1804c 100644 --- a/airflow/contrib/sensors/bigquery_sensor.py +++ b/airflow/contrib/sensors/bigquery_sensor.py @@ -16,59 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains a Google Bigquery sensor. -""" +"""This module is deprecated. Please use `airflow.gcp.sensors.bigquery`.""" -from airflow.sensors.base_sensor_operator import BaseSensorOperator -from airflow.contrib.hooks.bigquery_hook import BigQueryHook -from airflow.utils.decorators import apply_defaults +import warnings +# pylint: disable=unused-import +from airflow.gcp.sensors.bigquery import BigQueryTableSensor # noqa -class BigQueryTableSensor(BaseSensorOperator): - """ - Checks for the existence of a table in Google Bigquery. - - :param project_id: The Google cloud project in which to look for the table. - The connection supplied to the hook must provide - access to the specified project. - :type project_id: str - :param dataset_id: The name of the dataset in which to look for the table. - storage bucket. - :type dataset_id: str - :param table_id: The name of the table to check the existence of. - :type table_id: str - :param bigquery_conn_id: The connection ID to use when connecting to - Google BigQuery. - :type bigquery_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must - have domain-wide delegation enabled. - :type delegate_to: str - """ - template_fields = ('project_id', 'dataset_id', 'table_id',) - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - project_id, - dataset_id, - table_id, - bigquery_conn_id='google_cloud_default', - delegate_to=None, - *args, **kwargs): - - super().__init__(*args, **kwargs) - self.project_id = project_id - self.dataset_id = dataset_id - self.table_id = table_id - self.bigquery_conn_id = bigquery_conn_id - self.delegate_to = delegate_to - - def poke(self, context): - table_uri = '{0}:{1}.{2}'.format(self.project_id, self.dataset_id, self.table_id) - self.log.info('Sensor checks existence of table: %s', table_uri) - hook = BigQueryHook( - bigquery_conn_id=self.bigquery_conn_id, - delegate_to=self.delegate_to) - return hook.table_exists(self.project_id, self.dataset_id, self.table_id) +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.sensors.bigquery`.", + DeprecationWarning, +) diff --git a/airflow/contrib/sensors/gcp_transfer_sensor.py b/airflow/contrib/sensors/gcp_transfer_sensor.py index 91224095f1bb96..2f7dc4b832ada8 100644 --- a/airflow/contrib/sensors/gcp_transfer_sensor.py +++ b/airflow/contrib/sensors/gcp_transfer_sensor.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module is deprecated. Please use `airflow.gcp.operators.cloud_storage_transfer_service`.""" +"""This module is deprecated. Please use `airflow.gcp.sensors.cloud_storage_transfer_service`.""" import warnings @@ -26,6 +26,6 @@ ) warnings.warn( - "This module is deprecated. Please use `airflow.gcp.operators.cloud_storage_transfer_service`.", + "This module is deprecated. Please use `airflow.gcp.sensors.cloud_storage_transfer_service`.", DeprecationWarning, ) diff --git a/airflow/contrib/sensors/gcs_sensor.py b/airflow/contrib/sensors/gcs_sensor.py index c3c1bb996804d7..da1d5b4f88b622 100644 --- a/airflow/contrib/sensors/gcs_sensor.py +++ b/airflow/contrib/sensors/gcs_sensor.py @@ -16,305 +16,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains Google Cloud Storage sensors. -""" - -import os -from datetime import datetime - -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.sensors.base_sensor_operator import BaseSensorOperator -from airflow.utils.decorators import apply_defaults -from airflow import AirflowException - - -class GoogleCloudStorageObjectSensor(BaseSensorOperator): - """ - Checks for the existence of a file in Google Cloud Storage. - - :param bucket: The Google cloud storage bucket where the object is. - :type bucket: str - :param object: The name of the object to check in the Google cloud - storage bucket. - :type object: str - :param google_cloud_conn_id: The connection ID to use when - connecting to Google cloud storage. - :type google_cloud_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: str - """ - template_fields = ('bucket', 'object') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - bucket, - object, # pylint:disable=redefined-builtin - google_cloud_conn_id='google_cloud_default', - delegate_to=None, - *args, **kwargs): - - super().__init__(*args, **kwargs) - self.bucket = bucket - self.object = object - self.google_cloud_conn_id = google_cloud_conn_id - self.delegate_to = delegate_to - - def poke(self, context): - self.log.info('Sensor checks existence of : %s, %s', self.bucket, self.object) - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.google_cloud_conn_id, - delegate_to=self.delegate_to) - return hook.exists(self.bucket, self.object) - - -def ts_function(context): - """ - Default callback for the GoogleCloudStorageObjectUpdatedSensor. The default - behaviour is check for the object being updated after execution_date + - schedule_interval. - """ - return context['dag'].following_schedule(context['execution_date']) - - -class GoogleCloudStorageObjectUpdatedSensor(BaseSensorOperator): - """ - Checks if an object is updated in Google Cloud Storage. - - :param bucket: The Google cloud storage bucket where the object is. - :type bucket: str - :param object: The name of the object to download in the Google cloud - storage bucket. - :type object: str - :param ts_func: Callback for defining the update condition. The default callback - returns execution_date + schedule_interval. The callback takes the context - as parameter. - :type ts_func: function - :param google_cloud_conn_id: The connection ID to use when - connecting to Google cloud storage. - :type google_cloud_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - """ - template_fields = ('bucket', 'object') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - bucket, - object, # pylint:disable=redefined-builtin - ts_func=ts_function, - google_cloud_conn_id='google_cloud_default', - delegate_to=None, - *args, **kwargs): - - super().__init__(*args, **kwargs) - self.bucket = bucket - self.object = object - self.ts_func = ts_func - self.google_cloud_conn_id = google_cloud_conn_id - self.delegate_to = delegate_to - - def poke(self, context): - self.log.info('Sensor checks existence of : %s, %s', self.bucket, self.object) - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.google_cloud_conn_id, - delegate_to=self.delegate_to) - return hook.is_updated_after(self.bucket, self.object, self.ts_func(context)) - - -class GoogleCloudStoragePrefixSensor(BaseSensorOperator): - """ - Checks for the existence of GCS objects at a given prefix, passing matches via XCom. - - When files matching the given prefix are found, the poke method's criteria will be - fulfilled and the matching objects will be returned from the operator and passed - through XCom for downstream tasks. - - :param bucket: The Google cloud storage bucket where the object is. - :type bucket: str - :param prefix: The name of the prefix to check in the Google cloud - storage bucket. - :type prefix: str - :param google_cloud_conn_id: The connection ID to use when - connecting to Google cloud storage. - :type google_cloud_conn_id: str - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: str - """ - template_fields = ('bucket', 'prefix') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - bucket, - prefix, - google_cloud_conn_id='google_cloud_default', - delegate_to=None, - *args, **kwargs): - super().__init__(*args, **kwargs) - self.bucket = bucket - self.prefix = prefix - self.google_cloud_conn_id = google_cloud_conn_id - self.delegate_to = delegate_to - self._matches = [] - - def poke(self, context): - self.log.info('Sensor checks existence of objects: %s, %s', - self.bucket, self.prefix) - hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=self.google_cloud_conn_id, - delegate_to=self.delegate_to) - self._matches = hook.list(self.bucket, prefix=self.prefix) - return bool(self._matches) - - def execute(self, context): - """Overridden to allow matches to be passed""" - super(GoogleCloudStoragePrefixSensor, self).execute(context) - return self._matches - - -def get_time(): - """ - This is just a wrapper of datetime.datetime.now to simplify mocking in the - unittests. - """ - return datetime.now() - - -class GoogleCloudStorageUploadSessionCompleteSensor(BaseSensorOperator): - """ - Checks for changes in the number of objects at prefix in Google Cloud Storage - bucket and returns True if the inactivity period has passed with no - increase in the number of objects. Note, it is recommended to use reschedule - mode if you expect this sensor to run for hours. - - :param bucket: The Google cloud storage bucket where the objects are. - expected. - :type bucket: str - :param prefix: The name of the prefix to check in the Google cloud - storage bucket. - :param inactivity_period: The total seconds of inactivity to designate - an upload session is over. Note, this mechanism is not real time and - this operator may not return until a poke_interval after this period - has passed with no additional objects sensed. - :type inactivity_period: float - :param min_objects: The minimum number of objects needed for upload session - to be considered valid. - :type min_objects: int - :param previous_num_objects: The number of objects found during the last poke. - :type previous_num_objects: int - :param inactivity_seconds: The current seconds of the inactivity period. - :type inactivity_seconds: float - :param allow_delete: Should this sensor consider objects being deleted - between pokes valid behavior. If true a warning message will be logged - when this happens. If false an error will be raised. - :type allow_delete: bool - :param google_cloud_conn_id: The connection ID to use when connecting - to Google cloud storage. - :type google_cloud_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to work, - the service account making the request must have domain-wide - delegation enabled. - :type delegate_to: str - """ - - template_fields = ('bucket', 'prefix') - ui_color = '#f0eee4' - - @apply_defaults - def __init__(self, - bucket, - prefix, - inactivity_period=60 * 60, - min_objects=1, - previous_num_objects=0, - allow_delete=True, - google_cloud_conn_id='google_cloud_default', - delegate_to=None, - *args, **kwargs): - - super().__init__(*args, **kwargs) - - self.bucket = bucket - self.prefix = prefix - self.inactivity_period = inactivity_period - self.min_objects = min_objects - self.previous_num_objects = previous_num_objects - self.inactivity_seconds = 0 - self.allow_delete = allow_delete - self.google_cloud_conn_id = google_cloud_conn_id - self.delegate_to = delegate_to - self.last_activity_time = None - - def is_bucket_updated(self, current_num_objects): - """ - Checks whether new objects have been uploaded and the inactivity_period - has passed and updates the state of the sensor accordingly. - - :param current_num_objects: number of objects in bucket during last poke. - :type current_num_objects: int - """ - - if current_num_objects > self.previous_num_objects: - # When new objects arrived, reset the inactivity_seconds - # previous_num_objects for the next poke. - self.log.info("New objects found at %s resetting last_activity_time.", - os.path.join(self.bucket, self.prefix)) - self.last_activity_time = get_time() - self.inactivity_seconds = 0 - self.previous_num_objects = current_num_objects - return False - - if current_num_objects < self.previous_num_objects: - # During the last poke interval objects were deleted. - if self.allow_delete: - self.previous_num_objects = current_num_objects - self.last_activity_time = get_time() - self.log.warning( - """ - Objects were deleted during the last - poke interval. Updating the file counter and - resetting last_activity_time. - """ - ) - return False - - raise AirflowException( - """ - Illegal behavior: objects were deleted in {} between pokes. - """.format(os.path.join(self.bucket, self.prefix)) - ) - - if self.last_activity_time: - self.inactivity_seconds = (get_time() - self.last_activity_time).total_seconds() - else: - # Handles the first poke where last inactivity time is None. - self.last_activity_time = get_time() - self.inactivity_seconds = 0 - - if self.inactivity_seconds >= self.inactivity_period: - path = os.path.join(self.bucket, self.prefix) - - if current_num_objects >= self.min_objects: - self.log.info("""SUCCESS: - Sensor found %s objects at %s. - Waited at least %s seconds, with no new objects dropped. - """, current_num_objects, path, self.inactivity_period) - return True - - self.log.warning("FAILURE: Inactivity Period passed, not enough objects found in %s", path) - - return False - return False - - def poke(self, context): - hook = GoogleCloudStorageHook() - return self.is_bucket_updated(len(hook.list(self.bucket, prefix=self.prefix))) +"""This module is deprecated. Please use `airflow.gcp.sensors.gcs`.""" + +import warnings + +# pylint: disable=unused-import +from airflow.gcp.sensors.gcs import ( # noqa + GoogleCloudStorageObjectSensor, + GoogleCloudStorageObjectUpdatedSensor, + GoogleCloudStoragePrefixSensor, + GoogleCloudStorageUploadSessionCompleteSensor +) + +warnings.warn( + "This module is deprecated. Please use `airflow.gcp.sensors.gcs`.", + DeprecationWarning, +) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index a2dc2031a86bb6..a4e5ec77aa520f 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -16,9 +16,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow.operators.python_operator import PythonOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults +from typing import Optional, Dict, Callable, List class PythonSensor(BaseSensorOperator): @@ -38,12 +40,6 @@ class PythonSensor(BaseSensorOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -56,24 +52,21 @@ class PythonSensor(BaseSensorOperator): @apply_defaults def __init__( self, - python_callable, - op_args=None, - op_kwargs=None, - provide_context=False, - templates_dict=None, + python_callable: Callable, + op_args: Optional[List] = None, + op_kwargs: Optional[Dict] = None, + templates_dict: Optional[Dict] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict - def poke(self, context): - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict - self.op_kwargs = context + def poke(self, context: Dict): + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) self.log.info("Poking callable: %s", str(self.python_callable)) return_value = self.python_callable(*self.op_args, **self.op_kwargs) diff --git a/airflow/contrib/utils/gcp_field_validator.py b/airflow/contrib/utils/gcp_field_validator.py index ce380dc9ea814c..c633ac492a9d31 100644 --- a/airflow/contrib/utils/gcp_field_validator.py +++ b/airflow/contrib/utils/gcp_field_validator.py @@ -188,7 +188,7 @@ class GcpBodyFieldValidator(LoggingMixin): :type api_version: str """ - def __init__(self, validation_specs: Sequence[str], api_version: str) -> None: + def __init__(self, validation_specs: Sequence[Dict], api_version: str) -> None: super().__init__() self._validation_specs = validation_specs self._api_version = api_version diff --git a/airflow/example_dags/docker_copy_data.py b/airflow/example_dags/docker_copy_data.py index f091969777eebf..b41131d883d94b 100644 --- a/airflow/example_dags/docker_copy_data.py +++ b/airflow/example_dags/docker_copy_data.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -69,7 +70,6 @@ # # t_is_data_available = ShortCircuitOperator( # task_id='check_if_data_available', -# provide_context=True, # python_callable=is_data_available, # dag=dag) # diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py index ec60cfc01b903b..7455ef7ebbd23e 100644 --- a/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -58,7 +58,6 @@ def should_run(**kwargs): cond = BranchPythonOperator( task_id='condition', - provide_context=True, python_callable=should_run, dag=dag, ) diff --git a/airflow/example_dags/example_gcs_to_gcs.py b/airflow/example_dags/example_gcs_to_gcs.py new file mode 100644 index 00000000000000..ff7cddd1880b28 --- /dev/null +++ b/airflow/example_dags/example_gcs_to_gcs.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Storage to Google Cloud Storage transfer operators. +""" + +import os +import airflow +from airflow import models +from airflow.operators.gcs_to_gcs import GoogleCloudStorageSynchronizeBuckets + +default_args = {"start_date": airflow.utils.dates.days_ago(1)} + +BUCKET_1_SRC = os.environ.get("GCP_GCS_BUCKET_1_SRC", "test-gcs-sync-1-src") +BUCKET_1_DST = os.environ.get("GCP_GCS_BUCKET_1_DST", "test-gcs-sync-1-dst") + +BUCKET_2_SRC = os.environ.get("GCP_GCS_BUCKET_2_SRC", "test-gcs-sync-2-src") +BUCKET_2_DST = os.environ.get("GCP_GCS_BUCKET_2_DST", "test-gcs-sync-2-dst") + +BUCKET_3_SRC = os.environ.get("GCP_GCS_BUCKET_3_SRC", "test-gcs-sync-3-src") +BUCKET_3_DST = os.environ.get("GCP_GCS_BUCKET_3_DST", "test-gcs-sync-3-dst") + + +with models.DAG( + "example_gcs_to_gcs", default_args=default_args, schedule_interval=None +) as dag: + sync_full_bucket = GoogleCloudStorageSynchronizeBuckets( + task_id="sync-full-bucket", + source_bucket=BUCKET_1_SRC, + destination_bucket=BUCKET_1_DST + ) + + sync_to_subdirectory_and_delete_extra_files = GoogleCloudStorageSynchronizeBuckets( + task_id="sync_to_subdirectory_and_delete_extra_files", + source_bucket=BUCKET_1_SRC, + destination_bucket=BUCKET_1_DST, + destination_object="subdir/", + delete_extra_files=True, + ) + + sync_from_subdirectory_and_allow_overwrite_and_non_recursive = GoogleCloudStorageSynchronizeBuckets( + task_id="sync_from_subdirectory_and_allow_overwrite_and_non_recursive", + source_bucket=BUCKET_1_SRC, + source_object="subdir/", + destination_bucket=BUCKET_1_DST, + recursive=False, + ) diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py index 152b8cde9e63b5..e8fc9c963a916a 100644 --- a/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow/example_dags/example_passing_params_via_test_command.py @@ -37,17 +37,17 @@ ) -def my_py_command(**kwargs): +def my_py_command(test_mode, params): """ Print out the "foo" param passed in via `airflow tasks test example_passing_params_via_test_command run_this -tp '{"foo":"bar"}'` """ - if kwargs["test_mode"]: + if test_mode: print(" 'foo' was passed in via test={} command : kwargs[params][foo] \ - = {}".format(kwargs["test_mode"], kwargs["params"]["foo"])) + = {}".format(test_mode, params["foo"])) # Print out the value of "miff", passed in below via the Python Operator - print(" 'miff' was passed in via task params = {}".format(kwargs["params"]["miff"])) + print(" 'miff' was passed in via task params = {}".format(params["miff"])) return 1 @@ -58,7 +58,6 @@ def my_py_command(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=my_py_command, params={"miff": "agg"}, dag=dag, diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index 29c664f0a65ec0..86403ceb25b63b 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -48,7 +48,6 @@ def print_context(ds, **kwargs): run_this = PythonOperator( task_id='print_the_context', - provide_context=True, python_callable=print_context, dag=dag, ) diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 4758176981152e..32255103d804eb 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -69,7 +69,6 @@ def run_this_func(**kwargs): run_this = PythonOperator( task_id='run_this', - provide_context=True, python_callable=run_this_func, dag=dag, ) diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py index 5b7b79aca80f9b..fb043b021ebc3a 100644 --- a/airflow/example_dags/example_xcom.py +++ b/airflow/example_dags/example_xcom.py @@ -26,7 +26,6 @@ args = { 'owner': 'Airflow', 'start_date': airflow.utils.dates.days_ago(2), - 'provide_context': True, } dag = DAG('example_xcom', schedule_interval="@once", default_args=args) @@ -40,7 +39,7 @@ def push(**kwargs): kwargs['ti'].xcom_push(key='value from pusher 1', value=value_1) -def push_by_returning(): +def push_by_returning(**kwargs): """Pushes an XCom without a specific target, just by returning it""" return value_2 diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 09abf79c4b3803..710c15ccf38b0d 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -19,6 +19,7 @@ # # Note: Any AirflowException raised is expected to cause the TaskInstance # to be marked in an ERROR state +"""Exceptions used by Airflow""" class AirflowException(Exception): @@ -40,11 +41,11 @@ class AirflowNotFoundException(AirflowException): class AirflowConfigException(AirflowException): - pass + """Raise when there is configuration problem""" class AirflowSensorTimeout(AirflowException): - pass + """Raise when there is a timeout on sensor polling""" class AirflowRescheduleException(AirflowException): @@ -52,30 +53,30 @@ class AirflowRescheduleException(AirflowException): Raise when the task should be re-scheduled at a later time. :param reschedule_date: The date when the task should be rescheduled - :type reschedule: datetime.datetime + :type reschedule_date: datetime.datetime """ def __init__(self, reschedule_date): self.reschedule_date = reschedule_date class InvalidStatsNameException(AirflowException): - pass + """Raise when name of the stats is invalid""" class AirflowTaskTimeout(AirflowException): - pass + """Raise when the task execution times-out""" class AirflowWebServerTimeout(AirflowException): - pass + """Raise when the web server times out""" class AirflowSkipException(AirflowException): - pass + """Raise when the task should be skipped""" class AirflowDagCycleException(AirflowException): - pass + """Raise when there is a cycle in Dag definition""" class DagNotFound(AirflowNotFoundException): diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 6ebd0495ce2e41..676fbbca071169 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Kubernetes executor""" import base64 import hashlib from queue import Empty @@ -22,16 +22,18 @@ import re import json import multiprocessing -from dateutil import parser from uuid import uuid4 + +from dateutil import parser + import kubernetes from kubernetes import watch, client from kubernetes.client.rest import ApiException from airflow.kubernetes.pod_launcher import PodLauncher from airflow.kubernetes.kube_client import get_kube_client from airflow.kubernetes.worker_configuration import WorkerConfiguration +from airflow.kubernetes.pod_generator import PodGenerator from airflow.executors.base_executor import BaseExecutor -from airflow.executors import Executors from airflow.models import KubeResourceVersion, KubeWorkerIdentifier, TaskInstance from airflow.utils.state import State from airflow.utils.db import provide_session, create_session @@ -40,89 +42,12 @@ from airflow.exceptions import AirflowConfigException, AirflowException from airflow.utils.log.logging_mixin import LoggingMixin +MAX_POD_ID_LEN = 253 +MAX_LABEL_LEN = 63 -class KubernetesExecutorConfig: - def __init__(self, image=None, image_pull_policy=None, request_memory=None, - request_cpu=None, limit_memory=None, limit_cpu=None, limit_gpu=None, - gcp_service_account_key=None, node_selectors=None, affinity=None, - annotations=None, volumes=None, volume_mounts=None, tolerations=None, labels=None): - self.image = image - self.image_pull_policy = image_pull_policy - self.request_memory = request_memory - self.request_cpu = request_cpu - self.limit_memory = limit_memory - self.limit_cpu = limit_cpu - self.limit_gpu = limit_gpu - self.gcp_service_account_key = gcp_service_account_key - self.node_selectors = node_selectors - self.affinity = affinity - self.annotations = annotations - self.volumes = volumes - self.volume_mounts = volume_mounts - self.tolerations = tolerations - self.labels = labels or {} - - def __repr__(self): - return "{}(image={}, image_pull_policy={}, request_memory={}, request_cpu={}, " \ - "limit_memory={}, limit_cpu={}, limit_gpu={}, gcp_service_account_key={}, " \ - "node_selectors={}, affinity={}, annotations={}, volumes={}, " \ - "volume_mounts={}, tolerations={}, labels={})" \ - .format(KubernetesExecutorConfig.__name__, self.image, self.image_pull_policy, - self.request_memory, self.request_cpu, self.limit_memory, - self.limit_cpu, self.limit_gpu, self.gcp_service_account_key, self.node_selectors, - self.affinity, self.annotations, self.volumes, self.volume_mounts, - self.tolerations, self.labels) - @staticmethod - def from_dict(obj): - if obj is None: - return KubernetesExecutorConfig() - - if not isinstance(obj, dict): - raise TypeError( - 'Cannot convert a non-dictionary object into a KubernetesExecutorConfig') - - namespaced = obj.get(Executors.KubernetesExecutor, {}) - - return KubernetesExecutorConfig( - image=namespaced.get('image', None), - image_pull_policy=namespaced.get('image_pull_policy', None), - request_memory=namespaced.get('request_memory', None), - request_cpu=namespaced.get('request_cpu', None), - limit_memory=namespaced.get('limit_memory', None), - limit_cpu=namespaced.get('limit_cpu', None), - limit_gpu=namespaced.get('limit_gpu', None), - gcp_service_account_key=namespaced.get('gcp_service_account_key', None), - node_selectors=namespaced.get('node_selectors', None), - affinity=namespaced.get('affinity', None), - annotations=namespaced.get('annotations', {}), - volumes=namespaced.get('volumes', []), - volume_mounts=namespaced.get('volume_mounts', []), - tolerations=namespaced.get('tolerations', None), - labels=namespaced.get('labels', {}), - ) - - def as_dict(self): - return { - 'image': self.image, - 'image_pull_policy': self.image_pull_policy, - 'request_memory': self.request_memory, - 'request_cpu': self.request_cpu, - 'limit_memory': self.limit_memory, - 'limit_cpu': self.limit_cpu, - 'limit_gpu': self.limit_gpu, - 'gcp_service_account_key': self.gcp_service_account_key, - 'node_selectors': self.node_selectors, - 'affinity': self.affinity, - 'annotations': self.annotations, - 'volumes': self.volumes, - 'volume_mounts': self.volume_mounts, - 'tolerations': self.tolerations, - 'labels': self.labels, - } - - -class KubeConfig: +class KubeConfig: # pylint: disable=too-many-instance-attributes + """Configuration for Kubernetes""" core_section = 'core' kubernetes_section = 'kubernetes' @@ -278,13 +203,14 @@ def __init__(self): # and only return a blank string if contexts are not set. def _get_security_context_val(self, scontext): val = conf.get(self.kubernetes_section, scontext) - if len(val) == 0: - return val + if not val: + return 0 else: return int(val) def _validate(self): # TODO: use XOR for dags_volume_claim and git_dags_folder_mount_point + # pylint: disable=too-many-boolean-expressions if not self.dags_volume_claim \ and not self.dags_volume_host \ and not self.dags_in_image \ @@ -304,9 +230,11 @@ def _validate(self): 'must be set for authentication through user credentials; ' 'or `git_ssh_key_secret_name` must be set for authentication ' 'through ssh key, but not both') + # pylint: enable=too-many-boolean-expressions class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): + """Watches for Kubernetes jobs""" def __init__(self, namespace, watcher_queue, resource_version, worker_uuid, kube_config): multiprocessing.Process.__init__(self) self.namespace = namespace @@ -316,6 +244,7 @@ def __init__(self, namespace, watcher_queue, resource_version, worker_uuid, kube self.kube_config = kube_config def run(self): + """Performs watching""" kube_client = get_kube_client() while True: try: @@ -361,6 +290,7 @@ def _run(self, kube_client, resource_version, worker_uuid, kube_config): return last_resource_version def process_error(self, event): + """Process error response""" self.log.error( 'Encountered Error response from k8s list namespaced pod stream => %s', event @@ -369,16 +299,17 @@ def process_error(self, event): if raw_object['code'] == 410: self.log.info( 'Kubernetes resource version is too old, must reset to 0 => %s', - raw_object['message'] + (raw_object['message'],) ) # Return resource version 0 return '0' raise AirflowException( - 'Kubernetes failure for %s with code %s and message: %s', - raw_object['reason'], raw_object['code'], raw_object['message'] + 'Kubernetes failure for %s with code %s and message: %s' % + (raw_object['reason'], raw_object['code'], raw_object['message']) ) def process_status(self, pod_id, status, labels, resource_version): + """Process status response""" if status == 'Pending': self.log.info('Event: %s Pending', pod_id) elif status == 'Failed': @@ -397,6 +328,7 @@ def process_status(self, pod_id, status, labels, resource_version): class AirflowKubernetesScheduler(LoggingMixin): + """Airflow Scheduler for Kubernetes""" def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid): self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config @@ -430,7 +362,6 @@ def _health_check_kube_watcher(self): def run_next(self, next_job): """ - The run_next command will check the task_queue for any un-run jobs. It will then create a unique job-id, launch that job in the cluster, and store relevant info in the current_jobs map so we can track the job's @@ -439,22 +370,29 @@ def run_next(self, next_job): self.log.info('Kubernetes job is %s', str(next_job)) key, command, kube_executor_config = next_job dag_id, task_id, execution_date, try_number = key - self.log.debug("Kubernetes running for command %s", command) - self.log.debug("Kubernetes launching image %s", self.kube_config.kube_image) - pod = self.worker_configuration.make_pod( - namespace=self.namespace, worker_uuid=self.worker_uuid, + + config_pod = self.worker_configuration.make_pod( + namespace=self.namespace, + worker_uuid=self.worker_uuid, pod_id=self._create_pod_id(dag_id, task_id), dag_id=self._make_safe_label_value(dag_id), task_id=self._make_safe_label_value(task_id), try_number=try_number, execution_date=self._datetime_to_label_safe_datestring(execution_date), - airflow_command=command, kube_executor_config=kube_executor_config + airflow_command=command ) + # Reconcile the pod generated by the Operator and the Pod + # generated by the .cfg file + pod = PodGenerator.reconcile_pods(config_pod, kube_executor_config) + self.log.debug("Kubernetes running for command %s", command) + self.log.debug("Kubernetes launching image %s", pod.spec.containers[0].image) + # the watcher will monitor pods, so we do not block. self.launcher.run_pod_async(pod, **self.kube_config.kube_client_request_args) self.log.debug("Kubernetes Job created!") def delete_pod(self, pod_id: str) -> None: + """Deletes POD""" try: self.kube_client.delete_namespaced_pod( pod_id, self.namespace, body=client.V1DeleteOptions(), @@ -485,6 +423,7 @@ def sync(self): break def process_watcher_task(self, task): + """Process the task by watcher.""" pod_id, state, labels, resource_version = task self.log.info( 'Attempting to finish pod; pod_id: %s; state: %s; labels: %s', @@ -522,8 +461,6 @@ def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): :param random_uuid: a uuid :return: ``str`` valid Pod name of appropriate length """ - MAX_POD_ID_LEN = 253 - safe_key = safe_dag_id + safe_task_id safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid @@ -541,8 +478,6 @@ def _make_safe_label_value(string): way from the original value sent to this function, then we need to truncate to 53chars, and append it with a unique hash. """ - MAX_LABEL_LEN = 63 - safe_label = re.sub(r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$', '', string) if len(safe_label) > MAX_LABEL_LEN or string != safe_label: @@ -596,7 +531,7 @@ def _labels_to_key(self, labels): dag_id = labels['dag_id'] task_id = labels['task_id'] ex_time = self._label_safe_datestring_to_datetime(labels['execution_date']) - except Exception as e: + except Exception as e: # pylint: disable=broad-except self.log.warning( 'Error while retrieving labels; labels: %s; exception: %s', labels, e @@ -633,11 +568,13 @@ def _labels_to_key(self, labels): return None def terminate(self): + """Termninates the watcher.""" self.watcher_queue.join() self._manager.shutdown() class KubernetesExecutor(BaseExecutor, LoggingMixin): + """Executor for Kubernetes""" def __init__(self): self.kube_config = KubeConfig() self.task_queue = None @@ -674,6 +611,8 @@ def clear_not_launched_queued_tasks(self, session=None): ) for task in queued_tasks: + # noinspection PyProtectedMember + # pylint: disable=protected-access dict_string = ( "dag_id={},task_id={},execution_date={},airflow-worker={}".format( AirflowKubernetesScheduler._make_safe_label_value(task.dag_id), @@ -684,13 +623,14 @@ def clear_not_launched_queued_tasks(self, session=None): self.worker_uuid ) ) + # pylint: enable=protected-access kwargs = dict(label_selector=dict_string) if self.kube_config.kube_client_request_args: for key, value in self.kube_config.kube_client_request_args.iteritems(): kwargs[key] = value pod_list = self.kube_client.list_namespaced_pod( self.kube_config.kube_namespace, **kwargs) - if len(pod_list.items) == 0: + if not pod_list.items: self.log.info( 'TaskInstance: %s found in queued state but was not launched, ' 'rescheduling', task @@ -738,6 +678,7 @@ def _create_or_update_secret(secret_name, secret_path): _create_or_update_secret(service_account['name'], service_account['path']) def start(self): + """Starts the executor""" self.log.info('Start Kubernetes executor') self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid() self.log.debug('Start with worker_uuid: %s', self.worker_uuid) @@ -757,14 +698,17 @@ def start(self): self.clear_not_launched_queued_tasks() def execute_async(self, key, command, queue=None, executor_config=None): + """Executes task asynchronously""" self.log.info( 'Add task %s with command %s with executor_config %s', key, command, executor_config ) - kube_executor_config = KubernetesExecutorConfig.from_dict(executor_config) + + kube_executor_config = PodGenerator.from_obj(executor_config) self.task_queue.put((key, command, kube_executor_config)) def sync(self): + """Synchronize task state.""" if self.running: self.log.debug('self.running: %s', self.running) if self.queued_tasks: @@ -772,7 +716,7 @@ def sync(self): self.kube_scheduler.sync() last_resource_version = None - while True: + while True: # pylint: disable=too-many-nested-blocks try: results = self.result_queue.get_nowait() try: @@ -781,7 +725,7 @@ def sync(self): self.log.info('Changing state of %s to %s', results, state) try: self._change_state(key, state, pod_id) - except Exception as e: + except Exception as e: # pylint: disable=broad-except self.log.exception('Exception: %s when attempting ' + 'to change state of %s to %s, re-queueing.', e, results, state) self.result_queue.put(results) @@ -792,6 +736,7 @@ def sync(self): KubeResourceVersion.checkpoint_resource_version(last_resource_version) + # pylint: disable=too-many-nested-blocks for _ in range(self.kube_config.worker_pods_creation_batch_size): try: task = self.task_queue.get_nowait() @@ -805,6 +750,7 @@ def sync(self): self.task_queue.task_done() except Empty: break + # pylint: enable=too-many-nested-blocks def _change_state(self, key, state, pod_id: str) -> None: if state != State.RUNNING: @@ -818,6 +764,7 @@ def _change_state(self, key, state, pod_id: str) -> None: self.event_buffer[key] = state def end(self): + """Called when the executor shuts down""" self.log.info('Shutting down Kubernetes executor') self.task_queue.join() self.result_queue.join() diff --git a/airflow/gcp/example_dags/example_bigquery.py b/airflow/gcp/example_dags/example_bigquery.py new file mode 100644 index 00000000000000..f2d2cc7aec5dd9 --- /dev/null +++ b/airflow/gcp/example_dags/example_bigquery.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google BigQuery service. +""" +import os +from urllib.parse import urlparse + +import airflow +from airflow import models +from airflow.gcp.operators.bigquery import ( + BigQueryOperator, + BigQueryCreateEmptyTableOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryGetDatasetOperator, + BigQueryPatchDatasetOperator, + BigQueryUpdateDatasetOperator, + BigQueryDeleteDatasetOperator, + BigQueryCreateExternalTableOperator, + BigQueryGetDataOperator, + BigQueryTableDeleteOperator, +) + +from airflow.operators.bigquery_to_bigquery import BigQueryToBigQueryOperator +from airflow.operators.bigquery_to_gcs import BigQueryToCloudStorageOperator +from airflow.operators.bash_operator import BashOperator + +# 0x06012c8cf97BEaD5deAe237070F9587f8E7A266d = CryptoKitties contract address +WALLET_ADDRESS = os.environ.get("GCP_ETH_WALLET_ADDRESS", "0x06012c8cf97BEaD5deAe237070F9587f8E7A266d") + +default_args = {"start_date": airflow.utils.dates.days_ago(1)} + +MOST_VALUABLE_INCOMING_TRANSACTIONS = """ +SELECT + value, to_address +FROM + `bigquery-public-data.ethereum_blockchain.transactions` +WHERE + 1 = 1 + AND DATE(block_timestamp) = "{{ ds }}" + AND to_address = LOWER(@to_address) +ORDER BY value DESC +LIMIT 1000 +""" + +MOST_ACTIVE_PLAYERS = """ +SELECT + COUNT(from_address) + , from_address +FROM + `bigquery-public-data.ethereum_blockchain.transactions` +WHERE + 1 = 1 + AND DATE(block_timestamp) = "{{ ds }}" + AND to_address = LOWER(@to_address) +GROUP BY from_address +ORDER BY COUNT(from_address) DESC +LIMIT 1000 +""" + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BQ_LOCATION = "europe-north1" + +DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset") +LOCATION_DATASET_NAME = "{}_location".format(DATASET_NAME) +DATA_SAMPLE_GCS_URL = os.environ.get( + "GCP_BIGQUERY_DATA_GCS_URL", "gs://cloud-samples-data/bigquery/us-states/us-states.csv" +) + +DATA_SAMPLE_GCS_URL_PARTS = urlparse(DATA_SAMPLE_GCS_URL) +DATA_SAMPLE_GCS_BUCKET_NAME = DATA_SAMPLE_GCS_URL_PARTS.netloc +DATA_SAMPLE_GCS_OBJECT_NAME = DATA_SAMPLE_GCS_URL_PARTS.path[1:] + +DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME", "test-bigquery-sample-data") + + +with models.DAG( + "example_bigquery", default_args=default_args, schedule_interval=None # Override to match your needs +) as dag: + + execute_query = BigQueryOperator( + task_id="execute-query", + sql=MOST_VALUABLE_INCOMING_TRANSACTIONS, + use_legacy_sql=False, + query_params=[ + { + "name": "to_address", + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": WALLET_ADDRESS}, + } + ], + ) + + bigquery_execute_multi_query = BigQueryOperator( + task_id="execute-multi-query", + sql=[MOST_VALUABLE_INCOMING_TRANSACTIONS, MOST_ACTIVE_PLAYERS], + use_legacy_sql=False, + query_params=[ + { + "name": "to_address", + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": WALLET_ADDRESS}, + } + ], + ) + + execute_query_save = BigQueryOperator( + task_id="execute-query-save", + sql=MOST_VALUABLE_INCOMING_TRANSACTIONS, + use_legacy_sql=False, + destination_dataset_table="{}.save_query_result".format(DATASET_NAME), + query_params=[ + { + "name": "to_address", + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": WALLET_ADDRESS}, + } + ], + ) + + get_data = BigQueryGetDataOperator( + task_id="get-data", + dataset_id=DATASET_NAME, + table_id="save_query_result", + max_results="10", + selected_fields="value,to_address", + ) + + get_data_result = BashOperator( + task_id="get-data-result", bash_command="echo \"{{ task_instance.xcom_pull('get-data') }}\"" + ) + + create_external_table = BigQueryCreateExternalTableOperator( + task_id="create-external-table", + bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + source_objects=[DATA_SAMPLE_GCS_OBJECT_NAME], + destination_project_dataset_table="{}.external_table".format(DATASET_NAME), + skip_leading_rows=1, + schema_fields=[{"name": "name", "type": "STRING"}, {"name": "post_abbr", "type": "STRING"}], + ) + + execute_query_external_table = BigQueryOperator( + task_id="execute-query-external-table", + destination_dataset_table="{}.selected_data_from_external_table".format(DATASET_NAME), + sql='SELECT * FROM `{}.external_table` WHERE name LIKE "W%"'.format(DATASET_NAME), + use_legacy_sql=False, + ) + + copy_from_selected_data = BigQueryToBigQueryOperator( + task_id="copy-from-selected-data", + source_project_dataset_tables="{}.selected_data_from_external_table".format(DATASET_NAME), + destination_project_dataset_table="{}.copy_of_selected_data_from_external_table".format(DATASET_NAME), + ) + + bigquery_to_gcs = BigQueryToCloudStorageOperator( + task_id="bigquery-to-gcs", + source_project_dataset_table="{}.selected_data_from_external_table".format(DATASET_NAME), + destination_cloud_storage_uris=["gs://{}/export-bigquery.csv".format(DATA_EXPORT_BUCKET_NAME)], + ) + + create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create-dataset", dataset_id=DATASET_NAME) + + create_dataset_with_location = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset_with_location", + dataset_id=LOCATION_DATASET_NAME, + location=BQ_LOCATION + ) + + create_table = BigQueryCreateEmptyTableOperator( + task_id="create-table", + dataset_id=DATASET_NAME, + table_id="test_table", + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + ) + + create_table_with_location = BigQueryCreateEmptyTableOperator( + task_id="create_table_with_location", + dataset_id=LOCATION_DATASET_NAME, + table_id="test_table", + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + ) + + delete_table = BigQueryTableDeleteOperator( + task_id="delete-table", deletion_dataset_table="{}.test_table".format(DATASET_NAME) + ) + + get_dataset = BigQueryGetDatasetOperator(task_id="get-dataset", dataset_id=DATASET_NAME) + + get_dataset_result = BashOperator( + task_id="get-dataset-result", + bash_command="echo \"{{ task_instance.xcom_pull('get-dataset')['id'] }}\"", + ) + + patch_dataset = BigQueryPatchDatasetOperator( + task_id="patch-dataset", + dataset_id=DATASET_NAME, + dataset_resource={"friendlyName": "Patched Dataset", "description": "Patched dataset"}, + ) + + update_dataset = BigQueryUpdateDatasetOperator( + task_id="update-dataset", dataset_id=DATASET_NAME, dataset_resource={"description": "Updated dataset"} + ) + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete-dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + delete_dataset_with_location = BigQueryDeleteDatasetOperator( + task_id="delete_dataset_with_location", + dataset_id=LOCATION_DATASET_NAME, + delete_contents=True + ) + + create_dataset >> execute_query_save >> delete_dataset + create_dataset >> create_table >> delete_dataset + create_dataset >> get_dataset >> delete_dataset + create_dataset >> patch_dataset >> update_dataset >> delete_dataset + execute_query_save >> get_data >> get_dataset_result + get_data >> delete_dataset + create_dataset >> create_external_table >> execute_query_external_table >> \ + copy_from_selected_data >> delete_dataset + execute_query_external_table >> bigquery_to_gcs >> delete_dataset + create_table >> delete_table >> delete_dataset + create_dataset_with_location >> create_table_with_location >> delete_dataset_with_location diff --git a/airflow/gcp/example_dags/example_bigquery_dts.py b/airflow/gcp/example_dags/example_bigquery_dts.py new file mode 100644 index 00000000000000..e1a308603dc198 --- /dev/null +++ b/airflow/gcp/example_dags/example_bigquery_dts.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that creates and deletes Bigquery data transfer configurations. +""" +import os +import time + +from google.protobuf.json_format import ParseDict +from google.cloud.bigquery_datatransfer_v1.types import TransferConfig + +import airflow +from airflow import models +from airflow.gcp.operators.bigquery_dts import ( + BigQueryCreateDataTransferOperator, + BigQueryDeleteDataTransferConfigOperator, + BigQueryDataTransferServiceStartTransferRunsOperator, +) +from airflow.gcp.sensors.bigquery_dts import ( + BigQueryDataTransferServiceTransferRunSensor, +) + + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BUCKET_URI = os.environ.get( + "GCP_DTS_BUCKET_URI", "gs://cloud-ml-tables-data/bank-marketing.csv" +) +GCP_DTS_BQ_DATASET = os.environ.get("GCP_DTS_BQ_DATASET", "test_dts") +GCP_DTS_BQ_TABLE = os.environ.get("GCP_DTS_BQ_TABLE", "GCS_Test") + +# [START howto_bigquery_dts_create_args] + +# In the case of Airflow, the customer needs to create a transfer +# config with the automatic scheduling disabled and then trigger +# a transfer run using a specialized Airflow operator +schedule_options = {"disable_auto_scheduling": True} + +PARAMS = { + "field_delimiter": ",", + "max_bad_records": "0", + "skip_leading_rows": "1", + "data_path_template": BUCKET_URI, + "destination_table_name_template": GCP_DTS_BQ_TABLE, + "file_format": "CSV", +} + +TRANSFER_CONFIG = ParseDict( + { + "destination_dataset_id": GCP_DTS_BQ_DATASET, + "display_name": "GCS Test Config", + "data_source_id": "google_cloud_storage", + "schedule_options": schedule_options, + "params": PARAMS, + }, + TransferConfig(), +) + +# [END howto_bigquery_dts_create_args] + +default_args = {"start_date": airflow.utils.dates.days_ago(1)} + +with models.DAG( + "example_gcp_bigquery_dts", + default_args=default_args, + schedule_interval=None, # Override to match your needs +) as dag: + # [START howto_bigquery_create_data_transfer] + gcp_bigquery_create_transfer = BigQueryCreateDataTransferOperator( + transfer_config=TRANSFER_CONFIG, + project_id=GCP_PROJECT_ID, + task_id="gcp_bigquery_create_transfer", + ) + + transfer_config_id = ( + "{{ task_instance.xcom_pull('gcp_bigquery_create_transfer', " + "key='transfer_config_id') }}" + ) + # [END howto_bigquery_create_data_transfer] + + # [START howto_bigquery_start_transfer] + gcp_bigquery_start_transfer = BigQueryDataTransferServiceStartTransferRunsOperator( + task_id="gcp_bigquery_start_transfer", + transfer_config_id=transfer_config_id, + requested_run_time={"seconds": int(time.time() + 60)}, + ) + run_id = ( + "{{ task_instance.xcom_pull('gcp_bigquery_start_transfer', " "key='run_id') }}" + ) + # [END howto_bigquery_start_transfer] + + # [START howto_bigquery_dts_sensor] + gcp_run_sensor = BigQueryDataTransferServiceTransferRunSensor( + task_id="gcp_run_sensor", + transfer_config_id=transfer_config_id, + run_id=run_id, + expected_statuses={"SUCCEEDED"}, + ) + # [END howto_bigquery_dts_sensor] + + # [START howto_bigquery_delete_data_transfer] + gcp_bigquery_delete_transfer = BigQueryDeleteDataTransferConfigOperator( + transfer_config_id=transfer_config_id, task_id="gcp_bigquery_delete_transfer" + ) + # [END howto_bigquery_delete_data_transfer] + + ( + gcp_bigquery_create_transfer # noqa + >> gcp_bigquery_start_transfer # noqa + >> gcp_run_sensor # noqa + >> gcp_bigquery_delete_transfer # noqa + ) diff --git a/airflow/gcp/example_dags/example_bigtable.py b/airflow/gcp/example_dags/example_bigtable.py index cecc57c86f39df..e8bbc8e665fab8 100644 --- a/airflow/gcp/example_dags/example_bigtable.py +++ b/airflow/gcp/example_dags/example_bigtable.py @@ -1,18 +1,18 @@ # -*- coding: utf-8 -*- -# + # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the -# 'License'); you may not use this file except in compliance +# "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an -# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. diff --git a/airflow/gcp/example_dags/example_cloud_memorystore.py b/airflow/gcp/example_dags/example_cloud_memorystore.py new file mode 100644 index 00000000000000..45026a6bd26156 --- /dev/null +++ b/airflow/gcp/example_dags/example_cloud_memorystore.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Memorystore service. +""" +import os +from urllib.parse import urlparse + +from google.cloud.redis_v1.gapic.enums import Instance, FailoverInstanceRequest + +from airflow import models +from airflow.gcp.operators.gcs import GoogleCloudStorageBucketCreateAclEntryOperator +from airflow.gcp.operators.cloud_memorystore import ( + CloudMemorystoreCreateInstanceOperator, + CloudMemorystoreDeleteInstanceOperator, + CloudMemorystoreExportInstanceOperator, + CloudMemorystoreFailoverInstanceOperator, + CloudMemorystoreGetInstanceOperator, + CloudMemorystoreImportOperator, + CloudMemorystoreListInstancesOperator, + CloudMemorystoreUpdateInstanceOperator, + CloudMemorystoreCreateInstanceAndImportOperator, + CloudMemorystoreScaleInstanceOperator, + CloudMemorystoreExportAndDeleteInstanceOperator) +from airflow.operators.bash_operator import BashOperator +from airflow.utils import dates + + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") + +INSTANCE_NAME = os.environ.get("GCP_MEMORYSTORE_INSTANCE_NAME", "test-memorystore") +INSTANCE_NAME_2 = os.environ.get("GCP_MEMORYSTORE_INSTANCE_NAME2", "test-memorystore-2") +INSTANCE_NAME_3 = os.environ.get("GCP_MEMORYSTORE_INSTANCE_NAME3", "test-memorystore-3") + +EXPORT_GCS_URL = os.environ.get("GCP_MEMORYSTORE_EXPORT_GCS_URL", "gs://test-memorystore/my-export.rdb") +EXPORT_GCS_URL_PARTS = urlparse(EXPORT_GCS_URL) +BUCKET_NAME = EXPORT_GCS_URL_PARTS.netloc + +# [START howto_operator_instance] +FIRST_INSTANCE = {"tier": Instance.Tier.BASIC, "memory_size_gb": 1} +# [END howto_operator_instance] + +SECOND_INSTANCE = {"tier": Instance.Tier.STANDARD_HA, "memory_size_gb": 3} + + +default_args = {"start_date": dates.days_ago(1)} + +with models.DAG( + "gcp_cloud_memorystore", default_args=default_args, schedule_interval=None # Override to match your needs +) as dag: + # [START howto_operator_create_instance] + create_instance = CloudMemorystoreCreateInstanceOperator( + task_id="create-instance", + location="europe-north1", + instance_id=INSTANCE_NAME, + instance=FIRST_INSTANCE, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_create_instance] + + # [START howto_operator_create_instance_result] + create_instance_result = BashOperator( + task_id="create-instance-result", + bash_command="echo \"{{ task_instance.xcom_pull('create-instance') }}\"", + ) + # [END howto_operator_create_instance_result] + + create_instance_2 = CloudMemorystoreCreateInstanceOperator( + task_id="create-instance-2", + location="europe-north1", + instance_id=INSTANCE_NAME_2, + instance=SECOND_INSTANCE, + project_id=GCP_PROJECT_ID, + ) + + # [START howto_operator_get_instance] + get_instance = CloudMemorystoreGetInstanceOperator( + task_id="get-instance", location="europe-north1", instance=INSTANCE_NAME, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_get_instance] + + # [START howto_operator_get_instance_result] + get_instance_result = BashOperator( + task_id="get-instance-result", bash_command="echo \"{{ task_instance.xcom_pull('get-instance') }}\"" + ) + # [END howto_operator_get_instance_result] + + # [START howto_operator_failover_instance] + failover_instance = CloudMemorystoreFailoverInstanceOperator( + task_id="failover-instance", + location="europe-north1", + instance=INSTANCE_NAME_2, + data_protection_mode=FailoverInstanceRequest.DataProtectionMode.LIMITED_DATA_LOSS, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_failover_instance] + + # [START howto_operator_list_instances] + list_instances = CloudMemorystoreListInstancesOperator( + task_id="list-instances", location="-", page_size=100, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_list_instances] + + # [START howto_operator_list_instances_result] + list_instances_result = BashOperator( + task_id="list-instances-result", bash_command="echo \"{{ task_instance.xcom_pull('get-instance') }}\"" + ) + # [END howto_operator_list_instances_result] + + # [START howto_operator_update_instance] + update_instance = CloudMemorystoreUpdateInstanceOperator( + task_id="update-instance", + location="europe-north1", + instance_id=INSTANCE_NAME, + project_id=GCP_PROJECT_ID, + update_mask={"paths": ["memory_size_gb"]}, + instance={"memory_size_gb": 2}, + ) + # [END howto_operator_update_instance] + + # [START howto_operator_set_acl_permission] + set_acl_permission = GoogleCloudStorageBucketCreateAclEntryOperator( + task_id="gcs-set-acl-permission", + bucket=BUCKET_NAME, + entity="user-{{ task_instance.xcom_pull('get-instance')['persistenceIamIdentity']" + ".split(':', 2)[1] }}", + role="OWNER", + ) + # [END howto_operator_set_acl_permission] + + # [START howto_operator_export_instance] + export_instance = CloudMemorystoreExportInstanceOperator( + task_id="export-instance", + location="europe-north1", + instance=INSTANCE_NAME, + output_config={"gcs_destination": {"uri": EXPORT_GCS_URL}}, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_export_instance] + + # [START howto_operator_import_instance] + import_instance = CloudMemorystoreImportOperator( + task_id="import-instance", + location="europe-north1", + instance=INSTANCE_NAME_2, + input_config={"gcs_source": {"uri": EXPORT_GCS_URL}}, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_import_instance] + + # [START howto_operator_delete_instance] + delete_instance = CloudMemorystoreDeleteInstanceOperator( + task_id="delete-instance", location="europe-north1", instance=INSTANCE_NAME, project_id=GCP_PROJECT_ID + ) + # [END howto_operator_delete_instance] + + delete_instance_2 = CloudMemorystoreDeleteInstanceOperator( + task_id="delete-instance-2", + location="europe-north1", + instance=INSTANCE_NAME_2, + project_id=GCP_PROJECT_ID, + ) + + # [END howto_operator_create_instance_and_import] + create_instance_and_import = CloudMemorystoreCreateInstanceAndImportOperator( + task_id="create-instance-and-import", + location="europe-north1", + instance_id=INSTANCE_NAME_3, + instance=FIRST_INSTANCE, + input_config={"gcs_source": {"uri": EXPORT_GCS_URL}}, + project_id=GCP_PROJECT_ID, + ) + # [START howto_operator_create_instance_and_import] + + # [START howto_operator_scale_instance] + scale_instance = CloudMemorystoreScaleInstanceOperator( + task_id="scale-instance", + location="europe-north1", + instance_id=INSTANCE_NAME_3, + project_id=GCP_PROJECT_ID, + memory_size_gb=3, + ) + # [END howto_operator_scale_instance] + + # [END howto_operator_export_and_delete_instance] + export_and_delete_instance = CloudMemorystoreExportAndDeleteInstanceOperator( + task_id="export-and-delete-instance", + location="europe-north1", + instance=INSTANCE_NAME_3, + output_config={"gcs_destination": {"uri": EXPORT_GCS_URL}}, + project_id=GCP_PROJECT_ID, + ) + # [START howto_operator_export_and_delete_instance] + + create_instance >> get_instance >> get_instance_result + create_instance >> update_instance + create_instance >> create_instance_result + create_instance >> export_instance + create_instance_2 >> import_instance + create_instance >> list_instances >> list_instances_result + list_instances >> delete_instance + update_instance >> delete_instance + get_instance >> set_acl_permission >> export_instance + export_instance >> import_instance + export_instance >> delete_instance + import_instance >> delete_instance_2 + create_instance_2 >> failover_instance + failover_instance >> delete_instance_2 + + export_instance >> create_instance_and_import >> scale_instance >> export_and_delete_instance diff --git a/airflow/gcp/example_dags/example_cloud_sql.py b/airflow/gcp/example_dags/example_cloud_sql.py index 1f0d4feb8dc130..f778bade0b1f5f 100644 --- a/airflow/gcp/example_dags/example_cloud_sql.py +++ b/airflow/gcp/example_dags/example_cloud_sql.py @@ -38,7 +38,7 @@ CloudSqlInstanceDatabaseCreateOperator, CloudSqlInstanceDatabasePatchOperator, \ CloudSqlInstanceDatabaseDeleteOperator, CloudSqlInstanceExportOperator, \ CloudSqlInstanceImportOperator -from airflow.contrib.operators.gcs_acl_operator import \ +from airflow.gcp.operators.gcs import \ GoogleCloudStorageBucketCreateAclEntryOperator, \ GoogleCloudStorageObjectCreateAclEntryOperator diff --git a/airflow/gcp/example_dags/example_dataflow.py b/airflow/gcp/example_dags/example_dataflow.py new file mode 100644 index 00000000000000..a4e0c5e3c63ebc --- /dev/null +++ b/airflow/gcp/example_dags/example_dataflow.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Dataflow service +""" +import os + +import airflow +from airflow import models +from airflow.gcp.operators.dataflow import DataFlowJavaOperator, CheckJobRunning, DataFlowPythonOperator, \ + DataflowTemplateOperator + +GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') +GCS_TMP = os.environ.get('GCP_DATAFLOW_GCS_TMP', 'gs://test-dataflow-example/temp/') +GCS_STAGING = os.environ.get('GCP_DATAFLOW_GCS_STAGING', 'gs://test-dataflow-example/staging/') +GCS_OUTPUT = os.environ.get('GCP_DATAFLOW_GCS_OUTPUT', 'gs://test-dataflow-example/output') +GCS_JAR = os.environ.get('GCP_DATAFLOW_JAR', 'gs://test-dataflow-example/word-count-beam-bundled-0.1.jar') + +default_args = { + "start_date": airflow.utils.dates.days_ago(1), + 'dataflow_default_options': { + 'project': GCP_PROJECT_ID, + 'tempLocation': GCS_TMP, + 'stagingLocation': GCS_STAGING, + } +} + +with models.DAG( + "example_gcp_dataflow", + default_args=default_args, + schedule_interval=None, # Override to match your needs +) as dag: + + # [START howto_operator_start_java_job] + start_java_job = DataFlowJavaOperator( + task_id="start-java-job", + jar=GCS_JAR, + job_name='{{task.task_id}}22222255sss{{ macros.uuid.uuid4() }}', + options={ + 'output': GCS_OUTPUT, + }, + poll_sleep=10, + job_class='org.apache.beam.examples.WordCount', + check_if_running=CheckJobRunning.WaitForRun, + ) + # [END howto_operator_start_java_job] + + # [START howto_operator_start_python_job] + start_python_job = DataFlowPythonOperator( + task_id="start-python-job", + py_file='apache_beam.examples.wordcount', + py_options=['-m'], + job_name='{{task.task_id}}', + options={ + 'output': GCS_OUTPUT, + }, + check_if_running=CheckJobRunning.WaitForRun, + ) + # [END howto_operator_start_python_job] + + start_template_job = DataflowTemplateOperator( + task_id="start-template-job", + template='gs://dataflow-templates/latest/Word_Count', + parameters={ + 'inputFile': "gs://dataflow-samples/shakespeare/kinglear.txt", + 'output': GCS_OUTPUT + }, + ) diff --git a/airflow/gcp/example_dags/example_dataproc.py b/airflow/gcp/example_dags/example_dataproc.py new file mode 100644 index 00000000000000..9a5ee48210fba7 --- /dev/null +++ b/airflow/gcp/example_dags/example_dataproc.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that show how to use various Dataproc +operators to manage a cluster and submit jobs. +""" + +import os +import airflow +from airflow import models +from airflow.gcp.operators.dataproc import ( + DataprocClusterCreateOperator, + DataprocClusterDeleteOperator, + DataprocClusterScaleOperator, + DataProcSparkSqlOperator, + DataProcSparkOperator, + DataProcPySparkOperator, + DataProcPigOperator, + DataProcHiveOperator, + DataProcHadoopOperator, +) + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") +CLUSTER_NAME = os.environ.get("GCP_DATAPROC_CLUSTER_NAME", "example-project") +REGION = os.environ.get("GCP_LOCATION", "europe-west1") +ZONE = os.environ.get("GCP_REGION", "europe-west-1b") +BUCKET = os.environ.get("GCP_DATAPROC_BUCKET", "dataproc-system-tests") +OUTPUT_FOLDER = "wordcount" +OUTPUT_PATH = "gs://{}/{}/".format(BUCKET, OUTPUT_FOLDER) +PYSPARK_MAIN = os.environ.get("PYSPARK_MAIN", "hello_world.py") +PYSPARK_URI = "gs://{}/{}".format(BUCKET, PYSPARK_MAIN) + +with models.DAG( + "example_gcp_dataproc", + default_args={"start_date": airflow.utils.dates.days_ago(1)}, + schedule_interval=None, +) as dag: + create_cluster = DataprocClusterCreateOperator( + task_id="create_cluster", + cluster_name=CLUSTER_NAME, + project_id=PROJECT_ID, + num_workers=2, + region=REGION, + ) + + scale_cluster = DataprocClusterScaleOperator( + task_id="scale_cluster", + num_workers=3, + cluster_name=CLUSTER_NAME, + project_id=PROJECT_ID, + region=REGION, + ) + + pig_task = DataProcPigOperator( + task_id="pig_task", + query="define sin HiveUDF('sin');", + region=REGION, + cluster_name=CLUSTER_NAME, + ) + + spark_sql_task = DataProcSparkSqlOperator( + task_id="spark_sql_task", + query="SHOW DATABASES;", + region=REGION, + cluster_name=CLUSTER_NAME, + ) + + spark_task = DataProcSparkOperator( + task_id="spark_task", + main_class="org.apache.spark.examples.SparkPi", + dataproc_jars="file:///usr/lib/spark/examples/jars/spark-examples.jar", + region=REGION, + cluster_name=CLUSTER_NAME, + ) + + pyspark_task = DataProcPySparkOperator( + task_id="pyspark_task", + main=PYSPARK_URI, + region=REGION, + cluster_name=CLUSTER_NAME, + ) + + hive_task = DataProcHiveOperator( + task_id="hive_task", + query="SHOW DATABASES;", + region=REGION, + cluster_name=CLUSTER_NAME, + ) + + hadoop_task = DataProcHadoopOperator( + task_id="hadoop_task", + main_jar="file:///usr/lib/hadoop-mapreduce/hadoop-mapreduce-examples.jar", + arguments=["wordcount", "gs://pub/shakespeare/rose.txt", OUTPUT_PATH], + region=REGION, + cluster_name=CLUSTER_NAME, + ) + + delete_cluster = DataprocClusterDeleteOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + cluster_name=CLUSTER_NAME, + region=REGION, + ) + + create_cluster >> scale_cluster + scale_cluster >> hive_task >> delete_cluster + scale_cluster >> pig_task >> delete_cluster + scale_cluster >> spark_sql_task >> delete_cluster + scale_cluster >> spark_task >> delete_cluster + scale_cluster >> pyspark_task >> delete_cluster + scale_cluster >> hadoop_task >> delete_cluster diff --git a/airflow/gcp/example_dags/example_datastore.py b/airflow/gcp/example_dags/example_datastore.py new file mode 100644 index 00000000000000..911bbc5fb01d5f --- /dev/null +++ b/airflow/gcp/example_dags/example_datastore.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that shows how to use Datastore operators. + +This example requires that your project contains Datastore instance. +""" + +import os + +from airflow import models +from airflow.utils import dates +from airflow.gcp.operators.datastore import ( + DatastoreImportOperator, + DatastoreExportOperator, +) + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +BUCKET = os.environ.get("GCP_DATASTORE_BUCKET", "datastore-system-test") + +default_args = {"start_date": dates.days_ago(1)} + +with models.DAG( + "example_gcp_datastore", + default_args=default_args, + schedule_interval=None, # Override to match your needs +) as dag: + export_task = DatastoreExportOperator( + task_id="export_task", + bucket=BUCKET, + project_id=GCP_PROJECT_ID, + overwrite_existing=True, + ) + + bucket = "{{ task_instance.xcom_pull('export_task')['response']['outputUrl'].split('/')[2] }}" + file = "{{ '/'.join(task_instance.xcom_pull('export_task')['response']['outputUrl'].split('/')[3:]) }}" + + import_task = DatastoreImportOperator( + task_id="import_task", bucket=bucket, file=file, project_id=GCP_PROJECT_ID + ) + + export_task >> import_task diff --git a/airflow/gcp/example_dags/example_gcp_dataproc_create_cluster.py b/airflow/gcp/example_dags/example_gcp_dataproc_create_cluster.py deleted file mode 100644 index 07e201becb0293..00000000000000 --- a/airflow/gcp/example_dags/example_gcp_dataproc_create_cluster.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that creates DataProc cluster. -""" - -import os -import airflow -from airflow import models -from airflow.gcp.operators.dataproc import ( - DataprocClusterCreateOperator, - DataprocClusterDeleteOperator -) - -PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'an-id') -CLUSTER_NAME = os.environ.get('GCP_DATAPROC_CLUSTER_NAME', 'example-project') -REGION = os.environ.get('GCP_LOCATION', 'europe-west1') -ZONE = os.environ.get('GCP_REGION', 'europe-west-1b') - -with models.DAG( - "example_gcp_dataproc_create_cluster", - default_args={"start_date": airflow.utils.dates.days_ago(1)}, - schedule_interval=None, -) as dag: - create_cluster = DataprocClusterCreateOperator( - task_id="create_cluster", - cluster_name=CLUSTER_NAME, - project_id=PROJECT_ID, - num_workers=2, - region=REGION, - ) - - delete_cluster = DataprocClusterDeleteOperator( - task_id="delete_cluster", - project_id=PROJECT_ID, - cluster_name=CLUSTER_NAME, - region=REGION - ) - - create_cluster >> delete_cluster # pylint: disable=pointless-statement diff --git a/airflow/gcp/example_dags/example_gcp_dataproc_pig_operator.py b/airflow/gcp/example_dags/example_gcp_dataproc_pig_operator.py deleted file mode 100644 index 709608472b45a2..00000000000000 --- a/airflow/gcp/example_dags/example_gcp_dataproc_pig_operator.py +++ /dev/null @@ -1,66 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG for Google Dataproc PigOperator -""" - -import os -import airflow -from airflow import models -from airflow.gcp.operators.dataproc import ( - DataProcPigOperator, - DataprocClusterCreateOperator, - DataprocClusterDeleteOperator -) - -default_args = {"start_date": airflow.utils.dates.days_ago(1)} - -CLUSTER_NAME = os.environ.get('GCP_DATAPROC_CLUSTER_NAME', 'example-project') -PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'an-id') -REGION = os.environ.get('GCP_LOCATION', 'europe-west1') - - -with models.DAG( - "example_gcp_dataproc_pig_operator", - default_args=default_args, - schedule_interval=None, -) as dag: - create_task = DataprocClusterCreateOperator( - task_id="create_task", - cluster_name=CLUSTER_NAME, - project_id=PROJECT_ID, - region=REGION, - num_workers=2 - ) - - pig_task = DataProcPigOperator( - task_id="pig_task", - query="define sin HiveUDF('sin');", - region=REGION, - cluster_name=CLUSTER_NAME - ) - - delete_task = DataprocClusterDeleteOperator( - task_id="delete_task", - project_id=PROJECT_ID, - cluster_name=CLUSTER_NAME, - region=REGION - ) - - create_task >> pig_task >> delete_task diff --git a/airflow/gcp/example_dags/example_gcs.py b/airflow/gcp/example_dags/example_gcs.py new file mode 100644 index 00000000000000..8c037320bc5870 --- /dev/null +++ b/airflow/gcp/example_dags/example_gcs.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Cloud Storage operators. +""" + +import os +import airflow +from airflow import models +from airflow.operators.bash_operator import BashOperator +from airflow.operators.local_to_gcs import FileToGoogleCloudStorageOperator +from airflow.operators.gcs_to_gcs import GoogleCloudStorageToGoogleCloudStorageOperator +from airflow.gcp.operators.gcs import ( + GoogleCloudStorageBucketCreateAclEntryOperator, + GoogleCloudStorageObjectCreateAclEntryOperator, + GoogleCloudStorageListOperator, + GoogleCloudStorageDeleteOperator, + GoogleCloudStorageDownloadOperator, + GoogleCloudStorageCreateBucketOperator +) + +default_args = {"start_date": airflow.utils.dates.days_ago(1)} + +# [START howto_operator_gcs_acl_args_common] +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-id") +BUCKET_1 = os.environ.get("GCP_GCS_BUCKET_1", "test-gcs-example-bucket") +GCS_ACL_ENTITY = os.environ.get("GCS_ACL_ENTITY", "allUsers") +GCS_ACL_BUCKET_ROLE = "OWNER" +GCS_ACL_OBJECT_ROLE = "OWNER" +# [END howto_operator_gcs_acl_args_common] + +BUCKET_2 = os.environ.get("GCP_GCS_BUCKET_1", "test-gcs-example-bucket-2") + +PATH_TO_UPLOAD_FILE = os.environ.get( + "GCP_GCS_PATH_TO_UPLOAD_FILE", "test-gcs-example.txt" +) +PATH_TO_SAVED_FILE = os.environ.get( + "GCP_GCS_PATH_TO_SAVED_FILE", "test-gcs-example-download.txt" +) + +BUCKET_FILE_LOCATION = PATH_TO_UPLOAD_FILE.rpartition("/")[-1] + + +with models.DAG( + "example_gcs", default_args=default_args, schedule_interval=None +) as dag: + create_bucket1 = GoogleCloudStorageCreateBucketOperator( + task_id="create_bucket1", bucket_name=BUCKET_1, project_id=PROJECT_ID + ) + + create_bucket2 = GoogleCloudStorageCreateBucketOperator( + task_id="create_bucket2", bucket_name=BUCKET_2, project_id=PROJECT_ID + ) + + list_buckets = GoogleCloudStorageListOperator( + task_id="list_buckets", bucket=BUCKET_1 + ) + + list_buckets_result = BashOperator( + task_id="list_buckets_result", + bash_command="echo \"{{ task_instance.xcom_pull('list_buckets') }}\"", + ) + + upload_file = FileToGoogleCloudStorageOperator( + task_id="upload_file", + src=PATH_TO_UPLOAD_FILE, + dst=BUCKET_FILE_LOCATION, + bucket=BUCKET_1, + ) + + # [START howto_operator_gcs_bucket_create_acl_entry_task] + gcs_bucket_create_acl_entry_task = GoogleCloudStorageBucketCreateAclEntryOperator( + bucket=BUCKET_1, + entity=GCS_ACL_ENTITY, + role=GCS_ACL_BUCKET_ROLE, + task_id="gcs_bucket_create_acl_entry_task", + ) + # [END howto_operator_gcs_bucket_create_acl_entry_task] + + # [START howto_operator_gcs_object_create_acl_entry_task] + gcs_object_create_acl_entry_task = GoogleCloudStorageObjectCreateAclEntryOperator( + bucket=BUCKET_1, + object_name=BUCKET_FILE_LOCATION, + entity=GCS_ACL_ENTITY, + role=GCS_ACL_OBJECT_ROLE, + task_id="gcs_object_create_acl_entry_task", + ) + # [END howto_operator_gcs_object_create_acl_entry_task] + + download_file = GoogleCloudStorageDownloadOperator( + task_id="download_file", + object_name=BUCKET_FILE_LOCATION, + bucket=BUCKET_1, + filename=PATH_TO_SAVED_FILE, + ) + + copy_file = GoogleCloudStorageToGoogleCloudStorageOperator( + task_id="copy_file", + source_bucket=BUCKET_1, + source_object=BUCKET_FILE_LOCATION, + destination_bucket=BUCKET_2, + destination_object=BUCKET_FILE_LOCATION, + ) + + delete_files = GoogleCloudStorageDeleteOperator( + task_id="delete_files", bucket_name=BUCKET_1, prefix="" + ) + + [create_bucket1, create_bucket2] >> list_buckets >> list_buckets_result + [create_bucket1, create_bucket2] >> upload_file + upload_file >> [download_file, copy_file] + upload_file >> gcs_bucket_create_acl_entry_task >> gcs_object_create_acl_entry_task >> delete_files diff --git a/airflow/gcp/hooks/bigquery.py b/airflow/gcp/hooks/bigquery.py new file mode 100644 index 00000000000000..230c1a7dc163bb --- /dev/null +++ b/airflow/gcp/hooks/bigquery.py @@ -0,0 +1,2335 @@ +# -*- coding: utf-8 -*- # pylint: disable=too-many-lines +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains a BigQuery Hook, as well as a very basic PEP 249 +implementation for BigQuery. +""" + +import time +from copy import deepcopy +from typing import Any, NoReturn, Mapping, Union, Iterable, Dict, List, Optional, Tuple, Type + +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError +from pandas import DataFrame +from pandas_gbq.gbq import \ + _check_google_client_version as gbq_check_google_client_version +from pandas_gbq import read_gbq +from pandas_gbq.gbq import \ + _test_google_api_imports as gbq_test_google_api_imports +from pandas_gbq.gbq import GbqConnector + +from airflow import AirflowException +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook +from airflow.hooks.dbapi_hook import DbApiHook +from airflow.utils.log.logging_mixin import LoggingMixin + + +class BigQueryHook(GoogleCloudBaseHook, DbApiHook): + """ + Interact with BigQuery. This hook uses the Google Cloud Platform + connection. + """ + conn_name_attr = 'bigquery_conn_id' # type: str + + def __init__(self, + bigquery_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + use_legacy_sql: bool = True, + location: Optional[str] = None) -> None: + super().__init__( + gcp_conn_id=bigquery_conn_id, delegate_to=delegate_to) + self.use_legacy_sql = use_legacy_sql + self.location = location + + def get_conn(self) -> "BigQueryConnection": + """ + Returns a BigQuery PEP 249 connection object. + """ + service = self.get_service() + return BigQueryConnection( + service=service, + project_id=self.project_id, + use_legacy_sql=self.use_legacy_sql, + location=self.location, + num_retries=self.num_retries + ) + + def get_service(self) -> Any: + """ + Returns a BigQuery service object. + """ + http_authorized = self._authorize() + return build( + 'bigquery', 'v2', http=http_authorized, cache_discovery=False) + + def insert_rows( + self, table: Any, rows: Any, target_fields: Any = None, commit_every: Any = 1000, replace: Any = False + ) -> NoReturn: + """ + Insertion is currently unsupported. Theoretically, you could use + BigQuery's streaming API to insert rows into a table, but this hasn't + been implemented. + """ + raise NotImplementedError() + + def get_pandas_df( + self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None, dialect: Optional[str] = None + ) -> DataFrame: + """ + Returns a Pandas DataFrame for the results produced by a BigQuery + query. The DbApiHook method must be overridden because Pandas + doesn't support PEP 249 connections, except for SQLite. See: + + https://github.com/pydata/pandas/blob/master/pandas/io/sql.py#L447 + https://github.com/pydata/pandas/issues/6900 + + :param sql: The BigQuery SQL to execute. + :type sql: str + :param parameters: The parameters to render the SQL query with (not + used, leave to override superclass method) + :type parameters: mapping or iterable + :param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL + defaults to use `self.use_legacy_sql` if not specified + :type dialect: str in {'legacy', 'standard'} + """ + if dialect is None: + dialect = 'legacy' if self.use_legacy_sql else 'standard' + + credentials, project_id = self._get_credentials_and_project_id() + + return read_gbq(sql, + project_id=project_id, + dialect=dialect, + verbose=False, + credentials=credentials) + + def table_exists(self, project_id: str, dataset_id: str, table_id: str) -> bool: + """ + Checks for the existence of a table in Google BigQuery. + + :param project_id: The Google cloud project in which to look for the + table. The connection supplied to the hook must provide access to + the specified project. + :type project_id: str + :param dataset_id: The name of the dataset in which to look for the + table. + :type dataset_id: str + :param table_id: The name of the table to check the existence of. + :type table_id: str + """ + service = self.get_service() + try: + service.tables().get( # pylint: disable=no-member + projectId=project_id, datasetId=dataset_id, + tableId=table_id).execute(num_retries=self.num_retries) + return True + except HttpError as e: + if e.resp['status'] == '404': + return False + raise + + +class BigQueryPandasConnector(GbqConnector): + """ + This connector behaves identically to GbqConnector (from Pandas), except + that it allows the service to be injected, and disables a call to + self.get_credentials(). This allows Airflow to use BigQuery with Pandas + without forcing a three legged OAuth connection. Instead, we can inject + service account credentials into the binding. + """ + + def __init__( + self, project_id: str, service: str, reauth: bool = False, verbose: bool = False, dialect="legacy" + ) -> None: + super().__init__(project_id) + gbq_check_google_client_version() + gbq_test_google_api_imports() + self.project_id = project_id + self.reauth = reauth + self.service = service + self.verbose = verbose + self.dialect = dialect + + +class BigQueryConnection: + """ + BigQuery does not have a notion of a persistent connection. Thus, these + objects are small stateless factories for cursors, which do all the real + work. + """ + + def __init__(self, *args, **kwargs) -> None: + self._args = args + self._kwargs = kwargs + + def close(self) -> None: + """ BigQueryConnection does not have anything to close. """ + + def commit(self) -> None: + """ BigQueryConnection does not support transactions. """ + + def cursor(self) -> "BigQueryCursor": + """ Return a new :py:class:`Cursor` object using the connection. """ + return BigQueryCursor(*self._args, **self._kwargs) + + def rollback(self) -> NoReturn: + """ BigQueryConnection does not have transactions """ + raise NotImplementedError( + "BigQueryConnection does not have transactions") + + +class BigQueryBaseCursor(LoggingMixin): + """ + The BigQuery base cursor contains helper methods to execute queries against + BigQuery. The methods can be used directly by operators, in cases where a + PEP 249 cursor isn't needed. + """ + + def __init__(self, + service: Any, + project_id: str, + use_legacy_sql: bool = True, + api_resource_configs: Optional[Dict] = None, + location: Optional[str] = None, + num_retries: int = 5) -> None: + + self.service = service + self.project_id = project_id + self.use_legacy_sql = use_legacy_sql + if api_resource_configs: + _validate_value("api_resource_configs", api_resource_configs, dict) + self.api_resource_configs = api_resource_configs \ + if api_resource_configs else {} # type Dict + self.running_job_id = None # type: Optional[str] + self.location = location + self.num_retries = num_retries + + # pylint: disable=too-many-arguments + def create_empty_table(self, + project_id: str, + dataset_id: str, + table_id: str, + schema_fields: Optional[List] = None, + time_partitioning: Optional[Dict] = None, + cluster_fields: Optional[List] = None, + labels: Optional[Dict] = None, + view: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + num_retries: int = 5) -> None: + """ + Creates a new, empty table in the dataset. + To create a view, which is defined by a SQL query, parse a dictionary to 'view' kwarg + + :param project_id: The project to create the table into. + :type project_id: str + :param dataset_id: The dataset to create the table into. + :type dataset_id: str + :param table_id: The Name of the table to be created. + :type table_id: str + :param schema_fields: If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema + :type schema_fields: list + :param labels: a dictionary containing labels for the table, passed to BigQuery + :type labels: dict + + **Example**: :: + + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] + + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + + .. seealso:: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning + :type time_partitioning: dict + :param cluster_fields: [Optional] The fields used for clustering. + Must be specified with time_partitioning, data in the table will be first + partitioned and subsequently clustered. + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#clustering.fields + :type cluster_fields: list + :param view: [Optional] A dictionary containing definition for the view. + If set, it will create a view instead of a table: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#ViewDefinition + :type view: dict + + **Example**: :: + + view = { + "query": "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 1000", + "useLegacySql": False + } + + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :return: None + """ + + project_id = project_id if project_id is not None else self.project_id + + table_resource = { + 'tableReference': { + 'tableId': table_id + } + } # type: Dict[str, Any] + + if self.location: + table_resource['location'] = self.location + + if schema_fields: + table_resource['schema'] = {'fields': schema_fields} + + if time_partitioning: + table_resource['timePartitioning'] = time_partitioning + + if cluster_fields: + table_resource['clustering'] = { + 'fields': cluster_fields + } + + if labels: + table_resource['labels'] = labels + + if view: + table_resource['view'] = view + + if encryption_configuration: + table_resource["encryptionConfiguration"] = encryption_configuration + + num_retries = num_retries if num_retries else self.num_retries + + self.log.info('Creating Table %s:%s.%s', + project_id, dataset_id, table_id) + + try: + self.service.tables().insert( + projectId=project_id, + datasetId=dataset_id, + body=table_resource).execute(num_retries=num_retries) + + self.log.info('Table created successfully: %s:%s.%s', + project_id, dataset_id, table_id) + + except HttpError as err: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(err.content) + ) + + def create_external_table(self, # pylint: disable=too-many-locals,too-many-arguments + external_project_dataset_table: str, + schema_fields: List, + source_uris: List, + source_format: str = 'CSV', + autodetect: bool = False, + compression: str = 'NONE', + ignore_unknown_values: bool = False, + max_bad_records: int = 0, + skip_leading_rows: int = 0, + field_delimiter: str = ',', + quote_character: Optional[str] = None, + allow_quoted_newlines: bool = False, + allow_jagged_rows: bool = False, + src_fmt_configs: Optional[Dict] = None, + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None + ) -> None: + """ + Creates a new external table in the dataset with the data in Google + Cloud Storage. See here: + + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource + + for more details about these parameters. + + :param external_project_dataset_table: + The dotted ``(.|:).
($)`` BigQuery + table name to create external table. + If ```` is not included, project will be the + project defined in the connection json. + :type external_project_dataset_table: str + :param schema_fields: The schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource + :type schema_fields: list + :param source_uris: The source Google Cloud + Storage URI (e.g. gs://some-bucket/some-file.txt). A single wild + per-object name can be used. + :type source_uris: list + :param source_format: File format to export. + :type source_format: str + :param autodetect: Try to detect schema and format options automatically. + Any option specified explicitly will be honored. + :type autodetect: bool + :param compression: [Optional] The compression type of the data source. + Possible values include GZIP and NONE. + The default value is NONE. + This setting is ignored for Google Cloud Bigtable, + Google Cloud Datastore backups and Avro formats. + :type compression: str + :param ignore_unknown_values: [Optional] Indicates if BigQuery should allow + extra values that are not represented in the table schema. + If true, the extra values are ignored. If false, records with extra columns + are treated as bad records, and if there are too many bad records, an + invalid error is returned in the job result. + :type ignore_unknown_values: bool + :param max_bad_records: The maximum number of bad records that BigQuery can + ignore when running the job. + :type max_bad_records: int + :param skip_leading_rows: Number of rows to skip when loading from a CSV. + :type skip_leading_rows: int + :param field_delimiter: The delimiter to use when loading from a CSV. + :type field_delimiter: str + :param quote_character: The value that is used to quote data sections in a CSV + file. + :type quote_character: str + :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not + (false). + :type allow_quoted_newlines: bool + :param allow_jagged_rows: Accept rows that are missing trailing optional columns. + The missing values are treated as nulls. If false, records with missing + trailing columns are treated as bad records, and if there are too many bad + records, an invalid error is returned in the job result. Only applicable when + soure_format is CSV. + :type allow_jagged_rows: bool + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict + :param labels: a dictionary containing labels for the table, passed to BigQuery + :type labels: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + """ + + if src_fmt_configs is None: + src_fmt_configs = {} + project_id, dataset_id, external_table_id = \ + _split_tablename(table_input=external_project_dataset_table, + default_project_id=self.project_id, + var_name='external_project_dataset_table') + + # bigquery only allows certain source formats + # we check to make sure the passed source format is valid + # if it's not, we raise a ValueError + # Refer to this link for more details: + # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#externalDataConfiguration.sourceFormat # noqa # pylint: disable=line-too-long + + source_format = source_format.upper() + allowed_formats = [ + "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", + "DATASTORE_BACKUP", "PARQUET" + ] # type: List[str] + if source_format not in allowed_formats: + raise ValueError("{0} is not a valid source format. " + "Please use one of the following types: {1}" + .format(source_format, allowed_formats)) + + compression = compression.upper() + allowed_compressions = ['NONE', 'GZIP'] # type: List[str] + if compression not in allowed_compressions: + raise ValueError("{0} is not a valid compression format. " + "Please use one of the following types: {1}" + .format(compression, allowed_compressions)) + + table_resource = { + 'externalDataConfiguration': { + 'autodetect': autodetect, + 'sourceFormat': source_format, + 'sourceUris': source_uris, + 'compression': compression, + 'ignoreUnknownValues': ignore_unknown_values + }, + 'tableReference': { + 'projectId': project_id, + 'datasetId': dataset_id, + 'tableId': external_table_id, + } + } # type: Dict[str, Any] + + if self.location: + table_resource['location'] = self.location + + if schema_fields: + table_resource['externalDataConfiguration'].update({ + 'schema': { + 'fields': schema_fields + } + }) + + self.log.info('Creating external table: %s', external_project_dataset_table) + + if max_bad_records: + table_resource['externalDataConfiguration']['maxBadRecords'] = max_bad_records + + # if following fields are not specified in src_fmt_configs, + # honor the top-level params for backward-compatibility + backward_compatibility_configs = {'skipLeadingRows': skip_leading_rows, + 'fieldDelimiter': field_delimiter, + 'quote': quote_character, + 'allowQuotedNewlines': allow_quoted_newlines, + 'allowJaggedRows': allow_jagged_rows} + + src_fmt_to_param_mapping = { + 'CSV': 'csvOptions', + 'GOOGLE_SHEETS': 'googleSheetsOptions' + } + + src_fmt_to_configs_mapping = { + 'csvOptions': [ + 'allowJaggedRows', 'allowQuotedNewlines', + 'fieldDelimiter', 'skipLeadingRows', + 'quote' + ], + 'googleSheetsOptions': ['skipLeadingRows'] + } + + if source_format in src_fmt_to_param_mapping.keys(): + valid_configs = src_fmt_to_configs_mapping[ + src_fmt_to_param_mapping[source_format] + ] + + src_fmt_configs = _validate_src_fmt_configs(source_format, src_fmt_configs, valid_configs, + backward_compatibility_configs) + + table_resource['externalDataConfiguration'][src_fmt_to_param_mapping[ + source_format]] = src_fmt_configs + + if labels: + table_resource['labels'] = labels + + if encryption_configuration: + table_resource["encryptionConfiguration"] = encryption_configuration + + try: + self.service.tables().insert( + projectId=project_id, + datasetId=dataset_id, + body=table_resource + ).execute(num_retries=self.num_retries) + + self.log.info('External table created successfully: %s', + external_project_dataset_table) + + except HttpError as err: + raise Exception( + 'BigQuery job failed. Error was: {}'.format(err.content) + ) + + def patch_table(self, # pylint: disable=too-many-arguments + dataset_id: str, + table_id: str, + project_id: Optional[str] = None, + description: Optional[str] = None, + expiration_time: Optional[int] = None, + external_data_configuration: Optional[Dict] = None, + friendly_name: Optional[str] = None, + labels: Optional[Dict] = None, + schema: Optional[List] = None, + time_partitioning: Optional[Dict] = None, + view: Optional[Dict] = None, + require_partition_filter: Optional[bool] = None, + encryption_configuration: Optional[Dict] = None) -> None: + """ + Patch information in an existing table. + It only updates fileds that are provided in the request object. + + Reference: https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/patch + + :param dataset_id: The dataset containing the table to be patched. + :type dataset_id: str + :param table_id: The Name of the table to be patched. + :type table_id: str + :param project_id: The project containing the table to be patched. + :type project_id: str + :param description: [Optional] A user-friendly description of this table. + :type description: str + :param expiration_time: [Optional] The time when this table expires, + in milliseconds since the epoch. + :type expiration_time: int + :param external_data_configuration: [Optional] A dictionary containing + properties of a table stored outside of BigQuery. + :type external_data_configuration: dict + :param friendly_name: [Optional] A descriptive name for this table. + :type friendly_name: str + :param labels: [Optional] A dictionary containing labels associated with this table. + :type labels: dict + :param schema: [Optional] If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema + The supported schema modifications and unsupported schema modification are listed here: + https://cloud.google.com/bigquery/docs/managing-table-schemas + **Example**: :: + + schema=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] + + :type schema: list + :param time_partitioning: [Optional] A dictionary containing time-based partitioning + definition for the table. + :type time_partitioning: dict + :param view: [Optional] A dictionary containing definition for the view. + If set, it will patch a view instead of a table: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#ViewDefinition + **Example**: :: + + view = { + "query": "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 500", + "useLegacySql": False + } + + :type view: dict + :param require_partition_filter: [Optional] If true, queries over the this table require a + partition filter. If false, queries over the table + :type require_partition_filter: bool + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + + """ + + project_id = project_id if project_id is not None else self.project_id + + table_resource = {} # type: Dict[str, Any] + + if description is not None: + table_resource['description'] = description + if expiration_time is not None: + table_resource['expirationTime'] = expiration_time + if external_data_configuration: + table_resource['externalDataConfiguration'] = external_data_configuration + if friendly_name is not None: + table_resource['friendlyName'] = friendly_name + if labels: + table_resource['labels'] = labels + if schema: + table_resource['schema'] = {'fields': schema} + if time_partitioning: + table_resource['timePartitioning'] = time_partitioning + if view: + table_resource['view'] = view + if require_partition_filter is not None: + table_resource['requirePartitionFilter'] = require_partition_filter + if encryption_configuration: + table_resource["encryptionConfiguration"] = encryption_configuration + + self.log.info('Patching Table %s:%s.%s', + project_id, dataset_id, table_id) + + try: + self.service.tables().patch( + projectId=project_id, + datasetId=dataset_id, + tableId=table_id, + body=table_resource).execute(num_retries=self.num_retries) + + self.log.info('Table patched successfully: %s:%s.%s', + project_id, dataset_id, table_id) + + except HttpError as err: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(err.content) + ) + + # pylint: disable=too-many-locals,too-many-arguments, too-many-branches + def run_query(self, + sql: str, + destination_dataset_table: Optional[str] = None, + write_disposition: str = 'WRITE_EMPTY', + allow_large_results: bool = False, + flatten_results: Optional[bool] = None, + udf_config: Optional[List] = None, + use_legacy_sql: Optional[bool] = None, + maximum_billing_tier: Optional[int] = None, + maximum_bytes_billed: Optional[float] = None, + create_disposition: str = 'CREATE_IF_NEEDED', + query_params: Optional[List] = None, + labels: Optional[Dict] = None, + schema_update_options: Optional[Iterable] = None, + priority: str = 'INTERACTIVE', + time_partitioning: Optional[Dict] = None, + api_resource_configs: Optional[Dict] = None, + cluster_fields: Optional[List[str]] = None, + location: Optional[str] = None, + encryption_configuration: Optional[Dict] = None) -> str: + """ + Executes a BigQuery SQL query. Optionally persists results in a BigQuery + table. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + For more details about these parameters. + + :param sql: The BigQuery SQL to execute. + :type sql: str + :param destination_dataset_table: The dotted ``.
`` + BigQuery table to save the query results. + :type destination_dataset_table: str + :param write_disposition: What to do if the table already exists in + BigQuery. + :type write_disposition: str + :param allow_large_results: Whether to allow large results. + :type allow_large_results: bool + :param flatten_results: If true and query uses legacy SQL dialect, flattens + all nested and repeated fields in the query results. ``allowLargeResults`` + must be true if this is set to false. For standard SQL queries, this + flag is ignored and results are never flattened. + :type flatten_results: bool + :param udf_config: The User Defined Function configuration for the query. + See https://cloud.google.com/bigquery/user-defined-functions for details. + :type udf_config: list + :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). + If `None`, defaults to `self.use_legacy_sql`. + :type use_legacy_sql: bool + :param api_resource_configs: a dictionary that contain params + 'configuration' applied for Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs + for example, {'query': {'useQueryCache': False}}. You could use it + if you need to provide some params that are not supported by the + BigQueryHook like args. + :type api_resource_configs: dict + :param maximum_billing_tier: Positive integer that serves as a + multiplier of the basic price. + :type maximum_billing_tier: int + :param maximum_bytes_billed: Limits the bytes billed for this job. + Queries that will have bytes billed beyond this limit will fail + (without incurring a charge). If unspecified, this will be + set to your project default. + :type maximum_bytes_billed: float + :param create_disposition: Specifies whether the job is allowed to + create new tables. + :type create_disposition: str + :param query_params: a list of dictionary containing query parameter types and + values, passed to BigQuery + :type query_params: list + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param schema_update_options: Allows the schema of the destination + table to be updated as a side effect of the query job. + :type schema_update_options: Union[list, tuple, set] + :param priority: Specifies a priority for the query. + Possible values include INTERACTIVE and BATCH. + The default value is INTERACTIVE. + :type priority: str + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + :type time_partitioning: dict + :param cluster_fields: Request that the result of this query be stored sorted + by one or more columns. This is only available in combination with + time_partitioning. The order of columns given determines the sort order. + :type cluster_fields: list[str] + :param location: The geographic location of the job. Required except for + US and EU. See details at + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :type location: str + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + """ + schema_update_options = list(schema_update_options or []) + + if time_partitioning is None: + time_partitioning = {} + + if location: + self.location = location + + if not api_resource_configs: + api_resource_configs = self.api_resource_configs + else: + _validate_value('api_resource_configs', + api_resource_configs, dict) + configuration = deepcopy(api_resource_configs) + if 'query' not in configuration: + configuration['query'] = {} + + else: + _validate_value("api_resource_configs['query']", + configuration['query'], dict) + + if sql is None and not configuration['query'].get('query', None): + raise TypeError('`BigQueryBaseCursor.run_query` ' + 'missing 1 required positional argument: `sql`') + + # BigQuery also allows you to define how you want a table's schema to change + # as a side effect of a query job + # for more details: + # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions # noqa # pylint: disable=line-too-long + + allowed_schema_update_options = [ + 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" + ] + + if not set(allowed_schema_update_options + ).issuperset(set(schema_update_options)): + raise ValueError("{0} contains invalid schema update options. " + "Please only use one or more of the following " + "options: {1}" + .format(schema_update_options, + allowed_schema_update_options)) + + if schema_update_options: + if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: + raise ValueError("schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") + + if destination_dataset_table: + destination_project, destination_dataset, destination_table = \ + _split_tablename(table_input=destination_dataset_table, + default_project_id=self.project_id) + + destination_dataset_table = { # type: ignore + 'projectId': destination_project, + 'datasetId': destination_dataset, + 'tableId': destination_table, + } + + if cluster_fields: + cluster_fields = {'fields': cluster_fields} # type: ignore + + query_param_list = [ + (sql, 'query', None, (str,)), + (priority, 'priority', 'INTERACTIVE', (str,)), + (use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool), + (query_params, 'queryParameters', None, list), + (udf_config, 'userDefinedFunctionResources', None, list), + (maximum_billing_tier, 'maximumBillingTier', None, int), + (maximum_bytes_billed, 'maximumBytesBilled', None, float), + (time_partitioning, 'timePartitioning', {}, dict), + (schema_update_options, 'schemaUpdateOptions', None, list), + (destination_dataset_table, 'destinationTable', None, dict), + (cluster_fields, 'clustering', None, dict), + ] # type: List[Tuple] + + for param, param_name, param_default, param_type in query_param_list: + if param_name not in configuration['query'] and param in [None, {}, ()]: + if param_name == 'timePartitioning': + param_default = _cleanse_time_partitioning( + destination_dataset_table, time_partitioning) + param = param_default + + if param in [None, {}, ()]: + continue + + _api_resource_configs_duplication_check( + param_name, param, configuration['query']) + + configuration['query'][param_name] = param + + # check valid type of provided param, + # it last step because we can get param from 2 sources, + # and first of all need to find it + + _validate_value(param_name, configuration['query'][param_name], + param_type) + + if param_name == 'schemaUpdateOptions' and param: + self.log.info("Adding experimental 'schemaUpdateOptions': " + "%s", schema_update_options) + + if param_name != 'destinationTable': + continue + + for key in ['projectId', 'datasetId', 'tableId']: + if key not in configuration['query']['destinationTable']: + raise ValueError( + "Not correct 'destinationTable' in " + "api_resource_configs. 'destinationTable' " + "must be a dict with {'projectId':'', " + "'datasetId':'', 'tableId':''}") + + configuration['query'].update({ + 'allowLargeResults': allow_large_results, + 'flattenResults': flatten_results, + 'writeDisposition': write_disposition, + 'createDisposition': create_disposition, + }) + + if 'useLegacySql' in configuration['query'] and configuration['query']['useLegacySql'] and\ + 'queryParameters' in configuration['query']: + raise ValueError("Query parameters are not allowed " + "when using legacy SQL") + + if labels: + _api_resource_configs_duplication_check( + 'labels', labels, configuration) + configuration['labels'] = labels + + if encryption_configuration: + configuration["query"][ + "destinationEncryptionConfiguration" + ] = encryption_configuration + + return self.run_with_configuration(configuration) + + def run_extract( # noqa + self, + source_project_dataset_table: str, + destination_cloud_storage_uris: str, + compression: str = 'NONE', + export_format: str = 'CSV', + field_delimiter: str = ',', + print_header: bool = True, + labels: Optional[Dict] = None) -> str: + """ + Executes a BigQuery extract command to copy data from BigQuery to + Google Cloud Storage. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + For more details about these parameters. + + :param source_project_dataset_table: The dotted ``.
`` + BigQuery table to use as the source data. + :type source_project_dataset_table: str + :param destination_cloud_storage_uris: The destination Google Cloud + Storage URI (e.g. gs://some-bucket/some-file.txt). Follows + convention defined here: + https://cloud.google.com/bigquery/exporting-data-from-bigquery#exportingmultiple + :type destination_cloud_storage_uris: list + :param compression: Type of compression to use. + :type compression: str + :param export_format: File format to export. + :type export_format: str + :param field_delimiter: The delimiter to use when extracting to a CSV. + :type field_delimiter: str + :param print_header: Whether to print a header for a CSV file extract. + :type print_header: bool + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + """ + + source_project, source_dataset, source_table = \ + _split_tablename(table_input=source_project_dataset_table, + default_project_id=self.project_id, + var_name='source_project_dataset_table') + + configuration = { + 'extract': { + 'sourceTable': { + 'projectId': source_project, + 'datasetId': source_dataset, + 'tableId': source_table, + }, + 'compression': compression, + 'destinationUris': destination_cloud_storage_uris, + 'destinationFormat': export_format, + } + } # type: Dict[str, Any] + + if labels: + configuration['labels'] = labels + + if export_format == 'CSV': + # Only set fieldDelimiter and printHeader fields if using CSV. + # Google does not like it if you set these fields for other export + # formats. + configuration['extract']['fieldDelimiter'] = field_delimiter + configuration['extract']['printHeader'] = print_header + + return self.run_with_configuration(configuration) + + def run_copy(self, # pylint: disable=invalid-name + source_project_dataset_tables: Union[List, str], + destination_project_dataset_table: str, + write_disposition: str = 'WRITE_EMPTY', + create_disposition: str = 'CREATE_IF_NEEDED', + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None) -> str: + """ + Executes a BigQuery copy command to copy data from one BigQuery table + to another. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.copy + + For more details about these parameters. + + :param source_project_dataset_tables: One or more dotted + ``(project:|project.).
`` + BigQuery tables to use as the source data. Use a list if there are + multiple source tables. + If ```` is not included, project will be the project defined + in the connection json. + :type source_project_dataset_tables: list|string + :param destination_project_dataset_table: The destination BigQuery + table. Format is: ``(project:|project.).
`` + :type destination_project_dataset_table: str + :param write_disposition: The write disposition if the table already exists. + :type write_disposition: str + :param create_disposition: The create disposition if the table doesn't exist. + :type create_disposition: str + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + """ + source_project_dataset_tables = ([ + source_project_dataset_tables + ] if not isinstance(source_project_dataset_tables, list) else + source_project_dataset_tables) + + source_project_dataset_tables_fixup = [] + for source_project_dataset_table in source_project_dataset_tables: + source_project, source_dataset, source_table = \ + _split_tablename(table_input=source_project_dataset_table, + default_project_id=self.project_id, + var_name='source_project_dataset_table') + source_project_dataset_tables_fixup.append({ + 'projectId': + source_project, + 'datasetId': + source_dataset, + 'tableId': + source_table + }) + + destination_project, destination_dataset, destination_table = \ + _split_tablename(table_input=destination_project_dataset_table, + default_project_id=self.project_id) + configuration = { + 'copy': { + 'createDisposition': create_disposition, + 'writeDisposition': write_disposition, + 'sourceTables': source_project_dataset_tables_fixup, + 'destinationTable': { + 'projectId': destination_project, + 'datasetId': destination_dataset, + 'tableId': destination_table + } + } + } + + if labels: + configuration['labels'] = labels + + if encryption_configuration: + configuration["copy"][ + "destinationEncryptionConfiguration" + ] = encryption_configuration + + return self.run_with_configuration(configuration) + + def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid-name + destination_project_dataset_table: str, + source_uris: List, + schema_fields: Optional[List] = None, + source_format: str = 'CSV', + create_disposition: str = 'CREATE_IF_NEEDED', + skip_leading_rows: int = 0, + write_disposition: str = 'WRITE_EMPTY', + field_delimiter: str = ',', + max_bad_records: int = 0, + quote_character: Optional[str] = None, + ignore_unknown_values: bool = False, + allow_quoted_newlines: bool = False, + allow_jagged_rows: bool = False, + schema_update_options: Optional[Iterable] = None, + src_fmt_configs: Optional[Dict] = None, + time_partitioning: Optional[Dict] = None, + cluster_fields: Optional[List] = None, + autodetect: bool = False, + encryption_configuration: Optional[Dict] = None) -> str: + """ + Executes a BigQuery load command to load data from Google Cloud Storage + to BigQuery. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + For more details about these parameters. + + :param destination_project_dataset_table: + The dotted ``(.|:).
($)`` BigQuery + table to load data into. If ```` is not included, project will be the + project defined in the connection json. If a partition is specified the + operator will automatically append the data, create a new partition or create + a new DAY partitioned table. + :type destination_project_dataset_table: str + :param schema_fields: The schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load + Required if autodetect=False; optional if autodetect=True. + :type schema_fields: list + :param autodetect: Attempt to autodetect the schema for CSV and JSON + source files. + :type autodetect: bool + :param source_uris: The source Google Cloud + Storage URI (e.g. gs://some-bucket/some-file.txt). A single wild + per-object name can be used. + :type source_uris: list + :param source_format: File format to export. + :type source_format: str + :param create_disposition: The create disposition if the table doesn't exist. + :type create_disposition: str + :param skip_leading_rows: Number of rows to skip when loading from a CSV. + :type skip_leading_rows: int + :param write_disposition: The write disposition if the table already exists. + :type write_disposition: str + :param field_delimiter: The delimiter to use when loading from a CSV. + :type field_delimiter: str + :param max_bad_records: The maximum number of bad records that BigQuery can + ignore when running the job. + :type max_bad_records: int + :param quote_character: The value that is used to quote data sections in a CSV + file. + :type quote_character: str + :param ignore_unknown_values: [Optional] Indicates if BigQuery should allow + extra values that are not represented in the table schema. + If true, the extra values are ignored. If false, records with extra columns + are treated as bad records, and if there are too many bad records, an + invalid error is returned in the job result. + :type ignore_unknown_values: bool + :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not + (false). + :type allow_quoted_newlines: bool + :param allow_jagged_rows: Accept rows that are missing trailing optional columns. + The missing values are treated as nulls. If false, records with missing + trailing columns are treated as bad records, and if there are too many bad + records, an invalid error is returned in the job result. Only applicable when + soure_format is CSV. + :type allow_jagged_rows: bool + :param schema_update_options: Allows the schema of the destination + table to be updated as a side effect of the load job. + :type schema_update_options: Union[list, tuple, set] + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + :type time_partitioning: dict + :param cluster_fields: Request that the result of this load be stored sorted + by one or more columns. This is only available in combination with + time_partitioning. The order of columns given determines the sort order. + :type cluster_fields: list[str] + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + """ + # To provide backward compatibility + schema_update_options = list(schema_update_options or []) + + # bigquery only allows certain source formats + # we check to make sure the passed source format is valid + # if it's not, we raise a ValueError + # Refer to this link for more details: + # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat # noqa # pylint: disable=line-too-long + + if schema_fields is None and not autodetect: + raise ValueError( + 'You must either pass a schema or autodetect=True.') + + if src_fmt_configs is None: + src_fmt_configs = {} + + source_format = source_format.upper() + allowed_formats = [ + "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", + "DATASTORE_BACKUP", "PARQUET" + ] + if source_format not in allowed_formats: + raise ValueError("{0} is not a valid source format. " + "Please use one of the following types: {1}" + .format(source_format, allowed_formats)) + + # bigquery also allows you to define how you want a table's schema to change + # as a side effect of a load + # for more details: + # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions + allowed_schema_update_options = [ + 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" + ] + if not set(allowed_schema_update_options).issuperset( + set(schema_update_options)): + raise ValueError( + "{0} contains invalid schema update options." + "Please only use one or more of the following options: {1}" + .format(schema_update_options, allowed_schema_update_options)) + + destination_project, destination_dataset, destination_table = \ + _split_tablename(table_input=destination_project_dataset_table, + default_project_id=self.project_id, + var_name='destination_project_dataset_table') + + configuration = { + 'load': { + 'autodetect': autodetect, + 'createDisposition': create_disposition, + 'destinationTable': { + 'projectId': destination_project, + 'datasetId': destination_dataset, + 'tableId': destination_table, + }, + 'sourceFormat': source_format, + 'sourceUris': source_uris, + 'writeDisposition': write_disposition, + 'ignoreUnknownValues': ignore_unknown_values + } + } + + time_partitioning = _cleanse_time_partitioning( + destination_project_dataset_table, + time_partitioning + ) + if time_partitioning: + configuration['load'].update({ + 'timePartitioning': time_partitioning + }) + + if cluster_fields: + configuration['load'].update({'clustering': {'fields': cluster_fields}}) + + if schema_fields: + configuration['load']['schema'] = {'fields': schema_fields} + + if schema_update_options: + if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: + raise ValueError("schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") + else: + self.log.info( + "Adding experimental 'schemaUpdateOptions': %s", + schema_update_options + ) + configuration['load'][ + 'schemaUpdateOptions'] = schema_update_options + + if max_bad_records: + configuration['load']['maxBadRecords'] = max_bad_records + + if encryption_configuration: + configuration["load"][ + "destinationEncryptionConfiguration" + ] = encryption_configuration + + # if following fields are not specified in src_fmt_configs, + # honor the top-level params for backward-compatibility + if 'skipLeadingRows' not in src_fmt_configs: + src_fmt_configs['skipLeadingRows'] = skip_leading_rows + if 'fieldDelimiter' not in src_fmt_configs: + src_fmt_configs['fieldDelimiter'] = field_delimiter + if 'ignoreUnknownValues' not in src_fmt_configs: + src_fmt_configs['ignoreUnknownValues'] = ignore_unknown_values + if quote_character is not None: + src_fmt_configs['quote'] = quote_character + if allow_quoted_newlines: + src_fmt_configs['allowQuotedNewlines'] = allow_quoted_newlines + + src_fmt_to_configs_mapping = { + 'CSV': [ + 'allowJaggedRows', 'allowQuotedNewlines', 'autodetect', + 'fieldDelimiter', 'skipLeadingRows', 'ignoreUnknownValues', + 'nullMarker', 'quote' + ], + 'DATASTORE_BACKUP': ['projectionFields'], + 'NEWLINE_DELIMITED_JSON': ['autodetect', 'ignoreUnknownValues'], + 'PARQUET': ['autodetect', 'ignoreUnknownValues'], + 'AVRO': ['useAvroLogicalTypes'], + } + + valid_configs = src_fmt_to_configs_mapping[source_format] + + # if following fields are not specified in src_fmt_configs, + # honor the top-level params for backward-compatibility + backward_compatibility_configs = {'skipLeadingRows': skip_leading_rows, + 'fieldDelimiter': field_delimiter, + 'ignoreUnknownValues': ignore_unknown_values, + 'quote': quote_character, + 'allowQuotedNewlines': allow_quoted_newlines} + + src_fmt_configs = _validate_src_fmt_configs(source_format, src_fmt_configs, valid_configs, + backward_compatibility_configs) + + configuration['load'].update(src_fmt_configs) + + if allow_jagged_rows: + configuration['load']['allowJaggedRows'] = allow_jagged_rows + + return self.run_with_configuration(configuration) + + def run_with_configuration(self, configuration: Dict) -> str: + """ + Executes a BigQuery SQL query. See here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + For more details about the configuration parameter. + + :param configuration: The configuration parameter maps directly to + BigQuery's configuration field in the job object. See + https://cloud.google.com/bigquery/docs/reference/v2/jobs for + details. + """ + jobs = self.service.jobs() # type: Any + job_data = {'configuration': configuration} # type: Dict[str, Dict] + + # Send query and wait for reply. + query_reply = jobs \ + .insert(projectId=self.project_id, body=job_data) \ + .execute(num_retries=self.num_retries) + self.running_job_id = query_reply['jobReference']['jobId'] + if 'location' in query_reply['jobReference']: + location = query_reply['jobReference']['location'] + else: + location = self.location + + # Wait for query to finish. + keep_polling_job = True # type: bool + while keep_polling_job: + try: + keep_polling_job = self._check_query_status(jobs, keep_polling_job, location) + + except HttpError as err: + if err.resp.status in [500, 503]: + self.log.info( + '%s: Retryable error, waiting for job to complete: %s', + err.resp.status, self.running_job_id) + time.sleep(5) + else: + raise Exception( + 'BigQuery job status check failed. Final error was: {}'. + format(err.resp.status)) + + return self.running_job_id # type: ignore + + def _check_query_status(self, jobs: Any, keep_polling_job: bool, location: str) -> bool: + if location: + job = jobs.get( + projectId=self.project_id, + jobId=self.running_job_id, + location=location).execute(num_retries=self.num_retries) + else: + job = jobs.get( + projectId=self.project_id, + jobId=self.running_job_id).execute(num_retries=self.num_retries) + + if job['status']['state'] == 'DONE': + keep_polling_job = False + # Check if job had errors. + if 'errorResult' in job['status']: + raise Exception( + 'BigQuery job failed. Final error was: {}. The job was: {}'.format( + job['status']['errorResult'], job)) + else: + self.log.info('Waiting for job to complete : %s, %s', + self.project_id, self.running_job_id) + time.sleep(5) + return keep_polling_job + + def poll_job_complete(self, job_id: str) -> bool: + """ + Check if jobs completed. + + :param job_id: id of the job. + :type job_id: str + :rtype: bool + """ + jobs = self.service.jobs() + try: + if self.location: + job = jobs.get(projectId=self.project_id, + jobId=job_id, + location=self.location).execute(num_retries=self.num_retries) + else: + job = jobs.get(projectId=self.project_id, + jobId=job_id).execute(num_retries=self.num_retries) + if job['status']['state'] == 'DONE': + return True + except HttpError as err: + if err.resp.status in [500, 503]: + self.log.info( + '%s: Retryable error while polling job with id %s', + err.resp.status, job_id) + else: + raise Exception( + 'BigQuery job status check failed. Final error was: {}'. + format(err.resp.status)) + return False + + def cancel_query(self) -> None: + """ + Cancel all started queries that have not yet completed + """ + jobs = self.service.jobs() + if (self.running_job_id and + not self.poll_job_complete(self.running_job_id)): + self.log.info('Attempting to cancel job : %s, %s', self.project_id, + self.running_job_id) + if self.location: + jobs.cancel( + projectId=self.project_id, + jobId=self.running_job_id, + location=self.location).execute(num_retries=self.num_retries) + else: + jobs.cancel( + projectId=self.project_id, + jobId=self.running_job_id).execute(num_retries=self.num_retries) + else: + self.log.info('No running BigQuery jobs to cancel.') + return + + # Wait for all the calls to cancel to finish + max_polling_attempts = 12 + polling_attempts = 0 + + job_complete = False + while polling_attempts < max_polling_attempts and not job_complete: + polling_attempts = polling_attempts + 1 + job_complete = self.poll_job_complete(self.running_job_id) + if job_complete: + self.log.info('Job successfully canceled: %s, %s', + self.project_id, self.running_job_id) + elif polling_attempts == max_polling_attempts: + self.log.info( + "Stopping polling due to timeout. Job with id %s " + "has not completed cancel and may or may not finish.", + self.running_job_id) + else: + self.log.info('Waiting for canceled job with id %s to finish.', + self.running_job_id) + time.sleep(5) + + def get_schema(self, dataset_id: str, table_id: str) -> Dict: + """ + Get the schema for a given datset.table. + see https://cloud.google.com/bigquery/docs/reference/v2/tables#resource + + :param dataset_id: the dataset ID of the requested table + :param table_id: the table ID of the requested table + :return: a table schema + """ + tables_resource = self.service.tables() \ + .get(projectId=self.project_id, datasetId=dataset_id, tableId=table_id) \ + .execute(num_retries=self.num_retries) + return tables_resource['schema'] + + def get_tabledata(self, dataset_id: str, table_id: str, + max_results: Optional[int] = None, selected_fields: Optional[str] = None, + page_token: Optional[str] = None, start_index: Optional[int] = None) -> Dict: + """ + Get the data of a given dataset.table and optionally with selected columns. + see https://cloud.google.com/bigquery/docs/reference/v2/tabledata/list + + :param dataset_id: the dataset ID of the requested table. + :param table_id: the table ID of the requested table. + :param max_results: the maximum results to return. + :param selected_fields: List of fields to return (comma-separated). If + unspecified, all fields are returned. + :param page_token: page token, returned from a previous call, + identifying the result set. + :param start_index: zero based index of the starting row to read. + :return: map containing the requested rows. + """ + optional_params = {} # type: Dict[str, Any] + if self.location: + optional_params['location'] = self.location + if max_results: + optional_params['maxResults'] = max_results + if selected_fields: + optional_params['selectedFields'] = selected_fields + if page_token: + optional_params['pageToken'] = page_token + if start_index: + optional_params['startIndex'] = start_index + return (self.service.tabledata().list( + projectId=self.project_id, + datasetId=dataset_id, + tableId=table_id, + **optional_params).execute(num_retries=self.num_retries)) + + def run_table_delete(self, deletion_dataset_table: str, + ignore_if_missing: bool = False) -> None: + """ + Delete an existing table from the dataset; + If the table does not exist, return an error unless ignore_if_missing + is set to True. + + :param deletion_dataset_table: A dotted + ``(.|:).
`` that indicates which table + will be deleted. + :type deletion_dataset_table: str + :param ignore_if_missing: if True, then return success even if the + requested table does not exist. + :type ignore_if_missing: bool + :return: + """ + deletion_project, deletion_dataset, deletion_table = \ + _split_tablename(table_input=deletion_dataset_table, + default_project_id=self.project_id) + + try: + self.service.tables() \ + .delete(projectId=deletion_project, + datasetId=deletion_dataset, + tableId=deletion_table) \ + .execute(num_retries=self.num_retries) + self.log.info('Deleted table %s:%s.%s.', deletion_project, + deletion_dataset, deletion_table) + except HttpError: + if not ignore_if_missing: + raise Exception('Table deletion failed. Table does not exist.') + else: + self.log.info('Table does not exist. Skipping.') + + def run_table_upsert(self, dataset_id: str, table_resource: Dict, + project_id: Optional[str] = None) -> Dict: + """ + creates a new, empty table in the dataset; + If the table already exists, update the existing table. + Since BigQuery does not natively allow table upserts, this is not an + atomic operation. + + :param dataset_id: the dataset to upsert the table into. + :type dataset_id: str + :param table_resource: a table resource. see + https://cloud.google.com/bigquery/docs/reference/v2/tables#resource + :type table_resource: dict + :param project_id: the project to upsert the table into. If None, + project will be self.project_id. + :return: + """ + # check to see if the table exists + table_id = table_resource['tableReference']['tableId'] + project_id = project_id if project_id is not None else self.project_id + tables_list_resp = self.service.tables().list( + projectId=project_id, datasetId=dataset_id).execute(num_retries=self.num_retries) + while True: + for table in tables_list_resp.get('tables', []): + if table['tableReference']['tableId'] == table_id: + # found the table, do update + self.log.info('Table %s:%s.%s exists, updating.', + project_id, dataset_id, table_id) + return self.service.tables().update( + projectId=project_id, + datasetId=dataset_id, + tableId=table_id, + body=table_resource).execute(num_retries=self.num_retries) + # If there is a next page, we need to check the next page. + if 'nextPageToken' in tables_list_resp: + tables_list_resp = self.service.tables()\ + .list(projectId=project_id, + datasetId=dataset_id, + pageToken=tables_list_resp['nextPageToken'])\ + .execute(num_retries=self.num_retries) + # If there is no next page, then the table doesn't exist. + else: + # do insert + self.log.info('Table %s:%s.%s does not exist. creating.', + project_id, dataset_id, table_id) + return self.service.tables().insert( + projectId=project_id, + datasetId=dataset_id, + body=table_resource).execute(num_retries=self.num_retries) + + def run_grant_dataset_view_access(self, + source_dataset: str, + view_dataset: str, + view_table: str, + source_project: Optional[str] = None, + view_project: Optional[str] = None) -> Dict: + """ + Grant authorized view access of a dataset to a view table. + If this view has already been granted access to the dataset, do nothing. + This method is not atomic. Running it may clobber a simultaneous update. + + :param source_dataset: the source dataset + :type source_dataset: str + :param view_dataset: the dataset that the view is in + :type view_dataset: str + :param view_table: the table of the view + :type view_table: str + :param source_project: the project of the source dataset. If None, + self.project_id will be used. + :type source_project: str + :param view_project: the project that the view is in. If None, + self.project_id will be used. + :type view_project: str + :return: the datasets resource of the source dataset. + """ + + # Apply default values to projects + source_project = source_project if source_project else self.project_id + view_project = view_project if view_project else self.project_id + + # we don't want to clobber any existing accesses, so we have to get + # info on the dataset before we can add view access + source_dataset_resource = self.service.datasets().get( + projectId=source_project, datasetId=source_dataset).execute(num_retries=self.num_retries) + access = source_dataset_resource[ + 'access'] if 'access' in source_dataset_resource else [] + view_access = { + 'view': { + 'projectId': view_project, + 'datasetId': view_dataset, + 'tableId': view_table + } + } + # check to see if the view we want to add already exists. + if view_access not in access: + self.log.info( + 'Granting table %s:%s.%s authorized view access to %s:%s dataset.', + view_project, view_dataset, view_table, source_project, + source_dataset) + access.append(view_access) + return self.service.datasets().patch( + projectId=source_project, + datasetId=source_dataset, + body={ + 'access': access + }).execute(num_retries=self.num_retries) + else: + # if view is already in access, do nothing. + self.log.info( + 'Table %s:%s.%s already has authorized view access to %s:%s dataset.', + view_project, view_dataset, view_table, source_project, source_dataset) + return source_dataset_resource + + def create_empty_dataset(self, + dataset_id: str = "", + project_id: str = "", + location: Optional[str] = None, + dataset_reference: Optional[Dict] = None) -> None: + """ + Create a new empty dataset: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert + + :param project_id: The name of the project where we want to create + an empty a dataset. Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in dataset_reference. + :type dataset_id: str + :param location: (Optional) The geographic location where the dataset should reside. + There is no default value but the dataset will be created in US if nothing is provided. + :type location: str + :param dataset_reference: Dataset reference that could be provided + with request body. More info: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_reference: dict + """ + + if dataset_reference: + _validate_value('dataset_reference', dataset_reference, dict) + else: + dataset_reference = {} + + if "datasetReference" not in dataset_reference: + dataset_reference["datasetReference"] = {} + + if self.location: + dataset_reference['location'] = dataset_reference.get('location') or self.location + + if not dataset_reference["datasetReference"].get("datasetId") and not dataset_id: + raise ValueError( + "{} not provided datasetId. Impossible to create dataset") + + dataset_required_params = [(dataset_id, "datasetId", ""), + (project_id, "projectId", self.project_id)] + for param_tuple in dataset_required_params: + param, param_name, param_default = param_tuple + if param_name not in dataset_reference['datasetReference']: + if param_default and not param: + self.log.info( + "%s was not specified. Will be used default value %s.", + param_name, param_default + ) + param = param_default + dataset_reference['datasetReference'].update( + {param_name: param}) + elif param: + _api_resource_configs_duplication_check( + param_name, param, + dataset_reference['datasetReference'], 'dataset_reference') + + if location: + if 'location' not in dataset_reference: + dataset_reference.update({'location': location}) + else: + _api_resource_configs_duplication_check( + 'location', location, + dataset_reference, 'dataset_reference') + + dataset_id = dataset_reference.get("datasetReference").get("datasetId") # type: ignore + dataset_project_id = dataset_reference.get("datasetReference").get("projectId") # type: ignore + + self.log.info('Creating Dataset: %s in project: %s ', dataset_id, + dataset_project_id) + + try: + self.service.datasets().insert( + projectId=dataset_project_id, + body=dataset_reference).execute(num_retries=self.num_retries) + self.log.info('Dataset created successfully: In project %s ' + 'Dataset %s', dataset_project_id, dataset_id) + + except HttpError as err: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(err.content) + ) + + def delete_dataset(self, project_id: str, dataset_id: str, delete_contents: bool = False) -> None: + """ + Delete a dataset of Big query in your project. + + :param project_id: The name of the project where we have the dataset . + :type project_id: str + :param dataset_id: The dataset to be delete. + :type dataset_id: str + :param delete_contents: [Optional] Whether to force the deletion even if the dataset is not empty. + Will delete all tables (if any) in the dataset if set to True. + Will raise HttpError 400: "{dataset_id} is still in use" if set to False and dataset is not empty. + The default value is False. + :type delete_contents: bool + :return: + """ + project_id = project_id if project_id is not None else self.project_id + self.log.info('Deleting from project: %s Dataset:%s', + project_id, dataset_id) + + try: + self.service.datasets().delete( + projectId=project_id, + datasetId=dataset_id, + deleteContents=delete_contents).execute(num_retries=self.num_retries) + self.log.info('Dataset deleted successfully: In project %s ' + 'Dataset %s', project_id, dataset_id) + + except HttpError as err: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(err.content) + ) + + def get_dataset(self, dataset_id: str, project_id: Optional[str] = None) -> Dict: + """ + Method returns dataset_resource if dataset exist + and raised 404 error if dataset does not exist + + :param dataset_id: The BigQuery Dataset ID + :type dataset_id: str + :param project_id: The GCP Project ID + :type project_id: str + :return: dataset_resource + + .. seealso:: + For more information, see Dataset Resource content: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + if not dataset_id or not isinstance(dataset_id, str): + raise ValueError("dataset_id argument must be provided and has " + "a type 'str'. You provided: {}".format(dataset_id)) + + dataset_project_id = project_id if project_id else self.project_id + + try: + dataset_resource = self.service.datasets().get( + datasetId=dataset_id, projectId=dataset_project_id).execute(num_retries=self.num_retries) + self.log.info("Dataset Resource: %s", dataset_resource) + except HttpError as err: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(err.content)) + + return dataset_resource + + def get_datasets_list(self, project_id: Optional[str] = None) -> List: + """ + Method returns full list of BigQuery datasets in the current project + + .. seealso:: + For more information, see: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/list + + :param project_id: Google Cloud Project for which you + try to get all datasets + :type project_id: str + :return: datasets_list + + Example of returned datasets_list: :: + + { + "kind":"bigquery#dataset", + "location":"US", + "id":"your-project:dataset_2_test", + "datasetReference":{ + "projectId":"your-project", + "datasetId":"dataset_2_test" + } + }, + { + "kind":"bigquery#dataset", + "location":"US", + "id":"your-project:dataset_1_test", + "datasetReference":{ + "projectId":"your-project", + "datasetId":"dataset_1_test" + } + } + ] + """ + dataset_project_id = project_id if project_id else self.project_id + + try: + datasets_list = self.service.datasets().list( + projectId=dataset_project_id).execute(num_retries=self.num_retries)['datasets'] + self.log.info("Datasets List: %s", datasets_list) + + except HttpError as err: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(err.content)) + + return datasets_list + + def patch_dataset(self, dataset_id: str, dataset_resource: str, project_id: Optional[str] = None) -> Dict: + """ + Patches information in an existing dataset. + It only replaces fields that are provided in the submitted dataset resource. + More info: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/patch + + :param dataset_id: The BigQuery Dataset ID + :type dataset_id: str + :param dataset_resource: Dataset resource that will be provided + in request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_resource: dict + :param project_id: The GCP Project ID + :type project_id: str + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + if not dataset_id or not isinstance(dataset_id, str): + raise ValueError( + "dataset_id argument must be provided and has " + "a type 'str'. You provided: {}".format(dataset_id) + ) + + dataset_project_id = project_id if project_id else self.project_id + + try: + dataset = ( + self.service.datasets() + .patch( + datasetId=dataset_id, + projectId=dataset_project_id, + body=dataset_resource, + ) + .execute(num_retries=self.num_retries) + ) + self.log.info("Dataset successfully patched: %s", dataset) + except HttpError as err: + raise AirflowException( + "BigQuery job failed. Error was: {}".format(err.content) + ) + + return dataset + + def update_dataset(self, dataset_id: str, + dataset_resource: Dict, project_id: Optional[str] = None) -> Dict: + """ + Updates information in an existing dataset. The update method replaces the entire + dataset resource, whereas the patch method only replaces fields that are provided + in the submitted dataset resource. + More info: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/update + + :param dataset_id: The BigQuery Dataset ID + :type dataset_id: str + :param dataset_resource: Dataset resource that will be provided + in request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_resource: dict + :param project_id: The GCP Project ID + :type project_id: str + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + if not dataset_id or not isinstance(dataset_id, str): + raise ValueError( + "dataset_id argument must be provided and has " + "a type 'str'. You provided: {}".format(dataset_id) + ) + + dataset_project_id = project_id if project_id else self.project_id + + try: + dataset = ( + self.service.datasets() + .update( + datasetId=dataset_id, + projectId=dataset_project_id, + body=dataset_resource, + ) + .execute(num_retries=self.num_retries) + ) + self.log.info("Dataset successfully updated: %s", dataset) + except HttpError as err: + raise AirflowException( + "BigQuery job failed. Error was: {}".format(err.content) + ) + + return dataset + + def insert_all(self, project_id: str, dataset_id: str, table_id: str, + rows: List, ignore_unknown_values: bool = False, + skip_invalid_rows: bool = False, fail_on_error: bool = False) -> None: + """ + Method to stream data into BigQuery one record at a time without needing + to run a load job + + .. seealso:: + For more information, see: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/insertAll + + :param project_id: The name of the project where we have the table + :type project_id: str + :param dataset_id: The name of the dataset where we have the table + :type dataset_id: str + :param table_id: The name of the table + :type table_id: str + :param rows: the rows to insert + :type rows: list + + **Example or rows**: + rows=[{"json": {"a_key": "a_value_0"}}, {"json": {"a_key": "a_value_1"}}] + + :param ignore_unknown_values: [Optional] Accept rows that contain values + that do not match the schema. The unknown values are ignored. + The default value is false, which treats unknown values as errors. + :type ignore_unknown_values: bool + :param skip_invalid_rows: [Optional] Insert all valid rows of a request, + even if invalid rows exist. The default value is false, which causes + the entire request to fail if any invalid rows exist. + :type skip_invalid_rows: bool + :param fail_on_error: [Optional] Force the task to fail if any errors occur. + The default value is false, which indicates the task should not fail + even if any insertion errors occur. + :type fail_on_error: bool + """ + + dataset_project_id = project_id if project_id else self.project_id + + body = { + "rows": rows, + "ignoreUnknownValues": ignore_unknown_values, + "kind": "bigquery#tableDataInsertAllRequest", + "skipInvalidRows": skip_invalid_rows, + } + + try: + self.log.info( + 'Inserting %s row(s) into Table %s:%s.%s', + len(rows), dataset_project_id, dataset_id, table_id + ) + + resp = self.service.tabledata().insertAll( + projectId=dataset_project_id, datasetId=dataset_id, + tableId=table_id, body=body + ).execute(num_retries=self.num_retries) + + if 'insertErrors' not in resp: + self.log.info( + 'All row(s) inserted successfully: %s:%s.%s', + dataset_project_id, dataset_id, table_id + ) + else: + error_msg = '{} insert error(s) occurred: {}:{}.{}. Details: {}'.format( + len(resp['insertErrors']), + dataset_project_id, dataset_id, table_id, resp['insertErrors']) + if fail_on_error: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(error_msg) + ) + self.log.info(error_msg) + except HttpError as err: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(err.content) + ) + + +class BigQueryCursor(BigQueryBaseCursor): + """ + A very basic BigQuery PEP 249 cursor implementation. The PyHive PEP 249 + implementation was used as a reference: + + https://github.com/dropbox/PyHive/blob/master/pyhive/presto.py + https://github.com/dropbox/PyHive/blob/master/pyhive/common.py + """ + + def __init__( + self, + service: Any, + project_id: str, + use_legacy_sql: bool = True, + location: Optional[str] = None, + num_retries: int = 5, + ) -> None: + super().__init__( + service=service, + project_id=project_id, + use_legacy_sql=use_legacy_sql, + location=location, + num_retries=num_retries + ) + self.buffersize = None # type: Optional[int] + self.page_token = None # type: Optional[str] + self.job_id = None # type: Optional[str] + self.buffer = [] # type: list + self.all_pages_loaded = False # type: bool + + @property + def description(self) -> NoReturn: + """ The schema description method is not currently implemented. """ + raise NotImplementedError + + def close(self) -> None: + """ By default, do nothing """ + + @property + def rowcount(self) -> int: + """ By default, return -1 to indicate that this is not supported. """ + return -1 + + def execute(self, operation: str, parameters: Optional[Dict] = None) -> None: + """ + Executes a BigQuery query, and returns the job ID. + + :param operation: The query to execute. + :type operation: str + :param parameters: Parameters to substitute into the query. + :type parameters: dict + """ + sql = _bind_parameters(operation, + parameters) if parameters else operation + self.job_id = self.run_query(sql) + + def executemany(self, operation: str, seq_of_parameters: List) -> None: + """ + Execute a BigQuery query multiple times with different parameters. + + :param operation: The query to execute. + :type operation: str + :param seq_of_parameters: List of dictionary parameters to substitute into the + query. + :type seq_of_parameters: list + """ + for parameters in seq_of_parameters: + self.execute(operation, parameters) + + def fetchone(self) -> Union[List, None]: + """ Fetch the next row of a query result set. """ + return self.next() + + def next(self) -> Union[List, None]: + """ + Helper method for fetchone, which returns the next row from a buffer. + If the buffer is empty, attempts to paginate through the result set for + the next page, and load it into the buffer. + """ + if not self.job_id: + return None + + if not self.buffer: + if self.all_pages_loaded: + return None + + query_results = (self.service.jobs().getQueryResults( + projectId=self.project_id, + jobId=self.job_id, + pageToken=self.page_token).execute(num_retries=self.num_retries)) + + if 'rows' in query_results and query_results['rows']: + self.page_token = query_results.get('pageToken') + fields = query_results['schema']['fields'] + col_types = [field['type'] for field in fields] + rows = query_results['rows'] + + for dict_row in rows: + typed_row = ([ + _bq_cast(vs['v'], col_types[idx]) + for idx, vs in enumerate(dict_row['f']) + ]) + self.buffer.append(typed_row) + + if not self.page_token: + self.all_pages_loaded = True + + else: + # Reset all state since we've exhausted the results. + self.page_token = None + self.job_id = None + self.page_token = None + return None + + return self.buffer.pop(0) + + def fetchmany(self, size: Optional[int] = None) -> List: + """ + Fetch the next set of rows of a query result, returning a sequence of sequences + (e.g. a list of tuples). An empty sequence is returned when no more rows are + available. The number of rows to fetch per call is specified by the parameter. + If it is not given, the cursor's arraysize determines the number of rows to be + fetched. The method should try to fetch as many rows as indicated by the size + parameter. If this is not possible due to the specified number of rows not being + available, fewer rows may be returned. An :py:class:`~pyhive.exc.Error` + (or subclass) exception is raised if the previous call to + :py:meth:`execute` did not produce any result set or no call was issued yet. + """ + if size is None: + size = self.arraysize + result = [] + for _ in range(size): + one = self.fetchone() + if one is None: + break + else: + result.append(one) + return result + + def fetchall(self) -> List[List]: + """ + Fetch all (remaining) rows of a query result, returning them as a sequence of + sequences (e.g. a list of tuples). + """ + result = [] + while True: + one = self.fetchone() + if one is None: + break + else: + result.append(one) + return result + + def get_arraysize(self) -> int: + """ Specifies the number of rows to fetch at a time with .fetchmany() """ + return self.buffersize or 1 + + def set_arraysize(self, arraysize: int) -> None: + """ Specifies the number of rows to fetch at a time with .fetchmany() """ + self.buffersize = arraysize + + arraysize = property(get_arraysize, set_arraysize) + + def setinputsizes(self, sizes: Any) -> None: + """ Does nothing by default """ + + def setoutputsize(self, size: Any, column: Any = None) -> None: + """ Does nothing by default """ + + +def _bind_parameters(operation: str, parameters: Dict) -> str: + """ Helper method that binds parameters to a SQL query. """ + # inspired by MySQL Python Connector (conversion.py) + string_parameters = {} # type Dict[str, str] + for (name, value) in parameters.items(): + if value is None: + string_parameters[name] = 'NULL' + elif isinstance(value, str): + string_parameters[name] = "'" + _escape(value) + "'" + else: + string_parameters[name] = str(value) + return operation % string_parameters + + +def _escape(s: str) -> str: + """ Helper method that escapes parameters to a SQL query. """ + e = s + e = e.replace('\\', '\\\\') + e = e.replace('\n', '\\n') + e = e.replace('\r', '\\r') + e = e.replace("'", "\\'") + e = e.replace('"', '\\"') + return e + + +def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, str]: + """ + Helper method that casts a BigQuery row to the appropriate data types. + This is useful because BigQuery returns all fields as strings. + """ + if string_field is None: + return None + elif bq_type == 'INTEGER': + return int(string_field) + elif bq_type in ('FLOAT', 'TIMESTAMP'): + return float(string_field) + elif bq_type == 'BOOLEAN': + if string_field not in ['true', 'false']: + raise ValueError("{} must have value 'true' or 'false'".format( + string_field)) + return string_field == 'true' + else: + return string_field + + +def _split_tablename(table_input: str, default_project_id: str, + var_name: Optional[str] = None) -> Tuple[str, str, str]: + + if '.' not in table_input: + raise ValueError( + 'Expected target table name in the format of ' + '.
. Got: {}'.format(table_input)) + + if not default_project_id: + raise ValueError("INTERNAL: No default project is specified") + + def var_print(var_name): + if var_name is None: + return "" + else: + return "Format exception for {var}: ".format(var=var_name) + + if table_input.count('.') + table_input.count(':') > 3: + raise Exception(('{var}Use either : or . to specify project ' + 'got {input}').format( + var=var_print(var_name), input=table_input)) + cmpt = table_input.rsplit(':', 1) + project_id = None + rest = table_input + if len(cmpt) == 1: + project_id = None + rest = cmpt[0] + elif len(cmpt) == 2 and cmpt[0].count(':') <= 1: + if cmpt[-1].count('.') != 2: + project_id = cmpt[0] + rest = cmpt[1] + else: + raise Exception(('{var}Expect format of (.
, ' + 'got {input}').format( + var=var_print(var_name), input=table_input)) + + cmpt = rest.split('.') + if len(cmpt) == 3: + if project_id: + raise ValueError( + "{var}Use either : or . to specify project".format( + var=var_print(var_name))) + project_id = cmpt[0] + dataset_id = cmpt[1] + table_id = cmpt[2] + + elif len(cmpt) == 2: + dataset_id = cmpt[0] + table_id = cmpt[1] + else: + raise Exception( + ('{var}Expect format of (.
, ' + 'got {input}').format(var=var_print(var_name), input=table_input)) + + if project_id is None: + if var_name is not None: + log = LoggingMixin().log + log.info( + 'Project not included in %s: %s; using project "%s"', + var_name, table_input, default_project_id + ) + project_id = default_project_id + + return project_id, dataset_id, table_id + + +def _cleanse_time_partitioning( + destination_dataset_table: Optional[str], time_partitioning_in: Optional[Dict] +) -> Dict: # if it is a partitioned table ($ is in the table name) add partition load option + + if time_partitioning_in is None: + time_partitioning_in = {} + + time_partitioning_out = {} + if destination_dataset_table and '$' in destination_dataset_table: + time_partitioning_out['type'] = 'DAY' + time_partitioning_out.update(time_partitioning_in) + return time_partitioning_out + + +def _validate_value(key: Any, value: Any, expected_type: Type) -> None: + """ function to check expected type and raise + error if type is not correct """ + if not isinstance(value, expected_type): + raise TypeError("{} argument must have a type {} not {}".format( + key, expected_type, type(value))) + + +def _api_resource_configs_duplication_check(key: Any, value: Any, config_dict: Dict, + config_dict_name='api_resource_configs') -> None: + if key in config_dict and value != config_dict[key]: + raise ValueError("Values of {param_name} param are duplicated. " + "{dict_name} contained {param_name} param " + "in `query` config and {param_name} was also provided " + "with arg to run_query() method. Please remove duplicates." + .format(param_name=key, dict_name=config_dict_name)) + + +def _validate_src_fmt_configs(source_format: str, + src_fmt_configs: Dict, + valid_configs: List[str], + backward_compatibility_configs: Optional[Dict] = None) -> Dict: + """ + Validates the given src_fmt_configs against a valid configuration for the source format. + Adds the backward compatiblity config to the src_fmt_configs. + + :param source_format: File format to export. + :type source_format: str + :param src_fmt_configs: Configure optional fields specific to the source format. + :type src_fmt_configs: dict + :param valid_configs: Valid configuration specific to the source format + :type valid_configs: List[str] + :param backward_compatibility_configs: The top-level params for backward-compatibility + :type backward_compatibility_configs: dict + """ + + if backward_compatibility_configs is None: + backward_compatibility_configs = {} + + for k, v in backward_compatibility_configs.items(): + if k not in src_fmt_configs and k in valid_configs: + src_fmt_configs[k] = v + + for k, v in src_fmt_configs.items(): + if k not in valid_configs: + raise ValueError("{0} is not a valid src_fmt_configs for type {1}." + .format(k, source_format)) + + return src_fmt_configs diff --git a/airflow/gcp/hooks/bigquery_dts.py b/airflow/gcp/hooks/bigquery_dts.py new file mode 100644 index 00000000000000..ca53b95f4341d4 --- /dev/null +++ b/airflow/gcp/hooks/bigquery_dts.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains a BigQuery Hook. +""" +from typing import Union, Sequence, Tuple +from copy import copy + +from google.protobuf.json_format import MessageToDict, ParseDict +from google.api_core.retry import Retry +from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient +from google.cloud.bigquery_datatransfer_v1.types import ( + TransferConfig, + StartManualTransferRunsResponse, + TransferRun, +) + +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook + + +def get_object_id(obj: dict) -> str: + """ + Returns unique id of the object. + """ + return obj["name"].rpartition("/")[-1] + + +class BiqQueryDataTransferServiceHook(GoogleCloudBaseHook): + """ + Hook for Google Bigquery Transfer API. + + All the methods in the hook where ``project_id`` is used must be called with + keyword arguments rather than positional. + + """ + + _conn = None + + def __init__( + self, gcp_conn_id: str = "google_cloud_default", delegate_to: str = None + ) -> None: + super().__init__(gcp_conn_id=gcp_conn_id, delegate_to=delegate_to) + + @staticmethod + def _disable_auto_scheduling(config: Union[dict, TransferConfig]) -> TransferConfig: + """ + In the case of Airflow, the customer needs to create a transfer config + with the automatic scheduling disabled (UI, CLI or an Airflow operator) and + then trigger a transfer run using a specialized Airflow operator that will + call start_manual_transfer_runs. + + :param config: Data transfer configuration to create. + :type config: Union[dict, google.cloud.bigquery_datatransfer_v1.types.TransferConfig] + """ + config = MessageToDict(config) if isinstance(config, TransferConfig) else config + new_config = copy(config) + schedule_options = new_config.get("schedule_options") + if schedule_options: + disable_auto_scheduling = schedule_options.get( + "disable_auto_scheduling", None + ) + if disable_auto_scheduling is None: + schedule_options["disable_auto_scheduling"] = True + else: + new_config["schedule_options"] = {"disable_auto_scheduling": True} + return ParseDict(new_config, TransferConfig()) + + def get_conn(self) -> DataTransferServiceClient: + """ + Retrieves connection to Google Bigquery. + + :return: Google Bigquery API client + :rtype: google.cloud.bigquery_datatransfer_v1.DataTransferServiceClient + """ + if not self._conn: + self._conn = DataTransferServiceClient( + credentials=self._get_credentials(), client_info=self.client_info + ) + return self._conn + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def create_transfer_config( + self, + transfer_config: Union[dict, TransferConfig], + project_id: str = None, + authorization_code: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ) -> TransferConfig: + """ + Creates a new data transfer configuration. + + :param transfer_config: Data transfer configuration to create. + :type transfer_config: Union[dict, google.cloud.bigquery_datatransfer_v1.types.TransferConfig] + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param authorization_code: authorization code to use with this transfer configuration. + This is required if new credentials are needed. + :type authorization_code: Optional[str] + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: A ``google.cloud.bigquery_datatransfer_v1.types.TransferConfig`` instance. + """ + assert project_id is not None + client = self.get_conn() + parent = client.project_path(project_id) + return client.create_transfer_config( + parent=parent, + transfer_config=self._disable_auto_scheduling(transfer_config), + authorization_code=authorization_code, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def delete_transfer_config( + self, + transfer_config_id: str, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ) -> None: + """ + Deletes transfer configuration. + + :param transfer_config_id: Id of transfer config to be used. + :type transfer_config_id: str + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: None + """ + assert project_id is not None + client = self.get_conn() + name = client.project_transfer_config_path( + project=project_id, transfer_config=transfer_config_id + ) + return client.delete_transfer_config( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def start_manual_transfer_runs( + self, + transfer_config_id: str, + project_id: str = None, + requested_time_range: dict = None, + requested_run_time: dict = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ) -> StartManualTransferRunsResponse: + """ + Start manual transfer runs to be executed now with schedule_time equal + to current time. The transfer runs can be created for a time range where + the run_time is between start_time (inclusive) and end_time + (exclusive), or for a specific run_time. + + :param transfer_config_id: Id of transfer config to be used. + :type transfer_config_id: str + :param requested_time_range: Time range for the transfer runs that should be started. + If a dict is provided, it must be of the same form as the protobuf + message `~google.cloud.bigquery_datatransfer_v1.types.TimeRange` + :type requested_time_range: Union[dict, ~google.cloud.bigquery_datatransfer_v1.types.TimeRange] + :param requested_run_time: Specific run_time for a transfer run to be started. The + requested_run_time must not be in the future. If a dict is provided, it + must be of the same form as the protobuf message + `~google.cloud.bigquery_datatransfer_v1.types.Timestamp` + :type requested_run_time: Union[dict, ~google.cloud.bigquery_datatransfer_v1.types.Timestamp] + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: An ``google.cloud.bigquery_datatransfer_v1.types.StartManualTransferRunsResponse`` instance. + """ + assert project_id is not None + client = self.get_conn() + parent = client.project_transfer_config_path( + project=project_id, transfer_config=transfer_config_id + ) + return client.start_manual_transfer_runs( + parent=parent, + requested_time_range=requested_time_range, + requested_run_time=requested_run_time, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def get_transfer_run( + self, + run_id: str, + transfer_config_id: str, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ) -> TransferRun: + """ + Returns information about the particular transfer run. + + :param run_id: ID of the transfer run. + :type run_id: str + :param transfer_config_id: ID of transfer config to be used. + :type transfer_config_id: str + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance. + """ + assert project_id is not None + client = self.get_conn() + name = client.project_run_path( + project=project_id, transfer_config=transfer_config_id, run=run_id + ) + return client.get_transfer_run( + name=name, retry=retry, timeout=timeout, metadata=metadata + ) diff --git a/airflow/gcp/hooks/cloud_memorystore.py b/airflow/gcp/hooks/cloud_memorystore.py new file mode 100644 index 00000000000000..9b3b1a4876cfff --- /dev/null +++ b/airflow/gcp/hooks/cloud_memorystore.py @@ -0,0 +1,485 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Hooks for Cloud Memorystore service +""" +from typing import Dict, Sequence, Tuple, Union, Optional + +from google.api_core.exceptions import NotFound +from google.api_core.retry import Retry +from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest +from google.cloud.redis_v1.types import FieldMask, InputConfig, Instance, OutputConfig +from google.cloud.redis_v1 import CloudRedisClient +from google.protobuf.json_format import ParseDict + +from airflow import AirflowException, version +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook + + +class CloudMemorystoreHook(GoogleCloudBaseHook): + """ + Hook for Google Cloud Memorystore APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + + def __init__(self, gcp_conn_id: str = "google_cloud_default", delegate_to: str = None): + super().__init__(gcp_conn_id, delegate_to) + self._client = None # type: Optional[CloudRedisClient] + + def get_conn(self,): + """ + Retrieves client library object that allow access to Cloud Memorystore service. + + """ + if not self._client: + self._client = CloudRedisClient(credentials=self._get_credentials()) + return self._client + + @staticmethod + def _append_label(instance: Instance, key: str, val: str) -> Instance: + """ + Append labels to provided Instance type + + Labels must fit the regex ``[a-z]([-a-z0-9]*[a-z0-9])?`` (current + airflow version string follows semantic versioning spec: x.y.z). + + :param instance: The proto to append resource_label airflow + version to + :type instance: google.cloud.container_v1.types.Cluster + :param key: The key label + :type key: str + :param val: + :type val: str + :return: The cluster proto updated with new label + """ + val = val.replace(".", "-").replace("+", "-") + instance.labels.update({key: val}) + return instance + + @GoogleCloudBaseHook.fallback_to_default_project_id + def create_instance( + self, + location: str, + instance_id: str, + instance: Union[Dict, Instance], + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ): + """ + Creates a Redis instance based on the specified tier and memory size. + + By default, the instance is accessible from the project's `default network + `__. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: Required. The logical name of the Redis instance in the customer project with the + following restrictions: + + - Must contain only lowercase letters, numbers, and hyphens. + - Must start with a letter. + - Must be between 1-40 characters. + - Must end with a number or a letter. + - Must be unique within the customer project / location + :type instance_id: str + :param instance: Required. A Redis [Instance] resource + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + parent = CloudRedisClient.location_path(project_id, location) + instance_name = CloudRedisClient.instance_path(project_id, location, instance_id) + try: + instance = client.get_instance( + name=instance_name, retry=retry, timeout=timeout, metadata=metadata + ) + self.log.info("Instance exists. Skipping creation.") + return instance + except NotFound: + self.log.info("Instance not exists.") + + if isinstance(instance, dict): + instance = ParseDict(instance, Instance()) + elif not isinstance(instance, Instance): + raise AirflowException("instance is not instance of Instance type or python dict") + + self._append_label(instance, "airflow-version", "v" + version.version) + + result = client.create_instance( + parent=parent, + instance_id=instance_id, + instance=instance, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + result.result() + self.log.info("Instance created.") + return client.get_instance(name=instance_name, retry=retry, timeout=timeout, metadata=metadata) + + @GoogleCloudBaseHook.fallback_to_default_project_id + def delete_instance( + self, + location: str, + instance: str, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ): + """ + Deletes a specific Redis instance. Instance stops serving and data is deleted. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = CloudRedisClient.instance_path(project_id, location, instance) + self.log.info("Fetching Instance: %s", name) + instance = client.get_instance(name=name, retry=retry, timeout=timeout, metadata=metadata) + + if not instance: + return + + self.log.info("Deleting Instance: %s", name) + result = client.delete_instance(name=name, retry=retry, timeout=timeout, metadata=metadata) + result.result() + self.log.info("Instance deleted: %s", name) + + @GoogleCloudBaseHook.fallback_to_default_project_id + def export_instance( + self, + location: str, + instance: str, + output_config: Union[Dict, OutputConfig], + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ): + """ + Export Redis instance data into a Redis RDB format file in Cloud Storage. + + Redis will continue serving during this operation. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param output_config: Required. Specify data to be exported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.OutputConfig` + :type output_config: Union[Dict, google.cloud.redis_v1.types.OutputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = CloudRedisClient.instance_path(project_id, location, instance) + self.log.info("Exporting Instance: %s", name) + result = client.export_instance( + name=name, output_config=output_config, retry=retry, timeout=timeout, metadata=metadata + ) + result.result() + self.log.info("Instance exported: %s", name) + + @GoogleCloudBaseHook.fallback_to_default_project_id + def failover_instance( + self, + location: str, + instance: str, + data_protection_mode: FailoverInstanceRequest.DataProtectionMode, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ): + """ + Initiates a failover of the master node to current replica node for a specific STANDARD tier Cloud + Memorystore for Redis instance. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param data_protection_mode: Optional. Available data protection modes that the user can choose. If + it's unspecified, data protection mode will be LIMITED_DATA_LOSS by default. + :type data_protection_mode: google.cloud.redis_v1.gapic.enums.FailoverInstanceRequest + .DataProtectionMode + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = CloudRedisClient.instance_path(project_id, location, instance) + self.log.info("Failovering Instance: %s", name) + + result = client.failover_instance( + name=name, + data_protection_mode=data_protection_mode, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + result.result() + self.log.info("Instance failovered: %s", name) + + @GoogleCloudBaseHook.fallback_to_default_project_id + def get_instance( + self, + location: str, + instance: str, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ): + """ + Gets the details of a specific Redis instance. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = CloudRedisClient.instance_path(project_id, location, instance) + result = client.get_instance(name=name, retry=retry, timeout=timeout, metadata=metadata) + self.log.info("Fetched Instance: %s", name) + return result + + @GoogleCloudBaseHook.fallback_to_default_project_id + def import_instance( + self, + location: str, + instance: str, + input_config: Union[Dict, InputConfig], + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ): + """ + Import a Redis RDB snapshot file from Cloud Storage into a Redis instance. + + Redis may stop serving during this operation. Instance state will be IMPORTING for entire operation. + When complete, the instance will contain only data from the imported file. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param input_config: Required. Specify data to be imported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.InputConfig` + :type input_config: Union[Dict, google.cloud.redis_v1.types.InputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + name = CloudRedisClient.instance_path(project_id, location, instance) + self.log.info("Importing Instance: %s", name) + result = client.import_instance( + name=name, input_config=input_config, retry=retry, timeout=timeout, metadata=metadata + ) + result.result() + self.log.info("Instance imported: %s", name) + + @GoogleCloudBaseHook.fallback_to_default_project_id + def list_instances( + self, + location: str, + page_size: int, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ): + """ + Lists all Redis instances owned by a project in either the specified location (region) or all + locations. + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + + If it is specified as ``-`` (wildcard), then all regions available to the project are + queried, and the results are aggregated. + :type location: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per- resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :type page_size: int + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + parent = CloudRedisClient.location_path(project_id, location) + result = client.list_instances( + parent=parent, page_size=page_size, retry=retry, timeout=timeout, metadata=metadata + ) + self.log.info("Fetched instances") + return result + + @GoogleCloudBaseHook.fallback_to_default_project_id + def update_instance( + self, + update_mask: Union[Dict, FieldMask], + instance: Union[Dict, Instance], + location: str = None, + instance_id: str = None, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + ): + """ + Updates the metadata and configuration of a specific Redis instance. + + :param update_mask: Required. Mask of fields to update. At least one path must be supplied in this + field. The elements of the repeated paths field may only include these fields from ``Instance``: + + - ``displayName`` + - ``labels`` + - ``memorySizeGb`` + - ``redisConfig`` + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.FieldMask` + :type update_mask: Union[Dict, google.cloud.redis_v1.types.FieldMask] + :param instance: Required. Update description. Only fields specified in ``update_mask`` are updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Redis instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + """ + client = self.get_conn() + + if isinstance(instance, dict): + instance = ParseDict(instance, Instance()) + elif not isinstance(instance, Instance): + raise AirflowException("instance is not instance of Instance type or python dict") + + if location and instance_id: + name = CloudRedisClient.instance_path(project_id, location, instance_id) + instance.name = name + + self.log.info("Updating instances: %s", instance.name) + result = client.update_instance( + update_mask=update_mask, instance=instance, retry=retry, timeout=timeout, metadata=metadata + ) + result.result() + self.log.info("Instance updated: %s", instance.name) diff --git a/airflow/gcp/hooks/cloud_storage_transfer_service.py b/airflow/gcp/hooks/cloud_storage_transfer_service.py index 3d4e766c39931c..7d5508ec661a50 100644 --- a/airflow/gcp/hooks/cloud_storage_transfer_service.py +++ b/airflow/gcp/hooks/cloud_storage_transfer_service.py @@ -25,7 +25,7 @@ import warnings from copy import deepcopy from datetime import timedelta -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Union, Set, Optional from googleapiclient.discovery import build @@ -111,7 +111,7 @@ def __init__( self, api_version: str = 'v1', gcp_conn_id: str = 'google_cloud_default', - delegate_to: str = None + delegate_to: Optional[str] = None ) -> None: super().__init__(gcp_conn_id, delegate_to) self.api_version = api_version @@ -150,7 +150,7 @@ def create_transfer_job(self, body: Dict) -> Dict: @GoogleCloudBaseHook.fallback_to_default_project_id @GoogleCloudBaseHook.catch_http_exception - def get_transfer_job(self, job_name: str, project_id: str = None) -> Dict: + def get_transfer_job(self, job_name: str, project_id: Optional[str] = None) -> Dict: """ Gets the latest state of a long-running operation in Google Storage Transfer Service. @@ -172,7 +172,7 @@ def get_transfer_job(self, job_name: str, project_id: str = None) -> Dict: .execute(num_retries=self.num_retries) ) - def list_transfer_job(self, request_filter: Dict = None, **kwargs) -> List[Dict]: + def list_transfer_job(self, request_filter: Optional[Dict] = None, **kwargs) -> List[Dict]: """ Lists long-running operations in Google Storage Transfer Service that match the specified filter. @@ -230,7 +230,7 @@ def update_transfer_job(self, job_name: str, body: Dict) -> Dict: @GoogleCloudBaseHook.fallback_to_default_project_id @GoogleCloudBaseHook.catch_http_exception - def delete_transfer_job(self, job_name: str, project_id: str = None) -> None: + def delete_transfer_job(self, job_name: str, project_id: Optional[str] = None) -> None: """ Deletes a transfer job. This is a soft delete. After a transfer job is deleted, the job and all the transfer executions are subject to garbage @@ -293,7 +293,7 @@ def get_transfer_operation(self, operation_name: str) -> Dict: ) @GoogleCloudBaseHook.catch_http_exception - def list_transfer_operations(self, request_filter: Dict = None, **kwargs) -> List[Dict]: + def list_transfer_operations(self, request_filter: Optional[Dict] = None, **kwargs) -> List[Dict]: """ Gets an transfer operation in Google Storage Transfer Service. @@ -369,7 +369,7 @@ def resume_transfer_operation(self, operation_name: str): def wait_for_transfer_job( self, job: Dict, - expected_statuses: Tuple[str] = (GcpTransferOperationStatus.SUCCESS,), + expected_statuses: Optional[Set[str]] = None, timeout: Optional[Union[float, timedelta]] = None ) -> None: """ @@ -388,6 +388,9 @@ def wait_for_transfer_job( :type timeout: Optional[Union[float, timedelta]] :rtype: None """ + expected_statuses = ( + {GcpTransferOperationStatus.SUCCESS} if not expected_statuses else expected_statuses + ) if timeout is None: timeout = 60 elif isinstance(timeout, timedelta): @@ -417,7 +420,7 @@ def _inject_project_id(self, body: Dict, param_name: str, target_key: str) -> Di @staticmethod def operations_contain_expected_statuses( operations: List[Dict], - expected_statuses: Union[Tuple[str], str] + expected_statuses: Union[Set[str], str] ) -> bool: """ Checks whether the operation list has an operation with the diff --git a/airflow/gcp/hooks/dataflow.py b/airflow/gcp/hooks/dataflow.py index 45bd24edac0dd1..f8d2ee0cb32336 100644 --- a/airflow/gcp/hooks/dataflow.py +++ b/airflow/gcp/hooks/dataflow.py @@ -26,7 +26,7 @@ import subprocess import time import uuid -from typing import Dict, List, Callable, Any, Optional +from typing import Dict, List, Callable, Any, Optional, Union from googleapiclient.discovery import build @@ -38,6 +38,11 @@ DEFAULT_DATAFLOW_LOCATION = 'us-central1' +# https://github.com/apache/beam/blob/75eee7857bb80a0cdb4ce99ae3e184101092e2ed/sdks/go/pkg/beam/runners/ +# universal/runnerlib/execute.go#L85 +JOB_ID_PATTERN = re.compile(r'Submitted job:\s+([a-z|0-9|A-Z|\-|\_]+).*') + + class DataflowJobStatus: """ Helper class with Dataflow job statuses. @@ -61,7 +66,7 @@ def __init__( name: str, location: str, poll_sleep: int = 10, - job_id: str = None, + job_id: Optional[str] = None, num_retries: int = 0, multiple_jobs: bool = False ) -> None: @@ -75,7 +80,7 @@ def __init__( self._poll_sleep = poll_sleep self._jobs = self._get_jobs() - def is_job_running(self): + def is_job_running(self) -> bool: """ Helper method to check if jos is still running in dataflow @@ -88,7 +93,7 @@ def is_job_running(self): return False # pylint: disable=too-many-nested-blocks - def _get_dataflow_jobs(self): + def _get_dataflow_jobs(self) -> List[Dict]: """ Helper method to get list of jobs that start with job name or id @@ -96,10 +101,13 @@ def _get_dataflow_jobs(self): :rtype: list """ if not self._multiple_jobs and self._job_id: - return self._dataflow.projects().locations().jobs().get( - projectId=self._project_number, - location=self._job_location, - jobId=self._job_id).execute(num_retries=self._num_retries) + return [ + self._dataflow.projects().locations().jobs().get( + projectId=self._project_number, + location=self._job_location, + jobId=self._job_id + ).execute(num_retries=self._num_retries) + ] elif self._job_name: jobs = self._dataflow.projects().locations().jobs().list( projectId=self._project_number, @@ -116,7 +124,7 @@ def _get_dataflow_jobs(self): else: raise Exception('Missing both dataflow job ID and name.') - def _get_jobs(self): + def _get_jobs(self) -> List: """ Helper method to get all jobs by name @@ -145,7 +153,7 @@ def _get_jobs(self): return self._jobs # pylint: disable=too-many-nested-blocks - def check_dataflow_job_state(self, job): + def check_dataflow_job_state(self, job) -> bool: """ Helper method to check the state of all jobs in dataflow for this task if job failed raise exception @@ -176,7 +184,7 @@ def check_dataflow_job_state(self, job): DataflowJobStatus.JOB_STATE_PENDING}: time.sleep(self._poll_sleep) else: - self.log.debug(str(job)) + self.log.debug("Current job: %s", str(job)) raise Exception( "Google Cloud Dataflow job {} was unknown state: {}".format( job['name'], job['currentState'])) @@ -209,7 +217,7 @@ def get(self): class _Dataflow(LoggingMixin): - def __init__(self, cmd) -> None: + def __init__(self, cmd: Union[List, str]) -> None: self.log.info("Running command: %s", ' '.join(cmd)) self._proc = subprocess.Popen( cmd, @@ -218,23 +226,22 @@ def __init__(self, cmd) -> None: stderr=subprocess.PIPE, close_fds=True) - def _line(self, fd): + def _read_line_by_fd(self, fd): if fd == self._proc.stderr.fileno(): - line = b''.join(self._proc.stderr.readlines()) + line = self._proc.stderr.readline().decode() if line: self.log.warning(line[:-1]) return line if fd == self._proc.stdout.fileno(): - line = b''.join(self._proc.stdout.readlines()) + line = self._proc.stdout.readline().decode() if line: self.log.info(line[:-1]) return line raise Exception("No data in stderr or in stdout.") - @staticmethod - def _extract_job(line: bytes) -> Optional[str]: + def _extract_job(self, line: str) -> Optional[str]: """ Extracts job_id. @@ -244,11 +251,11 @@ def _extract_job(line: bytes) -> Optional[str]: :rtype: Optional[str] """ # Job id info: https://goo.gl/SE29y9. - job_id_pattern = re.compile( - br'.*console.cloud.google.com/dataflow.*/jobs/([a-z|0-9|A-Z|\-|\_]+).*') - matched_job = job_id_pattern.search(line or b'') + matched_job = JOB_ID_PATTERN.search(line) if matched_job: - return matched_job.group(1).decode() + job_id = matched_job.group(1) + self.log.info("Found Job ID: %s", job_id) + return job_id return None def wait_for_done(self) -> Optional[str]: @@ -265,14 +272,16 @@ def wait_for_done(self) -> Optional[str]: # terminated. process_ends = False while True: - ret = select.select(reads, [], [], 5) - if ret is None: + # Wait for at least one available fd. + readable_fbs, _, _ = select.select(reads, [], [], 5) + if readable_fbs is None: self.log.info("Waiting for DataFlow process to complete.") continue - for raw_line in ret[0]: - line = self._line(raw_line) - if line: + # Read available fds. + for readable_fb in readable_fbs: + line = self._read_line_by_fd(readable_fb) + if line and not job_id: job_id = job_id or self._extract_job(line) if process_ends: @@ -297,7 +306,7 @@ class DataFlowHook(GoogleCloudBaseHook): def __init__( self, gcp_conn_id: str = 'google_cloud_default', - delegate_to: str = None, + delegate_to: Optional[str] = None, poll_sleep: int = 10 ) -> None: self.poll_sleep = poll_sleep @@ -328,7 +337,7 @@ def _start_dataflow( .wait_for_done() @staticmethod - def _set_variables(variables: Dict): + def _set_variables(variables: Dict) -> Dict: if variables['project'] is None: raise Exception('Project not specified') if 'region' not in variables.keys(): @@ -340,7 +349,7 @@ def start_java_dataflow( job_name: str, variables: Dict, jar: str, - job_class: str = None, + job_class: Optional[str] = None, append_job_name: bool = True, multiple_jobs: bool = False ) -> None: @@ -377,7 +386,7 @@ def start_template_dataflow( variables: Dict, parameters: Dict, dataflow_template: str, - append_job_name=True + append_job_name: bool = True ) -> None: """ Starts Dataflow template job. @@ -404,7 +413,8 @@ def start_python_dataflow( variables: Dict, dataflow: str, py_options: List[str], - append_job_name: bool = True + append_job_name: bool = True, + py_interpreter: str = "python2" ): """ Starts Dataflow job. @@ -419,6 +429,11 @@ def start_python_dataflow( :type py_options: list :param append_job_name: True if unique suffix has to be appended to job name. :type append_job_name: bool + :param py_interpreter: Python version of the beam pipeline. + If None, this defaults to the python2. + To track python versions supported by beam and related + issues check: https://issues.apache.org/jira/browse/BEAM-1251 + :type py_interpreter: str """ name = self._build_dataflow_job_name(job_name, append_job_name) variables['job_name'] = name @@ -427,7 +442,7 @@ def label_formatter(labels_dict): return ['--labels={}={}'.format(key, value) for key, value in labels_dict.items()] - self._start_dataflow(variables, name, ["python2"] + py_options + [dataflow], + self._start_dataflow(variables, name, [py_interpreter] + py_options + [dataflow], label_formatter) @staticmethod diff --git a/airflow/gcp/hooks/dataproc.py b/airflow/gcp/hooks/dataproc.py index cf846e1545af7a..75e4706243f6db 100644 --- a/airflow/gcp/hooks/dataproc.py +++ b/airflow/gcp/hooks/dataproc.py @@ -565,6 +565,143 @@ def cancel(self, project_id: str, job_id: str, region: str = 'global') -> Dict: jobId=job_id ) + def get_final_cluster_state(self, project_id, region, cluster_name, logger): + """ + Poll for the state of a cluster until one is available + + :param project_id: + :param region: + :param cluster_name: + :param logger: + :return: + """ + while True: + state = DataProcHook.get_cluster_state(self.get_conn(), project_id, region, cluster_name) + if state is None: + logger.info("No state for cluster '%s'", cluster_name) + time.sleep(15) + else: + logger.info("State for cluster '%s' is %s", cluster_name, state) + return state + + @staticmethod + def get_cluster_state(service, project_id, region, cluster_name): + """ + Get the state of a cluster if it has one, otherwise None + :param service: + :param project_id: + :param region: + :param cluster_name: + :return: + """ + cluster = DataProcHook.find_cluster(service, project_id, region, cluster_name) + if cluster and 'status' in cluster: + return cluster['status']['state'] + else: + return None + + @staticmethod + def find_cluster(service, project_id, region, cluster_name): + """ + Retrieve a cluster from the project/region if it exists, otherwise None + :param service: + :param project_id: + :param region: + :param cluster_name: + :return: + """ + cluster_list = DataProcHook.get_cluster_list_for_project(service, project_id, region) + cluster = [c for c in cluster_list if c['clusterName'] == cluster_name] + if cluster: + return cluster[0] + return None + + @staticmethod + def get_cluster_list_for_project(service, project_id, region): + """ + List all clusters for a given project/region, an empty list if none exist + :param service: + :param project_id: + :param region: + :return: + """ + result = service.projects().regions().clusters().list( + projectId=project_id, + region=region + ).execute() + return result.get('clusters', []) + + @staticmethod + def execute_dataproc_diagnose(service, project_id, region, cluster_name): + """ + Execute the diagonse command against a given cluster, useful to get debugging + information if something has gone wrong or cluster creation failed. + :param service: + :param project_id: + :param region: + :param cluster_name: + :return: + """ + response = service.projects().regions().clusters().diagnose( + projectId=project_id, + region=region, + clusterName=cluster_name, + body={} + ).execute() + operation_name = response['name'] + return operation_name + + @staticmethod + def execute_delete(service, project_id, region, cluster_name): + """ + Delete a specified cluster + :param service: + :param project_id: + :param region: + :param cluster_name: + :return: The identifier of the operation being executed + """ + response = service.projects().regions().clusters().delete( + projectId=project_id, + region=region, + clusterName=cluster_name + ).execute(num_retries=5) + operation_name = response['name'] + return operation_name + + @staticmethod + def wait_for_operation_done(service, operation_name): + """ + Poll for the completion of a specific GCP operation + :param service: + :param operation_name: + :return: The response code of the completed operation + """ + while True: + response = service.projects().regions().operations().get( + name=operation_name + ).execute(num_retries=5) + + if response.get('done'): + return response + time.sleep(15) + + @staticmethod + def wait_for_operation_done_or_error(service, operation_name): + """ + Block until the specified operation is done. Throws an AirflowException if + the operation completed but had an error + :param service: + :param operation_name: + :return: + """ + response = DataProcHook.wait_for_operation_done(service, operation_name) + if response.get('done'): + if 'error' in response: + raise AirflowException(str(response['error'])) + else: + return + setattr( DataProcHook, diff --git a/airflow/gcp/hooks/datastore.py b/airflow/gcp/hooks/datastore.py index ca16a627221556..13a89794cfd799 100644 --- a/airflow/gcp/hooks/datastore.py +++ b/airflow/gcp/hooks/datastore.py @@ -22,6 +22,7 @@ """ import time +from typing import Any, List, Dict, Union, Optional from googleapiclient.discovery import build @@ -40,14 +41,14 @@ class DatastoreHook(GoogleCloudBaseHook): """ def __init__(self, - datastore_conn_id='google_cloud_default', - delegate_to=None, - api_version='v1'): + datastore_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + api_version: str = 'v1') -> None: super().__init__(datastore_conn_id, delegate_to) self.connection = None self.api_version = api_version - def get_conn(self): + def get_conn(self) -> Any: """ Establishes a connection to the Google API. @@ -62,7 +63,7 @@ def get_conn(self): return self.connection @GoogleCloudBaseHook.fallback_to_default_project_id - def allocate_ids(self, partial_keys, project_id=None): + def allocate_ids(self, partial_keys: List, project_id: Optional[str] = None) -> List: """ Allocate IDs for incomplete keys. @@ -76,7 +77,7 @@ def allocate_ids(self, partial_keys, project_id=None): :return: a list of full keys. :rtype: list """ - conn = self.get_conn() + conn = self.get_conn() # type: Any resp = (conn # pylint:disable=no-member .projects() @@ -86,7 +87,7 @@ def allocate_ids(self, partial_keys, project_id=None): return resp['keys'] @GoogleCloudBaseHook.fallback_to_default_project_id - def begin_transaction(self, project_id=None): + def begin_transaction(self, project_id: Optional[str] = None) -> str: """ Begins a new transaction. @@ -98,7 +99,7 @@ def begin_transaction(self, project_id=None): :return: a transaction handle. :rtype: str """ - conn = self.get_conn() + conn = self.get_conn() # type: Any resp = (conn # pylint:disable=no-member .projects() @@ -108,7 +109,7 @@ def begin_transaction(self, project_id=None): return resp['transaction'] @GoogleCloudBaseHook.fallback_to_default_project_id - def commit(self, body, project_id=None): + def commit(self, body: Dict, project_id: Optional[str] = None) -> Dict: """ Commit a transaction, optionally creating, deleting or modifying some entities. @@ -122,7 +123,7 @@ def commit(self, body, project_id=None): :return: the response body of the commit request. :rtype: dict """ - conn = self.get_conn() + conn = self.get_conn() # type: Any resp = (conn # pylint:disable=no-member .projects() @@ -132,7 +133,11 @@ def commit(self, body, project_id=None): return resp @GoogleCloudBaseHook.fallback_to_default_project_id - def lookup(self, keys, read_consistency=None, transaction=None, project_id=None): + def lookup(self, + keys: List, + read_consistency: Optional[str] = None, + transaction: Optional[str] = None, + project_id: Optional[str] = None) -> Dict: """ Lookup some entities by key. @@ -151,9 +156,9 @@ def lookup(self, keys, read_consistency=None, transaction=None, project_id=None) :return: the response body of the lookup request. :rtype: dict """ - conn = self.get_conn() + conn = self.get_conn() # type: Any - body = {'keys': keys} + body = {'keys': keys} # type: Dict[str, Any] if read_consistency: body['readConsistency'] = read_consistency if transaction: @@ -166,7 +171,7 @@ def lookup(self, keys, read_consistency=None, transaction=None, project_id=None) return resp @GoogleCloudBaseHook.fallback_to_default_project_id - def rollback(self, transaction, project_id=None): + def rollback(self, transaction: str, project_id: Optional[str] = None) -> Any: """ Roll back a transaction. @@ -178,14 +183,14 @@ def rollback(self, transaction, project_id=None): :param project_id: Google Cloud Platform project ID against which to make the request. :type project_id: str """ - conn = self.get_conn() + conn = self.get_conn() # type: Any conn.projects().rollback( # pylint:disable=no-member projectId=project_id, body={'transaction': transaction} ).execute(num_retries=self.num_retries) @GoogleCloudBaseHook.fallback_to_default_project_id - def run_query(self, body, project_id=None): + def run_query(self, body: Dict, project_id: Optional[str] = None) -> Dict: """ Run a query for entities. @@ -199,7 +204,7 @@ def run_query(self, body, project_id=None): :return: the batch of query results. :rtype: dict """ - conn = self.get_conn() + conn = self.get_conn() # type: Any resp = (conn # pylint:disable=no-member .projects() @@ -208,7 +213,7 @@ def run_query(self, body, project_id=None): return resp['batch'] - def get_operation(self, name): + def get_operation(self, name: str) -> Dict: """ Gets the latest state of a long-running operation. @@ -220,7 +225,7 @@ def get_operation(self, name): :return: a resource operation instance. :rtype: dict """ - conn = self.get_conn() + conn = self.get_conn() # type: Any resp = (conn # pylint:disable=no-member .projects() @@ -230,7 +235,7 @@ def get_operation(self, name): return resp - def delete_operation(self, name): + def delete_operation(self, name: str) -> Dict: """ Deletes the long-running operation. @@ -242,7 +247,7 @@ def delete_operation(self, name): :return: none if successful. :rtype: dict """ - conn = self.get_conn() + conn = self.get_conn() # type: Any resp = (conn # pylint:disable=no-member .projects() @@ -252,7 +257,7 @@ def delete_operation(self, name): return resp - def poll_operation_until_done(self, name, polling_interval_in_seconds): + def poll_operation_until_done(self, name: str, polling_interval_in_seconds: int) -> Dict: """ Poll backup operation state until it's completed. @@ -264,9 +269,9 @@ def poll_operation_until_done(self, name, polling_interval_in_seconds): :rtype: dict """ while True: - result = self.get_operation(name) + result = self.get_operation(name) # type: Dict - state = result['metadata']['common']['state'] + state = result['metadata']['common']['state'] # type: str if state == 'PROCESSING': self.log.info('Operation is processing. Re-polling state in {} seconds' .format(polling_interval_in_seconds)) @@ -275,8 +280,12 @@ def poll_operation_until_done(self, name, polling_interval_in_seconds): return result @GoogleCloudBaseHook.fallback_to_default_project_id - def export_to_storage_bucket(self, bucket, namespace=None, entity_filter=None, - labels=None, project_id=None): + def export_to_storage_bucket(self, + bucket: str, + namespace: Optional[str] = None, + entity_filter: Optional[Dict] = None, + labels: Optional[Dict[str, str]] = None, + project_id: Optional[str] = None) -> Dict: """ Export entities from Cloud Datastore to Cloud Storage for backup. @@ -299,9 +308,9 @@ def export_to_storage_bucket(self, bucket, namespace=None, entity_filter=None, :return: a resource operation instance. :rtype: dict """ - admin_conn = self.get_conn() + admin_conn = self.get_conn() # type: Any - output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace])) + output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace])) # type: str if not entity_filter: entity_filter = {} if not labels: @@ -310,7 +319,7 @@ def export_to_storage_bucket(self, bucket, namespace=None, entity_filter=None, 'outputUrlPrefix': output_uri_prefix, 'entityFilter': entity_filter, 'labels': labels, - } + } # type: Dict resp = (admin_conn # pylint:disable=no-member .projects() .export(projectId=project_id, body=body) @@ -319,8 +328,13 @@ def export_to_storage_bucket(self, bucket, namespace=None, entity_filter=None, return resp @GoogleCloudBaseHook.fallback_to_default_project_id - def import_from_storage_bucket(self, bucket, file, namespace=None, - entity_filter=None, labels=None, project_id=None): + def import_from_storage_bucket(self, + bucket: str, + file: str, + namespace: Optional[str] = None, + entity_filter: Optional[Dict] = None, + labels: Optional[Union[Dict, str]] = None, + project_id: Optional[str] = None) -> Dict: """ Import a backup from Cloud Storage to Cloud Datastore. @@ -345,9 +359,9 @@ def import_from_storage_bucket(self, bucket, file, namespace=None, :return: a resource operation instance. :rtype: dict """ - admin_conn = self.get_conn() + admin_conn = self.get_conn() # type: Any - input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file])) + input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file])) # type: str if not entity_filter: entity_filter = {} if not labels: @@ -356,7 +370,7 @@ def import_from_storage_bucket(self, bucket, file, namespace=None, 'inputUrl': input_url, 'entityFilter': entity_filter, 'labels': labels, - } + } # type: Dict resp = (admin_conn # pylint:disable=no-member .projects() .import_(projectId=project_id, body=body) diff --git a/airflow/gcp/hooks/discovery_api.py b/airflow/gcp/hooks/discovery_api.py new file mode 100644 index 00000000000000..3dbd8e5c1f0dbd --- /dev/null +++ b/airflow/gcp/hooks/discovery_api.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module allows you to connect to the Google Discovery API Service and query it. +""" +from googleapiclient.discovery import build + +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook + + +class GoogleDiscoveryApiHook(GoogleCloudBaseHook): + """ + A hook to use the Google API Discovery Service. + + :param api_service_name: The name of the api service that is needed to get the data + for example 'youtube'. + :type api_service_name: str + :param api_version: The version of the api that will be requested for example 'v3'. + :type api_version: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + _conn = None + + def __init__(self, api_service_name, api_version, gcp_conn_id='google_cloud_default', delegate_to=None): + super(GoogleDiscoveryApiHook, self).__init__(gcp_conn_id=gcp_conn_id, delegate_to=delegate_to) + self.api_service_name = api_service_name + self.api_version = api_version + + def get_conn(self): + """ + Creates an authenticated api client for the given api service name and credentials. + + :return: the authenticated api service. + :rtype: Resource + """ + self.log.info("Authenticating Google API Client") + + if not self._conn: + http_authorized = self._authorize() + self._conn = build( + serviceName=self.api_service_name, + version=self.api_version, + http=http_authorized, + cache_discovery=False + ) + return self._conn + + def query(self, endpoint, data, paginate=False, num_retries=0): + """ + Creates a dynamic API call to any Google API registered in Google's API Client Library + and queries it. + + :param endpoint: The client libraries path to the api call's executing method. + For example: 'analyticsreporting.reports.batchGet' + + .. seealso:: https://developers.google.com/apis-explorer + for more information on what methods are available. + :type endpoint: str + :param data: The data (endpoint params) needed for the specific request to given endpoint. + :type data: dict + :param paginate: If set to True, it will collect all pages of data. + :type paginate: bool + :param num_retries: Define the number of retries for the requests being made if it fails. + :type num_retries: int + :return: the API response from the passed endpoint. + :rtype: dict + """ + google_api_conn_client = self.get_conn() + + api_response = self._call_api_request(google_api_conn_client, endpoint, data, paginate, num_retries) + return api_response + + def _call_api_request(self, google_api_conn_client, endpoint, data, paginate, num_retries): + api_endpoint_parts = endpoint.split('.') + + google_api_endpoint_instance = self._build_api_request( + google_api_conn_client, + api_sub_functions=api_endpoint_parts[1:], + api_endpoint_params=data + ) + + if paginate: + return self._paginate_api( + google_api_endpoint_instance, google_api_conn_client, api_endpoint_parts, num_retries + ) + + return google_api_endpoint_instance.execute(num_retries=num_retries) + + def _build_api_request(self, google_api_conn_client, api_sub_functions, api_endpoint_params): + for sub_function in api_sub_functions: + google_api_conn_client = getattr(google_api_conn_client, sub_function) + if sub_function != api_sub_functions[-1]: + google_api_conn_client = google_api_conn_client() + else: + google_api_conn_client = google_api_conn_client(**api_endpoint_params) + + return google_api_conn_client + + def _paginate_api( + self, google_api_endpoint_instance, google_api_conn_client, api_endpoint_parts, num_retries + ): + api_responses = [] + + while google_api_endpoint_instance: + api_response = google_api_endpoint_instance.execute(num_retries=num_retries) + api_responses.append(api_response) + + google_api_endpoint_instance = self._build_next_api_request( + google_api_conn_client, api_endpoint_parts[1:], google_api_endpoint_instance, api_response + ) + + return api_responses + + def _build_next_api_request( + self, google_api_conn_client, api_sub_functions, api_endpoint_instance, api_response + ): + for sub_function in api_sub_functions: + if sub_function != api_sub_functions[-1]: + google_api_conn_client = getattr(google_api_conn_client, sub_function) + google_api_conn_client = google_api_conn_client() + else: + google_api_conn_client = getattr(google_api_conn_client, sub_function + '_next') + google_api_conn_client = google_api_conn_client(api_endpoint_instance, api_response) + + return google_api_conn_client diff --git a/airflow/gcp/hooks/gcs.py b/airflow/gcp/hooks/gcs.py new file mode 100644 index 00000000000000..0c6865ae764ec1 --- /dev/null +++ b/airflow/gcp/hooks/gcs.py @@ -0,0 +1,783 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains a Google Cloud Storage hook. +""" +import os +from os import path +from typing import Optional, Set, Tuple, Union +import gzip as gz +import shutil +from io import BytesIO + +from urllib.parse import urlparse +from google.cloud import storage + +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook +from airflow.exceptions import AirflowException +from airflow.version import version + + +class GoogleCloudStorageHook(GoogleCloudBaseHook): + """ + Interact with Google Cloud Storage. This hook uses the Google Cloud Platform + connection. + """ + + _conn = None # type: Optional[storage.Client] + + def __init__(self, + google_cloud_storage_conn_id='google_cloud_default', + delegate_to=None): + super().__init__(google_cloud_storage_conn_id, + delegate_to) + + def get_conn(self): + """ + Returns a Google Cloud Storage service object. + """ + if not self._conn: + self._conn = storage.Client(credentials=self._get_credentials(), + client_info=self.client_info, + project=self.project_id) + + return self._conn + + def copy(self, source_bucket, source_object, destination_bucket=None, + destination_object=None): + """ + Copies an object from a bucket to another, with renaming if requested. + + destination_bucket or destination_object can be omitted, in which case + source bucket/object is used, but not both. + + :param source_bucket: The bucket of the object to copy from. + :type source_bucket: str + :param source_object: The object to copy. + :type source_object: str + :param destination_bucket: The destination of the object to copied to. + Can be omitted; then the same bucket is used. + :type destination_bucket: str + :param destination_object: The (renamed) path of the object if given. + Can be omitted; then the same name is used. + :type destination_object: str + """ + destination_bucket = destination_bucket or source_bucket + destination_object = destination_object or source_object + + if source_bucket == destination_bucket and \ + source_object == destination_object: + + raise ValueError( + 'Either source/destination bucket or source/destination object ' + 'must be different, not both the same: bucket=%s, object=%s' % + (source_bucket, source_object)) + if not source_bucket or not source_object: + raise ValueError('source_bucket and source_object cannot be empty.') + + client = self.get_conn() + source_bucket = client.bucket(source_bucket) + source_object = source_bucket.blob(source_object) + destination_bucket = client.bucket(destination_bucket) + destination_object = source_bucket.copy_blob( + blob=source_object, + destination_bucket=destination_bucket, + new_name=destination_object) + + self.log.info('Object %s in bucket %s copied to object %s in bucket %s', + source_object.name, source_bucket.name, + destination_object.name, destination_bucket.name) + + def rewrite(self, source_bucket, source_object, destination_bucket, + destination_object=None): + """ + Has the same functionality as copy, except that will work on files + over 5 TB, as well as when copying between locations and/or storage + classes. + + destination_object can be omitted, in which case source_object is used. + + :param source_bucket: The bucket of the object to copy from. + :type source_bucket: str + :param source_object: The object to copy. + :type source_object: str + :param destination_bucket: The destination of the object to copied to. + :type destination_bucket: str + :param destination_object: The (renamed) path of the object if given. + Can be omitted; then the same name is used. + :type destination_object: str + """ + destination_object = destination_object or source_object + if (source_bucket == destination_bucket and + source_object == destination_object): + raise ValueError( + 'Either source/destination bucket or source/destination object ' + 'must be different, not both the same: bucket=%s, object=%s' % + (source_bucket, source_object)) + if not source_bucket or not source_object: + raise ValueError('source_bucket and source_object cannot be empty.') + + client = self.get_conn() + source_bucket = client.bucket(source_bucket) + source_object = source_bucket.blob(blob_name=source_object) + destination_bucket = client.bucket(destination_bucket) + + token, bytes_rewritten, total_bytes = destination_bucket.blob( + blob_name=destination_object).rewrite( + source=source_object + ) + + self.log.info('Total Bytes: %s | Bytes Written: %s', + total_bytes, bytes_rewritten) + + while token is not None: + token, bytes_rewritten, total_bytes = destination_bucket.blob( + blob_name=destination_object).rewrite( + source=source_object, token=token + ) + + self.log.info('Total Bytes: %s | Bytes Written: %s', + total_bytes, bytes_rewritten) + self.log.info('Object %s in bucket %s rewritten to object %s in bucket %s', + source_object.name, source_bucket.name, + destination_object, destination_bucket.name) + + def download(self, bucket_name, object_name, filename=None): + """ + Downloads a file from Google Cloud Storage. + + When no filename is supplied, the operator loads the file into memory and returns its + content. When a filename is supplied, it writes the file to the specified location and + returns the location. For file sizes that exceed the available memory it is recommended + to write to a file. + + :param bucket_name: The bucket to fetch from. + :type bucket_name: str + :param object_name: The object to fetch. + :type object_name: str + :param filename: If set, a local file path where the file should be written to. + :type filename: str + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name=object_name) + + if filename: + blob.download_to_filename(filename) + self.log.info('File downloaded to %s', filename) + return filename + else: + return blob.download_as_string() + + def upload(self, bucket_name: str, object_name: str, filename: str = None, + data: Union[str, bytes] = None, mime_type: str = None, gzip: bool = False, + encoding: str = 'utf-8') -> None: + """ + Uploads a local file or file data as string or bytes to Google Cloud Storage. + + :param bucket_name: The bucket to upload to. + :type bucket_name: str + :param object_name: The object name to set when uploading the file. + :type object_name: str + :param filename: The local file path to the file to be uploaded. + :type filename: str + :param data: The file's data as a string or bytes to be uploaded. + :type data: str + :param mime_type: The file's mime type set when uploading the file. + :type mime_type: str + :param gzip: Option to compress local file or file data for upload + :type gzip: bool + :param encoding: bytes encoding for file data if provided as string + :type encoding: str + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name=object_name) + if filename and data: + raise ValueError("'filename' and 'data' parameter provided. Please " + "specify a single parameter, either 'filename' for " + "local file uploads or 'data' for file content uploads.") + elif filename: + if not mime_type: + mime_type = 'application/octet-stream' + if gzip: + filename_gz = filename + '.gz' + + with open(filename, 'rb') as f_in: + with gz.open(filename_gz, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + filename = filename_gz + + blob.upload_from_filename(filename=filename, + content_type=mime_type) + if gzip: + os.remove(filename) + self.log.info('File %s uploaded to %s in %s bucket', filename, object_name, bucket_name) + elif data: + if not mime_type: + mime_type = 'text/plain' + if gzip: + if isinstance(data, str): + data = bytes(data, encoding) + out = BytesIO() + with gz.GzipFile(fileobj=out, mode="w") as f: + f.write(data) + data = out.getvalue() + blob.upload_from_string(data, + content_type=mime_type) + self.log.info('Data stream uploaded to %s in %s bucket', object_name, bucket_name) + else: + raise ValueError("'filename' and 'data' parameter missing. " + "One is required to upload to gcs.") + + def exists(self, bucket_name, object_name): + """ + Checks for the existence of a file in Google Cloud Storage. + + :param bucket_name: The Google cloud storage bucket where the object is. + :type bucket_name: str + :param object_name: The name of the blob_name to check in the Google cloud + storage bucket. + :type object_name: str + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name=object_name) + return blob.exists() + + def is_updated_after(self, bucket_name, object_name, ts): + """ + Checks if an blob_name is updated in Google Cloud Storage. + + :param bucket_name: The Google cloud storage bucket where the object is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket. + :type object_name: str + :param ts: The timestamp to check against. + :type ts: datetime.datetime + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.get_blob(blob_name=object_name) + + if blob is None: + raise ValueError("Object ({}) not found in Bucket ({})".format( + object_name, bucket_name)) + + blob_update_time = blob.updated + + if blob_update_time is not None: + import dateutil.tz + + if not ts.tzinfo: + ts = ts.replace(tzinfo=dateutil.tz.tzutc()) + + self.log.info("Verify object date: %s > %s", blob_update_time, ts) + + if blob_update_time > ts: + return True + + return False + + def delete(self, bucket_name, object_name): + """ + Deletes an object from the bucket. + + :param bucket_name: name of the bucket, where the object resides + :type bucket_name: str + :param object_name: name of the object to delete + :type object_name: str + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name=object_name) + blob.delete() + + self.log.info('Blob %s deleted.', object_name) + + def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimiter=None): + """ + List all objects from the bucket with the give string prefix in name + + :param bucket_name: bucket name + :type bucket_name: str + :param versions: if true, list all versions of the objects + :type versions: bool + :param max_results: max count of items to return in a single page of responses + :type max_results: int + :param prefix: prefix string which filters objects whose name begin with + this prefix + :type prefix: str + :param delimiter: filters objects based on the delimiter (for e.g '.csv') + :type delimiter: str + :return: a stream of object names matching the filtering criteria + """ + client = self.get_conn() + bucket = client.bucket(bucket_name) + + ids = [] + page_token = None + while True: + blobs = bucket.list_blobs( + max_results=max_results, + page_token=page_token, + prefix=prefix, + delimiter=delimiter, + versions=versions + ) + + blob_names = [] + for blob in blobs: + blob_names.append(blob.name) + + prefixes = blobs.prefixes + if prefixes: + ids += list(prefixes) + else: + ids += blob_names + + page_token = blobs.next_page_token + if page_token is None: + # empty next page token + break + return ids + + def get_size(self, bucket_name, object_name): + """ + Gets the size of a file in Google Cloud Storage. + + :param bucket_name: The Google cloud storage bucket where the blob_name is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google + cloud storage bucket_name. + :type object_name: str + + """ + self.log.info('Checking the file size of object: %s in bucket_name: %s', + object_name, + bucket_name) + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.get_blob(blob_name=object_name) + blob_size = blob.size + self.log.info('The file size of %s is %s bytes.', object_name, blob_size) + return blob_size + + def get_crc32c(self, bucket_name, object_name): + """ + Gets the CRC32c checksum of an object in Google Cloud Storage. + + :param bucket_name: The Google cloud storage bucket where the blob_name is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket_name. + :type object_name: str + """ + self.log.info('Retrieving the crc32c checksum of ' + 'object_name: %s in bucket_name: %s', object_name, bucket_name) + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.get_blob(blob_name=object_name) + blob_crc32c = blob.crc32c + self.log.info('The crc32c checksum of %s is %s', object_name, blob_crc32c) + return blob_crc32c + + def get_md5hash(self, bucket_name, object_name): + """ + Gets the MD5 hash of an object in Google Cloud Storage. + + :param bucket_name: The Google cloud storage bucket where the blob_name is. + :type bucket_name: str + :param object_name: The name of the object to check in the Google cloud + storage bucket_name. + :type object_name: str + """ + self.log.info('Retrieving the MD5 hash of ' + 'object: %s in bucket: %s', object_name, bucket_name) + client = self.get_conn() + bucket = client.bucket(bucket_name) + blob = bucket.get_blob(blob_name=object_name) + blob_md5hash = blob.md5_hash + self.log.info('The md5Hash of %s is %s', object_name, blob_md5hash) + return blob_md5hash + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id + def create_bucket(self, + bucket_name, + resource=None, + storage_class='MULTI_REGIONAL', + location='US', + project_id=None, + labels=None + ): + """ + Creates a new bucket. Google Cloud Storage uses a flat namespace, so + you can't create a bucket with a name that is already in use. + + .. seealso:: + For more information, see Bucket Naming Guidelines: + https://cloud.google.com/storage/docs/bucketnaming.html#requirements + + :param bucket_name: The name of the bucket. + :type bucket_name: str + :param resource: An optional dict with parameters for creating the bucket. + For information on available parameters, see Cloud Storage API doc: + https://cloud.google.com/storage/docs/json_api/v1/buckets/insert + :type resource: dict + :param storage_class: This defines how objects in the bucket are stored + and determines the SLA and the cost of storage. Values include + + - ``MULTI_REGIONAL`` + - ``REGIONAL`` + - ``STANDARD`` + - ``NEARLINE`` + - ``COLDLINE``. + + If this value is not specified when the bucket is + created, it will default to STANDARD. + :type storage_class: str + :param location: The location of the bucket. + Object data for objects in the bucket resides in physical storage + within this region. Defaults to US. + + .. seealso:: + https://developers.google.com/storage/docs/bucket-locations + + :type location: str + :param project_id: The ID of the GCP Project. + :type project_id: str + :param labels: User-provided labels, in key/value pairs. + :type labels: dict + :return: If successful, it returns the ``id`` of the bucket. + """ + + self.log.info('Creating Bucket: %s; Location: %s; Storage Class: %s', + bucket_name, location, storage_class) + + # Add airflow-version label to the bucket + labels = labels or {} + labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-') + + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket_name) + bucket_resource = resource or {} + + for item in bucket_resource: + if item != "name": + bucket._patch_property(name=item, value=resource[item]) # pylint: disable=protected-access + + bucket.storage_class = storage_class + bucket.labels = labels + bucket.create(project=project_id, location=location) + return bucket.id + + def insert_bucket_acl(self, bucket_name, entity, role, user_project=None): + """ + Creates a new ACL entry on the specified bucket_name. + See: https://cloud.google.com/storage/docs/json_api/v1/bucketAccessControls/insert + + :param bucket_name: Name of a bucket_name. + :type bucket_name: str + :param entity: The entity holding the permission, in one of the following forms: + user-userId, user-email, group-groupId, group-email, domain-domain, + project-team-projectId, allUsers, allAuthenticatedUsers. + See: https://cloud.google.com/storage/docs/access-control/lists#scopes + :type entity: str + :param role: The access permission for the entity. + Acceptable values are: "OWNER", "READER", "WRITER". + :type role: str + :param user_project: (Optional) The project to be billed for this request. + Required for Requester Pays buckets. + :type user_project: str + """ + self.log.info('Creating a new ACL entry in bucket: %s', bucket_name) + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket_name) + bucket.acl.reload() + bucket.acl.entity_from_dict(entity_dict={"entity": entity, "role": role}) + if user_project: + bucket.acl.user_project = user_project + bucket.acl.save() + + self.log.info('A new ACL entry created in bucket: %s', bucket_name) + + def insert_object_acl(self, bucket_name, object_name, entity, role, generation=None, user_project=None): + """ + Creates a new ACL entry on the specified object. + See: https://cloud.google.com/storage/docs/json_api/v1/objectAccessControls/insert + + :param bucket_name: Name of a bucket_name. + :type bucket_name: str + :param object_name: Name of the object. For information about how to URL encode + object names to be path safe, see: + https://cloud.google.com/storage/docs/json_api/#encoding + :type object_name: str + :param entity: The entity holding the permission, in one of the following forms: + user-userId, user-email, group-groupId, group-email, domain-domain, + project-team-projectId, allUsers, allAuthenticatedUsers + See: https://cloud.google.com/storage/docs/access-control/lists#scopes + :type entity: str + :param role: The access permission for the entity. + Acceptable values are: "OWNER", "READER". + :type role: str + :param generation: Optional. If present, selects a specific revision of this object. + :type generation: long + :param user_project: (Optional) The project to be billed for this request. + Required for Requester Pays buckets. + :type user_project: str + """ + self.log.info('Creating a new ACL entry for object: %s in bucket: %s', + object_name, bucket_name) + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket_name) + blob = bucket.blob(blob_name=object_name, generation=generation) + # Reload fetches the current ACL from Cloud Storage. + blob.acl.reload() + blob.acl.entity_from_dict(entity_dict={"entity": entity, "role": role}) + if user_project: + blob.acl.user_project = user_project + blob.acl.save() + + self.log.info('A new ACL entry created for object: %s in bucket: %s', + object_name, bucket_name) + + def compose(self, bucket_name, source_objects, destination_object): + """ + Composes a list of existing object into a new object in the same storage bucket_name + + Currently it only supports up to 32 objects that can be concatenated + in a single operation + + https://cloud.google.com/storage/docs/json_api/v1/objects/compose + + :param bucket_name: The name of the bucket containing the source objects. + This is also the same bucket to store the composed destination object. + :type bucket_name: str + :param source_objects: The list of source objects that will be composed + into a single object. + :type source_objects: list + :param destination_object: The path of the object if given. + :type destination_object: str + """ + + if not source_objects: + raise ValueError('source_objects cannot be empty.') + + if not bucket_name or not destination_object: + raise ValueError('bucket_name and destination_object cannot be empty.') + + self.log.info("Composing %s to %s in the bucket %s", + source_objects, destination_object, bucket_name) + client = self.get_conn() + bucket = client.bucket(bucket_name) + destination_blob = bucket.blob(destination_object) + destination_blob.compose( + sources=[ + bucket.blob(blob_name=source_object) for source_object in source_objects + ]) + + self.log.info("Completed successfully.") + + def sync( + self, + source_bucket: str, + destination_bucket: str, + source_object: Optional[str] = None, + destination_object: Optional[str] = None, + recursive: bool = True, + allow_overwrite: bool = False, + delete_extra_files: bool = False + ): + """ + Synchronizes the contents of the buckets. + + Parameters ``source_object`` and ``destination_object`` describe the root sync directories. If they + are not passed, the entire bucket will be synchronized. If they are passed, they should point + to directories. + + .. note:: + The synchronization of individual files is not supported. Only entire directories can be + synchronized. + + :param source_bucket: The name of the bucket containing the source objects. + :type source_bucket: str + :param destination_bucket: The name of the bucket containing the destination objects. + :type destination_bucket: str + :param source_object: The root sync directory in the source bucket. + :type source_object: Optional[str] + :param destination_object: The root sync directory in the destination bucket. + :type destination_object: Optional[str] + :param recursive: If True, subdirectories will be considered + :type recursive: bool + :param recursive: If True, subdirectories will be considered + :type recursive: bool + :param allow_overwrite: if True, the files will be overwritten if a mismatched file is found. + By default, overwriting files is not allowed + :type allow_overwrite: bool + :param delete_extra_files: if True, deletes additional files from the source that not found in the + destination. By default extra files are not deleted. + + .. note:: + This option can delete data quickly if you specify the wrong source/destination combination. + + :type delete_extra_files: bool + :return: none + """ + client = self.get_conn() + # Create bucket object + source_bucket_obj = client.bucket(source_bucket) + destination_bucket_obj = client.bucket(destination_bucket) + # Normalize parameters when they are passed + source_object = self._normalize_directory_path(source_object) + destination_object = self._normalize_directory_path(destination_object) + # Calculate the number of characters that remove from the name, because they contain information + # about the parent's path + source_object_prefix_len = len(source_object) if source_object else 0 + # Prepare synchronization plan + to_copy_blobs, to_delete_blobs, to_rewrite_blobs = self._prepare_sync_plan( + source_bucket=source_bucket_obj, + destination_bucket=destination_bucket_obj, + source_object=source_object, + destination_object=destination_object, + recursive=recursive + ) + self.log.info( + "Planned synchronization. To delete blobs count: %s, to upload blobs count: %s, " + "to rewrite blobs count: %s", + len(to_delete_blobs), + len(to_copy_blobs), + len(to_rewrite_blobs), + ) + + # Copy missing object to new bucket + if not to_copy_blobs: + self.log.info("Skipped blobs copying.") + else: + for blob in to_copy_blobs: + dst_object = self._calculate_sync_destination_path( + blob, destination_object, source_object_prefix_len + ) + self.copy( + source_bucket=source_bucket_obj.name, + source_object=blob.name, + destination_bucket=destination_bucket_obj.name, + destination_object=dst_object, + ) + self.log.info("Blobs copied.") + # Delete redundant files + if not to_delete_blobs: + self.log.info("Skipped blobs deleting.") + elif delete_extra_files: + # TODO: Add batch. I tried to do it, but the Google library is not stable at the moment. + for blob in to_delete_blobs: + self.delete(blob.bucket.name, blob.name) + self.log.info("Blobs deleted.") + + # Overwrite files that are different + if not to_rewrite_blobs: + self.log.info("Skipped blobs overwriting.") + elif allow_overwrite: + for blob in to_rewrite_blobs: + dst_object = self._calculate_sync_destination_path(blob, destination_object, + source_object_prefix_len) + self.rewrite( + source_bucket=source_bucket_obj.name, + source_object=blob.name, + destination_bucket=destination_bucket_obj.name, + destination_object=dst_object, + ) + self.log.info("Blobs rewritten.") + + self.log.info("Synchronization finished.") + + def _calculate_sync_destination_path( + self, + blob: storage.Blob, + destination_object: Optional[str], + source_object_prefix_len: int + ) -> str: + return ( + path.join(destination_object, blob.name[source_object_prefix_len:]) + if destination_object + else blob.name[source_object_prefix_len:] + ) + + def _normalize_directory_path(self, source_object: Optional[str]) -> Optional[str]: + return ( + source_object + "/" if source_object and not source_object.endswith("/") else source_object + ) + + @staticmethod + def _prepare_sync_plan( + source_bucket: storage.Bucket, + destination_bucket: storage.Bucket, + source_object: Optional[str], + destination_object: Optional[str], + recursive: bool, + ) -> Tuple[Set[storage.Blob], Set[storage.Blob], Set[storage.Blob]]: + # Calculate the number of characters that remove from the name, because they contain information + # about the parent's path + source_object_prefix_len = len(source_object) if source_object else 0 + destination_object_prefix_len = len(destination_object) if destination_object else 0 + delimiter = "/" if not recursive else None + # Fetch blobs list + source_blobs = list(source_bucket.list_blobs(prefix=source_object, delimiter=delimiter)) + destination_blobs = list( + destination_bucket.list_blobs(prefix=destination_object, delimiter=delimiter)) + # Create indexes that allow you to identify blobs based on their name + source_names_index = {a.name[source_object_prefix_len:]: a for a in source_blobs} + destination_names_index = {a.name[destination_object_prefix_len:]: a for a in destination_blobs} + # Create sets with names without parent object name + source_names = set(source_names_index.keys()) + destination_names = set(destination_names_index.keys()) + # Determine objects to copy and delete + to_copy = source_names - destination_names + to_delete = destination_names - source_names + to_copy_blobs = {source_names_index[a] for a in to_copy} # type: Set[storage.Blob] + to_delete_blobs = {destination_names_index[a] for a in to_delete} # type: Set[storage.Blob] + # Find names that are in both buckets + names_to_check = source_names.intersection(destination_names) + to_rewrite_blobs = set() # type: Set[storage.Blob] + # Compare objects based on crc32 + for current_name in names_to_check: + source_blob = source_names_index[current_name] + destination_blob = destination_names_index[current_name] + # if the objects are different, save it + if source_blob.crc32c != destination_blob.crc32c: + to_rewrite_blobs.add(source_blob) + return to_copy_blobs, to_delete_blobs, to_rewrite_blobs + + +def _parse_gcs_url(gsurl): + """ + Given a Google Cloud Storage URL (gs:///), returns a + tuple containing the corresponding bucket and blob. + """ + + parsed_url = urlparse(gsurl) + if not parsed_url.netloc: + raise AirflowException('Please provide a bucket name') + else: + bucket = parsed_url.netloc + # Remove leading '/' but NOT trailing one + blob = parsed_url.path.lstrip('/') + return bucket, blob diff --git a/airflow/gcp/hooks/kms.py b/airflow/gcp/hooks/kms.py index 3910ca0148985d..cf49fec2b2ae58 100644 --- a/airflow/gcp/hooks/kms.py +++ b/airflow/gcp/hooks/kms.py @@ -23,41 +23,68 @@ import base64 -from googleapiclient.discovery import build +from typing import Optional, Sequence, Tuple + +from google.api_core.retry import Retry +from google.cloud.kms_v1 import KeyManagementServiceClient from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook -def _b64encode(s): +def _b64encode(s: bytes) -> str: """ Base 64 encodes a bytes object to a string """ - return base64.b64encode(s).decode('ascii') + return base64.b64encode(s).decode("ascii") -def _b64decode(s): +def _b64decode(s: str) -> bytes: """ Base 64 decodes a string to bytes. """ - return base64.b64decode(s.encode('utf-8')) + return base64.b64decode(s.encode("utf-8")) +# noinspection PyAbstractClass class GoogleCloudKMSHook(GoogleCloudBaseHook): """ - Interact with Google Cloud KMS. This hook uses the Google Cloud Platform - connection. + Hook for Google Cloud Key Management service. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str """ - def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: str = None) -> None: - super().__init__(gcp_conn_id, delegate_to=delegate_to) + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: Optional[str] = None + ) -> None: + super().__init__(gcp_conn_id=gcp_conn_id, delegate_to=delegate_to) + self._conn = None # type: Optional[KeyManagementServiceClient] - def get_conn(self): + def get_conn(self) -> KeyManagementServiceClient: """ - Returns a KMS service object. + Retrieves connection to Cloud Key Management service. - :rtype: googleapiclient.discovery.Resource + :return: Cloud Key Management service object + :rtype: google.cloud.kms_v1.KeyManagementServiceClient """ - http_authorized = self._authorize() - return build( - 'cloudkms', 'v1', http=http_authorized, cache_discovery=False) - - def encrypt(self, key_name: str, plaintext: bytes, authenticated_data: bytes = None) -> str: + if not self._conn: + self._conn = KeyManagementServiceClient( + credentials=self._get_credentials(), + client_info=self.client_info + ) + return self._conn + + def encrypt( + self, + key_name: str, + plaintext: bytes, + authenticated_data: Optional[bytes] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> str: """ Encrypts a plaintext message using Google Cloud KMS. @@ -71,20 +98,37 @@ def encrypt(self, key_name: str, plaintext: bytes, authenticated_data: bytes = N must also be provided to decrypt the message. :type authenticated_data: bytes :return: The base 64 encoded ciphertext of the original message. + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] :rtype: str """ - keys = self.get_conn().projects().locations().keyRings().cryptoKeys() # pylint: disable=no-member - body = {'plaintext': _b64encode(plaintext)} - if authenticated_data: - body['additionalAuthenticatedData'] = _b64encode(authenticated_data) - - request = keys.encrypt(name=key_name, body=body) - response = request.execute(num_retries=self.num_retries) - - ciphertext = response['ciphertext'] + response = self.get_conn().encrypt( + name=key_name, + plaintext=plaintext, + additional_authenticated_data=authenticated_data, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + ciphertext = _b64encode(response.ciphertext) return ciphertext - def decrypt(self, key_name: str, ciphertext: str, authenticated_data: bytes = None) -> bytes: + def decrypt( + self, + key_name: str, + ciphertext: str, + authenticated_data: Optional[bytes] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> bytes: """ Decrypts a ciphertext message using Google Cloud KMS. @@ -96,16 +140,25 @@ def decrypt(self, key_name: str, ciphertext: str, authenticated_data: bytes = No :param authenticated_data: Any additional authenticated data that was provided when encrypting the message. :type authenticated_data: bytes + :param retry: A retry object used to retry requests. If None is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + retry is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: sequence[tuple[str, str]]] :return: The original message. :rtype: bytes """ - keys = self.get_conn().projects().locations().keyRings().cryptoKeys() # pylint: disable=no-member - body = {'ciphertext': ciphertext} - if authenticated_data: - body['additionalAuthenticatedData'] = _b64encode(authenticated_data) - - request = keys.decrypt(name=key_name, body=body) - response = request.execute(num_retries=self.num_retries) - - plaintext = _b64decode(response['plaintext']) + response = self.get_conn().decrypt( + name=key_name, + ciphertext=_b64decode(ciphertext), + additional_authenticated_data=authenticated_data, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + plaintext = response.plaintext return plaintext diff --git a/airflow/gcp/hooks/kubernetes_engine.py b/airflow/gcp/hooks/kubernetes_engine.py index 8c60c4a2a6107c..235ec7ac102b03 100644 --- a/airflow/gcp/hooks/kubernetes_engine.py +++ b/airflow/gcp/hooks/kubernetes_engine.py @@ -22,6 +22,7 @@ """ import time +import warnings from typing import Dict, Union, Optional from google.api_core.exceptions import AlreadyExists, NotFound @@ -58,7 +59,7 @@ def __init__( self._client = None self.location = location - def get_client(self) -> container_v1.ClusterManagerClient: + def get_conn(self) -> container_v1.ClusterManagerClient: """ Returns ClusterManagerCLinet object. @@ -72,6 +73,13 @@ def get_client(self) -> container_v1.ClusterManagerClient: ) return self._client + # To preserve backward compatibility + # TODO: remove one day + def get_client(self) -> container_v1.ClusterManagerClient: # pylint: disable=missing-docstring + warnings.warn("The get_client method has been deprecated. " + "You should use the get_conn method.", DeprecationWarning) + return self.get_conn() + def wait_for_operation(self, operation: Operation, project_id: str = None) -> Operation: """ Given an operation, continuously fetches the status from Google Cloud until either @@ -106,9 +114,9 @@ def get_operation(self, operation_name: str, project_id: str = None) -> Operatio :type project_id: str :return: The new, updated operation from Google Cloud """ - return self.get_client().get_operation(project_id=project_id or self.project_id, - zone=self.location, - operation_id=operation_name) + return self.get_conn().get_operation(project_id=project_id or self.project_id, + zone=self.location, + operation_id=operation_name) @staticmethod def _append_label(cluster_proto: Cluster, key: str, val: str) -> Cluster: @@ -131,6 +139,7 @@ def _append_label(cluster_proto: Cluster, key: str, val: str) -> Cluster: cluster_proto.resource_labels.update({key: val}) return cluster_proto + @GoogleCloudBaseHook.fallback_to_default_project_id def delete_cluster( self, name: str, @@ -161,15 +170,15 @@ def delete_cluster( """ self.log.info( - "Deleting (project_id=%s, zone=%s, cluster_id=%s)", self.project_id, self.location, name + "Deleting (project_id=%s, zone=%s, cluster_id=%s)", project_id, self.location, name ) try: - resource = self.get_client().delete_cluster(project_id=project_id or self.project_id, - zone=self.location, - cluster_id=name, - retry=retry, - timeout=timeout) + resource = self.get_conn().delete_cluster(project_id=project_id, + zone=self.location, + cluster_id=name, + retry=retry, + timeout=timeout) resource = self.wait_for_operation(resource) # Returns server-defined url for the resource return resource.self_link @@ -177,6 +186,7 @@ def delete_cluster( self.log.info('Assuming Success: %s', error.message) return None + @GoogleCloudBaseHook.fallback_to_default_project_id def create_cluster( self, cluster: Union[Dict, Cluster], @@ -219,14 +229,14 @@ def create_cluster( self.log.info( "Creating (project_id=%s, zone=%s, cluster_name=%s)", - self.project_id, self.location, cluster.name + project_id, self.location, cluster.name ) try: - resource = self.get_client().create_cluster(project_id=project_id or self.project_id, - zone=self.location, - cluster=cluster, - retry=retry, - timeout=timeout) + resource = self.get_conn().create_cluster(project_id=project_id, + zone=self.location, + cluster=cluster, + retry=retry, + timeout=timeout) resource = self.wait_for_operation(resource) return resource.target_link @@ -234,6 +244,7 @@ def create_cluster( self.log.info('Assuming Success: %s', error.message) return self.get_cluster(name=cluster.name).self_link + @GoogleCloudBaseHook.fallback_to_default_project_id def get_cluster( self, name: str, @@ -262,8 +273,8 @@ def get_cluster( project_id or self.project_id, self.location, name ) - return self.get_client().get_cluster(project_id=project_id or self.project_id, - zone=self.location, - cluster_id=name, - retry=retry, - timeout=timeout).self_link + return self.get_conn().get_cluster(project_id=project_id, + zone=self.location, + cluster_id=name, + retry=retry, + timeout=timeout).self_link diff --git a/airflow/gcp/hooks/mlengine.py b/airflow/gcp/hooks/mlengine.py index 0358ffbffc14f4..6ef8987abd8882 100644 --- a/airflow/gcp/hooks/mlengine.py +++ b/airflow/gcp/hooks/mlengine.py @@ -1,18 +1,20 @@ # -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. """ This module contains a Google ML Engine Hook. """ diff --git a/airflow/gcp/hooks/pubsub.py b/airflow/gcp/hooks/pubsub.py index a23aa355e40361..a53f027aea466e 100644 --- a/airflow/gcp/hooks/pubsub.py +++ b/airflow/gcp/hooks/pubsub.py @@ -19,7 +19,7 @@ """ This module contains a Google Pub/Sub Hook. """ - +from typing import Any, List, Dict, Optional from uuid import uuid4 from googleapiclient.discovery import build @@ -53,7 +53,7 @@ class PubSubHook(GoogleCloudBaseHook): def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: str = None) -> None: super().__init__(gcp_conn_id, delegate_to=delegate_to) - def get_conn(self): + def get_conn(self) -> Any: """ Returns a Pub/Sub service object. @@ -63,7 +63,7 @@ def get_conn(self): return build( 'pubsub', 'v1', http=http_authorized, cache_discovery=False) - def publish(self, project, topic, messages): + def publish(self, project: str, topic: str, messages: List[Dict]) -> None: """ Publishes messages to a Pub/Sub topic. @@ -87,7 +87,7 @@ def publish(self, project, topic, messages): raise PubSubException( 'Error publishing to topic {}'.format(full_topic), e) - def create_topic(self, project, topic, fail_if_exists=False): + def create_topic(self, project: str, topic: str, fail_if_exists: bool = False) -> None: """ Creates a Pub/Sub topic, if it does not already exist. @@ -117,7 +117,7 @@ def create_topic(self, project, topic, fail_if_exists=False): raise PubSubException( 'Error creating topic {}'.format(full_topic), e) - def delete_topic(self, project, topic, fail_if_not_exists=False): + def delete_topic(self, project: str, topic: str, fail_if_not_exists: bool = False) -> None: """ Deletes a Pub/Sub topic if it exists. @@ -146,9 +146,15 @@ def delete_topic(self, project, topic, fail_if_not_exists=False): raise PubSubException( 'Error deleting topic {}'.format(full_topic), e) - def create_subscription(self, topic_project, topic, subscription=None, - subscription_project=None, ack_deadline_secs=10, - fail_if_exists=False): + def create_subscription( + self, + topic_project: str, + topic: str, + subscription: Optional[str] = None, + subscription_project: Optional[str] = None, + ack_deadline_secs: int = 10, + fail_if_exists: bool = False, + ) -> str: """ Creates a Pub/Sub subscription, if it does not already exist. @@ -204,8 +210,7 @@ def create_subscription(self, topic_project, topic, subscription=None, e) return subscription - def delete_subscription(self, project, subscription, - fail_if_not_exists=False): + def delete_subscription(self, project: str, subscription: str, fail_if_not_exists: bool = False) -> None: """ Deletes a Pub/Sub subscription, if it exists. @@ -236,8 +241,9 @@ def delete_subscription(self, project, subscription, 'Error deleting subscription {}'.format(full_subscription), e) - def pull(self, project, subscription, max_messages, - return_immediately=False): + def pull( + self, project: str, subscription: str, max_messages: int, return_immediately: bool = False + ) -> List[Dict]: """ Pulls up to ``max_messages`` messages from Pub/Sub subscription. @@ -273,7 +279,7 @@ def pull(self, project, subscription, max_messages, 'Error pulling messages from subscription {}'.format( full_subscription), e) - def acknowledge(self, project, subscription, ack_ids): + def acknowledge(self, project: str, subscription: str, ack_ids: List) -> None: """ Pulls up to ``max_messages`` messages from Pub/Sub subscription. diff --git a/airflow/gcp/hooks/tasks.py b/airflow/gcp/hooks/tasks.py index 99f2ee85ce5809..cc7b3e572fac30 100644 --- a/airflow/gcp/hooks/tasks.py +++ b/airflow/gcp/hooks/tasks.py @@ -466,7 +466,7 @@ def create_task( :type task_name: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.types.Task.View + :type response_view: google.cloud.tasks_v2.enums.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -528,7 +528,7 @@ def get_task( :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.types.Task.View + :type response_view: google.cloud.tasks_v2.enums.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -577,7 +577,7 @@ def list_tasks( :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.types.Task.View + :type response_view: google.cloud.tasks_v2.enums.Task.View :param page_size: (Optional) The maximum number of resources contained in the underlying API response. :type page_size: int @@ -674,7 +674,7 @@ def run_task( :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.types.Task.View + :type response_view: google.cloud.tasks_v2.enums.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry diff --git a/airflow/gcp/operators/automl.py b/airflow/gcp/operators/automl.py index 005229294f8022..3a3b6b7ca56dd1 100644 --- a/airflow/gcp/operators/automl.py +++ b/airflow/gcp/operators/automl.py @@ -22,7 +22,7 @@ This module contains Google AutoML operators. """ import ast -from typing import Sequence, Tuple, Union, List, Dict +from typing import Sequence, Tuple, Union, List, Dict, Optional from google.api_core.retry import Retry from google.protobuf.json_format import MessageToDict @@ -31,6 +31,8 @@ from airflow.utils.decorators import apply_defaults from airflow.gcp.hooks.automl import CloudAutoMLHook +MetaData = Sequence[Tuple[str, str]] + class AutoMLTrainModelOperator(BaseOperator): """ @@ -66,14 +68,14 @@ def __init__( self, model: dict, location: str, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.model = model @@ -142,15 +144,15 @@ def __init__( model_id: str, location: str, payload: dict, - params: Dict[str, str] = None, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + params: Optional[Dict[str, str]] = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.model_id = model_id @@ -236,15 +238,15 @@ def __init__( # pylint:disable=too-many-arguments input_config: dict, output_config: dict, location: str, - project_id: str = None, - params: Dict[str, str] = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + params: Optional[Dict[str, str]] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.model_id = model_id @@ -314,14 +316,14 @@ def __init__( self, dataset: dict, location: str, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.dataset = dataset @@ -391,14 +393,14 @@ def __init__( dataset_id: str, location: str, input_config: dict, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.dataset_id = dataset_id @@ -482,17 +484,17 @@ def __init__( # pylint:disable=too-many-arguments dataset_id: str, table_spec_id: str, location: str, - field_mask: dict = None, - filter_: str = None, - page_size: int = None, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + field_mask: Optional[dict] = None, + filter_: Optional[str] = None, + page_size: Optional[int] = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.dataset_id = dataset_id self.table_spec_id = table_spec_id @@ -567,15 +569,15 @@ def __init__( self, dataset: dict, location: str, - project_id: str = None, - update_mask: dict = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + update_mask: Optional[dict] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.dataset = dataset @@ -638,14 +640,14 @@ def __init__( self, model_id: str, location: str, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.model_id = model_id @@ -705,14 +707,14 @@ def __init__( self, model_id: str, location: str, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.model_id = model_id @@ -782,15 +784,15 @@ def __init__( self, model_id: str, location: str, - project_id: str = None, - image_detection_metadata: dict = None, + project_id: Optional[str] = None, + image_detection_metadata: Optional[dict] = None, metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.model_id = model_id @@ -862,16 +864,16 @@ def __init__( self, dataset_id: str, location: str, - page_size: int = None, - filter_: str = None, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + page_size: Optional[int] = None, + filter_: Optional[str] = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.dataset_id = dataset_id self.filter_ = filter_ @@ -933,14 +935,14 @@ class AutoMLListDatasetOperator(BaseOperator): def __init__( self, location: str, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.project_id = project_id @@ -1005,14 +1007,14 @@ def __init__( self, dataset_id: Union[str, List[str]], location: str, - project_id: str = None, - metadata: Sequence[Tuple[str, str]] = None, - timeout: float = None, - retry: Retry = None, + project_id: Optional[str] = None, + metadata: Optional[MetaData] = None, + timeout: Optional[float] = None, + retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.dataset_id = dataset_id diff --git a/airflow/gcp/operators/bigquery.py b/airflow/gcp/operators/bigquery.py new file mode 100644 index 00000000000000..342a08a28080fe --- /dev/null +++ b/airflow/gcp/operators/bigquery.py @@ -0,0 +1,1301 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains Google BigQuery operators. +""" +# pylint:disable=too-many-lines + +import json +import warnings +from typing import Iterable, List, Optional, Union, Dict, Any, SupportsAbs + +from airflow.exceptions import AirflowException +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models.taskinstance import TaskInstance +from airflow.utils.decorators import apply_defaults +from airflow.operators.check_operator import \ + CheckOperator, ValueCheckOperator, IntervalCheckOperator +from airflow.gcp.hooks.bigquery import BigQueryHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook, _parse_gcs_url + + +BIGQUERY_JOB_DETAILS_LINK_FMT = 'https://console.cloud.google.com/bigquery?j={job_id}' + + +class BigQueryCheckOperator(CheckOperator): + """ + Performs checks against BigQuery. The ``BigQueryCheckOperator`` expects + a sql query that will return a single row. Each value on that + first row is evaluated using python ``bool`` casting. If any of the + values return ``False`` the check is failed and errors out. + + Note that Python bool casting evals the following as ``False``: + + * ``False`` + * ``0`` + * Empty string (``""``) + * Empty list (``[]``) + * Empty dictionary or set (``{}``) + + Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if + the count ``== 0``. You can craft much more complex query that could, + for instance, check that the table has the same number of rows as + the source table upstream, or that the count of today's partition is + greater than yesterday's partition, or that a set of metrics are less + than 3 standard deviation for the 7 day average. + + This operator can be used as a data quality check in your pipeline, and + depending on where you put it in your DAG, you have the choice to + stop the critical path, preventing from + publishing dubious data, or on the side and receive email alerts + without stopping the progress of the DAG. + + :param sql: the sql to be executed + :type sql: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :type use_legacy_sql: bool + """ + + template_fields = ('sql', 'gcp_conn_id', ) + template_ext = ('.sql', ) + + @apply_defaults + def __init__(self, + sql: str, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + use_legacy_sql: bool = True, + *args, **kwargs) -> None: + super().__init__(sql=sql, *args, **kwargs) + if not bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id # type: ignore + + self.gcp_conn_id = gcp_conn_id + self.sql = sql + self.use_legacy_sql = use_legacy_sql + + def get_db_hook(self): + return BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + use_legacy_sql=self.use_legacy_sql) + + +class BigQueryValueCheckOperator(ValueCheckOperator): + """ + Performs a simple value check using sql code. + + :param sql: the sql to be executed + :type sql: str + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :type use_legacy_sql: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + """ + + template_fields = ('sql', 'gcp_conn_id', 'pass_value', ) + template_ext = ('.sql', ) + + @apply_defaults + def __init__(self, sql: str, + pass_value: Any, + tolerance: Any = None, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + use_legacy_sql: bool = True, + *args, **kwargs) -> None: + super().__init__( + sql=sql, pass_value=pass_value, tolerance=tolerance, + *args, **kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.gcp_conn_id = gcp_conn_id + self.use_legacy_sql = use_legacy_sql + + def get_db_hook(self): + return BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + use_legacy_sql=self.use_legacy_sql) + + +class BigQueryIntervalCheckOperator(IntervalCheckOperator): + """ + Checks that the values of metrics given as SQL expressions are within + a certain tolerance of the ones from days_back before. + + This method constructs a query like so :: + + SELECT {metrics_threshold_dict_key} FROM {table} + WHERE {date_filter_column}= + + :param table: the table name + :type table: str + :param days_back: number of days between ds and the ds we want to check + against. Defaults to 7 days + :type days_back: int + :param metrics_threshold: a dictionary of ratios indexed by metrics, for + example 'COUNT(*)': 1.5 would require a 50 percent or less difference + between the current day, and the prior days_back. + :type metrics_threshold: dict + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :type use_legacy_sql: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + """ + + template_fields = ('table', 'gcp_conn_id', ) + + @apply_defaults + def __init__(self, + table: str, + metrics_thresholds: dict, + date_filter_column: str = 'ds', + days_back: SupportsAbs[int] = -7, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + use_legacy_sql: bool = True, + *args, + **kwargs) -> None: + super().__init__( + table=table, metrics_thresholds=metrics_thresholds, + date_filter_column=date_filter_column, days_back=days_back, + *args, **kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.gcp_conn_id = gcp_conn_id + self.use_legacy_sql = use_legacy_sql + + def get_db_hook(self): + return BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + use_legacy_sql=self.use_legacy_sql) + + +class BigQueryGetDataOperator(BaseOperator): + """ + Fetches the data from a BigQuery table (alternatively fetch data for selected columns) + and returns data in a python list. The number of elements in the returned list will + be equal to the number of rows fetched. Each element in the list will again be a list + where element would represent the columns values for that row. + + **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]`` + + .. note:: + If you pass fields to ``selected_fields`` which are in different order than the + order of columns already in + BQ table, the data will still be in the order of BQ table. + For example if the BQ table has 3 columns as + ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields`` + the data would still be of the form ``'A,B'``. + + **Example**: :: + + get_data = BigQueryGetDataOperator( + task_id='get_data_from_bq', + dataset_id='test_dataset', + table_id='Transaction_partitions', + max_results='100', + selected_fields='DATE', + gcp_conn_id='airflow-conn-id' + ) + + :param dataset_id: The dataset ID of the requested table. (templated) + :type dataset_id: str + :param table_id: The table ID of the requested table. (templated) + :type table_id: str + :param max_results: The maximum number of records (rows) to be fetched + from the table. (templated) + :type max_results: str + :param selected_fields: List of fields to return (comma-separated). If + unspecified, all fields are returned. + :type selected_fields: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param location: The location used for the operation. + :type location: str + """ + template_fields = ('dataset_id', 'table_id', 'max_results') + ui_color = '#e4f0e8' + + @apply_defaults + def __init__(self, + dataset_id: str, + table_id: str, + max_results: str = '100', + selected_fields: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + location: str = None, + *args, + **kwargs) -> None: + super().__init__(*args, **kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.dataset_id = dataset_id + self.table_id = table_id + self.max_results = max_results + self.selected_fields = selected_fields + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.location = location + + def execute(self, context): + self.log.info('Fetching Data from:') + self.log.info('Dataset: %s ; Table: %s ; Max Results: %s', + self.dataset_id, self.table_id, self.max_results) + + hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location) + + conn = hook.get_conn() + cursor = conn.cursor() + response = cursor.get_tabledata(dataset_id=self.dataset_id, + table_id=self.table_id, + max_results=self.max_results, + selected_fields=self.selected_fields) + + self.log.info('Total Extracted rows: %s', response['totalRows']) + rows = response['rows'] + + table_data = [] + for dict_row in rows: + single_row = [] + for fields in dict_row['f']: + single_row.append(fields['v']) + table_data.append(single_row) + + return table_data + + +class BigQueryConsoleLink(BaseOperatorLink): + """ + Helper class for constructing BigQuery link. + """ + name = 'BigQuery Console' + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') + return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else '' + + +class BigQueryConsoleIndexableLink(BaseOperatorLink): + """ + Helper class for constructing BigQuery link. + """ + + def __init__(self, index) -> None: + super().__init__() + self.index = index + + @property + def name(self) -> str: + return 'BigQuery Console #{index}'.format(index=self.index + 1) + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + job_ids = ti.xcom_pull(task_ids=operator.task_id, key='job_id') + if not job_ids: + return None + if len(job_ids) < self.index: + return None + job_id = job_ids[self.index] + return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) + + +# pylint: disable=too-many-instance-attributes +class BigQueryOperator(BaseOperator): + """ + Executes BigQuery SQL queries in a specific BigQuery database + + :param sql: the sql code to be executed (templated) + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql'. + :param destination_dataset_table: A dotted + ``(.|:).
`` that, if set, will store the results + of the query. (templated) + :type destination_dataset_table: str + :param write_disposition: Specifies the action that occurs if the destination table + already exists. (default: 'WRITE_EMPTY') + :type write_disposition: str + :param create_disposition: Specifies whether the job is allowed to create new tables. + (default: 'CREATE_IF_NEEDED') + :type create_disposition: str + :param allow_large_results: Whether to allow large results. + :type allow_large_results: bool + :param flatten_results: If true and query uses legacy SQL dialect, flattens + all nested and repeated fields in the query results. ``allow_large_results`` + must be ``true`` if this is set to ``false``. For standard SQL queries, this + flag is ignored and results are never flattened. + :type flatten_results: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param udf_config: The User Defined Function configuration for the query. + See https://cloud.google.com/bigquery/user-defined-functions for details. + :type udf_config: list + :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). + :type use_legacy_sql: bool + :param maximum_billing_tier: Positive integer that serves as a multiplier + of the basic price. + Defaults to None, in which case it uses the value set in the project. + :type maximum_billing_tier: int + :param maximum_bytes_billed: Limits the bytes billed for this job. + Queries that will have bytes billed beyond this limit will fail + (without incurring a charge). If unspecified, this will be + set to your project default. + :type maximum_bytes_billed: float + :param api_resource_configs: a dictionary that contain params + 'configuration' applied for Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs + for example, {'query': {'useQueryCache': False}}. You could use it + if you need to provide some params that are not supported by BigQueryOperator + like args. + :type api_resource_configs: dict + :param schema_update_options: Allows the schema of the destination + table to be updated as a side effect of the load job. + :type schema_update_options: Optional[Union[list, tuple, set]] + :param query_params: a list of dictionary containing query parameter types and + values, passed to BigQuery. The structure of dictionary should look like + 'queryParameters' in Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs. + For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' }, + 'parameterValue': { 'value': 'romeoandjuliet' } }]. + :type query_params: list + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param priority: Specifies a priority for the query. + Possible values include INTERACTIVE and BATCH. + The default value is INTERACTIVE. + :type priority: str + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + :type time_partitioning: dict + :param cluster_fields: Request that the result of this query be stored sorted + by one or more columns. This is only available in conjunction with + time_partitioning. The order of columns given determines the sort order. + :type cluster_fields: list[str] + :param location: The geographic location of the job. Required except for + US and EU. See details at + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :type location: str + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + """ + + template_fields = ('sql', 'destination_dataset_table', 'labels') + template_ext = ('.sql', ) + ui_color = '#e4f0e8' + + @property + def operator_extra_links(self): + """ + Return operator extra links + """ + if isinstance(self.sql, str): + return ( + BigQueryConsoleLink(), + ) + return ( + BigQueryConsoleIndexableLink(i) for i, _ in enumerate(self.sql) + ) + + # pylint: disable=too-many-arguments, too-many-locals + @apply_defaults + def __init__(self, + sql: Union[str, Iterable], + destination_dataset_table: Optional[str] = None, + write_disposition: Optional[str] = 'WRITE_EMPTY', + allow_large_results: Optional[bool] = False, + flatten_results: Optional[bool] = None, + gcp_conn_id: Optional[str] = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + udf_config: Optional[list] = None, + use_legacy_sql: Optional[bool] = True, + maximum_billing_tier: Optional[int] = None, + maximum_bytes_billed: Optional[float] = None, + create_disposition: Optional[str] = 'CREATE_IF_NEEDED', + schema_update_options: Optional[Union[list, tuple, set]] = None, + query_params: Optional[list] = None, + labels: Optional[dict] = None, + priority: Optional[str] = 'INTERACTIVE', + time_partitioning: Optional[dict] = None, + api_resource_configs: Optional[dict] = None, + cluster_fields: Optional[List[str]] = None, + location: Optional[str] = None, + encryption_configuration: Optional[dict] = None, + *args, + **kwargs) -> None: + super().__init__(*args, **kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.sql = sql + self.destination_dataset_table = destination_dataset_table + self.write_disposition = write_disposition + self.create_disposition = create_disposition + self.allow_large_results = allow_large_results + self.flatten_results = flatten_results + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.udf_config = udf_config + self.use_legacy_sql = use_legacy_sql + self.maximum_billing_tier = maximum_billing_tier + self.maximum_bytes_billed = maximum_bytes_billed + self.schema_update_options = schema_update_options + self.query_params = query_params + self.labels = labels + self.bq_cursor = None + self.priority = priority + self.time_partitioning = time_partitioning + self.api_resource_configs = api_resource_configs + self.cluster_fields = cluster_fields + self.location = location + self.encryption_configuration = encryption_configuration + + def execute(self, context): + if self.bq_cursor is None: + self.log.info('Executing: %s', self.sql) + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + use_legacy_sql=self.use_legacy_sql, + delegate_to=self.delegate_to, + location=self.location, + ) + conn = hook.get_conn() + self.bq_cursor = conn.cursor() + if isinstance(self.sql, str): + job_id = self.bq_cursor.run_query( + sql=self.sql, + destination_dataset_table=self.destination_dataset_table, + write_disposition=self.write_disposition, + allow_large_results=self.allow_large_results, + flatten_results=self.flatten_results, + udf_config=self.udf_config, + maximum_billing_tier=self.maximum_billing_tier, + maximum_bytes_billed=self.maximum_bytes_billed, + create_disposition=self.create_disposition, + query_params=self.query_params, + labels=self.labels, + schema_update_options=self.schema_update_options, + priority=self.priority, + time_partitioning=self.time_partitioning, + api_resource_configs=self.api_resource_configs, + cluster_fields=self.cluster_fields, + encryption_configuration=self.encryption_configuration + ) + elif isinstance(self.sql, Iterable): + job_id = [ + self.bq_cursor.run_query( + sql=s, + destination_dataset_table=self.destination_dataset_table, + write_disposition=self.write_disposition, + allow_large_results=self.allow_large_results, + flatten_results=self.flatten_results, + udf_config=self.udf_config, + maximum_billing_tier=self.maximum_billing_tier, + maximum_bytes_billed=self.maximum_bytes_billed, + create_disposition=self.create_disposition, + query_params=self.query_params, + labels=self.labels, + schema_update_options=self.schema_update_options, + priority=self.priority, + time_partitioning=self.time_partitioning, + api_resource_configs=self.api_resource_configs, + cluster_fields=self.cluster_fields, + encryption_configuration=self.encryption_configuration + ) + for s in self.sql] + else: + raise AirflowException( + "argument 'sql' of type {} is neither a string nor an iterable".format(type(str))) + context['task_instance'].xcom_push(key='job_id', value=job_id) + + def on_kill(self): + super().on_kill() + if self.bq_cursor is not None: + self.log.info('Cancelling running query') + self.bq_cursor.cancel_query() + + +class BigQueryCreateEmptyTableOperator(BaseOperator): + """ + Creates a new, empty table in the specified BigQuery dataset, + optionally with schema. + + The schema to be used for the BigQuery table may be specified in one of + two ways. You may either directly pass the schema fields in, or you may + point the operator to a Google cloud storage object name. The object in + Google cloud storage must be a JSON file with the schema fields in it. + You can also create a table without schema. + + :param project_id: The project to create the table into. (templated) + :type project_id: str + :param dataset_id: The dataset to create the table into. (templated) + :type dataset_id: str + :param table_id: The Name of the table to be created. (templated) + :type table_id: str + :param schema_fields: If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema + + **Example**: :: + + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] + + :type schema_fields: list + :param gcs_schema_object: Full path to the JSON file containing + schema (templated). For + example: ``gs://test-bucket/dir1/dir2/employee_schema.json`` + :type gcs_schema_object: str + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + + .. seealso:: + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning + :type time_partitioning: dict + :param bigquery_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform and + interact with the Bigquery service. + :type bigquery_conn_id: str + :param google_cloud_storage_conn_id: (Optional) The connection ID used to connect to Google Cloud + Platform and interact with the Google Cloud Storage service. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate, if any. For this to + work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param labels: a dictionary containing labels for the table, passed to BigQuery + + **Example (with schema JSON in GCS)**: :: + + CreateTable = BigQueryCreateEmptyTableOperator( + task_id='BigQueryCreateEmptyTableOperator_task', + dataset_id='ODS', + table_id='Employees', + project_id='internal-gcp-project', + gcs_schema_object='gs://schema-bucket/employee_schema.json', + bigquery_conn_id='airflow-conn-id', + google_cloud_storage_conn_id='airflow-conn-id' + ) + + **Corresponding Schema file** (``employee_schema.json``): :: + + [ + { + "mode": "NULLABLE", + "name": "emp_name", + "type": "STRING" + }, + { + "mode": "REQUIRED", + "name": "salary", + "type": "INTEGER" + } + ] + + **Example (with schema in the DAG)**: :: + + CreateTable = BigQueryCreateEmptyTableOperator( + task_id='BigQueryCreateEmptyTableOperator_task', + dataset_id='ODS', + table_id='Employees', + project_id='internal-gcp-project', + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}], + bigquery_conn_id='airflow-conn-id-account', + google_cloud_storage_conn_id='airflow-conn-id' + ) + :type labels: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param location: The location used for the operation. + :type location: str + """ + template_fields = ('dataset_id', 'table_id', 'project_id', + 'gcs_schema_object', 'labels') + ui_color = '#f0eee4' + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__(self, + dataset_id: str, + table_id: str, + project_id: Optional[str] = None, + schema_fields: Optional[List] = None, + gcs_schema_object: Optional[str] = None, + time_partitioning: Optional[Dict] = None, + bigquery_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + location: str = None, + *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.dataset_id = dataset_id + self.table_id = table_id + self.schema_fields = schema_fields + self.gcs_schema_object = gcs_schema_object + self.bigquery_conn_id = bigquery_conn_id + self.google_cloud_storage_conn_id = google_cloud_storage_conn_id + self.delegate_to = delegate_to + self.time_partitioning = {} if time_partitioning is None else time_partitioning + self.labels = labels + self.encryption_configuration = encryption_configuration + self.location = location + + def execute(self, context): + bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + location=self.location) + + if not self.schema_fields and self.gcs_schema_object: + + gcs_bucket, gcs_object = _parse_gcs_url(self.gcs_schema_object) + + gcs_hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.google_cloud_storage_conn_id, + delegate_to=self.delegate_to) + schema_fields = json.loads(gcs_hook.download( + gcs_bucket, + gcs_object).decode("utf-8")) + else: + schema_fields = self.schema_fields + + conn = bq_hook.get_conn() + cursor = conn.cursor() + + cursor.create_empty_table( + project_id=self.project_id, + dataset_id=self.dataset_id, + table_id=self.table_id, + schema_fields=schema_fields, + time_partitioning=self.time_partitioning, + labels=self.labels, + encryption_configuration=self.encryption_configuration + ) + + +# pylint: disable=too-many-instance-attributes +class BigQueryCreateExternalTableOperator(BaseOperator): + """ + Creates a new external table in the dataset with the data in Google Cloud + Storage. + + The schema to be used for the BigQuery table may be specified in one of + two ways. You may either directly pass the schema fields in, or you may + point the operator to a Google cloud storage object name. The object in + Google cloud storage must be a JSON file with the schema fields in it. + + :param bucket: The bucket to point the external table to. (templated) + :type bucket: str + :param source_objects: List of Google cloud storage URIs to point + table to. (templated) + If source_format is 'DATASTORE_BACKUP', the list must only contain a single URI. + :type source_objects: list + :param destination_project_dataset_table: The dotted ``(.).
`` + BigQuery table to load data into (templated). If ```` is not included, + project will be the project defined in the connection json. + :type destination_project_dataset_table: str + :param schema_fields: If set, the schema field list as defined here: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema + + **Example**: :: + + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}] + + Should not be set when source_format is 'DATASTORE_BACKUP'. + :type schema_fields: list + :param schema_object: If set, a GCS object path pointing to a .json file that + contains the schema for the table. (templated) + :type schema_object: str + :param source_format: File format of the data. + :type source_format: str + :param compression: [Optional] The compression type of the data source. + Possible values include GZIP and NONE. + The default value is NONE. + This setting is ignored for Google Cloud Bigtable, + Google Cloud Datastore backups and Avro formats. + :type compression: str + :param skip_leading_rows: Number of rows to skip when loading from a CSV. + :type skip_leading_rows: int + :param field_delimiter: The delimiter to use for the CSV. + :type field_delimiter: str + :param max_bad_records: The maximum number of bad records that BigQuery can + ignore when running the job. + :type max_bad_records: int + :param quote_character: The value that is used to quote data sections in a CSV file. + :type quote_character: str + :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not (false). + :type allow_quoted_newlines: bool + :param allow_jagged_rows: Accept rows that are missing trailing optional columns. + The missing values are treated as nulls. If false, records with missing trailing + columns are treated as bad records, and if there are too many bad records, an + invalid error is returned in the job result. Only applicable to CSV, ignored + for other formats. + :type allow_jagged_rows: bool + :param bigquery_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform and + interact with the Bigquery service. + :type bigquery_conn_id: str + :param google_cloud_storage_conn_id: (Optional) The connection ID used to connect to Google Cloud + Platform and interact with the Google Cloud Storage service. + cloud storage hook. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate, if any. For this to + work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict + :param labels: a dictionary containing labels for the table, passed to BigQuery + :type labels: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param location: The location used for the operation. + :type location: str + """ + template_fields = ('bucket', 'source_objects', + 'schema_object', 'destination_project_dataset_table', 'labels') + ui_color = '#f0eee4' + + # pylint: disable=too-many-arguments + @apply_defaults + def __init__(self, + bucket: str, + source_objects: List, + destination_project_dataset_table: str, + schema_fields: Optional[List] = None, + schema_object: Optional[str] = None, + source_format: str = 'CSV', + compression: str = 'NONE', + skip_leading_rows: int = 0, + field_delimiter: str = ',', + max_bad_records: int = 0, + quote_character: Optional[str] = None, + allow_quoted_newlines: bool = False, + allow_jagged_rows: bool = False, + bigquery_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + src_fmt_configs: Optional[dict] = None, + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + location: str = None, + *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # GCS config + self.bucket = bucket + self.source_objects = source_objects + self.schema_object = schema_object + + # BQ config + self.destination_project_dataset_table = destination_project_dataset_table + self.schema_fields = schema_fields + self.source_format = source_format + self.compression = compression + self.skip_leading_rows = skip_leading_rows + self.field_delimiter = field_delimiter + self.max_bad_records = max_bad_records + self.quote_character = quote_character + self.allow_quoted_newlines = allow_quoted_newlines + self.allow_jagged_rows = allow_jagged_rows + + self.bigquery_conn_id = bigquery_conn_id + self.google_cloud_storage_conn_id = google_cloud_storage_conn_id + self.delegate_to = delegate_to + + self.src_fmt_configs = src_fmt_configs if src_fmt_configs is not None else dict() + self.labels = labels + self.encryption_configuration = encryption_configuration + self.location = location + + def execute(self, context): + bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + location=self.location) + + if not self.schema_fields and self.schema_object and self.source_format != 'DATASTORE_BACKUP': + gcs_hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.google_cloud_storage_conn_id, + delegate_to=self.delegate_to) + schema_fields = json.loads(gcs_hook.download( + self.bucket, + self.schema_object).decode("utf-8")) + else: + schema_fields = self.schema_fields + + source_uris = ['gs://{}/{}'.format(self.bucket, source_object) + for source_object in self.source_objects] + conn = bq_hook.get_conn() + cursor = conn.cursor() + + cursor.create_external_table( + external_project_dataset_table=self.destination_project_dataset_table, + schema_fields=schema_fields, + source_uris=source_uris, + source_format=self.source_format, + compression=self.compression, + skip_leading_rows=self.skip_leading_rows, + field_delimiter=self.field_delimiter, + max_bad_records=self.max_bad_records, + quote_character=self.quote_character, + allow_quoted_newlines=self.allow_quoted_newlines, + allow_jagged_rows=self.allow_jagged_rows, + src_fmt_configs=self.src_fmt_configs, + labels=self.labels, + encryption_configuration=self.encryption_configuration + ) + + +class BigQueryDeleteDatasetOperator(BaseOperator): + """ + This operator deletes an existing dataset from your Project in Big query. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/delete + + :param project_id: The project id of the dataset. + :type project_id: str + :param dataset_id: The dataset to be deleted. + :type dataset_id: str + :param delete_contents: (Optional) Whether to force the deletion even if the dataset is not empty. + Will delete all tables (if any) in the dataset if set to True. + Will raise HttpError 400: "{dataset_id} is still in use" if set to False and dataset is not empty. + The default value is False. + :type delete_contents: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + + **Example**: :: + + delete_temp_data = BigQueryDeleteDatasetOperator( + dataset_id='temp-dataset', + project_id='temp-project', + delete_contents=True, # Force the deletion of the dataset as well as its tables (if any). + gcp_conn_id='_my_gcp_conn_', + task_id='Deletetemp', + dag=dag) + """ + + template_fields = ('dataset_id', 'project_id') + ui_color = '#f00004' + + @apply_defaults + def __init__(self, + dataset_id: str, + project_id: Optional[str] = None, + delete_contents: bool = False, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.dataset_id = dataset_id + self.project_id = project_id + self.delete_contents = delete_contents + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + + super().__init__(*args, **kwargs) + + def execute(self, context): + self.log.info('Dataset id: %s Project id: %s', self.dataset_id, self.project_id) + + bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to) + + conn = bq_hook.get_conn() + cursor = conn.cursor() + + cursor.delete_dataset( + project_id=self.project_id, + dataset_id=self.dataset_id, + delete_contents=self.delete_contents + ) + + +class BigQueryCreateEmptyDatasetOperator(BaseOperator): + """ + This operator is used to create new dataset for your Project in Big query. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + + :param project_id: The name of the project where we want to create the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in dataset_reference. + :type dataset_id: str + :param location: (Optional) The geographic location where the dataset should reside. + There is no default value but the dataset will be created in US if nothing is provided. + :type location: str + :param dataset_reference: Dataset reference that could be provided with request body. + More info: + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_reference: dict + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + **Example**: :: + + create_new_dataset = BigQueryCreateEmptyDatasetOperator( + dataset_id='new-dataset', + project_id='my-project', + dataset_reference={"friendlyName": "New Dataset"} + gcp_conn_id='_my_gcp_conn_', + task_id='newDatasetCreator', + dag=dag) + :param location: The location used for the operation. + :type location: str + """ + + template_fields = ('dataset_id', 'project_id') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + dataset_id: str, + project_id: Optional[str] = None, + dataset_reference: Optional[Dict] = None, + location: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.dataset_id = dataset_id + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.dataset_reference = dataset_reference if dataset_reference else {} + self.delegate_to = delegate_to + + super().__init__(*args, **kwargs) + + def execute(self, context): + self.log.info('Dataset id: %s Project id: %s', self.dataset_id, self.project_id) + + bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location) + + conn = bq_hook.get_conn() + cursor = conn.cursor() + + cursor.create_empty_dataset( + project_id=self.project_id, + dataset_id=self.dataset_id, + dataset_reference=self.dataset_reference, + location=self.location) + + +class BigQueryGetDatasetOperator(BaseOperator): + """ + This operator is used to return the dataset specified by dataset_id. + + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in dataset_reference. + :type dataset_id: str + :param project_id: The name of the project where we want to create the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + template_fields = ('dataset_id', 'project_id') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + dataset_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + self.dataset_id = dataset_id + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + super().__init__(*args, **kwargs) + + def execute(self, context): + bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to) + conn = bq_hook.get_conn() + cursor = conn.cursor() + + self.log.info('Start getting dataset: %s:%s', self.project_id, self.dataset_id) + + return cursor.get_dataset( + dataset_id=self.dataset_id, + project_id=self.project_id) + + +class BigQueryPatchDatasetOperator(BaseOperator): + """ + This operator is used to patch dataset for your Project in BigQuery. + It only replaces fields that are provided in the submitted dataset resource. + + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in dataset_reference. + :type dataset_id: str + :param dataset_resource: Dataset resource that will be provided with request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_resource: dict + :param project_id: The name of the project where we want to create the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + template_fields = ('dataset_id', 'project_id') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + dataset_id: str, + dataset_resource: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + self.dataset_id = dataset_id + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.dataset_resource = dataset_resource + self.delegate_to = delegate_to + super().__init__(*args, **kwargs) + + def execute(self, context): + bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to) + + conn = bq_hook.get_conn() + cursor = conn.cursor() + + self.log.info('Start patching dataset: %s:%s', self.project_id, self.dataset_id) + + return cursor.patch_dataset( + dataset_id=self.dataset_id, + dataset_resource=self.dataset_resource, + project_id=self.project_id) + + +class BigQueryUpdateDatasetOperator(BaseOperator): + """ + This operator is used to update dataset for your Project in BigQuery. + The update method replaces the entire dataset resource, whereas the patch + method only replaces fields that are provided in the submitted dataset resource. + + :param dataset_id: The id of dataset. Don't need to provide, + if datasetId in dataset_reference. + :type dataset_id: str + :param dataset_resource: Dataset resource that will be provided with request body. + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + :type dataset_resource: dict + :param project_id: The name of the project where we want to create the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :rtype: dataset + https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource + """ + + template_fields = ('dataset_id', 'project_id') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + dataset_id: str, + dataset_resource: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + self.dataset_id = dataset_id + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.dataset_resource = dataset_resource + self.delegate_to = delegate_to + super().__init__(*args, **kwargs) + + def execute(self, context): + bq_hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to) + + conn = bq_hook.get_conn() + cursor = conn.cursor() + + self.log.info('Start updating dataset: %s:%s', self.project_id, self.dataset_id) + + return cursor.update_dataset( + dataset_id=self.dataset_id, + dataset_resource=self.dataset_resource, + project_id=self.project_id) + + +class BigQueryTableDeleteOperator(BaseOperator): + """ + Deletes BigQuery tables + + :param deletion_dataset_table: A dotted + ``(.|:).
`` that indicates which table + will be deleted. (templated) + :type deletion_dataset_table: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param ignore_if_missing: if True, then return success even if the + requested table does not exist. + :type ignore_if_missing: bool + :param location: The location used for the operation. + :type location: str + """ + template_fields = ('deletion_dataset_table',) + ui_color = '#ffd1dc' + + @apply_defaults + def __init__(self, + deletion_dataset_table: str, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + ignore_if_missing: bool = False, + location: str = None, + *args, + **kwargs) -> None: + super().__init__(*args, **kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.deletion_dataset_table = deletion_dataset_table + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.ignore_if_missing = ignore_if_missing + self.location = location + + def execute(self, context): + self.log.info('Deleting: %s', self.deletion_dataset_table) + hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location) + conn = hook.get_conn() + cursor = conn.cursor() + cursor.run_table_delete( + deletion_dataset_table=self.deletion_dataset_table, + ignore_if_missing=self.ignore_if_missing) diff --git a/airflow/gcp/operators/bigquery_dts.py b/airflow/gcp/operators/bigquery_dts.py new file mode 100644 index 00000000000000..2fde488ac16409 --- /dev/null +++ b/airflow/gcp/operators/bigquery_dts.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains Google BigQuery Data Transfer Service operators. +""" +from typing import Sequence, Tuple +from google.protobuf.json_format import MessageToDict +from google.api_core.retry import Retry + +from airflow.gcp.hooks.bigquery_dts import ( + BiqQueryDataTransferServiceHook, + get_object_id, +) +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class BigQueryCreateDataTransferOperator(BaseOperator): + """ + Creates a new data transfer configuration. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryCreateDataTransferOperator` + + :param transfer_config: Data transfer configuration to create. + :type transfer_config: dict + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param authorization_code: authorization code to use with this transfer configuration. + This is required if new credentials are needed. + :type authorization_code: Optional[str] + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "transfer_config", + "project_id", + "authorization_code", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + transfer_config: dict, + project_id: str = None, + authorization_code: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.transfer_config = transfer_config + self.authorization_code = authorization_code + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = BiqQueryDataTransferServiceHook(gcp_conn_id=self.gcp_conn_id) + self.log.info("Creating DTS transfer config") + response = hook.create_transfer_config( + project_id=self.project_id, + transfer_config=self.transfer_config, + authorization_code=self.authorization_code, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = MessageToDict(response) + self.log.info("Created DTS transfer config %s", get_object_id(result)) + self.xcom_push(context, key="transfer_config_id", value=get_object_id(result)) + return result + + +class BigQueryDeleteDataTransferConfigOperator(BaseOperator): + """ + Deletes transfer configuration. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryDeleteDataTransferConfigOperator` + + :param transfer_config_id: Id of transfer config to be used. + :type transfer_config_id: str + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ("transfer_config_id", "project_id", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + transfer_config_id: str, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.project_id = project_id + self.transfer_config_id = transfer_config_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = BiqQueryDataTransferServiceHook(gcp_conn_id=self.gcp_conn_id) + hook.delete_transfer_config( + transfer_config_id=self.transfer_config_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator): + """ + Start manual transfer runs to be executed now with schedule_time equal + to current time. The transfer runs can be created for a time range where + the run_time is between start_time (inclusive) and end_time + (exclusive), or for a specific run_time. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryDataTransferServiceStartTransferRunsOperator` + + :param transfer_config_id: Id of transfer config to be used. + :type transfer_config_id: str + :param requested_time_range: Time range for the transfer runs that should be started. + If a dict is provided, it must be of the same form as the protobuf + message `~google.cloud.bigquery_datatransfer_v1.types.TimeRange` + :type requested_time_range: Union[dict, ~google.cloud.bigquery_datatransfer_v1.types.TimeRange] + :param requested_run_time: Specific run_time for a transfer run to be started. The + requested_run_time must not be in the future. If a dict is provided, it + must be of the same form as the protobuf message + `~google.cloud.bigquery_datatransfer_v1.types.Timestamp` + :type requested_run_time: Union[dict, ~google.cloud.bigquery_datatransfer_v1.types.Timestamp] + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "transfer_config_id", + "project_id", + "requested_time_range", + "requested_run_time", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + transfer_config_id: str, + project_id: str = None, + requested_time_range: dict = None, + requested_run_time: dict = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id="google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.project_id = project_id + self.transfer_config_id = transfer_config_id + self.requested_time_range = requested_time_range + self.requested_run_time = requested_run_time + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = BiqQueryDataTransferServiceHook(gcp_conn_id=self.gcp_conn_id) + self.log.info('Submitting manual transfer for %s', self.transfer_config_id) + response = hook.start_manual_transfer_runs( + transfer_config_id=self.transfer_config_id, + requested_time_range=self.requested_time_range, + requested_run_time=self.requested_run_time, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = MessageToDict(response) + run_id = None + if 'runs' in result: + run_id = get_object_id(result['runs'][0]) + self.xcom_push(context, key="run_id", value=run_id) + self.log.info('Transfer run %s submitted successfully.', run_id) + return result diff --git a/airflow/gcp/operators/bigtable.py b/airflow/gcp/operators/bigtable.py index 6004f23f5a60e0..46ef211e1a84f3 100644 --- a/airflow/gcp/operators/bigtable.py +++ b/airflow/gcp/operators/bigtable.py @@ -19,10 +19,11 @@ """ This module contains Google Cloud Bigtable operators. """ - -from typing import Iterable +from enum import IntEnum +from typing import Iterable, List, Optional, Dict import google.api_core.exceptions +from google.cloud.bigtable.column_family import GarbageCollectionRule from airflow import AirflowException from airflow.models import BaseOperator @@ -86,29 +87,31 @@ class BigtableInstanceCreateOperator(BaseOperator, BigtableValidationMixin): :type timeout: int :param timeout: (optional) timeout (in seconds) for instance creation. If None is not specified, Operator will wait indefinitely. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud Platform. + :type gcp_conn_id: str """ REQUIRED_ATTRIBUTES = ('instance_id', 'main_cluster_id', - 'main_cluster_zone') + 'main_cluster_zone') # type: Iterable[str] template_fields = ['project_id', 'instance_id', 'main_cluster_id', - 'main_cluster_zone'] + 'main_cluster_zone'] # type: Iterable[str] @apply_defaults def __init__(self, # pylint: disable=too-many-arguments - instance_id, - main_cluster_id, - main_cluster_zone, - project_id=None, - replica_cluster_id=None, - replica_cluster_zone=None, - instance_display_name=None, - instance_type=None, - instance_labels=None, - cluster_nodes=None, - cluster_storage_type=None, - timeout=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: str, + main_cluster_id: str, + main_cluster_zone: str, + project_id: Optional[str] = None, + replica_cluster_id: Optional[str] = None, + replica_cluster_zone: Optional[str] = None, + instance_display_name: Optional[str] = None, + instance_type: Optional[IntEnum] = None, + instance_labels: Optional[int] = None, + cluster_nodes: Optional[int] = None, + cluster_storage_type: Optional[IntEnum] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.project_id = project_id self.instance_id = instance_id self.main_cluster_id = main_cluster_id @@ -174,16 +177,18 @@ class BigtableInstanceDeleteOperator(BaseOperator, BigtableValidationMixin): :param project_id: Optional, the ID of the GCP project. If set to None or missing, the default project_id from the GCP connection is used. :type project_id: str + :param gcp_conn_id: The connection ID to use to connect to Google Cloud Platform. + :type gcp_conn_id: str """ - REQUIRED_ATTRIBUTES = ('instance_id',) - template_fields = ['project_id', 'instance_id'] + REQUIRED_ATTRIBUTES = ('instance_id',) # type: Iterable[str] + template_fields = ['project_id', 'instance_id'] # type: Iterable[str] @apply_defaults def __init__(self, - instance_id, - project_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.project_id = project_id self.instance_id = instance_id self._validate_inputs() @@ -232,19 +237,21 @@ class BigtableTableCreateOperator(BaseOperator, BigtableValidationMixin): :param column_families: (Optional) A map columns to create. The key is the column_id str and the value is a :class:`google.cloud.bigtable.column_family.GarbageCollectionRule` + :param gcp_conn_id: The connection ID to use to connect to Google Cloud Platform. + :type gcp_conn_id: str """ - REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') - template_fields = ['project_id', 'instance_id', 'table_id'] + REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') # type: Iterable[str] + template_fields = ['project_id', 'instance_id', 'table_id'] # type: Iterable[str] @apply_defaults def __init__(self, - instance_id, - table_id, - project_id=None, - initial_split_keys=None, - column_families=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: str, + table_id: str, + project_id: Optional[str] = None, + initial_split_keys: Optional[List] = None, + column_families: Optional[Dict[str, GarbageCollectionRule]] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.project_id = project_id self.instance_id = instance_id self.table_id = table_id @@ -320,18 +327,20 @@ class BigtableTableDeleteOperator(BaseOperator, BigtableValidationMixin): the default project_id from the GCP connection is used. :type app_profile_id: str :parm app_profile_id: Application profile. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud Platform. + :type gcp_conn_id: str """ - REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') - template_fields = ['project_id', 'instance_id', 'table_id'] + REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') # type: Iterable[str] + template_fields = ['project_id', 'instance_id', 'table_id'] # type: Iterable[str] @apply_defaults def __init__(self, - instance_id, - table_id, - project_id=None, - app_profile_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: str, + table_id: str, + project_id: Optional[str] = None, + app_profile_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.project_id = project_id self.instance_id = instance_id self.table_id = table_id @@ -383,18 +392,20 @@ class BigtableClusterUpdateOperator(BaseOperator, BigtableValidationMixin): :param nodes: The desired number of nodes for the Cloud Bigtable cluster. :type project_id: str :param project_id: Optional, the ID of the GCP project. + :param gcp_conn_id: The connection ID to use to connect to Google Cloud Platform. + :type gcp_conn_id: str """ - REQUIRED_ATTRIBUTES = ('instance_id', 'cluster_id', 'nodes') - template_fields = ['project_id', 'instance_id', 'cluster_id', 'nodes'] + REQUIRED_ATTRIBUTES = ('instance_id', 'cluster_id', 'nodes') # type: Iterable[str] + template_fields = ['project_id', 'instance_id', 'cluster_id', 'nodes'] # type: Iterable[str] @apply_defaults def __init__(self, - instance_id, - cluster_id, - nodes, - project_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: str, + cluster_id: str, + nodes: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.project_id = project_id self.instance_id = instance_id self.cluster_id = cluster_id diff --git a/airflow/gcp/operators/cloud_build.py b/airflow/gcp/operators/cloud_build.py index 8a11859db518c5..b63d3e55382060 100644 --- a/airflow/gcp/operators/cloud_build.py +++ b/airflow/gcp/operators/cloud_build.py @@ -19,6 +19,7 @@ """Operators that integrat with Google Cloud Build service.""" from copy import deepcopy import re +from typing import Dict, Iterable, Any, Optional from urllib.parse import urlparse, unquote from airflow import AirflowException @@ -42,7 +43,7 @@ class BuildProcessor: See: https://cloud.google.com/cloud-build/docs/api/reference/rest/Shared.Types/Build :type body: dict """ - def __init__(self, body): + def __init__(self, body: Dict) -> None: self.body = deepcopy(body) def _verify_source(self): @@ -127,7 +128,7 @@ def _convert_repo_url_to_dict(source): return source_dict @staticmethod - def _convert_storage_url_to_dict(storage_url): + def _convert_storage_url_to_dict(storage_url: str) -> Dict[str, Any]: """ Convert url to object in Google Cloud Storage to a format supported by the API @@ -165,18 +166,24 @@ class CloudBuildCreateBuildOperator(BaseOperator): :param body: The request body. See: https://cloud.google.com/cloud-build/docs/api/reference/rest/Shared.Types/Build :type body: dict + :param project_id: ID of the Google Cloud project if None then + default project_id is used. + :type project_id: str :param gcp_conn_id: The connection ID to use to connect to Google Cloud Platform. :type gcp_conn_id: str :param api_version: API version used (for example v1 or v1beta1). :type api_version: str """ - template_fields = ("body", "gcp_conn_id", "api_version") + template_fields = ("body", "gcp_conn_id", "api_version") # type: Iterable[str] @apply_defaults - def __init__( - self, body, project_id=None, gcp_conn_id="google_cloud_default", api_version="v1", *args, **kwargs - ): + def __init__(self, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.body = body self.project_id = project_id diff --git a/airflow/gcp/operators/cloud_memorystore.py b/airflow/gcp/operators/cloud_memorystore.py new file mode 100644 index 00000000000000..41916656653ca7 --- /dev/null +++ b/airflow/gcp/operators/cloud_memorystore.py @@ -0,0 +1,934 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operators for Google Cloud Memorystore service""" +from typing import Dict, Sequence, Tuple, Union + +from google.api_core.retry import Retry +from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest +from google.cloud.redis_v1.types import FieldMask, InputConfig, Instance, OutputConfig +from google.protobuf.json_format import MessageToDict + +from airflow.gcp.hooks.cloud_memorystore import CloudMemorystoreHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class CloudMemorystoreCreateInstanceOperator(BaseOperator): + """ + Creates a Redis instance based on the specified tier and memory size. + + By default, the instance is accessible from the project's `default network + `__. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreCreateInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: Required. The logical name of the Redis instance in the customer project with the + following restrictions: + + - Must contain only lowercase letters, numbers, and hyphens. + - Must start with a letter. + - Must be between 1-40 characters. + - Must end with a number or a letter. + - Must be unique within the customer project / location + :type instance_id: str + :param instance: Required. A Redis [Instance] resource + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "instance_id", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location: str, + instance_id: str, + instance: Union[Dict, Instance], + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.instance_id = instance_id + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + result = hook.create_instance( + location=self.location, + instance_id=self.instance_id, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(result) + + +class CloudMemorystoreDeleteInstanceOperator(BaseOperator): + """ + Deletes a specific Redis instance. Instance stops serving and data is deleted. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreDeleteInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ("location", "instance", "project_id", "retry", "timeout", "metadata", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location: str, + instance: str, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + hook.delete_instance( + location=self.location, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreExportInstanceOperator(BaseOperator): + """ + Export Redis instance data into a Redis RDB format file in Cloud Storage. + + Redis will continue serving during this operation. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreExportInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param output_config: Required. Specify data to be exported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.OutputConfig` + :type output_config: Union[Dict, google.cloud.redis_v1.types.OutputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "instance", + "output_config", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location: str, + instance: str, + output_config: Union[Dict, OutputConfig], + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.instance = instance + self.output_config = output_config + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + + hook.export_instance( + location=self.location, + instance=self.instance, + output_config=self.output_config, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreFailoverInstanceOperator(BaseOperator): + """ + Initiates a failover of the master node to current replica node for a specific STANDARD tier Cloud + Memorystore for Redis instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreFailoverInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param data_protection_mode: Optional. Available data protection modes that the user can choose. If it's + unspecified, data protection mode will be LIMITED_DATA_LOSS by default. + :type data_protection_mode: google.cloud.redis_v1.gapic.enums.FailoverInstanceRequest.DataProtectionMode + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "instance", + "data_protection_mode", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location: str, + instance: str, + data_protection_mode: FailoverInstanceRequest.DataProtectionMode, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.instance = instance + self.data_protection_mode = data_protection_mode + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + hook.failover_instance( + location=self.location, + instance=self.instance, + data_protection_mode=self.data_protection_mode, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreGetInstanceOperator(BaseOperator): + """ + Gets the details of a specific Redis instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreGetInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ("location", "instance", "project_id", "retry", "timeout", "metadata", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location: str, + instance: str, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.instance = instance + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + result = hook.get_instance( + location=self.location, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return MessageToDict(result) + + +class CloudMemorystoreImportOperator(BaseOperator): + """ + Import a Redis RDB snapshot file from Cloud Storage into a Redis instance. + + Redis may stop serving during this operation. Instance state will be IMPORTING for entire operation. When + complete, the instance will contain only data from the imported file. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreImportOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param input_config: Required. Specify data to be imported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.InputConfig` + :type input_config: Union[Dict, google.cloud.redis_v1.types.InputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "instance", + "input_config", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location: str, + instance: str, + input_config: Union[Dict, InputConfig], + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.instance = instance + self.input_config = input_config + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + hook.import_instance( + location=self.location, + instance=self.instance, + input_config=self.input_config, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreListInstancesOperator(BaseOperator): + """ + Lists all Redis instances owned by a project in either the specified location (region) or all locations. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreListInstancesOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + If it is specified as ``-`` (wildcard), then all regions available to the project are + queried, and the results are aggregated. + :type location: str + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per- resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :type page_size: int + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ("location", "page_size", "project_id", "retry", "timeout", "metadata", "gcp_conn_id") + + @apply_defaults + def __init__( + self, + location: str, + page_size: int, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.page_size = page_size + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + result = hook.list_instances( + location=self.location, + page_size=self.page_size, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + instances = [MessageToDict(a) for a in result] + return instances + + +class CloudMemorystoreUpdateInstanceOperator(BaseOperator): + """ + Updates the metadata and configuration of a specific Redis instance. + + :param update_mask: Required. Mask of fields to update. At least one path must be supplied in this field. + The elements of the repeated paths field may only include these fields from ``Instance``: + + - ``displayName`` + - ``labels`` + - ``memorySizeGb`` + - ``redisConfig`` + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.FieldMask` + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreUpdateInstanceOperator` + + :type update_mask: Union[Dict, google.cloud.redis_v1.types.FieldMask] + :param instance: Required. Update description. Only fields specified in update_mask are updated. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Redis instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "update_mask", + "instance", + "location", + "instance_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + update_mask: Union[Dict, FieldMask], + instance: Union[Dict, Instance], + location: str = None, + instance_id: str = None, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.update_mask = update_mask + self.instance = instance + self.location = location + self.instance_id = instance_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + hook.update_instance( + update_mask=self.update_mask, + instance=self.instance, + location=self.location, + instance_id=self.instance_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreScaleInstanceOperator(BaseOperator): + """ + Updates the metadata and configuration of a specific Redis instance. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreScaleInstanceOperator` + + :param memory_size_gb: Redis memory size in GiB. + :type memory_size_gb: int + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: The logical name of the Redis instance in the customer project. + :type instance_id: str + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "memory_size_gb", + "instance", + "location", + "instance_id", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + memory_size_gb: int, + location: str = None, + instance_id: str = None, + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.memory_size_gb = memory_size_gb + self.location = location + self.instance_id = instance_id + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + + hook.update_instance( + update_mask={"paths": ["memory_size_gb"]}, + instance={"memory_size_gb": self.memory_size_gb}, + location=self.location, + instance_id=self.instance_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreCreateInstanceAndImportOperator(BaseOperator): + """ + Creates a Redis instance based on the specified tier and memory size and import a Redis RDB snapshot file + from Cloud Storage into a this instance. + + By default, the instance is accessible from the project's `default network + `__. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreCreateInstanceAndImportOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance_id: Required. The logical name of the Redis instance in the customer project with the + following restrictions: + + - Must contain only lowercase letters, numbers, and hyphens. + - Must start with a letter. + - Must be between 1-40 characters. + - Must end with a number or a letter. + - Must be unique within the customer project / location + :type instance_id: str + :param instance: Required. A Redis [Instance] resource + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.Instance` + :type instance: Union[Dict, google.cloud.redis_v1.types.Instance] + :param input_config: Required. Specify data to be imported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.InputConfig` + :type input_config: Union[Dict, google.cloud.redis_v1.types.InputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "instance_id", + "instance", + "input_config", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location: str, + instance_id: str, + instance: Union[Dict, Instance], + input_config: Union[Dict, InputConfig], + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.instance_id = instance_id + self.instance = instance + self.input_config = input_config + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + + hook.create_instance( + location=self.location, + instance_id=self.instance_id, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + hook.import_instance( + location=self.location, + instance=self.instance, + input_config=self.input_config, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class CloudMemorystoreExportAndDeleteInstanceOperator(BaseOperator): + """ + Export Redis instance data into a Redis RDB format file in Cloud Storage. In next step, deletes a this + instance. + + Redis will continue serving during this operation. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudMemorystoreExportAndDeleteInstanceOperator` + + :param location: The location of the Cloud Memorystore instance (for example europe-west1) + :type location: str + :param instance: The logical name of the Redis instance in the customer project. + :type instance: str + :param output_config: Required. Specify data to be exported. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.redis_v1.types.OutputConfig` + :type output_config: Union[Dict, google.cloud.redis_v1.types.OutputConfig] + :param project_id: Project ID of the project that contains the instance. If set + to None or missing, the default project_id from the GCP connection is used. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :type timeout: float + :param metadata: Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]] + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + """ + + template_fields = ( + "location", + "instance", + "output_config", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + ) + + @apply_defaults + def __init__( + self, + location: str, + instance: str, + output_config: Union[Dict, OutputConfig], + project_id: str = None, + retry: Retry = None, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.location = location + self.instance = instance + self.output_config = output_config + self.project_id = project_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict): + hook = CloudMemorystoreHook(gcp_conn_id=self.gcp_conn_id) + + hook.export_instance( + location=self.location, + instance=self.instance, + output_config=self.output_config, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + hook.delete_instance( + location=self.location, + instance=self.instance, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) diff --git a/airflow/gcp/operators/cloud_sql.py b/airflow/gcp/operators/cloud_sql.py index 455a8218c0e3eb..017dbc88f4732f 100644 --- a/airflow/gcp/operators/cloud_sql.py +++ b/airflow/gcp/operators/cloud_sql.py @@ -19,6 +19,7 @@ """ This module contains Google Cloud SQL operators. """ +from typing import Union, List, Optional, Iterable, Dict from googleapiclient.errors import HttpError @@ -153,18 +154,18 @@ class CloudSqlBaseOperator(BaseOperator): """ @apply_defaults def __init__(self, - instance, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - *args, **kwargs): + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + *args, **kwargs) -> None: self.project_id = project_id self.instance = instance self.gcp_conn_id = gcp_conn_id self.api_version = api_version self._validate_inputs() self._hook = CloudSqlHook(gcp_conn_id=self.gcp_conn_id, - api_version=self.api_version) + api_version=self.api_version) # type: CloudSqlHook super().__init__(*args, **kwargs) def _validate_inputs(self): @@ -235,13 +236,13 @@ class CloudSqlInstanceCreateOperator(CloudSqlBaseOperator): @apply_defaults def __init__(self, - body, - instance, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - validate_body=True, - *args, **kwargs): + body: dict, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + *args, **kwargs) -> None: self.body = body self.validate_body = validate_body super().__init__( @@ -309,12 +310,12 @@ class CloudSqlInstancePatchOperator(CloudSqlBaseOperator): @apply_defaults def __init__(self, - body, - instance, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - *args, **kwargs): + body: dict, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + *args, **kwargs) -> None: self.body = body super().__init__( project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, @@ -361,11 +362,11 @@ class CloudSqlInstanceDeleteOperator(CloudSqlBaseOperator): @apply_defaults def __init__(self, - instance, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - *args, **kwargs): + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + *args, **kwargs) -> None: super().__init__( project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs) @@ -410,13 +411,13 @@ class CloudSqlInstanceDatabaseCreateOperator(CloudSqlBaseOperator): @apply_defaults def __init__(self, - instance, - body, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - validate_body=True, - *args, **kwargs): + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + *args, **kwargs) -> None: self.body = body self.validate_body = validate_body super().__init__( @@ -483,14 +484,14 @@ class CloudSqlInstanceDatabasePatchOperator(CloudSqlBaseOperator): @apply_defaults def __init__(self, - instance, - database, - body, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - validate_body=True, - *args, **kwargs): + instance: str, + database: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + *args, **kwargs) -> None: self.database = database self.body = body self.validate_body = validate_body @@ -552,12 +553,12 @@ class CloudSqlInstanceDatabaseDeleteOperator(CloudSqlBaseOperator): @apply_defaults def __init__(self, - instance, - database, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - *args, **kwargs): + instance: str, + database: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + *args, **kwargs) -> None: self.database = database super().__init__( project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, @@ -614,13 +615,13 @@ class CloudSqlInstanceExportOperator(CloudSqlBaseOperator): @apply_defaults def __init__(self, - instance, - body, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - validate_body=True, - *args, **kwargs): + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + *args, **kwargs) -> None: self.body = body self.validate_body = validate_body super().__init__( @@ -690,13 +691,13 @@ class CloudSqlInstanceImportOperator(CloudSqlBaseOperator): @apply_defaults def __init__(self, - instance, - body, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1beta4', - validate_body=True, - *args, **kwargs): + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + *args, **kwargs) -> None: self.body = body self.validate_body = validate_body super().__init__( @@ -737,7 +738,7 @@ class CloudSqlQueryOperator(BaseOperator): you can use CREATE TABLE IF NOT EXISTS to create a table. :type sql: str or list[str] :param parameters: (optional) the parameters to render the SQL query with. - :type parameters: mapping or iterable + :type parameters: dict or iterable :param autocommit: if True, each command is automatically committed. (default value: False) :type autocommit: bool @@ -757,12 +758,12 @@ class CloudSqlQueryOperator(BaseOperator): @apply_defaults def __init__(self, - sql, - autocommit=False, - parameters=None, - gcp_conn_id='google_cloud_default', - gcp_cloudsql_conn_id='google_cloud_sql_default', - *args, **kwargs): + sql: Union[List[str], str], + autocommit: bool = False, + parameters: Optional[Union[Dict, Iterable]] = None, + gcp_conn_id: str = 'google_cloud_default', + gcp_cloudsql_conn_id: str = 'google_cloud_sql_default', + *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.sql = sql self.gcp_conn_id = gcp_conn_id diff --git a/airflow/gcp/operators/cloud_storage_transfer_service.py b/airflow/gcp/operators/cloud_storage_transfer_service.py index c0ccbe79974796..67268821e862a4 100644 --- a/airflow/gcp/operators/cloud_storage_transfer_service.py +++ b/airflow/gcp/operators/cloud_storage_transfer_service.py @@ -23,6 +23,7 @@ from copy import deepcopy from datetime import date, time +from typing import Optional, Dict from airflow import AirflowException from airflow.gcp.hooks.cloud_storage_transfer_service import ( @@ -67,7 +68,7 @@ class TransferJobPreprocessor: Helper class for preprocess of transfer job body. """ - def __init__(self, body, aws_conn_id='aws_default', default_schedule=False): + def __init__(self, body: dict, aws_conn_id: str = 'aws_default', default_schedule: bool = False) -> None: self.body = body self.aws_conn_id = aws_conn_id self.default_schedule = default_schedule @@ -142,7 +143,7 @@ class TransferJobValidator: """ Helper class for validating transfer job body. """ - def __init__(self, body): + def __init__(self, body: dict) -> None: if not body: raise AirflowException("The required parameter 'body' is empty or None") @@ -220,13 +221,13 @@ class GcpTransferServiceJobCreateOperator(BaseOperator): @apply_defaults def __init__( self, - body, - aws_conn_id='aws_default', - gcp_conn_id='google_cloud_default', - api_version='v1', + body: dict, + aws_conn_id: str = 'aws_default', + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.body = deepcopy(body) self.aws_conn_id = aws_conn_id @@ -279,14 +280,14 @@ class GcpTransferServiceJobUpdateOperator(BaseOperator): @apply_defaults def __init__( self, - job_name, - body, - aws_conn_id='aws_default', - gcp_conn_id='google_cloud_default', - api_version='v1', + job_name: str, + body: dict, + aws_conn_id: str = 'aws_default', + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.job_name = job_name self.body = body @@ -335,8 +336,14 @@ class GcpTransferServiceJobDeleteOperator(BaseOperator): @apply_defaults def __init__( - self, job_name, gcp_conn_id='google_cloud_default', api_version='v1', project_id=None, *args, **kwargs - ): + self, + job_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + project_id: Optional[str] = None, + *args, + **kwargs + ) -> None: super().__init__(*args, **kwargs) self.job_name = job_name self.project_id = project_id @@ -376,7 +383,14 @@ class GcpTransferServiceOperationGetOperator(BaseOperator): # [END gcp_transfer_operation_get_template_fields] @apply_defaults - def __init__(self, operation_name, gcp_conn_id='google_cloud_default', api_version='v1', *args, **kwargs): + def __init__( + self, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + *args, + **kwargs + ) -> None: super().__init__(*args, **kwargs) self.operation_name = operation_name self.gcp_conn_id = gcp_conn_id @@ -402,9 +416,9 @@ class GcpTransferServiceOperationsListOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GcpTransferServiceOperationsListOperator` - :param filter: (Required) A request filter, as described in + :param request_filter: (Required) A request filter, as described in https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/list#body.QUERY_PARAMETERS.filter - :type filter: dict + :type request_filter: dict :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. :type gcp_conn_id: str @@ -416,11 +430,11 @@ class GcpTransferServiceOperationsListOperator(BaseOperator): # [END gcp_transfer_operations_list_template_fields] def __init__(self, - request_filter=None, - gcp_conn_id='google_cloud_default', - api_version='v1', + request_filter: Optional[Dict] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', *args, - **kwargs): + **kwargs) -> None: # To preserve backward compatibility # TODO: remove one day if request_filter is None: @@ -467,7 +481,14 @@ class GcpTransferServiceOperationPauseOperator(BaseOperator): # [END gcp_transfer_operation_pause_template_fields] @apply_defaults - def __init__(self, operation_name, gcp_conn_id='google_cloud_default', api_version='v1', *args, **kwargs): + def __init__( + self, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + *args, + **kwargs + ) -> None: super().__init__(*args, **kwargs) self.operation_name = operation_name self.gcp_conn_id = gcp_conn_id @@ -503,7 +524,14 @@ class GcpTransferServiceOperationResumeOperator(BaseOperator): # [END gcp_transfer_operation_resume_template_fields] @apply_defaults - def __init__(self, operation_name, gcp_conn_id='google_cloud_default', api_version='v1', *args, **kwargs): + def __init__( + self, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + *args, + **kwargs + ) -> None: self.operation_name = operation_name self.gcp_conn_id = gcp_conn_id self.api_version = api_version @@ -540,7 +568,14 @@ class GcpTransferServiceOperationCancelOperator(BaseOperator): # [END gcp_transfer_operation_cancel_template_fields] @apply_defaults - def __init__(self, operation_name, api_version='v1', gcp_conn_id='google_cloud_default', *args, **kwargs): + def __init__( + self, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + *args, + **kwargs + ) -> None: super().__init__(*args, **kwargs) self.operation_name = operation_name self.api_version = api_version @@ -624,21 +659,21 @@ class S3ToGoogleCloudStorageTransferOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments self, - s3_bucket, - gcs_bucket, - project_id=None, - aws_conn_id='aws_default', - gcp_conn_id='google_cloud_default', - delegate_to=None, - description=None, - schedule=None, - object_conditions=None, - transfer_options=None, - wait=True, - timeout=None, + s3_bucket: str, + gcs_bucket: str, + project_id: Optional[str] = None, + aws_conn_id: str = 'aws_default', + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + description: Optional[str] = None, + schedule: Optional[Dict] = None, + object_conditions: Optional[Dict] = None, + transfer_options: Optional[Dict] = None, + wait: bool = True, + timeout: Optional[float] = None, *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.s3_bucket = s3_bucket @@ -763,20 +798,20 @@ class GoogleCloudStorageToGoogleCloudStorageTransferOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments self, - source_bucket, - destination_bucket, - project_id=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - description=None, - schedule=None, - object_conditions=None, - transfer_options=None, - wait=True, - timeout=None, + source_bucket: str, + destination_bucket: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + description: Optional[str] = None, + schedule: Optional[Dict] = None, + object_conditions: Optional[Dict] = None, + transfer_options: Optional[Dict] = None, + wait: bool = True, + timeout: Optional[float] = None, *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.source_bucket = source_bucket diff --git a/airflow/gcp/operators/compute.py b/airflow/gcp/operators/compute.py index 60cd9abab910ac..08a9360c370f99 100644 --- a/airflow/gcp/operators/compute.py +++ b/airflow/gcp/operators/compute.py @@ -21,7 +21,7 @@ """ from copy import deepcopy -from typing import Dict +from typing import Dict, Optional, List, Any from json_merge_patch import merge from googleapiclient.errors import HttpError @@ -41,12 +41,12 @@ class GceBaseOperator(BaseOperator): @apply_defaults def __init__(self, - zone, - resource_id, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1', - *args, **kwargs): + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + *args, **kwargs) -> None: self.project_id = project_id self.zone = zone self.resource_id = resource_id @@ -98,12 +98,12 @@ class GceInstanceStartOperator(GceBaseOperator): @apply_defaults def __init__(self, - zone, - resource_id, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1', - *args, **kwargs): + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + *args, **kwargs) -> None: super().__init__( project_id=project_id, zone=zone, resource_id=resource_id, gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs) @@ -146,12 +146,12 @@ class GceInstanceStopOperator(GceBaseOperator): @apply_defaults def __init__(self, - zone, - resource_id, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1', - *args, **kwargs): + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + *args, **kwargs) -> None: super().__init__( project_id=project_id, zone=zone, resource_id=resource_id, gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs) @@ -204,16 +204,16 @@ class GceSetMachineTypeOperator(GceBaseOperator): @apply_defaults def __init__(self, - zone, - resource_id, - body, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1', - validate_body=True, - *args, **kwargs): + zone: str, + resource_id: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + validate_body: bool = True, + *args, **kwargs) -> None: self.body = body - self._field_validator = None + self._field_validator = None # type: Optional[GcpBodyFieldValidator] if validate_body: self._field_validator = GcpBodyFieldValidator( SET_MACHINE_TYPE_VALIDATION_SPECIFICATION, api_version=api_version) @@ -262,7 +262,7 @@ def execute(self, context): dict(name="guestAccelerators", optional=True), # not validating deeper dict(name="minCpuPlatform", optional=True), ]), -] +] # type: List[Dict[str, Any]] GCE_INSTANCE_TEMPLATE_FIELDS_TO_SANITIZE = [ "kind", @@ -326,17 +326,17 @@ class GceInstanceTemplateCopyOperator(GceBaseOperator): @apply_defaults def __init__(self, - resource_id, - body_patch, - project_id=None, + resource_id: str, + body_patch: dict, + project_id: Optional[str] = None, request_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1', - validate_body=True, - *args, **kwargs): + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + validate_body: bool = True, + *args, **kwargs) -> None: self.body_patch = body_patch self.request_id = request_id - self._field_validator = None + self._field_validator = None # Optional[GcpBodyFieldValidator] if 'name' not in self.body_patch: raise AirflowException("The body '{}' should contain at least " "name for the new operator in the 'name' field". @@ -436,16 +436,16 @@ class GceInstanceGroupManagerUpdateTemplateOperator(GceBaseOperator): @apply_defaults def __init__(self, - resource_id, - zone, - source_template, - destination_template, - project_id=None, + resource_id: str, + zone: str, + source_template: str, + destination_template: str, + project_id: Optional[str] = None, update_policy=None, request_id=None, - gcp_conn_id='google_cloud_default', + gcp_conn_id: str = 'google_cloud_default', api_version='beta', - *args, **kwargs): + *args, **kwargs) -> None: self.zone = zone self.source_template = source_template self.destination_template = destination_template diff --git a/airflow/gcp/operators/dataflow.py b/airflow/gcp/operators/dataflow.py index 675a7e21efa7f7..5f3d75f1a1eae3 100644 --- a/airflow/gcp/operators/dataflow.py +++ b/airflow/gcp/operators/dataflow.py @@ -24,9 +24,11 @@ import re import uuid import copy +import tempfile from enum import Enum +from typing import List, Optional -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.gcp.hooks.dataflow import DataFlowHook from airflow.models import BaseOperator from airflow.version import version @@ -167,18 +169,18 @@ class DataFlowJavaOperator(BaseOperator): @apply_defaults def __init__( self, - jar, - job_name='{{task.task_id}}', - dataflow_default_options=None, - options=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - poll_sleep=10, - job_class=None, - check_if_running=CheckJobRunning.WaitForRun, - multiple_jobs=None, + jar: str, + job_name: str = '{{task.task_id}}', + dataflow_default_options: Optional[dict] = None, + options: Optional[dict] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + job_class: Optional[str] = None, + check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun, + multiple_jobs: Optional[bool] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) dataflow_default_options = dataflow_default_options or {} @@ -295,15 +297,15 @@ class DataflowTemplateOperator(BaseOperator): @apply_defaults def __init__( self, - template, - job_name='{{task.task_id}}', - dataflow_default_options=None, - parameters=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - poll_sleep=10, + template: str, + job_name: str = '{{task.task_id}}', + dataflow_default_options: Optional[dict] = None, + parameters: Optional[dict] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + poll_sleep: int = 10, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) dataflow_default_options = dataflow_default_options or {} @@ -346,11 +348,16 @@ class DataFlowPythonOperator(BaseOperator): with key ``'jobName'`` or ``'job_name'`` in ``options`` will be overwritten. :type job_name: str :param py_options: Additional python options, e.g., ["-m", "-v"]. - :type pyt_options: list[str] + :type py_options: list[str] :param dataflow_default_options: Map of default job options. :type dataflow_default_options: dict :param options: Map of job specific options. :type options: dict + :param py_interpreter: Python version of the beam pipeline. + If None, this defaults to the python2. + To track python versions supported by beam and related + issues check: https://issues.apache.org/jira/browse/BEAM-1251 + :type py_interpreter: str :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. :type gcp_conn_id: str @@ -368,16 +375,17 @@ class DataFlowPythonOperator(BaseOperator): @apply_defaults def __init__( self, - py_file, - job_name='{{task.task_id}}', - py_options=None, - dataflow_default_options=None, - options=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - poll_sleep=10, + py_file: str, + job_name: str = '{{task.task_id}}', + py_options: Optional[List[str]] = None, + dataflow_default_options: Optional[dict] = None, + options: Optional[dict] = None, + py_interpreter: str = "python2", + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + poll_sleep: int = 10, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) @@ -388,6 +396,7 @@ def __init__( self.options = options or {} self.options.setdefault('labels', {}).update( {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) + self.py_interpreter = py_interpreter self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.poll_sleep = poll_sleep @@ -409,7 +418,7 @@ def execute(self, context): for key in dataflow_options} hook.start_python_dataflow( self.job_name, formatted_options, - self.py_file, self.py_options) + self.py_file, self.py_options, py_interpreter=self.py_interpreter) class GoogleCloudBucketHelper: @@ -417,11 +426,11 @@ class GoogleCloudBucketHelper: GCS_PREFIX_LENGTH = 5 def __init__(self, - gcp_conn_id='google_cloud_default', - delegate_to=None): + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None) -> None: self._gcs_hook = GoogleCloudStorageHook(gcp_conn_id, delegate_to) - def google_cloud_to_local(self, file_name): + def google_cloud_to_local(self, file_name: str) -> str: """ Checks whether the file specified by file_name is stored in Google Cloud Storage (GCS), if so, downloads the file and saves it locally. The full @@ -446,8 +455,10 @@ def google_cloud_to_local(self, file_name): bucket_id = path_components[0] object_id = '/'.join(path_components[1:]) - local_file = '/tmp/dataflow{}-{}'.format(str(uuid.uuid4())[:8], - path_components[-1]) + local_file = os.path.join( + tempfile.gettempdir(), + 'dataflow{}-{}'.format(str(uuid.uuid4())[:8], path_components[-1]) + ) self._gcs_hook.download(bucket_id, object_id, local_file) if os.stat(local_file).st_size > 0: diff --git a/airflow/gcp/operators/dataproc.py b/airflow/gcp/operators/dataproc.py index 2e5f5c38094d15..a45fbd4b3a2929 100644 --- a/airflow/gcp/operators/dataproc.py +++ b/airflow/gcp/operators/dataproc.py @@ -27,10 +27,11 @@ import re import time import uuid -from datetime import timedelta +from datetime import datetime, timedelta +from typing import List, Dict, Set, Optional from airflow.gcp.hooks.dataproc import DataProcHook -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -44,12 +45,12 @@ class DataprocOperationBaseOperator(BaseOperator): """ @apply_defaults def __init__(self, - project_id, - region='global', - gcp_conn_id='google_cloud_default', - delegate_to=None, + project_id: str, + region: str = 'global', + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to @@ -194,42 +195,42 @@ class DataprocClusterCreateOperator(DataprocOperationBaseOperator): # pylint: disable=too-many-arguments,too-many-locals @apply_defaults def __init__(self, - project_id, - cluster_name, - num_workers, - zone=None, - network_uri=None, - subnetwork_uri=None, - internal_ip_only=None, - tags=None, - storage_bucket=None, - init_actions_uris=None, - init_action_timeout="10m", - metadata=None, - custom_image=None, - custom_image_project_id=None, - image_version=None, - autoscaling_policy=None, - properties=None, - optional_components=None, - num_masters=1, - master_machine_type='n1-standard-4', - master_disk_type='pd-standard', - master_disk_size=1024, - worker_machine_type='n1-standard-4', - worker_disk_type='pd-standard', - worker_disk_size=1024, - num_preemptible_workers=0, - labels=None, - region='global', - service_account=None, - service_account_scopes=None, - idle_delete_ttl=None, - auto_delete_time=None, - auto_delete_ttl=None, - customer_managed_key=None, + project_id: str, + cluster_name: str, + num_workers: int, + zone: Optional[str] = None, + network_uri: Optional[str] = None, + subnetwork_uri: Optional[str] = None, + internal_ip_only: Optional[bool] = None, + tags: Optional[List[str]] = None, + storage_bucket: Optional[str] = None, + init_actions_uris: Optional[List[str]] = None, + init_action_timeout: str = "10m", + metadata: Optional[Dict] = None, + custom_image: Optional[str] = None, + custom_image_project_id: Optional[str] = None, + image_version: Optional[str] = None, + autoscaling_policy: Optional[str] = None, + properties: Optional[Dict] = None, + optional_components: Optional[List[str]] = None, + num_masters: int = 1, + master_machine_type: str = 'n1-standard-4', + master_disk_type: str = 'pd-standard', + master_disk_size: int = 1024, + worker_machine_type: str = 'n1-standard-4', + worker_disk_type: str = 'pd-standard', + worker_disk_size: int = 1024, + num_preemptible_workers: int = 0, + labels: Optional[Dict] = None, + region: str = 'global', + service_account: Optional[str] = None, + service_account_scopes: Optional[List[str]] = None, + idle_delete_ttl: Optional[int] = None, + auto_delete_time: Optional[datetime] = None, + auto_delete_ttl: Optional[int] = None, + customer_managed_key: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(project_id=project_id, region=region, *args, **kwargs) self.cluster_name = cluster_name @@ -275,8 +276,36 @@ def __init__(self, ) ), "num_workers == 0 means single node mode - no preemptibles allowed" + def _cluster_ready(self, state, service): + if state == 'RUNNING': + return True + if state == 'DELETING': + raise Exception('Tried to create a cluster but it\'s in DELETING, something went wrong.') + if state == 'ERROR': + cluster = DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name) + try: + error_details = cluster['status']['details'] + except KeyError: + error_details = 'Unknown error in cluster creation, ' \ + 'check Google Cloud console for details.' + + self.log.info('Dataproc cluster creation resulted in an ERROR state running diagnostics') + self.log.info(error_details) + diagnose_operation_name = \ + DataProcHook.execute_dataproc_diagnose(service, self.project_id, + self.region, self.cluster_name) + diagnose_result = DataProcHook.wait_for_operation_done(service, diagnose_operation_name) + if diagnose_result.get('response') and diagnose_result.get('response').get('outputUri'): + output_uri = diagnose_result.get('response').get('outputUri') + self.log.info('Diagnostic information for ERROR cluster available at [%s]', output_uri) + else: + self.log.info('Diagnostic information could not be retrieved!') + + raise Exception(error_details) + return False + def _get_init_action_timeout(self): - match = re.match(r"^(\d+)(s|m)$", self.init_action_timeout) + match = re.match(r"^(\d+)([sm])$", self.init_action_timeout) if match: if match.group(2) == "s": return self.init_action_timeout @@ -406,7 +435,7 @@ def _build_cluster_data(self): cluster_data['config']['softwareConfig']['imageVersion'] = self.image_version elif self.custom_image: - project_id = self.custom_image_project_id if (self.custom_image_project_id) else self.project_id + project_id = self.custom_image_project_id or self.project_id custom_image_url = 'https://www.googleapis.com/compute/beta/projects/' \ '{}/global/images/{}'.format(project_id, self.custom_image) @@ -444,11 +473,51 @@ def _build_cluster_data(self): return cluster_data + def _usable_existing_cluster_present(self, service): + existing_cluster = DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name) + if existing_cluster: + self.log.info( + 'Cluster %s already exists... Checking status...', + self.cluster_name + ) + existing_status = self.hook.get_final_cluster_state(self.project_id, + self.region, self.cluster_name, self.log) + + if existing_status == 'RUNNING': + self.log.info('Cluster exists and is already running. Using it.') + return True + + elif existing_status == 'DELETING': + while DataProcHook.find_cluster(service, self.project_id, self.region, self.cluster_name) \ + and DataProcHook.get_cluster_state(service, self.project_id, + self.region, self.cluster_name) == 'DELETING': + self.log.info('Existing cluster is deleting, waiting for it to finish') + time.sleep(15) + + elif existing_status == 'ERROR': + self.log.info('Existing cluster in ERROR state, deleting it first') + + operation_name = DataProcHook.execute_delete(service, self.project_id, + self.region, self.cluster_name) + self.log.info("Cluster delete operation name: %s", operation_name) + DataProcHook.wait_for_operation_done_or_error(service, operation_name) + + return False + def start(self): """ Create a new cluster on Google Cloud Dataproc. """ self.log.info('Creating cluster: %s', self.cluster_name) + hook = DataProcHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to + ) + service = hook.get_conn() + + if self._usable_existing_cluster_present(service): + return True + cluster_data = self._build_cluster_data() return ( @@ -506,21 +575,21 @@ class DataprocClusterScaleOperator(DataprocOperationBaseOperator): @apply_defaults def __init__(self, - cluster_name, - project_id, - region='global', - num_workers=2, - num_preemptible_workers=0, - graceful_decommission_timeout=None, + cluster_name: str, + project_id: str, + region: str = 'global', + num_workers: int = 2, + num_preemptible_workers: int = 0, + graceful_decommission_timeout: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(project_id=project_id, region=region, *args, **kwargs) self.cluster_name = cluster_name self.num_workers = num_workers self.num_preemptible_workers = num_preemptible_workers # Optional - self.optional_arguments = {} + self.optional_arguments = {} # type: Dict if graceful_decommission_timeout: self.optional_arguments['gracefulDecommissionTimeout'] = \ self._get_graceful_decommission_timeout( @@ -541,7 +610,7 @@ def _build_scale_cluster_data(self): @staticmethod def _get_graceful_decommission_timeout(timeout): - match = re.match(r"^(\d+)(s|m|h|d)$", timeout) + match = re.match(r"^(\d+)([smdh])$", timeout) if match: if match.group(2) == "s": return timeout @@ -606,11 +675,11 @@ class DataprocClusterDeleteOperator(DataprocOperationBaseOperator): @apply_defaults def __init__(self, - cluster_name, - project_id, - region='global', + cluster_name: str, + project_id: str, + region: str = 'global', *args, - **kwargs): + **kwargs) -> None: super().__init__(project_id=project_id, region=region, *args, **kwargs) self.cluster_name = cluster_name @@ -674,17 +743,17 @@ class DataProcJobBaseOperator(BaseOperator): @apply_defaults def __init__(self, - job_name='{{task.task_id}}_{{ds_nodash}}', - cluster_name="cluster-1", - dataproc_properties=None, - dataproc_jars=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - labels=None, - region='global', - job_error_states=None, + job_name: str = '{{task.task_id}}_{{ds_nodash}}', + cluster_name: str = "cluster-1", + dataproc_properties: Optional[Dict] = None, + dataproc_jars: Optional[List[str]] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + region: str = 'global', + job_error_states: Optional[Set[str]] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to @@ -780,11 +849,11 @@ class DataProcPigOperator(DataProcJobBaseOperator): @apply_defaults def __init__( self, - query=None, - query_uri=None, - variables=None, + query: Optional[str] = None, + query_uri: Optional[str] = None, + variables: Optional[Dict] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.query = query @@ -823,11 +892,11 @@ class DataProcHiveOperator(DataProcJobBaseOperator): @apply_defaults def __init__( self, - query=None, - query_uri=None, - variables=None, + query: Optional[str] = None, + query_uri: Optional[str] = None, + variables: Optional[Dict] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.query = query @@ -867,11 +936,11 @@ class DataProcSparkSqlOperator(DataProcJobBaseOperator): @apply_defaults def __init__( self, - query=None, - query_uri=None, - variables=None, + query: Optional[str] = None, + query_uri: Optional[str] = None, + variables: Optional[Dict] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.query = query @@ -918,13 +987,13 @@ class DataProcSparkOperator(DataProcJobBaseOperator): @apply_defaults def __init__( self, - main_jar=None, - main_class=None, - arguments=None, - archives=None, - files=None, + main_jar: Optional[str] = None, + main_class: Optional[str] = None, + arguments: Optional[List] = None, + archives: Optional[List] = None, + files: Optional[List] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.main_jar = main_jar @@ -970,13 +1039,13 @@ class DataProcHadoopOperator(DataProcJobBaseOperator): @apply_defaults def __init__( self, - main_jar=None, - main_class=None, - arguments=None, - archives=None, - files=None, + main_jar: Optional[str] = None, + main_class: Optional[str] = None, + arguments: Optional[List] = None, + archives: Optional[List] = None, + files: Optional[List] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.main_jar = main_jar @@ -1049,13 +1118,13 @@ def _upload_file_temp(self, bucket, local_file): @apply_defaults def __init__( self, - main, - arguments=None, - archives=None, - pyfiles=None, - files=None, + main: str, + arguments: Optional[List] = None, + archives: Optional[List] = None, + pyfiles: Optional[List] = None, + files: List = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.main = main @@ -1118,7 +1187,7 @@ class DataprocWorkflowTemplateInstantiateOperator(DataprocOperationBaseOperator) template_fields = ['template_id'] @apply_defaults - def __init__(self, template_id, parameters, *args, **kwargs): + def __init__(self, template_id: str, parameters: Dict[str, str], *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.template_id = template_id self.parameters = parameters @@ -1148,7 +1217,7 @@ class DataprocWorkflowTemplateInstantiateInlineOperator( https://cloud.google.com/dataproc/docs/reference/rest/v1beta2/projects.regions.workflowTemplates/instantiateInline :param template: The template contents. (templated) - :type template: map + :type template: dict :param project_id: The ID of the google cloud project in which the template runs :type project_id: str @@ -1165,7 +1234,7 @@ class DataprocWorkflowTemplateInstantiateInlineOperator( template_fields = ['template'] @apply_defaults - def __init__(self, template, *args, **kwargs): + def __init__(self, template: Dict, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.template = template diff --git a/airflow/gcp/operators/datastore.py b/airflow/gcp/operators/datastore.py index 8ee713fb656da8..568a7da6561694 100644 --- a/airflow/gcp/operators/datastore.py +++ b/airflow/gcp/operators/datastore.py @@ -20,9 +20,10 @@ """ This module contains Google Datastore operators. """ +from typing import Optional from airflow.gcp.hooks.datastore import DatastoreHook -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -59,21 +60,22 @@ class DatastoreExportOperator(BaseOperator): emptied prior to exports. This enables overwriting existing backups. :type overwrite_existing: bool """ + template_fields = ['bucket', 'namespace', 'entity_filter', 'labels'] @apply_defaults def __init__(self, # pylint:disable=too-many-arguments - bucket, - namespace=None, - datastore_conn_id='google_cloud_default', - cloud_storage_conn_id='google_cloud_default', - delegate_to=None, - entity_filter=None, - labels=None, - polling_interval_in_seconds=10, - overwrite_existing=False, - project_id=None, + bucket: str, + namespace: Optional[str] = None, + datastore_conn_id: str = 'google_cloud_default', + cloud_storage_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + entity_filter: Optional[dict] = None, + labels: Optional[dict] = None, + polling_interval_in_seconds: int = 10, + overwrite_existing: bool = False, + project_id: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.datastore_conn_id = datastore_conn_id self.cloud_storage_conn_id = cloud_storage_conn_id @@ -111,7 +113,6 @@ def execute(self, context): state = result['metadata']['common']['state'] if state != 'SUCCESSFUL': raise AirflowException('Operation failed: result={}'.format(result)) - return result @@ -141,22 +142,24 @@ class DatastoreImportOperator(BaseOperator): :type delegate_to: str :param polling_interval_in_seconds: number of seconds to wait before polling for execution status again - :type polling_interval_in_seconds: int + :type polling_interval_in_seconds: float """ + template_fields = ['bucket', 'file', 'namespace', 'entity_filter', 'labels'] + @apply_defaults def __init__(self, - bucket, - file, - namespace=None, - entity_filter=None, - labels=None, - datastore_conn_id='google_cloud_default', - delegate_to=None, - polling_interval_in_seconds=10, - project_id=None, + bucket: str, + file: str, + namespace: Optional[str] = None, + entity_filter: Optional[dict] = None, + labels: Optional[dict] = None, + datastore_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + polling_interval_in_seconds: float = 10, + project_id: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.datastore_conn_id = datastore_conn_id self.delegate_to = delegate_to diff --git a/airflow/gcp/operators/dlp.py b/airflow/gcp/operators/dlp.py index 743e1297e917f8..a1c33d581d0387 100644 --- a/airflow/gcp/operators/dlp.py +++ b/airflow/gcp/operators/dlp.py @@ -23,6 +23,7 @@ which allow you to perform basic operations using Cloud DLP. """ +from typing import Optional from airflow.gcp.hooks.dlp import CloudDLPHook from airflow.models import BaseOperator @@ -58,11 +59,11 @@ class CloudDLPCancelDLPJobOperator(BaseOperator): def __init__( self, dlp_job_id, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -127,13 +128,13 @@ class CloudDLPCreateDeidentifyTemplateOperator(BaseOperator): def __init__( self, organization_id=None, - project_id=None, + project_id: Optional[str] = None, deidentify_template=None, template_id=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -196,15 +197,15 @@ class CloudDLPCreateDLPJobOperator(BaseOperator): @apply_defaults def __init__( self, - project_id=None, + project_id: Optional[str] = None, inspect_job=None, risk_job=None, job_id=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, wait_until_finished=True, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -275,13 +276,13 @@ class CloudDLPCreateInspectTemplateOperator(BaseOperator): def __init__( self, organization_id=None, - project_id=None, + project_id: Optional[str] = None, inspect_template=None, template_id=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -340,13 +341,13 @@ class CloudDLPCreateJobTriggerOperator(BaseOperator): @apply_defaults def __init__( self, - project_id=None, + project_id: Optional[str] = None, job_trigger=None, trigger_id=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -412,13 +413,13 @@ class CloudDLPCreateStoredInfoTypeOperator(BaseOperator): def __init__( self, organization_id=None, - project_id=None, + project_id: Optional[str] = None, config=None, stored_info_type_id=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -497,16 +498,16 @@ class CloudDLPDeidentifyContentOperator(BaseOperator): @apply_defaults def __init__( self, - project_id=None, + project_id: Optional[str] = None, deidentify_config=None, inspect_config=None, item=None, inspect_template_name=None, deidentify_template_name=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -570,11 +571,11 @@ def __init__( self, template_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -629,11 +630,11 @@ class CloudDLPDeleteDlpJobOperator(BaseOperator): def __init__( self, dlp_job_id, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -689,11 +690,11 @@ def __init__( self, template_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -747,11 +748,11 @@ class CloudDLPDeleteJobTriggerOperator(BaseOperator): def __init__( self, job_trigger_id, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -812,11 +813,11 @@ def __init__( self, stored_info_type_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -875,11 +876,11 @@ def __init__( self, template_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -934,11 +935,11 @@ class CloudDLPGetDlpJobOperator(BaseOperator): def __init__( self, dlp_job_id, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -995,11 +996,11 @@ def __init__( self, template_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1054,11 +1055,11 @@ class CloudDLPGetJobTripperOperator(BaseOperator): def __init__( self, job_trigger_id, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1120,11 +1121,11 @@ def __init__( self, stored_info_type_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1191,14 +1192,14 @@ class CloudDLPInspectContentOperator(BaseOperator): @apply_defaults def __init__( self, - project_id=None, + project_id: Optional[str] = None, inspect_config=None, item=None, inspect_template_name=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1262,13 +1263,13 @@ class CloudDLPListDeidentifyTemplatesOperator(BaseOperator): def __init__( self, organization_id=None, - project_id=None, + project_id: Optional[str] = None, page_size=None, order_by=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1332,15 +1333,15 @@ class CloudDLPListDlpJobsOperator(BaseOperator): @apply_defaults def __init__( self, - project_id=None, + project_id: Optional[str] = None, results_filter=None, page_size=None, job_type=None, order_by=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1401,9 +1402,9 @@ def __init__( language_code=None, results_filter=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1463,13 +1464,13 @@ class CloudDLPListInspectTemplatesOperator(BaseOperator): def __init__( self, organization_id=None, - project_id=None, + project_id: Optional[str] = None, page_size=None, order_by=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1531,14 +1532,14 @@ class CloudDLPListJobTriggersOperator(BaseOperator): @apply_defaults def __init__( self, - project_id=None, + project_id: Optional[str] = None, page_size=None, order_by=None, results_filter=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1602,13 +1603,13 @@ class CloudDLPListStoredInfoTypesOperator(BaseOperator): def __init__( self, organization_id=None, - project_id=None, + project_id: Optional[str] = None, page_size=None, order_by=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1681,15 +1682,15 @@ class CloudDLPRedactImageOperator(BaseOperator): @apply_defaults def __init__( self, - project_id=None, + project_id: Optional[str] = None, inspect_config=None, image_redaction_configs=None, include_findings=None, byte_item=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1767,16 +1768,16 @@ class CloudDLPReidentifyContentOperator(BaseOperator): @apply_defaults def __init__( self, - project_id=None, + project_id: Optional[str] = None, reidentify_config=None, inspect_config=None, item=None, inspect_template_name=None, reidentify_template_name=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1852,13 +1853,13 @@ def __init__( self, template_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, deidentify_template=None, update_mask=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -1932,13 +1933,13 @@ def __init__( self, template_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, inspect_template=None, update_mask=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -2007,13 +2008,13 @@ class CloudDLPUpdateJobTriggerOperator(BaseOperator): def __init__( self, job_trigger_id, - project_id=None, + project_id: Optional[str] = None, job_trigger=None, update_mask=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): @@ -2086,13 +2087,13 @@ def __init__( self, stored_info_type_id, organization_id=None, - project_id=None, + project_id: Optional[str] = None, config=None, update_mask=None, retry=None, - timeout=None, + timeout: Optional[float] = None, metadata=None, - gcp_conn_id="google_cloud_default", + gcp_conn_id: str = "google_cloud_default", *args, **kwargs ): diff --git a/airflow/gcp/operators/functions.py b/airflow/gcp/operators/functions.py index c27947d4a15430..7b7bbedf7f0d72 100644 --- a/airflow/gcp/operators/functions.py +++ b/airflow/gcp/operators/functions.py @@ -21,7 +21,7 @@ """ import re -from typing import Optional, Dict +from typing import Optional, List, Dict, Any from googleapiclient.errors import HttpError @@ -80,7 +80,7 @@ def _validate_max_instances(value): ]) ]) ]), -] +] # type: List[Dict[str, Any]] class GcfFunctionDeployOperator(BaseOperator): @@ -123,14 +123,14 @@ class GcfFunctionDeployOperator(BaseOperator): @apply_defaults def __init__(self, - location, - body, - project_id=None, - gcp_conn_id='google_cloud_default', - api_version='v1', - zip_path=None, - validate_body=True, - *args, **kwargs): + location: str, + body: Dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + zip_path: Optional[str] = None, + validate_body: bool = True, + *args, **kwargs) -> None: self.project_id = project_id self.location = location self.body = body @@ -138,7 +138,7 @@ def __init__(self, self.api_version = api_version self.zip_path = zip_path self.zip_path_preprocessor = ZipPathPreprocessor(body, zip_path) - self._field_validator = None + self._field_validator = None # type: Optional[GcpBodyFieldValidator] if validate_body: self._field_validator = GcpBodyFieldValidator(CLOUD_FUNCTION_VALIDATION, api_version=api_version) @@ -223,13 +223,16 @@ class ZipPathPreprocessor: :param body: Body passed to the create/update method calls. :type body: dict - :param zip_path: path to the zip file containing source code. - :type body: dict + :param zip_path: (optional) Path to zip file containing source code of the function. If the path + is set, the sourceUploadUrl should not be specified in the body or it should + be empty. Then the zip file will be uploaded using the upload URL generated + via generateUploadUrl from the Cloud Functions API. + :type zip_path: str """ upload_function = None # type: Optional[bool] - def __init__(self, body, zip_path): + def __init__(self, body: dict, zip_path: Optional[str] = None) -> None: self.body = body self.zip_path = zip_path @@ -312,10 +315,10 @@ class GcfFunctionDeleteOperator(BaseOperator): @apply_defaults def __init__(self, - name, - gcp_conn_id='google_cloud_default', - api_version='v1', - *args, **kwargs): + name: str, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + *args, **kwargs) -> None: self.name = name self.gcp_conn_id = gcp_conn_id self.api_version = api_version diff --git a/airflow/gcp/operators/gcs.py b/airflow/gcp/operators/gcs.py new file mode 100644 index 00000000000000..3ba1d1732170f9 --- /dev/null +++ b/airflow/gcp/operators/gcs.py @@ -0,0 +1,528 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains a Google Cloud Storage Bucket operator. +""" +import sys +import warnings +from typing import Dict, Optional, Iterable + +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.models.xcom import MAX_XCOM_SIZE +from airflow import AirflowException + + +class GoogleCloudStorageCreateBucketOperator(BaseOperator): + """ + Creates a new bucket. Google Cloud Storage uses a flat namespace, + so you can't create a bucket with a name that is already in use. + + .. seealso:: + For more information, see Bucket Naming Guidelines: + https://cloud.google.com/storage/docs/bucketnaming.html#requirements + + :param bucket_name: The name of the bucket. (templated) + :type bucket_name: str + :param resource: An optional dict with parameters for creating the bucket. + For information on available parameters, see Cloud Storage API doc: + https://cloud.google.com/storage/docs/json_api/v1/buckets/insert + :type resource: dict + :param storage_class: This defines how objects in the bucket are stored + and determines the SLA and the cost of storage (templated). Values include + + - ``MULTI_REGIONAL`` + - ``REGIONAL`` + - ``STANDARD`` + - ``NEARLINE`` + - ``COLDLINE``. + + If this value is not specified when the bucket is + created, it will default to STANDARD. + :type storage_class: str + :param location: The location of the bucket. (templated) + Object data for objects in the bucket resides in physical storage + within this region. Defaults to US. + + .. seealso:: https://developers.google.com/storage/docs/bucket-locations + + :type location: str + :param project_id: The ID of the GCP Project. (templated) + :type project_id: str + :param labels: User-provided labels, in key/value pairs. + :type labels: dict + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud + Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must + have domain-wide delegation enabled. + :type delegate_to: str + + The following Operator would create a new bucket ``test-bucket`` + with ``MULTI_REGIONAL`` storage class in ``EU`` region + + .. code-block:: python + + CreateBucket = GoogleCloudStorageCreateBucketOperator( + task_id='CreateNewBucket', + bucket_name='test-bucket', + storage_class='MULTI_REGIONAL', + location='EU', + labels={'env': 'dev', 'team': 'airflow'}, + gcp_conn_id='airflow-conn-id' + ) + + """ + template_fields = ('bucket_name', 'storage_class', + 'location', 'project_id') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + bucket_name: str, + resource: Optional[Dict] = None, + storage_class: str = 'MULTI_REGIONAL', + location: str = 'US', + project_id: Optional[str] = None, + labels: Optional[Dict] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + *args, + **kwargs) -> None: + super().__init__(*args, **kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket_name = bucket_name + self.resource = resource + self.storage_class = storage_class + self.location = location + self.project_id = project_id + self.labels = labels + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + + def execute(self, context): + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to + ) + + hook.create_bucket(bucket_name=self.bucket_name, + resource=self.resource, + storage_class=self.storage_class, + location=self.location, + project_id=self.project_id, + labels=self.labels) + + +class GoogleCloudStorageListOperator(BaseOperator): + """ + List all objects from the bucket with the give string prefix and delimiter in name. + + This operator returns a python list with the name of objects which can be used by + `xcom` in the downstream task. + + :param bucket: The Google cloud storage bucket to find the objects. (templated) + :type bucket: str + :param prefix: Prefix string which filters objects whose name begin with + this prefix. (templated) + :type prefix: str + :param delimiter: The delimiter by which you want to filter the objects. (templated) + For e.g to lists the CSV files from in a directory in GCS you would use + delimiter='.csv'. + :type delimiter: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud + Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + + **Example**: + The following Operator would list all the Avro files from ``sales/sales-2017`` + folder in ``data`` bucket. :: + + GCS_Files = GoogleCloudStorageListOperator( + task_id='GCS_Files', + bucket='data', + prefix='sales/sales-2017/', + delimiter='.avro', + gcp_conn_id=google_cloud_conn_id + ) + """ + template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str] + + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + bucket: str, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + *args, + **kwargs) -> None: + super().__init__(*args, **kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket = bucket + self.prefix = prefix + self.delimiter = delimiter + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + + def execute(self, context): + + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to + ) + + self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s', + self.bucket, self.delimiter, self.prefix) + + return hook.list(bucket_name=self.bucket, + prefix=self.prefix, + delimiter=self.delimiter) + + +class GoogleCloudStorageDownloadOperator(BaseOperator): + """ + Downloads a file from Google Cloud Storage. + + If a filename is supplied, it writes the file to the specified location, alternatively one can + set the ``store_to_xcom_key`` parameter to True push the file content into xcom. When the file size + exceeds the maximum size for xcom it is recommended to write to a file. + + :param bucket: The Google cloud storage bucket where the object is. + Must not contain 'gs://' prefix. (templated) + :type bucket: str + :param object: The name of the object to download in the Google cloud + storage bucket. (templated) + :type object: str + :param filename: The file path, including filename, on the local file system (where the + operator is being executed) that the file should be downloaded to. (templated) + If no filename passed, the downloaded data will not be stored on the local file + system. + :type filename: str + :param store_to_xcom_key: If this param is set, the operator will push + the contents of the downloaded file to XCom with the key set in this + parameter. If not set, the downloaded data will not be pushed to XCom. (templated) + :type store_to_xcom_key: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud + Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + template_fields = ('bucket', 'object', 'filename', 'store_to_xcom_key',) + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + bucket: str, + object_name: Optional[str] = None, + filename: Optional[str] = None, + store_to_xcom_key: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + *args, + **kwargs) -> None: + # To preserve backward compatibility + # TODO: Remove one day + if object_name is None: + if 'object' in kwargs: + object_name = kwargs['object'] + DeprecationWarning("Use 'object_name' instead of 'object'.") + else: + TypeError("__init__() missing 1 required positional argument: 'object_name'") + + if filename is not None and store_to_xcom_key is not None: + raise ValueError("Either filename or store_to_xcom_key can be set") + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = google_cloud_storage_conn_id + + super().__init__(*args, **kwargs) + self.bucket = bucket + self.object = object_name + self.filename = filename + self.store_to_xcom_key = store_to_xcom_key + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + + def execute(self, context): + self.log.info('Executing download: %s, %s, %s', self.bucket, + self.object, self.filename) + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to + ) + + if self.store_to_xcom_key: + file_bytes = hook.download(bucket_name=self.bucket, + object_name=self.object) + if sys.getsizeof(file_bytes) < MAX_XCOM_SIZE: + context['ti'].xcom_push(key=self.store_to_xcom_key, value=file_bytes) + else: + raise AirflowException( + 'The size of the downloaded file is too large to push to XCom!' + ) + else: + hook.download(bucket_name=self.bucket, + object_name=self.object, + filename=self.filename) + + +class GoogleCloudStorageDeleteOperator(BaseOperator): + """ + Deletes objects from a Google Cloud Storage bucket, either + from an explicit list of object names or all objects + matching a prefix. + + :param bucket_name: The GCS bucket to delete from + :type bucket_name: str + :param objects: List of objects to delete. These should be the names + of objects in the bucket, not including gs://bucket/ + :type objects: Iterable[str] + :param prefix: Prefix of objects to delete. All objects matching this + prefix in the bucket will be deleted. + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud + Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + + template_fields = ('bucket_name', 'prefix', 'objects') + + @apply_defaults + def __init__(self, + bucket_name: str, + objects: Optional[Iterable[str]] = None, + prefix: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket_name = bucket_name + self.objects = objects + self.prefix = prefix + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + + assert objects is not None or prefix is not None + + super().__init__(*args, **kwargs) + + def execute(self, context): + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to + ) + + if self.objects: + objects = self.objects + else: + objects = hook.list(bucket_name=self.bucket_name, + prefix=self.prefix) + + self.log.info("Deleting %s objects from %s", + len(objects), self.bucket_name) + for object_name in objects: + hook.delete(bucket_name=self.bucket_name, + object_name=object_name) + + +class GoogleCloudStorageBucketCreateAclEntryOperator(BaseOperator): + """ + Creates a new ACL entry on the specified bucket. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCloudStorageBucketCreateAclEntryOperator` + + :param bucket: Name of a bucket. + :type bucket: str + :param entity: The entity holding the permission, in one of the following forms: + user-userId, user-email, group-groupId, group-email, domain-domain, + project-team-projectId, allUsers, allAuthenticatedUsers + :type entity: str + :param role: The access permission for the entity. + Acceptable values are: "OWNER", "READER", "WRITER". + :type role: str + :param user_project: (Optional) The project to be billed for this request. + Required for Requester Pays buckets. + :type user_project: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud + Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + """ + # [START gcs_bucket_create_acl_template_fields] + template_fields = ('bucket', 'entity', 'role', 'user_project') + # [END gcs_bucket_create_acl_template_fields] + + @apply_defaults + def __init__( + self, + bucket: str, + entity: str, + role: str, + user_project: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + *args, + **kwargs + ) -> None: + super().__init__(*args, **kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket = bucket + self.entity = entity + self.role = role + self.user_project = user_project + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id + ) + hook.insert_bucket_acl(bucket_name=self.bucket, entity=self.entity, role=self.role, + user_project=self.user_project) + + +class GoogleCloudStorageObjectCreateAclEntryOperator(BaseOperator): + """ + Creates a new ACL entry on the specified object. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GoogleCloudStorageObjectCreateAclEntryOperator` + + :param bucket: Name of a bucket. + :type bucket: str + :param object_name: Name of the object. For information about how to URL encode object + names to be path safe, see: + https://cloud.google.com/storage/docs/json_api/#encoding + :type object_name: str + :param entity: The entity holding the permission, in one of the following forms: + user-userId, user-email, group-groupId, group-email, domain-domain, + project-team-projectId, allUsers, allAuthenticatedUsers + :type entity: str + :param role: The access permission for the entity. + Acceptable values are: "OWNER", "READER". + :type role: str + :param generation: Optional. If present, selects a specific revision of this object. + :type generation: long + :param user_project: (Optional) The project to be billed for this request. + Required for Requester Pays buckets. + :type user_project: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud + Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + """ + # [START gcs_object_create_acl_template_fields] + template_fields = ('bucket', 'object_name', 'entity', 'generation', 'role', 'user_project') + # [END gcs_object_create_acl_template_fields] + + @apply_defaults + def __init__(self, + bucket: str, + object_name: str, + entity: str, + role: str, + generation: Optional[int] = None, + user_project: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = google_cloud_storage_conn_id + + self.bucket = bucket + self.object_name = object_name + self.entity = entity + self.role = role + self.generation = generation + self.user_project = user_project + self.gcp_conn_id = gcp_conn_id + + def execute(self, context): + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id + ) + hook.insert_object_acl(bucket_name=self.bucket, + object_name=self.object_name, + entity=self.entity, + role=self.role, + generation=self.generation, + user_project=self.user_project) diff --git a/airflow/gcp/operators/kubernetes_engine.py b/airflow/gcp/operators/kubernetes_engine.py index d75944e7bb60d5..f7fd043933dd2e 100644 --- a/airflow/gcp/operators/kubernetes_engine.py +++ b/airflow/gcp/operators/kubernetes_engine.py @@ -24,8 +24,10 @@ import os import subprocess import tempfile +from typing import Union, Dict, Optional from google.auth.environment_vars import CREDENTIALS +from google.cloud.container_v1.types import Cluster from airflow import AirflowException from airflow.gcp.hooks.kubernetes_engine import GKEClusterHook @@ -69,13 +71,13 @@ class GKEClusterDeleteOperator(BaseOperator): @apply_defaults def __init__(self, - project_id, - name, - location, - gcp_conn_id='google_cloud_default', - api_version='v2', + name: str, + location: str, + project_id: str = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v2', *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.project_id = project_id @@ -83,6 +85,7 @@ def __init__(self, self.location = location self.api_version = api_version self.name = name + self._check_input() def _check_input(self): if not all([self.project_id, self.name, self.location]): @@ -91,7 +94,6 @@ def _check_input(self): raise AirflowException('Operator has incorrect or missing input.') def execute(self, context): - self._check_input() hook = GKEClusterHook(gcp_conn_id=self.gcp_conn_id, location=self.location) delete_result = hook.delete_cluster(name=self.name, project_id=self.project_id) return delete_result @@ -144,41 +146,34 @@ class GKEClusterCreateOperator(BaseOperator): @apply_defaults def __init__(self, - project_id, - location, - body=None, - gcp_conn_id='google_cloud_default', - api_version='v2', + location: str, + body: Optional[Union[Dict, Cluster]], + project_id: str = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v2', *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) - if body is None: - body = {} self.project_id = project_id self.gcp_conn_id = gcp_conn_id self.location = location self.api_version = api_version self.body = body + self._check_input() def _check_input(self): - if all([self.project_id, self.location, self.body]): - if isinstance(self.body, dict) \ - and 'name' in self.body \ - and 'initial_node_count' in self.body: - # Don't throw error - return - # If not dict, then must - elif self.body.name and self.body.initial_node_count: - return - - self.log.error( - 'One of (project_id, location, body, body[\'name\'], ' - 'body[\'initial_node_count\']) is missing or incorrect') - raise AirflowException('Operator has incorrect or missing input.') + if not all([self.project_id, self.location, self.body]) or not ( + (isinstance(self.body, dict) and "name" in self.body and "initial_node_count" in self.body) or + (getattr(self.body, "name", None) and getattr(self.body, "initial_node_count", None)) + ): + self.log.error( + "One of (project_id, location, body, body['name'], " + "body['initial_node_count']) is missing or incorrect" + ) + raise AirflowException("Operator has incorrect or missing input.") def execute(self, context): - self._check_input() hook = GKEClusterHook(gcp_conn_id=self.gcp_conn_id, location=self.location) create_op = hook.create_cluster(cluster=self.body, project_id=self.project_id) return create_op @@ -231,10 +226,10 @@ class GKEPodOperator(KubernetesPodOperator): @apply_defaults def __init__(self, - project_id, - location, - cluster_name, - gcp_conn_id='google_cloud_default', + project_id: str, + location: str, + cluster_name: str, + gcp_conn_id: str = 'google_cloud_default', *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/airflow/gcp/operators/mlengine.py b/airflow/gcp/operators/mlengine.py index e2758333733273..5b1ed88516b453 100644 --- a/airflow/gcp/operators/mlengine.py +++ b/airflow/gcp/operators/mlengine.py @@ -1,23 +1,25 @@ # -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the 'License'); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. """ This module contains GCP MLEngine operators. """ import re - +from typing import List, Optional from airflow.gcp.hooks.mlengine import MLEngineHook from airflow.exceptions import AirflowException @@ -28,7 +30,7 @@ log = LoggingMixin().log -def _normalize_mlengine_job_id(job_id): +def _normalize_mlengine_job_id(job_id: str) -> str: """ Replaces invalid MLEngine job_id characters with '_'. @@ -136,8 +138,10 @@ class MLEngineBatchPredictionOperator(BaseOperator): :type uri: str :param max_worker_count: The maximum number of workers to be used - for parallel processing. Defaults to 10 if not specified. - :type max_worker_count: int + for parallel processing. Defaults to 10 if not specified. Should be a + string representing the worker count ("10" instead of 10, "50" instead + of 50, etc.) + :type max_worker_count: string :param runtime_version: The Google Cloud ML Engine runtime version to use for batch prediction. @@ -173,22 +177,22 @@ class MLEngineBatchPredictionOperator(BaseOperator): @apply_defaults def __init__(self, # pylint:disable=too-many-arguments - project_id, - job_id, - region, - data_format, - input_paths, - output_path, - model_name=None, - version_name=None, - uri=None, - max_worker_count=None, - runtime_version=None, - signature_name=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, + project_id: str, + job_id: str, + region: str, + data_format: str, + input_paths: List[str], + output_path: str, + model_name: Optional[str] = None, + version_name: Optional[str] = None, + uri: Optional[str] = None, + max_worker_count: Optional[int] = None, + runtime_version: Optional[str] = None, + signature_name: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self._project_id = project_id @@ -317,13 +321,13 @@ class MLEngineModelOperator(BaseOperator): @apply_defaults def __init__(self, - project_id, - model, - operation='create', - gcp_conn_id='google_cloud_default', - delegate_to=None, + project_id: str, + model: dict, + operation: str = 'create', + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self._project_id = project_id self._model = model @@ -406,15 +410,15 @@ class MLEngineVersionOperator(BaseOperator): @apply_defaults def __init__(self, - project_id, - model_name, - version_name=None, - version=None, - operation='create', - gcp_conn_id='google_cloud_default', - delegate_to=None, + project_id: str, + model_name: str, + version_name: Optional[str] = None, + version: Optional[dict] = None, + operation: str = 'create', + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self._project_id = project_id @@ -528,22 +532,22 @@ class MLEngineTrainingOperator(BaseOperator): @apply_defaults def __init__(self, # pylint:disable=too-many-arguments - project_id, - job_id, - package_uris, - training_python_module, - training_args, - region, - scale_tier=None, - master_type=None, - runtime_version=None, - python_version=None, - job_dir=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - mode='PRODUCTION', + project_id: str, + job_id: str, + package_uris: str, + training_python_module: str, + training_args: str, + region: str, + scale_tier: Optional[str] = None, + master_type: Optional[str] = None, + runtime_version: Optional[str] = None, + python_version: Optional[str] = None, + job_dir: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + mode: str = 'PRODUCTION', *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self._project_id = project_id self._job_id = job_id diff --git a/airflow/gcp/operators/natural_language.py b/airflow/gcp/operators/natural_language.py index c0cda6a7dfc654..f5f19a870eba61 100644 --- a/airflow/gcp/operators/natural_language.py +++ b/airflow/gcp/operators/natural_language.py @@ -19,11 +19,18 @@ """ This module contains Google Cloud Language operators. """ +from typing import Union, Tuple, Sequence, Optional + from google.protobuf.json_format import MessageToDict +from google.cloud.language_v1.types import Document +from google.cloud.language_v1 import enums +from google.api_core.retry import Retry from airflow.gcp.hooks.natural_language import CloudNaturalLanguageHook from airflow.models import BaseOperator +MetaData = Sequence[Tuple[str, str]] + class CloudLanguageAnalyzeEntitiesOperator(BaseOperator): """ @@ -38,14 +45,14 @@ class CloudLanguageAnalyzeEntitiesOperator(BaseOperator): If a dict is provided, it must be of the same form as the protobuf message Document :type document: dict or google.cloud.language_v1.types.Document :param encoding_type: The encoding type used by the API to calculate offsets. - :type encoding_type: google.cloud.language_v1.types.EncodingType + :type encoding_type: google.cloud.language_v1.enums.EncodingType :param retry: A retry object used to retry requests. If None is specified, requests will not be retried. :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :type timeout: float :param metadata: Additional metadata that is provided to the method. - :type metadata: seq[tuple[str, str]]] + :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. :type gcp_conn_id: str """ @@ -55,15 +62,15 @@ class CloudLanguageAnalyzeEntitiesOperator(BaseOperator): def __init__( self, - document, - encoding_type=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.document = document self.encoding_type = encoding_type @@ -97,14 +104,14 @@ class CloudLanguageAnalyzeEntitySentimentOperator(BaseOperator): If a dict is provided, it must be of the same form as the protobuf message Document :type document: dict or google.cloud.language_v1.types.Document :param encoding_type: The encoding type used by the API to calculate offsets. - :type encoding_type: google.cloud.language_v1.types.EncodingType + :type encoding_type: google.cloud.language_v1.enums.EncodingType :param retry: A retry object used to retry requests. If None is specified, requests will not be retried. :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :type timeout: float :param metadata: Additional metadata that is provided to the method. - :type metadata: seq[tuple[str, str]]] + :type metadata: Sequence[Tuple[str, str]]] :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. :type gcp_conn_id: str @@ -115,15 +122,15 @@ class CloudLanguageAnalyzeEntitySentimentOperator(BaseOperator): def __init__( self, - document, - encoding_type=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.document = document self.encoding_type = encoding_type @@ -160,7 +167,7 @@ class CloudLanguageAnalyzeSentimentOperator(BaseOperator): If a dict is provided, it must be of the same form as the protobuf message Document :type document: dict or google.cloud.language_v1.types.Document :param encoding_type: The encoding type used by the API to calculate offsets. - :type encoding_type: google.cloud.language_v1.types.EncodingType + :type encoding_type: google.cloud.language_v1.enums.EncodingType :param retry: A retry object used to retry requests. If None is specified, requests will not be retried. :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if @@ -178,15 +185,15 @@ class CloudLanguageAnalyzeSentimentOperator(BaseOperator): def __init__( self, - document, - encoding_type=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + document: Union[dict, Document], + encoding_type: Optional[enums.EncodingType] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.document = document self.encoding_type = encoding_type @@ -234,14 +241,14 @@ class CloudLanguageClassifyTextOperator(BaseOperator): def __init__( self, - document, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + document: Union[dict, Document], + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.document = document self.retry = retry diff --git a/airflow/gcp/operators/pubsub.py b/airflow/gcp/operators/pubsub.py index 261587625ce9e8..e9a1df3cddde12 100644 --- a/airflow/gcp/operators/pubsub.py +++ b/airflow/gcp/operators/pubsub.py @@ -19,6 +19,8 @@ """ This module contains Google PubSub operators. """ +from typing import List, Optional + from airflow.gcp.hooks.pubsub import PubSubHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -75,13 +77,13 @@ class PubSubTopicCreateOperator(BaseOperator): @apply_defaults def __init__( self, - project, - topic, - fail_if_exists=False, - gcp_conn_id='google_cloud_default', - delegate_to=None, + project: str, + topic: str, + fail_if_exists: bool = False, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.project = project @@ -177,15 +179,15 @@ class PubSubSubscriptionCreateOperator(BaseOperator): def __init__( self, topic_project, - topic, + topic: str, subscription=None, subscription_project=None, - ack_deadline_secs=10, - fail_if_exists=False, - gcp_conn_id='google_cloud_default', - delegate_to=None, + ack_deadline_secs: int = 10, + fail_if_exists: bool = False, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.topic_project = topic_project self.topic = topic @@ -256,13 +258,13 @@ class PubSubTopicDeleteOperator(BaseOperator): @apply_defaults def __init__( self, - project, - topic, + project: str, + topic: str, fail_if_not_exists=False, - gcp_conn_id='google_cloud_default', - delegate_to=None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.project = project @@ -331,13 +333,13 @@ class PubSubSubscriptionDeleteOperator(BaseOperator): @apply_defaults def __init__( self, - project, - subscription, + project: str, + subscription: str, fail_if_not_exists=False, - gcp_conn_id='google_cloud_default', - delegate_to=None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.project = project @@ -408,13 +410,13 @@ class PubSubPublishOperator(BaseOperator): @apply_defaults def __init__( self, - project, - topic, - messages, - gcp_conn_id='google_cloud_default', - delegate_to=None, + project: str, + topic: str, + messages: List, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.gcp_conn_id = gcp_conn_id diff --git a/airflow/gcp/operators/spanner.py b/airflow/gcp/operators/spanner.py index a5c0048c9d92be..d4a29a9af6b0fa 100644 --- a/airflow/gcp/operators/spanner.py +++ b/airflow/gcp/operators/spanner.py @@ -19,6 +19,7 @@ """ This module contains Google Spanner operators. """ +from typing import List, Optional from airflow import AirflowException from airflow.gcp.hooks.spanner import CloudSpannerHook @@ -57,13 +58,13 @@ class CloudSpannerInstanceDeployOperator(BaseOperator): @apply_defaults def __init__(self, - instance_id, - configuration_name, - node_count, - display_name, - project_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: int, + configuration_name: str, + node_count: str, + display_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.instance_id = instance_id self.project_id = project_id self.configuration_name = configuration_name @@ -118,10 +119,10 @@ class CloudSpannerInstanceDeleteOperator(BaseOperator): @apply_defaults def __init__(self, - instance_id, - project_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: int, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.instance_id = instance_id self.project_id = project_id self.gcp_conn_id = gcp_conn_id @@ -174,12 +175,12 @@ class CloudSpannerInstanceDatabaseQueryOperator(BaseOperator): @apply_defaults def __init__(self, - instance_id, + instance_id: int, database_id, query, - project_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.instance_id = instance_id self.project_id = project_id self.database_id = database_id @@ -257,12 +258,12 @@ class CloudSpannerInstanceDatabaseDeployOperator(BaseOperator): @apply_defaults def __init__(self, - instance_id, - database_id, - ddl_statements, - project_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: int, + database_id: str, + ddl_statements: List[str], + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.instance_id = instance_id self.project_id = project_id self.database_id = database_id @@ -331,13 +332,13 @@ class CloudSpannerInstanceDatabaseUpdateOperator(BaseOperator): @apply_defaults def __init__(self, - instance_id, - database_id, - ddl_statements, - project_id=None, - operation_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: int, + database_id: str, + ddl_statements: List[str], + project_id: Optional[str] = None, + operation_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.instance_id = instance_id self.project_id = project_id self.database_id = database_id @@ -403,11 +404,11 @@ class CloudSpannerInstanceDatabaseDeleteOperator(BaseOperator): @apply_defaults def __init__(self, - instance_id, - database_id, - project_id=None, - gcp_conn_id='google_cloud_default', - *args, **kwargs): + instance_id: int, + database_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + *args, **kwargs) -> None: self.instance_id = instance_id self.project_id = project_id self.database_id = database_id diff --git a/airflow/gcp/operators/speech_to_text.py b/airflow/gcp/operators/speech_to_text.py index 9f11b5733444a1..c3dffda4779535 100644 --- a/airflow/gcp/operators/speech_to_text.py +++ b/airflow/gcp/operators/speech_to_text.py @@ -19,9 +19,13 @@ """ This module contains a Google Speech to Text operator. """ +from typing import Optional + +from google.api_core.retry import Retry +from google.cloud.speech_v1.types import RecognitionConfig from airflow import AirflowException -from airflow.gcp.hooks.speech_to_text import GCPSpeechToTextHook +from airflow.gcp.hooks.speech_to_text import RecognitionAudio, GCPSpeechToTextHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -61,15 +65,15 @@ class GcpSpeechToTextRecognizeSpeechOperator(BaseOperator): @apply_defaults def __init__( self, - audio, - config, - project_id=None, - gcp_conn_id="google_cloud_default", - retry=None, - timeout=None, + audio: RecognitionAudio, + config: RecognitionConfig, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + retry: Optional[Retry] = None, + timeout: Optional[float] = None, *args, **kwargs - ): + ) -> None: self.audio = audio self.config = config self.project_id = project_id diff --git a/airflow/gcp/operators/tasks.py b/airflow/gcp/operators/tasks.py index 4f58241dd02de6..9509b7c60f41fa 100644 --- a/airflow/gcp/operators/tasks.py +++ b/airflow/gcp/operators/tasks.py @@ -22,11 +22,18 @@ which allow you to perform basic operations using Cloud Tasks queues/tasks. """ +from typing import Tuple, Sequence, Union, Dict, Optional + +from google.api_core.retry import Retry +from google.cloud.tasks_v2.types import Queue, FieldMask, Task +from google.cloud.tasks_v2 import enums from airflow.gcp.hooks.tasks import CloudTasksHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults +MetaData = Sequence[Tuple[str, str]] + class CloudTasksQueueCreateOperator(BaseOperator): """ @@ -69,17 +76,17 @@ class CloudTasksQueueCreateOperator(BaseOperator): @apply_defaults def __init__( self, - location, - task_queue, - project_id=None, - queue_name=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + task_queue: Queue, + project_id: Optional[str] = None, + queue_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.task_queue = task_queue @@ -150,18 +157,18 @@ class CloudTasksQueueUpdateOperator(BaseOperator): @apply_defaults def __init__( self, - task_queue, - project_id=None, - location=None, - queue_name=None, - update_mask=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + task_queue: Queue, + project_id: Optional[str] = None, + location: Optional[str] = None, + queue_name: Optional[str] = None, + update_mask: Union[Dict, FieldMask] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.task_queue = task_queue self.project_id = project_id @@ -217,16 +224,16 @@ class CloudTasksQueueGetOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - project_id=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -282,17 +289,17 @@ class CloudTasksQueuesListOperator(BaseOperator): @apply_defaults def __init__( self, - location, - project_id=None, + location: str, + project_id: Optional[str] = None, results_filter=None, - page_size=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.project_id = project_id @@ -345,16 +352,16 @@ class CloudTasksQueueDeleteOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - project_id=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -406,16 +413,16 @@ class CloudTasksQueuePurgeOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - project_id=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -467,16 +474,16 @@ class CloudTasksQueuePauseOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - project_id=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -528,16 +535,16 @@ class CloudTasksQueueResumeOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - project_id=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -578,7 +585,7 @@ class CloudTasksTaskCreateOperator(BaseOperator): :type task_name: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.types.Task.View + :type response_view: google.cloud.tasks_v2.enums.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -603,21 +610,21 @@ class CloudTasksTaskCreateOperator(BaseOperator): ) @apply_defaults - def __init__( + def __init__( # pylint: disable=too-many-arguments self, - location, - queue_name, - task, - project_id=None, - task_name=None, - response_view=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + task: Union[Dict, Task], + project_id: Optional[str] = None, + task_name: Optional[str] = None, + response_view: Optional[enums.Task.View] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): # pylint: disable=too-many-arguments + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -660,7 +667,7 @@ class CloudTasksTaskGetOperator(BaseOperator): :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.types.Task.View + :type response_view: google.cloud.tasks_v2.enums.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -686,18 +693,18 @@ class CloudTasksTaskGetOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - task_name, - project_id=None, - response_view=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + task_name: str, + project_id: Optional[str] = None, + response_view: Optional[enums.Task.View] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -736,7 +743,7 @@ class CloudTasksTasksListOperator(BaseOperator): :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.types.Task.View + :type response_view: google.cloud.tasks_v2.enums.Task.View :param page_size: (Optional) The maximum number of resources contained in the underlying API response. :type page_size: int @@ -759,18 +766,18 @@ class CloudTasksTasksListOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - project_id=None, - response_view=None, - page_size=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + project_id: Optional[str] = None, + response_view: Optional[enums.Task.View] = None, + page_size: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -833,17 +840,17 @@ class CloudTasksTaskDeleteOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - task_name, - project_id=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + task_name: str, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name @@ -882,7 +889,7 @@ class CloudTasksTaskRunOperator(BaseOperator): :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.types.Task.View + :type response_view: google.cloud.tasks_v2.enums.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -908,18 +915,18 @@ class CloudTasksTaskRunOperator(BaseOperator): @apply_defaults def __init__( self, - location, - queue_name, - task_name, - project_id=None, - response_view=None, - retry=None, - timeout=None, - metadata=None, - gcp_conn_id="google_cloud_default", + location: str, + queue_name: str, + task_name: str, + project_id: Optional[str] = None, + response_view: Optional[enums.Task.View] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.queue_name = queue_name diff --git a/airflow/gcp/operators/text_to_speech.py b/airflow/gcp/operators/text_to_speech.py index 4f0f02d4e402d3..669a8bec3bf362 100644 --- a/airflow/gcp/operators/text_to_speech.py +++ b/airflow/gcp/operators/text_to_speech.py @@ -21,10 +21,14 @@ """ from tempfile import NamedTemporaryFile +from typing import Dict, Union, Optional + +from google.api_core.retry import Retry +from google.cloud.texttospeech_v1.types import SynthesisInput, VoiceSelectionParams, AudioConfig from airflow import AirflowException from airflow.gcp.hooks.text_to_speech import GCPTextToSpeechHook -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -79,18 +83,18 @@ class GcpTextToSpeechSynthesizeOperator(BaseOperator): @apply_defaults def __init__( self, - input_data, - voice, - audio_config, - target_bucket_name, - target_filename, - project_id=None, - gcp_conn_id="google_cloud_default", - retry=None, - timeout=None, + input_data: Union[Dict, SynthesisInput], + voice: Union[Dict, VoiceSelectionParams], + audio_config: Union[Dict, AudioConfig], + target_bucket_name: str, + target_filename: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + retry: Optional[Retry] = None, + timeout: Optional[float] = None, *args, **kwargs - ): + ) -> None: self.input_data = input_data self.voice = voice self.audio_config = audio_config diff --git a/airflow/gcp/operators/translate.py b/airflow/gcp/operators/translate.py index 5ab3af68836c71..27b491efe7d204 100644 --- a/airflow/gcp/operators/translate.py +++ b/airflow/gcp/operators/translate.py @@ -19,6 +19,7 @@ """ This module contains Google Translate operators. """ +from typing import List, Union from airflow import AirflowException from airflow.gcp.hooks.translate import CloudTranslateHook @@ -80,15 +81,15 @@ class CloudTranslateTextOperator(BaseOperator): @apply_defaults def __init__( self, - values, - target_language, - format_, - source_language, - model, + values: Union[List[str], str], + target_language: str, + format_: str, + source_language: str, + model: str, gcp_conn_id='google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.values = values self.target_language = target_language diff --git a/airflow/gcp/operators/translate_speech.py b/airflow/gcp/operators/translate_speech.py index ac86a5d1d0691d..c837b00008e24d 100644 --- a/airflow/gcp/operators/translate_speech.py +++ b/airflow/gcp/operators/translate_speech.py @@ -19,7 +19,10 @@ """ This module contains a Google Cloud Translate Speech operator. """ +from typing import Optional + from google.protobuf.json_format import MessageToDict +from google.cloud.speech_v1.types import RecognitionAudio, RecognitionConfig from airflow import AirflowException from airflow.gcp.hooks.speech_to_text import GCPSpeechToTextHook @@ -102,17 +105,17 @@ class GcpTranslateSpeechOperator(BaseOperator): @apply_defaults def __init__( self, - audio, - config, - target_language, - format_, - source_language, - model, - project_id=None, + audio: RecognitionAudio, + config: RecognitionConfig, + target_language: str, + format_: str, + source_language: str, + model: str, + project_id: Optional[str] = None, gcp_conn_id='google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.audio = audio self.config = config diff --git a/airflow/gcp/operators/video_intelligence.py b/airflow/gcp/operators/video_intelligence.py index 5f8cd69f94e7c5..9523d860e836f8 100644 --- a/airflow/gcp/operators/video_intelligence.py +++ b/airflow/gcp/operators/video_intelligence.py @@ -19,9 +19,12 @@ """ This module contains Google Cloud Vision operators. """ +from typing import Dict, Union, Optional +from google.api_core.retry import Retry from google.protobuf.json_format import MessageToDict from google.cloud.videointelligence_v1 import enums +from google.cloud.videointelligence_v1.types import VideoContext from airflow.gcp.hooks.video_intelligence import CloudVideoIntelligenceHook from airflow.models import BaseOperator @@ -68,17 +71,17 @@ class CloudVideoIntelligenceDetectVideoLabelsOperator(BaseOperator): def __init__( self, - input_uri, - input_content=None, - output_uri=None, - video_context=None, - location=None, - retry=None, - timeout=None, - gcp_conn_id="google_cloud_default", + input_uri: str, + input_content: Optional[bytes] = None, + output_uri: Optional[str] = None, + video_context: Union[Dict, VideoContext] = None, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.input_uri = input_uri self.input_content = input_content @@ -147,17 +150,17 @@ class CloudVideoIntelligenceDetectVideoExplicitContentOperator(BaseOperator): def __init__( self, - input_uri, - output_uri=None, - input_content=None, - video_context=None, - location=None, - retry=None, - timeout=None, - gcp_conn_id="google_cloud_default", + input_uri: str, + output_uri: Optional[str] = None, + input_content: Optional[bytes] = None, + video_context: Union[Dict, VideoContext] = None, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.input_uri = input_uri self.output_uri = output_uri @@ -226,17 +229,17 @@ class CloudVideoIntelligenceDetectVideoShotsOperator(BaseOperator): def __init__( self, - input_uri, - output_uri=None, - input_content=None, - video_context=None, - location=None, - retry=None, - timeout=None, - gcp_conn_id="google_cloud_default", + input_uri: str, + output_uri: Optional[str] = None, + input_content: Optional[bytes] = None, + video_context: Union[Dict, VideoContext] = None, + location: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.input_uri = input_uri self.output_uri = output_uri diff --git a/airflow/gcp/operators/vision.py b/airflow/gcp/operators/vision.py index 59a3693d46682d..71b0be3d06ec9a 100644 --- a/airflow/gcp/operators/vision.py +++ b/airflow/gcp/operators/vision.py @@ -38,6 +38,8 @@ from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults +MetaData = Sequence[Tuple[str, str]] + class CloudVisionProductSetCreateOperator(BaseOperator): """ @@ -82,15 +84,15 @@ def __init__( self, product_set: Union[dict, ProductSet], location: str, - project_id: str = None, - product_set_id: str = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + project_id: Optional[str] = None, + product_set_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.project_id = project_id @@ -158,14 +160,14 @@ def __init__( self, location: str, product_set_id: str, - project_id: str = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.project_id = project_id @@ -242,17 +244,17 @@ class CloudVisionProductSetUpdateOperator(BaseOperator): def __init__( self, product_set: Union[Dict, ProductSet], - location: str = None, - product_set_id: str = None, - project_id: str = None, + location: Optional[str] = None, + product_set_id: Optional[str] = None, + project_id: Optional[str] = None, update_mask: Union[Dict, FieldMask] = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.product_set = product_set self.update_mask = update_mask @@ -317,14 +319,14 @@ def __init__( self, location: str, product_set_id: str, - project_id: str = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.project_id = project_id @@ -395,15 +397,15 @@ def __init__( self, location: str, product: str, - project_id: str = None, - product_id: str = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + project_id: Optional[str] = None, + product_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.product = product @@ -474,14 +476,14 @@ def __init__( self, location: str, product_id: str, - project_id: str = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.product_id = product_id @@ -569,17 +571,17 @@ class CloudVisionProductUpdateOperator(BaseOperator): def __init__( self, product: Union[Dict, Product], - location: str = None, - product_id: str = None, - project_id: str = None, + location: Optional[str] = None, + product_id: Optional[str] = None, + project_id: Optional[str] = None, update_mask: Union[Dict, FieldMask] = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.product = product self.location = location @@ -649,14 +651,14 @@ def __init__( self, location: str, product_id: str, - project_id: str = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.product_id = product_id @@ -709,12 +711,12 @@ class CloudVisionAnnotateImageOperator(BaseOperator): def __init__( self, request: Union[Dict, AnnotateImageRequest], - retry: Retry = None, - timeout: float = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.request = request self.retry = retry @@ -791,15 +793,15 @@ def __init__( location: str, reference_image: Union[Dict, ReferenceImage], product_id: str, - reference_image_id: str = None, - project_id: str = None, - retry: Retry = None, - timeout: str = None, - metadata: Sequence[Tuple[str, str]] = None, + reference_image_id: Optional[str] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[str] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.location = location self.product_id = product_id @@ -878,14 +880,14 @@ def __init__( product_set_id: str, product_id: str, location: str, - project_id: str = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.product_set_id = product_set_id self.product_id = product_id @@ -949,14 +951,14 @@ def __init__( product_set_id: str, product_id: str, location: str, - project_id: str = None, - retry: Retry = None, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = None, + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.product_set_id = product_set_id self.product_id = product_id @@ -1015,16 +1017,16 @@ class CloudVisionDetectTextOperator(BaseOperator): def __init__( self, image: Union[Dict, Image], - max_results: int = None, - retry: Retry = None, - timeout: float = None, - language_hints: Union[str, List[str]] = None, - web_detection_params: Dict = None, - additional_properties: Dict = None, + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + language_hints: Optional[Union[str, List[str]]] = None, + web_detection_params: Optional[Dict] = None, + additional_properties: Optional[Dict] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.image = image self.max_results = max_results @@ -1078,22 +1080,22 @@ class CloudVisionDetectDocumentTextOperator(BaseOperator): :type additional_properties: dict """ # [START vision_document_detect_text_set_template_fields] - template_fields = ("image", "max_results", "timeout", "gcp_conn_id") + template_fields = ("image", "max_results", "timeout", "gcp_conn_id") # Iterable[str] # [END vision_document_detect_text_set_template_fields] def __init__( self, image: Union[Dict, Image], - max_results: int = None, - retry: Retry = None, - timeout: float = None, - language_hints: Union[str, List[str]] = None, - web_detection_params: Dict = None, - additional_properties: Dict = None, + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + language_hints: Optional[Union[str, List[str]]] = None, + web_detection_params: Optional[Dict] = None, + additional_properties: Optional[Dict] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.image = image self.max_results = max_results @@ -1146,14 +1148,14 @@ class CloudVisionDetectImageLabelsOperator(BaseOperator): def __init__( self, image: Union[Dict, Image], - max_results: int = None, - retry: Retry = None, - timeout: float = None, - additional_properties: Dict = None, + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + additional_properties: Optional[Dict] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.image = image self.max_results = max_results @@ -1202,14 +1204,14 @@ class CloudVisionDetectImageSafeSearchOperator(BaseOperator): def __init__( self, image: Union[Dict, Image], - max_results: int = None, - retry: Retry = None, - timeout: float = None, - additional_properties: Dict = None, + max_results: Optional[int] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + additional_properties: Optional[Dict] = None, gcp_conn_id: str = "google_cloud_default", *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.image = image self.max_results = max_results diff --git a/airflow/gcp/sensors/bigquery.py b/airflow/gcp/sensors/bigquery.py new file mode 100644 index 00000000000000..72acd3fe0869c4 --- /dev/null +++ b/airflow/gcp/sensors/bigquery.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains a Google Bigquery sensor. +""" +from typing import Optional + +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.gcp.hooks.bigquery import BigQueryHook +from airflow.utils.decorators import apply_defaults + + +class BigQueryTableSensor(BaseSensorOperator): + """ + Checks for the existence of a table in Google Bigquery. + + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :type project_id: str + :param dataset_id: The name of the dataset in which to look for the table. + storage bucket. + :type dataset_id: str + :param table_id: The name of the table to check the existence of. + :type table_id: str + :param bigquery_conn_id: The connection ID to use when connecting to + Google BigQuery. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must + have domain-wide delegation enabled. + :type delegate_to: str + """ + template_fields = ('project_id', 'dataset_id', 'table_id',) + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + project_id: str, + dataset_id: str, + table_id: str, + bigquery_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + + super().__init__(*args, **kwargs) + self.project_id = project_id + self.dataset_id = dataset_id + self.table_id = table_id + self.bigquery_conn_id = bigquery_conn_id + self.delegate_to = delegate_to + + def poke(self, context): + table_uri = '{0}:{1}.{2}'.format(self.project_id, self.dataset_id, self.table_id) + self.log.info('Sensor checks existence of table: %s', table_uri) + hook = BigQueryHook( + bigquery_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to) + return hook.table_exists(self.project_id, self.dataset_id, self.table_id) diff --git a/airflow/gcp/sensors/bigquery_dts.py b/airflow/gcp/sensors/bigquery_dts.py new file mode 100644 index 00000000000000..92bd72a0b526df --- /dev/null +++ b/airflow/gcp/sensors/bigquery_dts.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains a Google BigQuery Data Transfer Service sensor. +""" +from typing import Sequence, Tuple, Union, Set + +from google.api_core.retry import Retry +from google.protobuf.json_format import MessageToDict + +from airflow.gcp.hooks.bigquery_dts import BiqQueryDataTransferServiceHook +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator): + """ + Waits for Data Transfer Service run to complete. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:BigQueryDataTransferServiceTransferRunSensor` + + :param expected_statuses: The expected state of the operation. + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status + :type expected_statuses: Union[Set[str], str] + :param run_id: ID of the transfer run. + :type run_id: str + :param transfer_config_id: ID of transfer config to be used. + :type transfer_config_id: str + :param project_id: The BigQuery project id where the transfer configuration should be + created. If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: A retry object used to retry requests. If `None` is + specified, requests will not be retried. + :type retry: Optional[google.api_core.retry.Retry] + :param request_timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each individual + attempt. + :type request_timeout: Optional[float] + :param metadata: Additional metadata that is provided to the method. + :type metadata: Optional[Sequence[Tuple[str, str]]] + :return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance. + """ + + template_fields = ( + "run_id", + "transfer_config_id", + "expected_statuses", + "project_id", + ) + + @apply_defaults + def __init__( + self, + run_id: str, + transfer_config_id: str, + expected_statuses: Union[Set[str], str] = 'SUCCEEDED', + project_id: str = None, + gcp_conn_id: str = "google_cloud_default", + retry: Retry = None, + request_timeout: float = None, + metadata: Sequence[Tuple[str, str]] = None, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.run_id = run_id + self.transfer_config_id = transfer_config_id + self.retry = retry + self.request_timeout = request_timeout + self.metadata = metadata + self.expected_statuses = ( + {expected_statuses} + if isinstance(expected_statuses, str) + else expected_statuses + ) + self.project_id = project_id + self.gcp_cloud_conn_id = gcp_conn_id + + def poke(self, context): + hook = BiqQueryDataTransferServiceHook(gcp_conn_id=self.gcp_cloud_conn_id) + run = hook.get_transfer_run( + run_id=self.run_id, + transfer_config_id=self.transfer_config_id, + project_id=self.project_id, + retry=self.retry, + timeout=self.request_timeout, + metadata=self.metadata, + ) + result = MessageToDict(run) + state = result["state"] + self.log.info("Status of %s run: %s", self.run_id, state) + + return state in self.expected_statuses diff --git a/airflow/gcp/sensors/cloud_storage_transfer_service.py b/airflow/gcp/sensors/cloud_storage_transfer_service.py index 9a9cb2285daee6..628a141820cb4f 100644 --- a/airflow/gcp/sensors/cloud_storage_transfer_service.py +++ b/airflow/gcp/sensors/cloud_storage_transfer_service.py @@ -19,6 +19,7 @@ """ This module contains a Google Cloud Transfer sensor. """ +from typing import Set, Union, Optional from airflow.gcp.hooks.cloud_storage_transfer_service import GCPTransferServiceHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -52,13 +53,13 @@ class GCPTransferServiceWaitForJobStatusSensor(BaseSensorOperator): @apply_defaults def __init__( self, - job_name, - expected_statuses, - project_id=None, - gcp_conn_id='google_cloud_default', + job_name: str, + expected_statuses: Union[Set[str], str], + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', *args, **kwargs - ): + ) -> None: super().__init__(*args, **kwargs) self.job_name = job_name self.expected_statuses = ( diff --git a/airflow/gcp/sensors/gcs.py b/airflow/gcp/sensors/gcs.py new file mode 100644 index 00000000000000..a7abbc7a142f98 --- /dev/null +++ b/airflow/gcp/sensors/gcs.py @@ -0,0 +1,319 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains Google Cloud Storage sensors. +""" + +import os +from datetime import datetime +from typing import Callable, List, Optional + +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from airflow import AirflowException + + +class GoogleCloudStorageObjectSensor(BaseSensorOperator): + """ + Checks for the existence of a file in Google Cloud Storage. + + :param bucket: The Google cloud storage bucket where the object is. + :type bucket: str + :param object: The name of the object to check in the Google cloud + storage bucket. + :type object: str + :param google_cloud_conn_id: The connection ID to use when + connecting to Google cloud storage. + :type google_cloud_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + template_fields = ('bucket', 'object') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + bucket: str, + object: str, # pylint:disable=redefined-builtin + google_cloud_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + + super().__init__(*args, **kwargs) + self.bucket = bucket + self.object = object + self.google_cloud_conn_id = google_cloud_conn_id + self.delegate_to = delegate_to + + def poke(self, context): + self.log.info('Sensor checks existence of : %s, %s', self.bucket, self.object) + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.google_cloud_conn_id, + delegate_to=self.delegate_to) + return hook.exists(self.bucket, self.object) + + +def ts_function(context): + """ + Default callback for the GoogleCloudStorageObjectUpdatedSensor. The default + behaviour is check for the object being updated after execution_date + + schedule_interval. + """ + return context['dag'].following_schedule(context['execution_date']) + + +class GoogleCloudStorageObjectUpdatedSensor(BaseSensorOperator): + """ + Checks if an object is updated in Google Cloud Storage. + + :param bucket: The Google cloud storage bucket where the object is. + :type bucket: str + :param object: The name of the object to download in the Google cloud + storage bucket. + :type object: str + :param ts_func: Callback for defining the update condition. The default callback + returns execution_date + schedule_interval. The callback takes the context + as parameter. + :type ts_func: function + :param google_cloud_conn_id: The connection ID to use when + connecting to Google cloud storage. + :type google_cloud_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + """ + template_fields = ('bucket', 'object') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + bucket: str, + object: str, # pylint:disable=redefined-builtin + ts_func: Callable = ts_function, + google_cloud_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + + super().__init__(*args, **kwargs) + self.bucket = bucket + self.object = object + self.ts_func = ts_func + self.google_cloud_conn_id = google_cloud_conn_id + self.delegate_to = delegate_to + + def poke(self, context): + self.log.info('Sensor checks existence of : %s, %s', self.bucket, self.object) + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.google_cloud_conn_id, + delegate_to=self.delegate_to) + return hook.is_updated_after(self.bucket, self.object, self.ts_func(context)) + + +class GoogleCloudStoragePrefixSensor(BaseSensorOperator): + """ + Checks for the existence of GCS objects at a given prefix, passing matches via XCom. + + When files matching the given prefix are found, the poke method's criteria will be + fulfilled and the matching objects will be returned from the operator and passed + through XCom for downstream tasks. + + :param bucket: The Google cloud storage bucket where the object is. + :type bucket: str + :param prefix: The name of the prefix to check in the Google cloud + storage bucket. + :type prefix: str + :param google_cloud_conn_id: The connection ID to use when + connecting to Google cloud storage. + :type google_cloud_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + """ + template_fields = ('bucket', 'prefix') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + bucket: str, + prefix: str, + google_cloud_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.bucket = bucket + self.prefix = prefix + self.google_cloud_conn_id = google_cloud_conn_id + self.delegate_to = delegate_to + self._matches = [] # type: List[str] + + def poke(self, context): + self.log.info('Sensor checks existence of objects: %s, %s', + self.bucket, self.prefix) + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.google_cloud_conn_id, + delegate_to=self.delegate_to) + self._matches = hook.list(self.bucket, prefix=self.prefix) + return bool(self._matches) + + def execute(self, context): + """Overridden to allow matches to be passed""" + super(GoogleCloudStoragePrefixSensor, self).execute(context) + return self._matches + + +def get_time(): + """ + This is just a wrapper of datetime.datetime.now to simplify mocking in the + unittests. + """ + return datetime.now() + + +class GoogleCloudStorageUploadSessionCompleteSensor(BaseSensorOperator): + """ + Checks for changes in the number of objects at prefix in Google Cloud Storage + bucket and returns True if the inactivity period has passed with no + increase in the number of objects. Note, it is recommended to use reschedule + mode if you expect this sensor to run for hours. + + :param bucket: The Google cloud storage bucket where the objects are. + expected. + :type bucket: str + :param prefix: The name of the prefix to check in the Google cloud + storage bucket. + :param inactivity_period: The total seconds of inactivity to designate + an upload session is over. Note, this mechanism is not real time and + this operator may not return until a poke_interval after this period + has passed with no additional objects sensed. + :type inactivity_period: float + :param min_objects: The minimum number of objects needed for upload session + to be considered valid. + :type min_objects: int + :param previous_num_objects: The number of objects found during the last poke. + :type previous_num_objects: int + :param allow_delete: Should this sensor consider objects being deleted + between pokes valid behavior. If true a warning message will be logged + when this happens. If false an error will be raised. + :type allow_delete: bool + :param google_cloud_conn_id: The connection ID to use when connecting + to Google cloud storage. + :type google_cloud_conn_id: str + :param delegate_to: The account to impersonate, if any. For this to work, + the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + """ + + template_fields = ('bucket', 'prefix') + ui_color = '#f0eee4' + + @apply_defaults + def __init__(self, + bucket: str, + prefix: str, + inactivity_period: float = 60 * 60, + min_objects: int = 1, + previous_num_objects: int = 0, + allow_delete: bool = True, + google_cloud_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, **kwargs) -> None: + + super().__init__(*args, **kwargs) + + self.bucket = bucket + self.prefix = prefix + self.inactivity_period = inactivity_period + self.min_objects = min_objects + self.previous_num_objects = previous_num_objects + self.inactivity_seconds = 0 + self.allow_delete = allow_delete + self.google_cloud_conn_id = google_cloud_conn_id + self.delegate_to = delegate_to + self.last_activity_time = None + + def is_bucket_updated(self, current_num_objects: int) -> bool: + """ + Checks whether new objects have been uploaded and the inactivity_period + has passed and updates the state of the sensor accordingly. + + :param current_num_objects: number of objects in bucket during last poke. + :type current_num_objects: int + """ + + if current_num_objects > self.previous_num_objects: + # When new objects arrived, reset the inactivity_seconds + # previous_num_objects for the next poke. + self.log.info("New objects found at %s resetting last_activity_time.", + os.path.join(self.bucket, self.prefix)) + self.last_activity_time = get_time() + self.inactivity_seconds = 0 + self.previous_num_objects = current_num_objects + return False + + if current_num_objects < self.previous_num_objects: + # During the last poke interval objects were deleted. + if self.allow_delete: + self.previous_num_objects = current_num_objects + self.last_activity_time = get_time() + self.log.warning( + """ + Objects were deleted during the last + poke interval. Updating the file counter and + resetting last_activity_time. + """ + ) + return False + + raise AirflowException( + """ + Illegal behavior: objects were deleted in {} between pokes. + """.format(os.path.join(self.bucket, self.prefix)) + ) + + if self.last_activity_time: + self.inactivity_seconds = (get_time() - self.last_activity_time).total_seconds() + else: + # Handles the first poke where last inactivity time is None. + self.last_activity_time = get_time() + self.inactivity_seconds = 0 + + if self.inactivity_seconds >= self.inactivity_period: + path = os.path.join(self.bucket, self.prefix) + + if current_num_objects >= self.min_objects: + self.log.info("""SUCCESS: + Sensor found %s objects at %s. + Waited at least %s seconds, with no new objects dropped. + """, current_num_objects, path, self.inactivity_period) + return True + + self.log.warning("FAILURE: Inactivity Period passed, not enough objects found in %s", path) + + return False + return False + + def poke(self, context): + hook = GoogleCloudStorageHook() + return self.is_bucket_updated(len(hook.list(self.bucket, prefix=self.prefix))) diff --git a/airflow/gcp/sensors/pubsub.py b/airflow/gcp/sensors/pubsub.py index 16e921d29a6d9b..45af7c901c4850 100644 --- a/airflow/gcp/sensors/pubsub.py +++ b/airflow/gcp/sensors/pubsub.py @@ -19,6 +19,7 @@ """ This module contains a Google PubSub sensor. """ +from typing import Optional from airflow.gcp.hooks.pubsub import PubSubHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -68,15 +69,15 @@ class PubSubPullSensor(BaseSensorOperator): @apply_defaults def __init__( self, - project, - subscription, - max_messages=5, - return_immediately=False, - ack_messages=False, - gcp_conn_id='google_cloud_default', - delegate_to=None, + project: str, + subscription: str, + max_messages: int = 5, + return_immediately: bool = False, + ack_messages: bool = False, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.gcp_conn_id = gcp_conn_id diff --git a/airflow/gcp/utils/mlengine_operator_utils.py b/airflow/gcp/utils/mlengine_operator_utils.py index 66cdad8a171d4a..bc68e8deb2f671 100644 --- a/airflow/gcp/utils/mlengine_operator_utils.py +++ b/airflow/gcp/utils/mlengine_operator_utils.py @@ -1,18 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the 'License'); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 # -# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ This module contains helper functions for MLEngine operators. """ @@ -25,7 +28,7 @@ import dill -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.gcp.operators.mlengine import MLEngineBatchPredictionOperator from airflow.gcp.operators.dataflow import DataFlowPythonOperator from airflow.exceptions import AirflowException @@ -223,6 +226,7 @@ def validate_err_and_count(summary): "metric_fn_encoded": metric_fn_encoded, "metric_keys": ','.join(metric_keys) }, + py_interpreter='python2', dag=dag) evaluate_summary.set_upstream(evaluate_prediction) @@ -240,7 +244,6 @@ def apply_validate_fn(*args, **kwargs): evaluate_validation = PythonOperator( task_id=(task_prefix + "-validation"), python_callable=apply_validate_fn, - provide_context=True, templates_dict={"prediction_path": prediction_path}, dag=dag) evaluate_validation.set_upstream(evaluate_summary) diff --git a/airflow/gcp/utils/mlengine_prediction_summary.py b/airflow/gcp/utils/mlengine_prediction_summary.py index beca4b49621246..1a0853160133cd 100644 --- a/airflow/gcp/utils/mlengine_prediction_summary.py +++ b/airflow/gcp/utils/mlengine_prediction_summary.py @@ -1,20 +1,21 @@ # flake8: noqa: F841 # -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the 'License'); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. """A template called by DataFlowPythonOperator to summarize BatchPrediction. It accepts a user function to calculate the metric(s) per instance in diff --git a/airflow/kubernetes/k8s_model.py b/airflow/kubernetes/k8s_model.py new file mode 100644 index 00000000000000..bca1cc0b9e1250 --- /dev/null +++ b/airflow/kubernetes/k8s_model.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Classes for interacting with Kubernetes API +""" + +from abc import ABC, abstractmethod +from typing import List, Optional +from functools import reduce +import kubernetes.client.models as k8s + + +class K8SModel(ABC): + """ + These Airflow Kubernetes models are here for backwards compatibility + reasons only. Ideally clients should use the kubernetes api + and the process of + + client input -> Airflow k8s models -> k8s models + + can be avoided. All of these models implement the + `attach_to_pod` method so that they integrate with the kubernetes client. + """ + @abstractmethod + def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: + """ + :param pod: A pod to attach this Kubernetes object to + :type pod: kubernetes.client.models.V1Pod + :return: The pod with the object attached + """ + + +def append_to_pod(pod: k8s.V1Pod, k8s_objects: Optional[List[K8SModel]]): + """ + :param pod: A pod to attach a list of Kubernetes objects to + :type pod: kubernetes.client.models.V1Pod + :param k8s_objects: a potential None list of K8SModels + :type k8s_objects: Optional[List[K8SModel]] + :return: pod with the objects attached if they exist + """ + if not k8s_objects: + return pod + return reduce(lambda p, o: o.attach_to_pod(p), k8s_objects, pod) diff --git a/airflow/kubernetes/kube_client.py b/airflow/kubernetes/kube_client.py index 7a00be751f09f4..52d68f1324d4f6 100644 --- a/airflow/kubernetes/kube_client.py +++ b/airflow/kubernetes/kube_client.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Client for kubernetes communication""" from airflow.configuration import conf try: from kubernetes import config, client - from kubernetes.client.rest import ApiException + from kubernetes.client.rest import ApiException # pylint: disable=unused-import has_kubernetes = True except ImportError as e: # We need an exception class to be able to use it in ``except`` elsewhere @@ -41,6 +42,18 @@ def _load_kube_config(in_cluster, cluster_context, config_file): def get_kube_client(in_cluster=conf.getboolean('kubernetes', 'in_cluster'), cluster_context=None, config_file=None): + """ + Retrieves Kubernetes client + + :param in_cluster: whether we are in cluster + :type in_cluster: bool + :param cluster_context: context of the cluster + :type cluster_context: str + :param config_file: configuration file + :type config_file: str + :return kubernetes client + :rtype client.CoreV1Api + """ if not in_cluster: if cluster_context is None: cluster_context = conf.get('kubernetes', 'cluster_context', fallback=None) diff --git a/airflow/kubernetes/kubernetes_request_factory/kubernetes_request_factory.py b/airflow/kubernetes/kubernetes_request_factory/kubernetes_request_factory.py deleted file mode 100644 index b3ca99360b2190..00000000000000 --- a/airflow/kubernetes/kubernetes_request_factory/kubernetes_request_factory.py +++ /dev/null @@ -1,256 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from abc import ABCMeta, abstractmethod - - -class KubernetesRequestFactory(metaclass=ABCMeta): - """ - Create requests to be sent to kube API. - Extend this class to talk to kubernetes and generate your specific resources. - This is equivalent of generating yaml files that can be used by `kubectl` - """ - - @abstractmethod - def create(self, pod): - """ - Creates the request for kubernetes API. - - :param pod: The pod object - """ - - @staticmethod - def extract_image(pod, req): - req['spec']['containers'][0]['image'] = pod.image - - @staticmethod - def extract_image_pull_policy(pod, req): - if pod.image_pull_policy: - req['spec']['containers'][0]['imagePullPolicy'] = pod.image_pull_policy - - @staticmethod - def add_secret_to_env(env, secret): - env.append({ - 'name': secret.deploy_target, - 'valueFrom': { - 'secretKeyRef': { - 'name': secret.secret, - 'key': secret.key - } - } - }) - - @staticmethod - def add_runtime_info_env(env, runtime_info): - env.append({ - 'name': runtime_info.name, - 'valueFrom': { - 'fieldRef': { - 'fieldPath': runtime_info.field_path - } - } - }) - - @staticmethod - def extract_labels(pod, req): - req['metadata']['labels'] = req['metadata'].get('labels', {}) - for k, v in pod.labels.items(): - req['metadata']['labels'][k] = v - - @staticmethod - def extract_annotations(pod, req): - req['metadata']['annotations'] = req['metadata'].get('annotations', {}) - for k, v in pod.annotations.items(): - req['metadata']['annotations'][k] = v - - @staticmethod - def extract_affinity(pod, req): - req['spec']['affinity'] = req['spec'].get('affinity', {}) - for k, v in pod.affinity.items(): - req['spec']['affinity'][k] = v - - @staticmethod - def extract_node_selector(pod, req): - req['spec']['nodeSelector'] = req['spec'].get('nodeSelector', {}) - for k, v in pod.node_selectors.items(): - req['spec']['nodeSelector'][k] = v - - @staticmethod - def extract_cmds(pod, req): - req['spec']['containers'][0]['command'] = pod.cmds - - @staticmethod - def extract_args(pod, req): - req['spec']['containers'][0]['args'] = pod.args - - @staticmethod - def attach_ports(pod, req): - req['spec']['containers'][0]['ports'] = ( - req['spec']['containers'][0].get('ports', [])) - if len(pod.ports) > 0: - req['spec']['containers'][0]['ports'].extend(pod.ports) - - @staticmethod - def attach_volumes(pod, req): - req['spec']['volumes'] = ( - req['spec'].get('volumes', [])) - if len(pod.volumes) > 0: - req['spec']['volumes'].extend(pod.volumes) - - @staticmethod - def attach_volume_mounts(pod, req): - if len(pod.volume_mounts) > 0: - req['spec']['containers'][0]['volumeMounts'] = ( - req['spec']['containers'][0].get('volumeMounts', [])) - req['spec']['containers'][0]['volumeMounts'].extend(pod.volume_mounts) - - @staticmethod - def extract_name(pod, req): - req['metadata']['name'] = pod.name - - @staticmethod - def extract_volume_secrets(pod, req): - vol_secrets = [s for s in pod.secrets if s.deploy_type == 'volume'] - if any(vol_secrets): - req['spec']['containers'][0]['volumeMounts'] = ( - req['spec']['containers'][0].get('volumeMounts', [])) - req['spec']['volumes'] = ( - req['spec'].get('volumes', [])) - for idx, vol in enumerate(vol_secrets): - vol_id = 'secretvol' + str(idx) - req['spec']['containers'][0]['volumeMounts'].append({ - 'mountPath': vol.deploy_target, - 'name': vol_id, - 'readOnly': True - }) - req['spec']['volumes'].append({ - 'name': vol_id, - 'secret': { - 'secretName': vol.secret - } - }) - - @staticmethod - def extract_env_and_secrets(pod, req): - envs_from_key_secrets = [ - env for env in pod.secrets if env.deploy_type == 'env' and env.key is not None - ] - - if len(pod.envs) > 0 or len(envs_from_key_secrets) > 0 or len(pod.pod_runtime_info_envs) > 0: - env = [] - for k in pod.envs.keys(): - env.append({'name': k, 'value': pod.envs[k]}) - for secret in envs_from_key_secrets: - KubernetesRequestFactory.add_secret_to_env(env, secret) - for runtime_info in pod.pod_runtime_info_envs: - KubernetesRequestFactory.add_runtime_info_env(env, runtime_info) - - req['spec']['containers'][0]['env'] = env - - KubernetesRequestFactory._apply_env_from(pod, req) - - @staticmethod - def extract_resources(pod, req): - if not pod.resources or pod.resources.is_empty_resource_request(): - return - - req['spec']['containers'][0]['resources'] = {} - - if pod.resources.has_requests(): - req['spec']['containers'][0]['resources']['requests'] = {} - if pod.resources.request_memory: - req['spec']['containers'][0]['resources']['requests'][ - 'memory'] = pod.resources.request_memory - if pod.resources.request_cpu: - req['spec']['containers'][0]['resources']['requests'][ - 'cpu'] = pod.resources.request_cpu - - if pod.resources.has_limits(): - req['spec']['containers'][0]['resources']['limits'] = {} - if pod.resources.limit_memory: - req['spec']['containers'][0]['resources']['limits'][ - 'memory'] = pod.resources.limit_memory - if pod.resources.limit_cpu: - req['spec']['containers'][0]['resources']['limits'][ - 'cpu'] = pod.resources.limit_cpu - if pod.resources.limit_gpu: - req['spec']['containers'][0]['resources']['limits'][ - 'nvidia.com/gpu'] = pod.resources.limit_gpu - - @staticmethod - def extract_init_containers(pod, req): - if pod.init_containers: - req['spec']['initContainers'] = pod.init_containers - - @staticmethod - def extract_service_account_name(pod, req): - if pod.service_account_name: - req['spec']['serviceAccountName'] = pod.service_account_name - - @staticmethod - def extract_hostnetwork(pod, req): - if pod.hostnetwork: - req['spec']['hostNetwork'] = pod.hostnetwork - - @staticmethod - def extract_dnspolicy(pod, req): - if pod.dnspolicy: - req['spec']['dnsPolicy'] = pod.dnspolicy - - @staticmethod - def extract_image_pull_secrets(pod, req): - if pod.image_pull_secrets: - req['spec']['imagePullSecrets'] = [{ - 'name': pull_secret - } for pull_secret in pod.image_pull_secrets.split(',')] - - @staticmethod - def extract_tolerations(pod, req): - if pod.tolerations: - req['spec']['tolerations'] = pod.tolerations - - @staticmethod - def extract_security_context(pod, req): - if pod.security_context: - req['spec']['securityContext'] = pod.security_context - - @staticmethod - def _apply_env_from(pod, req): - envs_from_secrets = [ - env for env in pod.secrets if env.deploy_type == 'env' and env.key is None - ] - - if pod.configmaps or envs_from_secrets: - req['spec']['containers'][0]['envFrom'] = [] - - for secret in envs_from_secrets: - req['spec']['containers'][0]['envFrom'].append( - { - 'secretRef': { - 'name': secret.secret - } - } - ) - - for configmap in pod.configmaps: - req['spec']['containers'][0]['envFrom'].append( - { - 'configMapRef': { - 'name': configmap - } - } - ) diff --git a/airflow/kubernetes/kubernetes_request_factory/pod_request_factory.py b/airflow/kubernetes/kubernetes_request_factory/pod_request_factory.py deleted file mode 100644 index f75a724304aef9..00000000000000 --- a/airflow/kubernetes/kubernetes_request_factory/pod_request_factory.py +++ /dev/null @@ -1,142 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Dict - -import yaml -from airflow.kubernetes.pod import Pod -from airflow.kubernetes.kubernetes_request_factory.kubernetes_request_factory \ - import KubernetesRequestFactory - - -class SimplePodRequestFactory(KubernetesRequestFactory): - """ - Request generator for a pod. - - """ - _yaml = """apiVersion: v1 -kind: Pod -metadata: - name: name -spec: - containers: - - name: base - image: airflow-worker:latest - command: ["/usr/local/airflow/entrypoint.sh", "/bin/bash sleep 25"] - restartPolicy: Never - """ - - def __init__(self): - - pass - - def create(self, pod: Pod) -> Dict: - req = yaml.safe_load(self._yaml) - self.extract_name(pod, req) - self.extract_labels(pod, req) - self.extract_image(pod, req) - self.extract_image_pull_policy(pod, req) - self.extract_cmds(pod, req) - self.extract_args(pod, req) - self.extract_node_selector(pod, req) - self.extract_env_and_secrets(pod, req) - self.extract_volume_secrets(pod, req) - self.attach_ports(pod, req) - self.attach_volumes(pod, req) - self.attach_volume_mounts(pod, req) - self.extract_resources(pod, req) - self.extract_service_account_name(pod, req) - self.extract_init_containers(pod, req) - self.extract_image_pull_secrets(pod, req) - self.extract_annotations(pod, req) - self.extract_affinity(pod, req) - self.extract_hostnetwork(pod, req) - self.extract_tolerations(pod, req) - self.extract_security_context(pod, req) - self.extract_dnspolicy(pod, req) - return req - - -class ExtractXcomPodRequestFactory(KubernetesRequestFactory): - """ - Request generator for a pod with sidecar container. - - """ - - XCOM_MOUNT_PATH = '/airflow/xcom' - SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' - _yaml = """apiVersion: v1 -kind: Pod -metadata: - name: name -spec: - volumes: - - name: xcom - emptyDir: {{}} - containers: - - name: base - image: airflow-worker:latest - command: ["/usr/local/airflow/entrypoint.sh", "/bin/bash sleep 25"] - volumeMounts: - - name: xcom - mountPath: {xcomMountPath} - - name: {sidecarContainerName} - image: python:3.5-alpine - command: - - python - - -c - - | - import time - while True: - try: - time.sleep(3600) - except KeyboardInterrupt: - exit(0) - volumeMounts: - - name: xcom - mountPath: {xcomMountPath} - restartPolicy: Never - """.format(xcomMountPath=XCOM_MOUNT_PATH, sidecarContainerName=SIDECAR_CONTAINER_NAME) - - def __init__(self): - pass - - def create(self, pod: Pod) -> Dict: - req = yaml.safe_load(self._yaml) - self.extract_name(pod, req) - self.extract_labels(pod, req) - self.extract_image(pod, req) - self.extract_image_pull_policy(pod, req) - self.extract_cmds(pod, req) - self.extract_args(pod, req) - self.extract_node_selector(pod, req) - self.extract_env_and_secrets(pod, req) - self.extract_volume_secrets(pod, req) - self.attach_ports(pod, req) - self.attach_volumes(pod, req) - self.attach_volume_mounts(pod, req) - self.extract_resources(pod, req) - self.extract_service_account_name(pod, req) - self.extract_init_containers(pod, req) - self.extract_image_pull_secrets(pod, req) - self.extract_annotations(pod, req) - self.extract_affinity(pod, req) - self.extract_hostnetwork(pod, req) - self.extract_tolerations(pod, req) - self.extract_security_context(pod, req) - self.extract_dnspolicy(pod, req) - return req diff --git a/airflow/kubernetes/pod.py b/airflow/kubernetes/pod.py index b799c505f40855..bdce1ce50a042e 100644 --- a/airflow/kubernetes/pod.py +++ b/airflow/kubernetes/pod.py @@ -14,9 +14,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +Classes for interacting with Kubernetes API +""" +import copy +import kubernetes.client.models as k8s +from airflow.kubernetes.k8s_model import K8SModel -class Resources: + +class Resources(K8SModel): + """ + Stores information about resources used by the Pod. + + :param request_memory: requested memory + :type request_memory: str + :param request_cpu: requested CPU number + :type request_cpu: float | str + :param limit_memory: limit for memory usage + :type limit_memory: str + :param limit_cpu: Limit for CPU used + :type limit_cpu: float | str + :param limit_gpu: Limits for GPU used + :type limit_gpu: int + """ def __init__( self, request_memory=None, @@ -31,118 +52,50 @@ def __init__( self.limit_gpu = limit_gpu def is_empty_resource_request(self): + """Whether resource is empty""" return not self.has_limits() and not self.has_requests() def has_limits(self): + """Whether resource has limits""" return self.limit_cpu is not None or self.limit_memory is not None or self.limit_gpu is not None def has_requests(self): + """Whether resource has requests""" return self.request_cpu is not None or self.request_memory is not None - def __str__(self): - return "Request: [cpu: {}, memory: {}], Limit: [cpu: {}, memory: {}, gpu: {}]".format( - self.request_cpu, self.request_memory, self.limit_cpu, self.limit_memory, self.limit_gpu + def to_k8s_client_obj(self) -> k8s.V1ResourceRequirements: + """Converts to k8s client object""" + return k8s.V1ResourceRequirements( + limits={'cpu': self.limit_cpu, 'memory': self.limit_memory, 'nvidia.com/gpu': self.limit_gpu}, + requests={'cpu': self.request_cpu, 'memory': self.request_memory} ) + def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: + """Attaches to pod""" + cp_pod = copy.deepcopy(pod) + resources = self.to_k8s_client_obj() + cp_pod.spec.containers[0].resources = resources + return cp_pod + -class Port: +class Port(K8SModel): + """POD port""" def __init__( self, name=None, container_port=None): + """Creates port""" self.name = name self.container_port = container_port + def to_k8s_client_obj(self) -> k8s.V1ContainerPort: + """Converts to k8s client object""" + return k8s.V1ContainerPort(name=self.name, container_port=self.container_port) -class Pod: - """ - Represents a kubernetes pod and manages execution of a single pod. - :param image: The docker image - :type image: str - :param envs: A dict containing the environment variables - :type envs: dict - :param cmds: The command to be run on the pod - :type cmds: list[str] - :param secrets: Secrets to be launched to the pod - :type secrets: list[airflow.contrib.kubernetes.secret.Secret] - :param result: The result that will be returned to the operator after - successful execution of the pod - :type result: any - :param image_pull_policy: Specify a policy to cache or always pull an image - :type image_pull_policy: str - :param image_pull_secrets: Any image pull secrets to be given to the pod. - If more than one secret is required, provide a comma separated list: - secret_a,secret_b - :type image_pull_secrets: str - :param affinity: A dict containing a group of affinity scheduling rules - :type affinity: dict - :param hostnetwork: If True enable host networking on the pod - :type hostnetwork: bool - :param tolerations: A list of kubernetes tolerations - :type tolerations: list - :param security_context: A dict containing the security context for the pod - :type security_context: dict - :param configmaps: A list containing names of configmaps object - mounting env variables to the pod - :type configmaps: list[str] - :param pod_runtime_info_envs: environment variables about - pod runtime information (ip, namespace, nodeName, podName) - :type pod_runtime_info_envs: list[PodRuntimeEnv] - :param dnspolicy: Specify a dnspolicy for the pod - :type dnspolicy: str - """ - def __init__( - self, - image, - envs, - cmds, - args=None, - secrets=None, - labels=None, - node_selectors=None, - name=None, - ports=None, - volumes=None, - volume_mounts=None, - namespace='default', - result=None, - image_pull_policy='IfNotPresent', - image_pull_secrets=None, - init_containers=None, - service_account_name=None, - resources=None, - annotations=None, - affinity=None, - hostnetwork=False, - tolerations=None, - security_context=None, - configmaps=None, - pod_runtime_info_envs=None, - dnspolicy=None - ): - self.image = image - self.envs = envs or {} - self.cmds = cmds - self.args = args or [] - self.secrets = secrets or [] - self.result = result - self.labels = labels or {} - self.name = name - self.ports = ports or [] - self.volumes = volumes or [] - self.volume_mounts = volume_mounts or [] - self.node_selectors = node_selectors or {} - self.namespace = namespace - self.image_pull_policy = image_pull_policy - self.image_pull_secrets = image_pull_secrets - self.init_containers = init_containers - self.service_account_name = service_account_name - self.resources = resources or Resources() - self.annotations = annotations or {} - self.affinity = affinity or {} - self.hostnetwork = hostnetwork or False - self.tolerations = tolerations or [] - self.security_context = security_context - self.configmaps = configmaps or [] - self.pod_runtime_info_envs = pod_runtime_info_envs or [] - self.dnspolicy = dnspolicy + def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: + """Attaches to pod""" + cp_pod = copy.deepcopy(pod) + port = self.to_k8s_client_obj() + cp_pod.spec.containers[0].ports = cp_pod.spec.containers[0].ports or [] + cp_pod.spec.containers[0].ports.append(port) + return cp_pod diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index 5feb662036f952..d12b7525785df0 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -14,167 +14,337 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from airflow.kubernetes.pod import Pod, Port -from airflow.kubernetes.volume import Volume -from airflow.kubernetes.volume_mount import VolumeMount +""" +This module provides an interface between the previous Pod +API and outputs a kubernetes.client.models.V1Pod. +The advantage being that the full Kubernetes API +is supported and no serialization need be written. +""" + +import copy import uuid - -class PodGenerator: - """Contains Kubernetes Airflow Worker configuration logic""" - - def __init__(self, kube_config=None): - self.kube_config = kube_config - self.ports = [] - self.volumes = [] - self.volume_mounts = [] - self.init_containers = [] - - def add_init_container(self, - name, - image, - security_context, - init_environment, - volume_mounts - ): - """ - - Adds an init container to the launched pod. useful for pre- - - Args: - name (str): - image (str): - security_context (dict): - init_environment (dict): - volume_mounts (dict): - - Returns: - - """ - self.init_containers.append( - { - 'name': name, - 'image': image, - 'securityContext': security_context, - 'env': init_environment, - 'volumeMounts': volume_mounts +import kubernetes.client.models as k8s +from airflow.executors import Executors + + +class PodDefaults: + """ + Static defaults for the PodGenerator + """ + XCOM_MOUNT_PATH = '/airflow/xcom' + SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' + XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;' + VOLUME_MOUNT = k8s.V1VolumeMount( + name='xcom', + mount_path=XCOM_MOUNT_PATH + ) + VOLUME = k8s.V1Volume( + name='xcom', + empty_dir=k8s.V1EmptyDirVolumeSource() + ) + SIDECAR_CONTAINER = k8s.V1Container( + name=SIDECAR_CONTAINER_NAME, + command=['sh', '-c', XCOM_CMD], + image='alpine', + volume_mounts=[VOLUME_MOUNT], + resources=k8s.V1ResourceRequirements( + requests={ + "cpu": "1m", } - ) - - def _get_init_containers(self): - return self.init_containers - - def add_port(self, port: Port): - """ - Adds a Port to the generator - - :param port: ports for generated pod - :type port: airflow.kubernetes.pod.Port - """ - self.ports.append({'name': port.name, 'containerPort': port.container_port}) - - def add_volume(self, volume: Volume): - """ - Adds a Volume to the generator - - :param volume: volume for generated pod - :type volume: airflow.kubernetes.volume.Volume - """ - - self._add_volume(name=volume.name, configs=volume.configs) - - def _add_volume(self, name, configs): - """ - - Args: - name (str): - configs (dict): Configurations for the volume. - Could be used to define PersistentVolumeClaim, ConfigMap, etc... + ), + ) - Returns: - """ - volume_map = {'name': name} - for k, v in configs.items(): - volume_map[k] = v - - self.volumes.append(volume_map) +class PodGenerator: + """ + Contains Kubernetes Airflow Worker configuration logic + + Represents a kubernetes pod and manages execution of a single pod. + :param image: The docker image + :type image: str + :param envs: A dict containing the environment variables + :type envs: Dict[str, str] + :param cmds: The command to be run on the pod + :type cmds: List[str] + :param secrets: Secrets to be launched to the pod + :type secrets: List[airflow.kubernetes.models.secret.Secret] + :param image_pull_policy: Specify a policy to cache or always pull an image + :type image_pull_policy: str + :param image_pull_secrets: Any image pull secrets to be given to the pod. + If more than one secret is required, provide a comma separated list: + secret_a,secret_b + :type image_pull_secrets: str + :param affinity: A dict containing a group of affinity scheduling rules + :type affinity: dict + :param hostnetwork: If True enable host networking on the pod + :type hostnetwork: bool + :param tolerations: A list of kubernetes tolerations + :type tolerations: list + :param security_context: A dict containing the security context for the pod + :type security_context: dict + :param configmaps: Any configmap refs to envfrom. + If more than one configmap is required, provide a comma separated list + configmap_a,configmap_b + :type configmaps: str + :param dnspolicy: Specify a dnspolicy for the pod + :type dnspolicy: str + :param pod: The fully specified pod. + :type pod: kubernetes.client.models.V1Pod + """ + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + image, + name=None, + namespace=None, + volume_mounts=None, + envs=None, + cmds=None, + args=None, + labels=None, + node_selectors=None, + ports=None, + volumes=None, + image_pull_policy='IfNotPresent', + restart_policy='Never', + image_pull_secrets=None, + init_containers=None, + service_account_name=None, + resources=None, + annotations=None, + affinity=None, + hostnetwork=False, + tolerations=None, + security_context=None, + configmaps=None, + dnspolicy=None, + pod=None, + extract_xcom=False, + ): + self.ud_pod = pod + self.pod = k8s.V1Pod() + self.pod.api_version = 'v1' + self.pod.kind = 'Pod' + + # Pod Metadata + self.metadata = k8s.V1ObjectMeta() + self.metadata.labels = labels + self.metadata.name = name + "-" + str(uuid.uuid4())[:8] if name else None + self.metadata.namespace = namespace + self.metadata.annotations = annotations + + # Pod Container + self.container = k8s.V1Container(name='base') + self.container.image = image + self.container.env = [] + + if envs: + if isinstance(envs, dict): + for key, val in envs.items(): + self.container.env.append(k8s.V1EnvVar( + name=key, + value=val + )) + elif isinstance(envs, list): + self.container.env.extend(envs) + + configmaps = configmaps or [] + self.container.env_from = [] + for configmap in configmaps: + self.container.env_from.append(k8s.V1EnvFromSource( + config_map_ref=k8s.V1ConfigMapEnvSource( + name=configmap + ) + )) + + self.container.command = cmds or [] + self.container.args = args or [] + self.container.image_pull_policy = image_pull_policy + self.container.ports = ports or [] + self.container.resources = resources + self.container.volume_mounts = volume_mounts or [] + + # Pod Spec + self.spec = k8s.V1PodSpec(containers=[]) + self.spec.security_context = security_context + self.spec.tolerations = tolerations + self.spec.dns_policy = dnspolicy + self.spec.host_network = hostnetwork + self.spec.affinity = affinity + self.spec.service_account_name = service_account_name + self.spec.init_containers = init_containers + self.spec.volumes = volumes or [] + self.spec.node_selector = node_selectors + self.spec.restart_policy = restart_policy + + self.spec.image_pull_secrets = [] + + if image_pull_secrets: + for image_pull_secret in image_pull_secrets.split(','): + self.spec.image_pull_secrets.append(k8s.V1LocalObjectReference( + name=image_pull_secret + )) + + # Attach sidecar + self.extract_xcom = extract_xcom + + def gen_pod(self) -> k8s.V1Pod: + """Generates pod""" + result = self.ud_pod + + if result is None: + result = self.pod + result.spec = self.spec + result.metadata = self.metadata + result.spec.containers = [self.container] + + if self.extract_xcom: + result = self.add_sidecar(result) + + return result + + @staticmethod + def add_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod: + """Adds sidecar""" + pod_cp = copy.deepcopy(pod) + + pod_cp.spec.volumes.insert(0, PodDefaults.VOLUME) + pod_cp.spec.containers[0].volume_mounts.insert(0, PodDefaults.VOLUME_MOUNT) + pod_cp.spec.containers.append(PodDefaults.SIDECAR_CONTAINER) + + return pod_cp + + @staticmethod + def from_obj(obj) -> k8s.V1Pod: + """Converts to pod from obj""" + if obj is None: + return k8s.V1Pod() + + if isinstance(obj, PodGenerator): + return obj.gen_pod() + + if not isinstance(obj, dict): + raise TypeError( + 'Cannot convert a non-dictionary or non-PodGenerator ' + 'object into a KubernetesExecutorConfig') + + namespaced = obj.get(Executors.KubernetesExecutor, {}) + + resources = namespaced.get('resources') + + if resources is None: + requests = { + 'cpu': namespaced.get('request_cpu'), + 'memory': namespaced.get('request_memory') - def add_volume_with_configmap(self, name, config_map): - self.volumes.append( - { - 'name': name, - 'configMap': config_map } + limits = { + 'cpu': namespaced.get('limit_cpu'), + 'memory': namespaced.get('limit_memory') + } + all_resources = list(requests.values()) + list(limits.values()) + if all(r is None for r in all_resources): + resources = None + else: + resources = k8s.V1ResourceRequirements( + requests=requests, + limits=limits + ) + + annotations = namespaced.get('annotations', {}) + gcp_service_account_key = namespaced.get('gcp_service_account_key', None) + + if annotations is not None and gcp_service_account_key is not None: + annotations.update({ + 'iam.cloud.google.com/service-account': gcp_service_account_key + }) + + pod_spec_generator = PodGenerator( + image=namespaced.get('image'), + envs=namespaced.get('env'), + cmds=namespaced.get('cmds'), + args=namespaced.get('args'), + labels=namespaced.get('labels'), + node_selectors=namespaced.get('node_selectors'), + name=namespaced.get('name'), + ports=namespaced.get('ports'), + volumes=namespaced.get('volumes'), + volume_mounts=namespaced.get('volume_mounts'), + namespace=namespaced.get('namespace'), + image_pull_policy=namespaced.get('image_pull_policy'), + restart_policy=namespaced.get('restart_policy'), + image_pull_secrets=namespaced.get('image_pull_secrets'), + init_containers=namespaced.get('init_containers'), + service_account_name=namespaced.get('service_account_name'), + resources=resources, + annotations=namespaced.get('annotations'), + affinity=namespaced.get('affinity'), + hostnetwork=namespaced.get('hostnetwork'), + tolerations=namespaced.get('tolerations'), + security_context=namespaced.get('security_context'), + configmaps=namespaced.get('configmaps'), + dnspolicy=namespaced.get('dnspolicy'), + pod=namespaced.get('pod'), + extract_xcom=namespaced.get('extract_xcom'), ) - def _add_mount(self, - name, - mount_path, - sub_path, - read_only): - """ - - Args: - name (str): - mount_path (str): - sub_path (str): - read_only: - - Returns: + return pod_spec_generator.gen_pod() + @staticmethod + def reconcile_pods(base_pod: k8s.V1Pod, client_pod: k8s.V1Pod) -> k8s.V1Pod: """ - - self.volume_mounts.append({ - 'name': name, - 'mountPath': mount_path, - 'subPath': sub_path, - 'readOnly': read_only - }) - - def add_mount(self, - volume_mount: VolumeMount): + :param base_pod: has the base attributes which are overwritten if they exist + in the client pod and remain if they do not exist in the client_pod + :type base_pod: k8s.V1Pod + :param client_pod: the pod that the client wants to create. + :type client_pod: k8s.V1Pod + :return: the merged pods + + This can't be done recursively as certain fields are preserved, + some overwritten, and some concatenated, e.g. The command + should be preserved from base, the volumes appended to and + the other fields overwritten. """ - Adds a VolumeMount to the generator - :param volume_mount: volume for generated pod - :type volume_mount: airflow.kubernetes.volume_mount.VolumeMount - """ - self._add_mount( - name=volume_mount.name, - mount_path=volume_mount.mount_path, - sub_path=volume_mount.sub_path, - read_only=volume_mount.read_only - ) - - def _get_volumes_and_mounts(self): - return self.volumes, self.volume_mounts - - def _get_image_pull_secrets(self): - """Extracts any image pull secrets for fetching container(s)""" - if not self.kube_config.image_pull_secrets: - return [] - return self.kube_config.image_pull_secrets.split(',') - - def make_pod(self, namespace, image, pod_id, cmds, arguments, labels): - volumes, volume_mounts = self._get_volumes_and_mounts() - worker_init_container_spec = self._get_init_containers() - - return Pod( - namespace=namespace, - name=pod_id + "-" + str(uuid.uuid4())[:8], - image=image, - cmds=cmds, - args=arguments, - labels=labels, - envs={}, - secrets=[], - # service_account_name=self.kube_config.worker_service_account_name, - # image_pull_secrets=self.kube_config.image_pull_secrets, - init_containers=worker_init_container_spec, - ports=self.ports, - volumes=volumes, - volume_mounts=volume_mounts, - resources=None - ) + client_pod_cp = copy.deepcopy(client_pod) + + def merge_objects(base_obj, client_obj): + for base_key in base_obj.to_dict().keys(): + base_val = getattr(base_obj, base_key, None) + if not getattr(client_obj, base_key, None) and base_val: + setattr(client_obj, base_key, base_val) + + def extend_object_field(base_obj, client_obj, field_name): + base_obj_field = getattr(base_obj, field_name, None) + client_obj_field = getattr(client_obj, field_name, None) + if not base_obj_field: + return + if not client_obj_field: + setattr(client_obj, field_name, base_obj_field) + return + appended_fields = base_obj_field + client_obj_field + setattr(client_obj, field_name, appended_fields) + + # Values at the pod and metadata should be overwritten where they exist, + # but certain values at the spec and container level must be conserved. + base_container = base_pod.spec.containers[0] + client_container = client_pod_cp.spec.containers[0] + + extend_object_field(base_container, client_container, 'volume_mounts') + extend_object_field(base_container, client_container, 'env') + extend_object_field(base_container, client_container, 'env_from') + extend_object_field(base_container, client_container, 'ports') + extend_object_field(base_container, client_container, 'volume_devices') + client_container.command = base_container.command + client_container.args = base_container.args + merge_objects(base_pod.spec.containers[0], client_pod_cp.spec.containers[0]) + # Just append any additional containers from the base pod + client_pod_cp.spec.containers.extend(base_pod.spec.containers[1:]) + + merge_objects(base_pod.metadata, client_pod_cp.metadata) + + extend_object_field(base_pod.spec, client_pod_cp.spec, 'volumes') + merge_objects(base_pod.spec, client_pod_cp.spec) + merge_objects(base_pod, client_pod_cp) + + return client_pod_cp diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py index 87bac43935324a..3e97112d6dc7cb 100644 --- a/airflow/kubernetes/pod_launcher.py +++ b/airflow/kubernetes/pod_launcher.py @@ -14,27 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Launches PODs""" import json import time -import tenacity +from datetime import datetime as dt from typing import Tuple, Optional -from airflow.settings import pod_mutation_hook -from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.state import State -from datetime import datetime as dt -from airflow.kubernetes.pod import Pod -from airflow.kubernetes.kubernetes_request_factory import pod_request_factory as pod_factory +from requests.exceptions import BaseHTTPError + +import tenacity + from kubernetes import watch, client from kubernetes.client.rest import ApiException from kubernetes.stream import stream as kubernetes_stream +from kubernetes.client.models.v1_pod import V1Pod + +from airflow.settings import pod_mutation_hook +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import State from airflow import AirflowException -from requests.exceptions import BaseHTTPError +from airflow.kubernetes.pod_generator import PodDefaults + from .kube_client import get_kube_client class PodStatus: + """Status of the PODs""" PENDING = 'pending' RUNNING = 'running' FAILED = 'failed' @@ -42,33 +47,49 @@ class PodStatus: class PodLauncher(LoggingMixin): - def __init__(self, kube_client=None, in_cluster=True, cluster_context=None, - extract_xcom=False): + """Launches PODS""" + def __init__(self, + kube_client: client.CoreV1Api = None, + in_cluster: bool = True, + cluster_context: str = None, + extract_xcom: bool = False): + """ + Creates the launcher. + + :param kube_client: kubernetes client + :param in_cluster: whether we are in cluster + :param cluster_context: context of the cluster + :param extract_xcom: whether we should extract xcom + """ super().__init__() self._client = kube_client or get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context) self._watch = watch.Watch() self.extract_xcom = extract_xcom - self.kube_req_factory = pod_factory.ExtractXcomPodRequestFactory( - ) if extract_xcom else pod_factory.SimplePodRequestFactory() - def run_pod_async(self, pod, **kwargs): + def run_pod_async(self, pod: V1Pod, **kwargs): + """Runs POD asynchronously""" pod_mutation_hook(pod) - req = self.kube_req_factory.create(pod) - self.log.debug('Pod Creation Request: \n%s', json.dumps(req, indent=2)) + sanitized_pod = self._client.api_client.sanitize_for_serialization(pod) + json_pod = json.dumps(sanitized_pod, indent=2) + + self.log.debug('Pod Creation Request: \n%s', json_pod) try: - resp = self._client.create_namespaced_pod(body=req, namespace=pod.namespace, **kwargs) + resp = self._client.create_namespaced_pod(body=sanitized_pod, + namespace=pod.metadata.namespace, **kwargs) self.log.debug('Pod Creation Response: %s', resp) - except ApiException: - self.log.exception('Exception when attempting to create Namespaced Pod.') - raise + except Exception as e: + self.log.exception('Exception when attempting ' + 'to create Namespaced Pod: %s', json_pod) + raise e return resp - def delete_pod(self, pod): + def delete_pod(self, pod: V1Pod): + """Deletes POD""" try: self._client.delete_namespaced_pod( - pod.name, pod.namespace, body=client.V1DeleteOptions()) + pod.metadata.name, pod.metadata.namespace, body=client.V1DeleteOptions()) except ApiException as e: # If the pod is already deleted if e.status != 404: @@ -76,7 +97,7 @@ def delete_pod(self, pod): def run_pod( self, - pod: Pod, + pod: V1Pod, startup_timeout: int = 120, get_logs: bool = True) -> Tuple[State, Optional[str]]: """ @@ -99,7 +120,7 @@ def run_pod( return self._monitor_pod(pod, get_logs) - def _monitor_pod(self, pod: Pod, get_logs: bool) -> Tuple[State, Optional[str]]: + def _monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]]: if get_logs: logs = self.read_pod_logs(pod) for line in logs: @@ -107,13 +128,13 @@ def _monitor_pod(self, pod: Pod, get_logs: bool) -> Tuple[State, Optional[str]]: result = None if self.extract_xcom: while self.base_container_is_running(pod): - self.log.info('Container %s has state %s', pod.name, State.RUNNING) + self.log.info('Container %s has state %s', pod.metadata.name, State.RUNNING) time.sleep(2) result = self._extract_xcom(pod) self.log.info(result) result = json.loads(result) while self.pod_is_running(pod): - self.log.info('Pod %s has state %s', pod.name, State.RUNNING) + self.log.info('Pod %s has state %s', pod.metadata.name, State.RUNNING) time.sleep(2) return self._task_status(self.read_pod(pod)), result @@ -124,18 +145,23 @@ def _task_status(self, event): status = self.process_status(event.metadata.name, event.status.phase) return status - def pod_not_started(self, pod): + def pod_not_started(self, pod: V1Pod): + """Tests if pod has not started""" state = self._task_status(self.read_pod(pod)) return state == State.QUEUED - def pod_is_running(self, pod): + def pod_is_running(self, pod: V1Pod): + """Tests if pod is running""" state = self._task_status(self.read_pod(pod)) - return state != State.SUCCESS and state != State.FAILED + return state not in (State.SUCCESS, State.FAILED) - def base_container_is_running(self, pod): + def base_container_is_running(self, pod: V1Pod): + """Tests if base container is running""" event = self.read_pod(pod) status = next(iter(filter(lambda s: s.name == 'base', event.status.container_statuses)), None) + if not status: + return False return status.state.running is not None @tenacity.retry( @@ -143,12 +169,12 @@ def base_container_is_running(self, pod): wait=tenacity.wait_exponential(), reraise=True ) - def read_pod_logs(self, pod): - + def read_pod_logs(self, pod: V1Pod): + """Reads log from the POD""" try: return self._client.read_namespaced_pod_log( - name=pod.name, - namespace=pod.namespace, + name=pod.metadata.name, + namespace=pod.metadata.namespace, container='base', follow=True, tail_lines=10, @@ -164,29 +190,30 @@ def read_pod_logs(self, pod): wait=tenacity.wait_exponential(), reraise=True ) - def read_pod(self, pod): + def read_pod(self, pod: V1Pod): + """Read POD information""" try: - return self._client.read_namespaced_pod(pod.name, pod.namespace) + return self._client.read_namespaced_pod(pod.metadata.name, pod.metadata.namespace) except BaseHTTPError as e: raise AirflowException( 'There was an error reading the kubernetes API: {}'.format(e) ) - def _extract_xcom(self, pod): + def _extract_xcom(self, pod: V1Pod): resp = kubernetes_stream(self._client.connect_get_namespaced_pod_exec, - pod.name, pod.namespace, - container=self.kube_req_factory.SIDECAR_CONTAINER_NAME, + pod.metadata.name, pod.metadata.namespace, + container=PodDefaults.SIDECAR_CONTAINER_NAME, command=['/bin/sh'], stdin=True, stdout=True, stderr=True, tty=False, _preload_content=False) try: result = self._exec_pod_command( - resp, 'cat {}/return.json'.format(self.kube_req_factory.XCOM_MOUNT_PATH)) + resp, 'cat {}/return.json'.format(PodDefaults.XCOM_MOUNT_PATH)) self._exec_pod_command(resp, 'kill -s SIGINT 1') finally: resp.close() if result is None: - raise AirflowException('Failed to extract xcom from pod: {}'.format(pod.name)) + raise AirflowException('Failed to extract xcom from pod: {}'.format(pod.metadata.name)) return result def _exec_pod_command(self, resp, command): @@ -200,8 +227,10 @@ def _exec_pod_command(self, resp, command): if resp.peek_stderr(): self.log.info(resp.read_stderr()) break + return None def process_status(self, job_id, status): + """Process status infomration for the JOB""" status = status.lower() if status == PodStatus.PENDING: return State.QUEUED diff --git a/airflow/kubernetes/pod_runtime_info_env.py b/airflow/kubernetes/pod_runtime_info_env.py index f52791ed43c918..bf1320dc431716 100644 --- a/airflow/kubernetes/pod_runtime_info_env.py +++ b/airflow/kubernetes/pod_runtime_info_env.py @@ -15,11 +15,15 @@ # specific language governing permissions and limitations # under the License. """ -Classes for using Kubernetes Downward API +Classes for interacting with Kubernetes API """ +import copy +import kubernetes.client.models as k8s +from airflow.kubernetes.k8s_model import K8SModel -class PodRuntimeInfoEnv: + +class PodRuntimeInfoEnv(K8SModel): """Defines Pod runtime information as environment variable""" def __init__(self, name, field_path): @@ -34,3 +38,23 @@ def __init__(self, name, field_path): """ self.name = name self.field_path = field_path + + def to_k8s_client_obj(self) -> k8s.V1EnvVar: + """ + :return: kubernetes.client.models.V1EnvVar + """ + return k8s.V1EnvVar( + name=self.name, + value_from=k8s.V1EnvVarSource( + field_ref=k8s.V1ObjectFieldSelector( + self.field_path + ) + ) + ) + + def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: + cp_pod = copy.deepcopy(pod) + env = self.to_k8s_client_obj() + cp_pod.spec.containers[0].env = cp_pod.spec.containers[0].env or [] + cp_pod.spec.containers[0].env.append(env) + return cp_pod diff --git a/airflow/kubernetes/secret.py b/airflow/kubernetes/secret.py index cc96915e3c5a8f..3d33739cbd6729 100644 --- a/airflow/kubernetes/secret.py +++ b/airflow/kubernetes/secret.py @@ -14,10 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +Classes for interacting with Kubernetes API +""" + +import uuid +import copy +from typing import Tuple +import kubernetes.client.models as k8s from airflow.exceptions import AirflowConfigException +from airflow.kubernetes.k8s_model import K8SModel -class Secret: +class Secret(K8SModel): """Defines Kubernetes Secret Volume""" def __init__(self, deploy_type, deploy_target, secret, key=None): @@ -36,6 +45,9 @@ def __init__(self, deploy_type, deploy_target, secret, key=None): if not provided in `deploy_type` `env` it will mount all secrets in object :type key: str or None """ + if deploy_type not in ('env', 'volume'): + raise AirflowConfigException("deploy_type must be env or volume") + self.deploy_type = deploy_type self.deploy_target = deploy_target @@ -51,6 +63,60 @@ def __init__(self, deploy_type, deploy_target, secret, key=None): self.secret = secret self.key = key + def to_env_secret(self) -> k8s.V1EnvVar: + """Stores es environment secret""" + return k8s.V1EnvVar( + name=self.deploy_target, + value_from=k8s.V1EnvVarSource( + secret_key_ref=k8s.V1SecretKeySelector( + name=self.secret, + key=self.key + ) + ) + ) + + def to_env_from_secret(self) -> k8s.V1EnvFromSource: + """Reads from environment to secret""" + return k8s.V1EnvFromSource( + secret_ref=k8s.V1SecretEnvSource(name=self.secret) + ) + + def to_volume_secret(self) -> Tuple[k8s.V1Volume, k8s.V1VolumeMount]: + """Converts to volume secret""" + vol_id = 'secretvol{}'.format(uuid.uuid4()) + return ( + k8s.V1Volume( + name=vol_id, + secret=k8s.V1SecretVolumeSource( + secret_name=self.secret + ) + ), + k8s.V1VolumeMount( + mount_path=self.deploy_target, + name=vol_id, + read_only=True + ) + ) + + def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: + """Attaches to pod""" + cp_pod = copy.deepcopy(pod) + if self.deploy_type == 'volume': + volume, volume_mount = self.to_volume_secret() + cp_pod.spec.volumes = pod.spec.volumes or [] + cp_pod.spec.volumes.append(volume) + cp_pod.spec.containers[0].volume_mounts = pod.spec.containers[0].volume_mounts or [] + cp_pod.spec.containers[0].volume_mounts.append(volume_mount) + if self.deploy_type == 'env' and self.key is not None: + env = self.to_env_secret() + cp_pod.spec.containers[0].env = cp_pod.spec.containers[0].env or [] + cp_pod.spec.containers[0].env.append(env) + if self.deploy_type == 'env' and self.key is None: + env_from = self.to_env_from_secret() + cp_pod.spec.containers[0].env_from = cp_pod.spec.containers[0].env_from or [] + cp_pod.spec.containers[0].env_from.append(env_from) + return cp_pod + def __eq__(self, other): return ( self.deploy_type == other.deploy_type and diff --git a/airflow/kubernetes/volume.py b/airflow/kubernetes/volume.py index 94003fe48dcb39..679671cc87b335 100644 --- a/airflow/kubernetes/volume.py +++ b/airflow/kubernetes/volume.py @@ -14,20 +14,42 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +Classes for interacting with Kubernetes API +""" +import copy +from typing import Dict +import kubernetes.client.models as k8s +from airflow.kubernetes.k8s_model import K8SModel -class Volume: - """Defines Kubernetes Volume""" - def __init__(self, name, configs): - """ Adds Kubernetes Volume to pod. allows pod to access features like ConfigMaps - and Persistent Volumes - :param name: the name of the volume mount - :type name: str - :param configs: dictionary of any features needed for volume. +class Volume(K8SModel): + """ + Adds Kubernetes Volume to pod. allows pod to access features like ConfigMaps + and Persistent Volumes + + :param name: the name of the volume mount + :type name: str + :param configs: dictionary of any features needed for volume. We purposely keep this vague since there are multiple volume types with changing configs. - :type configs: dict - """ + :type configs: dict + """ + def __init__(self, name, configs): self.name = name self.configs = configs + + def to_k8s_client_obj(self) -> Dict[str, str]: + """Converts to k8s object""" + return { + 'name': self.name, + **self.configs + } + + def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: + cp_pod = copy.deepcopy(pod) + volume = self.to_k8s_client_obj() + cp_pod.spec.volumes = pod.spec.volumes or [] + cp_pod.spec.volumes.append(volume) + return cp_pod diff --git a/airflow/kubernetes/volume_mount.py b/airflow/kubernetes/volume_mount.py index 4bdf09c07c0e57..3f7b2b5bef5ec9 100644 --- a/airflow/kubernetes/volume_mount.py +++ b/airflow/kubernetes/volume_mount.py @@ -14,24 +14,58 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +Classes for interacting with Kubernetes API +""" +import copy +import kubernetes.client.models as k8s +from airflow.kubernetes.k8s_model import K8SModel -class VolumeMount: - """Defines Kubernetes Volume Mount""" +class VolumeMount(K8SModel): + """ + Initialize a Kubernetes Volume Mount. Used to mount pod level volumes to + running container. + + :param name: the name of the volume mount + :type name: str + :param mount_path: + :type mount_path: str + :param sub_path: subpath within the volume mount + :type sub_path: str + :param read_only: whether to access pod with read-only mode + :type read_only: bool + """ def __init__(self, name, mount_path, sub_path, read_only): - """Initialize a Kubernetes Volume Mount. Used to mount pod level volumes to - running container. - :param name: the name of the volume mount - :type name: str - :param mount_path: - :type mount_path: str - :param sub_path: subpath within the volume mount - :type sub_path: str - :param read_only: whether to access pod with read-only mode - :type read_only: bool - """ self.name = name self.mount_path = mount_path self.sub_path = sub_path self.read_only = read_only + + def to_k8s_client_obj(self) -> k8s.V1VolumeMount: + """ + Converts to k8s object. + + :return Volume Mount k8s object + + """ + return k8s.V1VolumeMount( + name=self.name, + mount_path=self.mount_path, + sub_path=self.sub_path, + read_only=self.read_only + ) + + def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: + """ + Attaches to pod + + :return Copy of the Pod object + + """ + cp_pod = copy.deepcopy(pod) + volume_mount = self.to_k8s_client_obj() + cp_pod.spec.containers[0].volume_mounts = pod.spec.containers[0].volume_mounts or [] + cp_pod.spec.containers[0].volume_mounts.append(volume_mount) + return cp_pod diff --git a/airflow/kubernetes/worker_configuration.py b/airflow/kubernetes/worker_configuration.py index 396abbcbc8df33..90e36f206ac983 100644 --- a/airflow/kubernetes/worker_configuration.py +++ b/airflow/kubernetes/worker_configuration.py @@ -14,13 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Configuration of the worker""" import os +from typing import List, Dict +import kubernetes.client.models as k8s from airflow.configuration import conf -from airflow.kubernetes.pod import Pod, Resources -from airflow.kubernetes.secret import Secret +from airflow.kubernetes.pod_generator import PodGenerator from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.kubernetes.secret import Secret +from airflow.kubernetes.k8s_model import append_to_pod class WorkerConfiguration(LoggingMixin): @@ -38,10 +41,9 @@ def __init__(self, kube_config): self.worker_airflow_home = self.kube_config.airflow_home self.worker_airflow_dags = self.kube_config.dags_folder self.worker_airflow_logs = self.kube_config.base_log_folder - super().__init__() - def _get_init_containers(self): + def _get_init_containers(self) -> List[k8s.V1Container]: """When using git to retrieve the DAGs, use the GitSync Init Container""" # If we're using volume claims to mount the dags, no init container is needed if self.kube_config.dags_volume_claim or \ @@ -49,115 +51,114 @@ def _get_init_containers(self): return [] # Otherwise, define a git-sync init container - init_environment = [{ - 'name': 'GIT_SYNC_REPO', - 'value': self.kube_config.git_repo - }, { - 'name': 'GIT_SYNC_BRANCH', - 'value': self.kube_config.git_branch - }, { - 'name': 'GIT_SYNC_ROOT', - 'value': self.kube_config.git_sync_root - }, { - 'name': 'GIT_SYNC_DEST', - 'value': self.kube_config.git_sync_dest - }, { - 'name': 'GIT_SYNC_DEPTH', - 'value': '1' - }, { - 'name': 'GIT_SYNC_ONE_TIME', - 'value': 'true' - }] + init_environment = [k8s.V1EnvVar( + name='GIT_SYNC_REPO', + value=self.kube_config.git_repo + ), k8s.V1EnvVar( + name='GIT_SYNC_BRANCH', + value=self.kube_config.git_branch + ), k8s.V1EnvVar( + name='GIT_SYNC_ROOT', + value=self.kube_config.git_sync_root + ), k8s.V1EnvVar( + name='GIT_SYNC_DEST', + value=self.kube_config.git_sync_dest + ), k8s.V1EnvVar( + name='GIT_SYNC_DEPTH', + value='1' + ), k8s.V1EnvVar( + name='GIT_SYNC_ONE_TIME', + value='true' + )] if self.kube_config.git_user: - init_environment.append({ - 'name': 'GIT_SYNC_USERNAME', - 'value': self.kube_config.git_user - }) + init_environment.append(k8s.V1EnvVar( + name='GIT_SYNC_USERNAME', + value=self.kube_config.git_user + )) if self.kube_config.git_password: - init_environment.append({ - 'name': 'GIT_SYNC_PASSWORD', - 'value': self.kube_config.git_password - }) + init_environment.append(k8s.V1EnvVar( + name='GIT_SYNC_PASSWORD', + value=self.kube_config.git_password + )) + + volume_mounts = [k8s.V1VolumeMount( + mount_path=self.kube_config.git_sync_root, + name=self.dags_volume_name, + read_only=False + )] if self.kube_config.git_sync_credentials_secret: init_environment.extend([ - { - 'name': 'GIT_SYNC_USERNAME', - 'valueFrom': { - 'secretKeyRef': { - 'name': self.kube_config.git_sync_credentials_secret, - 'key': 'GIT_SYNC_USERNAME' - } - } - }, - { - 'name': 'GIT_SYNC_PASSWORD', - 'valueFrom': { - 'secretKeyRef': { - 'name': self.kube_config.git_sync_credentials_secret, - 'key': 'GIT_SYNC_PASSWORD' - } - } - } + k8s.V1EnvVar( + name='GIT_SYNC_USERNAME', + value_from=k8s.V1EnvVarSource( + secret_key_ref=k8s.V1SecretKeySelector( + name=self.kube_config.git_sync_credentials_secret, + key='GIT_SYNC_USERNAME') + ) + ), + k8s.V1EnvVar( + name='GIT_SYNC_PASSWORD', + value_from=k8s.V1EnvVarSource( + secret_key_ref=k8s.V1SecretKeySelector( + name=self.kube_config.git_sync_credentials_secret, + key='GIT_SYNC_PASSWORD') + ) + ) ]) - volume_mounts = [{ - 'mountPath': self.kube_config.git_sync_root, - 'name': self.dags_volume_name, - 'readOnly': False - }] if self.kube_config.git_ssh_key_secret_name: - volume_mounts.append({ - 'name': self.git_sync_ssh_secret_volume_name, - 'mountPath': '/etc/git-secret/ssh', - 'subPath': 'ssh' - }) - init_environment.extend([ - { - 'name': 'GIT_SSH_KEY_FILE', - 'value': '/etc/git-secret/ssh' - }, - { - 'name': 'GIT_SYNC_SSH', - 'value': 'true' - }]) - if self.kube_config.git_ssh_known_hosts_configmap_name: - volume_mounts.append({ - 'name': self.git_sync_ssh_known_hosts_volume_name, - 'mountPath': '/etc/git-secret/known_hosts', - 'subPath': 'known_hosts' - }) + volume_mounts.append(k8s.V1VolumeMount( + name=self.git_sync_ssh_secret_volume_name, + mount_path='/etc/git-secret/ssh', + sub_path='ssh' + )) + init_environment.extend([ - { - 'name': 'GIT_KNOWN_HOSTS', - 'value': 'true' - }, - { - 'name': 'GIT_SSH_KNOWN_HOSTS_FILE', - 'value': '/etc/git-secret/known_hosts' - } + k8s.V1EnvVar( + name='GIT_SSH_KEY_FILE', + value='/etc/git-secret/ssh' + ), + k8s.V1EnvVar( + name='GIT_SYNC_SSH', + value='true' + ) ]) + + if self.kube_config.git_ssh_known_hosts_configmap_name: + volume_mounts.append(k8s.V1VolumeMount( + name=self.git_sync_ssh_known_hosts_volume_name, + mount_path='/etc/git-secret/known_hosts', + sub_path='known_hosts' + )) + init_environment.extend([k8s.V1EnvVar( + name='GIT_KNOWN_HOSTS', + value='true' + ), k8s.V1EnvVar( + name='GIT_SSH_KNOWN_HOSTS_FILE', + value='/etc/git-secret/known_hosts' + )]) else: - init_environment.append({ - 'name': 'GIT_KNOWN_HOSTS', - 'value': 'false' - }) - - init_containers = [{ - 'name': self.kube_config.git_sync_init_container_name, - 'image': self.kube_config.git_sync_container, - 'env': init_environment, - 'volumeMounts': volume_mounts - }] + init_environment.append(k8s.V1EnvVar( + name='GIT_KNOWN_HOSTS', + value='false' + )) + + init_containers = k8s.V1Container( + name=self.kube_config.git_sync_init_container_name, + image=self.kube_config.git_sync_container, + env=init_environment, + volume_mounts=volume_mounts + ) if self.kube_config.git_sync_run_as_user != "": - init_containers[0]['securityContext'] = { - 'runAsUser': self.kube_config.git_sync_run_as_user # git-sync user - } + init_containers.security_context = k8s.V1SecurityContext( + run_as_user=self.kube_config.git_sync_run_as_user or 65533 + ) # git-sync user - return init_containers + return [init_containers] - def _get_environment(self): + def _get_environment(self) -> Dict[str, str]: """Defines any necessary environment variables for the pod executor""" env = {} @@ -182,11 +183,23 @@ def _get_environment(self): env['AIRFLOW__CORE__DAGS_FOLDER'] = dag_volume_mount_path return env - def _get_configmaps(self): + def _get_env_from(self) -> List[k8s.V1EnvFromSource]: """Extracts any configmapRefs to envFrom""" - if not self.kube_config.env_from_configmap_ref: - return [] - return self.kube_config.env_from_configmap_ref.split(',') + env_from = [] + + if self.kube_config.env_from_configmap_ref: + for config_map_ref in self.kube_config.env_from_configmap_ref.split(','): + env_from.append( + k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(config_map_ref)) + ) + + if self.kube_config.env_from_secret_ref: + for secret_ref in self.kube_config.env_from_secret_ref.split(','): + env_from.append( + k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(secret_ref)) + ) + + return env_from def _get_secrets(self): """Defines any necessary secrets for the pod executor""" @@ -206,50 +219,87 @@ def _get_secrets(self): return worker_secrets - def _get_image_pull_secrets(self): + def _get_image_pull_secrets(self) -> List[k8s.V1LocalObjectReference]: """Extracts any image pull secrets for fetching container(s)""" if not self.kube_config.image_pull_secrets: return [] - return self.kube_config.image_pull_secrets.split(',') + pull_secrets = self.kube_config.image_pull_secrets.split(',') + return list(map(k8s.V1LocalObjectReference, pull_secrets)) - def _get_security_context(self): + def _get_security_context(self) -> k8s.V1PodSecurityContext: """Defines the security context""" - security_context = {} + + security_context = k8s.V1PodSecurityContext() if self.kube_config.worker_run_as_user != "": - security_context['runAsUser'] = self.kube_config.worker_run_as_user + security_context.run_as_user = self.kube_config.worker_run_as_user if self.kube_config.worker_fs_group != "": - security_context['fsGroup'] = self.kube_config.worker_fs_group + security_context.fs_group = self.kube_config.worker_fs_group # set fs_group to 65533 if not explicitly specified and using git ssh keypair auth - if self.kube_config.git_ssh_key_secret_name and security_context.get('fsGroup') is None: - security_context['fsGroup'] = 65533 + if self.kube_config.git_ssh_key_secret_name and security_context.fs_group is None: + security_context.fs_group = 65533 return security_context - def _get_labels(self, kube_executor_labels, labels): + def _get_labels(self, kube_executor_labels, labels) -> k8s.V1LabelSelector: copy = self.kube_config.kube_labels.copy() copy.update(kube_executor_labels) copy.update(labels) return copy - def _get_volumes_and_mounts(self): - def _construct_volume(name, claim, host): - volume = { - 'name': name - } + def _get_volume_mounts(self) -> List[k8s.V1VolumeMount]: + volume_mounts = { + self.dags_volume_name: k8s.V1VolumeMount( + name=self.dags_volume_name, + mount_path=self.generate_dag_volume_mount_path(), + read_only=True, + ), + self.logs_volume_name: k8s.V1VolumeMount( + name=self.logs_volume_name, + mount_path=self.worker_airflow_logs, + ) + } + + if self.kube_config.dags_volume_subpath: + volume_mounts[self.dags_volume_name].sub_path = self.kube_config.dags_volume_subpath + + if self.kube_config.logs_volume_subpath: + volume_mounts[self.logs_volume_name].sub_path = self.kube_config.logs_volume_subpath + + if self.kube_config.dags_in_image: + del volume_mounts[self.dags_volume_name] + + # Mount the airflow.cfg file via a configmap the user has specified + if self.kube_config.airflow_configmap: + config_volume_name = 'airflow-config' + config_path = '{}/airflow.cfg'.format(self.worker_airflow_home) + volume_mounts[config_volume_name] = k8s.V1VolumeMount( + name=config_volume_name, + mount_path=config_path, + sub_path='airflow.cfg', + read_only=True + ) + + return list(volume_mounts.values()) + + def _get_volumes(self) -> List[k8s.V1Volume]: + def _construct_volume(name, claim, host) -> k8s.V1Volume: + volume = k8s.V1Volume(name=name) + if claim: - volume['persistentVolumeClaim'] = { - 'claimName': claim - } + volume.persistent_volume_claim = k8s.V1PersistentVolumeClaimVolumeSource( + claim_name=claim + ) elif host: - volume['hostPath'] = { - 'path': host, - 'type': '' - } + volume.host_path = k8s.V1HostPathVolumeSource( + path=host, + type='' + ) else: - volume['emptyDir'] = {} + volume.empty_dir = {} + return volume volumes = { @@ -265,127 +315,81 @@ def _construct_volume(name, claim, host): ) } - volume_mounts = { - self.dags_volume_name: { - 'name': self.dags_volume_name, - 'mountPath': self.generate_dag_volume_mount_path(), - 'readOnly': True, - }, - self.logs_volume_name: { - 'name': self.logs_volume_name, - 'mountPath': self.worker_airflow_logs, - } - } - - if self.kube_config.dags_volume_subpath: - volume_mounts[self.dags_volume_name]['subPath'] = self.kube_config.dags_volume_subpath - - if self.kube_config.logs_volume_subpath: - volume_mounts[self.logs_volume_name]['subPath'] = self.kube_config.logs_volume_subpath - if self.kube_config.dags_in_image: del volumes[self.dags_volume_name] - del volume_mounts[self.dags_volume_name] # Get the SSH key from secrets as a volume if self.kube_config.git_ssh_key_secret_name: - volumes[self.git_sync_ssh_secret_volume_name] = { - 'name': self.git_sync_ssh_secret_volume_name, - 'secret': { - 'secretName': self.kube_config.git_ssh_key_secret_name, - 'items': [{ - 'key': self.git_ssh_key_secret_key, - 'path': 'ssh', - 'mode': 0o440 - }] - } - } + volumes[self.git_sync_ssh_secret_volume_name] = k8s.V1Volume( + name=self.git_sync_ssh_secret_volume_name, + secret=k8s.V1SecretVolumeSource( + secret_name=self.kube_config.git_ssh_key_secret_name, + items=[k8s.V1KeyToPath( + key=self.git_ssh_key_secret_key, + path='ssh', + mode=0o440 + )] + ) + ) if self.kube_config.git_ssh_known_hosts_configmap_name: - volumes[self.git_sync_ssh_known_hosts_volume_name] = { - 'name': self.git_sync_ssh_known_hosts_volume_name, - 'configMap': { - 'name': self.kube_config.git_ssh_known_hosts_configmap_name - }, - 'mode': 0o440 - } + volumes[self.git_sync_ssh_known_hosts_volume_name] = k8s.V1Volume( + name=self.git_sync_ssh_known_hosts_volume_name, + config_map=k8s.V1ConfigMapVolumeSource( + name=self.kube_config.git_ssh_known_hosts_configmap_name, + default_mode=0o440 + ) + ) # Mount the airflow.cfg file via a configmap the user has specified if self.kube_config.airflow_configmap: config_volume_name = 'airflow-config' - config_path = '{}/airflow.cfg'.format(self.worker_airflow_home) - volumes[config_volume_name] = { - 'name': config_volume_name, - 'configMap': { - 'name': self.kube_config.airflow_configmap - } - } - volume_mounts[config_volume_name] = { - 'name': config_volume_name, - 'mountPath': config_path, - 'subPath': 'airflow.cfg', - 'readOnly': True - } - - return volumes, volume_mounts - - def generate_dag_volume_mount_path(self): - if self.kube_config.dags_volume_claim or self.kube_config.dags_volume_host: - dag_volume_mount_path = self.worker_airflow_dags - else: - dag_volume_mount_path = self.kube_config.git_dags_folder_mount_point - - return dag_volume_mount_path + volumes[config_volume_name] = k8s.V1Volume( + name=config_volume_name, + config_map=k8s.V1ConfigMapVolumeSource( + name=self.kube_config.airflow_configmap + ) + ) - def make_pod(self, namespace, worker_uuid, pod_id, dag_id, task_id, execution_date, - try_number, airflow_command, kube_executor_config): - volumes_dict, volume_mounts_dict = self._get_volumes_and_mounts() - worker_init_container_spec = self._get_init_containers() - resources = Resources( - request_memory=kube_executor_config.request_memory, - request_cpu=kube_executor_config.request_cpu, - limit_memory=kube_executor_config.limit_memory, - limit_cpu=kube_executor_config.limit_cpu, - limit_gpu=kube_executor_config.limit_gpu - ) - gcp_sa_key = kube_executor_config.gcp_service_account_key - annotations = dict(kube_executor_config.annotations) or self.kube_config.kube_annotations - if gcp_sa_key: - annotations['iam.cloud.google.com/service-account'] = gcp_sa_key + return list(volumes.values()) - volumes = [value for value in volumes_dict.values()] + kube_executor_config.volumes - volume_mounts = [value for value in volume_mounts_dict.values()] + kube_executor_config.volume_mounts + def generate_dag_volume_mount_path(self) -> str: + """Generate path for DAG volume""" + if self.kube_config.dags_volume_claim or self.kube_config.dags_volume_host: + return self.worker_airflow_dags - affinity = kube_executor_config.affinity or self.kube_config.kube_affinity - tolerations = kube_executor_config.tolerations or self.kube_config.kube_tolerations + return self.kube_config.git_dags_folder_mount_point - return Pod( + def make_pod(self, namespace, worker_uuid, pod_id, dag_id, task_id, execution_date, + try_number, airflow_command) -> k8s.V1Pod: + """Creates POD.""" + pod_generator = PodGenerator( namespace=namespace, name=pod_id, - image=kube_executor_config.image or self.kube_config.kube_image, - image_pull_policy=(kube_executor_config.image_pull_policy or - self.kube_config.kube_image_pull_policy), - cmds=airflow_command, - labels=self._get_labels(kube_executor_config.labels, { + image=self.kube_config.kube_image, + image_pull_policy=self.kube_config.kube_image_pull_policy, + labels={ 'airflow-worker': worker_uuid, 'dag_id': dag_id, 'task_id': task_id, 'execution_date': execution_date, 'try_number': str(try_number), - }), + }, + cmds=airflow_command, + volumes=self._get_volumes(), + volume_mounts=self._get_volume_mounts(), + init_containers=self._get_init_containers(), + annotations=self.kube_config.kube_annotations, + affinity=self.kube_config.kube_affinity, + tolerations=self.kube_config.kube_tolerations, envs=self._get_environment(), - secrets=self._get_secrets(), + node_selectors=self.kube_config.kube_node_selectors, service_account_name=self.kube_config.worker_service_account_name, - image_pull_secrets=self.kube_config.image_pull_secrets, - init_containers=worker_init_container_spec, - volumes=volumes, - volume_mounts=volume_mounts, - resources=resources, - annotations=annotations, - node_selectors=(kube_executor_config.node_selectors or - self.kube_config.kube_node_selectors), - affinity=affinity, - tolerations=tolerations, - security_context=self._get_security_context(), - configmaps=self._get_configmaps() ) + + pod = pod_generator.gen_pod() + pod.spec.containers[0].env_from = pod.spec.containers[0].env_from or [] + pod.spec.containers[0].env_from.extend(self._get_env_from()) + pod.spec.security_context = self._get_security_context() + + return append_to_pod(pod, self._get_secrets()) diff --git a/airflow/lineage/backend/__init__.py b/airflow/lineage/backend/__init__.py index 243b86973f3242..53a9b6751d70f9 100644 --- a/airflow/lineage/backend/__init__.py +++ b/airflow/lineage/backend/__init__.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. # +"""Sends lineage metadata to a backend""" class LineageBackend: + """Sends lineage metadata to a backend""" def send_lineage(self, operator=None, inlets=None, outlets=None, context=None): """ diff --git a/airflow/macros/hive.py b/airflow/macros/hive.py index 914b9af2c4b1e2..5a0df1597c874f 100644 --- a/airflow/macros/hive.py +++ b/airflow/macros/hive.py @@ -93,6 +93,10 @@ def closest_ds_partition( :type ds: list[datetime.date] :param before: closest before (True), after (False) or either side of ds :type before: bool or None + :param schema: table schema + :type schema: str + :param metastore_conn_id: which matastore connection to use + :type metastore_conn_id: str :returns: The closest date :rtype: str or None diff --git a/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py b/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py index b77e49d6574f19..c1e62a2cc2a5f8 100644 --- a/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py +++ b/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py @@ -44,11 +44,11 @@ def upgrade(): conn = op.get_bind() # alembic creates an invalid SQL for mssql and mysql dialects - if conn.dialect.name in ("mysql"): + if conn.dialect.name in {"mysql"}: columns_and_constraints.append( sa.CheckConstraint("one_row_id<>0", name="kube_resource_version_one_row_id") ) - elif conn.dialect.name not in ("mssql"): + elif conn.dialect.name not in {"mssql"}: columns_and_constraints.append( sa.CheckConstraint("one_row_id", name="kube_resource_version_one_row_id") ) diff --git a/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py b/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py index 7e9e7fecaf34b5..86ad47131e4a1f 100644 --- a/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py +++ b/airflow/migrations/versions/86770d1215c0_add_kubernetes_scheduler_uniqueness.py @@ -45,11 +45,11 @@ def upgrade(): conn = op.get_bind() # alembic creates an invalid SQL for mssql and mysql dialects - if conn.dialect.name in ("mysql"): + if conn.dialect.name in {"mysql"}: columns_and_constraints.append( sa.CheckConstraint("one_row_id<>0", name="kube_worker_one_row_id") ) - elif conn.dialect.name not in ("mssql"): + elif conn.dialect.name not in {"mssql"}: columns_and_constraints.append( sa.CheckConstraint("one_row_id", name="kube_worker_one_row_id") ) diff --git a/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py b/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py index 3249a2e0589cb3..560b763963dd6c 100644 --- a/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py +++ b/airflow/migrations/versions/dd25f486b8ea_add_idx_log_dag.py @@ -15,9 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from alembic import op - """add idx_log_dag Revision ID: dd25f486b8ea @@ -25,6 +22,7 @@ Create Date: 2018-08-07 06:41:41.028249 """ +from alembic import op # revision identifiers, used by Alembic. revision = 'dd25f486b8ea' diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index 43b13f596c54bd..2dc56116233dc5 100644 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Airflow models""" from airflow.models.base import Base, ID_LEN # noqa: F401 from airflow.models.baseoperator import BaseOperator # noqa: F401 from airflow.models.connection import Connection # noqa: F401 diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 951cab41513f79..bdac7ab90dec11 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -16,16 +16,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +""" +Base operator for all operators. +""" from abc import ABCMeta, abstractmethod -from cached_property import cached_property import copy import functools import logging import sys import warnings from datetime import timedelta, datetime -from typing import Callable, Dict, Iterable, List, Optional, Set, Any +from typing import Callable, Dict, Iterable, List, Optional, Set, Any, Union + +from dateutil.relativedelta import relativedelta + +from cached_property import cached_property import jinja2 @@ -49,7 +54,10 @@ from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule +ScheduleInterval = Union[str, timedelta, relativedelta] + +# pylint: disable=too-many-instance-attributes,too-many-public-methods @functools.total_ordering class BaseOperator(LoggingMixin): """ @@ -221,14 +229,14 @@ class derived from this one results in the creation of a task object, # Defines which files extensions to look for in the templated fields template_ext = [] # type: Iterable[str] # Defines the color in the UI - ui_color = '#fff' - ui_fgcolor = '#000' + ui_color = '#fff' # type str + ui_fgcolor = '#000' # type str # base list which includes all the attrs that don't need deep copy. _base_operator_shallow_copy_attrs = ('user_defined_macros', 'user_defined_filters', 'params', - '_log',) + '_log',) # type: Iterable[str] # each operator should override this class attr for shallow copy attrs. shallow_copy_attrs = () # type: Iterable[str] @@ -246,7 +254,6 @@ class derived from this one results in the creation of a task object, 'retry_exponential_backoff', 'max_retry_delay', 'start_date', - 'schedule_interval', 'depends_on_past', 'wait_for_downstream', 'priority_weight', @@ -258,6 +265,8 @@ class derived from this one results in the creation of a task object, 'do_xcom_push', } + # noinspection PyUnusedLocal + # pylint: disable=too-many-arguments,too-many-locals @apply_defaults def __init__( self, @@ -266,18 +275,17 @@ def __init__( email: Optional[str] = None, email_on_retry: bool = True, email_on_failure: bool = True, - retries: int = None, + retries: Optional[int] = None, retry_delay: timedelta = timedelta(seconds=300), retry_exponential_backoff: bool = False, max_retry_delay: Optional[datetime] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, - schedule_interval=None, # not hooked as of now depends_on_past: bool = False, wait_for_downstream: bool = False, dag: Optional[DAG] = None, params: Optional[Dict] = None, - default_args: Optional[Dict] = None, + default_args: Optional[Dict] = None, # pylint: disable=unused-argument priority_weight: int = 1, weight_rule: str = WeightRule.DOWNSTREAM, queue: str = conf.get('celery', 'default_queue'), @@ -340,14 +348,6 @@ def __init__( if wait_for_downstream: self.depends_on_past = True - if schedule_interval: - self.log.warning( - "schedule_interval is used for %s, though it has " - "been deprecated as a task parameter, you need to " - "specify it as a DAG parameter instead", - self - ) - self._schedule_interval = schedule_interval self.retries = retries if retries is not None else \ conf.getint('core', 'default_task_retries', fallback=0) self.queue = queue @@ -414,7 +414,7 @@ def __init__( self._outlets.update(outlets) def __eq__(self, other): - if (type(self) == type(other) and + if (type(self) == type(other) and # pylint: disable=unidiomatic-typecheck self.task_id == other.task_id): return all(self.__dict__.get(c, None) == other.__dict__.get(c, None) for c in self._comps) return False @@ -427,8 +427,8 @@ def __lt__(self, other): def __hash__(self): hash_components = [type(self)] - for c in self._comps: - val = getattr(self, c, None) + for component in self._comps: + val = getattr(self, component, None) try: hash(val) hash_components.append(val) @@ -512,7 +512,7 @@ def dag(self, dag): elif self.task_id not in dag.task_dict: dag.add_task(self) - self._dag = dag + self._dag = dag # pylint: disable=attribute-defined-outside-init def has_dag(self): """ @@ -522,6 +522,7 @@ def has_dag(self): @property def dag_id(self): + """Returns dag id if it has one or an adhoc + owner""" if self.has_dag(): return self.dag.dag_id else: @@ -541,19 +542,16 @@ def deps(self): } @property - def schedule_interval(self): - """ - The schedule interval of the DAG always wins over individual tasks so - that tasks within a DAG always line up. The task still needs a - schedule_interval as it may not be attached to a DAG. + def priority_weight_total(self): """ - if self.has_dag(): - return self.dag._schedule_interval - else: - return self._schedule_interval + Total priority weight for the task. It might include all upstream or downstream tasks. + depending on the weight rule. - @property - def priority_weight_total(self): + - WeightRule.ABSOLUTE - only own weight + - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks + - WeightRule.UPSTREAM - adds priority weight of all upstream tasks + + """ if self.weight_rule == WeightRule.ABSOLUTE: return self.priority_weight elif self.weight_rule == WeightRule.DOWNSTREAM: @@ -570,10 +568,12 @@ def priority_weight_total(self): @cached_property def operator_extra_link_dict(self): + """Returns dictionary of all extra links for the operator""" return {link.name: link for link in self.operator_extra_links} @cached_property def global_operator_extra_link_dict(self): + """Returns dictionary of all global extra links""" from airflow.plugins_manager import global_operator_extra_links return {link.name: link for link in global_operator_extra_links} @@ -618,7 +618,9 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result - shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs + # noinspection PyProtectedMember + shallow_copy = cls.shallow_copy_attrs + \ + cls._base_operator_shallow_copy_attrs # pylint: disable=protected-access for k, v in self.__dict__.items(): if k not in shallow_copy: @@ -634,7 +636,7 @@ def __getstate__(self): return state def __setstate__(self, state): - self.__dict__ = state + self.__dict__ = state # pylint: disable=attribute-defined-outside-init self._log = logging.getLogger("airflow.task.operators") def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Environment] = None) -> None: @@ -656,7 +658,7 @@ def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Envir rendered_content = self.render_template(content, context, jinja_env) setattr(self, attr_name, rendered_content) - def render_template( + def render_template( # pylint: disable=too-many-return-statements self, content: Any, context: Dict, jinja_env: Optional[jinja2.Environment] = None ) -> Any: """ @@ -684,7 +686,7 @@ def render_template( return jinja_env.from_string(content).render(**context) if isinstance(content, tuple): - if type(content) is not tuple: + if type(content) is not tuple: # pylint: disable=unidiomatic-typecheck # Special case for named tuples return content.__class__( *(self.render_template(element, context, jinja_env) for element in content) @@ -717,8 +719,8 @@ def prepare_template(self): """ def resolve_template_files(self): - # Getting the content of files for template_field / template_ext - if self.template_ext: + """Getting the content of files for template_field / template_ext""" + if self.template_ext: # pylint: disable=too-many-nested-blocks for attr in self.template_fields: content = getattr(self, attr, None) if content is None: @@ -728,16 +730,16 @@ def resolve_template_files(self): env = self.get_template_env() try: setattr(self, attr, env.loader.get_source(env, content)[0]) - except Exception as e: + except Exception as e: # pylint: disable=broad-except self.log.exception(e) elif isinstance(content, list): env = self.dag.get_template_env() - for i in range(len(content)): + for i in range(len(content)): # pylint: disable=consider-using-enumerate if isinstance(content[i], str) and \ any([content[i].endswith(ext) for ext in self.template_ext]): try: content[i] = env.loader.get_source(env, content[i])[0] - except Exception as e: + except Exception as e: # pylint: disable=broad-except self.log.exception(e) self.prepare_template() @@ -748,6 +750,7 @@ def upstream_list(self): @property def upstream_task_ids(self): + """@property: list of ids of tasks directly upstream""" return self._upstream_task_ids @property @@ -757,6 +760,7 @@ def downstream_list(self): @property def downstream_task_ids(self): + """@property: list of ids of tasks directly downstream""" return self._downstream_task_ids @provide_session @@ -851,14 +855,15 @@ def run( start_date = start_date or self.start_date end_date = end_date or self.end_date or timezone.utcnow() - for dt in self.dag.date_range(start_date, end_date=end_date): - TaskInstance(self, dt).run( + for execution_date in self.dag.date_range(start_date, end_date=end_date): + TaskInstance(self, execution_date).run( mark_success=mark_success, ignore_depends_on_past=( - dt == start_date and ignore_first_depends_on_past), + execution_date == start_date and ignore_first_depends_on_past), ignore_ti_state=ignore_ti_state) def dry_run(self): + """Performs dry run for the operator - just render template fields.""" self.log.info('Dry run') for attr in self.template_fields: content = getattr(self, attr) @@ -892,31 +897,35 @@ def __repr__(self): @property def task_type(self): + """@property: type of the task""" return self.__class__.__name__ def add_only_new(self, item_set, item): + """Adds only new items to item set""" if item in item_set: self.log.warning( - 'Dependency {self}, {item} already registered' - ''.format(self=self, item=item)) + 'Dependency %s, %s already registered', self, item) else: item_set.add(item) def _set_relatives(self, task_or_task_list, upstream=False): + """Sets relatives for the task.""" try: task_list = list(task_or_task_list) except TypeError: task_list = [task_or_task_list] - for t in task_list: - if not isinstance(t, BaseOperator): + for task in task_list: + if not isinstance(task, BaseOperator): raise AirflowException( "Relationships can only be set between " - "Operators; received {}".format(t.__class__.__name__)) + "Operators; received {}".format(task.__class__.__name__)) # relationships can only be set if the tasks share a single DAG. Tasks # without a DAG are assigned to that DAG. - dags = {t._dag.dag_id: t._dag for t in [self] + task_list if t.has_dag()} + dags = { + task._dag.dag_id: task._dag # pylint: disable=protected-access + for task in [self] + task_list if task.has_dag()} if len(dags) > 1: raise AirflowException( @@ -989,6 +998,7 @@ def xcom_pull( @cached_property def extra_links(self) -> Iterable[str]: + """@property: extra links for the task. """ return list(set(self.operator_extra_link_dict.keys()) .union(self.global_operator_extra_link_dict.keys())) @@ -1007,6 +1017,8 @@ def get_extra_links(self, dttm, link_name): return self.operator_extra_link_dict[link_name].get_link(self, dttm) elif link_name in self.global_operator_extra_link_dict: return self.global_operator_extra_link_dict[link_name].get_link(self, dttm) + else: + return None class BaseOperatorLink(metaclass=ABCMeta): diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 29051e499be377..00f14875a8ed59 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -204,14 +204,14 @@ def get_hook(self): from airflow.hooks.mysql_hook import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': - from airflow.contrib.hooks.bigquery_hook import BigQueryHook + from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.hooks.postgres_hook import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.hooks.pig_hook import PigCliHook - return PigCliHook(pig_conn_id=self.conn_id) + return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.hooks.hive_hooks import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index ad0861a0485a4a..3fb08b3cee0009 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Taskfail tracks the failed run durations of each task instance""" from sqlalchemy import Column, Index, Integer, String from airflow.models.base import Base, ID_LEN diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py index 4bdb6431f1ae41..ebca6ee4e4ee6f 100644 --- a/airflow/models/taskreschedule.py +++ b/airflow/models/taskreschedule.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""TaskReschedule tracks rescheduled task instances.""" from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc from airflow.models.base import Base, ID_LEN diff --git a/airflow/operators/adls_to_gcs.py b/airflow/operators/adls_to_gcs.py index b5a58e6377e00b..b24d14305c1e54 100644 --- a/airflow/operators/adls_to_gcs.py +++ b/airflow/operators/adls_to_gcs.py @@ -23,10 +23,11 @@ import os import warnings from tempfile import NamedTemporaryFile +from typing import Optional from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook from airflow.contrib.operators.adls_list_operator import AzureDataLakeStorageListOperator -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook, _parse_gcs_url +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook, _parse_gcs_url from airflow.utils.decorators import apply_defaults @@ -99,16 +100,16 @@ class AdlsToGoogleCloudStorageOperator(AzureDataLakeStorageListOperator): @apply_defaults def __init__(self, - src_adls, - dest_gcs, - azure_data_lake_conn_id, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - replace=False, - gzip=False, + src_adls: str, + dest_gcs: str, + azure_data_lake_conn_id: str, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + replace: bool = False, + gzip: bool = False, *args, - **kwargs): + **kwargs) -> None: super().__init__( path=src_adls, diff --git a/airflow/operators/bigquery_to_bigquery.py b/airflow/operators/bigquery_to_bigquery.py new file mode 100644 index 00000000000000..5a5366647fef61 --- /dev/null +++ b/airflow/operators/bigquery_to_bigquery.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains Google BigQuery to BigQuery operator. +""" +import warnings +from typing import List, Optional, Union, Dict + +from airflow.models.baseoperator import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.gcp.hooks.bigquery import BigQueryHook + + +class BigQueryToBigQueryOperator(BaseOperator): + """ + Copies data from one BigQuery table to another. + + .. seealso:: + For more details about these parameters: + https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.copy + + :param source_project_dataset_tables: One or more + dotted ``(project:|project.).
`` BigQuery tables to use as the + source data. If ```` is not included, project will be the + project defined in the connection json. Use a list if there are multiple + source tables. (templated) + :type source_project_dataset_tables: list|string + :param destination_project_dataset_table: The destination BigQuery + table. Format is: ``(project:|project.).
`` (templated) + :type destination_project_dataset_table: str + :param write_disposition: The write disposition if the table already exists. + :type write_disposition: str + :param create_disposition: The create disposition if the table doesn't exist. + :type create_disposition: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + **Example**: :: + + encryption_configuration = { + "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" + } + :type encryption_configuration: dict + :param location: The location used for the operation. + :type location: str + """ + template_fields = ('source_project_dataset_tables', + 'destination_project_dataset_table', 'labels') + template_ext = ('.sql',) + ui_color = '#e6f0e4' + + @apply_defaults + def __init__(self, # pylint:disable=too-many-arguments + source_project_dataset_tables: Union[List[str], str], + destination_project_dataset_table: str, + write_disposition: str = 'WRITE_EMPTY', + create_disposition: str = 'CREATE_IF_NEEDED', + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + location: str = None, + *args, + **kwargs) -> None: + super().__init__(*args, **kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.source_project_dataset_tables = source_project_dataset_tables + self.destination_project_dataset_table = destination_project_dataset_table + self.write_disposition = write_disposition + self.create_disposition = create_disposition + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.labels = labels + self.encryption_configuration = encryption_configuration + self.location = location + + def execute(self, context): + self.log.info( + 'Executing copy of %s into: %s', + self.source_project_dataset_tables, self.destination_project_dataset_table + ) + hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location) + conn = hook.get_conn() + cursor = conn.cursor() + cursor.run_copy( + source_project_dataset_tables=self.source_project_dataset_tables, + destination_project_dataset_table=self.destination_project_dataset_table, + write_disposition=self.write_disposition, + create_disposition=self.create_disposition, + labels=self.labels, + encryption_configuration=self.encryption_configuration) diff --git a/airflow/operators/bigquery_to_gcs.py b/airflow/operators/bigquery_to_gcs.py new file mode 100644 index 00000000000000..bbc386788e1047 --- /dev/null +++ b/airflow/operators/bigquery_to_gcs.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains Google BigQuery to Google CLoud Storage operator. +""" +import warnings +from typing import List, Optional, Dict + +from airflow.models.baseoperator import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.gcp.hooks.bigquery import BigQueryHook + + +class BigQueryToCloudStorageOperator(BaseOperator): + """ + Transfers a BigQuery table to a Google Cloud Storage bucket. + + .. seealso:: + For more details about these parameters: + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + :param source_project_dataset_table: The dotted + ``(.|:).
`` BigQuery table to use as the + source data. If ```` is not included, project will be the project + defined in the connection json. (templated) + :type source_project_dataset_table: str + :param destination_cloud_storage_uris: The destination Google Cloud + Storage URI (e.g. gs://some-bucket/some-file.txt). (templated) Follows + convention defined here: + https://cloud.google.com/bigquery/exporting-data-from-bigquery#exportingmultiple + :type destination_cloud_storage_uris: List[str] + :param compression: Type of compression to use. + :type compression: str + :param export_format: File format to export. + :type export_format: str + :param field_delimiter: The delimiter to use when extracting to a CSV. + :type field_delimiter: str + :param print_header: Whether to print a header for a CSV file extract. + :type print_header: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param labels: a dictionary containing labels for the job/query, + passed to BigQuery + :type labels: dict + :param location: The location used for the operation. + :type location: str + """ + template_fields = ('source_project_dataset_table', + 'destination_cloud_storage_uris', 'labels') + template_ext = () + ui_color = '#e4e6f0' + + @apply_defaults + def __init__(self, # pylint: disable=too-many-arguments + source_project_dataset_table: str, + destination_cloud_storage_uris: List[str], + compression: str = 'NONE', + export_format: str = 'CSV', + field_delimiter: str = ',', + print_header: bool = True, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + location: str = None, + *args, + **kwargs) -> None: + super().__init__(*args, **kwargs) + + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = bigquery_conn_id + + self.source_project_dataset_table = source_project_dataset_table + self.destination_cloud_storage_uris = destination_cloud_storage_uris + self.compression = compression + self.export_format = export_format + self.field_delimiter = field_delimiter + self.print_header = print_header + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.labels = labels + self.location = location + + def execute(self, context): + self.log.info('Executing extract of %s into: %s', + self.source_project_dataset_table, + self.destination_cloud_storage_uris) + hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location) + conn = hook.get_conn() + cursor = conn.cursor() + cursor.run_extract( + source_project_dataset_table=self.source_project_dataset_table, + destination_cloud_storage_uris=self.destination_cloud_storage_uris, + compression=self.compression, + export_format=self.export_format, + field_delimiter=self.field_delimiter, + print_header=self.print_header, + labels=self.labels) diff --git a/airflow/operators/bigquery_to_mysql.py b/airflow/operators/bigquery_to_mysql.py new file mode 100644 index 00000000000000..62c407b02f569b --- /dev/null +++ b/airflow/operators/bigquery_to_mysql.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains Google BigQuery to MySQL operator. +""" +from typing import Optional + +from airflow.models.baseoperator import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.gcp.hooks.bigquery import BigQueryHook +from airflow.hooks.mysql_hook import MySqlHook + + +class BigQueryToMySqlOperator(BaseOperator): + """ + Fetches the data from a BigQuery table (alternatively fetch data for selected columns) + and insert that data into a MySQL table. + + + .. note:: + If you pass fields to ``selected_fields`` which are in different order than the + order of columns already in + BQ table, the data will still be in the order of BQ table. + For example if the BQ table has 3 columns as + ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields`` + the data would still be of the form ``'A,B'`` and passed through this form + to MySQL + + **Example**: :: + + transfer_data = BigQueryToMySqlOperator( + task_id='task_id', + dataset_table='origin_bq_table', + mysql_table='dest_table_name', + replace=True, + ) + + :param dataset_table: A dotted ``.
``: the big query table of origin + :type dataset_table: str + :param max_results: The maximum number of records (rows) to be fetched + from the table. (templated) + :type max_results: str + :param selected_fields: List of fields to return (comma-separated). If + unspecified, all fields are returned. + :type selected_fields: str + :param gcp_conn_id: reference to a specific GCP hook. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param mysql_conn_id: reference to a specific mysql hook + :type mysql_conn_id: str + :param database: name of database which overwrite defined one in connection + :type database: str + :param replace: Whether to replace instead of insert + :type replace: bool + :param batch_size: The number of rows to take in each batch + :type batch_size: int + :param location: The location used for the operation. + :type location: str + """ + template_fields = ('dataset_id', 'table_id', 'mysql_table') + + @apply_defaults + def __init__(self, # pylint:disable=too-many-arguments + dataset_table: str, + mysql_table: str, + selected_fields: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + mysql_conn_id: str = 'mysql_default', + database: Optional[str] = None, + delegate_to: Optional[str] = None, + replace: bool = False, + batch_size: int = 1000, + location: str = None, + *args, + **kwargs) -> None: + super(BigQueryToMySqlOperator, self).__init__(*args, **kwargs) + self.selected_fields = selected_fields + self.gcp_conn_id = gcp_conn_id + self.mysql_conn_id = mysql_conn_id + self.database = database + self.mysql_table = mysql_table + self.replace = replace + self.delegate_to = delegate_to + self.batch_size = batch_size + self.location = location + try: + self.dataset_id, self.table_id = dataset_table.split('.') + except ValueError: + raise ValueError('Could not parse {} as .
' + .format(dataset_table)) + + def _bq_get_data(self): + self.log.info('Fetching Data from:') + self.log.info('Dataset: %s ; Table: %s', + self.dataset_id, self.table_id) + + hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location) + + conn = hook.get_conn() + cursor = conn.cursor() + i = 0 + while True: + response = cursor.get_tabledata(dataset_id=self.dataset_id, + table_id=self.table_id, + max_results=self.batch_size, + selected_fields=self.selected_fields, + start_index=i * self.batch_size) + + if 'rows' in response: + rows = response['rows'] + else: + self.log.info('Job Finished') + return + + self.log.info('Total Extracted rows: %s', len(rows) + i * self.batch_size) + + table_data = [] + for dict_row in rows: + single_row = [] + for fields in dict_row['f']: + single_row.append(fields['v']) + table_data.append(single_row) + + yield table_data + i += 1 + + def execute(self, context): + mysql_hook = MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id) + for rows in self._bq_get_data(): + mysql_hook.insert_rows(self.mysql_table, rows, replace=self.replace) diff --git a/airflow/operators/cassandra_to_gcs.py b/airflow/operators/cassandra_to_gcs.py new file mode 100644 index 00000000000000..e11daa7ebead26 --- /dev/null +++ b/airflow/operators/cassandra_to_gcs.py @@ -0,0 +1,366 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains operator for copying +data from Cassandra to Google cloud storage in JSON format. +""" +import json +import warnings +from base64 import b64encode +from datetime import datetime +from decimal import Decimal +from tempfile import NamedTemporaryFile +from typing import Optional +from uuid import UUID + +from cassandra.util import Date, Time, SortedSet, OrderedMapSerializedKey + +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook +from airflow.contrib.hooks.cassandra_hook import CassandraHook +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class CassandraToGoogleCloudStorageOperator(BaseOperator): + """ + Copy data from Cassandra to Google cloud storage in JSON format + + Note: Arrays of arrays are not supported. + + :param cql: The CQL to execute on the Cassandra table. + :type cql: str + :param bucket: The bucket to upload to. + :type bucket: str + :param filename: The filename to use as the object name when uploading + to Google cloud storage. A {} should be specified in the filename + to allow the operator to inject file numbers in cases where the + file is split due to size. + :type filename: str + :param schema_filename: If set, the filename to use as the object name + when uploading a .json file containing the BigQuery schema fields + for the table that was dumped from MySQL. + :type schema_filename: str + :param approx_max_file_size_bytes: This operator supports the ability + to split large table dumps into multiple files (see notes in the + filename param docs above). This param allows developers to specify the + file size of the splits. Check https://cloud.google.com/storage/quotas + to see the maximum allowed file size for a single object. + :type approx_max_file_size_bytes: long + :param cassandra_conn_id: Reference to a specific Cassandra hook. + :type cassandra_conn_id: str + :param gzip: Option to compress file for upload + :type gzip: bool + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud + Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type google_cloud_storage_conn_id: str + :param delegate_to: The account to impersonate, if any. For this to + work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + """ + template_fields = ('cql', 'bucket', 'filename', 'schema_filename',) + template_ext = ('.cql',) + ui_color = '#a0e08c' + + @apply_defaults + def __init__(self, + cql: str, + bucket: str, + filename: str, + schema_filename: Optional[str] = None, + approx_max_file_size_bytes: int = 1900000000, + gzip: bool = False, + cassandra_conn_id: str = 'cassandra_default', + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + *args, + **kwargs) -> None: + super().__init__(*args, **kwargs) + + if google_cloud_storage_conn_id: + warnings.warn( + "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " + "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + gcp_conn_id = google_cloud_storage_conn_id + + self.cql = cql + self.bucket = bucket + self.filename = filename + self.schema_filename = schema_filename + self.approx_max_file_size_bytes = approx_max_file_size_bytes + self.cassandra_conn_id = cassandra_conn_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.gzip = gzip + + self.hook = None + + # Default Cassandra to BigQuery type mapping + CQL_TYPE_MAP = { + 'BytesType': 'BYTES', + 'DecimalType': 'FLOAT', + 'UUIDType': 'BYTES', + 'BooleanType': 'BOOL', + 'ByteType': 'INTEGER', + 'AsciiType': 'STRING', + 'FloatType': 'FLOAT', + 'DoubleType': 'FLOAT', + 'LongType': 'INTEGER', + 'Int32Type': 'INTEGER', + 'IntegerType': 'INTEGER', + 'InetAddressType': 'STRING', + 'CounterColumnType': 'INTEGER', + 'DateType': 'TIMESTAMP', + 'SimpleDateType': 'DATE', + 'TimestampType': 'TIMESTAMP', + 'TimeUUIDType': 'BYTES', + 'ShortType': 'INTEGER', + 'TimeType': 'TIME', + 'DurationType': 'INTEGER', + 'UTF8Type': 'STRING', + 'VarcharType': 'STRING', + } + + def execute(self, context): + cursor = self._query_cassandra() + files_to_upload = self._write_local_data_files(cursor) + + # If a schema is set, create a BQ schema JSON file. + if self.schema_filename: + files_to_upload.update(self._write_local_schema_file(cursor)) + + # Flush all files before uploading + for file_handle in files_to_upload.values(): + file_handle.flush() + + self._upload_to_gcs(files_to_upload) + + # Close all temp file handles. + for file_handle in files_to_upload.values(): + file_handle.close() + + # Close all sessions and connection associated with this Cassandra cluster + self.hook.shutdown_cluster() + + def _query_cassandra(self): + """ + Queries cassandra and returns a cursor to the results. + """ + self.hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id) + session = self.hook.get_conn() + cursor = session.execute(self.cql) + return cursor + + def _write_local_data_files(self, cursor): + """ + Takes a cursor, and writes results to a local file. + + :return: A dictionary where keys are filenames to be used as object + names in GCS, and values are file handles to local files that + contain the data for the GCS objects. + """ + file_no = 0 + tmp_file_handle = NamedTemporaryFile(delete=True) + tmp_file_handles = {self.filename.format(file_no): tmp_file_handle} + for row in cursor: + row_dict = self.generate_data_dict(row._fields, row) + s = json.dumps(row_dict).encode('utf-8') + tmp_file_handle.write(s) + + # Append newline to make dumps BigQuery compatible. + tmp_file_handle.write(b'\n') + + if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: + file_no += 1 + tmp_file_handle = NamedTemporaryFile(delete=True) + tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle + + return tmp_file_handles + + def _write_local_schema_file(self, cursor): + """ + Takes a cursor, and writes the BigQuery schema for the results to a + local file system. + + :return: A dictionary where key is a filename to be used as an object + name in GCS, and values are file handles to local files that + contains the BigQuery schema fields in .json format. + """ + schema = [] + tmp_schema_file_handle = NamedTemporaryFile(delete=True) + + for name, type in zip(cursor.column_names, cursor.column_types): + schema.append(self.generate_schema_dict(name, type)) + json_serialized_schema = json.dumps(schema).encode('utf-8') + + tmp_schema_file_handle.write(json_serialized_schema) + return {self.schema_filename: tmp_schema_file_handle} + + def _upload_to_gcs(self, files_to_upload): + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to) + for object, tmp_file_handle in files_to_upload.items(): + hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json', self.gzip) + + @classmethod + def generate_data_dict(cls, names, values): + row_dict = {} + for name, value in zip(names, values): + row_dict.update({name: cls.convert_value(name, value)}) + return row_dict + + @classmethod + def convert_value(cls, name, value): + if not value: + return value + elif isinstance(value, (str, int, float, bool, dict)): + return value + elif isinstance(value, bytes): + return b64encode(value).decode('ascii') + elif isinstance(value, UUID): + return b64encode(value.bytes).decode('ascii') + elif isinstance(value, (datetime, Date)): + return str(value) + elif isinstance(value, Decimal): + return float(value) + elif isinstance(value, Time): + return str(value).split('.')[0] + elif isinstance(value, (list, SortedSet)): + return cls.convert_array_types(name, value) + elif hasattr(value, '_fields'): + return cls.convert_user_type(name, value) + elif isinstance(value, tuple): + return cls.convert_tuple_type(name, value) + elif isinstance(value, OrderedMapSerializedKey): + return cls.convert_map_type(name, value) + else: + raise AirflowException('unexpected value: ' + str(value)) + + @classmethod + def convert_array_types(cls, name, value): + return [cls.convert_value(name, nested_value) for nested_value in value] + + @classmethod + def convert_user_type(cls, name, value): + """ + Converts a user type to RECORD that contains n fields, where n is the + number of attributes. Each element in the user type class will be converted to its + corresponding data type in BQ. + """ + names = value._fields + values = [cls.convert_value(name, getattr(value, name)) for name in names] + return cls.generate_data_dict(names, values) + + @classmethod + def convert_tuple_type(cls, name, value): + """ + Converts a tuple to RECORD that contains n fields, each will be converted + to its corresponding data type in bq and will be named 'field_', where + index is determined by the order of the tuple elements defined in cassandra. + """ + names = ['field_' + str(i) for i in range(len(value))] + values = [cls.convert_value(name, value) for name, value in zip(names, value)] + return cls.generate_data_dict(names, values) + + @classmethod + def convert_map_type(cls, name, value): + """ + Converts a map to a repeated RECORD that contains two fields: 'key' and 'value', + each will be converted to its corresponding data type in BQ. + """ + converted_map = [] + for k, v in zip(value.keys(), value.values()): + converted_map.append({ + 'key': cls.convert_value('key', k), + 'value': cls.convert_value('value', v) + }) + return converted_map + + @classmethod + def generate_schema_dict(cls, name, type): + field_schema = dict() + field_schema.update({'name': name}) + field_schema.update({'type': cls.get_bq_type(type)}) + field_schema.update({'mode': cls.get_bq_mode(type)}) + fields = cls.get_bq_fields(name, type) + if fields: + field_schema.update({'fields': fields}) + return field_schema + + @classmethod + def get_bq_fields(cls, name, type): + fields = [] + + if not cls.is_simple_type(type): + names, types = [], [] + + if cls.is_array_type(type) and cls.is_record_type(type.subtypes[0]): + names = type.subtypes[0].fieldnames + types = type.subtypes[0].subtypes + elif cls.is_record_type(type): + names = type.fieldnames + types = type.subtypes + + if types and not names and type.cassname == 'TupleType': + names = ['field_' + str(i) for i in range(len(types))] + elif types and not names and type.cassname == 'MapType': + names = ['key', 'value'] + + for name, type in zip(names, types): + field = cls.generate_schema_dict(name, type) + fields.append(field) + + return fields + + @classmethod + def is_simple_type(cls, type): + return type.cassname in CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP + + @classmethod + def is_array_type(cls, type): + return type.cassname in ['ListType', 'SetType'] + + @classmethod + def is_record_type(cls, type): + return type.cassname in ['UserType', 'TupleType', 'MapType'] + + @classmethod + def get_bq_type(cls, type): + if cls.is_simple_type(type): + return CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type.cassname] + elif cls.is_record_type(type): + return 'RECORD' + elif cls.is_array_type(type): + return cls.get_bq_type(type.subtypes[0]) + else: + raise AirflowException('Not a supported type: ' + type.cassname) + + @classmethod + def get_bq_mode(cls, type): + if cls.is_array_type(type) or type.cassname == 'MapType': + return 'REPEATED' + elif cls.is_record_type(type) or cls.is_simple_type(type): + return 'NULLABLE' + else: + raise AirflowException('Not a supported type: ' + type.cassname) diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py index 005671c260dd05..d4adf3c6bea176 100644 --- a/airflow/operators/docker_operator.py +++ b/airflow/operators/docker_operator.py @@ -16,19 +16,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +""" +Implements Docker operator +""" import json from typing import Union, List, Dict, Iterable +import ast +from docker import APIClient, tls + from airflow.hooks.docker_hook import DockerHook from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.utils.file import TemporaryDirectory -from docker import APIClient, tls -import ast +# pylint: disable=too-many-instance-attributes class DockerOperator(BaseOperator): """ Execute a command inside a docker container. @@ -119,6 +123,7 @@ class DockerOperator(BaseOperator): template_fields = ('command', 'environment', 'container_name') template_ext = ('.sh', '.bash',) + # pylint: disable=too-many-arguments,too-many-locals @apply_defaults def __init__( self, @@ -184,7 +189,12 @@ def __init__( self.cli = None self.container = None - def get_hook(self): + def get_hook(self) -> DockerHook: + """ + Retrieves hook for the operator. + + :return: The Docker Hook + """ return DockerHook( docker_conn_id=self.docker_conn_id, base_url=self.docker_url, @@ -238,6 +248,8 @@ def _run_image(self): if self.do_xcom_push: return self.cli.logs(container=self.container['Id']) \ if self.xcom_all else line.encode('utf-8') + else: + return None def execute(self, context): @@ -253,10 +265,10 @@ def execute(self, context): ) # Pull the docker image if `force_pull` is set or image does not exist locally - if self.force_pull or len(self.cli.images(name=self.image)) == 0: + if self.force_pull or not self.cli.images(name=self.image): self.log.info('Pulling docker image %s', self.image) - for l in self.cli.pull(self.image, stream=True): - output = json.loads(l.decode('utf-8').strip()) + for line in self.cli.pull(self.image, stream=True): + output = json.loads(line.decode('utf-8').strip()) if 'status' in output: self.log.info("%s", output['status']) @@ -265,6 +277,12 @@ def execute(self, context): self._run_image() def get_command(self): + """ + Retrieve command(s). if command string starts with [, it returns the command list) + + :return: the command (or commands) + :rtype: str | List[str] + """ if isinstance(self.command, str) and self.command.strip().find('[') == 0: commands = ast.literal_eval(self.command) else: @@ -279,11 +297,14 @@ def on_kill(self): def __get_tls_config(self): tls_config = None if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key: + # Ignore type error on SSL version here - it is deprecated and type annotation is wrong + # it should be string + # noinspection PyTypeChecker tls_config = tls.TLSConfig( ca_cert=self.tls_ca_cert, client_cert=(self.tls_client_cert, self.tls_client_key), verify=True, - ssl_version=self.tls_ssl_version, + ssl_version=self.tls_ssl_version, # type: ignore assert_hostname=self.tls_hostname ) self.docker_url = self.docker_url.replace('tcp://', 'https://') diff --git a/airflow/operators/gcs_to_bq.py b/airflow/operators/gcs_to_bq.py index 7b811654b3e153..0890f4847a44a1 100644 --- a/airflow/operators/gcs_to_bq.py +++ b/airflow/operators/gcs_to_bq.py @@ -23,8 +23,8 @@ import json from airflow import AirflowException -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.contrib.hooks.bigquery_hook import BigQueryHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook +from airflow.gcp.hooks.bigquery import BigQueryHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/gcs_to_gcs.py b/airflow/operators/gcs_to_gcs.py index 7b653fdb59c145..a4d51e9a5157ac 100644 --- a/airflow/operators/gcs_to_gcs.py +++ b/airflow/operators/gcs_to_gcs.py @@ -20,8 +20,9 @@ This module contains a Google Cloud Storage operator. """ import warnings +from typing import Optional -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.exceptions import AirflowException @@ -209,3 +210,90 @@ def _copy_single_object(self, hook, source_object, destination_object): if self.move_object: hook.delete(self.source_bucket, source_object) + + +class GoogleCloudStorageSynchronizeBuckets(BaseOperator): + """ + Synchronizes the contents of the buckets or bucket's directories in the Google Cloud Services. + + Parameters ``source_object`` and ``destination_object`` describe the root sync directory. If they are + not passed, the entire bucket will be synchronized. They should point to directories. + + .. note:: + The synchronization of individual files is not supported. Only entire directories can be + synchronized. + + :param source_bucket: The name of the bucket containing the source objects. + :type source_bucket: str + :param destination_bucket: The name of the bucket containing the destination objects. + :type destination_bucket: str + :param source_object: The root sync directory in the source bucket. + :type source_object: Optional[str] + :param destination_object: The root sync directory in the destination bucket. + :type destination_object: Optional[str] + :param recursive: If True, subdirectories will be considered + :type recursive: bool + :param allow_overwrite: if True, the files will be overwritten if a mismatched file is found. + By default, overwriting files is not allowed + :type allow_overwrite: bool + :param delete_extra_files: if True, deletes additional files from the source that not found in the + destination. By default extra files are not deleted. + + .. note:: + This option can delete data quickly if you specify the wrong source/destination combination. + + :type delete_extra_files: bool + """ + + template_fields = ( + 'source_bucket', + 'destination_bucket', + 'source_object', + 'destination_object', + 'recursive', + 'delete_extra_files', + 'allow_overwrite', + 'gcp_conn_id', + 'delegate_to', + ) + + @apply_defaults + def __init__( + self, + source_bucket: str, + destination_bucket: str, + source_object: Optional[str] = None, + destination_object: Optional[str] = None, + recursive: bool = True, + delete_extra_files: bool = False, + allow_overwrite: bool = False, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + *args, + **kwargs + ) -> None: + super().__init__(*args, **kwargs) + self.source_bucket = source_bucket + self.destination_bucket = destination_bucket + self.source_object = source_object + self.destination_object = destination_object + self.recursive = recursive + self.delete_extra_files = delete_extra_files + self.allow_overwrite = allow_overwrite + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + + def execute(self, context): + hook = GoogleCloudStorageHook( + google_cloud_storage_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to + ) + hook.sync( + source_bucket=self.source_bucket, + destination_bucket=self.destination_bucket, + source_object=self.source_object, + destination_object=self.destination_object, + recursive=self.recursive, + delete_extra_files=self.delete_extra_files, + allow_overwrite=self.allow_overwrite + ) diff --git a/airflow/operators/gcs_to_s3.py b/airflow/operators/gcs_to_s3.py index 1c11801288e5e6..51e4227c399e3e 100644 --- a/airflow/operators/gcs_to_s3.py +++ b/airflow/operators/gcs_to_s3.py @@ -21,7 +21,7 @@ """ import warnings -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.contrib.operators.gcs_list_operator import GoogleCloudStorageListOperator from airflow.utils.decorators import apply_defaults from airflow.hooks.S3_hook import S3Hook diff --git a/airflow/operators/google_api_to_s3_transfer.py b/airflow/operators/google_api_to_s3_transfer.py new file mode 100644 index 00000000000000..f65c5da126347e --- /dev/null +++ b/airflow/operators/google_api_to_s3_transfer.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module allows you to transfer data from any Google API endpoint into a S3 Bucket. +""" +import json +import sys + +from airflow.models import BaseOperator +from airflow.models.xcom import MAX_XCOM_SIZE +from airflow.utils.decorators import apply_defaults + +from airflow.gcp.hooks.discovery_api import GoogleDiscoveryApiHook +from airflow.hooks.S3_hook import S3Hook + + +class GoogleApiToS3Transfer(BaseOperator): + """ + Basic class for transferring data from a Google API endpoint into a S3 Bucket. + + :param google_api_service_name: The specific API service that is being requested. + :type google_api_service_name: str + :param google_api_service_version: The version of the API that is being requested. + :type google_api_service_version: str + :param google_api_endpoint_path: The client libraries path to the api call's executing method. + For example: 'analyticsreporting.reports.batchGet' + + .. note:: See https://developers.google.com/apis-explorer + for more information on which methods are available. + + :type google_api_endpoint_path: str + :param google_api_endpoint_params: The params to control the corresponding endpoint result. + :type google_api_endpoint_params: dict + :param s3_destination_key: The url where to put the data retrieved from the endpoint in S3. + :type s3_destination_key: str + :param google_api_response_via_xcom: Can be set to expose the google api response to xcom. + :type google_api_response_via_xcom: str + :param google_api_endpoint_params_via_xcom: If set to a value this value will be used as a key + for pulling from xcom and updating the google api endpoint params. + :type google_api_endpoint_params_via_xcom: str + :param google_api_endpoint_params_via_xcom_task_ids: Task ids to filter xcom by. + :type google_api_endpoint_params_via_xcom_task_ids: str or list of str + :param google_api_pagination: If set to True Pagination will be enabled for this request + to retrieve all data. + + .. note:: This means the response will be a list of responses. + + :type google_api_pagination: bool + :param google_api_num_retries: Define the number of retries for the google api requests being made + if it fails. + :type google_api_num_retries: int + :param s3_overwrite: Specifies whether the s3 file will be overwritten if exists. + :type s3_overwrite: bool + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + :param aws_conn_id: The connection id specifying the authentication information for the S3 Bucket. + :type aws_conn_id: str + """ + + template_fields = ( + 'google_api_endpoint_params', + 's3_destination_key', + ) + template_ext = () + ui_color = '#cc181e' + + @apply_defaults + def __init__( + self, + google_api_service_name, + google_api_service_version, + google_api_endpoint_path, + google_api_endpoint_params, + s3_destination_key, + *args, + google_api_response_via_xcom=None, + google_api_endpoint_params_via_xcom=None, + google_api_endpoint_params_via_xcom_task_ids=None, + google_api_pagination=False, + google_api_num_retries=0, + s3_overwrite=False, + gcp_conn_id='google_cloud_default', + delegate_to=None, + aws_conn_id='aws_default', + **kwargs + ): + super(GoogleApiToS3Transfer, self).__init__(*args, **kwargs) + self.google_api_service_name = google_api_service_name + self.google_api_service_version = google_api_service_version + self.google_api_endpoint_path = google_api_endpoint_path + self.google_api_endpoint_params = google_api_endpoint_params + self.s3_destination_key = s3_destination_key + self.google_api_response_via_xcom = google_api_response_via_xcom + self.google_api_endpoint_params_via_xcom = google_api_endpoint_params_via_xcom + self.google_api_endpoint_params_via_xcom_task_ids = google_api_endpoint_params_via_xcom_task_ids + self.google_api_pagination = google_api_pagination + self.google_api_num_retries = google_api_num_retries + self.s3_overwrite = s3_overwrite + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.aws_conn_id = aws_conn_id + + def execute(self, context): + """ + Transfers Google APIs json data to S3. + + :param context: The context that is being provided when executing. + :type context: dict + """ + self.log.info('Transferring data from %s to s3', self.google_api_service_name) + + if self.google_api_endpoint_params_via_xcom: + self._update_google_api_endpoint_params_via_xcom(context['task_instance']) + + data = self._retrieve_data_from_google_api() + + self._load_data_to_s3(data) + + if self.google_api_response_via_xcom: + self._expose_google_api_response_via_xcom(context['task_instance'], data) + + def _retrieve_data_from_google_api(self): + google_discovery_api_hook = GoogleDiscoveryApiHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_service_name=self.google_api_service_name, + api_version=self.google_api_service_version + ) + google_api_response = google_discovery_api_hook.query( + endpoint=self.google_api_endpoint_path, + data=self.google_api_endpoint_params, + paginate=self.google_api_pagination, + num_retries=self.google_api_num_retries + ) + return google_api_response + + def _load_data_to_s3(self, data): + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + s3_hook.load_string( + string_data=json.dumps(data), + key=self.s3_destination_key, + replace=self.s3_overwrite + ) + + def _update_google_api_endpoint_params_via_xcom(self, task_instance): + google_api_endpoint_params = task_instance.xcom_pull( + task_ids=self.google_api_endpoint_params_via_xcom_task_ids, + key=self.google_api_endpoint_params_via_xcom + ) + self.google_api_endpoint_params.update(google_api_endpoint_params) + + def _expose_google_api_response_via_xcom(self, task_instance, data): + if sys.getsizeof(data) < MAX_XCOM_SIZE: + task_instance.xcom_push(key=self.google_api_response_via_xcom, value=data) + else: + raise RuntimeError('The size of the downloaded data is too large to push to XCom!') diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py index 9bd0a65fe2bda3..23d241aae30483 100644 --- a/airflow/operators/hive_stats_operator.py +++ b/airflow/operators/hive_stats_operator.py @@ -113,7 +113,6 @@ def execute(self, context=None): ('', 'count'): 'COUNT(*)' } for col, col_type in list(field_types.items()): - d = {} if self.assignment_func: d = self.assignment_func(col, col_type) if d is None: diff --git a/airflow/operators/local_to_gcs.py b/airflow/operators/local_to_gcs.py index 6170e104c51bb8..a9de78896fef7a 100644 --- a/airflow/operators/local_to_gcs.py +++ b/airflow/operators/local_to_gcs.py @@ -21,7 +21,7 @@ """ import warnings -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 46430b215e93cc..4d3c8da19f60c0 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -23,8 +23,10 @@ import subprocess import sys import types +from inspect import signature +from itertools import islice from textwrap import dedent -from typing import Optional, Iterable, Dict, Callable +from typing import Optional, Iterable, Dict, Callable, List import dill @@ -51,12 +53,6 @@ class PythonOperator(BaseOperator): :param op_args: a list of positional arguments that will get unpacked when calling your callable :type op_args: list (templated) - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available @@ -77,11 +73,10 @@ class PythonOperator(BaseOperator): def __init__( self, python_callable: Callable, - op_args: Optional[Iterable] = None, + op_args: Optional[List] = None, op_kwargs: Optional[Dict] = None, - provide_context: bool = False, templates_dict: Optional[Dict] = None, - templates_exts: Optional[Iterable[str]] = None, + templates_exts: Optional[List[str]] = None, *args, **kwargs ) -> None: @@ -91,12 +86,47 @@ def __init__( self.python_callable = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} - self.provide_context = provide_context self.templates_dict = templates_dict if templates_exts: self.template_ext = templates_exts - def execute(self, context): + @staticmethod + def determine_op_kwargs(python_callable: Callable, + context: Dict, + num_op_args: int = 0) -> Dict: + """ + Function that will inspect the signature of a python_callable to determine which + values need to be passed to the function. + + :param python_callable: The function that you want to invoke + :param context: The context provided by the execute method of the Operator/Sensor + :param num_op_args: The number of op_args provided, so we know how many to skip + :return: The op_args dictionary which contains the values that are compatible with the Callable + """ + context_keys = context.keys() + sig = signature(python_callable).parameters.items() + op_args_names = islice(sig, num_op_args) + for name, _ in op_args_names: + # Check if it is part of the context + if name in context_keys: + # Raise an exception to let the user know that the keyword is reserved + raise ValueError( + "The key {} in the op_args is part of the context, and therefore reserved".format(name) + ) + + if any(str(param).startswith("**") for _, param in sig): + # If there is a ** argument then just dump everything. + op_kwargs = context + else: + # If there is only for example, an execution_date, then pass only these in :-) + op_kwargs = { + name: context[name] + for name, _ in sig + if name in context # If it isn't available on the context, then ignore + } + return op_kwargs + + def execute(self, context: Dict): # Export context to make it available for callables to use. airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) self.log.info("Exporting the following env vars:\n%s", @@ -104,10 +134,10 @@ def execute(self, context): for k, v in airflow_context_vars.items()])) os.environ.update(airflow_context_vars) - if self.provide_context: - context.update(self.op_kwargs) - context['templates_dict'] = self.templates_dict - self.op_kwargs = context + context.update(self.op_kwargs) + context['templates_dict'] = self.templates_dict + + self.op_kwargs = PythonOperator.determine_op_kwargs(self.python_callable, context, len(self.op_args)) return_value = self.execute_callable() self.log.info("Done. Returned value was: %s", return_value) @@ -130,7 +160,8 @@ class BranchPythonOperator(PythonOperator, SkipMixin): downstream to allow for the DAG state to fill up and the DAG run's state to be inferred. """ - def execute(self, context): + + def execute(self, context: Dict): branch = super().execute(context) self.skip_all_except(context['ti'], branch) @@ -147,7 +178,8 @@ class ShortCircuitOperator(PythonOperator, SkipMixin): The condition is determined by the result of `python_callable`. """ - def execute(self, context): + + def execute(self, context: Dict): condition = super().execute(context) self.log.info("Condition result is %s", condition) @@ -200,12 +232,6 @@ class PythonVirtualenvOperator(PythonOperator): :type op_kwargs: list :param op_kwargs: A dict of keyword arguments to pass to python_callable. :type op_kwargs: dict - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define `**kwargs` in your - function header. - :type provide_context: bool :param string_args: Strings that are present in the global var virtualenv_string_args, available to python_callable at runtime as a list[str]. Note that args are split by newline. @@ -219,6 +245,7 @@ class PythonVirtualenvOperator(PythonOperator): processing templated fields, for examples ``['.sql', '.hql']`` :type templates_exts: list[str] """ + @apply_defaults def __init__( self, @@ -229,7 +256,6 @@ def __init__( system_site_packages: bool = True, op_args: Iterable = None, op_kwargs: Dict = None, - provide_context: bool = False, string_args: Optional[Iterable[str]] = None, templates_dict: Optional[Dict] = None, templates_exts: Optional[Iterable[str]] = None, @@ -242,7 +268,6 @@ def __init__( op_kwargs=op_kwargs, templates_dict=templates_dict, templates_exts=templates_exts, - provide_context=provide_context, *args, **kwargs) self.requirements = requirements or [] @@ -264,8 +289,8 @@ def __init__( self.__class__.__name__) # check that args are passed iff python major version matches if (python_version is not None and - str(python_version)[0] != str(sys.version_info[0]) and - self._pass_op_args()): + str(python_version)[0] != str(sys.version_info[0]) and + self._pass_op_args()): raise AirflowException("Passing op_args or op_kwargs is not supported across " "different Python major versions " "for PythonVirtualenvOperator. " @@ -383,7 +408,7 @@ def _generate_python_code(self): fn = self.python_callable # dont try to read pickle if we didnt pass anything if self._pass_op_args(): - load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)'\ + load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)' \ .format(pickling_library) else: load_args_line = 'arg_dict = {"args": [], "kwargs": {}}' diff --git a/airflow/operators/sql_to_gcs.py b/airflow/operators/sql_to_gcs.py index 0d6bd345144b41..9287d771fb8441 100644 --- a/airflow/operators/sql_to_gcs.py +++ b/airflow/operators/sql_to_gcs.py @@ -27,7 +27,7 @@ import unicodecsv as csv -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator diff --git a/airflow/security/kerberos.py b/airflow/security/kerberos.py index 1c891bef3867b3..18f52c3834c43e 100644 --- a/airflow/security/kerberos.py +++ b/airflow/security/kerberos.py @@ -1,4 +1,21 @@ #!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # Licensed to Cloudera, Inc. under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,6 +31,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Kerberos security provider""" from typing import Optional import socket @@ -29,10 +47,16 @@ log = LoggingMixin().log -def renew_from_kt(principal, keytab): +def renew_from_kt(principal: str, keytab: str): + """ + Renew kerberos token from keytab + + :param principal: principal + :param keytab: keytab file + :return: None + """ # The config is specified in seconds. But we ask for that same amount in # minutes to give ourselves a large renewal buffer. - renewal_lifetime = "%sm" % conf.getint('kerberos', 'reinit_frequency') cmd_principal = principal or conf.get('kerberos', 'principal').replace( @@ -47,7 +71,7 @@ def renew_from_kt(principal, keytab): "-c", conf.get('kerberos', 'ccache'), # specify credentials cache cmd_principal ] - log.info("Reinitting kerberos from keytab: %s", " ".join(cmdv)) + log.info("Re-initialising kerberos from keytab: %s", " ".join(cmdv)) subp = subprocess.Popen(cmdv, stdout=subprocess.PIPE, @@ -63,7 +87,7 @@ def renew_from_kt(principal, keytab): ) sys.exit(subp.returncode) - global NEED_KRB181_WORKAROUND + global NEED_KRB181_WORKAROUND # pylint: disable=global-statement if NEED_KRB181_WORKAROUND is None: NEED_KRB181_WORKAROUND = detect_conf_var() if NEED_KRB181_WORKAROUND: @@ -73,7 +97,13 @@ def renew_from_kt(principal, keytab): perform_krb181_workaround(principal) -def perform_krb181_workaround(principal): +def perform_krb181_workaround(principal: str): + """ + Workaround for Kerberos 1.8.1. + + :param principal: principal name + :return: None + """ cmdv = [conf.get('kerberos', 'kinit_path'), "-c", conf.get('kerberos', 'ccache'), "-R"] # Renew ticket_cache @@ -112,7 +142,14 @@ def detect_conf_var() -> bool: return b'X-CACHECONF:' in file.read() -def run(principal, keytab): +def run(principal: str, keytab: str): + """ + Run the kerbros renewer. + + :param principal: principal name + :param keytab: keytab file + :return: None + """ if not keytab: log.debug("Keytab renewer not starting, no keytab configured") sys.exit(0) diff --git a/airflow/security/utils.py b/airflow/security/utils.py index 33550107b5bd27..85e2cedba712ed 100644 --- a/airflow/security/utils.py +++ b/airflow/security/utils.py @@ -1,4 +1,21 @@ #!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + # Licensed to Cloudera, Inc. under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,7 +32,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +"""Various security-related utils.""" import re import socket @@ -34,6 +51,7 @@ def get_components(principal): def replace_hostname_pattern(components, host=None): + """Replaces hostname with the right pattern including lowercase of the name.""" fqdn = host if not fqdn or fqdn == '0.0.0.0': fqdn = get_hostname() @@ -41,7 +59,7 @@ def replace_hostname_pattern(components, host=None): def get_fqdn(hostname_or_ip=None): - # Get hostname + """Retrieves FQDN - hostname for the IP or hostname.""" try: if hostname_or_ip: fqdn = socket.gethostbyaddr(hostname_or_ip)[0] @@ -56,6 +74,7 @@ def get_fqdn(hostname_or_ip=None): def principal_from_username(username, realm): + """Retrieves principal from the user name and realm.""" if ('@' not in username) and realm: username = "{}@{}".format(username, realm) diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py index 6e5d69946232e0..3102d33f4c049f 100644 --- a/airflow/sensors/http_sensor.py +++ b/airflow/sensors/http_sensor.py @@ -16,6 +16,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Dict, Callable + +from airflow.operators.python_operator import PythonOperator from airflow.exceptions import AirflowException from airflow.hooks.http_hook import HttpHook @@ -31,13 +34,17 @@ class HttpSensor(BaseSensorOperator): HTTP Error codes other than 404 (like 403) or Connection Refused Error would fail the sensor itself directly (no more poking). - The response check can access the template context by passing ``provide_context=True`` to the operator:: + The response check can access the template context to the operator: - def response_check(response, **context): - # Can look at context['ti'] etc. + def response_check(response, task_instance): + # The task_instance is injected, so you can pull data form xcom + # Other context variables such as dag, ds, execution_date are also available. + xcom_data = task_instance.xcom_pull(task_ids='pushing_task') + # In practice you would do something more sensible with this data.. + print(xcom_data) return True - HttpSensor(task_id='my_http_sensor', ..., provide_context=True, response_check=response_check) + HttpSensor(task_id='my_http_sensor', ..., response_check=response_check) :param http_conn_id: The connection to run the sensor against @@ -50,12 +57,6 @@ def response_check(response, **context): :type request_params: a dictionary of string key/value pairs :param headers: The HTTP headers to be added to the GET request :type headers: a dictionary of string key/value pairs - :param provide_context: if set to true, Airflow will pass a set of - keyword arguments that can be used in your function. This set of - kwargs correspond exactly to what you can use in your jinja - templates. For this to work, you need to define context in your - function header. - :type provide_context: bool :param response_check: A check against the 'requests' response object. Returns True for 'pass' and False otherwise. :type response_check: A lambda or defined function. @@ -69,14 +70,13 @@ def response_check(response, **context): @apply_defaults def __init__(self, - endpoint, - http_conn_id='http_default', - method='GET', - request_params=None, - headers=None, - response_check=None, - provide_context=False, - extra_options=None, *args, **kwargs): + endpoint: str, + http_conn_id: str = 'http_default', + method: str = 'GET', + request_params: Dict = None, + headers: Dict = None, + response_check: Callable = None, + extra_options: Dict = None, *args, **kwargs): super().__init__(*args, **kwargs) self.endpoint = endpoint self.http_conn_id = http_conn_id @@ -84,13 +84,12 @@ def __init__(self, self.headers = headers or {} self.extra_options = extra_options or {} self.response_check = response_check - self.provide_context = provide_context self.hook = HttpHook( method=method, http_conn_id=http_conn_id) - def poke(self, context): + def poke(self, context: Dict): self.log.info('Poking: %s', self.endpoint) try: response = self.hook.run(self.endpoint, @@ -98,10 +97,9 @@ def poke(self, context): headers=self.headers, extra_options=self.extra_options) if self.response_check: - if self.provide_context: - return self.response_check(response, **context) - else: - return self.response_check(response) + op_kwargs = PythonOperator.determine_op_kwargs(self.response_check, context) + return self.response_check(response, **op_kwargs) + except AirflowException as ae: if str(ae).startswith("404"): return False diff --git a/airflow/settings.py b/airflow/settings.py index 264d8410534fd1..ba6f4abd035d6b 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -32,7 +32,6 @@ import airflow from airflow.configuration import conf, AIRFLOW_HOME, WEBSERVER_CONFIG # NOQA F401 -from airflow.kubernetes.pod import Pod from airflow.logging_config import configure_logging from airflow.utils.sqlalchemy import setup_event_handlers @@ -102,10 +101,11 @@ def policy(task_instance): """ -def pod_mutation_hook(pod: Pod): +def pod_mutation_hook(pod): """ - This setting allows altering ``Pod`` objects before they are passed to - the Kubernetes client by the ``PodLauncher`` for scheduling. + This setting allows altering ``kubernetes.client.models.V1Pod`` object + before they are passed to the Kubernetes client by the ``PodLauncher`` + for scheduling. To define a pod mutation hook, add a ``airflow_local_settings`` module to your PYTHONPATH that defines this ``pod_mutation_hook`` function. @@ -144,10 +144,7 @@ def configure_orm(disable_connection_pool=False): # Pool size engine args not supported by sqlite. # If no config value is defined for the pool size, select a reasonable value. # 0 means no limit, which could lead to exceeding the Database connection limit. - try: - pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE') - except conf.AirflowConfigException: - pool_size = 5 + pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE', fallback=5) # The maximum overflow size of the pool. # When the number of checked-out connections reaches the size set in pool_size, @@ -159,24 +156,26 @@ def configure_orm(disable_connection_pool=False): # max_overflow can be set to -1 to indicate no overflow limit; # no limit will be placed on the total number # of concurrent connections. Defaults to 10. - try: - max_overflow = conf.getint('core', 'SQL_ALCHEMY_MAX_OVERFLOW') - except conf.AirflowConfigException: - max_overflow = 10 + max_overflow = conf.getint('core', 'SQL_ALCHEMY_MAX_OVERFLOW', fallback=10) # The DB server already has a value for wait_timeout (number of seconds after # which an idle sleeping connection should be killed). Since other DBs may # co-exist on the same server, SQLAlchemy should set its # pool_recycle to an equal or smaller value. - try: - pool_recycle = conf.getint('core', 'SQL_ALCHEMY_POOL_RECYCLE') - except conf.AirflowConfigException: - pool_recycle = 1800 + pool_recycle = conf.getint('core', 'SQL_ALCHEMY_POOL_RECYCLE', fallback=1800) + + # Check connection at the start of each connection pool checkout. + # Typically, this is a simple statement like “SELECT 1”, but may also make use + # of some DBAPI-specific method to test the connection for liveness. + # More information here: + # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic + pool_pre_ping = conf.getboolean('core', 'SQL_ALCHEMY_POOL_PRE_PING', fallback=True) log.info("settings.configure_orm(): Using pool settings. pool_size={}, max_overflow={}, " "pool_recycle={}, pid={}".format(pool_size, max_overflow, pool_recycle, os.getpid())) engine_args['pool_size'] = pool_size engine_args['pool_recycle'] = pool_recycle + engine_args['pool_pre_ping'] = pool_pre_ping engine_args['max_overflow'] = max_overflow # Allow the user to specify an encoding for their DB otherwise default @@ -187,8 +186,7 @@ def configure_orm(disable_connection_pool=False): engine_args['encoding'] = engine_args['encoding'].__str__() engine = create_engine(SQL_ALCHEMY_CONN, **engine_args) - reconnect_timeout = conf.getint('core', 'SQL_ALCHEMY_RECONNECT_TIMEOUT') - setup_event_handlers(engine, reconnect_timeout) + setup_event_handlers(engine) Session = scoped_session( sessionmaker(autocommit=False, diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py new file mode 100644 index 00000000000000..2ba386b1c830ad --- /dev/null +++ b/airflow/utils/dot_renderer.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Renderer DAG (tasks and dependencies) to the graphviz object. +""" + +import graphviz + +from airflow import DAG + + +def _refine_color(color: str): + """ + Converts color in #RGB (12 bits) format to #RRGGBB (32 bits), if it possible. + Otherwise, it returns the original value. Graphviz does not support colors in #RGB format. + + :param color: Text representation of color + :return: Refined representation of color + """ + if len(color) == 4 and color[0] == "#": + color_r = color[1] + color_g = color[2] + color_b = color[3] + return "#" + color_r + color_r + color_g + color_g + color_b + color_b + return color + + +def render_dag(dag: DAG) -> graphviz.Digraph: + """ + Renders the DAG object to the DOT object. + + :param dag: DAG that will be rendered. + :return: Graphviz object + :rtype: graphviz.Digraph + """ + dot = graphviz.Digraph(dag.dag_id, graph_attr={"rankdir": "LR", "labelloc": "t", "label": dag.dag_id}) + for task in dag.tasks: + dot.node( + task.task_id, + _attributes={ + "shape": "rectangle", + "style": "filled,rounded", + "color": _refine_color(task.ui_fgcolor), + "fillcolor": _refine_color(task.ui_color), + }, + ) + for downstream_task_id in task.downstream_task_ids: + dot.edge(task.task_id, downstream_task_id) + return dot diff --git a/airflow/utils/log/colored_log.py b/airflow/utils/log/colored_log.py index 29a9a0f33428b7..496885cfa58f24 100644 --- a/airflow/utils/log/colored_log.py +++ b/airflow/utils/log/colored_log.py @@ -19,6 +19,7 @@ """ Class responsible for colouring logs based on log level. """ +import re import sys from typing import Any, Union @@ -55,14 +56,23 @@ def _color_arg(arg: Any) -> Union[str, float, int]: return arg return colored(str(arg), **ARGS) # type: ignore + @staticmethod + def _count_number_of_arguments_in_message(record: LogRecord) -> int: + matches = re.findall(r"%.", record.msg) + return len(matches) if matches else 0 + def _color_record_args(self, record: LogRecord) -> LogRecord: if isinstance(record.args, (tuple, list)): record.args = tuple(self._color_arg(arg) for arg in record.args) elif isinstance(record.args, dict): - # Case of logging.debug("a %(a)d b %(b)s", {'a':1, 'b':2}) - record.args = { - key: self._color_arg(value) for key, value in record.args.items() - } + if self._count_number_of_arguments_in_message(record) > 1: + # Case of logging.debug("a %(a)d b %(b)s", {'a':1, 'b':2}) + record.args = { + key: self._color_arg(value) for key, value in record.args.items() + } + else: + # Case of single dict passed to formatted string + record.args = self._color_arg(record.args) # type: ignore elif isinstance(record.args, str): record.args = self._color_arg(record.args) return record diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 1e65539fb1ec53..b389dee4e18f2e 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -16,13 +16,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""File logging handler for tasks.""" import logging import os +from typing import Optional + import requests from airflow.configuration import conf from airflow.configuration import AirflowConfigException +from airflow.models import TaskInstance from airflow.utils.file import mkdirs from airflow.utils.helpers import parse_template_string @@ -33,39 +36,38 @@ class FileTaskHandler(logging.Handler): task instance logs. It creates and delegates log handling to `logging.FileHandler` after receiving task instance context. It reads logs from task instance's host machine. + :param base_log_folder: Base log folder to place logs. + :param filename_template: template filename string """ - - def __init__(self, base_log_folder, filename_template): - """ - :param base_log_folder: Base log folder to place logs. - :param filename_template: template filename string - """ + def __init__(self, base_log_folder: str, filename_template: str): super().__init__() - self.handler = None + self.handler = None # type: Optional[logging.FileHandler] self.local_base = base_log_folder self.filename_template, self.filename_jinja_template = \ parse_template_string(filename_template) - def set_context(self, ti): + def set_context(self, ti: TaskInstance): """ Provide task_instance context to airflow task handler. + :param ti: task instance object """ local_loc = self._init_file(ti) self.handler = logging.FileHandler(local_loc) - self.handler.setFormatter(self.formatter) + if self.formatter: + self.handler.setFormatter(self.formatter) self.handler.setLevel(self.level) def emit(self, record): - if self.handler is not None: + if self.handler: self.handler.emit(record) def flush(self): - if self.handler is not None: + if self.handler: self.handler.flush() def close(self): - if self.handler is not None: + if self.handler: self.handler.close() def _render_filename(self, ti, try_number): @@ -79,7 +81,7 @@ def _render_filename(self, ti, try_number): execution_date=ti.execution_date.isoformat(), try_number=try_number) - def _read(self, ti, try_number, metadata=None): + def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argument """ Template method that contains custom logic of reading logs given the try_number. @@ -102,7 +104,7 @@ def _read(self, ti, try_number, metadata=None): with open(location) as file: log += "*** Reading local file: {}\n".format(location) log += "".join(file.readlines()) - except Exception as e: + except Exception as e: # pylint: disable=broad-except log = "*** Failed to load local log file: {}\n".format(location) log += "*** {}\n".format(str(e)) else: @@ -127,7 +129,7 @@ def _read(self, ti, try_number, metadata=None): response.raise_for_status() log += '\n' + response.text - except Exception as e: + except Exception as e: # pylint: disable=broad-except log += "*** Failed to fetch log file from worker. {}\n".format(str(e)) return log, {'end_of_log': True} @@ -159,13 +161,13 @@ def read(self, task_instance, try_number=None, metadata=None): try_numbers = [try_number] logs = [''] * len(try_numbers) - metadatas = [{}] * len(try_numbers) - for i, try_number in enumerate(try_numbers): - log, metadata = self._read(task_instance, try_number, metadata) + metadata_array = [{}] * len(try_numbers) + for i, try_number_element in enumerate(try_numbers): + log, metadata = self._read(task_instance, try_number_element, metadata) logs[i] += log - metadatas[i] = metadata + metadata_array[i] = metadata - return logs, metadatas + return logs, metadata_array def _init_file(self, ti): """ diff --git a/airflow/utils/log/gcs_task_handler.py b/airflow/utils/log/gcs_task_handler.py index c80e53e03199bc..b685d2d0fdbde5 100644 --- a/airflow/utils/log/gcs_task_handler.py +++ b/airflow/utils/log/gcs_task_handler.py @@ -46,7 +46,7 @@ def __init__(self, base_log_folder, gcs_log_folder, filename_template): def hook(self): remote_conn_id = conf.get('core', 'REMOTE_LOG_CONN_ID') try: - from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook + from airflow.gcp.hooks.gcs import GoogleCloudStorageHook return GoogleCloudStorageHook( google_cloud_storage_conn_id=remote_conn_id ) diff --git a/airflow/utils/log/logging_mixin.py b/airflow/utils/log/logging_mixin.py index 887ba102d033d1..b7ee7098018ad7 100644 --- a/airflow/utils/log/logging_mixin.py +++ b/airflow/utils/log/logging_mixin.py @@ -16,14 +16,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import re import logging import sys -import warnings from contextlib import contextmanager from logging import Handler, StreamHandler +# 7-bit C1 ANSI escape sequences +ANSI_ESCAPE = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]') + + +def remove_escape_codes(text: str) -> str: + """ + Remove ANSI escapes codes from string. It's used to remove + "colors" from log messages. + """ + return ANSI_ESCAPE.sub("", text) + class LoggingMixin: """ @@ -32,19 +42,6 @@ class LoggingMixin: def __init__(self, context=None): self._set_context(context) - # We want to deprecate the logger property in Airflow 2.0 - # The log property is the de facto standard in most programming languages - @property - def logger(self): - warnings.warn( - 'Initializing logger for {} using logger(), which will ' - 'be replaced by .log in Airflow 2.0'.format( - self.__class__.__module__ + '.' + self.__class__.__name__ - ), - DeprecationWarning - ) - return self.log - @property def log(self): try: @@ -86,6 +83,12 @@ def closed(self): """ return False + def _propagate_log(self, message): + """ + Propagate message removing escape codes. + """ + self.logger.log(self.level, remove_escape_codes(message)) + def write(self, message): """ Do whatever it takes to actually log the specified logging record @@ -95,7 +98,7 @@ def write(self, message): self._buffer += message else: self._buffer += message - self.logger.log(self.level, self._buffer.rstrip()) + self._propagate_log(self._buffer.rstrip()) self._buffer = str() def flush(self): @@ -103,7 +106,7 @@ def flush(self): Ensure all logging output has been flushed """ if len(self._buffer) > 0: - self.logger.log(self.level, self._buffer) + self._propagate_log(self._buffer) self._buffer = str() def isatty(self): diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 55e018342191af..423cd283c39b27 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -21,11 +21,9 @@ import os import json import pendulum -import time -import random from dateutil import relativedelta -from sqlalchemy import event, exc, select +from sqlalchemy import event, exc from sqlalchemy.types import Text, DateTime, TypeDecorator from airflow.utils.log.logging_mixin import LoggingMixin @@ -34,73 +32,7 @@ utc = pendulum.timezone('UTC') -def setup_event_handlers(engine, - reconnect_timeout_seconds, - initial_backoff_seconds=0.2, - max_backoff_seconds=120): - @event.listens_for(engine, "engine_connect") - def ping_connection(connection, branch): # pylint: disable=unused-variable - """ - Pessimistic SQLAlchemy disconnect handling. Ensures that each - connection returned from the pool is properly connected to the database. - - http://docs.sqlalchemy.org/en/rel_1_1/core/pooling.html#disconnect-handling-pessimistic - """ - if branch: - # "branch" refers to a sub-connection of a connection, - # we don't want to bother pinging on these. - return - - start = time.time() - backoff = initial_backoff_seconds - - # turn off "close with result". This flag is only used with - # "connectionless" execution, otherwise will be False in any case - save_should_close_with_result = connection.should_close_with_result - - while True: - connection.should_close_with_result = False - - try: - connection.scalar(select([1])) - # If we made it here then the connection appears to be healthy - break - except exc.DBAPIError as err: - if time.time() - start >= reconnect_timeout_seconds: - log.error( - "Failed to re-establish DB connection within %s secs: %s", - reconnect_timeout_seconds, - err) - raise - if err.connection_invalidated: - # Don't log the first time -- this happens a lot and unless - # there is a problem reconnecting is not a sign of a - # problem - if backoff > initial_backoff_seconds: - log.warning("DB connection invalidated. Reconnecting...") - else: - log.debug("DB connection invalidated. Initial reconnect") - - # Use a truncated binary exponential backoff. Also includes - # a jitter to prevent the thundering herd problem of - # simultaneous client reconnects - backoff += backoff * random.random() - time.sleep(min(backoff, max_backoff_seconds)) - - # run the same SELECT again - the connection will re-validate - # itself and establish a new connection. The disconnect detection - # here also causes the whole connection pool to be invalidated - # so that all stale connections are discarded. - continue - else: - log.error( - "Unknown database connection error. Not retrying: %s", - err) - raise - finally: - # restore "close with result" - connection.should_close_with_result = save_should_close_with_result - +def setup_event_handlers(engine): @event.listens_for(engine, "connect") def connect(dbapi_connection, connection_record): # pylint: disable=unused-variable connection_record.info['pid'] = os.getpid() diff --git a/airflow/utils/strings.py b/airflow/utils/strings.py index 179d2b15fd29c0..63c44b821259bc 100644 --- a/airflow/utils/strings.py +++ b/airflow/utils/strings.py @@ -2,7 +2,6 @@ Common utility functions with strings ''' # -*- coding: utf-8 -*- -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -28,4 +27,4 @@ def get_random_string(length=8, choices=string.ascii_letters + string.digits): ''' Generate random string ''' - return ''.join([choice(choices) for i in range(length)]) + return ''.join([choice(choices) for _ in range(length)]) diff --git a/airflow/www/app.py b/airflow/www/app.py index 543ac41189a375..0ce638e9015b5c 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -90,7 +90,6 @@ def create_app(config=None, session=None, testing=False, app_name="Airflow"): configure_manifest_files(app) with app.app_context(): - from airflow.www.security import AirflowSecurityManager security_manager_class = app.config.get('SECURITY_MANAGER_CLASS') or \ AirflowSecurityManager @@ -108,6 +107,9 @@ def create_app(config=None, session=None, testing=False, app_name="Airflow"): def init_views(appbuilder): from airflow.www import views + # Remove the session from scoped_session registry to avoid + # reusing a session with a disconnected connection + appbuilder.session.remove() appbuilder.add_view_no_menu(views.Airflow()) appbuilder.add_view_no_menu(views.DagModelView()) appbuilder.add_view_no_menu(views.ConfigurationView()) diff --git a/airflow/www/security.py b/airflow/www/security.py index 12e93a318c502d..36fe22fd415d07 100644 --- a/airflow/www/security.py +++ b/airflow/www/security.py @@ -340,8 +340,8 @@ def clean_perms(self): pvms = ( sesh.query(sqla_models.PermissionView) .filter(or_( - sqla_models.PermissionView.permission == None, # noqa pylint: disable=singleton-comparison - sqla_models.PermissionView.view_menu == None, # noqa pylint: disable=singleton-comparison + sqla_models.PermissionView.permission is None, # noqa pylint: disable=singleton-comparison + sqla_models.PermissionView.view_menu is None, # noqa pylint: disable=singleton-comparison )) ) # Since FAB doesn't define ON DELETE CASCADE on these tables, we need diff --git a/airflow/www/templates/airflow/dags.html b/airflow/www/templates/airflow/dags.html index 8c17546e1fa132..e98c1e9ee39d49 100644 --- a/airflow/www/templates/airflow/dags.html +++ b/airflow/www/templates/airflow/dags.html @@ -444,7 +444,7 @@

DAGs

}) .attr('fill-opacity', 0) .attr('r', diameter/2) - .attr('title', function(d) {return d.state}) + .attr('title', function(d) {return d.state || 'none'}) .attr('style', function(d) { if (d.count > 0) return"cursor:pointer;" diff --git a/airflow/www/templates/airflow/tree.html b/airflow/www/templates/airflow/tree.html index 9565d5a42e7e4e..d968b82d73033f 100644 --- a/airflow/www/templates/airflow/tree.html +++ b/airflow/www/templates/airflow/tree.html @@ -82,14 +82,15 @@