Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make github connector idempotent #417

Merged
merged 6 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions garden_ai/model_connectors/github_conn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from git import Repo # type: ignore
from git.repo.fun import is_git_dir
from garden_ai.mlmodel import ModelMetadata
from garden_ai.utils.misc import trackcalls
from requests.exceptions import HTTPError
import os
import sys
import requests

Expand Down Expand Up @@ -35,8 +35,19 @@ def __init__(

@trackcalls
def stage(self) -> str:
if not os.path.exists(self.local_dir):
os.mkdir(self.local_dir)

if is_git_dir(f"{self.local_dir}/.git"):
# double check the existing repo in local_dir refers to the same
# repo as this connector before pulling
found_repo = Repo(self.local_dir)
if self.repo_url not in found_repo.remotes.origin.url:
raise ValueError(
f"Failed to clone {self.repo_url} to {self.local_dir} "
f"({found_repo.remotes.origin.url} already cloned here)."
)
else:
found_repo.remotes.origin.pull(self.branch)
return self.local_dir

Repo.clone_from(f"{self.repo_url}.git", self.local_dir, branch=self.branch)

Expand Down
8 changes: 6 additions & 2 deletions garden_ai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ def clean_identifier(name: str) -> str:
def trackcalls(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
wrapper.has_been_called = True
return func(*args, **kwargs)
# note: attribute set only after func completes execution
# so that func itself can determine if it's been called
try:
return func(*args, **kwargs)
finally:
wrapper.has_been_called = True

wrapper.has_been_called = False
return wrapper
58 changes: 58 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
EntrypointMetadata,
)
from garden_ai.model_connectors import HFConnector, GitHubConnector
from unittest.mock import MagicMock


def test_create_empty_garden(garden_client):
Expand Down Expand Up @@ -108,3 +109,60 @@ def my_step():

assert my_step._garden_step.function_name == "my_step"
assert my_step._garden_step.description == "My nifty step"


def test_GHconnector_idempotent(mocker):
# Mock os.path.exists and os.mkdir
mocker.patch("os.path.exists", return_value=True)
mocker.patch("os.mkdir")

# Mock Repo to simulate both clone and pull scenarios
mock_repo_class = mocker.patch("garden_ai.model_connectors.github_conn.Repo")
mock_repo_instance = MagicMock()
mock_repo_class.return_value = mock_repo_instance
mock_repo_instance.remotes.origin.pull = MagicMock()
# Set up the mock to return a URL that matches the connector's repo_url
mock_repo_instance.remotes.origin.url = "https://github.com/fake/repo.git"

# Mock Repo.clone_from method to track calls without actually cloning
mock_clone_from = mocker.patch(
"garden_ai.model_connectors.github_conn.Repo.clone_from"
)

# Mock is_git_dir to control the flow in the stage method
mocker.patch(
"garden_ai.model_connectors.github_conn.is_git_dir",
side_effect=[
False,
True,
True,
], # First call: not a git dir, then it is a git dir
)

# enable_imports=False just bc mocking sys.path.append was hard
connector = GitHubConnector(
repo_url="https://github.com/fake/repo",
local_dir="gh_model",
branch="main",
enable_imports=False,
)

# First call should trigger clone since it's not a git dir yet
connector.stage()
mock_clone_from.assert_called_once_with(
"https://github.com/fake/repo.git", "gh_model", branch="main"
)

# Reset mock to test idempotency on subsequent calls
mock_clone_from.reset_mock()

# Subsequent calls should not trigger clone_from again, but should pull
connector.stage()
connector.stage()

# Assert that Repo.clone_from was not called again after the first time
mock_clone_from.assert_not_called()
# Assert that pull was called on subsequent invocations
assert (
mock_repo_instance.remotes.origin.pull.call_count == 2
), "Pull should be called on subsequent calls"
Loading