diff --git a/scripts/plot_socket_profile.py b/scripts/plot_socket_profile.py index c9e364abc..96a6b76c4 100644 --- a/scripts/plot_socket_profile.py +++ b/scripts/plot_socket_profile.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt # type: ignore from tqdm import tqdm -from skyplane import skyplane_root +from skyplane import __root__ def plot(file): @@ -28,7 +28,7 @@ def plot(file): parser = argparse.ArgumentParser() parser.add_argument("profile_file", help="Path to the profile file") parser.add_argument( - "--plot_dir", default=skyplane_root / "data" / "figures" / "socket_profiles", help="Path to the directory where to save the plot" + "--plot_dir", default=__root__ / "data" / "figures" / "socket_profiles", help="Path to the directory where to save the plot" ) args = parser.parse_args() diff --git a/skyplane/__init__.py b/skyplane/__init__.py index 858d601f9..0501f5403 100644 --- a/skyplane/__init__.py +++ b/skyplane/__init__.py @@ -1,54 +1,22 @@ -import os from pathlib import Path -from skyplane.config import SkyplaneConfig -from skyplane.gateway_version import gateway_version - # version __version__ = "0.2.1" # paths -skyplane_root = Path(__file__).parent.parent -config_root = Path("~/.skyplane").expanduser() -config_root.mkdir(exist_ok=True) - -if "SKYPLANE_CONFIG" in os.environ: - config_path = Path(os.environ["SKYPLANE_CONFIG"]).expanduser() -else: - config_path = config_root / "config" - -aws_config_path = config_root / "aws_config" -azure_config_path = config_root / "azure_config" -azure_sku_path = config_root / "azure_sku_mapping" -gcp_config_path = config_root / "gcp_config" - -key_root = config_root / "keys" +__root__ = Path(__file__).parent.parent +__config_root__ = Path("~/.skyplane").expanduser() +__config_root__.mkdir(exist_ok=True) tmp_log_dir = Path("/tmp/skyplane") -tmp_log_dir.mkdir(exist_ok=True) - -# definitions -KB = 1024 -MB = 1024 * 1024 -GB = 1024 * 1024 * 1024 - - -def format_bytes(bytes: int): - if bytes < KB: - return f"{bytes}B" - elif bytes < MB: - return f"{bytes / KB:.2f}KB" - elif bytes < GB: - return f"{bytes / MB:.2f}MB" - else: - return f"{bytes / GB:.2f}GB" - - -if config_path.exists(): - cloud_config = SkyplaneConfig.load_config(config_path) -else: - cloud_config = SkyplaneConfig.default_config() -is_gateway_env = os.environ.get("SKYPLANE_IS_GATEWAY", None) == "1" -# load gateway docker image version -def gateway_docker_image(): - return "public.ecr.aws/s6m1p0n8/skyplane:" + gateway_version +__all__ = [ + "__root__", + "__config_root__", + "__version__", + "SkyplaneClient", + "Dataplane", + "TransferConfig", + "AWSConfig", + "AzureConfig", + "GCPConfig", +] diff --git a/skyplane/cli/__init__.py b/skyplane/cli/__init__.py new file mode 100644 index 000000000..23675a404 --- /dev/null +++ b/skyplane/cli/__init__.py @@ -0,0 +1,26 @@ +import functools +import os +from pathlib import Path + +from skyplane import __config_root__ +from skyplane.config import SkyplaneConfig + + +@functools.lru_cache +def load_config_path(): + if "SKYPLANE_CONFIG" in os.environ: + return Path(os.environ["SKYPLANE_CONFIG"]).expanduser() + else: + return __config_root__ / "config" + + +@functools.lru_cache +def load_cloud_config(path): + if path.exists(): + return SkyplaneConfig.load_config(path) + else: + return SkyplaneConfig.default_config() + + +config_path = load_config_path() +cloud_config = load_cloud_config(config_path) diff --git a/skyplane/cli/cli.py b/skyplane/cli/cli.py index e52672b3a..ea67b5ecc 100644 --- a/skyplane/cli/cli.py +++ b/skyplane/cli/cli.py @@ -22,7 +22,9 @@ import skyplane.cli.usage.client import skyplane.cli.usage.definitions import skyplane.cli.usage.definitions -from skyplane import GB, cloud_config, config_path, exceptions, skyplane_root +from skyplane import exceptions, __root__ +from skyplane.cli import config_path, cloud_config +from skyplane.utils.definitions import GB from skyplane.cli.cli_impl.cp_replicate import ( confirm_transfer, enrich_dest_objs, @@ -76,9 +78,7 @@ def cp( # solver solve: bool = typer.Option(False, help="If true, will use solver to optimize transfer, else direct path is chosen"), solver_target_tput_per_vm_gbits: float = typer.Option(4, help="Solver option: Required throughput in Gbps"), - solver_throughput_grid: Path = typer.Option( - skyplane_root / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file" - ), + solver_throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"), solver_verbose: bool = False, ): """ @@ -278,9 +278,7 @@ def sync( # solver solve: bool = typer.Option(False, help="If true, will use solver to optimize transfer, else direct path is chosen"), solver_target_tput_per_vm_gbits: float = typer.Option(4, help="Solver option: Required throughput in Gbps per instance"), - solver_throughput_grid: Path = typer.Option( - skyplane_root / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file" - ), + solver_throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"), solver_verbose: bool = False, ): """ diff --git a/skyplane/cli/cli_aws.py b/skyplane/cli/cli_aws.py index 7b48579ea..5619423c9 100644 --- a/skyplane/cli/cli_aws.py +++ b/skyplane/cli/cli_aws.py @@ -7,7 +7,7 @@ import typer -from skyplane import GB +from skyplane.utils.definitions import GB from skyplane.compute.aws.aws_auth import AWSAuthentication from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider from skyplane.obj_store.s3_interface import S3Interface diff --git a/skyplane/cli/cli_azure.py b/skyplane/cli/cli_azure.py index 5a011027f..3548cadf8 100644 --- a/skyplane/cli/cli_azure.py +++ b/skyplane/cli/cli_azure.py @@ -13,7 +13,7 @@ from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider from skyplane.utils.fn import do_parallel from skyplane.utils import logger -from skyplane import cloud_config +from skyplane.cli import cloud_config from rich import print as rprint from skyplane.compute.azure.azure_auth import AzureAuthentication diff --git a/skyplane/cli/cli_config.py b/skyplane/cli/cli_config.py index 3426fd65d..96e74dfdd 100644 --- a/skyplane/cli/cli_config.py +++ b/skyplane/cli/cli_config.py @@ -10,7 +10,7 @@ import typer -from skyplane import cloud_config, config_path +from skyplane.cli import config_path, cloud_config from skyplane.cli.common import console from skyplane.cli.usage.client import UsageClient diff --git a/skyplane/cli/cli_impl/cp_replicate.py b/skyplane/cli/cli_impl/cp_replicate.py index aa6326698..6b385747e 100644 --- a/skyplane/cli/cli_impl/cp_replicate.py +++ b/skyplane/cli/cli_impl/cp_replicate.py @@ -8,7 +8,9 @@ import typer from rich import print as rprint -from skyplane import exceptions, GB, format_bytes, gateway_docker_image, skyplane_root, cloud_config +from skyplane import exceptions, __root__ +from skyplane.cli import cloud_config +from skyplane.utils.definitions import GB, format_bytes, gateway_docker_image from skyplane.cli.common import console from skyplane.cli.usage.client import UsageClient from skyplane.compute.cloud_providers import CloudProvider @@ -31,7 +33,7 @@ def generate_topology( solver_class: str = "ILP", solver_total_gbyte_to_transfer: Optional[float] = None, solver_target_tput_per_vm_gbits: Optional[float] = None, - solver_throughput_grid: Optional[pathlib.Path] = skyplane_root / "profiles" / "throughput.csv", + solver_throughput_grid: Optional[pathlib.Path] = __root__ / "profiles" / "throughput.csv", solver_verbose: Optional[bool] = False, args: Optional[Dict] = None, ) -> ReplicationTopology: diff --git a/skyplane/cli/cli_impl/init.py b/skyplane/cli/cli_impl/init.py index cdc63df9f..241b8a997 100644 --- a/skyplane/cli/cli_impl/init.py +++ b/skyplane/cli/cli_impl/init.py @@ -10,11 +10,11 @@ from rich.progress import Progress, SpinnerColumn, TextColumn import questionary -from skyplane import SkyplaneConfig, aws_config_path, gcp_config_path -from skyplane.compute.aws.aws_auth import AWSAuthentication +from skyplane.compute.aws.aws_auth import AWSAuthentication, aws_config_path from skyplane.compute.azure.azure_auth import AzureAuthentication from skyplane.compute.azure.azure_server import AzureServer -from skyplane.compute.gcp.gcp_auth import GCPAuthentication +from skyplane.compute.gcp.gcp_auth import GCPAuthentication, gcp_config_path +from skyplane.config import SkyplaneConfig def load_aws_config(config: SkyplaneConfig, non_interactive: bool = False) -> SkyplaneConfig: diff --git a/skyplane/cli/cli_internal.py b/skyplane/cli/cli_internal.py index 1943e3758..cb62b0846 100644 --- a/skyplane/cli/cli_internal.py +++ b/skyplane/cli/cli_internal.py @@ -3,7 +3,7 @@ import typer -from skyplane import skyplane_root +from skyplane import __root__ from skyplane.cli.cli_impl.cp_replicate import confirm_transfer, launch_replication_job from skyplane.cli.common import print_header from skyplane.obj_store.object_store_interface import ObjectStoreObject @@ -91,9 +91,7 @@ def replicate_random_solve( reuse_gateways: bool = False, solve: bool = typer.Option(False, help="If true, will use solver to optimize transfer, else direct path is chosen"), throughput_per_instance_gbits: float = typer.Option(2, help="Solver option: Required throughput in gbps."), - solver_throughput_grid: Path = typer.Option( - skyplane_root / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file" - ), + solver_throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"), solver_verbose: bool = False, debug: bool = False, ): diff --git a/skyplane/cli/experiments/cli_profile.py b/skyplane/cli/experiments/cli_profile.py index 78718a9d7..49ad63ad5 100644 --- a/skyplane/cli/experiments/cli_profile.py +++ b/skyplane/cli/experiments/cli_profile.py @@ -10,7 +10,8 @@ import typer from rich.progress import Progress -from skyplane import GB, skyplane_root +from skyplane import __root__ +from skyplane.utils.definitions import GB from skyplane.cli.experiments.provision import provision from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider @@ -240,7 +241,7 @@ def setup(server: Server): experiment_tag_words = os.popen("bash scripts/get_random_word_hash.sh").read().strip() timestamp = datetime.now(timezone.utc).strftime("%Y.%m.%d_%H.%M") experiment_tag = f"{timestamp}_{experiment_tag_words}_{iperf3_runtime}s_{iperf3_connections}c" - data_dir = skyplane_root / "data" + data_dir = __root__ / "data" log_dir = data_dir / "logs" / "throughput_grid" / f"{experiment_tag}" raw_iperf3_log_dir = log_dir / "raw_iperf3_logs" @@ -433,7 +434,7 @@ def setup(server: Server): experiment_tag_words = os.popen("bash scripts/get_random_word_hash.sh").read().strip() timestamp = datetime.now(timezone.utc).strftime("%Y.%m.%d_%H.%M") experiment_tag = f"{timestamp}_{experiment_tag_words}" - data_dir = skyplane_root / "data" + data_dir = __root__ / "data" log_dir = data_dir / "logs" / "latency_grid" / f"{experiment_tag}" # ask for confirmation diff --git a/skyplane/cli/experiments/cli_query.py b/skyplane/cli/experiments/cli_query.py index 4d1707099..bb5d26ef3 100644 --- a/skyplane/cli/experiments/cli_query.py +++ b/skyplane/cli/experiments/cli_query.py @@ -2,7 +2,7 @@ import typer -from skyplane import skyplane_root +from skyplane import __root__ from skyplane.replicate.solver import ThroughputSolver @@ -11,7 +11,7 @@ def util_grid_throughput( dest: str, src_tier: str = "PREMIUM", dest_tier: str = "PREMIUM", - throughput_grid: Path = typer.Option(skyplane_root / "profiles" / "throughput.csv", help="Throughput grid file"), + throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", help="Throughput grid file"), ): solver = ThroughputSolver(throughput_grid) print(solver.get_path_throughput(src, dest, src_tier, dest_tier) / 2**30) @@ -22,7 +22,7 @@ def util_grid_cost( dest: str, src_tier: str = "PREMIUM", dest_tier: str = "PREMIUM", - throughput_grid: Path = typer.Option(skyplane_root / "profiles" / "throughput.csv", help="Throughput grid file"), + throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", help="Throughput grid file"), ): solver = ThroughputSolver(throughput_grid) print(solver.get_path_cost(src, dest, src_tier, dest_tier)) @@ -42,7 +42,7 @@ def get_max_throughput(region_tag: str): def dump_full_util_cost_grid( - throughput_grid: Path = typer.Option(skyplane_root / "profiles" / "throughput.csv", help="Throughput grid file"), + throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", help="Throughput grid file"), ): solver = ThroughputSolver(throughput_grid) regions = solver.get_regions() diff --git a/skyplane/cli/usage/client.py b/skyplane/cli/usage/client.py index 1bbe0888b..af917462a 100644 --- a/skyplane/cli/usage/client.py +++ b/skyplane/cli/usage/client.py @@ -14,7 +14,8 @@ from rich import print as rprint import skyplane.cli.usage.definitions -from skyplane import cloud_config, config_path, tmp_log_dir +from skyplane import tmp_log_dir +from skyplane.cli import config_path, cloud_config from skyplane.config import _map_type from skyplane.replicate.replicator_client import TransferStats from skyplane.utils import logger diff --git a/skyplane/compute/aws/aws_auth.py b/skyplane/compute/aws/aws_auth.py index ad51a6b4e..39240d4a3 100644 --- a/skyplane/compute/aws/aws_auth.py +++ b/skyplane/compute/aws/aws_auth.py @@ -1,10 +1,12 @@ from typing import Optional -from skyplane import aws_config_path -from skyplane import config_path +from skyplane import __config_root__ +from skyplane.cli import config_path from skyplane.config import SkyplaneConfig from skyplane.utils import imports +aws_config_path = __config_root__ / "aws_config" + class AWSAuthentication: def __init__(self, config: Optional[SkyplaneConfig] = None, access_key: Optional[str] = None, secret_key: Optional[str] = None): diff --git a/skyplane/compute/aws/aws_key_manager.py b/skyplane/compute/aws/aws_key_manager.py index f40715bcc..a1a5d4fc6 100644 --- a/skyplane/compute/aws/aws_key_manager.py +++ b/skyplane/compute/aws/aws_key_manager.py @@ -2,7 +2,7 @@ from pathlib import Path from skyplane import exceptions as skyplane_exceptions -from skyplane import key_root +from skyplane.compute.server import key_root from skyplane.compute.aws.aws_auth import AWSAuthentication from skyplane.utils import logger diff --git a/skyplane/compute/aws/aws_pricing.py b/skyplane/compute/aws/aws_pricing.py index ef942999e..22a6e6544 100644 --- a/skyplane/compute/aws/aws_pricing.py +++ b/skyplane/compute/aws/aws_pricing.py @@ -1,4 +1,4 @@ -from skyplane import skyplane_root +from skyplane import __root__ from skyplane.utils import logger try: @@ -16,7 +16,7 @@ def __init__(self): def transfer_df(self): if pd: if not self._transfer_df: - self._transfer_df = pd.read_csv(skyplane_root / "profiles" / "aws_transfer_costs.csv").set_index(["src", "dst"]) + self._transfer_df = pd.read_csv(__root__ / "profiles" / "aws_transfer_costs.csv").set_index(["src", "dst"]) return self._transfer_df else: return None diff --git a/skyplane/compute/aws/aws_server.py b/skyplane/compute/aws/aws_server.py index 92f1ebfd7..503e49555 100644 --- a/skyplane/compute/aws/aws_server.py +++ b/skyplane/compute/aws/aws_server.py @@ -11,9 +11,9 @@ warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import paramiko -from skyplane import exceptions, key_root +from skyplane import exceptions from skyplane.compute.aws.aws_auth import AWSAuthentication -from skyplane.compute.server import Server, ServerState +from skyplane.compute.server import Server, ServerState, key_root from skyplane.utils import imports from skyplane.utils.cache import ignore_lru_cache diff --git a/skyplane/compute/azure/azure_auth.py b/skyplane/compute/azure/azure_auth.py index 665be9316..a6331e252 100644 --- a/skyplane/compute/azure/azure_auth.py +++ b/skyplane/compute/azure/azure_auth.py @@ -4,10 +4,9 @@ import subprocess from typing import Dict, List, Optional -from skyplane import azure_config_path -from skyplane import azure_sku_path -from skyplane import config_path -from skyplane import is_gateway_env +from skyplane import __config_root__ +from skyplane.cli import config_path +from skyplane.utils.definitions import is_gateway_env from skyplane.compute.const_cmds import query_which_cloud from skyplane.config import SkyplaneConfig from skyplane.utils import imports @@ -168,3 +167,7 @@ def get_container_client(ContainerClient, self, account_url: str, container_name @imports.inject("azure.storage.blob.BlobServiceClient", pip_extra="azure") def get_blob_service_client(BlobServiceClient, self, account_url: str): return BlobServiceClient(account_url=account_url, credential=self.credential) + + +azure_config_path = __config_root__ / "azure_config" +azure_sku_path = __config_root__ / "azure_sku_mapping" diff --git a/skyplane/compute/azure/azure_cloud_provider.py b/skyplane/compute/azure/azure_cloud_provider.py index c9024dfc4..eb37b7785 100644 --- a/skyplane/compute/azure/azure_cloud_provider.py +++ b/skyplane/compute/azure/azure_cloud_provider.py @@ -12,10 +12,12 @@ warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import paramiko -from skyplane import cloud_config, exceptions, key_root +from skyplane import exceptions +from skyplane.cli import cloud_config from skyplane.compute.azure.azure_auth import AzureAuthentication from skyplane.compute.azure.azure_server import AzureServer from skyplane.compute.cloud_providers import CloudProvider +from skyplane.compute.server import key_root from skyplane.utils import logger, imports from skyplane.utils.timer import Timer diff --git a/skyplane/compute/azure/azure_server.py b/skyplane/compute/azure/azure_server.py index c74f3cb86..f3fbace59 100644 --- a/skyplane/compute/azure/azure_server.py +++ b/skyplane/compute/azure/azure_server.py @@ -7,9 +7,9 @@ warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import paramiko -from skyplane import exceptions, key_root +from skyplane import exceptions from skyplane.compute.azure.azure_auth import AzureAuthentication -from skyplane.compute.server import Server, ServerState +from skyplane.compute.server import Server, ServerState, key_root from skyplane.utils import imports from skyplane.utils.cache import ignore_lru_cache from skyplane.utils.fn import PathLike diff --git a/skyplane/compute/gcp/gcp_auth.py b/skyplane/compute/gcp/gcp_auth.py index 31c04e20e..a973ffd54 100644 --- a/skyplane/compute/gcp/gcp_auth.py +++ b/skyplane/compute/gcp/gcp_auth.py @@ -3,7 +3,9 @@ from pathlib import Path from typing import Optional -from skyplane import config_path, gcp_config_path, key_root +from skyplane import __config_root__ +from skyplane.cli import config_path +from skyplane.compute.server import key_root from skyplane.config import SkyplaneConfig from skyplane.utils import logger, imports from skyplane.utils.retry import retry_backoff @@ -205,3 +207,6 @@ def get_gcp_instances(self, gcp_region: str): def check_compute_engine_enabled(self): """Check if the GCP compute engine API is enabled""" + + +gcp_config_path = __config_root__ / "gcp_config" diff --git a/skyplane/compute/gcp/gcp_cloud_provider.py b/skyplane/compute/gcp/gcp_cloud_provider.py index fd75a65c4..da47f03da 100644 --- a/skyplane/compute/gcp/gcp_cloud_provider.py +++ b/skyplane/compute/gcp/gcp_cloud_provider.py @@ -13,12 +13,12 @@ warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import paramiko -from skyplane import exceptions, key_root +from skyplane import exceptions from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider from skyplane.compute.cloud_providers import CloudProvider from skyplane.compute.gcp.gcp_auth import GCPAuthentication from skyplane.compute.gcp.gcp_server import GCPServer -from skyplane.compute.server import Server, ServerState +from skyplane.compute.server import Server, ServerState, key_root from skyplane.utils import logger from skyplane.utils.fn import wait_for diff --git a/skyplane/compute/gcp/gcp_server.py b/skyplane/compute/gcp/gcp_server.py index 3c0262bd8..ad0c2ac80 100644 --- a/skyplane/compute/gcp/gcp_server.py +++ b/skyplane/compute/gcp/gcp_server.py @@ -8,9 +8,9 @@ warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import paramiko -from skyplane import key_root, exceptions +from skyplane import exceptions from skyplane.compute.gcp.gcp_auth import GCPAuthentication -from skyplane.compute.server import Server, ServerState +from skyplane.compute.server import Server, ServerState, key_root from skyplane.utils.fn import PathLike diff --git a/skyplane/compute/server.py b/skyplane/compute/server.py index fca10362a..548d8ba08 100644 --- a/skyplane/compute/server.py +++ b/skyplane/compute/server.py @@ -10,7 +10,8 @@ import urllib3 -from skyplane import config_path, key_root, cloud_config +from skyplane import __config_root__ +from skyplane.cli import config_path, cloud_config from skyplane.compute.const_cmds import make_autoshutdown_script, make_dozzle_command, make_sysctl_tcp_tuning_command from skyplane.utils import logger from skyplane.utils.fn import PathLike, wait_for @@ -372,3 +373,6 @@ def is_api_ready(): raise e finally: logging.disable(logging.NOTSET) + + +key_root = __config_root__ / "keys" diff --git a/skyplane/gateway/gateway_daemon.py b/skyplane/gateway/gateway_daemon.py index f71e03ff1..962660e56 100644 --- a/skyplane/gateway/gateway_daemon.py +++ b/skyplane/gateway/gateway_daemon.py @@ -12,7 +12,7 @@ from threading import BoundedSemaphore from typing import Dict -from skyplane import MB +from skyplane.utils.definitions import MB from skyplane.chunk import ChunkState from skyplane.gateway.chunk_store import ChunkStore from skyplane.gateway.gateway_daemon_api import GatewayDaemonAPI diff --git a/skyplane/gateway/gateway_obj_store.py b/skyplane/gateway/gateway_obj_store.py index 7742e54da..bd3a18430 100644 --- a/skyplane/gateway/gateway_obj_store.py +++ b/skyplane/gateway/gateway_obj_store.py @@ -6,7 +6,7 @@ from multiprocessing import Event, Manager, Process, Value, Queue from typing import Dict, Optional -from skyplane import cloud_config +from skyplane.cli import cloud_config from skyplane.chunk import ChunkRequest from skyplane.gateway.chunk_store import ChunkStore from skyplane.obj_store.object_store_interface import ObjectStoreInterface diff --git a/skyplane/gateway/gateway_receiver.py b/skyplane/gateway/gateway_receiver.py index 468fb92db..a5f51cfaf 100644 --- a/skyplane/gateway/gateway_receiver.py +++ b/skyplane/gateway/gateway_receiver.py @@ -11,7 +11,7 @@ import lz4.frame import nacl.secret -from skyplane import MB +from skyplane.utils.definitions import MB from skyplane.chunk import WireProtocolHeader from skyplane.gateway.cert import generate_self_signed_certificate from skyplane.gateway.chunk_store import ChunkStore diff --git a/skyplane/gateway/gateway_sender.py b/skyplane/gateway/gateway_sender.py index d5f9c3a2b..cda599f97 100644 --- a/skyplane/gateway/gateway_sender.py +++ b/skyplane/gateway/gateway_sender.py @@ -12,7 +12,7 @@ import nacl.secret import urllib3 -from skyplane import MB +from skyplane.utils.definitions import MB from skyplane.chunk import ChunkRequest from skyplane.gateway.chunk_store import ChunkStore from skyplane.utils import logger diff --git a/skyplane/replicate/replication_plan.py b/skyplane/replicate/replication_plan.py index 39f9e2c26..e9306fa1a 100644 --- a/skyplane/replicate/replication_plan.py +++ b/skyplane/replicate/replication_plan.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple -from skyplane import MB +from skyplane.utils.definitions import MB from skyplane.chunk import ChunkRequest from skyplane.obj_store.object_store_interface import ObjectStoreObject from skyplane.utils import logger diff --git a/skyplane/replicate/replicator_client.py b/skyplane/replicate/replicator_client.py index ec01d8da6..5d0cf57a2 100644 --- a/skyplane/replicate/replicator_client.py +++ b/skyplane/replicate/replicator_client.py @@ -15,7 +15,8 @@ import urllib3 from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TextColumn, TimeRemainingColumn, TransferSpeedColumn -from skyplane import GB, MB, exceptions, gateway_docker_image, tmp_log_dir +from skyplane import exceptions, tmp_log_dir +from skyplane.utils.definitions import MB, GB, gateway_docker_image from skyplane.chunk import Chunk, ChunkRequest, ChunkState from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider diff --git a/skyplane/replicate/solver.py b/skyplane/replicate/solver.py index 1bbc1e0fa..da48e8055 100644 --- a/skyplane/replicate/solver.py +++ b/skyplane/replicate/solver.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd -from skyplane import GB +from skyplane.utils.definitions import GB from skyplane.compute.cloud_providers import CloudProvider from skyplane.replicate.replication_plan import ReplicationTopology from skyplane.utils import logger diff --git a/skyplane/utils/definitions.py b/skyplane/utils/definitions.py new file mode 100644 index 000000000..f855ec207 --- /dev/null +++ b/skyplane/utils/definitions.py @@ -0,0 +1,25 @@ +import os + +from skyplane.gateway_version import gateway_version + +KB = 1024 +MB = 1024 * 1024 +GB = 1024 * 1024 * 1024 + + +def format_bytes(bytes_int: int): + if bytes_int < KB: + return f"{bytes_int}B" + elif bytes_int < MB: + return f"{bytes_int / KB:.2f}KB" + elif bytes_int < GB: + return f"{bytes_int / MB:.2f}MB" + else: + return f"{bytes_int / GB:.2f}GB" + + +is_gateway_env = os.environ.get("SKYPLANE_IS_GATEWAY", None) == "1" + + +def gateway_docker_image(): + return "public.ecr.aws/s6m1p0n8/skyplane:" + gateway_version diff --git a/skyplane/utils/logger.py b/skyplane/utils/logger.py index 45105110b..a475da0b2 100644 --- a/skyplane/utils/logger.py +++ b/skyplane/utils/logger.py @@ -5,7 +5,7 @@ from rich import print as rprint -from skyplane import is_gateway_env +from skyplane.utils.definitions import is_gateway_env log_file = None diff --git a/tests/integration/cp.py b/tests/integration/cp.py index 3b9126ec4..6543104f6 100644 --- a/tests/integration/cp.py +++ b/tests/integration/cp.py @@ -3,7 +3,7 @@ import tempfile import time import uuid -from skyplane import MB +from skyplane.utils.definitions import MB from skyplane.obj_store.object_store_interface import ObjectStoreInterface from skyplane.cli.cli import cp from skyplane.utils import logger diff --git a/tests/interface_util.py b/tests/interface_util.py index e01f3bf97..f8179933f 100644 --- a/tests/interface_util.py +++ b/tests/interface_util.py @@ -3,7 +3,7 @@ import os import tempfile import uuid -from skyplane import MB +from skyplane.utils.definitions import MB from skyplane.obj_store.object_store_interface import ObjectStoreInterface from skyplane.utils.fn import wait_for