diff --git a/asteroid/utils/hub_utils.py b/asteroid/utils/hub_utils.py index 6da7b49bb..be24927db 100644 --- a/asteroid/utils/hub_utils.py +++ b/asteroid/utils/hub_utils.py @@ -5,9 +5,9 @@ import sys import tempfile from contextlib import contextmanager -from functools import partial +from functools import partial, lru_cache from hashlib import sha256 -from typing import BinaryIO, Dict, Optional, Union +from typing import BinaryIO, Dict, Optional, Union, List from urllib.parse import urlparse import requests @@ -38,6 +38,7 @@ HF_WEIGHTS_NAME = "pytorch_model.bin" HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}" +ENDPOINT = "https://huggingface.co" def cached_download(filename_or_url): @@ -309,3 +310,16 @@ def _resumable_file_manager() -> "io.BufferedWriter": json.dump(meta, meta_file) return cache_path + + +@lru_cache() +def model_list(endpoint=ENDPOINT, name_only=False) -> List[Dict]: + """Get the public list of all the models on huggingface with an 'asteroid' tag.""" + path = "{}/api/models?full=true".format(endpoint) + r = requests.get(path) + r.raise_for_status() + d = r.json() + all_models = [x for x in d if "asteroid" in x.get("tags", [])] + if name_only: + return [x["modelId"] for x in all_models] + return all_models