Skip to content

Commit

Permalink
[#940] Better inheritance for repo card utils (#956)
Browse files Browse the repository at this point in the history
* 🚧 wip

* 🚧 wip

* 🚧 wip

* 💄 apply style

* 🚧 wip
  • Loading branch information
nateraw committed Aug 22, 2022
1 parent 3e0d13b commit 0f8226c
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 55 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include src/huggingface_hub/templates/modelcard_template.md
include src/huggingface_hub/templates/datasetcard_template.md
3 changes: 2 additions & 1 deletion src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __dir__():
"metadata_save",
"metadata_update",
"ModelCard",
"DatasetCard",
],
"community": [
"Discussion",
Expand All @@ -217,6 +218,6 @@ def __dir__():
"DiscussionCommit",
"DiscussionTitleChange",
],
"repocard_data": ["CardData", "EvalResult"],
"repocard_data": ["CardData", "ModelCardData", "DatasetCardData", "EvalResult"],
},
)
60 changes: 35 additions & 25 deletions src/huggingface_hub/repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from huggingface_hub.hf_api import upload_file
from huggingface_hub.repocard_data import (
CardData,
DatasetCardData,
EvalResult,
ModelCardData,
eval_results_to_model_index,
model_index_to_eval_results,
)
Expand All @@ -29,12 +31,21 @@

# exact same regex as in the Hub server. Please keep in sync.
TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md"
TEMPLATE_DATASETCARD_PATH = (
Path(__file__).parent / "templates" / "datasetcard_template.md"
)

REGEX_YAML_BLOCK = re.compile(r"---[\n\r]+([\S\s]*?)[\n\r]+---[\n\r]")

logger = get_logger(__name__)


class ModelCard:
class RepoCard:

card_data_class = CardData
default_template_path = TEMPLATE_MODELCARD_PATH
repo_type = "model"

def __init__(self, content: str):
"""Initialize a RepoCard from string content. The content should be a
Markdown file with a YAML block at the beginning and a Markdown body.
Expand Down Expand Up @@ -65,18 +76,7 @@ def __init__(self, content: str):
data_dict = {}
self.text = content

model_index = data_dict.pop("model-index", None)
if model_index:
try:
model_name, eval_results = model_index_to_eval_results(model_index)
data_dict["model_name"] = model_name
data_dict["eval_results"] = eval_results
except KeyError:
logger.warning(
"Invalid model-index. Not loading eval results into CardData."
)

self.data = CardData(**data_dict)
self.data = self.card_data_class(**data_dict)

def __str__(self):
line_break = _detect_line_ending(self.content) or "\n"
Expand Down Expand Up @@ -127,20 +127,23 @@ def load(cls, repo_id_or_path: Union[str, Path], repo_type=None, token=None):
card_path = hf_hub_download(
repo_id_or_path,
REPOCARD_NAME,
repo_type=repo_type,
repo_type=repo_type or cls.repo_type,
use_auth_token=token,
)

# Preserve newlines in the existing file.
with Path(card_path).open(mode="r", newline="") as f:
return cls(f.read())

def validate(self, repo_type="model"):
def validate(self, repo_type=None):
"""Validates card against Hugging Face Hub's model card validation logic.
Using this function requires access to the internet, so it is only called
internally by `huggingface_hub.ModelCard.push_to_hub`.
"""

# If repo type is provided, otherwise, use the repo type of the card.
repo_type = repo_type or self.repo_type

body = {
"repoType": repo_type,
"content": str(self),
Expand All @@ -162,7 +165,7 @@ def push_to_hub(
self,
repo_id,
token=None,
repo_type="model",
repo_type=None,
commit_message=None,
commit_description=None,
revision=None,
Expand Down Expand Up @@ -199,14 +202,8 @@ def push_to_hub(
`str`: URL of the commit which updated the card metadata.
"""

if repo_type is None:
repo_type = "model"

if repo_type not in ["model", "space", "dataset"]:
raise RuntimeError(
"Provided repo_type '{repo_type}' should be one of ['model', 'space',"
" 'dataset']."
)
# If repo type is provided, otherwise, use the repo type of the card.
repo_type = repo_type or self.repo_type

# Validate card before pushing to hub
self.validate(repo_type=repo_type)
Expand All @@ -231,7 +228,7 @@ def push_to_hub(
def from_template(
cls,
card_data: CardData,
template_path: Optional[str] = TEMPLATE_MODELCARD_PATH,
template_path: Optional[str] = None,
**template_kwargs,
):
"""Initialize a ModelCard from a template. By default, it uses the default template.
Expand Down Expand Up @@ -296,12 +293,25 @@ def from_template(
... )
"""
template_path = template_path or cls.default_template_path
content = jinja2.Template(Path(template_path).read_text()).render(
card_data=card_data.to_yaml(), **template_kwargs
)
return cls(content)


class ModelCard(RepoCard):
card_data_class = ModelCardData
default_template_path = TEMPLATE_MODELCARD_PATH
repo_type = "model"


class DatasetCard(RepoCard):
card_data_class = DatasetCardData
default_template_path = TEMPLATE_DATASETCARD_PATH
repo_type = "dataset"


def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]:
"""
same implem as in Hub server, keep it in sync
Expand Down
107 changes: 89 additions & 18 deletions src/huggingface_hub/repocard_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

import yaml

from .utils.logging import get_logger


logger = get_logger(__name__)


@dataclass
class EvalResult:
Expand Down Expand Up @@ -76,6 +81,38 @@ class EvalResult:

@dataclass
class CardData:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)

def to_dict(self):
"""Converts CardData to a dict.
Returns:
`dict`: CardData represented as a dictionary ready to be dumped to a YAML
block for inclusion in a README.md file.
"""

data_dict = copy.deepcopy(self.__dict__)
self._to_dict(data_dict)
return _remove_none(data_dict)

def _to_dict(self, data_dict):
"""Use this method in child classes to alter the dict representation of the data. Alter the dict in-place.
Args:
data_dict (_type_): The raw dict representation of the card data.
"""
pass

def to_yaml(self, line_break=None):
"""Dumps CardData to a YAML block for inclusion in a README.md file."""
return yaml.dump(self.to_dict(), sort_keys=False, line_break=line_break).strip()

def __repr__(self):
return self.to_yaml()


class ModelCardData(CardData):
def __init__(
self,
language: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -140,38 +177,72 @@ def __init__(
self.metrics = metrics
self.eval_results = eval_results
self.model_name = model_name
self.__dict__.update(kwargs)

model_index = kwargs.pop("model-index", None)
if model_index:
try:
model_name, eval_results = model_index_to_eval_results(model_index)
self.model_name = model_name
self.eval_results = eval_results
except KeyError:
logger.warning(
"Invalid model-index. Not loading eval results into CardData."
)

super().__init__(**kwargs)

if self.eval_results:
if type(self.eval_results) == EvalResult:
self.eval_results = [self.eval_results]
if self.model_name is None:
raise ValueError("`eval_results` requires `model_name` to be set.")

def to_dict(self):
"""Converts CardData to a dict. It also formats the internal eval_results to
be compatible with the model-index format.
Returns:
`dict`: CardData represented as a dictionary ready to be dumped to a YAML
block for inclusion in a README.md file.
"""

data_dict = copy.deepcopy(self.__dict__)
def _to_dict(self, data_dict):
"""Format the internal data dict. In this case, we convert eval results to a valid model index"""
if self.eval_results is not None:
data_dict["model-index"] = eval_results_to_model_index(
self.model_name, self.eval_results
)
del data_dict["eval_results"], data_dict["model_name"]

return _remove_none(data_dict)

def to_yaml(self, line_break=None):
"""Dumps CardData to a YAML block for inclusion in a README.md file."""
return yaml.dump(self.to_dict(), sort_keys=False, line_break=line_break).strip()

def __repr__(self):
return self.to_yaml()
class DatasetCardData(CardData):
def __init__(
self,
annotations_creators: Optional[Union[str, List[str]]] = None,
language_creators: Optional[Union[str, List[str]]] = None,
language: Optional[Union[str, List[str]]] = None,
license: Optional[Union[str, List[str]]] = None,
multilinguality: Optional[Union[str, List[str]]] = None,
size_categories: Optional[Union[str, List[str]]] = None,
source_datasets: Optional[Union[str, List[str]]] = None,
task_categories: Optional[Union[str, List[str]]] = None,
task_ids: Optional[Union[str, List[str]]] = None,
paperswithcode_id: Optional[str] = None,
pretty_name: Optional[str] = None,
train_eval_index: Optional[Dict] = None,
configs: Optional[Union[str, List[str]]] = None,
**kwargs,
):
self.annotations_creators = annotations_creators
self.language_creators = language_creators
self.language = language
self.license = license
self.multilinguality = multilinguality
self.size_categories = size_categories
self.source_datasets = source_datasets
self.task_categories = task_categories
self.task_ids = task_ids
self.paperswithcode_id = paperswithcode_id
self.pretty_name = pretty_name
self.configs = configs

# TODO - maybe handle this similarly to EvalResult?
self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None)
super().__init__(**kwargs)

def _to_dict(self, data_dict):
data_dict["train-eval-index"] = data_dict.pop("train_eval_index")


def model_index_to_eval_results(model_index: List[Dict[str, Any]]):
Expand Down
Loading

0 comments on commit 0f8226c

Please sign in to comment.