Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo checks: skip docstring checks if not in the diff #32328

Merged
merged 6 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ jobs:
- run: python utils/custom_init_isort.py --check_only
- run: python utils/sort_auto_mappings.py --check_only
- run: python utils/check_doc_toc.py
- run: python utils/check_docstrings.py --check_all

check_repository_consistency:
working_directory: ~/transformers
Expand Down Expand Up @@ -190,4 +191,4 @@ workflows:
- check_circleci_user
- check_code_quality
- check_repository_consistency
- fetch_all_tests
- fetch_all_tests
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ quality:
python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
python utils/check_doc_toc.py
python utils/check_docstrings.py --check_all


# Format source code automatically and check is there are any problems left that need manual fixing
Expand Down
1 change: 1 addition & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def from_dict(cls, config_dict, **kwargs):
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.

Returns:
CacheConfig: Instance of CacheConfig constructed from the dictionary.
"""
Expand Down
37 changes: 34 additions & 3 deletions utils/check_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@
from typing import Any, Optional, Tuple, Union

from check_repo import ignore_undocumented
from git import Repo

from transformers.utils import direct_transformers_import


PATH_TO_REPO = Path(__file__).parent.parent.resolve()
PATH_TO_TRANSFORMERS = Path("src").resolve() / "transformers"

# This is to make sure the transformers module imported is the one in the repo.
Expand Down Expand Up @@ -943,14 +945,33 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str):
f.write("\n".join(lines))


def check_docstrings(overwrite: bool = False):
def check_docstrings(overwrite: bool = False, check_all: bool = False):
"""
Check docstrings of all public objects that are callables and are documented.
Check docstrings of all public objects that are callables and are documented. By default, only checks the diff.

Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether to fix inconsistencies or not.
check_all (`bool`, *optional*, defaults to `False`):
Whether to check all files.
"""
module_diff_files = None
if not check_all:
module_diff_files = set()
repo = Repo(PATH_TO_REPO)
# Diff from index to unstaged files
for modified_file_diff in repo.index.diff(None):
if modified_file_diff.a_path.startswith("src/transformers"):
module_diff_files.add(modified_file_diff.a_path)
# Diff from index to `main`
for modified_file_diff in repo.index.diff(repo.refs.main.commit):
if modified_file_diff.a_path.startswith("src/transformers"):
module_diff_files.add(modified_file_diff.a_path)
# quick escape route: if there are no module files in the diff, skip this check
if len(module_diff_files) == 0:
return
print(" Checking docstrings in the following files:" + "\n - " + "\n - ".join(module_diff_files))

failures = []
hard_failures = []
to_clean = []
Expand All @@ -963,6 +984,13 @@ def check_docstrings(overwrite: bool = False):
if not callable(obj) or not isinstance(obj, type) or getattr(obj, "__doc__", None) is None:
continue

# If we are checking against the diff, we skip objects that are not part of the diff.
if module_diff_files is not None:
object_file = find_source_file(getattr(transformers, name))
object_file_relative_path = "src/" + str(object_file).split("/src/")[1]
if object_file_relative_path not in module_diff_files:
continue

# Check docstring
try:
result = match_docstring_with_signature(obj)
Expand Down Expand Up @@ -1013,6 +1041,9 @@ def check_docstrings(overwrite: bool = False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
parser.add_argument(
"--check_all", action="store_true", help="Whether to check all files. By default, only checks the diff"
)
args = parser.parse_args()

check_docstrings(overwrite=args.fix_and_overwrite)
check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)
Loading