From 840654ea21a1aaa0eef828e4e379cc223c6ef681 Mon Sep 17 00:00:00 2001 From: Masaya Suzuki Date: Wed, 8 May 2024 18:23:07 -0700 Subject: [PATCH] Add new sequencer object This is similar to git's sequencer, but it's for re-stacking branches. https://github.com/git/git/blob/master/sequencer.h This object is responsible for re-stacking branches based on the planned operations. --- internal/git/git.go | 15 +- internal/sequencer/planner/planner.go | 65 ++++++ internal/sequencer/planner/targets.go | 93 +++++++++ internal/sequencer/sequencer.go | 289 ++++++++++++++++++++++++++ 4 files changed, 461 insertions(+), 1 deletion(-) create mode 100644 internal/sequencer/planner/planner.go create mode 100644 internal/sequencer/planner/targets.go create mode 100644 internal/sequencer/sequencer.go diff --git a/internal/git/git.go b/internal/git/git.go index b3677925..7ec8faf9 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -12,6 +12,7 @@ import ( "time" "emperror.dev/errors" + "github.com/go-git/go-git/v5" "github.com/sirupsen/logrus" giturls "github.com/whilp/git-urls" ) @@ -21,16 +22,24 @@ var ErrRemoteNotFound = errors.Sentinel("this repository doesn't have a remote o type Repo struct { repoDir string gitDir string + gitRepo *git.Repository log logrus.FieldLogger } func OpenRepo(repoDir string, gitDir string) (*Repo, error) { + repo, err := git.PlainOpenWithOptions(repoDir, &git.PlainOpenOptions{ + DetectDotGit: true, + EnableDotGitCommonDir: true, + }) + if err != nil { + return nil, errors.Errorf("failed to open git repo: %v", err) + } r := &Repo{ repoDir, gitDir, + repo, logrus.WithFields(logrus.Fields{"repo": filepath.Base(repoDir)}), } - return r, nil } @@ -46,6 +55,10 @@ func (r *Repo) AvDir() string { return filepath.Join(r.GitDir(), "av") } +func (r *Repo) GoGitRepo() *git.Repository { + return r.gitRepo +} + func (r *Repo) AvTmpDir() string { dir := filepath.Join(r.AvDir(), "tmp") // Try to create the directory, but swallow the error since it will diff --git a/internal/sequencer/planner/planner.go b/internal/sequencer/planner/planner.go new file mode 100644 index 00000000..9833e859 --- /dev/null +++ b/internal/sequencer/planner/planner.go @@ -0,0 +1,65 @@ +package planner + +import ( + "github.com/aviator-co/av/internal/git" + "github.com/aviator-co/av/internal/meta" + "github.com/aviator-co/av/internal/sequencer" + "github.com/go-git/go-git/v5/plumbing" +) + +func PlanForRestack(tx meta.ReadTx, repo *git.Repo, targetBranches []plumbing.ReferenceName) ([]sequencer.RestackOp, error) { + var ret []sequencer.RestackOp + for _, br := range targetBranches { + avbr, _ := tx.Branch(br.Short()) + if avbr.MergeCommit != "" { + // Skip rebasing branches that have merge commits. + continue + } + ret = append(ret, sequencer.RestackOp{ + Name: br, + NewParent: plumbing.NewBranchReferenceName(avbr.Parent.Name), + NewParentIsTrunk: avbr.Parent.Trunk, + }) + } + return ret, nil +} + +func PlanForSync(tx meta.ReadTx, repo *git.Repo, targetBranches []plumbing.ReferenceName, syncToTrunkInsteadOfMergeCommit bool) ([]sequencer.RestackOp, error) { + var ret []sequencer.RestackOp + for _, br := range targetBranches { + avbr, _ := tx.Branch(br.Short()) + if avbr.MergeCommit != "" { + // Skip rebasing branches that have merge commits. + continue + } + if !avbr.Parent.Trunk { + // Check if the parent branch is merged. + avpbr, _ := tx.Branch(avbr.Parent.Name) + if avpbr.MergeCommit != "" { + // The parent is merged. Sync to either trunk or merge commit. + trunk, _ := meta.Trunk(tx, br.Short()) + var newParentHash plumbing.Hash + if syncToTrunkInsteadOfMergeCommit { + // By setting this to ZeroHash, the sequencer will sync to + // the remote tracking branch. + newParentHash = plumbing.ZeroHash + } else { + newParentHash = plumbing.NewHash(avpbr.MergeCommit) + } + ret = append(ret, sequencer.RestackOp{ + Name: br, + NewParent: plumbing.NewBranchReferenceName(trunk), + NewParentIsTrunk: true, + NewParentHash: newParentHash, + }) + continue + } + } + ret = append(ret, sequencer.RestackOp{ + Name: br, + NewParent: plumbing.NewBranchReferenceName(avbr.Parent.Name), + NewParentIsTrunk: avbr.Parent.Trunk, + }) + } + return ret, nil +} diff --git a/internal/sequencer/planner/targets.go b/internal/sequencer/planner/targets.go new file mode 100644 index 00000000..45ff1165 --- /dev/null +++ b/internal/sequencer/planner/targets.go @@ -0,0 +1,93 @@ +package planner + +import ( + "github.com/aviator-co/av/internal/git" + "github.com/aviator-co/av/internal/meta" + "github.com/go-git/go-git/v5/plumbing" +) + +type TargetBranchMode int + +const ( + // Target all branches in the repository. + AllBranches TargetBranchMode = iota + // The current branch and all its predecessors. + CurrentAndParents + // The current branch and all its successors. + CurrentAndChildren + // Branches of the current stack. (The stack root and all its successors.) + CurrentStack +) + +// GetTargetBranches returns the branches to be restacked. +// +// If `includeStackRoots` is true, the stack root branches (the immediate children of the trunk +// branches) are included in the result. +func GetTargetBranches(tx meta.ReadTx, repo *git.Repo, includeStackRoots bool, mode TargetBranchMode) ([]plumbing.ReferenceName, error) { + var ret []plumbing.ReferenceName + if mode == AllBranches { + for _, br := range tx.AllBranches() { + if !br.IsStackRoot() { + continue + } + if includeStackRoots { + ret = append(ret, plumbing.NewBranchReferenceName(br.Name)) + } + for _, n := range meta.SubsequentBranches(tx, br.Name) { + ret = append(ret, plumbing.NewBranchReferenceName(n)) + } + } + return ret, nil + } + if mode == CurrentAndParents { + curr, err := repo.CurrentBranchName() + if err != nil { + return nil, err + } + prevs, err := meta.PreviousBranches(tx, curr) + if err != nil { + return nil, err + } + for _, n := range prevs { + br, _ := tx.Branch(n) + if !br.IsStackRoot() || includeStackRoots { + ret = append(ret, plumbing.NewBranchReferenceName(n)) + } + } + br, _ := tx.Branch(curr) + if !br.IsStackRoot() || includeStackRoots { + ret = append(ret, plumbing.NewBranchReferenceName(curr)) + } + return ret, nil + } + if mode == CurrentAndChildren { + curr, err := repo.CurrentBranchName() + if err != nil { + return nil, err + } + br, _ := tx.Branch(curr) + if !br.IsStackRoot() || includeStackRoots { + ret = append(ret, plumbing.NewBranchReferenceName(curr)) + } + // The rest of the branches cannot be a stack root. + for _, n := range meta.SubsequentBranches(tx, curr) { + ret = append(ret, plumbing.NewBranchReferenceName(n)) + } + return ret, nil + } + curr, err := repo.CurrentBranchName() + if err != nil { + return nil, err + } + brs, err := meta.StackBranches(tx, curr) + if err != nil { + return nil, err + } + for _, n := range brs { + br, _ := tx.Branch(n) + if !br.IsStackRoot() || includeStackRoots { + ret = append(ret, plumbing.NewBranchReferenceName(n)) + } + } + return ret, nil +} diff --git a/internal/sequencer/sequencer.go b/internal/sequencer/sequencer.go new file mode 100644 index 00000000..d647c5c7 --- /dev/null +++ b/internal/sequencer/sequencer.go @@ -0,0 +1,289 @@ +package sequencer + +import ( + "fmt" + "os" + "path/filepath" + + "emperror.dev/errors" + "github.com/aviator-co/av/internal/git" + "github.com/aviator-co/av/internal/meta" + "github.com/go-git/go-git/v5/config" + "github.com/go-git/go-git/v5/plumbing" +) + +type RestackOp struct { + Name plumbing.ReferenceName + + // New parent branch to sync to. + NewParent plumbing.ReferenceName + + // Mark the new parent branch as trunk. + NewParentIsTrunk bool + + // The new parent branch's hash. If not specified, the sequencer will use the new parent's + // branch hash if the new parent is not trunk. Or if the new parent is trunk, the sequencer + // will use the remote tracking branch's hash. + NewParentHash plumbing.Hash +} + +type branchSnapshot struct { + // The branch name. + Name plumbing.ReferenceName + // The parent branch name. + ParentBranch plumbing.ReferenceName + // True if the parent branch is the trunk branch (refs/heads/master etc.). + IsParentTrunk bool + // Commit hash that the parent branch was previously at last time this was synced. + // This is plumbing.ZeroHash if the parent branch is a trunk. + PreviouslySyncedParentBranchHash plumbing.Hash +} + +// Sequencer re-stacks the specified branches. +// +// This entire Sequencer object should be JSON serializable. The caller is expected to save this to +// file when the sequencer needs to be paused for more input. +type Sequencer struct { + // The name of the remote (e.g. "origin"). + RemoteName string + // All branch information initially when the sequencer started. + OriginalBranchSnapshots map[plumbing.ReferenceName]*branchSnapshot + // Ref that is currently being synced. Next time the sequencer runs, it will rebase this + // ref. + CurrentSyncRef plumbing.ReferenceName + // If the rebase is stopped, these fields are set. + SequenceInterruptedNewParentHash plumbing.Hash + + Operations []RestackOp +} + +func NewSequencer(remoteName string, db meta.DB, ops []RestackOp) *Sequencer { + return &Sequencer{ + RemoteName: remoteName, + OriginalBranchSnapshots: getBranchSnapshots(db), + Operations: ops, + CurrentSyncRef: ops[0].Name, + } +} + +func getBranchSnapshots(db meta.DB) map[plumbing.ReferenceName]*branchSnapshot { + ret := map[plumbing.ReferenceName]*branchSnapshot{} + for name, avbr := range db.ReadTx().AllBranches() { + snapshot := &branchSnapshot{ + Name: plumbing.ReferenceName("refs/heads/" + name), + ParentBranch: plumbing.ReferenceName("refs/heads/" + avbr.Parent.Name), + } + ret[snapshot.Name] = snapshot + if avbr.Parent.Trunk { + snapshot.IsParentTrunk = true + } else { + snapshot.PreviouslySyncedParentBranchHash = plumbing.NewHash(avbr.Parent.Head) + } + } + return ret +} + +func (seq *Sequencer) Run(repo *git.Repo, db meta.DB, seqAbort, seqContinue, seqSkip bool) (*git.RebaseResult, error) { + if seqAbort || seqContinue || seqSkip { + return seq.runFromInterruptedState(repo, db, seqAbort, seqContinue, seqSkip) + } + + if seq.CurrentSyncRef == "" { + return nil, nil + } + return seq.rebaseBranch(repo, db) +} + +func (seq *Sequencer) runFromInterruptedState(repo *git.Repo, db meta.DB, seqAbort, seqContinue, seqSkip bool) (*git.RebaseResult, error) { + if seq.CurrentSyncRef == "" { + return nil, errors.New("no sync in progress") + } + if seq.SequenceInterruptedNewParentHash.IsZero() { + panic("broken interruption state: no new parent hash") + } + if seqAbort { + // Abort the rebase if we need to + if stat, _ := os.Stat(filepath.Join(repo.GitDir(), "REBASE_HEAD")); stat != nil { + if _, err := repo.Rebase(git.RebaseOpts{Abort: true}); err != nil { + return nil, errors.Errorf("failed to abort in-progress rebase: %v", err) + } + } + seq.CurrentSyncRef = "" + seq.SequenceInterruptedNewParentHash = plumbing.ZeroHash + return nil, nil + } + if seqContinue { + if err := seq.checkNoUnstagedChanges(repo); err != nil { + return nil, err + } + result, err := repo.RebaseParse(git.RebaseOpts{Continue: true}) + if err != nil { + return nil, errors.Errorf("failed to continue in-progress rebase: %v", err) + } + if result.Status == git.RebaseConflict { + return result, nil + } + if err := seq.postRebaseBranchUpdate(db, seq.SequenceInterruptedNewParentHash); err != nil { + return nil, err + } + return result, nil + } + if seqSkip { + result, err := repo.RebaseParse(git.RebaseOpts{Skip: true}) + if err != nil { + return nil, errors.Errorf("failed to skip in-progress rebase: %v", err) + } + if result.Status == git.RebaseConflict { + return result, nil + } + if err := seq.postRebaseBranchUpdate(db, seq.SequenceInterruptedNewParentHash); err != nil { + return nil, err + } + return result, nil + } + panic("unreachable") +} + +func (seq *Sequencer) rebaseBranch(repo *git.Repo, db meta.DB) (*git.RebaseResult, error) { + op := seq.getCurrentOp() + snapshot, ok := seq.OriginalBranchSnapshots[op.Name] + if !ok { + panic(fmt.Sprintf("branch %q not found in original branch infos", op.Name)) + } + + var previousParentHash plumbing.Hash + if snapshot.IsParentTrunk { + // Use the current remote tracking branch hash as the previous parent hash. + var err error + previousParentHash, err = seq.getRemoteTrackingBranchCommit(repo, snapshot.ParentBranch) + if err != nil { + return nil, err + } + } else { + previousParentHash = snapshot.PreviouslySyncedParentBranchHash + } + + var newParentHash plumbing.Hash + if op.NewParentHash.IsZero() { + if op.NewParentIsTrunk { + var err error + newParentHash, err = seq.getRemoteTrackingBranchCommit(repo, op.NewParent) + if err != nil { + return nil, err + } + } else { + var err error + newParentHash, err = seq.getBranchCommit(repo, op.NewParent) + if err != nil { + return nil, err + } + } + } else { + newParentHash = op.NewParentHash + } + + // The commits from `rebaseFrom` to `snapshot.Name` should be rebased onto `rebaseOnto`. + opts := git.RebaseOpts{ + Branch: op.Name.Short(), + Upstream: previousParentHash.String(), + Onto: newParentHash.String(), + } + result, err := repo.RebaseParse(opts) + if err != nil { + return nil, err + } + if result.Status == git.RebaseConflict { + seq.SequenceInterruptedNewParentHash = newParentHash + return result, nil + } + if err := seq.postRebaseBranchUpdate(db, newParentHash); err != nil { + return nil, err + } + return result, nil +} + +func (seq *Sequencer) checkNoUnstagedChanges(repo *git.Repo) error { + diff, err := repo.Diff(&git.DiffOpts{Quiet: true}) + if err != nil { + return err + } + if !diff.Empty { + return errors.New( + "refusing to sync: there are unstaged changes in the working tree (use `git add` to stage changes)", + ) + } + return nil +} + +func (seq *Sequencer) postRebaseBranchUpdate(db meta.DB, newParentHash plumbing.Hash) error { + op := seq.getCurrentOp() + newParentBranchState := meta.BranchState{ + Name: op.NewParent.Short(), + Trunk: op.NewParentIsTrunk, + } + if !op.NewParentIsTrunk { + newParentBranchState.Head = newParentHash.String() + } + + tx := db.WriteTx() + br, _ := tx.Branch(op.Name.Short()) + br.Parent = newParentBranchState + tx.SetBranch(br) + if err := tx.Commit(); err != nil { + return err + } + seq.SequenceInterruptedNewParentHash = plumbing.ZeroHash + for i, op := range seq.Operations { + if op.Name == seq.CurrentSyncRef { + if i+1 < len(seq.Operations) { + seq.CurrentSyncRef = seq.Operations[i+1].Name + } else { + seq.CurrentSyncRef = plumbing.ReferenceName("") + } + break + } + } + return nil +} + +func (seq *Sequencer) getCurrentOp() RestackOp { + for _, op := range seq.Operations { + if op.Name == seq.CurrentSyncRef { + return op + } + } + panic(fmt.Sprintf("op not found for ref %q", seq.CurrentSyncRef)) +} + +func (seq *Sequencer) getRemoteTrackingBranchCommit(repo *git.Repo, ref plumbing.ReferenceName) (plumbing.Hash, error) { + remote, err := repo.GoGitRepo().Remote(seq.RemoteName) + if err != nil { + return plumbing.ZeroHash, errors.Errorf("failed to get remote %q: %v", seq.RemoteName, err) + } + rtb := mapToRemoteTrackingBranch(remote.Config(), ref) + if rtb == nil { + return plumbing.ZeroHash, errors.Errorf("failed to get remote tracking branch in %q for %q", seq.RemoteName, ref) + } + return seq.getBranchCommit(repo, *rtb) +} + +func (seq *Sequencer) getBranchCommit(repo *git.Repo, ref plumbing.ReferenceName) (plumbing.Hash, error) { + refObj, err := repo.GoGitRepo().Reference(ref, false) + if err != nil { + return plumbing.ZeroHash, errors.Errorf("failed to get branch %q: %v", ref, err) + } + if refObj.Type() != plumbing.HashReference { + return plumbing.ZeroHash, errors.Errorf("unexpected reference type for branch %q: %v", ref, refObj.Type()) + } + return refObj.Hash(), nil +} + +func mapToRemoteTrackingBranch(remoteConfig *config.RemoteConfig, refName plumbing.ReferenceName) *plumbing.ReferenceName { + for _, fetch := range remoteConfig.Fetch { + if fetch.Match(refName) { + dst := fetch.Dst(refName) + return &dst + } + } + return nil +}