Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Allow google cloud storage locations for cached_path #5173

Merged
merged 5 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Use `dist_reduce_sum` in distributed metrics.
- Allow Google Cloud Storage paths in `cached_path` ("gs://...").

### Added

Expand Down
73 changes: 65 additions & 8 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import botocore
import torch
from filelock import FileLock as _FileLock
from google.cloud import storage
from google.api_core.exceptions import NotFound
import numpy as np
from overrides import overrides
import requests
Expand Down Expand Up @@ -211,7 +213,7 @@ def cached_path(
then return the path to the cached file. If it's already a local path,
make sure the file exists and return the path.

For URLs, "http://", "https://", "s3://", and "hf://" are all supported.
For URLs, "http://", "https://", "s3://", "gs://", and "hf://" are all supported.
The latter corresponds to the HuggingFace Hub.

For example, to download the PyTorch weights for the model `epwalsh/bert-xsmall-dummy`
Expand Down Expand Up @@ -281,7 +283,7 @@ def cached_path(

parsed = urlparse(url_or_filename)

if parsed.scheme in ("http", "https", "s3", "hf"):
if parsed.scheme in ("http", "https", "s3", "hf", "gs"):
# URL, so get it from the cache (downloading if necessary)
file_path = get_from_cache(url_or_filename, cache_dir)

Expand Down Expand Up @@ -373,20 +375,28 @@ def is_url_or_existing_file(url_or_filename: Union[str, Path, None]) -> bool:
return False
url_or_filename = os.path.expanduser(str(url_or_filename))
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https", "s3") or os.path.exists(url_or_filename)
return parsed.scheme in ("http", "https", "s3", "gs") or os.path.exists(url_or_filename)


def _split_s3_path(url: str) -> Tuple[str, str]:
return _split_cloud_path(url, "s3")


def _split_gcs_path(url: str) -> Tuple[str, str]:
return _split_cloud_path(url, "gs")


def _split_cloud_path(url: str, provider: str) -> Tuple[str, str]:
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
raise ValueError("bad {} path {}".format(provider, url))
bucket_name = parsed.netloc
s3_path = parsed.path
provider_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
if provider_path.startswith("/"):
provider_path = provider_path[1:]
return bucket_name, provider_path


def _s3_request(func: Callable):
Expand Down Expand Up @@ -437,6 +447,49 @@ def _s3_get(url: str, temp_file: IO) -> None:
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


def _gcs_request(func: Callable):
"""
Wrapper function for gcs requests in order to create more helpful error
messages.
"""

@wraps(func)
def wrapper(url: str, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except NotFound:
raise FileNotFoundError("file {} not found".format(url))

return wrapper


def _get_gcs_client():
storage_client = storage.Client()
return storage_client


def _get_gcs_blob(url: str) -> storage.blob.Blob:
gcs_resource = _get_gcs_client()
bucket_name, gcs_path = _split_gcs_path(url)
bucket = gcs_resource.bucket(bucket_name)
blob = bucket.blob(gcs_path)
return blob


@_gcs_request
def _gcs_md5(url: str) -> Optional[str]:
"""Get GCS object's md5."""
blob = _get_gcs_blob(url)
return blob.md5_hash


@_gcs_request
def _gcs_get(url: str, temp_filename: str) -> None:
"""Pull a file directly from GCS."""
blob = _get_gcs_blob(url)
blob.download_to_filename(temp_filename)


def _session_with_backoff() -> requests.Session:
"""
We ran into an issue where http requests to s3 were timing out,
Expand Down Expand Up @@ -923,6 +976,8 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
try:
if url.startswith("s3://"):
etag = _s3_etag(url)
elif url.startswith("gs://"):
etag = _gcs_md5(url)
else:
etag = _http_etag(url)
except (requests.exceptions.ConnectionError, botocore.exceptions.EndpointConnectionError):
Expand Down Expand Up @@ -977,6 +1032,8 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
# GET file object
if url.startswith("s3://"):
_s3_get(url, cache_file)
elif url.startswith("gs://"):
_gcs_get(url, cache_file.name)
else:
_http_get(url, cache_file)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"checklist==0.0.10",
"wandb>=0.10.0,<0.11.0",
"huggingface_hub>=0.0.8",
"google-cloud-storage>=1.38.0,<1.39.0",
],
entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]},
include_package_data=True,
Expand Down
12 changes: 12 additions & 0 deletions tests/common/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_from_cache,
cached_path,
_split_s3_path,
_split_gcs_path,
open_compressed,
CacheFile,
_Meta,
Expand Down Expand Up @@ -228,6 +229,17 @@ def test_split_s3_path(self):
_split_s3_path("s3://myfile.txt")
_split_s3_path("myfile.txt")

def test_split_gcs_path(self):
# Test splitting good urls.
assert _split_gcs_path("gs://my-bucket/subdir/file.txt") == ("my-bucket", "subdir/file.txt")
assert _split_gcs_path("gs://my-bucket/file.txt") == ("my-bucket", "file.txt")

# Test splitting bad urls.
with pytest.raises(ValueError):
_split_gcs_path("gs://")
_split_gcs_path("gs://myfile.txt")
_split_gcs_path("myfile.txt")

@responses.activate
def test_get_from_cache(self):
url = "http://fake.datastore.com/glove.txt.gz"
Expand Down