Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Allow google cloud storage locations for cached_path #5173

Merged
merged 5 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Use `dist_reduce_sum` in distributed metrics.
- Allow Google Cloud Storage paths in `cached_path`.
epwalsh marked this conversation as resolved.
Show resolved Hide resolved

### Added

Expand Down
73 changes: 65 additions & 8 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import botocore
import torch
from filelock import FileLock as _FileLock
from google.cloud import storage
from google.api_core.exceptions import NotFound
import numpy as np
from overrides import overrides
import requests
Expand Down Expand Up @@ -211,7 +213,7 @@ def cached_path(
then return the path to the cached file. If it's already a local path,
make sure the file exists and return the path.

For URLs, "http://", "https://", "s3://", and "hf://" are all supported.
For URLs, "http://", "https://", "s3://", "gs://", and "hf://" are all supported.
The latter corresponds to the HuggingFace Hub.

For example, to download the PyTorch weights for the model `epwalsh/bert-xsmall-dummy`
Expand Down Expand Up @@ -281,7 +283,7 @@ def cached_path(

parsed = urlparse(url_or_filename)

if parsed.scheme in ("http", "https", "s3", "hf"):
if parsed.scheme in ("http", "https", "s3", "hf", "gs"):
# URL, so get it from the cache (downloading if necessary)
file_path = get_from_cache(url_or_filename, cache_dir)

Expand Down Expand Up @@ -373,20 +375,28 @@ def is_url_or_existing_file(url_or_filename: Union[str, Path, None]) -> bool:
return False
url_or_filename = os.path.expanduser(str(url_or_filename))
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https", "s3") or os.path.exists(url_or_filename)
return parsed.scheme in ("http", "https", "s3", "gs") or os.path.exists(url_or_filename)


def _split_s3_path(url: str) -> Tuple[str, str]:
return _split_cloud_path(url, "s3")


def _split_gcs_path(url: str) -> Tuple[str, str]:
return _split_cloud_path(url, "gs")


def _split_cloud_path(url: str, provider: str) -> Tuple[str, str]:
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
raise ValueError("bad {} path {}".format(provider, url))
bucket_name = parsed.netloc
s3_path = parsed.path
provider_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
if provider_path.startswith("/"):
provider_path = provider_path[1:]
return bucket_name, provider_path


def _s3_request(func: Callable):
Expand Down Expand Up @@ -437,6 +447,49 @@ def _s3_get(url: str, temp_file: IO) -> None:
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


def _gcs_request(func: Callable):
"""
Wrapper function for gcs requests in order to create more helpful error
messages.
"""

@wraps(func)
def wrapper(url: str, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except NotFound:
raise FileNotFoundError("file {} not found".format(url))

return wrapper


def _get_gcs_client():
storage_client = storage.Client()
return storage_client


def _get_gcs_blob(url: str) -> storage.blob.Blob:
gcs_resource = _get_gcs_client()
bucket_name, gcs_path = _split_gcs_path(url)
bucket = gcs_resource.bucket(bucket_name)
blob = bucket.blob(gcs_path)
return blob


@_gcs_request
def _gcs_md5(url: str) -> Optional[str]:
"""Get GCS object's md5."""
blob = _get_gcs_blob(url)
return blob.md5_hash


@_gcs_request
def _gcs_get(url: str, temp_filename: str) -> None:
"""Pull a file directly from GCS."""
blob = _get_gcs_blob(url)
blob.download_to_filename(temp_filename)


def _session_with_backoff() -> requests.Session:
"""
We ran into an issue where http requests to s3 were timing out,
Expand Down Expand Up @@ -923,6 +976,8 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
try:
if url.startswith("s3://"):
etag = _s3_etag(url)
elif url.startswith("gs://"):
etag = _gcs_md5(url)
else:
etag = _http_etag(url)
except (requests.exceptions.ConnectionError, botocore.exceptions.EndpointConnectionError):
Expand Down Expand Up @@ -977,6 +1032,8 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
# GET file object
if url.startswith("s3://"):
_s3_get(url, cache_file)
elif url.startswith("gs://"):
_gcs_get(url, cache_file.name)
else:
_http_get(url, cache_file)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"checklist==0.0.10",
"wandb>=0.10.0,<0.11.0",
"huggingface_hub>=0.0.8",
"google-cloud-storage",
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
],
entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]},
include_package_data=True,
Expand Down
12 changes: 12 additions & 0 deletions tests/common/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_from_cache,
cached_path,
_split_s3_path,
_split_gcs_path,
open_compressed,
CacheFile,
_Meta,
Expand Down Expand Up @@ -228,6 +229,17 @@ def test_split_s3_path(self):
_split_s3_path("s3://myfile.txt")
_split_s3_path("myfile.txt")

def test_split_gcs_path(self):
# Test splitting good urls.
assert _split_gcs_path("gs://my-bucket/subdir/file.txt") == ("my-bucket", "subdir/file.txt")
assert _split_gcs_path("gs://my-bucket/file.txt") == ("my-bucket", "file.txt")

# Test splitting bad urls.
with pytest.raises(ValueError):
_split_gcs_path("gs://")
_split_gcs_path("gs://myfile.txt")
_split_gcs_path("myfile.txt")

@responses.activate
def test_get_from_cache(self):
url = "http://fake.datastore.com/glove.txt.gz"
Expand Down