diff --git a/mybox/package/clone.py b/mybox/package/clone.py index a073794..4771235 100644 --- a/mybox/package/clone.py +++ b/mybox/package/clone.py @@ -6,6 +6,8 @@ from ..utils import RunArg, async_cached, repo_version from .destination import Destination +DEFAULT_REMOTE = "origin" + class Clone(Destination): repo: str = Field(..., alias="clone") @@ -40,6 +42,9 @@ def remote(self): async def get_remote_version(self) -> str: return await repo_version(self.remote) + async def branch_name(self, ref: str) -> str: + return await self.run_git_output("rev-parse", "--abbrev-ref", ref) + async def install(self, *, tracker: Tracker) -> None: destination = await self.destination() @@ -48,15 +53,16 @@ async def install(self, *, tracker: Tracker) -> None: if not await self.directory_exists(): await self.driver.run("git", "clone", self.remote, destination) - await self.run_git("remote", "set-url", "origin", self.remote) - await self.run_git("fetch") - default_branch = ( - await self.run_git_output("rev-parse", "--abbrev-ref", "origin/HEAD") - ).split("/")[1] - current_branch = await self.run_git_output("rev-parse", "--abbrev-ref", "HEAD") + await self.run_git("remote", "set-url", DEFAULT_REMOTE, self.remote) + + default_remote_branch = await self.branch_name(f"{DEFAULT_REMOTE}/HEAD") + default_branch = default_remote_branch.split("/")[1] + await self.run_git("fetch", "--no-tags", DEFAULT_REMOTE, default_branch) + + current_branch = await self.branch_name("HEAD") if current_branch != default_branch: await self.run_git("switch", default_branch) - await self.run_git("reset", "--hard", f"origin/{default_branch}") + await self.run_git("reset", "--hard", default_remote_branch) tracker.track(destination, root=self.root)