Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrap api client to add defaults #669

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 31 additions & 30 deletions src/codeflare_sdk/cluster/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,7 @@ def __init__(
self.token = token
self.server = server
self.skip_tls = skip_tls
self.ca_cert_path = self._gen_ca_cert_path(ca_cert_path)

def _gen_ca_cert_path(self, ca_cert_path: str):
if ca_cert_path is not None:
return ca_cert_path
elif "CF_SDK_CA_CERT_PATH" in os.environ:
return os.environ.get("CF_SDK_CA_CERT_PATH")
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
return WORKBENCH_CA_CERT_PATH
else:
return None
self.ca_cert_path = _gen_ca_cert_path(ca_cert_path)

def login(self) -> str:
"""
Expand All @@ -119,25 +109,14 @@ def login(self) -> str:
configuration.host = self.server
configuration.api_key["authorization"] = self.token

api_client = client.ApiClient(configuration)
if not self.skip_tls:
if self.ca_cert_path is None:
configuration.ssl_ca_cert = None
elif os.path.isfile(self.ca_cert_path):
print(
f"Authenticated with certificate located at {self.ca_cert_path}"
)
configuration.ssl_ca_cert = self.ca_cert_path
else:
raise FileNotFoundError(
f"Certificate file not found at {self.ca_cert_path}"
)
configuration.verify_ssl = True
_client_with_cert(api_client, self.ca_cert_path)
else:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
print("Insecure request warnings have been disabled")
configuration.verify_ssl = False

api_client = client.ApiClient(configuration)
client.AuthenticationApi(api_client).get_api_group()
config_path = None
return "Logged into %s" % self.server
Expand Down Expand Up @@ -211,11 +190,33 @@ def config_check() -> str:
return config_path


def api_config_handler() -> Optional[client.ApiClient]:
"""
This function is used to load the api client if the user has logged in
"""
if api_client != None and config_path == None:
return api_client
def _client_with_cert(client: client.ApiClient, ca_cert_path: Optional[str] = None):
client.configuration.verify_ssl = True
cert_path = _gen_ca_cert_path(ca_cert_path)
if cert_path is None:
client.configuration.ssl_ca_cert = None
elif os.path.isfile(cert_path):
client.configuration.ssl_ca_cert = cert_path
else:
raise FileNotFoundError(f"Certificate file not found at {cert_path}")


def _gen_ca_cert_path(ca_cert_path: Optional[str]):
"""Gets the path to the default CA certificate file either through env config or default path"""
if ca_cert_path is not None:
return ca_cert_path
elif "CF_SDK_CA_CERT_PATH" in os.environ:
return os.environ.get("CF_SDK_CA_CERT_PATH")
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
return WORKBENCH_CA_CERT_PATH
else:
return None


def get_api_client() -> client.ApiClient:
"This function should load the api client with defaults"
if api_client != None:
return api_client
to_return = client.ApiClient()
_client_with_cert(to_return)
return to_return
6 changes: 3 additions & 3 deletions src/codeflare_sdk/cluster/awload.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from kubernetes import client, config
from ..utils.kube_api_helpers import _kube_api_error_handling
from .auth import config_check, api_config_handler
from .auth import config_check, get_api_client


class AWManager:
Expand Down Expand Up @@ -59,7 +59,7 @@ def submit(self) -> None:
"""
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
api_instance.create_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -84,7 +84,7 @@ def remove(self) -> None:

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
api_instance.delete_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand Down
39 changes: 18 additions & 21 deletions src/codeflare_sdk/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from kubernetes import config
from ray.job_submission import JobSubmissionClient

from .auth import config_check, api_config_handler
from .auth import config_check, get_api_client
from ..utils import pretty_print
from ..utils.generate_yaml import (
generate_appwrapper,
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self, config: ClusterConfiguration):

@property
def _client_headers(self):
k8_client = api_config_handler() or client.ApiClient()
k8_client = get_api_client()
return {
"Authorization": k8_client.configuration.get_api_key_with_prefix(
"authorization"
Expand All @@ -96,7 +96,7 @@ def _client_verify_tls(self):

@property
def job_client(self):
k8client = api_config_handler() or client.ApiClient()
k8client = get_api_client()
if self._job_submission_client:
return self._job_submission_client
if is_openshift_cluster():
Expand Down Expand Up @@ -142,7 +142,7 @@ def up(self):

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
if self.config.appwrapper:
if self.config.write_to_file:
with open(self.app_wrapper_yaml) as f:
Expand Down Expand Up @@ -173,7 +173,7 @@ def up(self):
return _kube_api_error_handling(e)

def _throw_for_no_raycluster(self):
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
try:
api_instance.list_namespaced_custom_object(
group="ray.io",
Expand All @@ -200,7 +200,7 @@ def down(self):
self._throw_for_no_raycluster()
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
if self.config.appwrapper:
api_instance.delete_namespaced_custom_object(
group="workload.codeflare.dev",
Expand Down Expand Up @@ -359,7 +359,7 @@ def cluster_dashboard_uri(self) -> str:
config_check()
if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
version="v1",
Expand All @@ -381,7 +381,7 @@ def cluster_dashboard_uri(self) -> str:
return f"{protocol}://{route['spec']['host']}"
else:
try:
api_instance = client.NetworkingV1Api(api_config_handler())
api_instance = client.NetworkingV1Api(get_api_client())
ingresses = api_instance.list_namespaced_ingress(self.config.namespace)
except Exception as e: # pragma no cover
return _kube_api_error_handling(e)
Expand Down Expand Up @@ -580,9 +580,6 @@ def get_current_namespace(): # pragma: no cover
return active_context
except Exception as e:
print("Unable to find current namespace")

if api_config_handler() != None:
return None
print("trying to gather from current context")
try:
_, active_context = config.list_kube_config_contexts(config_check())
Expand All @@ -602,7 +599,7 @@ def get_cluster(
):
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand Down Expand Up @@ -657,7 +654,7 @@ def _create_resources(yamls, namespace: str, api_instance: client.CustomObjectsA
def _check_aw_exists(name: str, namespace: str) -> bool:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -684,7 +681,7 @@ def _get_ingress_domain(self): # pragma: no cover

if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())

routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
Expand All @@ -703,7 +700,7 @@ def _get_ingress_domain(self): # pragma: no cover
domain = route["spec"]["host"]
else:
try:
api_client = client.NetworkingV1Api(api_config_handler())
api_client = client.NetworkingV1Api(get_api_client())
ingresses = api_client.list_namespaced_ingress(namespace)
except Exception as e: # pragma: no cover
return _kube_api_error_handling(e)
Expand All @@ -717,7 +714,7 @@ def _get_ingress_domain(self): # pragma: no cover
def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -736,7 +733,7 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand All @@ -758,7 +755,7 @@ def _get_ray_clusters(
list_of_clusters = []
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand Down Expand Up @@ -787,7 +784,7 @@ def _get_app_wrappers(

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand Down Expand Up @@ -816,7 +813,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
dashboard_url = None
if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
version="v1",
Expand All @@ -835,7 +832,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
dashboard_url = f"{protocol}://{route['spec']['host']}"
else:
try:
api_instance = client.NetworkingV1Api(api_config_handler())
api_instance = client.NetworkingV1Api(get_api_client())
ingresses = api_instance.list_namespaced_ingress(
rc["metadata"]["namespace"]
)
Expand Down
4 changes: 2 additions & 2 deletions src/codeflare_sdk/cluster/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .config import ClusterConfiguration
from .model import RayClusterStatus
from ..utils.kube_api_helpers import _kube_api_error_handling
from .auth import config_check, api_config_handler
from .auth import config_check, get_api_client


def cluster_up_down_buttons(cluster: "codeflare_sdk.cluster.Cluster") -> widgets.Button:
Expand Down Expand Up @@ -343,7 +343,7 @@ def _delete_cluster(

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())

if _check_aw_exists(cluster_name, namespace):
api_instance.delete_namespaced_custom_object(
Expand Down
4 changes: 2 additions & 2 deletions src/codeflare_sdk/utils/generate_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cryptography import x509
from cryptography.x509.oid import NameOID
import datetime
from ..cluster.auth import config_check, api_config_handler
from ..cluster.auth import config_check, get_api_client
from kubernetes import client, config
from .kube_api_helpers import _kube_api_error_handling

Expand Down Expand Up @@ -103,7 +103,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.key"}}'
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
config_check()
v1 = client.CoreV1Api(api_config_handler())
v1 = client.CoreV1Api(get_api_client())

# Secrets have a suffix appended to the end so we must list them and gather the secret that includes cluster_name-ca-secret-
secret_name = get_secret_name(cluster_name, namespace, v1)
Expand Down
8 changes: 4 additions & 4 deletions src/codeflare_sdk/utils/generate_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import uuid
from kubernetes import client, config
from .kube_api_helpers import _kube_api_error_handling
from ..cluster.auth import api_config_handler, config_check
from ..cluster.auth import get_api_client, config_check
from os import urandom
from base64 import b64encode
from urllib3.util import parse_url
Expand Down Expand Up @@ -57,7 +57,7 @@ def gen_names(name):
def is_openshift_cluster():
try:
config_check()
for api in client.ApisApi(api_config_handler()).get_api_versions().groups:
for api in client.ApisApi(get_api_client()).get_api_versions().groups:
for v in api.versions:
if "route.openshift.io/v1" in v.group_version:
return True
Expand Down Expand Up @@ -235,7 +235,7 @@ def get_default_kueue_name(namespace: str):
# If the local queue is set, use it. Otherwise, try to use the default queue.
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
local_queues = api_instance.list_namespaced_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
Expand All @@ -261,7 +261,7 @@ def local_queue_exists(namespace: str, local_queue_name: str):
# get all local queues in the namespace
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
local_queues = api_instance.list_namespaced_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
Expand Down
Loading