Skip to content

Commit

Permalink
Add GKECreateCustomResourceOperator and GKEDeleteCustomResourceOperat…
Browse files Browse the repository at this point in the history
…or operators (apache#37616)

* Add GKECreateCustomResource operator

* Add GKEDeleteCustomResourceOperator operator

* Add 'yaml_conf_file' path option for operators

---------

Co-authored-by: Ulada Zakharava <[email protected]>
  • Loading branch information
2 people authored and utkarsharma2 committed Apr 22, 2024
1 parent ca25f25 commit 3b3c832
Show file tree
Hide file tree
Showing 8 changed files with 710 additions and 68 deletions.
60 changes: 47 additions & 13 deletions airflow/providers/cncf/kubernetes/operators/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

from __future__ import annotations

import os
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

import yaml
from kubernetes.utils import create_from_yaml

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.utils.delete_from import delete_from_yaml
Expand All @@ -40,34 +42,44 @@ class KubernetesResourceBaseOperator(BaseOperator):
Abstract base class for all Kubernetes Resource operators.
:param yaml_conf: string. Contains the kubernetes resources to Create or Delete
:param yaml_conf_file: path to the kubernetes resources file (templated)
:param namespace: string. Contains the namespace to create all resources inside.
The namespace must preexist otherwise the resource creation will fail.
If the API object in the yaml file already contains a namespace definition then
this parameter has no effect.
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
for the Kubernetes cluster.
:param namespaced: specified that Kubernetes resource is or isn't in a namespace.
This parameter works only when custom_resource_definition parameter is True.
"""

template_fields = ("yaml_conf",)
template_fields: Sequence[str] = ("yaml_conf", "yaml_conf_file")
template_fields_renderers = {"yaml_conf": "yaml"}

def __init__(
self,
*,
yaml_conf: str,
yaml_conf: str | None = None,
yaml_conf_file: str | None = None,
namespace: str | None = None,
kubernetes_conn_id: str | None = KubernetesHook.default_conn_name,
custom_resource_definition: bool = False,
namespaced: bool = True,
config_file: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self._namespace = namespace
self.kubernetes_conn_id = kubernetes_conn_id
self.yaml_conf = yaml_conf
self.yaml_conf_file = yaml_conf_file
self.custom_resource_definition = custom_resource_definition
self.namespaced = namespaced
self.config_file = config_file

if not any([self.yaml_conf, self.yaml_conf_file]):
raise AirflowException("One of `yaml_conf` or `yaml_conf_file` arguments must be provided")

@cached_property
def client(self) -> ApiClient:
return self.hook.api_client
Expand Down Expand Up @@ -109,18 +121,29 @@ class KubernetesCreateResourceOperator(KubernetesResourceBaseOperator):

def create_custom_from_yaml_object(self, body: dict):
group, version, namespace, plural = self.get_crd_fields(body)
self.custom_object_client.create_namespaced_custom_object(group, version, namespace, plural, body)
if self.namespaced:
self.custom_object_client.create_namespaced_custom_object(group, version, namespace, plural, body)
else:
self.custom_object_client.create_cluster_custom_object(group, version, plural, body)

def execute(self, context) -> None:
resources = yaml.safe_load_all(self.yaml_conf)
def _create_objects(self, objects):
if not self.custom_resource_definition:
create_from_yaml(
k8s_client=self.client,
yaml_objects=resources,
yaml_objects=objects,
namespace=self.get_namespace(),
)
else:
k8s_resource_iterator(self.create_custom_from_yaml_object, resources)
k8s_resource_iterator(self.create_custom_from_yaml_object, objects)

def execute(self, context) -> None:
if self.yaml_conf:
self._create_objects(yaml.safe_load_all(self.yaml_conf))
elif self.yaml_conf_file and os.path.exists(self.yaml_conf_file):
with open(self.yaml_conf_file) as stream:
self._create_objects(yaml.safe_load_all(stream))
else:
raise AirflowException("File %s not found", self.yaml_conf_file)


class KubernetesDeleteResourceOperator(KubernetesResourceBaseOperator):
Expand All @@ -129,15 +152,26 @@ class KubernetesDeleteResourceOperator(KubernetesResourceBaseOperator):
def delete_custom_from_yaml_object(self, body: dict):
name = body["metadata"]["name"]
group, version, namespace, plural = self.get_crd_fields(body)
self.custom_object_client.delete_namespaced_custom_object(group, version, namespace, plural, name)
if self.namespaced:
self.custom_object_client.delete_namespaced_custom_object(group, version, namespace, plural, name)
else:
self.custom_object_client.delete_cluster_custom_object(group, version, plural, name)

def execute(self, context) -> None:
resources = yaml.safe_load_all(self.yaml_conf)
def _delete_objects(self, objects):
if not self.custom_resource_definition:
delete_from_yaml(
k8s_client=self.client,
yaml_objects=resources,
yaml_objects=objects,
namespace=self.get_namespace(),
)
else:
k8s_resource_iterator(self.delete_custom_from_yaml_object, resources)
k8s_resource_iterator(self.delete_custom_from_yaml_object, objects)

def execute(self, context) -> None:
if self.yaml_conf:
self._delete_objects(yaml.safe_load_all(self.yaml_conf))
elif self.yaml_conf_file and os.path.exists(self.yaml_conf_file):
with open(self.yaml_conf_file) as stream:
self._delete_objects(yaml.safe_load_all(stream))
else:
raise AirflowException("File %s not found", self.yaml_conf_file)
126 changes: 72 additions & 54 deletions airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@
OPERATIONAL_POLL_INTERVAL = 15


class GKEClusterConnection:
"""Helper for establishing connection to GKE cluster."""

def __init__(self, cluster_url: str, ssl_ca_cert: str, credentials: google.auth.credentials.Credentials):
self._cluster_url = cluster_url
self._ssl_ca_cert = ssl_ca_cert
self._credentials = credentials

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self._credentials)}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
api_key_prefix={"authorization": "Bearer"},
api_key={"authorization": self._get_token(self._credentials)},
)
configuration.ssl_ca_cert = FileOrData(
{
"certificate-authority-data": self._ssl_ca_cert,
},
file_key_name="certificate-authority",
).as_file()
return configuration

@staticmethod
def _get_token(creds: google.auth.credentials.Credentials) -> str:
if creds.token is None or creds.expired:
auth_req = google_requests.Request()
creds.refresh(auth_req)
return creds.token


class GKEHook(GoogleBaseHook):
"""Google Kubernetes Engine cluster APIs.
Expand Down Expand Up @@ -360,33 +398,9 @@ def apps_v1_client(self) -> client.AppsV1Api:
return client.AppsV1Api(api_client=self.api_client)

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self.get_credentials())}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
api_key_prefix={"authorization": "Bearer"},
api_key={"authorization": self._get_token(self.get_credentials())},
)
configuration.ssl_ca_cert = FileOrData(
{
"certificate-authority-data": self._ssl_ca_cert,
},
file_key_name="certificate-authority",
).as_file()
return configuration

@staticmethod
def _get_token(creds: google.auth.credentials.Credentials) -> str:
if creds.token is None or creds.expired:
auth_req = google_requests.Request()
creds.refresh(auth_req)
return creds.token
return GKEClusterConnection(
cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, credentials=self.get_credentials()
).get_conn()

def check_kueue_deployment_running(self, name, namespace):
timeout = 300
Expand Down Expand Up @@ -419,6 +433,34 @@ def check_kueue_deployment_running(self, name, namespace):
raise AirflowException("Deployment timed out")


class GKECustomResourceHook(GoogleBaseHook, KubernetesHook):
"""Google Kubernetes Engine Custom Resource APIs."""

def __init__(
self,
cluster_url: str,
ssl_ca_cert: str,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._cluster_url = cluster_url
self._ssl_ca_cert = ssl_ca_cert

@cached_property
def api_client(self) -> client.ApiClient:
return self.get_conn()

@cached_property
def custom_object_client(self) -> client.CustomObjectsApi:
return client.CustomObjectsApi(api_client=self.api_client)

def get_conn(self) -> client.ApiClient:
return GKEClusterConnection(
cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, credentials=self.get_credentials()
).get_conn()


class GKEAsyncHook(GoogleBaseAsyncHook):
"""Asynchronous client of GKE."""

Expand Down Expand Up @@ -615,33 +657,9 @@ def batch_v1_client(self) -> client.BatchV1Api:
return client.BatchV1Api(self.api_client)

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self.get_credentials())}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
api_key_prefix={"authorization": "Bearer"},
api_key={"authorization": self._get_token(self.get_credentials())},
)
configuration.ssl_ca_cert = FileOrData(
{
"certificate-authority-data": self._ssl_ca_cert,
},
file_key_name="certificate-authority",
).as_file()
return configuration

@staticmethod
def _get_token(creds: google.auth.credentials.Credentials) -> str:
if creds.token is None or creds.expired:
auth_req = google_requests.Request()
creds.refresh(auth_req)
return creds.token
return GKEClusterConnection(
cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, credentials=self.get_credentials()
).get_conn()


class GKEPodAsyncHook(GoogleBaseAsyncHook):
Expand Down
Loading

0 comments on commit 3b3c832

Please sign in to comment.