Skip to content

Commit

Permalink
Merge branch 'main' into better-opening
Browse files Browse the repository at this point in the history
  • Loading branch information
betolink authored Oct 13, 2023
2 parents c2bb1cd + 250848d commit 0b6bb98
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 29 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## [unreleased]
* bug fixes:
* Fix spelling mistake in `access` variable assignment (`direc` -> `direct`)
in `earthaccess.store._get_granules`.
* Pass `threads` arg to `_open_urls_https` in
`earthaccess.store._open_urls`, replacing the hard-coded value of 8.

## [v0.6.0] 2023-09-20
* bug fixes:
Expand Down
9 changes: 6 additions & 3 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fsspec import AbstractFileSystem

from .auth import Auth
from .results import DataGranule
from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery
from .store import Store
from .utils import _validation as validate
Expand Down Expand Up @@ -150,8 +151,8 @@ def login(strategy: str = "all", persist: bool = False) -> Auth:


def download(
granules: Union[List[earthaccess.results.DataGranule], List[str]],
local_path: Optional[str],
granules: Union[DataGranule, List[DataGranule], List[str]],
local_path: Union[str, None],
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand All @@ -161,14 +162,16 @@ def download(
* If we run it outside AWS (us-west-2 region) and the dataset is cloud hostes we'll use HTTP links
Parameters:
granules: a list of granules(DataGranule) instances or a list of granule links (HTTP)
granules: a granule, list of granules, or a list of granule links (HTTP)
local_path: local directory to store the remote data granules
provider: if we download a list of URLs we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8
Returns:
List of downloaded files
"""
if isinstance(granules, DataGranule):
granules = [granules]
try:
results = earthaccess.__store__.get(granules, local_path, provider, threads)
except AttributeError as err:
Expand Down
40 changes: 21 additions & 19 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ def _open_urls(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
return None

fileset = self._open_urls_https(data_links, granules, 8, sizes)

return fileset

def get(
Expand All @@ -480,6 +482,12 @@ def get(
Returns:
List of downloaded files
"""
if local_path is None:
local_path = os.path.join(
".",
"data",
f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}",
)
if len(granules):
files = self._get(granules, local_path, provider, threads)
return files
Expand All @@ -491,7 +499,7 @@ def get(
def _get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand Down Expand Up @@ -519,7 +527,7 @@ def _get(
def _get_urls(
self,
granules: List[str],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand All @@ -536,22 +544,21 @@ def _get_urls(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this parallel or concurrent
for file in data_links:
file_name = file.split("/")[-1]
s3_fs.get(file, local_path)
print(f"Retrieved: {file} to {local_path}")
file_name = os.path.join(local_path, os.path.basename(file))
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files

else:
# if we are not in AWS
return self._download_onprem_granules(data_links, local_path, threads)
return None

@_get.register
def _get_granules(
self,
granules: List[DataGranule],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand All @@ -560,7 +567,7 @@ def _get_granules(
provider = granules[0]["meta"]["provider-id"]
endpoint = self._own_s3_credentials(granules[0]["umm"]["RelatedUrls"])
cloud_hosted = granules[0].cloud_hosted
access = "direc" if (cloud_hosted and self.running_in_aws) else "external"
access = "direct" if (cloud_hosted and self.running_in_aws) else "external"
data_links = list(
# we are not in region
chain.from_iterable(
Expand All @@ -584,14 +591,13 @@ def _get_granules(
# TODO: make this async
for file in data_links:
s3_fs.get(file, local_path)
file_name = file.split("/")[-1]
print(f"Retrieved: {file} to {local_path}")
file_name = os.path.join(local_path, os.path.basename(file))
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
else:
# if the data is cloud based bu we are not in AWS it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)
return None

def _download_file(self, url: str, directory: str) -> str:
"""
Expand Down Expand Up @@ -625,10 +631,10 @@ def _download_file(self, url: str, directory: str) -> str:
raise Exception
else:
print(f"File {local_filename} already downloaded")
return local_filename
return local_path

def _download_onprem_granules(
self, urls: List[str], directory: Optional[str] = None, threads: int = 8
self, urls: List[str], directory: str, threads: int = 8
) -> List[Any]:
"""
downloads a list of URLS into the data directory.
Expand All @@ -645,14 +651,10 @@ def _download_onprem_granules(
"We need to be logged into NASA EDL in order to download data granules"
)
return []
if directory is None:
directory_prefix = f"./data/{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}"
else:
directory_prefix = directory
if not os.path.exists(directory_prefix):
os.makedirs(directory_prefix)
if not os.path.exists(directory):
os.makedirs(directory)

arguments = [(url, directory_prefix) for url in urls]
arguments = [(url, directory) for url in urls]
results = pqdm(
arguments,
self._download_file,
Expand Down
12 changes: 5 additions & 7 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# package imports
import logging
import os
import shutil
import unittest

import earthaccess
Expand Down Expand Up @@ -69,16 +68,15 @@ def test_granules_search_returns_valid_results(kwargs):
assertions.assertTrue(len(results) <= 10)


def test_earthaccess_api_can_download_granules():
@pytest.mark.parametrize("selection", [0, slice(None)])
def test_earthaccess_api_can_download_granules(tmp_path, selection):
results = earthaccess.search_data(
count=2,
short_name="ATL08",
cloud_hosted=True,
bounding_box=(-92.86, 16.26, -91.58, 16.97),
)
local_path = "./tests/integration/data/ATL08"
assertions.assertIsInstance(results, list)
assertions.assertTrue(len(results) <= 2)
files = earthaccess.download(results, local_path=local_path)
result = results[selection]
files = earthaccess.download(result, str(tmp_path))
assertions.assertIsInstance(files, list)
shutil.rmtree(local_path)
assert all(os.path.exists(f) for f in files)

0 comments on commit 0b6bb98

Please sign in to comment.