diff --git a/airflow/contrib/executors/kubernetes_executor.py b/airflow/contrib/executors/kubernetes_executor.py new file mode 100644 index 00000000000000..416b2d76ff29da --- /dev/null +++ b/airflow/contrib/executors/kubernetes_executor.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.executors import kubernetes_executor # noqa diff --git a/airflow/contrib/kubernetes/pod.py b/airflow/contrib/kubernetes/pod.py index 0ab36160ba5275..0ce580037498e4 100644 --- a/airflow/contrib/kubernetes/pod.py +++ b/airflow/contrib/kubernetes/pod.py @@ -19,7 +19,18 @@ import warnings # pylint: disable=unused-import -from airflow.kubernetes.pod import Port, Resources # noqa +from typing import List, Union + +from kubernetes.client import models as k8s + +from airflow.kubernetes.pod import Port, Resources # noqa +from airflow.kubernetes.volume import Volume +from airflow.kubernetes.volume_mount import VolumeMount +from airflow.kubernetes.secret import Secret + +from kubernetes.client.api_client import ApiClient + +api_client = ApiClient() warnings.warn( "This module is deprecated. Please use `airflow.kubernetes.pod`.", @@ -120,7 +131,7 @@ def __init__( self.affinity = affinity or {} self.hostnetwork = hostnetwork or False self.tolerations = tolerations or [] - self.security_context = security_context + self.security_context = security_context or {} self.configmaps = configmaps or [] self.pod_runtime_info_envs = pod_runtime_info_envs or [] self.dnspolicy = dnspolicy @@ -154,6 +165,7 @@ def to_v1_kubernetes_pod(self): dns_policy=self.dnspolicy, host_network=self.hostnetwork, tolerations=self.tolerations, + affinity=self.affinity, security_context=self.security_context, ) @@ -161,17 +173,18 @@ def to_v1_kubernetes_pod(self): spec=spec, metadata=meta, ) - for port in self.ports: + for port in _extract_ports(self.ports): pod = port.attach_to_pod(pod) - for volume in self.volumes: + volumes, _ = _extract_volumes_and_secrets(self.volumes, self.volume_mounts) + for volume in volumes: pod = volume.attach_to_pod(pod) - for volume_mount in self.volume_mounts: + for volume_mount in _extract_volume_mounts(self.volume_mounts): pod = volume_mount.attach_to_pod(pod) for secret in self.secrets: pod = secret.attach_to_pod(pod) for runtime_info in self.pod_runtime_info_envs: pod = runtime_info.attach_to_pod(pod) - pod = self.resources.attach_to_pod(pod) + pod = _extract_resources(self.resources).attach_to_pod(pod) return pod def as_dict(self): @@ -182,3 +195,115 @@ def as_dict(self): res['volumes'] = [volume.as_dict() for volume in res['volumes']] return res + + +def _extract_env_vars_and_secrets(env_vars): + result = {} + env_vars = env_vars or [] # type: List[Union[k8s.V1EnvVar, dict]] + secrets = [] + for env_var in env_vars: + if isinstance(env_var, k8s.V1EnvVar): + secret = _extract_env_secret(env_var) + if secret: + secrets.append(secret) + continue + env_var = api_client.sanitize_for_serialization(env_var) + result[env_var.get("name")] = env_var.get("value") + return result, secrets + + +def _extract_env_secret(env_var): + if env_var.value_from and env_var.value_from.secret_key_ref: + secret = env_var.value_from.secret_key_ref # type: k8s.V1SecretKeySelector + name = secret.name + key = secret.key + return Secret("env", deploy_target=env_var.name, secret=name, key=key) + return None + + +def _extract_ports(ports): + result = [] + ports = ports or [] # type: List[Union[k8s.V1ContainerPort, dict]] + for port in ports: + if isinstance(port, k8s.V1ContainerPort): + port = api_client.sanitize_for_serialization(port) + port = Port(name=port.get("name"), container_port=port.get("containerPort")) + elif not isinstance(port, Port): + port = Port(name=port.get("name"), container_port=port.get("containerPort")) + result.append(port) + return result + + +def _extract_resources(resources): + if isinstance(resources, k8s.V1ResourceRequirements): + requests = resources.requests + limits = resources.limits + return Resources( + request_memory=requests.get('memory', None), + request_cpu=requests.get('cpu', None), + request_ephemeral_storage=requests.get('ephemeral-storage', None), + limit_memory=limits.get('memory', None), + limit_cpu=limits.get('cpu', None), + limit_ephemeral_storage=limits.get('ephemeral-storage', None), + limit_gpu=limits.get('nvidia.com/gpu') + ) + elif isinstance(resources, Resources): + return resources + + +def _extract_security_context(security_context): + if isinstance(security_context, k8s.V1PodSecurityContext): + security_context = api_client.sanitize_for_serialization(security_context) + return security_context + + +def _extract_volume_mounts(volume_mounts): + result = [] + volume_mounts = volume_mounts or [] # type: List[Union[k8s.V1VolumeMount, dict]] + for volume_mount in volume_mounts: + if isinstance(volume_mount, k8s.V1VolumeMount): + volume_mount = api_client.sanitize_for_serialization(volume_mount) + volume_mount = VolumeMount( + name=volume_mount.get("name"), + mount_path=volume_mount.get("mountPath"), + sub_path=volume_mount.get("subPath"), + read_only=volume_mount.get("readOnly") + ) + elif not isinstance(volume_mount, VolumeMount): + volume_mount = VolumeMount( + name=volume_mount.get("name"), + mount_path=volume_mount.get("mountPath"), + sub_path=volume_mount.get("subPath"), + read_only=volume_mount.get("readOnly") + ) + + result.append(volume_mount) + return result + + +def _extract_volumes_and_secrets(volumes, volume_mounts): + result = [] + volumes = volumes or [] # type: List[Union[k8s.V1Volume, dict]] + secrets = [] + volume_mount_dict = { + volume_mount.name: volume_mount + for volume_mount in _extract_volume_mounts(volume_mounts) + } + for volume in volumes: + if isinstance(volume, k8s.V1Volume): + secret = _extract_volume_secret(volume, volume_mount_dict.get(volume.name, None)) + if secret: + secrets.append(secret) + continue + volume = api_client.sanitize_for_serialization(volume) + volume = Volume(name=volume.get("name"), configs=volume) + if not isinstance(volume, Volume): + volume = Volume(name=volume.get("name"), configs=volume) + result.append(volume) + return result, secrets + + +def _extract_volume_secret(volume, volume_mount): + if not volume.secret: + return None + return Secret("volume", volume_mount.mount_path, volume.name, volume.secret.secret_name) diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 7bbdc985ebcee6..3ad4222b59135f 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -417,6 +417,12 @@ def run_next(self, next_job): kube_executor_config=kube_executor_config, worker_config=self.worker_configuration_pod ) + + sanitized_pod = self.launcher._client.api_client.sanitize_for_serialization(pod) + json_pod = json.dumps(sanitized_pod, indent=2) + + self.log.debug('Pod Creation Request before mutation: \n%s', json_pod) + # Reconcile the pod generated by the Operator and the Pod # generated by the .cfg file self.log.debug("Kubernetes running for command %s", command) diff --git a/airflow/kubernetes/pod.py b/airflow/kubernetes/pod.py index 9e455afe405e54..67dc98348263dd 100644 --- a/airflow/kubernetes/pod.py +++ b/airflow/kubernetes/pod.py @@ -20,7 +20,7 @@ import copy -import kubernetes.client.models as k8s +from kubernetes.client import models as k8s from airflow.kubernetes.k8s_model import K8SModel @@ -87,18 +87,25 @@ def has_requests(self): self.request_ephemeral_storage is not None def to_k8s_client_obj(self): - return k8s.V1ResourceRequirements( - limits={ - 'cpu': self.limit_cpu, - 'memory': self.limit_memory, - 'nvidia.com/gpu': self.limit_gpu, - 'ephemeral-storage': self.limit_ephemeral_storage - }, - requests={ - 'cpu': self.request_cpu, - 'memory': self.request_memory, - 'ephemeral-storage': self.request_ephemeral_storage} + limits_raw = { + 'cpu': self.limit_cpu, + 'memory': self.limit_memory, + 'nvidia.com/gpu': self.limit_gpu, + 'ephemeral-storage': self.limit_ephemeral_storage + } + requests_raw = { + 'cpu': self.request_cpu, + 'memory': self.request_memory, + 'ephemeral-storage': self.request_ephemeral_storage + } + + limits = {k: v for k, v in limits_raw.items() if v} + requests = {k: v for k, v in requests_raw.items() if v} + resource_req = k8s.V1ResourceRequirements( + limits=limits, + requests=requests ) + return resource_req def attach_to_pod(self, pod): cp_pod = copy.deepcopy(pod) diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index d11c1755716cd2..090e2b14c9f1e0 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -36,6 +36,7 @@ import kubernetes.client.models as k8s import yaml from kubernetes.client.api_client import ApiClient +from airflow.contrib.kubernetes.pod import _extract_volume_mounts from airflow.exceptions import AirflowConfigException from airflow.version import version as airflow_version @@ -249,7 +250,7 @@ def __init__( self.container.image_pull_policy = image_pull_policy self.container.ports = ports or [] self.container.resources = resources - self.container.volume_mounts = volume_mounts or [] + self.container.volume_mounts = [v.to_k8s_client_obj() for v in _extract_volume_mounts(volume_mounts)] # Pod Spec self.spec = k8s.V1PodSpec(containers=[]) @@ -370,6 +371,11 @@ def extract(cpu, memory, ephemeral_storage, limit_gpu=None): requests=requests, limits=limits ) + elif isinstance(resources, dict): + resources = k8s.V1ResourceRequirements( + requests=resources['requests'], + limits=resources['limits'] + ) annotations = namespaced.get('annotations', {}) gcp_service_account_key = namespaced.get('gcp_service_account_key', None) @@ -402,12 +408,35 @@ def reconcile_pods(base_pod, client_pod): client_pod_cp = copy.deepcopy(client_pod) client_pod_cp.spec = PodGenerator.reconcile_specs(base_pod.spec, client_pod_cp.spec) - - client_pod_cp.metadata = merge_objects(base_pod.metadata, client_pod_cp.metadata) + client_pod_cp.metadata = PodGenerator.reconcile_metadata(base_pod.metadata, client_pod_cp.metadata) client_pod_cp = merge_objects(base_pod, client_pod_cp) return client_pod_cp + @staticmethod + def reconcile_metadata(base_meta, client_meta): + """ + :param base_meta: has the base attributes which are overwritten if they exist + in the client_meta and remain if they do not exist in the client_meta + :type base_meta: k8s.V1ObjectMeta + :param client_meta: the spec that the client wants to create. + :type client_meta: k8s.V1ObjectMeta + :return: the merged specs + """ + if base_meta and not client_meta: + return base_meta + if not base_meta and client_meta: + return client_meta + elif client_meta and base_meta: + client_meta.labels = merge_objects(base_meta.labels, client_meta.labels) + client_meta.annotations = merge_objects(base_meta.annotations, client_meta.annotations) + extend_object_field(base_meta, client_meta, 'managed_fields') + extend_object_field(base_meta, client_meta, 'finalizers') + extend_object_field(base_meta, client_meta, 'owner_references') + return merge_objects(base_meta, client_meta) + + return None + @staticmethod def reconcile_specs(base_spec, client_spec): @@ -580,10 +609,17 @@ def merge_objects(base_obj, client_obj): client_obj_cp = copy.deepcopy(client_obj) + if isinstance(base_obj, dict) and isinstance(client_obj_cp, dict): + client_obj_cp.update(base_obj) + return client_obj_cp + for base_key in base_obj.to_dict().keys(): base_val = getattr(base_obj, base_key, None) if not getattr(client_obj, base_key, None) and base_val: - setattr(client_obj_cp, base_key, base_val) + if not isinstance(client_obj_cp, dict): + setattr(client_obj_cp, base_key, base_val) + else: + client_obj_cp[base_key] = base_val return client_obj_cp @@ -610,6 +646,36 @@ def extend_object_field(base_obj, client_obj, field_name): setattr(client_obj_cp, field_name, base_obj_field) return client_obj_cp - appended_fields = base_obj_field + client_obj_field + base_obj_set = _get_dict_from_list(base_obj_field) + client_obj_set = _get_dict_from_list(client_obj_field) + + appended_fields = _merge_list_of_objects(base_obj_set, client_obj_set) + setattr(client_obj_cp, field_name, appended_fields) return client_obj_cp + + +def _merge_list_of_objects(base_obj_set, client_obj_set): + for k, v in base_obj_set.items(): + if k not in client_obj_set: + client_obj_set[k] = v + else: + client_obj_set[k] = merge_objects(v, client_obj_set[k]) + appended_field_keys = sorted(client_obj_set.keys()) + appended_fields = [client_obj_set[k] for k in appended_field_keys] + return appended_fields + + +def _get_dict_from_list(base_list): + """ + :type base_list: list(Optional[dict, *to_dict]) + """ + result = {} + for obj in base_list: + if isinstance(obj, dict): + result[obj['name']] = obj + elif hasattr(obj, "to_dict"): + result[obj.name] = obj + else: + raise AirflowConfigException("Trying to merge invalid object {}".format(obj)) + return result diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py index d6507df7812387..39ed83697deae5 100644 --- a/airflow/kubernetes/pod_launcher.py +++ b/airflow/kubernetes/pod_launcher.py @@ -22,18 +22,21 @@ import tenacity from kubernetes import watch, client +from kubernetes.client import models as k8s from kubernetes.client.rest import ApiException from kubernetes.stream import stream as kubernetes_stream from requests.exceptions import BaseHTTPError from airflow import AirflowException -from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod -from airflow.kubernetes.pod_generator import PodDefaults from airflow import settings +from airflow.contrib.kubernetes.pod import ( + Pod, _extract_env_vars_and_secrets, _extract_volumes_and_secrets, _extract_volume_mounts, + _extract_ports, _extract_security_context +) +from airflow.kubernetes.kube_client import get_kube_client +from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State -import kubernetes.client.models as k8s # noqa -from .kube_client import get_kube_client class PodStatus: @@ -90,19 +93,22 @@ def run_pod_async(self, pod, **kwargs): def _mutate_pod_backcompat(pod): """Backwards compatible Pod Mutation Hook""" try: - settings.pod_mutation_hook(pod) - # attempts to run pod_mutation_hook using k8s.V1Pod, if this - # fails we attempt to run by converting pod to Old Pod - except AttributeError: + dummy_pod = _convert_to_airflow_pod(pod) + settings.pod_mutation_hook(dummy_pod) warnings.warn( "Using `airflow.contrib.kubernetes.pod.Pod` is deprecated. " "Please use `k8s.V1Pod` instead.", DeprecationWarning, stacklevel=2 ) - dummy_pod = convert_to_airflow_pod(pod) - settings.pod_mutation_hook(dummy_pod) dummy_pod = dummy_pod.to_v1_kubernetes_pod() - return dummy_pod - return pod + + new_pod = PodGenerator.reconcile_pods(pod, dummy_pod) + except AttributeError as e: + try: + settings.pod_mutation_hook(pod) + return pod + except AttributeError as e2: + raise Exception([e, e2]) + return new_pod def delete_pod(self, pod): """Deletes POD""" @@ -269,7 +275,7 @@ def _exec_pod_command(self, resp, command): return None def process_status(self, job_id, status): - """Process status infomration for the JOB""" + """Process status information for the JOB""" status = status.lower() if status == PodStatus.PENDING: return State.QUEUED @@ -284,3 +290,35 @@ def process_status(self, job_id, status): else: self.log.info('Event: Invalid state %s on job %s', status, job_id) return State.FAILED + + +def _convert_to_airflow_pod(pod): + base_container = pod.spec.containers[0] # type: k8s.V1Container + env_vars, secrets = _extract_env_vars_and_secrets(base_container.env) + volumes, vol_secrets = _extract_volumes_and_secrets(pod.spec.volumes, base_container.volume_mounts) + secrets.extend(vol_secrets) + dummy_pod = Pod( + image=base_container.image, + envs=env_vars, + cmds=base_container.command, + args=base_container.args, + labels=pod.metadata.labels, + annotations=pod.metadata.annotations, + node_selectors=pod.spec.node_selector, + name=pod.metadata.name, + ports=_extract_ports(base_container.ports), + volumes=volumes, + volume_mounts=_extract_volume_mounts(base_container.volume_mounts), + namespace=pod.metadata.namespace, + image_pull_policy=base_container.image_pull_policy or 'IfNotPresent', + tolerations=pod.spec.tolerations, + init_containers=pod.spec.init_containers, + image_pull_secrets=pod.spec.image_pull_secrets, + resources=base_container.resources, + service_account_name=pod.spec.service_account_name, + secrets=secrets, + affinity=pod.spec.affinity, + hostnetwork=pod.spec.host_network, + security_context=_extract_security_context(pod.spec.security_context) + ) + return dummy_pod diff --git a/airflow/kubernetes/pod_launcher_helper.py b/airflow/kubernetes/pod_launcher_helper.py deleted file mode 100644 index 8c9fc6ee7eec35..00000000000000 --- a/airflow/kubernetes/pod_launcher_helper.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import List, Union - -import kubernetes.client.models as k8s # noqa - -from airflow.kubernetes.volume import Volume -from airflow.kubernetes.volume_mount import VolumeMount -from airflow.kubernetes.pod import Port -from airflow.contrib.kubernetes.pod import Pod - - -def convert_to_airflow_pod(pod): - base_container = pod.spec.containers[0] # type: k8s.V1Container - - dummy_pod = Pod( - image=base_container.image, - envs=_extract_env_vars(base_container.env), - volumes=_extract_volumes(pod.spec.volumes), - volume_mounts=_extract_volume_mounts(base_container.volume_mounts), - labels=pod.metadata.labels, - name=pod.metadata.name, - namespace=pod.metadata.namespace, - image_pull_policy=base_container.image_pull_policy or 'IfNotPresent', - cmds=[], - ports=_extract_ports(base_container.ports) - ) - return dummy_pod - - -def _extract_env_vars(env_vars): - """ - - :param env_vars: - :type env_vars: list - :return: result - :rtype: dict - """ - result = {} - env_vars = env_vars or [] # type: List[Union[k8s.V1EnvVar, dict]] - for env_var in env_vars: - if isinstance(env_var, k8s.V1EnvVar): - env_var.to_dict() - result[env_var.get("name")] = env_var.get("value") - return result - - -def _extract_volumes(volumes): - result = [] - volumes = volumes or [] # type: List[Union[k8s.V1Volume, dict]] - for volume in volumes: - if isinstance(volume, k8s.V1Volume): - volume = volume.to_dict() - result.append(Volume(name=volume.get("name"), configs=volume)) - return result - - -def _extract_volume_mounts(volume_mounts): - result = [] - volume_mounts = volume_mounts or [] # type: List[Union[k8s.V1VolumeMount, dict]] - for volume_mount in volume_mounts: - if isinstance(volume_mount, k8s.V1VolumeMount): - volume_mount = volume_mount.to_dict() - result.append( - VolumeMount( - name=volume_mount.get("name"), - mount_path=volume_mount.get("mount_path"), - sub_path=volume_mount.get("sub_path"), - read_only=volume_mount.get("read_only")) - ) - - return result - - -def _extract_ports(ports): - result = [] - ports = ports or [] # type: List[Union[k8s.V1ContainerPort, dict]] - for port in ports: - if isinstance(port, k8s.V1ContainerPort): - port = port.to_dict() - result.append(Port(name=port.get("name"), container_port=port.get("container_port"))) - return result diff --git a/airflow/kubernetes/volume.py b/airflow/kubernetes/volume.py index 9d85959bf5b6c2..9e5e5c44dd1405 100644 --- a/airflow/kubernetes/volume.py +++ b/airflow/kubernetes/volume.py @@ -37,9 +37,15 @@ def __init__(self, name, configs): self.configs = configs def to_k8s_client_obj(self): - configs = self.configs - configs['name'] = self.name - return configs + from kubernetes.client import models as k8s + resp = k8s.V1Volume(name=self.name) + for k, v in self.configs.items(): + snake_key = Volume._convert_to_snake_case(k) + if hasattr(resp, snake_key): + setattr(resp, snake_key, v) + else: + raise AttributeError("V1Volume does not have attribute {}".format(k)) + return resp def attach_to_pod(self, pod): cp_pod = copy.deepcopy(pod) @@ -47,3 +53,8 @@ def attach_to_pod(self, pod): cp_pod.spec.volumes = pod.spec.volumes or [] cp_pod.spec.volumes.append(volume) return cp_pod + + # source: https://www.geeksforgeeks.org/python-program-to-convert-camel-case-string-to-snake-case/ + @staticmethod + def _convert_to_snake_case(str): + return ''.join(['_' + i.lower() if i.isupper() else i for i in str]).lstrip('_') diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 78b6a410e98929..392e0fcd5521bb 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -234,8 +234,8 @@ def __init__( python_version=None, # type: Optional[str] use_dill=False, # type: bool system_site_packages=True, # type: bool - op_args=None, # type: Iterable - op_kwargs=None, # type: Dict + op_args=None, # type: Optional[Iterable] + op_kwargs=None, # type: Optional[Dict] provide_context=False, # type: bool string_args=None, # type: Optional[Iterable[str]] templates_dict=None, # type: Optional[Dict] diff --git a/docs/conf.py b/docs/conf.py index d18b6ea2840c39..101d050e6a9c41 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -220,6 +220,7 @@ '_api/airflow/version', '_api/airflow/www', '_api/airflow/www_rbac', + '_api/kubernetes_executor', '_api/main', '_api/mesos_executor', 'autoapi_templates', diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index b6cecda1981805..50a12588edadf4 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -404,7 +404,6 @@ def test_pod_resources(self): 'limits': { 'memory': '64Mi', 'cpu': 0.25, - 'nvidia.com/gpu': None, 'ephemeral-storage': '2Gi' } } diff --git a/tests/kubernetes/models/test_pod.py b/tests/kubernetes/models/test_pod.py index 2e53d60438fa9a..8de33bf788c013 100644 --- a/tests/kubernetes/models/test_pod.py +++ b/tests/kubernetes/models/test_pod.py @@ -75,11 +75,16 @@ def test_port_attach_to_pod(self, mock_uuid): } }, result) - def test_to_v1_pod(self): + @mock.patch('uuid.uuid4') + def test_to_v1_pod(self, mock_uuid): from airflow.contrib.kubernetes.pod import Pod as DeprecatedPod from airflow.kubernetes.volume import Volume from airflow.kubernetes.volume_mount import VolumeMount + from airflow.kubernetes.secret import Secret from airflow.kubernetes.pod import Resources + import uuid + static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48') + mock_uuid.return_value = static_uuid pod = DeprecatedPod( image="foo", @@ -93,7 +98,14 @@ def test_to_v1_pod(self): request_cpu="100Mi", limit_gpu="100G" ), - volumes=[Volume(name="foo", configs={})], + volumes=[ + Volume(name="foo", configs={}), + {"name": "bar", 'secret': {'secretName': 'volume-secret'}} + ], + secrets=[ + Secret('env', "AIRFLOW_SECRET", 'secret_name', "airflow_config"), + Secret("volume", "/opt/airflow", "volume-secret", "secret-key") + ], volume_mounts=[VolumeMount(name="foo", mount_path="/mnt", sub_path="/", read_only=True)] ) @@ -103,55 +115,35 @@ def test_to_v1_pod(self): result = k8s_client.sanitize_for_serialization(result) expected = \ - { - 'metadata': - { - 'labels': {}, - 'name': 'bar', - 'namespace': 'baz' - }, - 'spec': - {'containers': - [ - { - 'args': [], - 'command': ['airflow'], - 'env': [{'name': 'test_key', 'value': 'test_value'}], - 'image': 'foo', - 'imagePullPolicy': 'Never', - 'name': 'base', - 'volumeMounts': - [ - { - 'mountPath': '/mnt', - 'name': 'foo', - 'readOnly': True, 'subPath': '/' - } - ], # noqa - 'resources': - { - 'limits': - { - 'cpu': None, - 'memory': None, - 'nvidia.com/gpu': '100G', - 'ephemeral-storage': None - }, - 'requests': - { - 'cpu': '100Mi', - 'memory': '1G', - 'ephemeral-storage': None - } - } - } - ], - 'hostNetwork': False, - 'tolerations': [], - 'volumes': [ - {'name': 'foo'} - ] - } - } + {'metadata': {'labels': {}, 'name': 'bar', 'namespace': 'baz'}, + 'spec': {'affinity': {}, + 'containers': [{'args': [], + 'command': ['airflow'], + 'env': [{'name': 'test_key', 'value': 'test_value'}, + {'name': 'AIRFLOW_SECRET', + 'valueFrom': {'secretKeyRef': {'key': 'airflow_config', + 'name': 'secret_name'}}}], + 'image': 'foo', + 'imagePullPolicy': 'Never', + 'name': 'base', + 'resources': {'limits': {'nvidia.com/gpu': '100G'}, + 'requests': {'cpu': '100Mi', + 'memory': '1G'}}, + 'volumeMounts': [{'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'}, + {'mountPath': '/opt/airflow', + 'name': 'secretvol' + str(static_uuid), + 'readOnly': True}]}], + 'hostNetwork': False, + 'securityContext': {}, + 'tolerations': [], + 'volumes': [{'name': 'foo'}, + {'name': 'bar', + 'secret': {'secretName': 'volume-secret'}}, + {'name': 'secretvol' + str(static_uuid), + 'secret': {'secretName': 'volume-secret'}} + ]}} self.maxDiff = None - self.assertEquals(expected, result) + self.assertEqual(expected, result) diff --git a/tests/kubernetes/models/test_volume.py b/tests/kubernetes/models/test_volume.py new file mode 100644 index 00000000000000..c1b8e29a83d257 --- /dev/null +++ b/tests/kubernetes/models/test_volume.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest + +from kubernetes.client import models as k8s + +from airflow.kubernetes.volume import Volume + + +class TestVolume(unittest.TestCase): + def test_to_k8s_object(self): + volume_config = { + 'persistentVolumeClaim': + { + 'claimName': 'test-volume' + } + } + volume = Volume(name='test-volume', configs=volume_config) + expected_volume = k8s.V1Volume( + name="test-volume", + persistent_volume_claim={ + "claimName": "test-volume" + } + ) + result = volume.to_k8s_client_obj() + self.assertEqual(result, expected_volume) diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py index d0faf4c4972ea2..bb714d4cb4a1a6 100644 --- a/tests/kubernetes/test_pod_generator.py +++ b/tests/kubernetes/test_pod_generator.py @@ -255,6 +255,20 @@ def test_from_obj(self, mock_uuid): "name": "example-kubernetes-test-volume", }, ], + "resources": { + "requests": { + "memory": "256Mi", + "cpu": "500m", + "ephemeral-storage": "2G", + "nvidia.com/gpu": "0" + }, + "limits": { + "memory": "512Mi", + "cpu": "1000m", + "ephemeral-storage": "2G", + "nvidia.com/gpu": "0" + } + } } }) result = self.k8s_client.sanitize_for_serialization(result) @@ -277,6 +291,92 @@ def test_from_obj(self, mock_uuid): 'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume' }], + "resources": { + "requests": { + "memory": "256Mi", + "cpu": "500m", + "ephemeral-storage": "2G", + "nvidia.com/gpu": "0" + }, + "limits": { + "memory": "512Mi", + "cpu": "1000m", + "ephemeral-storage": "2G", + "nvidia.com/gpu": "0" + } + } + }], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'volumes': [{ + 'hostPath': {'path': '/tmp/'}, + 'name': 'example-kubernetes-test-volume' + }], + } + }, result) + + @mock.patch('uuid.uuid4') + def test_from_obj_with_resources_object(self, mock_uuid): + mock_uuid.return_value = self.static_uuid + result = PodGenerator.from_obj({ + "KubernetesExecutor": { + "annotations": {"test": "annotation"}, + "volumes": [ + { + "name": "example-kubernetes-test-volume", + "hostPath": {"path": "/tmp/"}, + }, + ], + "volume_mounts": [ + { + "mountPath": "/foo/", + "name": "example-kubernetes-test-volume", + }, + ], + "resources": { + "requests": { + "memory": "256Mi", + "cpu": "500m", + "ephemeral-storage": "2G", + "nvidia.com/gpu": "0" + }, + "limits": { + "memory": "512Mi", + "cpu": "1000m", + "ephemeral-storage": "2G", + "nvidia.com/gpu": "0" + } + } + } + }) + result = self.k8s_client.sanitize_for_serialization(result) + + self.assertEqual({ + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': { + 'annotations': {'test': 'annotation'}, + }, + 'spec': { + 'containers': [{ + 'args': [], + 'command': [], + 'env': [], + 'envFrom': [], + 'name': 'base', + 'ports': [], + 'volumeMounts': [{ + 'mountPath': '/foo/', + 'name': 'example-kubernetes-test-volume' + }], + 'resources': {'limits': {'cpu': '1000m', + 'ephemeral-storage': '2G', + 'memory': '512Mi', + 'nvidia.com/gpu': '0'}, + 'requests': {'cpu': '500m', + 'ephemeral-storage': '2G', + 'memory': '256Mi', + 'nvidia.com/gpu': '0'}}, }], 'hostNetwork': False, 'imagePullSecrets': [], @@ -586,7 +686,7 @@ def test_construct_pod_empty_worker_config(self, mock_uuid): }, sanitized_result) @mock.patch('uuid.uuid4') - def test_construct_pod_empty_execuctor_config(self, mock_uuid): + def test_construct_pod_empty_executor_config(self, mock_uuid): mock_uuid.return_value = self.static_uuid worker_config = k8s.V1Pod( spec=k8s.V1PodSpec( @@ -731,6 +831,92 @@ def test_construct_pod(self, mock_uuid): } }, sanitized_result) + @mock.patch('uuid.uuid4') + def test_construct_pod_with_mutation(self, mock_uuid): + mock_uuid.return_value = self.static_uuid + worker_config = k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + name='gets-overridden-by-dynamic-args', + annotations={ + 'should': 'stay' + } + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name='doesnt-override', + resources=k8s.V1ResourceRequirements( + limits={ + 'cpu': '1m', + 'memory': '1G' + } + ), + security_context=k8s.V1SecurityContext( + run_as_user=1 + ) + ) + ] + ) + ) + executor_config = k8s.V1Pod( + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name='doesnt-override-either', + resources=k8s.V1ResourceRequirements( + limits={ + 'cpu': '2m', + 'memory': '2G' + } + ) + ) + ] + ) + ) + + result = PodGenerator.construct_pod( + 'dag_id', + 'task_id', + 'pod_id', + 3, + 'date', + ['command'], + executor_config, + worker_config, + 'namespace', + 'uuid', + ) + sanitized_result = self.k8s_client.sanitize_for_serialization(result) + + self.metadata.update({'annotations': {'should': 'stay'}}) + + self.assertEqual({ + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': self.metadata, + 'spec': { + 'containers': [{ + 'args': [], + 'command': ['command'], + 'env': [], + 'envFrom': [], + 'name': 'base', + 'ports': [], + 'resources': { + 'limits': { + 'cpu': '2m', + 'memory': '2G' + } + }, + 'volumeMounts': [], + 'securityContext': {'runAsUser': 1} + }], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'volumes': [] + } + }, sanitized_result) + def test_merge_objects_empty(self): annotations = {'foo1': 'bar1'} base_obj = k8s.V1ObjectMeta(annotations=annotations) @@ -901,3 +1087,21 @@ def test_validate_pod_generator(self): PodGenerator(image='k') PodGenerator(pod_template_file='tests/kubernetes/pod.yaml') PodGenerator(pod=k8s.V1Pod()) + + def test_add_custom_label(self): + from kubernetes.client import models as k8s + + pod = PodGenerator.construct_pod( + namespace="test", + worker_uuid="test", + pod_id="test", + dag_id="test", + task_id="test", + try_number=1, + date="23-07-2020", + command="test", + kube_executor_config=None, + worker_config=k8s.V1Pod(metadata=k8s.V1ObjectMeta(labels={"airflow-test": "airflow-task-pod"}, + annotations={"my.annotation": "foo"}))) + self.assertIn("airflow-test", pod.metadata.labels) + self.assertIn("my.annotation", pod.metadata.annotations) diff --git a/tests/kubernetes/test_pod_launcher.py b/tests/kubernetes/test_pod_launcher.py index 09ba339da655c7..c86a6cf7fe8345 100644 --- a/tests/kubernetes/test_pod_launcher.py +++ b/tests/kubernetes/test_pod_launcher.py @@ -16,11 +16,17 @@ # under the License. import unittest import mock +from kubernetes.client import models as k8s from requests.exceptions import BaseHTTPError from airflow import AirflowException -from airflow.kubernetes.pod_launcher import PodLauncher +from airflow.contrib.kubernetes.pod import Pod +from airflow.kubernetes.pod import Port +from airflow.kubernetes.pod_launcher import PodLauncher, _convert_to_airflow_pod +from airflow.kubernetes.volume import Volume +from airflow.kubernetes.secret import Secret +from airflow.kubernetes.volume_mount import VolumeMount class TestPodLauncher(unittest.TestCase): @@ -162,3 +168,132 @@ def test_read_pod_retries_fails(self): self.pod_launcher.read_pod, mock.sentinel ) + + +class TestPodLauncherHelper(unittest.TestCase): + def test_convert_to_airflow_pod(self): + input_pod = k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + name="foo", + namespace="bar" + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name="base", + command=["foo"], + image="myimage", + env=[ + k8s.V1EnvVar( + name="AIRFLOW_SECRET", + value_from=k8s.V1EnvVarSource( + secret_key_ref=k8s.V1SecretKeySelector( + name="ai", + key="secret_key" + ) + )) + ], + ports=[ + k8s.V1ContainerPort( + name="myport", + container_port=8080, + ) + ], + volume_mounts=[ + k8s.V1VolumeMount( + name="myvolume", + mount_path="/tmp/mount", + read_only="True" + ), + k8s.V1VolumeMount( + name='airflow-config', + mount_path='/config', + sub_path='airflow.cfg', + read_only=True + ), + k8s.V1VolumeMount( + name="airflow-secret", + mount_path="/opt/mount", + read_only=True + )] + ) + ], + security_context=k8s.V1PodSecurityContext( + run_as_user=0, + fs_group=0, + ), + volumes=[ + k8s.V1Volume( + name="myvolume" + ), + k8s.V1Volume( + name="airflow-config", + config_map=k8s.V1ConfigMap( + data="airflow-data" + ) + ), + k8s.V1Volume( + name="airflow-secret", + secret=k8s.V1SecretVolumeSource( + secret_name="secret-name", + + ) + ) + ] + ) + ) + result_pod = _convert_to_airflow_pod(input_pod) + + expected = Pod( + name="foo", + namespace="bar", + envs={}, + cmds=["foo"], + image="myimage", + ports=[ + Port(name="myport", container_port=8080) + ], + volume_mounts=[ + VolumeMount( + name="myvolume", + mount_path="/tmp/mount", + sub_path=None, + read_only="True" + ), + VolumeMount( + name="airflow-config", + read_only=True, + mount_path="/config", + sub_path="airflow.cfg" + ), + VolumeMount( + name="airflow-secret", + read_only=True, + mount_path="/opt/mount", + sub_path=None, + )], + secrets=[Secret("env", "AIRFLOW_SECRET", "ai", "secret_key"), + Secret('volume', '/opt/mount', 'airflow-secret', "secret-name")], + security_context={'fsGroup': 0, 'runAsUser': 0}, + volumes=[Volume(name="myvolume", configs={'name': 'myvolume'}), + Volume(name="airflow-config", configs={'configMap': {'data': 'airflow-data'}, + 'name': 'airflow-config'})] + ) + expected_dict = expected.as_dict() + result_dict = result_pod.as_dict() + parsed_configs = self.pull_out_volumes(result_dict) + result_dict['volumes'] = parsed_configs + self.assertDictEqual(expected_dict, result_dict) + + @staticmethod + def pull_out_volumes(result_dict): + parsed_configs = [] + for volume in result_dict['volumes']: + vol = {'name': volume['name']} + confs = {} + for k, v in volume['configs'].items(): + if v and k[0] != '_': + confs[k] = v + vol['configs'] = confs + parsed_configs.append(vol) + return parsed_configs diff --git a/tests/kubernetes/test_pod_launcher_helper.py b/tests/kubernetes/test_pod_launcher_helper.py deleted file mode 100644 index 761d138f4bf82b..00000000000000 --- a/tests/kubernetes/test_pod_launcher_helper.py +++ /dev/null @@ -1,98 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import unittest - -from airflow.kubernetes.pod import Port -from airflow.kubernetes.volume_mount import VolumeMount -from airflow.kubernetes.volume import Volume -from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod -from airflow.contrib.kubernetes.pod import Pod -import kubernetes.client.models as k8s - - -class TestPodLauncherHelper(unittest.TestCase): - def test_convert_to_airflow_pod(self): - input_pod = k8s.V1Pod( - metadata=k8s.V1ObjectMeta( - name="foo", - namespace="bar" - ), - spec=k8s.V1PodSpec( - containers=[ - k8s.V1Container( - name="base", - command="foo", - image="myimage", - ports=[ - k8s.V1ContainerPort( - name="myport", - container_port=8080, - ) - ], - volume_mounts=[k8s.V1VolumeMount( - name="mymount", - mount_path="/tmp/mount", - read_only="True" - )] - ) - ], - volumes=[ - k8s.V1Volume( - name="myvolume" - ) - ] - ) - ) - result_pod = convert_to_airflow_pod(input_pod) - - expected = Pod( - name="foo", - namespace="bar", - envs={}, - cmds=[], - image="myimage", - ports=[ - Port(name="myport", container_port=8080) - ], - volume_mounts=[VolumeMount( - name="mymount", - mount_path="/tmp/mount", - sub_path=None, - read_only="True" - )], - volumes=[Volume(name="myvolume", configs={'name': 'myvolume'})] - ) - expected_dict = expected.as_dict() - result_dict = result_pod.as_dict() - parsed_configs = self.pull_out_volumes(result_dict) - result_dict['volumes'] = parsed_configs - self.maxDiff = None - - self.assertDictEqual(expected_dict, result_dict) - - @staticmethod - def pull_out_volumes(result_dict): - parsed_configs = [] - for volume in result_dict['volumes']: - vol = {'name': volume['name']} - confs = {} - for k, v in volume['configs'].items(): - if v and k[0] != '_': - confs[k] = v - vol['configs'] = confs - parsed_configs.append(vol) - return parsed_configs diff --git a/tests/kubernetes/test_worker_configuration.py b/tests/kubernetes/test_worker_configuration.py index a94a1124941648..0273ae89260464 100644 --- a/tests/kubernetes/test_worker_configuration.py +++ b/tests/kubernetes/test_worker_configuration.py @@ -173,6 +173,13 @@ def test_worker_environment_no_dags_folder(self): self.assertNotIn('AIRFLOW__CORE__DAGS_FOLDER', env) + @conf_vars({ + ('kubernetes', 'airflow_configmap'): 'airflow-configmap'}) + def test_worker_adds_config(self): + worker_config = WorkerConfiguration(self.kube_config) + volumes = worker_config._get_volumes() + print(volumes) + def test_worker_environment_when_dags_folder_specified(self): self.kube_config.airflow_configmap = 'airflow-configmap' self.kube_config.git_dags_folder_mount_point = '' diff --git a/tests/test_local_settings.py b/tests/test_local_settings.py deleted file mode 100644 index 0e45ad8f5c0e2e..00000000000000 --- a/tests/test_local_settings.py +++ /dev/null @@ -1,269 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -import sys -import tempfile -import unittest -from airflow.kubernetes import pod_generator -from tests.compat import MagicMock, Mock, call, patch - - -SETTINGS_FILE_POLICY = """ -def test_policy(task_instance): - task_instance.run_as_user = "myself" -""" - -SETTINGS_FILE_POLICY_WITH_DUNDER_ALL = """ -__all__ = ["test_policy"] - -def test_policy(task_instance): - task_instance.run_as_user = "myself" - -def not_policy(): - print("This shouldn't be imported") -""" - -SETTINGS_FILE_POD_MUTATION_HOOK = """ -from airflow.kubernetes.volume import Volume -from airflow.kubernetes.pod import Port, Resources - -def pod_mutation_hook(pod): - pod.namespace = 'airflow-tests' - pod.image = 'my_image' - pod.volumes.append(Volume(name="bar", configs={})) - pod.ports = [Port(container_port=8080)] - pod.resources = Resources( - request_memory="2G", - request_cpu="200Mi", - limit_gpu="200G" - ) - -""" - -SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD = """ -def pod_mutation_hook(pod): - pod.spec.containers[0].image = "test-image" - -""" - - -class SettingsContext: - def __init__(self, content, module_name): - self.content = content - self.settings_root = tempfile.mkdtemp() - filename = "{}.py".format(module_name) - self.settings_file = os.path.join(self.settings_root, filename) - - def __enter__(self): - with open(self.settings_file, 'w') as handle: - handle.writelines(self.content) - sys.path.append(self.settings_root) - return self.settings_file - - def __exit__(self, *exc_info): - sys.path.remove(self.settings_root) - - -class LocalSettingsTest(unittest.TestCase): - # Make sure that the configure_logging is not cached - def setUp(self): - self.old_modules = dict(sys.modules) - - def tearDown(self): - # Remove any new modules imported during the test run. This lets us - # import the same source files for more than one test. - for mod in [m for m in sys.modules if m not in self.old_modules]: - del sys.modules[mod] - - @patch("airflow.settings.import_local_settings") - @patch("airflow.settings.prepare_syspath") - def test_initialize_order(self, prepare_syspath, import_local_settings): - """ - Tests that import_local_settings is called after prepare_classpath - """ - mock = Mock() - mock.attach_mock(prepare_syspath, "prepare_syspath") - mock.attach_mock(import_local_settings, "import_local_settings") - - import airflow.settings - airflow.settings.initialize() - - mock.assert_has_calls([call.prepare_syspath(), call.import_local_settings()]) - - def test_import_with_dunder_all_not_specified(self): - """ - Tests that if __all__ is specified in airflow_local_settings, - only module attributes specified within are imported. - """ - with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"): - from airflow import settings - settings.import_local_settings() # pylint: ignore - - with self.assertRaises(AttributeError): - settings.not_policy() - - def test_import_with_dunder_all(self): - """ - Tests that if __all__ is specified in airflow_local_settings, - only module attributes specified within are imported. - """ - with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"): - from airflow import settings - settings.import_local_settings() # pylint: ignore - - task_instance = MagicMock() - settings.test_policy(task_instance) - - assert task_instance.run_as_user == "myself" - - @patch("airflow.settings.log.debug") - def test_import_local_settings_without_syspath(self, log_mock): - """ - Tests that an ImportError is raised in import_local_settings - if there is no airflow_local_settings module on the syspath. - """ - from airflow import settings - settings.import_local_settings() - log_mock.assert_called_with("Failed to import airflow_local_settings.", exc_info=True) - - def test_policy_function(self): - """ - Tests that task instances are mutated by the policy - function in airflow_local_settings. - """ - with SettingsContext(SETTINGS_FILE_POLICY, "airflow_local_settings"): - from airflow import settings - settings.import_local_settings() # pylint: ignore - - task_instance = MagicMock() - settings.test_policy(task_instance) - - assert task_instance.run_as_user == "myself" - - def test_pod_mutation_hook(self): - """ - Tests that pods are mutated by the pod_mutation_hook - function in airflow_local_settings. - """ - with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"): - from airflow import settings - settings.import_local_settings() # pylint: ignore - - pod = MagicMock() - pod.volumes = [] - settings.pod_mutation_hook(pod) - - assert pod.namespace == 'airflow-tests' - self.assertEqual(pod.volumes[0].name, "bar") - - def test_pod_mutation_to_k8s_pod(self): - with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"): - from airflow import settings - settings.import_local_settings() # pylint: ignore - from airflow.kubernetes.pod_launcher import PodLauncher - - self.mock_kube_client = Mock() - self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) - pod = pod_generator.PodGenerator( - image="foo", - name="bar", - namespace="baz", - image_pull_policy="Never", - cmds=["foo"], - volume_mounts=[ - {"name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True"} - ], - volumes=[{"name": "foo"}] - ).gen_pod() - - self.assertEqual(pod.metadata.namespace, "baz") - self.assertEqual(pod.spec.containers[0].image, "foo") - self.assertEqual(pod.spec.volumes, [{'name': 'foo'}]) - self.assertEqual(pod.spec.containers[0].ports, []) - self.assertEqual(pod.spec.containers[0].resources, None) - - pod = self.pod_launcher._mutate_pod_backcompat(pod) - - self.assertEqual(pod.metadata.namespace, "airflow-tests") - self.assertEqual(pod.spec.containers[0].image, "my_image") - self.assertEqual(pod.spec.volumes, [{'name': 'foo'}, {'name': 'bar'}]) - self.maxDiff = None - self.assertEqual( - pod.spec.containers[0].ports[0].to_dict(), - { - "container_port": 8080, - "host_ip": None, - "host_port": None, - "name": None, - "protocol": None - } - ) - self.assertEqual( - pod.spec.containers[0].resources.to_dict(), - { - 'limits': { - 'cpu': None, - 'memory': None, - 'ephemeral-storage': None, - 'nvidia.com/gpu': '200G'}, - 'requests': {'cpu': '200Mi', 'ephemeral-storage': None, 'memory': '2G'} - } - ) - - def test_pod_mutation_v1_pod(self): - with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD, "airflow_local_settings"): - from airflow import settings - settings.import_local_settings() # pylint: ignore - from airflow.kubernetes.pod_launcher import PodLauncher - - self.mock_kube_client = Mock() - self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) - pod = pod_generator.PodGenerator( - image="myimage", - cmds=["foo"], - volume_mounts={ - "name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True" - }, - volumes=[{"name": "foo"}] - ).gen_pod() - - self.assertEqual(pod.spec.containers[0].image, "myimage") - pod = self.pod_launcher._mutate_pod_backcompat(pod) - self.assertEqual(pod.spec.containers[0].image, "test-image") - - -class TestStatsWithAllowList(unittest.TestCase): - - def setUp(self): - from airflow.settings import SafeStatsdLogger, AllowListValidator - self.statsd_client = Mock() - self.stats = SafeStatsdLogger(self.statsd_client, AllowListValidator("stats_one, stats_two")) - - def test_increment_counter_with_allowed_key(self): - self.stats.incr('stats_one') - self.statsd_client.incr.assert_called_once_with('stats_one', 1, 1) - - def test_increment_counter_with_allowed_prefix(self): - self.stats.incr('stats_two.bla') - self.statsd_client.incr.assert_called_once_with('stats_two.bla', 1, 1) - - def test_not_increment_counter_if_not_allowed(self): - self.stats.incr('stats_three') - self.statsd_client.assert_not_called() diff --git a/tests/test_local_settings/__init__.py b/tests/test_local_settings/__init__.py new file mode 100644 index 00000000000000..13a83393a9124b --- /dev/null +++ b/tests/test_local_settings/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/test_local_settings/test_local_settings.py b/tests/test_local_settings/test_local_settings.py new file mode 100644 index 00000000000000..ece813d12c4eb4 --- /dev/null +++ b/tests/test_local_settings/test_local_settings.py @@ -0,0 +1,441 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +import sys +import tempfile +import unittest +from airflow.kubernetes import pod_generator +from kubernetes.client import ApiClient +import kubernetes.client.models as k8s +from tests.compat import MagicMock, Mock, mock, call, patch + +api_client = ApiClient() + +SETTINGS_FILE_POLICY = """ +def test_policy(task_instance): + task_instance.run_as_user = "myself" +""" + +SETTINGS_FILE_POLICY_WITH_DUNDER_ALL = """ +__all__ = ["test_policy"] + +def test_policy(task_instance): + task_instance.run_as_user = "myself" + +def not_policy(): + print("This shouldn't be imported") +""" + +SETTINGS_FILE_POD_MUTATION_HOOK = """ +from airflow.kubernetes.volume import Volume +from airflow.kubernetes.pod import Port, Resources + +def pod_mutation_hook(pod): + pod.namespace = 'airflow-tests' + pod.image = 'my_image' + pod.volumes.append(Volume(name="bar", configs={})) + pod.ports = [Port(container_port=8080), {"containerPort": 8081}] + pod.resources = Resources( + request_memory="2G", + request_cpu="200Mi", + limit_gpu="200G" + ) + + secret_volume = { + "name": "airflow-secrets-mount", + "secret": { + "secretName": "airflow-test-secrets" + } + } + secret_volume_mount = { + "name": "airflow-secrets-mount", + "readOnly": True, + "mountPath": "/opt/airflow/secrets/" + } + + pod.volumes.append(secret_volume) + pod.volume_mounts.append(secret_volume_mount) + + pod.labels.update({"test_label": "test_value"}) + pod.envs.update({"TEST_USER": "ADMIN"}) + + pod.tolerations += [ + {"key": "dynamic-pods", "operator": "Equal", "value": "true", "effect": "NoSchedule"} + ] + pod.affinity.update( + {"nodeAffinity": + {"requiredDuringSchedulingIgnoredDuringExecution": + {"nodeSelectorTerms": + [{ + "matchExpressions": [ + {"key": "test/dynamic-pods", "operator": "In", "values": ["true"]} + ] + }] + } + } + } + ) + + if 'fsGroup' in pod.security_context and pod.security_context['fsGroup'] == 0 : + del pod.security_context['fsGroup'] + if 'runAsUser' in pod.security_context and pod.security_context['runAsUser'] == 0 : + del pod.security_context['runAsUser'] + + if pod.args and pod.args[0] == "/bin/sh": + pod.args = ['/bin/sh', '-c', 'touch /tmp/healthy2'] + +""" + +SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD = """ +def pod_mutation_hook(pod): + from kubernetes.client import models as k8s + secret_volume = { + "name": "airflow-secrets-mount", + "secret": { + "secretName": "airflow-test-secrets" + } + } + secret_volume_mount = { + "name": "airflow-secrets-mount", + "readOnly": True, + "mountPath": "/opt/airflow/secrets/" + } + base_container = pod.spec.containers[0] + base_container.image = "test-image" + base_container.volume_mounts.append(secret_volume_mount) + base_container.env.extend([{'name': 'TEST_USER', 'value': 'ADMIN'}]) + base_container.ports.extend([{'containerPort': 8080}, k8s.V1ContainerPort(container_port=8081)]) + + pod.spec.volumes.append(secret_volume) + pod.metadata.namespace = 'airflow-tests' + +""" + + +class SettingsContext: + def __init__(self, content, module_name): + self.content = content + self.settings_root = tempfile.mkdtemp() + filename = "{}.py".format(module_name) + self.settings_file = os.path.join(self.settings_root, filename) + + def __enter__(self): + with open(self.settings_file, 'w') as handle: + handle.writelines(self.content) + sys.path.append(self.settings_root) + return self.settings_file + + def __exit__(self, *exc_info): + sys.path.remove(self.settings_root) + + +class LocalSettingsTest(unittest.TestCase): + # Make sure that the configure_logging is not cached + def setUp(self): + self.old_modules = dict(sys.modules) + self.maxDiff = None + + def tearDown(self): + # Remove any new modules imported during the test run. This lets us + # import the same source files for more than one test. + for mod in [m for m in sys.modules if m not in self.old_modules]: + del sys.modules[mod] + + @patch("airflow.settings.import_local_settings") + @patch("airflow.settings.prepare_syspath") + def test_initialize_order(self, prepare_syspath, import_local_settings): + """ + Tests that import_local_settings is called after prepare_classpath + """ + mock = Mock() + mock.attach_mock(prepare_syspath, "prepare_syspath") + mock.attach_mock(import_local_settings, "import_local_settings") + + import airflow.settings + airflow.settings.initialize() + + mock.assert_has_calls([call.prepare_syspath(), call.import_local_settings()]) + + def test_import_with_dunder_all_not_specified(self): + """ + Tests that if __all__ is specified in airflow_local_settings, + only module attributes specified within are imported. + """ + with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"): + from airflow import settings + settings.import_local_settings() # pylint: ignore + + with self.assertRaises(AttributeError): + settings.not_policy() + + def test_import_with_dunder_all(self): + """ + Tests that if __all__ is specified in airflow_local_settings, + only module attributes specified within are imported. + """ + with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"): + from airflow import settings + settings.import_local_settings() # pylint: ignore + + task_instance = MagicMock() + settings.test_policy(task_instance) + + assert task_instance.run_as_user == "myself" + + @patch("airflow.settings.log.debug") + def test_import_local_settings_without_syspath(self, log_mock): + """ + Tests that an ImportError is raised in import_local_settings + if there is no airflow_local_settings module on the syspath. + """ + from airflow import settings + settings.import_local_settings() + log_mock.assert_called_with("Failed to import airflow_local_settings.", exc_info=True) + + def test_policy_function(self): + """ + Tests that task instances are mutated by the policy + function in airflow_local_settings. + """ + with SettingsContext(SETTINGS_FILE_POLICY, "airflow_local_settings"): + from airflow import settings + settings.import_local_settings() # pylint: ignore + + task_instance = MagicMock() + settings.test_policy(task_instance) + + assert task_instance.run_as_user == "myself" + + def test_pod_mutation_hook(self): + """ + Tests that pods are mutated by the pod_mutation_hook + function in airflow_local_settings. + """ + with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"): + from airflow import settings + settings.import_local_settings() # pylint: ignore + + pod = MagicMock() + pod.volumes = [] + settings.pod_mutation_hook(pod) + + assert pod.namespace == 'airflow-tests' + self.assertEqual(pod.volumes[0].name, "bar") + + def test_pod_mutation_to_k8s_pod(self): + with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"): + from airflow import settings + settings.import_local_settings() # pylint: ignore + from airflow.kubernetes.pod_launcher import PodLauncher + + self.mock_kube_client = Mock() + self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) + pod = pod_generator.PodGenerator( + image="foo", + name="bar", + namespace="baz", + image_pull_policy="Never", + cmds=["foo"], + args=["/bin/sh", "-c", "touch /tmp/healthy"], + tolerations=[ + {'effect': 'NoSchedule', + 'key': 'static-pods', + 'operator': 'Equal', + 'value': 'true'} + ], + volume_mounts=[ + {"name": "foo", "mountPath": "/mnt", "subPath": "/", "readOnly": True} + ], + security_context=k8s.V1PodSecurityContext(fs_group=0, run_as_user=1), + volumes=[k8s.V1Volume(name="foo")] + ).gen_pod() + + sanitized_pod_pre_mutation = api_client.sanitize_for_serialization(pod) + self.assertEqual( + sanitized_pod_pre_mutation, + {'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'name': mock.ANY, + 'namespace': 'baz'}, + 'spec': {'containers': [{'args': ['/bin/sh', '-c', 'touch /tmp/healthy'], + 'command': ['foo'], + 'env': [], + 'envFrom': [], + 'image': 'foo', + 'imagePullPolicy': 'Never', + 'name': 'base', + 'ports': [], + 'volumeMounts': [{'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'}]}], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'tolerations': [{'effect': 'NoSchedule', + 'key': 'static-pods', + 'operator': 'Equal', + 'value': 'true'}], + 'volumes': [{'name': 'foo'}], + 'securityContext': {'fsGroup': 0, 'runAsUser': 1}}}, + ) + + # Apply Pod Mutation Hook + pod = self.pod_launcher._mutate_pod_backcompat(pod) + + sanitized_pod_post_mutation = api_client.sanitize_for_serialization(pod) + + self.assertEqual( + sanitized_pod_post_mutation, + {"apiVersion": "v1", + "kind": "Pod", + 'metadata': {'labels': {'test_label': 'test_value'}, + 'name': mock.ANY, + 'namespace': 'airflow-tests'}, + 'spec': {'affinity': {'nodeAffinity': {'requiredDuringSchedulingIgnoredDuringExecution': { + 'nodeSelectorTerms': [{'matchExpressions': [{'key': 'test/dynamic-pods', + 'operator': 'In', + 'values': ['true']}]}]}}}, + 'containers': [{'args': ['/bin/sh', '-c', 'touch /tmp/healthy2'], + 'command': ['foo'], + 'env': [{'name': 'TEST_USER', 'value': 'ADMIN'}], + 'image': 'my_image', + 'imagePullPolicy': 'Never', + 'name': 'base', + 'ports': [{'containerPort': 8080}, + {'containerPort': 8081}], + 'resources': {'limits': {'nvidia.com/gpu': '200G'}, + 'requests': {'cpu': '200Mi', + 'memory': '2G'}}, + 'volumeMounts': [{'mountPath': '/opt/airflow/secrets/', + 'name': 'airflow-secrets-mount', + 'readOnly': True}, + {'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'} + ]}], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'tolerations': [{'effect': 'NoSchedule', + 'key': 'static-pods', + 'operator': 'Equal', + 'value': 'true'}, + {'effect': 'NoSchedule', + 'key': 'dynamic-pods', + 'operator': 'Equal', + 'value': 'true'}], + 'volumes': [{'name': 'airflow-secrets-mount', + 'secret': {'secretName': 'airflow-test-secrets'}}, + {'name': 'bar'}, + {'name': 'foo'}, + ], + 'securityContext': {'runAsUser': 1}}} + ) + + def test_pod_mutation_v1_pod(self): + with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD, "airflow_local_settings"): + from airflow import settings + settings.import_local_settings() # pylint: ignore + from airflow.kubernetes.pod_launcher import PodLauncher + + self.mock_kube_client = Mock() + self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) + pod = pod_generator.PodGenerator( + image="myimage", + cmds=["foo"], + namespace="baz", + volume_mounts=[ + {"name": "foo", "mountPath": "/mnt", "subPath": "/", "readOnly": True} + ], + volumes=[{"name": "foo"}] + ).gen_pod() + + sanitized_pod_pre_mutation = api_client.sanitize_for_serialization(pod) + + self.assertEqual( + sanitized_pod_pre_mutation, + {'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'namespace': 'baz'}, + 'spec': {'containers': [{'args': [], + 'command': ['foo'], + 'env': [], + 'envFrom': [], + 'image': 'myimage', + 'name': 'base', + 'ports': [], + 'volumeMounts': [{'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'}]}], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'volumes': [{'name': 'foo'}]}} + ) + + # Apply Pod Mutation Hook + pod = self.pod_launcher._mutate_pod_backcompat(pod) + + sanitized_pod_post_mutation = api_client.sanitize_for_serialization(pod) + self.assertEqual( + sanitized_pod_post_mutation, + {'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'namespace': 'airflow-tests'}, + 'spec': {'containers': [{'args': [], + 'command': ['foo'], + 'env': [{'name': 'TEST_USER', 'value': 'ADMIN'}], + 'envFrom': [], + 'image': 'test-image', + 'name': 'base', + 'ports': [{'containerPort': 8080}, {'containerPort': 8081}], + 'volumeMounts': [{'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'}, + {'mountPath': '/opt/airflow/secrets/', + 'name': 'airflow-secrets-mount', + 'readOnly': True}]}], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'volumes': [{'name': 'foo'}, + {'name': 'airflow-secrets-mount', + 'secret': {'secretName': 'airflow-test-secrets'}}]}} + ) + + +class TestStatsWithAllowList(unittest.TestCase): + + def setUp(self): + from airflow.settings import SafeStatsdLogger, AllowListValidator + self.statsd_client = Mock() + self.stats = SafeStatsdLogger(self.statsd_client, AllowListValidator("stats_one, stats_two")) + + def test_increment_counter_with_allowed_key(self): + self.stats.incr('stats_one') + self.statsd_client.incr.assert_called_once_with('stats_one', 1, 1) + + def test_increment_counter_with_allowed_prefix(self): + self.stats.incr('stats_two.bla') + self.statsd_client.incr.assert_called_once_with('stats_two.bla', 1, 1) + + def test_not_increment_counter_if_not_allowed(self): + self.stats.incr('stats_three') + self.statsd_client.assert_not_called()