From 19c2abedb1a924e378ef575b6a13fb22ee76174c Mon Sep 17 00:00:00 2001 From: Kevin Date: Mon, 16 Sep 2024 10:47:24 -0400 Subject: [PATCH] wrap api client to add defaults Signed-off-by: Kevin --- src/codeflare_sdk/cluster/auth.py | 61 ++++++++++++------------ src/codeflare_sdk/cluster/awload.py | 6 +-- src/codeflare_sdk/cluster/cluster.py | 39 +++++++-------- src/codeflare_sdk/utils/generate_cert.py | 4 +- src/codeflare_sdk/utils/generate_yaml.py | 8 ++-- tests/unit_test.py | 15 +++--- 6 files changed, 66 insertions(+), 67 deletions(-) diff --git a/src/codeflare_sdk/cluster/auth.py b/src/codeflare_sdk/cluster/auth.py index c39fe1d4a..fbba0c226 100644 --- a/src/codeflare_sdk/cluster/auth.py +++ b/src/codeflare_sdk/cluster/auth.py @@ -93,17 +93,7 @@ def __init__( self.token = token self.server = server self.skip_tls = skip_tls - self.ca_cert_path = self._gen_ca_cert_path(ca_cert_path) - - def _gen_ca_cert_path(self, ca_cert_path: str): - if ca_cert_path is not None: - return ca_cert_path - elif "CF_SDK_CA_CERT_PATH" in os.environ: - return os.environ.get("CF_SDK_CA_CERT_PATH") - elif os.path.exists(WORKBENCH_CA_CERT_PATH): - return WORKBENCH_CA_CERT_PATH - else: - return None + self.ca_cert_path = _gen_ca_cert_path(ca_cert_path) def login(self) -> str: """ @@ -119,25 +109,14 @@ def login(self) -> str: configuration.host = self.server configuration.api_key["authorization"] = self.token + api_client = client.ApiClient(configuration) if not self.skip_tls: - if self.ca_cert_path is None: - configuration.ssl_ca_cert = None - elif os.path.isfile(self.ca_cert_path): - print( - f"Authenticated with certificate located at {self.ca_cert_path}" - ) - configuration.ssl_ca_cert = self.ca_cert_path - else: - raise FileNotFoundError( - f"Certificate file not found at {self.ca_cert_path}" - ) - configuration.verify_ssl = True + _client_with_cert(api_client, self.ca_cert_path) else: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) print("Insecure request warnings have been disabled") configuration.verify_ssl = False - api_client = client.ApiClient(configuration) client.AuthenticationApi(api_client).get_api_group() config_path = None return "Logged into %s" % self.server @@ -211,11 +190,33 @@ def config_check() -> str: return config_path -def api_config_handler() -> Optional[client.ApiClient]: - """ - This function is used to load the api client if the user has logged in - """ - if api_client != None and config_path == None: - return api_client +def _client_with_cert(client: client.ApiClient, ca_cert_path: Optional[str] = None): + client.configuration.verify_ssl = True + cert_path = _gen_ca_cert_path(ca_cert_path) + if cert_path is None: + client.configuration.ssl_ca_cert = None + elif os.path.isfile(cert_path): + client.configuration.ssl_ca_cert = cert_path + else: + raise FileNotFoundError(f"Certificate file not found at {cert_path}") + + +def _gen_ca_cert_path(ca_cert_path: Optional[str]): + """Gets the path to the default CA certificate file either through env config or default path""" + if ca_cert_path is not None: + return ca_cert_path + elif "CF_SDK_CA_CERT_PATH" in os.environ: + return os.environ.get("CF_SDK_CA_CERT_PATH") + elif os.path.exists(WORKBENCH_CA_CERT_PATH): + return WORKBENCH_CA_CERT_PATH else: return None + + +def get_api_client() -> client.ApiClient: + "This function should load the api client with defaults" + if api_client != None: + return api_client + to_return = client.ApiClient() + _client_with_cert(to_return) + return to_return diff --git a/src/codeflare_sdk/cluster/awload.py b/src/codeflare_sdk/cluster/awload.py index 7455b2161..1ead59146 100644 --- a/src/codeflare_sdk/cluster/awload.py +++ b/src/codeflare_sdk/cluster/awload.py @@ -24,7 +24,7 @@ from kubernetes import client, config from ..utils.kube_api_helpers import _kube_api_error_handling -from .auth import config_check, api_config_handler +from .auth import config_check, get_api_client class AWManager: @@ -59,7 +59,7 @@ def submit(self) -> None: """ try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) api_instance.create_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", @@ -84,7 +84,7 @@ def remove(self) -> None: try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) api_instance.delete_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 7c652a186..79c353116 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -25,7 +25,7 @@ from kubernetes import config from ray.job_submission import JobSubmissionClient -from .auth import config_check, api_config_handler +from .auth import config_check, get_api_client from ..utils import pretty_print from ..utils.generate_yaml import ( generate_appwrapper, @@ -80,7 +80,7 @@ def __init__(self, config: ClusterConfiguration): @property def _client_headers(self): - k8_client = api_config_handler() or client.ApiClient() + k8_client = get_api_client() return { "Authorization": k8_client.configuration.get_api_key_with_prefix( "authorization" @@ -95,7 +95,7 @@ def _client_verify_tls(self): @property def job_client(self): - k8client = api_config_handler() or client.ApiClient() + k8client = get_api_client() if self._job_submission_client: return self._job_submission_client if is_openshift_cluster(): @@ -141,7 +141,7 @@ def up(self): try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) if self.config.appwrapper: if self.config.write_to_file: with open(self.app_wrapper_yaml) as f: @@ -172,7 +172,7 @@ def up(self): return _kube_api_error_handling(e) def _throw_for_no_raycluster(self): - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) try: api_instance.list_namespaced_custom_object( group="ray.io", @@ -199,7 +199,7 @@ def down(self): self._throw_for_no_raycluster() try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) if self.config.appwrapper: api_instance.delete_namespaced_custom_object( group="workload.codeflare.dev", @@ -358,7 +358,7 @@ def cluster_dashboard_uri(self) -> str: config_check() if is_openshift_cluster(): try: - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) routes = api_instance.list_namespaced_custom_object( group="route.openshift.io", version="v1", @@ -380,7 +380,7 @@ def cluster_dashboard_uri(self) -> str: return f"{protocol}://{route['spec']['host']}" else: try: - api_instance = client.NetworkingV1Api(api_config_handler()) + api_instance = client.NetworkingV1Api(get_api_client()) ingresses = api_instance.list_namespaced_ingress(self.config.namespace) except Exception as e: # pragma no cover return _kube_api_error_handling(e) @@ -579,9 +579,6 @@ def get_current_namespace(): # pragma: no cover return active_context except Exception as e: print("Unable to find current namespace") - - if api_config_handler() != None: - return None print("trying to gather from current context") try: _, active_context = config.list_kube_config_contexts(config_check()) @@ -601,7 +598,7 @@ def get_cluster( ): try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1", @@ -656,7 +653,7 @@ def _create_resources(yamls, namespace: str, api_instance: client.CustomObjectsA def _check_aw_exists(name: str, namespace: str) -> bool: try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) aws = api_instance.list_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", @@ -683,7 +680,7 @@ def _get_ingress_domain(self): # pragma: no cover if is_openshift_cluster(): try: - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) routes = api_instance.list_namespaced_custom_object( group="route.openshift.io", @@ -702,7 +699,7 @@ def _get_ingress_domain(self): # pragma: no cover domain = route["spec"]["host"] else: try: - api_client = client.NetworkingV1Api(api_config_handler()) + api_client = client.NetworkingV1Api(get_api_client()) ingresses = api_client.list_namespaced_ingress(namespace) except Exception as e: # pragma: no cover return _kube_api_error_handling(e) @@ -716,7 +713,7 @@ def _get_ingress_domain(self): # pragma: no cover def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]: try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) aws = api_instance.list_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", @@ -735,7 +732,7 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]: def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]: try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1", @@ -757,7 +754,7 @@ def _get_ray_clusters( list_of_clusters = [] try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1", @@ -786,7 +783,7 @@ def _get_app_wrappers( try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) aws = api_instance.list_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", @@ -815,7 +812,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]: dashboard_url = None if is_openshift_cluster(): try: - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) routes = api_instance.list_namespaced_custom_object( group="route.openshift.io", version="v1", @@ -834,7 +831,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]: dashboard_url = f"{protocol}://{route['spec']['host']}" else: try: - api_instance = client.NetworkingV1Api(api_config_handler()) + api_instance = client.NetworkingV1Api(get_api_client()) ingresses = api_instance.list_namespaced_ingress( rc["metadata"]["namespace"] ) diff --git a/src/codeflare_sdk/utils/generate_cert.py b/src/codeflare_sdk/utils/generate_cert.py index 5de56882b..f3dc80e94 100644 --- a/src/codeflare_sdk/utils/generate_cert.py +++ b/src/codeflare_sdk/utils/generate_cert.py @@ -19,7 +19,7 @@ from cryptography import x509 from cryptography.x509.oid import NameOID import datetime -from ..cluster.auth import config_check, api_config_handler +from ..cluster.auth import config_check, get_api_client from kubernetes import client, config from .kube_api_helpers import _kube_api_error_handling @@ -103,7 +103,7 @@ def generate_tls_cert(cluster_name, namespace, days=30): # oc get secret ca-secret- -o template='{{index .data "ca.key"}}' # oc get secret ca-secret- -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt config_check() - v1 = client.CoreV1Api(api_config_handler()) + v1 = client.CoreV1Api(get_api_client()) # Secrets have a suffix appended to the end so we must list them and gather the secret that includes cluster_name-ca-secret- secret_name = get_secret_name(cluster_name, namespace, v1) diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index c4e1755d8..7a17e0103 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -27,7 +27,7 @@ import uuid from kubernetes import client, config from .kube_api_helpers import _kube_api_error_handling -from ..cluster.auth import api_config_handler, config_check +from ..cluster.auth import get_api_client, config_check from os import urandom from base64 import b64encode from urllib3.util import parse_url @@ -57,7 +57,7 @@ def gen_names(name): def is_openshift_cluster(): try: config_check() - for api in client.ApisApi(api_config_handler()).get_api_versions().groups: + for api in client.ApisApi(get_api_client()).get_api_versions().groups: for v in api.versions: if "route.openshift.io/v1" in v.group_version: return True @@ -235,7 +235,7 @@ def get_default_kueue_name(namespace: str): # If the local queue is set, use it. Otherwise, try to use the default queue. try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) local_queues = api_instance.list_namespaced_custom_object( group="kueue.x-k8s.io", version="v1beta1", @@ -261,7 +261,7 @@ def local_queue_exists(namespace: str, local_queue_name: str): # get all local queues in the namespace try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) local_queues = api_instance.list_namespaced_custom_object( group="kueue.x-k8s.io", version="v1beta1", diff --git a/tests/unit_test.py b/tests/unit_test.py index 388723c50..cdb2d1117 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -135,7 +135,7 @@ def test_token_auth_creation(): assert token_auth.skip_tls == True assert token_auth.ca_cert_path == None - os.environ["CF_SDK_CA_CERT_PATH"] = f"/etc/pki/tls/custom-certs/ca-bundle.crt" + os.environ["CF_SDK_CA_CERT_PATH"] = "/etc/pki/tls/custom-certs/ca-bundle.crt" token_auth = TokenAuthentication(token="token", server="server", skip_tls=False) assert token_auth.token == "token" assert token_auth.server == "server" @@ -154,7 +154,8 @@ def test_token_auth_creation(): assert token_auth.skip_tls == False assert token_auth.ca_cert_path == f"{parent}/tests/auth-test.crt" - except Exception: + except Exception as e: + raise e assert 0 == 1 @@ -295,6 +296,7 @@ def test_cluster_creation(mocker): ) +@patch.dict("os.environ", {"NB_PREFIX": "test-prefix"}) def test_cluster_no_kueue_no_aw(mocker): mocker.patch("kubernetes.client.ApisApi.get_api_versions") mocker.patch( @@ -302,7 +304,6 @@ def test_cluster_no_kueue_no_aw(mocker): return_value={"spec": {"domain": "apps.cluster.awsroute.org"}}, ) mocker.patch("kubernetes.client.CustomObjectsApi.list_namespaced_custom_object") - mocker.patch("os.environ.get", return_value="test-prefix") config = createClusterConfig() config.appwrapper = False config.name = "unit-test-no-kueue" @@ -351,6 +352,7 @@ def get_local_queue(group, version, namespace, plural): return local_queues +@patch.dict("os.environ", {"NB_PREFIX": "test-prefix"}) def test_cluster_creation_no_mcad(mocker): # Create Ray Cluster with no local queue specified mocker.patch("kubernetes.client.ApisApi.get_api_versions") @@ -362,7 +364,6 @@ def test_cluster_creation_no_mcad(mocker): "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object", return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"), ) - mocker.patch("os.environ.get", return_value="test-prefix") config = createClusterConfig() config.name = "unit-test-cluster-ray" @@ -380,6 +381,7 @@ def test_cluster_creation_no_mcad(mocker): ) +@patch.dict("os.environ", {"NB_PREFIX": "test-prefix"}) def test_cluster_creation_no_mcad_local_queue(mocker): # With written resources # Create Ray Cluster with local queue specified @@ -392,7 +394,6 @@ def test_cluster_creation_no_mcad_local_queue(mocker): "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object", return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"), ) - mocker.patch("os.environ.get", return_value="test-prefix") config = createClusterConfig() config.name = "unit-test-cluster-ray" config.appwrapper = False @@ -461,6 +462,7 @@ def test_default_cluster_creation(mocker): assert cluster.config.namespace == "opendatahub" +@patch.dict("os.environ", {"NB_PREFIX": "test-prefix"}) def test_cluster_creation_with_custom_image(mocker): # With written resources # Create Ray Cluster with local queue specified @@ -473,7 +475,6 @@ def test_cluster_creation_with_custom_image(mocker): "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object", return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"), ) - mocker.patch("os.environ.get", return_value="test-prefix") config = createClusterConfig() config.name = "unit-test-cluster-custom-image" config.appwrapper = False @@ -2169,7 +2170,7 @@ def test_map_to_ray_cluster(mocker): mock_api_client = mocker.MagicMock(spec=client.ApiClient) mocker.patch( - "codeflare_sdk.cluster.auth.api_config_handler", return_value=mock_api_client + "codeflare_sdk.cluster.auth.get_api_client", return_value=mock_api_client ) mock_routes = {