Skip to content

Commit

Permalink
Add GKEDeleteCustomResourceOperator operator
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak committed Feb 26, 2024
1 parent 86d42e1 commit 006dc8a
Show file tree
Hide file tree
Showing 8 changed files with 559 additions and 134 deletions.
34 changes: 19 additions & 15 deletions airflow/providers/cncf/kubernetes/operators/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

import yaml
from kubernetes.utils import create_from_yaml
Expand All @@ -46,9 +46,11 @@ class KubernetesResourceBaseOperator(BaseOperator):
this parameter has no effect.
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
for the Kubernetes cluster.
:param namespaced: specified that Kubernetes resource is or isn't in a namespace.
This parameter works only when custom_resource_definition parameter is True.
"""

template_fields = ("yaml_conf",)
template_fields: Sequence[str] = ("yaml_conf",)
template_fields_renderers = {"yaml_conf": "yaml"}

def __init__(
Expand All @@ -58,6 +60,7 @@ def __init__(
namespace: str | None = None,
kubernetes_conn_id: str | None = KubernetesHook.default_conn_name,
custom_resource_definition: bool = False,
namespaced: bool = True,
config_file: str | None = None,
**kwargs,
) -> None:
Expand All @@ -66,6 +69,7 @@ def __init__(
self.kubernetes_conn_id = kubernetes_conn_id
self.yaml_conf = yaml_conf
self.custom_resource_definition = custom_resource_definition
self.namespaced = namespaced
self.config_file = config_file

@cached_property
Expand Down Expand Up @@ -109,22 +113,19 @@ class KubernetesCreateResourceOperator(KubernetesResourceBaseOperator):

def create_custom_from_yaml_object(self, body: dict):
group, version, namespace, plural = self.get_crd_fields(body)
self.custom_object_client.create_namespaced_custom_object(group=group, version=version,
namespace=namespace, plural=plural,
body=body)
if self.namespaced:
self.custom_object_client.create_namespaced_custom_object(group, version, namespace, plural, body)
else:
self.custom_object_client.create_cluster_custom_object(group, version, plural, body)

def execute(self, context) -> None:
resources = yaml.safe_load_all(self.yaml_conf)
if not self.custom_resource_definition:
try:
create_from_yaml(
k8s_client=self.client,
yaml_objects=resources,
verbose=True,
namespace=self.get_namespace(),
)
except Exception as exc:
self.log.info("Some error happened: %s", exc)
create_from_yaml(
k8s_client=self.client,
yaml_objects=resources,
namespace=self.get_namespace(),
)
else:
k8s_resource_iterator(self.create_custom_from_yaml_object, resources)

Expand All @@ -135,7 +136,10 @@ class KubernetesDeleteResourceOperator(KubernetesResourceBaseOperator):
def delete_custom_from_yaml_object(self, body: dict):
name = body["metadata"]["name"]
group, version, namespace, plural = self.get_crd_fields(body)
self.custom_object_client.delete_namespaced_custom_object(group, version, namespace, plural, name)
if self.namespaced:
self.custom_object_client.delete_namespaced_custom_object(group, version, namespace, plural, name)
else:
self.custom_object_client.delete_cluster_custom_object(group, version, plural, name)

def execute(self, context) -> None:
resources = yaml.safe_load_all(self.yaml_conf)
Expand Down
143 changes: 47 additions & 96 deletions airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@
OPERATIONAL_POLL_INTERVAL = 15


class GKEClusterConnection:
"""Helper for establishing connection to GKE cluster."""

def __init__(self, cluster_url: str, ssl_ca_cert: str, credentials: google.auth.credentials.Credentials):
self._cluster_url = cluster_url
self._ssl_ca_cert = ssl_ca_cert
self._credentials = credentials

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self._credentials)}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
api_key_prefix={"authorization": "Bearer"},
api_key={"authorization": self._get_token(self._credentials)},
)
configuration.ssl_ca_cert = FileOrData(
{
"certificate-authority-data": self._ssl_ca_cert,
},
file_key_name="certificate-authority",
).as_file()
return configuration

@staticmethod
def _get_token(creds: google.auth.credentials.Credentials) -> str:
if creds.token is None or creds.expired:
auth_req = google_requests.Request()
creds.refresh(auth_req)
return creds.token


class GKEHook(GoogleBaseHook):
"""Google Kubernetes Engine cluster APIs.
Expand Down Expand Up @@ -360,33 +398,9 @@ def apps_v1_client(self) -> client.AppsV1Api:
return client.AppsV1Api(api_client=self.api_client)

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self.get_credentials())}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
api_key_prefix={"authorization": "Bearer"},
api_key={"authorization": self._get_token(self.get_credentials())},
)
configuration.ssl_ca_cert = FileOrData(
{
"certificate-authority-data": self._ssl_ca_cert,
},
file_key_name="certificate-authority",
).as_file()
return configuration

@staticmethod
def _get_token(creds: google.auth.credentials.Credentials) -> str:
if creds.token is None or creds.expired:
auth_req = google_requests.Request()
creds.refresh(auth_req)
return creds.token
return GKEClusterConnection(
cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, credentials=self.get_credentials()
).get_conn()

def check_kueue_deployment_running(self, name, namespace):
timeout = 300
Expand Down Expand Up @@ -435,54 +449,16 @@ def __init__(

@cached_property
def api_client(self) -> client.ApiClient:
self.log.info("in get conn")
return self.get_conn()

@cached_property
def core_v1_client(self) -> client.CoreV1Api:
return client.CoreV1Api(self.api_client)

@cached_property
def batch_v1_client(self) -> client.BatchV1Api:
return client.BatchV1Api(self.api_client)

@cached_property
def apps_v1_client(self) -> client.AppsV1Api:
return client.AppsV1Api(api_client=self.api_client)

@cached_property
def custom_object_client(self) -> client.CustomObjectsApi:
self.log.info("in custom hook obj")
return client.CustomObjectsApi(api_client=self.api_client)

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self.get_credentials())}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
api_key_prefix={"authorization": "Bearer"},
api_key={"authorization": self._get_token(self.get_credentials())},
)
configuration.ssl_ca_cert = FileOrData(
{
"certificate-authority-data": self._ssl_ca_cert,
},
file_key_name="certificate-authority",
).as_file()
return configuration

@staticmethod
def _get_token(creds: google.auth.credentials.Credentials) -> str:
if creds.token is None or creds.expired:
auth_req = google_requests.Request()
creds.refresh(auth_req)
return creds.token
return GKEClusterConnection(
cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, credentials=self.get_credentials()
).get_conn()


class GKEAsyncHook(GoogleBaseAsyncHook):
Expand Down Expand Up @@ -589,7 +565,6 @@ def get_conn(self) -> client.ApiClient:

if self.disable_tcp_keepalive is not True:
_enable_tcp_keepalive()
self.log.info("in get_conn")

return client.ApiClient(configuration)

Expand Down Expand Up @@ -682,33 +657,9 @@ def batch_v1_client(self) -> client.BatchV1Api:
return client.BatchV1Api(self.api_client)

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self.get_credentials())}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
api_key_prefix={"authorization": "Bearer"},
api_key={"authorization": self._get_token(self.get_credentials())},
)
configuration.ssl_ca_cert = FileOrData(
{
"certificate-authority-data": self._ssl_ca_cert,
},
file_key_name="certificate-authority",
).as_file()
return configuration

@staticmethod
def _get_token(creds: google.auth.credentials.Credentials) -> str:
if creds.token is None or creds.expired:
auth_req = google_requests.Request()
creds.refresh(auth_req)
return creds.token
return GKEClusterConnection(
cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, credentials=self.get_credentials()
).get_conn()


class GKEPodAsyncHook(GoogleBaseAsyncHook):
Expand Down
Loading

0 comments on commit 006dc8a

Please sign in to comment.