Skip to content

Commit

Permalink
Merge pull request #55 from nasaharvest/ee-status-check-fix
Browse files Browse the repository at this point in the history
EE status check fix
  • Loading branch information
ivanzvonkov authored Jul 11, 2022
2 parents 5128d4e + dfe0d84 commit 635e3c2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
9 changes: 7 additions & 2 deletions openmapflow/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,16 @@ def get_ee_task_amount(prefix: Optional[str] = None):
Returns:
Amount of active tasks.
"""
ee_prefix = None
if prefix is not None:
ee_prefix = prefix.replace("/", "-").replace("=", "-")
amount = 0
task_list = ee.data.getTaskList()
for t in tqdm(task_list):
valid_state = t["state"] in ["READY", "RUNNING"]
if valid_state and (prefix is None or prefix in t["description"]):
if valid_state and (
ee_prefix is None or t["description"].startswith(ee_prefix)
):
amount += 1
return amount

Expand Down Expand Up @@ -123,7 +128,7 @@ def get_status(prefix: str) -> Tuple[int, int, int]:
amount of predictions made.
"""
print_between_lines(prefix)
ee_task_amount = get_ee_task_amount(prefix=prefix.replace("/", "-"))
ee_task_amount = get_ee_task_amount(prefix=prefix)
tifs_amount = get_gcs_file_amount(bn.INFERENCE_TIFS, prefix=prefix)
predictions_amount = get_gcs_file_amount(bn.PREDS, prefix=prefix)
print(f"1) Obtaining input data: {ee_task_amount}")
Expand Down
20 changes: 20 additions & 0 deletions tests/test_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,26 @@ def test_get_status(self, mock_storage, mock_ee):
] * 10
self.assertEqual(get_status("fake_prefix"), (10, 100, 100))

@patch("openmapflow.inference_utils.ee")
@patch("openmapflow.inference_utils.storage")
def test_get_status_special_chars(self, mock_storage, mock_ee):
mock_storage_client = mock_storage.Client()
mock_storage_client.list_blobs.return_value = [MockBlob("file")] * 100
mock_ee.data.getTaskList.return_value = [
{"state": "READY", "description": "fake_prefix_lon-10"}
] * 10
self.assertEqual(get_status("fake_prefix_lon=10"), (10, 100, 100))

@patch("openmapflow.inference_utils.ee")
@patch("openmapflow.inference_utils.storage")
def test_get_status_prefix(self, mock_storage, mock_ee):
mock_storage_client = mock_storage.Client()
mock_storage_client.list_blobs.return_value = [MockBlob("file")] * 100
mock_ee.data.getTaskList.return_value = [
{"state": "READY", "description": "fake_prefix_lon_10"}
] * 10
self.assertEqual(get_status("fake_prefix"), (10, 100, 100))

@patch("openmapflow.inference_utils.storage")
def test_find_missing_predictions(self, mock_storage):
mock_storage_client = mock_storage.Client()
Expand Down

0 comments on commit 635e3c2

Please sign in to comment.