Skip to content

Commit

Permalink
[hub] List asteroid models HF's hub (#382)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente committed Dec 8, 2020
1 parent 89f24fd commit f5939ed
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions asteroid/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import sys
import tempfile
from contextlib import contextmanager
from functools import partial
from functools import partial, lru_cache
from hashlib import sha256
from typing import BinaryIO, Dict, Optional, Union
from typing import BinaryIO, Dict, Optional, Union, List
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -38,6 +38,7 @@

HF_WEIGHTS_NAME = "pytorch_model.bin"
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
ENDPOINT = "https://huggingface.co"


def cached_download(filename_or_url):
Expand Down Expand Up @@ -309,3 +310,16 @@ def _resumable_file_manager() -> "io.BufferedWriter":
json.dump(meta, meta_file)

return cache_path


@lru_cache()
def model_list(endpoint=ENDPOINT, name_only=False) -> List[Dict]:
"""Get the public list of all the models on huggingface with an 'asteroid' tag."""
path = "{}/api/models?full=true".format(endpoint)
r = requests.get(path)
r.raise_for_status()
d = r.json()
all_models = [x for x in d if "asteroid" in x.get("tags", [])]
if name_only:
return [x["modelId"] for x in all_models]
return all_models

0 comments on commit f5939ed

Please sign in to comment.