diff --git a/openmapflow/inference_utils.py b/openmapflow/inference_utils.py index 3c4b4788..5630cd92 100644 --- a/openmapflow/inference_utils.py +++ b/openmapflow/inference_utils.py @@ -73,11 +73,16 @@ def get_ee_task_amount(prefix: Optional[str] = None): Returns: Amount of active tasks. """ + ee_prefix = None + if prefix is not None: + ee_prefix = prefix.replace("/", "-").replace("=", "-") amount = 0 task_list = ee.data.getTaskList() for t in tqdm(task_list): valid_state = t["state"] in ["READY", "RUNNING"] - if valid_state and (prefix is None or prefix in t["description"]): + if valid_state and ( + ee_prefix is None or t["description"].startswith(ee_prefix) + ): amount += 1 return amount @@ -123,7 +128,7 @@ def get_status(prefix: str) -> Tuple[int, int, int]: amount of predictions made. """ print_between_lines(prefix) - ee_task_amount = get_ee_task_amount(prefix=prefix.replace("/", "-")) + ee_task_amount = get_ee_task_amount(prefix=prefix) tifs_amount = get_gcs_file_amount(bn.INFERENCE_TIFS, prefix=prefix) predictions_amount = get_gcs_file_amount(bn.PREDS, prefix=prefix) print(f"1) Obtaining input data: {ee_task_amount}") diff --git a/tests/test_inference_utils.py b/tests/test_inference_utils.py index 8d9a5ed8..469cbdae 100644 --- a/tests/test_inference_utils.py +++ b/tests/test_inference_utils.py @@ -200,6 +200,26 @@ def test_get_status(self, mock_storage, mock_ee): ] * 10 self.assertEqual(get_status("fake_prefix"), (10, 100, 100)) + @patch("openmapflow.inference_utils.ee") + @patch("openmapflow.inference_utils.storage") + def test_get_status_special_chars(self, mock_storage, mock_ee): + mock_storage_client = mock_storage.Client() + mock_storage_client.list_blobs.return_value = [MockBlob("file")] * 100 + mock_ee.data.getTaskList.return_value = [ + {"state": "READY", "description": "fake_prefix_lon-10"} + ] * 10 + self.assertEqual(get_status("fake_prefix_lon=10"), (10, 100, 100)) + + @patch("openmapflow.inference_utils.ee") + @patch("openmapflow.inference_utils.storage") + def test_get_status_prefix(self, mock_storage, mock_ee): + mock_storage_client = mock_storage.Client() + mock_storage_client.list_blobs.return_value = [MockBlob("file")] * 100 + mock_ee.data.getTaskList.return_value = [ + {"state": "READY", "description": "fake_prefix_lon_10"} + ] * 10 + self.assertEqual(get_status("fake_prefix"), (10, 100, 100)) + @patch("openmapflow.inference_utils.storage") def test_find_missing_predictions(self, mock_storage): mock_storage_client = mock_storage.Client()