From d15f4ed71dc388f37f782f0186cfe4d743f1de00 Mon Sep 17 00:00:00 2001 From: Akshesh Date: Wed, 26 Jun 2019 14:52:49 +0700 Subject: [PATCH] [AIRFLOW-4843] Allow orchestration via Docker Swarm (SwarmOperator) Add support for running Docker containers via Docker Swarm which allows the task to run on any machine (node) which is a part of your Swarm cluster More details: https://issues.apache.org/jira/browse/AIRFLOW-4843 --- .../example_dags/example_swarm_operator.py | 55 ++++++ airflow/operators/docker_operator.py | 44 +++-- airflow/operators/swarm_operator.py | 174 ++++++++++++++++++ airflow/utils/strings.py | 25 +++ tests/operators/test_swarm_operator.py | 147 +++++++++++++++ 5 files changed, 425 insertions(+), 20 deletions(-) create mode 100644 airflow/example_dags/example_swarm_operator.py create mode 100644 airflow/operators/swarm_operator.py create mode 100644 airflow/utils/strings.py create mode 100644 tests/operators/test_swarm_operator.py diff --git a/airflow/example_dags/example_swarm_operator.py b/airflow/example_dags/example_swarm_operator.py new file mode 100644 index 0000000000000..0bbc5835d51d9 --- /dev/null +++ b/airflow/example_dags/example_swarm_operator.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" + +from datetime import timedelta + +import airflow +from airflow import DAG +from airflow.operators.swarm_operator import SwarmOperator + + +default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + 'start_date': airflow.utils.dates.days_ago(1), + 'email': ['airflow@example.com'], + 'email_on_failure': False, + 'email_on_retry': False +} + +dag = DAG( + 'swarm_sample', + default_args=default_args, + schedule_interval=timedelta(minutes=10), + catchup=False +) + +with dag as dag: + t1 = SwarmOperator( + api_version='auto', + docker_url='tcp://localhost:2375', #Set your docker URL + command='sleep 10', + image='reg-hk.agodadev.io/adp-messaging/adp-airflow:1.10.3', + auto_remove=True, + task_id='sleep_with_swarm', + ) + +""" diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py index 08f263f48ba42..0fdfe236e6b64 100644 --- a/airflow/operators/docker_operator.py +++ b/airflow/operators/docker_operator.py @@ -190,29 +190,10 @@ def get_hook(self): tls=self.__get_tls_config() ) - def execute(self, context): + def _execute(self): self.log.info('Starting docker container from image %s', self.image) - tls_config = self.__get_tls_config() - - if self.docker_conn_id: - self.cli = self.get_hook().get_conn() - else: - self.cli = APIClient( - base_url=self.docker_url, - version=self.api_version, - tls=tls_config - ) - - if self.force_pull or len(self.cli.images(name=self.image)) == 0: - self.log.info('Pulling docker image %s', self.image) - for l in self.cli.pull(self.image, stream=True): - output = json.loads(l.decode('utf-8').strip()) - if 'status' in output: - self.log.info("%s", output['status']) - with TemporaryDirectory(prefix='airflowtmp', dir=self.host_tmp_dir) as host_tmp_dir: - self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir self.volumes.append('{0}:{1}'.format(host_tmp_dir, self.tmp_dir)) self.container = self.cli.create_container( @@ -248,6 +229,29 @@ def execute(self, context): return self.cli.logs(container=self.container['Id']) \ if self.xcom_all else str(line) + def execute(self, context): + tls_config = self.__get_tls_config() + + if self.docker_conn_id: + self.cli = self.get_hook().get_conn() + else: + self.cli = APIClient( + base_url=self.docker_url, + version=self.api_version, + tls=tls_config + ) + + if self.force_pull or len(self.cli.images(name=self.image)) == 0: + self.log.info('Pulling docker image %s', self.image) + for l in self.cli.pull(self.image, stream=True): + output = json.loads(l.decode('utf-8').strip()) + if 'status' in output: + self.log.info("%s", output['status']) + + self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir + + self._execute() + def get_command(self): if self.command is not None and self.command.strip().find('[') == 0: commands = ast.literal_eval(self.command) diff --git a/airflow/operators/swarm_operator.py b/airflow/operators/swarm_operator.py new file mode 100644 index 0000000000000..df9d79019f2d3 --- /dev/null +++ b/airflow/operators/swarm_operator.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.exceptions import AirflowException +from airflow.operators.docker_operator import DockerOperator +from airflow.utils.decorators import apply_defaults +from airflow.utils.strings import get_random_string +from docker import types + + +class SwarmOperator(DockerOperator): + """ + Execute a command as a ephemeral docker swarm service. + Example use-case - Using Docker Swarm orchestration to make one-time + scripts highly available. + + A temporary directory is created on the host and + mounted into a container to allow storing files + that together exceed the default disk size of 10GB in a container. + The path to the mounted directory can be accessed + via the environment variable ``AIRFLOW_TMP_DIR``. + + If a login to a private registry is required prior to pulling the image, a + Docker connection needs to be configured in Airflow and the connection ID + be provided with the parameter ``docker_conn_id``. + + :param image: Docker image from which to create the container. + If image tag is omitted, "latest" will be used. + :type image: str + :param api_version: Remote API version. Set to ``auto`` to automatically + detect the server's version. + :type api_version: str + :param auto_remove: Auto-removal of the container on daemon side when the + container's process exits. + The default is False. + :type auto_remove: bool + :param command: Command to be run in the container. (templated) + :type command: str or list + :param docker_url: URL of the host running the docker daemon. + Default is unix://var/run/docker.sock + :type docker_url: str + :param environment: Environment variables to set in the container. (templated) + :type environment: dict + :param force_pull: Pull the docker image on every run. Default is False. + :type force_pull: bool + :param mem_limit: Maximum amount of memory the container can use. + Either a float value, which represents the limit in bytes, + or a string like ``128m`` or ``1g``. + :type mem_limit: float or str + :param tls_ca_cert: Path to a PEM-encoded certificate authority + to secure the docker connection. + :type tls_ca_cert: str + :param tls_client_cert: Path to the PEM-encoded certificate + used to authenticate docker client. + :type tls_client_cert: str + :param tls_client_key: Path to the PEM-encoded key used to authenticate docker client. + :type tls_client_key: str + :param tls_hostname: Hostname to match against + the docker server certificate or False to disable the check. + :type tls_hostname: str or bool + :param tls_ssl_version: Version of SSL to use when communicating with docker daemon. + :type tls_ssl_version: str + :param tmp_dir: Mount point inside the container to + a temporary directory created on the host by the operator. + The path is also made available via the environment variable + ``AIRFLOW_TMP_DIR`` inside the container. + :type tmp_dir: str + :param user: Default user inside the docker container. + :type user: int or str + :param docker_conn_id: ID of the Airflow connection to use + :type docker_conn_id: str + """ + + @apply_defaults + def __init__( + self, + image, + api_version=None, + command=None, + docker_url='unix://var/run/docker.sock', + environment=None, + force_pull=False, + mem_limit=None, + tls_ca_cert=None, + tls_client_cert=None, + tls_client_key=None, + tls_hostname=None, + tls_ssl_version=None, + tmp_dir='/tmp/airflow', + user=None, + docker_conn_id=None, + auto_remove=False, + *args, + **kwargs): + + super().__init__(image=image, *args, **kwargs) + self.api_version = api_version + self.auto_remove = auto_remove + self.command = command + self.docker_url = docker_url + self.environment = environment or {} + self.force_pull = force_pull + self.image = image + self.mem_limit = mem_limit + self.tls_ca_cert = tls_ca_cert + self.tls_client_cert = tls_client_cert + self.tls_client_key = tls_client_key + self.tls_hostname = tls_hostname + self.tls_ssl_version = tls_ssl_version + self.tmp_dir = tmp_dir + self.user = user + self.docker_conn_id = docker_conn_id + + self.cli = None + self.service = None + + def _execute(self): + self.log.info('Starting docker service from image %s', self.image) + + self.service = self.cli.create_service( + types.TaskTemplate( + container_spec=types.ContainerSpec( + image=self.image, + command=self.get_command(), + env=self.environment, + user=self.user + ), + restart_policy=types.RestartPolicy(condition='none'), + resources=types.Resources(mem_limit=self.mem_limit) + ), + name='airflow-%s' % get_random_string(), + labels={'name': 'airflow__%s__%s' % (self.dag_id, self.task_id)} + ) + + self.log.info('Service started: %s' % str(self.service)) + + status = None + # wait for the service to start the task + while not self.cli.tasks(filters={'service': self.service['ID']}): + continue + while True: + + status = self.cli.tasks( + filters={'service': self.service['ID']} + )[0]['Status']['State'] + if status in ['failed', 'complete']: + self.log.info('Service status before exiting: %s' % status) + break + + if self.auto_remove: + self.cli.remove_service(self.service['ID']) + if status == 'failed': + raise AirflowException('Service failed: ' + repr(self.service)) + + def on_kill(self): + if self.cli is not None: + self.log.info('Removing docker service: %s' % self.service['ID']) + self.cli.remove_service(self.service['ID']) diff --git a/airflow/utils/strings.py b/airflow/utils/strings.py new file mode 100644 index 0000000000000..00c46e788a233 --- /dev/null +++ b/airflow/utils/strings.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import string +from random import choice + + +def get_random_string(length=10, choices=string.ascii_letters + string.digits): + return ''.join([choice(choices) for i in range(length)]) diff --git a/tests/operators/test_swarm_operator.py b/tests/operators/test_swarm_operator.py new file mode 100644 index 0000000000000..02278e3127388 --- /dev/null +++ b/tests/operators/test_swarm_operator.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +try: + # Gracefully fail unit tests when docker-py isn't installed + from airflow.operators.swarm_operator import SwarmOperator + from docker import APIClient +except ImportError: + pass + +from airflow.exceptions import AirflowException +from tests.compat import mock + + +class SwarmOperatorTestCase(unittest.TestCase): + + @mock.patch('airflow.operators.docker_operator.APIClient') + @mock.patch('airflow.operators.swarm_operator.types') + def test_execute(self, types_mock, client_class_mock): + + mock_obj = mock.Mock() + + def _client_tasks_side_effect(): + for i in range(2): + yield [{'Status': {'State': 'pending'}}] + yield [{'Status': {'State': 'complete'}}] + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {'ID': 'akki'} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.side_effect = _client_tasks_side_effect() + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + client_class_mock.return_value = client_mock + + operator = SwarmOperator( + api_version='1.19', command='env', environment={'UNIT': 'TEST'}, image='ubuntu:latest', + mem_limit='128m', user='unittest', task_id='unittest', auto_remove=True + ) + operator.execute(None) + + types_mock.TaskTemplate.assert_called_with( + container_spec=mock_obj, restart_policy=mock_obj, resources=mock_obj + ) + types_mock.ContainerSpec.assert_called_with( + image='ubuntu:latest', command='env', user='unittest', + env={'UNIT': 'TEST', 'AIRFLOW_TMP_DIR': '/tmp/airflow'} + ) + types_mock.RestartPolicy.assert_called_with(condition='none') + types_mock.Resources.assert_called_with(mem_limit='128m') + + client_class_mock.assert_called_with( + base_url='unix://var/run/docker.sock', tls=None, version='1.19' + ) + + csargs, cskwargs = client_mock.create_service.call_args_list[0] + self.assertEqual( + len(csargs), 1, 'create_service called with different number of arguments than expected' + ) + self.assertEqual(csargs, (mock_obj, )) + self.assertEqual(cskwargs['labels'], {'name': 'airflow__adhoc_airflow__unittest'}) + self.assertTrue(cskwargs['name'].startswith('airflow-')) + self.assertEqual(client_mock.tasks.call_count, 3) + client_mock.remove_service.assert_called_with('akki') + + @mock.patch('airflow.operators.docker_operator.APIClient') + @mock.patch('airflow.operators.swarm_operator.types') + def test_no_auto_remove(self, types_mock, client_class_mock): + + mock_obj = mock.Mock() + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {'ID': 'akki'} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.return_value = [{'Status': {'State': 'complete'}}] + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + client_class_mock.return_value = client_mock + + operator = SwarmOperator(image='', auto_remove=False, task_id='unittest') + operator.execute(None) + + self.assertEqual( + client_mock.remove_service.call_count, 0, + 'Docker service being removed even when `auto_remove` set to `False`' + ) + + @mock.patch('airflow.operators.docker_operator.APIClient') + @mock.patch('airflow.operators.swarm_operator.types') + def test_failed_service_raises_error(self, types_mock, client_class_mock): + + mock_obj = mock.Mock() + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {'ID': 'akki'} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.return_value = [{'Status': {'State': 'failed'}}] + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + client_class_mock.return_value = client_mock + + operator = SwarmOperator(image='', auto_remove=False, task_id='unittest') + msg = "Service failed: {'ID': 'akki'}" + with self.assertRaises(AirflowException) as error: + operator.execute(None) + self.assertEqual(str(error.exception), msg) + + def test_on_kill(self): + client_mock = mock.Mock(spec=APIClient) + + operator = SwarmOperator(image='', auto_remove=False, task_id='unittest') + operator.cli = client_mock + operator.service = {'ID': 'akki'} + + operator.on_kill() + + client_mock.remove_service.assert_called_with('akki')