Skip to content

Commit

Permalink
Add args to docker service ContainerSpec (apache#39464)
Browse files Browse the repository at this point in the history
* Add args to docker service ContainerSpec

* args is a list in ContainerSpec

* fix ContainerSpec assertion

* fix args formatter

* fix ContainerSpec assert

* remove some spaces

* add docker service args list test case

* replace ast.literal_eval with json.loads

* remove json string representation

---------

Co-authored-by: Guy Driesen <[email protected]>
  • Loading branch information
2 people authored and romsharon98 committed Jul 26, 2024
1 parent eb53295 commit ecb5b98
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
19 changes: 19 additions & 0 deletions airflow/providers/docker/operators/docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import re
import shlex
from datetime import datetime
from time import sleep
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -58,6 +59,7 @@ class DockerSwarmOperator(DockerOperator):
container's process exits.
The default is False.
:param command: Command to be run in the container. (templated)
:param args: Arguments to the command.
:param docker_url: URL of the host running the docker daemon.
Default is the value of the ``DOCKER_HOST`` environment variable or unix://var/run/docker.sock
if it is unset.
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(
self,
*,
image: str,
args: str | list[str] | None = None,
enable_logging: bool = True,
configs: list[types.ConfigReference] | None = None,
secrets: list[types.SecretReference] | None = None,
Expand All @@ -116,6 +119,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(image=image, **kwargs)
self.args = args
self.enable_logging = enable_logging
self.service = None
self.configs = configs
Expand All @@ -136,6 +140,7 @@ def _run_service(self) -> None:
container_spec=types.ContainerSpec(
image=self.image,
command=self.format_command(self.command),
args=self.format_args(self.args),
mounts=self.mounts,
env=self.environment,
user=self.user,
Expand Down Expand Up @@ -225,6 +230,20 @@ def stream_new_logs(last_line_logged, since=0):
sleep(2)
last_line_logged, last_timestamp = stream_new_logs(last_line_logged, since=last_timestamp)

@staticmethod
def format_args(args: list[str] | str | None) -> list[str] | None:
"""Retrieve args.
The args string is parsed to a list.
:param args: args to the docker service
:return: the args as list
"""
if isinstance(args, str):
return shlex.split(args)
return args

def on_kill(self) -> None:
if self.hook.client_created and self.service is not None:
self.log.info("Removing docker service: %s", self.service["ID"])
Expand Down
77 changes: 77 additions & 0 deletions tests/providers/docker/operators/test_docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _client_service_logs_effect():
types_mock.ContainerSpec.assert_called_once_with(
image="ubuntu:latest",
command="env",
args=None,
user="unittest",
mounts=[types.Mount(source="/host/path", target="/container/path", type="bind")],
tty=True,
Expand Down Expand Up @@ -254,3 +255,79 @@ def test_container_resources(self, types_mock, docker_api_client_patcher):
placement=None,
)
types_mock.Resources.assert_not_called()

@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
def test_service_args_str(self, types_mock, docker_api_client_patcher):
mock_obj = mock.Mock()

client_mock = mock.Mock(spec=APIClient)
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
types_mock.Resources.return_value = mock_obj

docker_api_client_patcher.return_value = client_mock

operator = DockerSwarmOperator(
image="ubuntu:latest",
command="env",
args="--show",
task_id="unittest",
auto_remove="success",
enable_logging=False,
)
operator.execute(None)

types_mock.ContainerSpec.assert_called_once_with(
image="ubuntu:latest",
command="env",
args=["--show"],
user=None,
mounts=[],
tty=False,
env={"AIRFLOW_TMP_DIR": "/tmp/airflow"},
configs=None,
secrets=None,
)

@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
def test_service_args_list(self, types_mock, docker_api_client_patcher):
mock_obj = mock.Mock()

client_mock = mock.Mock(spec=APIClient)
client_mock.create_service.return_value = {"ID": "some_id"}
client_mock.images.return_value = []
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.tasks.return_value = [{"Status": {"State": "complete"}}]
types_mock.TaskTemplate.return_value = mock_obj
types_mock.ContainerSpec.return_value = mock_obj
types_mock.RestartPolicy.return_value = mock_obj
types_mock.Resources.return_value = mock_obj

docker_api_client_patcher.return_value = client_mock

operator = DockerSwarmOperator(
image="ubuntu:latest",
command="env",
args=["--show"],
task_id="unittest",
auto_remove="success",
enable_logging=False,
)
operator.execute(None)

types_mock.ContainerSpec.assert_called_once_with(
image="ubuntu:latest",
command="env",
args=["--show"],
user=None,
mounts=[],
tty=False,
env={"AIRFLOW_TMP_DIR": "/tmp/airflow"},
configs=None,
secrets=None,
)

0 comments on commit ecb5b98

Please sign in to comment.