Skip to content

Commit

Permalink
Merge pull request #380 from allenai/shanea/storage-cleaner-download-…
Browse files Browse the repository at this point in the history
…upload

[Storage Cleaner] Add more basic functionality to storage adapters
  • Loading branch information
2015aroras authored Dec 7, 2023
2 parents 4e849e4 + 3f9d55a commit a120ab2
Showing 1 changed file with 108 additions and 8 deletions.
116 changes: 108 additions & 8 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from argparse import ArgumentParser, _SubParsersAction
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
Expand All @@ -17,6 +18,7 @@
from cached_path import add_scheme_client, cached_path, set_cache_dir
from cached_path.schemes import S3Client
from google.api_core.exceptions import NotFound
from rich.progress import Progress, TaskID, track

from olmo import util
from olmo.aliases import PathOrStr
Expand Down Expand Up @@ -44,7 +46,7 @@ def list_entries(self, directory: str, max_file_size: Optional[int] = None) -> L
"""Lists all the entries within the given directory.
Returns only top-level entries (i.e. not entries in subdirectories).
max_file_size sets a threshold (in bytes) for the largest size file to retain within entries.
`max_file_size`: Sets a threshold (in bytes) for the largest size file to retain within entries.
Any file of larger size is not included in the returned results.
"""

Expand Down Expand Up @@ -74,6 +76,14 @@ def get_file_size(self, path: str) -> int:
def is_dir(self, path: str) -> bool:
"""Returns whether the given path corresponds to an existing directory."""

@abstractmethod
def download_folder(self, directory_path: str, local_dest_folder: PathOrStr):
"""Downloads the content from the directory path to the local FS destination folder."""

@abstractmethod
def upload(self, local_src: PathOrStr, dest_path: str):
"""Uploads the content from the directory or file at the local FS source to the path."""

@classmethod
def create_storage_adapter(cls, storage_type: StorageType):
if storage_type == StorageType.LOCAL_FS:
Expand Down Expand Up @@ -109,19 +119,27 @@ class LocalFileSystemAdapter(StorageAdapter):
def __init__(self) -> None:
super().__init__()
self._temp_files: List[tempfile._TemporaryFileWrapper[bytes]] = []
self._temp_dirs: List[tempfile.TemporaryDirectory] = []
self._archive_extensions: List[str] = [
extension.lower() for _, extensions, _ in shutil.get_unpack_formats() for extension in extensions
]

def __del__(self):
for temp_file in self._temp_files:
temp_file.close()
for temp_dir in self._temp_dirs:
temp_dir.cleanup()

def create_temp_file(self, suffix: Optional[str] = None) -> str:
temp_file = tempfile.NamedTemporaryFile(suffix=suffix)
self._temp_files.append(temp_file)
return temp_file.name

def create_temp_dir(self, suffix: Optional[str] = None) -> str:
temp_dir = tempfile.TemporaryDirectory(suffix=suffix)
self._temp_dirs.append(temp_dir)
return temp_dir.name

def has_supported_archive_extension(self, path: PathOrStr) -> bool:
filename = Path(path).name.lower()
return any(filename.endswith(extension) for extension in self._archive_extensions)
Expand Down Expand Up @@ -177,6 +195,25 @@ def is_dir(self, path: str) -> bool:

return path_obj.is_dir()

def download_folder(self, directory_path: str, local_dest_folder: PathOrStr):
directory_path_obj = Path(directory_path)
if not directory_path_obj.exists():
raise ValueError(f"No entry exists at path {directory_path}")

if directory_path_obj.is_dir():
shutil.copytree(directory_path, str(local_dest_folder), dirs_exist_ok=True)
else:
raise RuntimeError(f"Unexpected type of path {directory_path}")

def upload(self, local_src: PathOrStr, dest_path: str):
local_src_obj = Path(local_src)
if local_src_obj.is_file():
shutil.copy(str(local_src_obj), dest_path)
elif local_src_obj.is_dir():
self.download_folder(str(local_src), dest_path)
else:
raise RuntimeError(f"Unexpected type of local src path {local_src}")


class GoogleCloudStorageAdapter(StorageAdapter):
def __init__(self) -> None:
Expand Down Expand Up @@ -271,11 +308,11 @@ def _get_directory_entries(
def _list_entries(
self, directory: str, include_files: bool = True, max_file_size: Optional[int] = None
) -> List[str]:
if not self.is_dir(directory):
raise ValueError(f"{directory} is not an existing directory")

bucket_name, key = self._get_bucket_name_and_key(directory)

if not self._is_dir(bucket_name, key):
raise ValueError(f"{directory} is not an existing directory")

res = self._get_directory_entries(
bucket_name, key, include_files=include_files, max_file_size=max_file_size
)
Expand Down Expand Up @@ -306,14 +343,38 @@ def get_file_size(self, path: str) -> int:

return self._get_size(bucket_name, key)

def is_dir(self, path: str) -> bool:
path = f"{path}/" if not path.endswith("/") else path
bucket_name, key = self._get_bucket_name_and_key(path)
def _is_dir(self, bucket_name: str, key: str) -> bool:
key = f"{key}/" if not key.endswith("/") else key

bucket = self.gcs_client.bucket(bucket_name)
blobs = list(bucket.list_blobs(prefix=key, max_results=1))

return not self._is_file(bucket_name, key) and len(blobs) > 0

def is_dir(self, path: str) -> bool:
bucket_name, key = self._get_bucket_name_and_key(path)

return self._is_dir(bucket_name, key)

def download_folder(self, directory_path: str, local_dest_folder: PathOrStr):
bucket_name, key = self._get_bucket_name_and_key(directory_path)
bucket = self.gcs_client.bucket(bucket_name)

if self._is_dir(bucket_name, key):
blobs: List[gcs.Blob] = list(bucket.list_blobs(prefix=key))

for blob in track(blobs, description=f"Downloading files at {directory_path}"):
if not blob.name:
raise NotImplementedError()
blob_path: str = blob.name
blob_local_dest = blob_path.replace(key.rstrip("/"), str(local_dest_folder).rstrip("/"))
blob.download_to_filename(blob_local_dest)
else:
raise ValueError(f"Path {directory_path} is not a valid directory")

def upload(self, local_src: PathOrStr, dest_path: str):
raise NotImplementedError()


class S3StorageAdapter(StorageAdapter):
def __init__(self, storage_type: StorageType):
Expand Down Expand Up @@ -473,18 +534,57 @@ def get_file_size(self, path: str) -> int:
return self._get_size(bucket_name, key)

def _is_dir(self, bucket_name: str, key: str) -> bool:
key = f"{key}/" if not key.endswith("/") else key
if self._is_file(bucket_name, key):
return False

response = self._s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key, MaxKeys=1)
return "Contents" in response

def is_dir(self, path: str) -> bool:
path = f"{path}/" if not path.endswith("/") else path
bucket_name, key = self._get_bucket_name_and_key(path)

return self._is_dir(bucket_name, key)

def download_folder(self, directory_path: str, local_dest_folder: PathOrStr):
bucket_name, key = self._get_bucket_name_and_key(directory_path)

if self._is_dir(bucket_name, key):
response = self._s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key)
objects_metadata: List[Dict[str, Any]] = response["Contents"]
for object_metadata in track(objects_metadata, description=f"Downloading files at {directory_path}"):
object_key: str = object_metadata["Key"]
object_local_dest = object_key.replace(key.rstrip("/"), str(local_dest_folder).rstrip("/"))

self._s3_client.download_file(bucket_name, key, object_local_dest)
else:
raise ValueError(f"Path {directory_path} is not a valid directory")

def upload(self, local_src: PathOrStr, dest_path: str):
if self.local_fs_adapter.is_file(str(local_src)):
bucket_name, key = self._get_bucket_name_and_key(dest_path)
self._s3_client.upload_file(str(local_src), bucket_name, key)

elif self.local_fs_adapter.is_dir(str(local_src)):
local_src = Path(local_src)

def upload_callback(progress: Progress, upload_task: TaskID, bytes_uploaded: int):
progress.update(upload_task, advance=bytes_uploaded)

for file_local_path in local_src.rglob("*"):
file_dest_path = str(file_local_path).replace(str(local_src), dest_path)
bucket_name, key = self._get_bucket_name_and_key(file_dest_path)

with Progress(transient=True) as progress:
size_in_bytes = file_local_path.stat().st_size
upload_task = progress.add_task(f"Uploading {key}", total=size_in_bytes)
callback = partial(upload_callback, progress, upload_task)

self._s3_client.upload_file(str(file_local_path), bucket_name, key, Callback=callback)

else:
raise ValueError(f"Local source {local_src} does not correspond to a valid file or directory")


@dataclass
class DeleteBadRunsConfig:
Expand Down

0 comments on commit a120ab2

Please sign in to comment.