diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index 8507ed6ce..57e477879 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -84,6 +84,10 @@ def cp(src: str, dst: str): copy_local_gcs(Path(path_src), bucket_dst, path_dst) elif provider_src == "gcs" and provider_dst == "local": copy_gcs_local(bucket_src, path_src, Path(path_dst)) + elif provider_src == "local" and provider_dst == "azure": + copy_local_azure(Path(path_src), bucket_dst, path_dst) + elif provider_src == "azure" and provider_dst == "local": + copy_azure_local(bucket_src, path_src, Path(path_dst)) else: raise NotImplementedError(f"{provider_src} to {provider_dst} not supported yet") diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index d21f01152..efd09ae0b 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -93,6 +93,14 @@ def copy_gcs_local(src_bucket: str, src_key: str, dst: Path): raise NotImplementedError(f"GCS not yet supported") +def copy_local_azure(src: Path, dst_bucket: str, dst_key: str): + raise NotImplementedError(f"Azure not yet supported") + + +def copy_azure_local(src_bucket: str, src_key: str, dst: Path): + raise NotImplementedError(f"Azure not yet supported") + + def copy_local_s3(src: Path, dst_bucket: str, dst_key: str, use_tls: bool = True): s3 = S3Interface(None, dst_bucket, use_tls=use_tls) ops: List[concurrent.futures.Future] = [] diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py new file mode 100644 index 000000000..61872037d --- /dev/null +++ b/skylark/obj_store/azure_interface.py @@ -0,0 +1,146 @@ +import mimetypes +import os +import typer +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Iterator, List + +import os, uuid, time +from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient, __version__, BlobBlock + +from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject +from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError + + +class AzureObject(ObjectStoreObject): + def full_path(self): + raise NotImplementedError() + + +class AzureInterface(ObjectStoreInterface): + def __init__(self, azure_region, container_name): + # TODO: the azure region should get corresponding os.getenv() + self.azure_region = azure_region + + self.container_name = container_name + self.bucket_name = self.container_name # For compatibility + self.pending_downloads, self.completed_downloads = 0, 0 + self.pending_uploads, self.completed_uploads = 0, 0 + + # Retrieve the connection string for use with the application. The storage + # connection string is stored in an environment variable on the machine + # running the application called AZURE_STORAGE_CONNECTION_STRING. If the environment variable is + # created after the application is launched in a console or with Visual Studio, + # the shell or application needs to be closed and reloaded to take the + # environment variable into account. + self._connect_str = os.getenv("AZURE_STORAGE_CONNECTION_STRING") + # Create the BlobServiceClient object which will be used to create a container client + self.blob_service_client = BlobServiceClient.from_connection_string(self._connect_str) + + self.container_client = None + + # TODO:: Figure this out, since azure by default has 15 workers + self.pool = ThreadPoolExecutor(max_workers=1) + self.max_concurrency = 24 + + def _on_done_download(self, **kwargs): + self.completed_downloads += 1 + self.pending_downloads -= 1 + + def _on_done_upload(self, **kwargs): + self.completed_uploads += 1 + self.pending_uploads -= 1 + + def container_exists(self): # More like "is container empty?" + # Get a client to interact with a specific container - though it may not yet exist + if self.container_client is None: + self.container_client = self.blob_service_client.get_container_client(self.container_name) + try: + for blob in self.container_client.list_blobs(): + return True + except ResourceNotFoundError: + return False + + def create_container(self): + try: + self.container_client = self.blob_service_client.create_container(self.container_name) + self.properties = self.container_client.get_container_properties() + except ResourceExistsError: + typer.secho("Container already exists. Exiting") + exit(-1) + + def create_bucket(self): + return self.create_container() + + def delete_container(self): + if self.container_client is None: + self.container_client = self.blob_service_client.get_container_client(self.container_name) + try: + self.container_client.delete_container() + except ResourceNotFoundError: + typer.secho("Container doesn't exists. Unable to delete") + + def delete_bucket(self): + return self.delete_container() + + def list_objects(self, prefix="") -> Iterator[AzureObject]: + if self.container_client is None: + self.container_client = self.blob_service_client.get_container_client(self.container_name) + blobs = self.container_client.list_blobs() + for blob in blobs: + yield AzureObject("azure", blob.container, blob.name, blob.size, blob.last_modified) + + def delete_objects(self, keys: List[str]): + for key in keys: + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=key) + blob_client.delete_blob() + + def get_obj_metadata(self, obj_name): # Not Tested + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=obj_name) + try: + return blob_client.get_blob_properties() + except ResourceNotFoundError: + typer.secho("No blob found.") + + def get_obj_size(self, obj_name): + return self.get_obj_metadata(obj_name).size + + def exists(self, obj_name): + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=obj_name) + try: + blob_client.get_blob_properties() + return True + except ResourceNotFoundError: + return False + + """ + stream = blob_client.download_blob() + for chunk in stream.chunks(): + # Reading data in chunks to avoid loading all into memory at once + """ + + def download_object(self, src_object_name, dst_file_path) -> Future: + src_object_name, dst_file_path = str(src_object_name), str(dst_file_path) + src_object_name = src_object_name if src_object_name[0] != "/" else src_object_name + + def _download_object_helper(offset, **kwargs): + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=src_object_name) + # write file + if not os.path.exists(dst_file_path): + open(dst_file_path, "a").close() + with open(dst_file_path, "rb+") as download_file: + download_file.write(blob_client.download_blob(max_concurrency=self.max_concurrency).readall()) + + return self.pool.submit(_download_object_helper, 0) + + def upload_object(self, src_file_path, dst_object_name, content_type="infer") -> Future: + src_file_path, dst_object_name = str(src_file_path), str(dst_object_name) + dst_object_name = dst_object_name if dst_object_name[0] != "/" else dst_object_name + os.path.getsize(src_file_path) + + def _upload_object_helper(): + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=dst_object_name) + with open(src_file_path, "rb") as data: + blob_client.upload_blob(data) + return True + + return self.pool.submit(_upload_object_helper) diff --git a/skylark/obj_store/object_store_interface.py b/skylark/obj_store/object_store_interface.py index 2c42e1d7f..7f2132f6b 100644 --- a/skylark/obj_store/object_store_interface.py +++ b/skylark/obj_store/object_store_interface.py @@ -22,6 +22,9 @@ def bucket_exists(self): def create_bucket(self): raise NotImplementedError + def delete_bucket(self): + raise NotImplementedError + def list_objects(self, prefix=""): raise NotImplementedError diff --git a/skylark/test/test_azure.py b/skylark/test/test_azure.py new file mode 100644 index 000000000..1b06b56af --- /dev/null +++ b/skylark/test/test_azure.py @@ -0,0 +1,88 @@ +import os, uuid, time +from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient, __version__, BlobBlock + +try: + print("Azure Blob Storage v" + __version__) + # Retrieve the connection string for use with the application. The storage + # connection string is stored in an environment variable on the machine + # running the application called AZURE_STORAGE_CONNECTION_STRING. If the environment variable is + # created after the application is launched in a console or with Visual Studio, + # the shell or application needs to be closed and reloaded to take the + # environment variable into account. + connect_str = os.getenv("AZURE_STORAGE_CONNECTION_STRING") + # Create the BlobServiceClient object which will be used to create a container client + blob_service_client = BlobServiceClient.from_connection_string(connect_str) + + # Create a unique name for the container + container_name = str(uuid.uuid4()) + print(f"Creating container:{container_name}") + + # Create the container + container_client = blob_service_client.create_container(container_name) + # Create a local directory to hold blob data + local_path = "./data" + + # Create a file in the local data directory to upload and download + local_file_name = "demo.txt" + + upload_file_path = os.path.join(local_path, local_file_name) + print("\nFile Size (MB):", os.path.getsize(upload_file_path) / (1024 * 1024)) + + # Create a blob client using the local file name as the name for the blob + blob_client = blob_service_client.get_blob_client(container=container_name, blob=local_file_name) + + print("\nUploading to Azure Storage as blob:\n\t" + local_file_name) + + # Upload the created file + upload_start_time = time.time() + with open(upload_file_path, "rb") as data: + blob_client.upload_blob(data) + print("\nTime to upload from filesys(s):", time.time() - upload_start_time) + + print("\nListing blobs...") + + # List the blobs in the container + blob_list = container_client.list_blobs() + for blob in blob_list: + print("\t" + blob.name) + + # Download the blob to a local file using read_all() + # Add 'DOWNLOAD' before the .txt extension so you can see both files in the data directory + download_file_path = os.path.join(local_path, str.replace(local_file_name, ".txt", "DOWNLOAD_READ_ALL.txt")) + print("\nDownloading blob to \n\t" + download_file_path) + + download_start_time = time.time() + with open(download_file_path, "wb") as download_file: + download_file.write(blob_client.download_blob(max_concurrency=24).readall()) + print("\nTime to Download and write to file (s):", time.time() - download_start_time) + + # Download the blob to a local file using chunks() + stream = blob_client.download_blob() + block_list = [] + + download_chunk_start_time = time.time() + # Reading data in chunks to avoid loading all into memory at once + for chunk in stream.chunks(): + # process your data (anything can be done here `chunk` is a 4M byte array). + # print(chunk.decode()) + # block_id = str(uuid.uuid4()) + # blob_client.stage_block(block_id=block_id, data=chunk) + block_list.append([chunk]) + + print("\nTime to download as chunks (s):", time.time() - download_chunk_start_time) + + # Clean up + print("\nPress the Enter key to begin clean up") + input() + + print("Deleting blob container...") + container_client.delete_container() + + print("Deleting the local source and downloaded files...") + os.remove(download_file_path) + + print("Done") + +except Exception as ex: + print("Exception:") + print(ex) diff --git a/skylark/test/test_azure_interface.py b/skylark/test/test_azure_interface.py new file mode 100644 index 000000000..67f7188f3 --- /dev/null +++ b/skylark/test/test_azure_interface.py @@ -0,0 +1,53 @@ +import hashlib +import os +import tempfile +from skylark import MB + +from skylark.obj_store.azure_interface import AzureInterface +from skylark.utils.utils import Timer + + +def test_azure_interface(): + azure_interface = AzureInterface(f"us-east1", f"sky-us-east-2") + assert azure_interface.bucket_name == "sky-us-east-2" + assert azure_interface.azure_region == "us-east1" + azure_interface.create_bucket() + + # generate file and upload + obj_name = "test.txt" + file_size_mb = 128 + with tempfile.NamedTemporaryFile() as tmp: + fpath = tmp.name + with open(fpath, "wb") as f: + f.write(os.urandom(int(file_size_mb * MB))) + file_md5 = hashlib.md5(open(fpath, "rb").read()).hexdigest() + + with Timer() as t: + upload_future = azure_interface.upload_object(fpath, obj_name) + upload_future.result() + assert azure_interface.get_obj_size(obj_name) == os.path.getsize(fpath) + assert azure_interface.exists(obj_name) + assert not azure_interface.exists("random_nonexistent_file") + + # download object + with tempfile.NamedTemporaryFile() as tmp: + fpath = tmp.name + if os.path.exists(fpath): + os.remove(fpath) + with Timer() as t: + download_future = azure_interface.download_object(obj_name, fpath) + download_future.result() + + # check md5 + dl_file_md5 = hashlib.md5(open(fpath, "rb").read()).hexdigest() + assert dl_file_md5 == file_md5 + + # Clean Up Azure + azure_interface.delete_objects([obj_name]) + assert not azure_interface.exists(obj_name) + azure_interface.delete_bucket() + assert not azure_interface.container_exists() + + +if __name__ == "__main__": + test_azure_interface() diff --git a/skylark/test/test_replicator_client.py b/skylark/test/test_replicator_client.py index 3c15c2e8f..6b5334406 100644 --- a/skylark/test/test_replicator_client.py +++ b/skylark/test/test_replicator_client.py @@ -8,6 +8,7 @@ import os from skylark.obj_store.s3_interface import S3Interface from skylark.obj_store.gcs_interface import GCSInterface +from skylark.obj_store.azure_interface import AzureInterface import tempfile import concurrent @@ -69,6 +70,8 @@ def main(args): obj_store_interface_src = S3Interface(args.src_region.split(":")[1], src_bucket) elif "gcp" in args.src_region: obj_store_interface_src = GCSInterface(args.src_region.split(":")[1][:-2], src_bucket) + elif "azure" in args.src_region: + obj_store_interface_src = AzureInterface(args.src_region.split(":")[1][:-2], src_bucket) else: raise ValueError(f"No region in source region {args.src_region}") @@ -76,6 +79,8 @@ def main(args): obj_store_interface_dst = S3Interface(args.dest_region.split(":")[1], dst_bucket) elif "gcp" in args.dest_region: obj_store_interface_dst = GCSInterface(args.dest_region.split(":")[1][:-2], dst_bucket) + elif "azure" in args.dest_region: + obj_store_interface_dst = AzureInterface(args.dest_region.split(":")[1][:-2], dst_bucket) else: raise ValueError(f"No region in destination region {args.dst_region}")