diff --git a/setup.py b/setup.py index eab553124b..986fd04a25 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,6 @@ "pluggy>=1.3,<2", "pydantic>=2.5.2,<3", "pydantic-settings>=2.0.3,<3", - "PyGithub>=1.59,<2", "pytest>=6.0,<8.0", "python-dateutil>=2.8.2,<3", "PyYAML>=5.0,<7", diff --git a/src/ape/managers/project/dependency.py b/src/ape/managers/project/dependency.py index 32512e1879..c4734325bf 100644 --- a/src/ape/managers/project/dependency.py +++ b/src/ape/managers/project/dependency.py @@ -15,11 +15,11 @@ from ape.utils import ( ManagerAccessMixin, cached_property, - github_client, load_config, log_instead_of_fail, pragma_str_to_specifier_set, ) +from ape.utils._github import github_client class DependencyManager(ManagerAccessMixin): @@ -221,8 +221,16 @@ def version_id(self) -> str: elif self.version and self.version != "latest": return self.version - latest_release = github_client.get_release(self.github, "latest") - return latest_release.tag_name + latest_release = github_client.get_latest_release(self.org_name, self.repo_name) + return latest_release["tag_name"] + + @cached_property + def org_name(self) -> str: + return self.github.split("/")[0] + + @cached_property + def repo_name(self) -> str: + return self.github.split("/")[1] @property def uri(self) -> AnyUrl: @@ -250,12 +258,14 @@ def extract_manifest(self, use_cache: bool = True) -> PackageManifest: temp_project_path.mkdir(exist_ok=True, parents=True) if self.ref: - github_client.clone_repo(self.github, temp_project_path, branch=self.ref) + github_client.clone_repo( + self.org_name, self.repo_name, temp_project_path, branch=self.ref + ) else: try: github_client.download_package( - self.github, self.version or "latest", temp_project_path + self.org_name, self.repo_name, self.version or "latest", temp_project_path ) except UnknownVersionError as err: logger.warning( @@ -265,7 +275,7 @@ def extract_manifest(self, use_cache: bool = True) -> PackageManifest: ) try: github_client.clone_repo( - self.github, temp_project_path, branch=self.version + self.org_name, self.repo_name, temp_project_path, branch=self.version ) except Exception: # Raise the UnknownVersionError. diff --git a/src/ape/plugins/_utils.py b/src/ape/plugins/_utils.py index 3c42581e11..c02cbd3142 100644 --- a/src/ape/plugins/_utils.py +++ b/src/ape/plugins/_utils.py @@ -11,7 +11,8 @@ from ape.__modules__ import __modules__ from ape.logging import logger from ape.plugins import clean_plugin_name -from ape.utils import BaseInterfaceModel, get_package_version, github_client, log_instead_of_fail +from ape.utils import BaseInterfaceModel, get_package_version, log_instead_of_fail +from ape.utils._github import github_client from ape.utils.basemodel import BaseModel from ape.utils.misc import _get_distributions from ape.version import version as ape_version_str diff --git a/src/ape/utils/__init__.py b/src/ape/utils/__init__.py index f8782b556a..be86215134 100644 --- a/src/ape/utils/__init__.py +++ b/src/ape/utils/__init__.py @@ -18,7 +18,6 @@ ManagerAccessMixin, injected_before_use, ) -from ape.utils.github import GithubClient, github_client from ape.utils.misc import ( DEFAULT_LIVE_NETWORK_BASE_FEE_MULTIPLIER, DEFAULT_LOCAL_TRANSACTION_ACCEPTANCE_TIMEOUT, @@ -84,8 +83,6 @@ "get_relative_path", "gas_estimation_error_message", "get_package_version", - "GithubClient", - "github_client", "GeneratedDevAccount", "generate_dev_accounts", "get_all_files_in_directory", diff --git a/src/ape/utils/_github.py b/src/ape/utils/_github.py new file mode 100644 index 0000000000..f5252794c7 --- /dev/null +++ b/src/ape/utils/_github.py @@ -0,0 +1,219 @@ +import os +import shutil +import subprocess +import tempfile +import zipfile +from io import BytesIO +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from requests import Session +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from ape.exceptions import CompilerError, ProjectError, UnknownVersionError +from ape.logging import logger +from ape.utils.misc import USER_AGENT, cached_property, stream_response + + +class GitProcessWrapper: + @cached_property + def git(self) -> str: + if path := shutil.which("git"): + return path + + raise ProjectError("`git` not installed.") + + def clone(self, url: str, target_path: Optional[Path] = None, branch: Optional[str] = None): + command = [self.git, "-c", "advice.detachedHead=false", "clone", url] + + if target_path: + command.append(str(target_path)) + + if branch is not None: + command.extend(("--branch", branch)) + + logger.debug(f"Running git command: '{' '.join(command)}'") + result = subprocess.call(command) + if result != 0: + fail_msg = f"`git clone` command failed for '{url}'." + + if branch and not branch.startswith("v"): + # Often times, `v` is required for tags. + try: + self.clone(url, target_path, branch=f"v{branch}") + except Exception: + raise ProjectError(fail_msg) + + # Succeeded when prefixing `v`. + return + + # Failed and we don't really know why. + # Shouldn't really happen. + # User will have to run command separately to debug. + raise ProjectError(fail_msg) + + +# NOTE: This client is only meant to be used internally for ApeWorX projects. +class _GithubClient: + # Generic git/github client attributes. + TOKEN_KEY = "GITHUB_ACCESS_TOKEN" + API_URL_PREFIX = "https://api.github.com" + git: GitProcessWrapper = GitProcessWrapper() + + # ApeWorX-specific attributes. + ORGANIZATION_NAME = "ApeWorX" + FRAMEWORK_NAME = "ape" + _repo_cache: Dict[str, Dict] = {} + + def __init__(self, session: Optional[Session] = None): + if session: + # NOTE: Mostly allowed for testing purposes. + self.__session = session + + else: + headers = {"Content-Type": "application/json", "User-Agent": USER_AGENT} + if auth := os.environ[self.TOKEN_KEY] if self.TOKEN_KEY in os.environ else None: + headers["Authorization"] = f"token {auth}" + + session = Session() + session.headers = {**session.headers, **headers} + adapter = HTTPAdapter( + max_retries=Retry(total=10, backoff_factor=1.0, status_forcelist=[403]), + ) + session.mount("https://", adapter) + self.__session = session + + @cached_property + def org(self) -> Dict: + """ + Our organization on ``Github``. + """ + return self.get_organization(self.ORGANIZATION_NAME) + + @cached_property + def available_plugins(self) -> Set[str]: + return { + repo["name"].replace("-", "_") + for repo in self.get_org_repos() + if not repo.get("private", False) and repo["name"].startswith(f"{self.FRAMEWORK_NAME}-") + } + + def get_org_repos(self) -> List[Dict]: + return self._get(f"orgs/{self.ORGANIZATION_NAME}/repos") + + def get_release(self, org_name: str, repo_name: str, version: str) -> Dict: + if version == "latest": + return self.get_latest_release(org_name, repo_name) + + def _try_get_release(vers): + try: + return self._get_release(org_name, repo_name, vers) + except Exception: + return None + + if release := _try_get_release(version): + return release + else: + original_version = str(version) + # Try an alternative tag style + if version.startswith("v"): + version = version.lstrip("v") + else: + version = f"v{version}" + + if release := _try_get_release(version): + return release + + raise UnknownVersionError(original_version, repo_name) + + def _get_release(self, org_name: str, repo_name: str, version: str) -> Dict: + return self._get(f"repos/{org_name}/{repo_name}/releases/tags/{version}") + + def get_repo(self, org_name: str, repo_name: str) -> Dict: + repo_path = f"{org_name}/{repo_name}" + if repo_path not in self._repo_cache: + try: + self._repo_cache[repo_path] = self._get_repo(org_name, repo_name) + return self._repo_cache[repo_path] + except Exception as err: + raise ProjectError(f"Unknown repository '{repo_path}'") from err + + else: + return self._repo_cache[repo_path] + + def _get_repo(self, org_name: str, repo_name: str) -> Dict: + return self._get(f"repos/{org_name}/{repo_name}") + + def get_latest_release(self, org_name: str, repo_name: str) -> Dict: + return self._get(f"repos/{org_name}/{repo_name}/releases/latest") + + def get_organization(self, org_name: str) -> Dict: + return self._get(f"orgs/{org_name}") + + def clone_repo( + self, + org_name: str, + repo_name: str, + target_path: Union[str, Path], + branch: Optional[str] = None, + scheme: str = "http", + ): + repo = self.get_repo(org_name, repo_name) + branch = branch or repo["default_branch"] + logger.info(f"Cloning branch '{branch}' from '{repo['name']}'.") + url = repo["git_url"] + + if "ssh" in scheme or "git" in scheme: + url = url.replace("git://github.com/", "git@github.com:") + elif "http" in scheme: + url = url.replace("git://", "https://") + else: + raise ValueError(f"Scheme '{scheme}' not supported.") + + target_path = Path(target_path) + if target_path.exists(): + # Else, cloning will fail! + target_path = target_path / repo_name + + self.git.clone(url, branch=branch, target_path=target_path) + + def download_package( + self, org_name: str, repo_name: str, version: str, target_path: Union[Path, str] + ): + target_path = Path(target_path) # Handles str + if not target_path or not target_path.is_dir(): + raise ValueError(f"'target_path' must be a valid directory (got '{target_path}').") + + release = self.get_release(org_name, repo_name, version) + description = f"Downloading {org_name}/{repo_name}@{version}" + release_content = stream_response( + release["zipball_url"], progress_bar_description=description + ) + + # Use temporary path to isolate a package when unzipping + with tempfile.TemporaryDirectory() as tmp: + temp_path = Path(tmp) + with zipfile.ZipFile(BytesIO(release_content)) as zf: + zf.extractall(temp_path) + + # Copy the directory contents into the target path. + downloaded_packages = [f for f in temp_path.iterdir() if f.is_dir()] + if len(downloaded_packages) < 1: + raise CompilerError(f"Unable to download package at '{org_name}/{repo_name}'.") + + package_path = temp_path / downloaded_packages[0] + for source_file in package_path.iterdir(): + shutil.move(str(source_file), str(target_path)) + + def _get(self, url: str) -> Any: + return self._request("GET", url) + + def _request(self, method: str, url: str, **kwargs) -> Any: + url = f"{self.API_URL_PREFIX}/{url}" + response = self.__session.request(method, url, **kwargs) + response.raise_for_status() + return response.json() + + +github_client = _GithubClient() diff --git a/src/ape/utils/github.py b/src/ape/utils/github.py deleted file mode 100644 index bee3de0567..0000000000 --- a/src/ape/utils/github.py +++ /dev/null @@ -1,226 +0,0 @@ -import os -import shutil -import subprocess -import tempfile -import zipfile -from io import BytesIO -from pathlib import Path -from typing import Dict, Optional, Set - -from github import Github, UnknownObjectException -from github.Auth import Token as GithubToken -from github.GitRelease import GitRelease -from github.Organization import Organization -from github.Repository import Repository as GithubRepository -from urllib3.util.retry import Retry - -from ape.exceptions import CompilerError, ProjectError, UnknownVersionError -from ape.logging import logger -from ape.utils.misc import USER_AGENT, cached_property, stream_response - - -class GitProcessWrapper: - @cached_property - def git(self) -> str: - if path := shutil.which("git"): - return path - - raise ProjectError("`git` not installed.") - - def clone(self, url: str, target_path: Optional[Path] = None, branch: Optional[str] = None): - command = [self.git, "-c", "advice.detachedHead=false", "clone", url] - - if target_path: - command.append(str(target_path)) - - if branch is not None: - command.extend(("--branch", branch)) - - logger.debug(f"Running git command: '{' '.join(command)}'") - result = subprocess.call(command) - if result != 0: - fail_msg = f"`git clone` command failed for '{url}'." - - if branch and not branch.startswith("v"): - # Often times, `v` is required for tags. - try: - self.clone(url, target_path, branch=f"v{branch}") - except Exception: - raise ProjectError(fail_msg) - - # Succeeded when prefixing `v`. - return - - # Failed and we don't really know why. - # Shouldn't really happen. - # User will have to run command separately to debug. - raise ProjectError(fail_msg) - - -class GithubClient: - """ - An HTTP client for the Github API. - """ - - TOKEN_KEY = "GITHUB_ACCESS_TOKEN" - _repo_cache: Dict[str, GithubRepository] = {} - git: GitProcessWrapper = GitProcessWrapper() - - def __init__(self): - token = os.environ[self.TOKEN_KEY] if self.TOKEN_KEY in os.environ else None - auth = GithubToken(token) if token else None - retry = Retry(total=10, backoff_factor=1.0, status_forcelist=[403]) - self._client = Github(auth=auth, user_agent=USER_AGENT, retry=retry) - - @cached_property - def ape_org(self) -> Organization: - """ - The ``ApeWorX`` organization on ``Github`` (https://github.com/ApeWorX). - """ - return self.get_organization("ApeWorX") - - @cached_property - def available_plugins(self) -> Set[str]: - """ - The available ``ape`` plugins, found from looking at the ``ApeWorX`` Github organization. - - Returns: - Set[str]: The plugin names as ``'ape_plugin_name'`` (module-like). - """ - return { - repo.name.replace("-", "_") - for repo in self.ape_org.get_repos() - if not repo.private and repo.name.startswith("ape-") - } - - def get_release(self, repo_path: str, version: str) -> GitRelease: - """ - Get a release from Github. - - Args: - repo_path (str): The path on Github to the repository, - e.g. ``OpenZeppelin/openzeppelin-contracts``. - version (str): The version of the release to get. Pass in ``"latest"`` - to get the latest release. - - Returns: - github.GitRelease.GitRelease - """ - repo = self.get_repo(repo_path) - - if version == "latest": - return repo.get_latest_release() - - def _try_get_release(vers): - try: - return repo.get_release(vers) - except UnknownObjectException: - return None - - if release := _try_get_release(version): - return release - else: - original_version = str(version) - # Try an alternative tag style - if version.startswith("v"): - version = version.lstrip("v") - else: - version = f"v{version}" - - if release := _try_get_release(version): - return release - - raise UnknownVersionError(original_version, repo.name) - - def get_repo(self, repo_path: str) -> GithubRepository: - """ - Get a repository from GitHub. - - Args: - repo_path (str): The path to the repository, such as - ``OpenZeppelin/openzeppelin-contracts``. - - Returns: - github.Repository.Repository - """ - - if repo_path not in self._repo_cache: - try: - self._repo_cache[repo_path] = self._client.get_repo(repo_path) - return self._repo_cache[repo_path] - except UnknownObjectException as err: - raise ProjectError(f"Unknown repository '{repo_path}'") from err - - else: - return self._repo_cache[repo_path] - - def get_organization(self, name: str) -> Organization: - return self._client.get_organization(name) - - def clone_repo( - self, - repo_path: str, - target_path: Path, - branch: Optional[str] = None, - scheme: str = "http", - ): - """ - Clone a repository from Github. - - Args: - repo_path (str): The path on Github to the repository, - e.g. ``OpenZeppelin/openzeppelin-contracts``. - target_path (Path): The local path to store the repo. - branch (Optional[str]): The branch to clone. Defaults to the default branch. - scheme (str): The git scheme to use when cloning. Defaults to `ssh`. - """ - - repo = self.get_repo(repo_path) - branch = branch or repo.default_branch - logger.info(f"Cloning branch '{branch}' from '{repo.name}'.") - url = repo.git_url - - if "ssh" in scheme or "git" in scheme: - url = url.replace("git://github.com/", "git@github.com:") - elif "http" in scheme: - url = url.replace("git://", "https://") - else: - raise ValueError(f"Scheme '{scheme}' not supported.") - - self.git.clone(url, branch=branch, target_path=target_path) - - def download_package(self, repo_path: str, version: str, target_path: Path): - """ - Download a package from Github. This is useful for managing project dependencies. - - Args: - repo_path (str): The path on ``Github`` to the repository, - such as ``OpenZeppelin/openzeppelin-contracts``. - version (str): Number to specify update types - to the downloaded package. - target_path (path): A path in your local filesystem to save the downloaded package. - """ - if not target_path or not target_path.is_dir(): - raise ValueError(f"'target_path' must be a valid directory (got '{target_path}').") - - release = self.get_release(repo_path, version) - description = f"Downloading {repo_path}@{version}" - release_content = stream_response(release.zipball_url, progress_bar_description=description) - - # Use temporary path to isolate a package when unzipping - with tempfile.TemporaryDirectory() as tmp: - temp_path = Path(tmp) - with zipfile.ZipFile(BytesIO(release_content)) as zf: - zf.extractall(temp_path) - - # Copy the directory contents into the target path. - downloaded_packages = [f for f in temp_path.iterdir() if f.is_dir()] - if len(downloaded_packages) < 1: - raise CompilerError(f"Unable to download package at '{repo_path}'.") - - package_path = temp_path / downloaded_packages[0] - for source_file in package_path.iterdir(): - shutil.move(str(source_file), str(target_path)) - - -github_client = GithubClient() diff --git a/src/ape_init/_cli.py b/src/ape_init/_cli.py index 788b7ebd99..dd68d935b5 100644 --- a/src/ape_init/_cli.py +++ b/src/ape_init/_cli.py @@ -5,7 +5,7 @@ from ape.cli import ape_cli_context from ape.managers.config import CONFIG_FILE_NAME -from ape.utils import github_client +from ape.utils._github import github_client GITIGNORE_CONTENT = """ # Ape stuff diff --git a/src/ape_plugins/_cli.py b/src/ape_plugins/_cli.py index 3a28d4ff3b..596117d70b 100644 --- a/src/ape_plugins/_cli.py +++ b/src/ape_plugins/_cli.py @@ -16,8 +16,7 @@ PluginType, ape_version, ) -from ape.utils import load_config -from ape.utils.misc import _get_distributions +from ape.utils.misc import _get_distributions, load_config @click.group(short_help="Manage ape plugins") diff --git a/tests/functional/utils/test_github.py b/tests/functional/utils/test_github.py index eaecc0db0a..681dbac322 100644 --- a/tests/functional/utils/test_github.py +++ b/tests/functional/utils/test_github.py @@ -2,19 +2,20 @@ from pathlib import Path import pytest -from github import UnknownObjectException from requests.exceptions import ConnectTimeout -from ape.utils.github import GithubClient +from ape.utils._github import _GithubClient -REPO_PATH = "test/path" +ORG_NAME = "test" +REPO_NAME = "path" +REPO_PATH = f"{ORG_NAME}/{REPO_NAME}" @pytest.fixture(autouse=True) -def clear_repo_cache(github_client_with_mocks): +def clear_repo_cache(github_client): def clear(): - if REPO_PATH in github_client_with_mocks._repo_cache: - del github_client_with_mocks._repo_cache[REPO_PATH] + if REPO_PATH in github_client._repo_cache: + del github_client._repo_cache[REPO_PATH] clear() yield @@ -22,37 +23,31 @@ def clear(): @pytest.fixture -def mock_client(mocker): - return mocker.MagicMock() - - -@pytest.fixture -def mock_repo(mocker): +def mock_session(mocker): return mocker.MagicMock() @pytest.fixture def mock_release(mocker): - return mocker.MagicMock() + release = mocker.MagicMock() + release.json.return_value = {"name": REPO_NAME} + return release @pytest.fixture -def github_client_with_mocks(mock_client, mock_repo): - client = GithubClient() - mock_client.get_repo.return_value = mock_repo - client._client = mock_client - return client +def github_client(mock_session): + return _GithubClient(session=mock_session) class TestGithubClient: def test_clone_repo(self, mocker): # NOTE: this test actually clones the repo. - client = GithubClient() - git_patch = mocker.patch("ape.utils.github.subprocess.call") + client = _GithubClient() + git_patch = mocker.patch("ape.utils._github.subprocess.call") git_patch.return_value = 0 with tempfile.TemporaryDirectory() as temp_dir: try: - client.clone_repo("dapphub/ds-test", Path(temp_dir), branch="master") + client.clone_repo("dapphub", "ds-test", Path(temp_dir), branch="master") except ConnectTimeout: pytest.xfail("Internet required to run this test.") @@ -66,42 +61,33 @@ def test_clone_repo(self, mocker): assert cmd[6] == "--branch" assert cmd[7] == "master" - def test_get_release(self, github_client_with_mocks, mock_repo): - github_client_with_mocks.get_release(REPO_PATH, "0.1.0") - - # Test that we used the given tag. - mock_repo.get_release.assert_called_once_with("0.1.0") - - # Ensure that it uses the repo cache the second time - github_client_with_mocks.get_release(REPO_PATH, "0.1.0") - assert github_client_with_mocks._client.get_repo.call_count == 1 - - def test_get_release_when_tag_fails_tries_with_v( - self, mock_release, github_client_with_mocks, mock_repo - ): - # This test makes sure that if we try to get a release and the `v` is not needed, - # it will try again without the `v`. - def side_effect(version): - if version.startswith("v"): - raise UnknownObjectException(400, {}, {}) - - return mock_release - - mock_repo.get_release.side_effect = side_effect - actual = github_client_with_mocks.get_release(REPO_PATH, "v0.1.0") - assert actual == mock_release - - def test_get_release_when_tag_fails_tries_without_v( - self, mock_release, github_client_with_mocks, mock_repo - ): - # This test makes sure that if we try to get a release and the `v` is needed, - # it will try again with the `v`. - def side_effect(version): - if not version.startswith("v"): - raise UnknownObjectException(400, {}, {}) + def test_get_release(self, github_client, mock_session): + version = "0.1.0" + github_client.get_release(ORG_NAME, REPO_NAME, "0.1.0") + base_uri = f"https://api.github.com/repos/{ORG_NAME}/{REPO_NAME}/releases/tags" + expected_uri = f"{base_uri}/{version}" + assert mock_session.request.call_args[0] == ("GET", expected_uri) + + @pytest.mark.parametrize("version", ("0.1.0", "v0.1.0")) + def test_get_release_retry(self, mock_release, github_client, mock_session, version): + """ + Ensure after failing to get a release, we re-attempt with + out a v-prefix. + """ + opposite = version.lstrip("v") if version.startswith("v") else f"v{version}" + + def side_effect(method, uri, *arg, **kwargs): + _version = uri.split("/")[-1] + if _version == version: + # Force it to try the opposite. + raise ValueError() return mock_release - mock_repo.get_release.side_effect = side_effect - actual = github_client_with_mocks.get_release(REPO_PATH, "0.1.0") - assert actual == mock_release + mock_session.request.side_effect = side_effect + actual = github_client.get_release(ORG_NAME, REPO_NAME, version) + assert actual["name"] == REPO_NAME + calls = mock_session.request.call_args_list[-2:] + expected_uri = "https://api.github.com/repos/test/path/releases/tags" + assert calls[0][0] == ("GET", f"{expected_uri}/{version}") + assert calls[1][0] == ("GET", f"{expected_uri}/{opposite}")