diff --git a/garden_ai/model_connectors/github_conn.py b/garden_ai/model_connectors/github_conn.py index ff5e41b1..59c67d2a 100644 --- a/garden_ai/model_connectors/github_conn.py +++ b/garden_ai/model_connectors/github_conn.py @@ -1,8 +1,8 @@ from git import Repo # type: ignore +from git.repo.fun import is_git_dir from garden_ai.mlmodel import ModelMetadata from garden_ai.utils.misc import trackcalls from requests.exceptions import HTTPError -import os import sys import requests @@ -35,8 +35,19 @@ def __init__( @trackcalls def stage(self) -> str: - if not os.path.exists(self.local_dir): - os.mkdir(self.local_dir) + + if is_git_dir(f"{self.local_dir}/.git"): + # double check the existing repo in local_dir refers to the same + # repo as this connector before pulling + found_repo = Repo(self.local_dir) + if self.repo_url not in found_repo.remotes.origin.url: + raise ValueError( + f"Failed to clone {self.repo_url} to {self.local_dir} " + f"({found_repo.remotes.origin.url} already cloned here)." + ) + else: + found_repo.remotes.origin.pull(self.branch) + return self.local_dir Repo.clone_from(f"{self.repo_url}.git", self.local_dir, branch=self.branch) diff --git a/garden_ai/utils/misc.py b/garden_ai/utils/misc.py index 69d534c9..1714cd08 100644 --- a/garden_ai/utils/misc.py +++ b/garden_ai/utils/misc.py @@ -84,8 +84,12 @@ def clean_identifier(name: str) -> str: def trackcalls(func): @functools.wraps(func) def wrapper(*args, **kwargs): - wrapper.has_been_called = True - return func(*args, **kwargs) + # note: attribute set only after func completes execution + # so that func itself can determine if it's been called + try: + return func(*args, **kwargs) + finally: + wrapper.has_been_called = True wrapper.has_been_called = False return wrapper diff --git a/tests/test_models.py b/tests/test_models.py index ae602e6b..218c7041 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,6 +9,7 @@ EntrypointMetadata, ) from garden_ai.model_connectors import HFConnector, GitHubConnector +from unittest.mock import MagicMock def test_create_empty_garden(garden_client): @@ -108,3 +109,60 @@ def my_step(): assert my_step._garden_step.function_name == "my_step" assert my_step._garden_step.description == "My nifty step" + + +def test_GHconnector_idempotent(mocker): + # Mock os.path.exists and os.mkdir + mocker.patch("os.path.exists", return_value=True) + mocker.patch("os.mkdir") + + # Mock Repo to simulate both clone and pull scenarios + mock_repo_class = mocker.patch("garden_ai.model_connectors.github_conn.Repo") + mock_repo_instance = MagicMock() + mock_repo_class.return_value = mock_repo_instance + mock_repo_instance.remotes.origin.pull = MagicMock() + # Set up the mock to return a URL that matches the connector's repo_url + mock_repo_instance.remotes.origin.url = "https://github.com/fake/repo.git" + + # Mock Repo.clone_from method to track calls without actually cloning + mock_clone_from = mocker.patch( + "garden_ai.model_connectors.github_conn.Repo.clone_from" + ) + + # Mock is_git_dir to control the flow in the stage method + mocker.patch( + "garden_ai.model_connectors.github_conn.is_git_dir", + side_effect=[ + False, + True, + True, + ], # First call: not a git dir, then it is a git dir + ) + + # enable_imports=False just bc mocking sys.path.append was hard + connector = GitHubConnector( + repo_url="https://github.com/fake/repo", + local_dir="gh_model", + branch="main", + enable_imports=False, + ) + + # First call should trigger clone since it's not a git dir yet + connector.stage() + mock_clone_from.assert_called_once_with( + "https://github.com/fake/repo.git", "gh_model", branch="main" + ) + + # Reset mock to test idempotency on subsequent calls + mock_clone_from.reset_mock() + + # Subsequent calls should not trigger clone_from again, but should pull + connector.stage() + connector.stage() + + # Assert that Repo.clone_from was not called again after the first time + mock_clone_from.assert_not_called() + # Assert that pull was called on subsequent invocations + assert ( + mock_repo_instance.remotes.origin.pull.call_count == 2 + ), "Pull should be called on subsequent calls"