diff --git a/pillar_tool/git/repository.py b/pillar_tool/git/repository.py index bde0b11..8931239 100644 --- a/pillar_tool/git/repository.py +++ b/pillar_tool/git/repository.py @@ -1,9 +1,10 @@ import os +from typing import Optional import pygit2 -from pygit2 import RemoteCallbacks, CredentialType, Username, UserPass, Keypair -from pygit2.callbacks import _Credentials -from pygit2.enums import BranchType, FetchPrune, CheckoutStrategy as CS +from pygit2 import RemoteCallbacks, CredentialType, Username, UserPass, Keypair, DiffFile +from pygit2.callbacks import _Credentials, CheckoutCallbacks +from pygit2.enums import BranchType, FetchPrune, CheckoutStrategy as CS, CheckoutNotify from pillar_tool.util import Config @@ -12,6 +13,7 @@ class RepoCallbacks(RemoteCallbacks): def __init__(self, config: Config): super().__init__() self.config = config + self.checkout_notify_flags = None def credentials( self, @@ -20,10 +22,9 @@ class RepoCallbacks(RemoteCallbacks): allowed_types: CredentialType, ) -> Username | UserPass | Keypair: # compute allowed methods - allowed = [ 2**i for i,v in enumerate(reversed(bin(15)[2:])) if int(v) ] + allowed = [ 2**i for i,v in enumerate(reversed(bin(allowed_types)[2:])) if int(v) ] if CredentialType.SSH_KEY.value in allowed: - print("cred ssh_key") return Keypair( username=self.config.git.state_repo_user, privkey=self.config.git.state_repo_keyfile, @@ -31,13 +32,23 @@ class RepoCallbacks(RemoteCallbacks): passphrase=None ) elif CredentialType.USERNAME.value in allowed: - print("cred username") return Username(self.config.git.state_repo_user) print(f"The remote requested invalid credentials: {allowed}") raise RuntimeError(f"The remote requested invalid credentials: {allowed}") +class COCallbacks(CheckoutCallbacks): + def checkout_notify( + self, + why: CheckoutNotify, + path: str, + baseline: Optional[DiffFile], + target: Optional[DiffFile], + workdir: Optional[DiffFile], + ) -> None: + print(f"why: {why}") + def checkout_remote_branch(config: Config, branch_name: str) -> None: @@ -56,7 +67,8 @@ def checkout_remote_branch(config: Config, branch_name: str) -> None: or the specified branch does not exist. """ # create an instance of the RepositoryCallback class - cbs = RepoCallbacks(config) + cbs_remote = RepoCallbacks(config) + cbs_checkout = COCallbacks() # check if the repository actually exists if not os.path.isdir(config.git.state_repo_path): @@ -64,34 +76,46 @@ def checkout_remote_branch(config: Config, branch_name: str) -> None: try: print("cloning state repo") os.makedirs(os.path.dirname(config.git.state_repo_path), mode=0o700, exist_ok=True) - repository = pygit2.clone_repository(config.git.state_repo_remote, config.git.state_repo_path, callbacks=cbs, depth=1) + repository = pygit2.clone_repository(config.git.state_repo_remote, config.git.state_repo_path, callbacks=cbs_remote, depth=1) except Exception as e: print(f"Failed to clone state repo: {e}") raise ValueError(f"Unable to clone the states repository: {e}") else: # directory exists, so attempt to open the repository try: + print("opening state repo") repository = pygit2.Repository(config.git.state_repo_path) - except Exception: - raise ValueError(f"State repo at {config.git.state_repo_path} cannot be opened") + except Exception as e: + print(f"State repo at {config.git.state_repo_path} cannot be opened: {e}") + raise ValueError(f"State repo at {config.git.state_repo_path} cannot be opened: {e}") # check whether this repository has a remote named origin # this only needs to happen when the repo has not just been cloned - if "origin" not in repository.remotes: + if "origin" not in repository.remotes.names(): + print(f"No remote named origin in repo at {config.git.state_repo_path}") + print(list(repository.remotes.names())) raise ValueError(f"No remote named origin in repo at {config.git.state_repo_path}") else: - repository.remotes["origin"].fetch(prune=FetchPrune.PRUNE, depth=1, callbacks=cbs) + try: + repository.remotes["origin"].fetch(prune=FetchPrune.PRUNE, depth=1, callbacks=cbs_remote) + except Exception as e: + print(f"Failed to fetch origin remote: {e}") + raise ValueError(f"Unable to fetch origin remote: {e}") # check if the requested branch exists try: + print("checking for branch") branch_ref = repository.lookup_branch(f'origin/{branch_name}', BranchType.REMOTE) except KeyError: + print(f"Branch '{branch_name}' does not exist in the repository.") raise ValueError(f"Branch '{branch_name}' does not exist in the repository.") try: + print("checking out branch") # checkout the remote branch with force # this should be done like this, since there should never be any change made in this clone of the repository - repository.checkout(branch_ref, callbacks=cbs, strategy=CS.FORCE | CS.RECREATE_MISSING | CS.REMOVE_UNTRACKED) + repository.checkout(branch_ref, callbacks=cbs_checkout, strategy=CS.FORCE | CS.RECREATE_MISSING | CS.REMOVE_UNTRACKED) except Exception as exc: + print(f"Failed to checkout branch: {exc}") raise ValueError(f"Failed to checkout branch: {exc}")