diff --git a/cmd/av/stack.go b/cmd/av/stack.go index 3145b974..91e08f42 100644 --- a/cmd/av/stack.go +++ b/cmd/av/stack.go @@ -21,6 +21,7 @@ func init() { stackPrevCmd, stackReorderCmd, stackReparentCmd, + stackRestackCmd, stackSubmitCmd, stackSwitchCmd, stackSyncCmd, diff --git a/cmd/av/stack_restack.go b/cmd/av/stack_restack.go new file mode 100644 index 00000000..76a0a8ca --- /dev/null +++ b/cmd/av/stack_restack.go @@ -0,0 +1,267 @@ +package main + +import ( + "os" + "strings" + + "emperror.dev/errors" + "github.com/aviator-co/av/internal/actions" + "github.com/aviator-co/av/internal/git" + "github.com/aviator-co/av/internal/meta" + "github.com/aviator-co/av/internal/sequencer" + "github.com/aviator-co/av/internal/sequencer/planner" + "github.com/aviator-co/av/internal/utils/colors" + "github.com/aviator-co/av/internal/utils/stackutils" + "github.com/charmbracelet/bubbles/spinner" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/go-git/go-git/v5/plumbing" + "github.com/spf13/cobra" +) + +var stackRestackFlags struct { + DryRun bool + Abort bool + Continue bool + Skip bool +} + +var stackRestackCmd = &cobra.Command{ + Use: "restack", + Short: "Restack branches", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + repo, err := getRepo() + if err != nil { + return err + } + + db, err := getDB(repo) + if err != nil { + return err + } + p := tea.NewProgram(stackRestackViewModel{ + repo: repo, + db: db, + spinner: spinner.New(spinner.WithSpinner(spinner.Dot)), + }) + model, err := p.Run() + if err != nil { + return err + } + if err := model.(stackRestackViewModel).err; err != nil { + return actions.ErrExitSilently{ExitCode: 1} + } + return nil + }, +} + +type stackRestackState struct { + InitialBranch string + StNode *stackutils.StackTreeNode + Seq *sequencer.Sequencer +} + +type stackRestackSeqResult struct { + result *git.RebaseResult + err error +} + +type stackRestackViewModel struct { + repo *git.Repo + db meta.DB + state *stackRestackState + spinner spinner.Model + + rebaseConflictErrorHeadline string + rebaseConflictHint string + abortedBranch plumbing.ReferenceName + err error +} + +func (vm stackRestackViewModel) Init() tea.Cmd { + return tea.Batch(vm.spinner.Tick, vm.initCmd) +} + +func (vm stackRestackViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case error: + vm.err = msg + return vm, tea.Quit + case *stackRestackState: + vm.state = msg + if stackRestackFlags.DryRun { + return vm, tea.Quit + } + if stackRestackFlags.Skip || stackRestackFlags.Continue || stackRestackFlags.Abort { + if stackRestackFlags.Abort { + vm.abortedBranch = vm.state.Seq.CurrentSyncRef + } + return vm, vm.runSeqWithContinuationFlags + } + return vm, vm.runSeq + case *stackRestackSeqResult: + if msg.err == nil && msg.result == nil { + // Finished the sequence. + if err := vm.repo.WriteStateFile(git.StateFileKindRestack, nil); err != nil { + vm.err = err + } + if _, err := vm.repo.CheckoutBranch(&git.CheckoutBranch{Name: vm.state.InitialBranch}); err != nil { + vm.err = err + } + return vm, tea.Quit + } + if msg.result.Status == git.RebaseConflict { + vm.rebaseConflictErrorHeadline = msg.result.ErrorHeadline + vm.rebaseConflictHint = msg.result.Hint + if err := vm.repo.WriteStateFile(git.StateFileKindRestack, vm.state); err != nil { + vm.err = err + } + return vm, tea.Quit + } + vm.err = msg.err + if vm.err != nil { + return vm, tea.Quit + } + return vm, vm.runSeq + case spinner.TickMsg: + var cmd tea.Cmd + vm.spinner, cmd = vm.spinner.Update(msg) + return vm, cmd + case tea.KeyMsg: + switch msg.String() { + case "ctrl+c": + return vm, tea.Quit + } + } + return vm, nil +} + +func (vm stackRestackViewModel) View() string { + sb := strings.Builder{} + if vm.state != nil && vm.state.Seq != nil { + if vm.state.Seq.CurrentSyncRef != "" { + sb.WriteString("Restacking " + vm.state.Seq.CurrentSyncRef.Short() + "...\n") + } else if vm.abortedBranch != "" { + sb.WriteString("Restack aborted\n") + } else { + sb.WriteString("Restack done\n") + } + syncedBranches := map[plumbing.ReferenceName]bool{} + pendingBranches := map[plumbing.ReferenceName]bool{} + seenCurrent := false + for _, op := range vm.state.Seq.Operations { + if op.Name == vm.state.Seq.CurrentSyncRef || op.Name == vm.abortedBranch { + seenCurrent = true + } else if !seenCurrent { + syncedBranches[op.Name] = true + } else { + pendingBranches[op.Name] = true + } + } + + sb.WriteString(stackutils.RenderTree(vm.state.StNode, func(branchName string, isTrunk bool) string { + bn := plumbing.NewBranchReferenceName(branchName) + if syncedBranches[bn] { + return colors.Success("✓ " + branchName) + } + if pendingBranches[bn] { + return lipgloss.NewStyle().Foreground(colors.Amber500).Render(branchName) + } + if bn == vm.state.Seq.CurrentSyncRef { + return lipgloss.NewStyle().Foreground(colors.Amber500).Render(vm.spinner.View() + branchName) + } + if bn == vm.abortedBranch { + return colors.Failure("✗ " + branchName) + } + return branchName + })) + } + if vm.rebaseConflictErrorHeadline != "" { + sb.WriteString("\n") + sb.WriteString(colors.Failure("Rebase conflict while rebasing ", vm.state.Seq.CurrentSyncRef.Short()) + "\n") + sb.WriteString(vm.rebaseConflictErrorHeadline + "\n") + sb.WriteString(vm.rebaseConflictHint + "\n") + sb.WriteString("\n") + sb.WriteString("Resolve the conflicts and continue the restack with " + colors.CliCmd("av stack restack --continue") + "\n") + } + if vm.err != nil { + sb.WriteString(vm.err.Error() + "\n") + } + return sb.String() +} + +func (vm stackRestackViewModel) initCmd() tea.Msg { + if clean, err := vm.repo.CheckCleanWorkdir(); err != nil { + return err + } else if !clean { + return errors.New("the working directory is not clean, please stash or commit them before running restack command.") + } + + var state stackRestackState + if err := vm.repo.ReadStateFile(git.StateFileKindRestack, &state); err != nil && os.IsNotExist(err) { + var currentBranch string + if dh, err := vm.repo.DetachedHead(); err != nil { + return err + } else if !dh { + currentBranch, err = vm.repo.CurrentBranchName() + if err != nil { + return err + } + } + if _, exist := vm.db.ReadTx().Branch(currentBranch); !exist { + return errors.New("current branch is not adopted to av") + } + state.InitialBranch = currentBranch + state.StNode, err = stackutils.BuildStackTreeCurrentStack(vm.db.ReadTx(), currentBranch, true) + if err != nil { + return err + } + targetBranches, err := planner.GetTargetBranches(vm.db.ReadTx(), vm.repo, false, planner.CurrentStack) + if err != nil { + return err + } + ops, err := planner.PlanForRestack(vm.db.ReadTx(), vm.repo, targetBranches) + if err != nil { + return err + } + if len(ops) == 0 { + return errors.New("nothing to restack") + } + state.Seq = sequencer.NewSequencer("origin", vm.db, ops) + } else if err != nil { + return err + } + return &state +} + +func (vm stackRestackViewModel) runSeqWithContinuationFlags() tea.Msg { + result, err := vm.state.Seq.Run(vm.repo, vm.db, stackRestackFlags.Abort, stackRestackFlags.Continue, stackRestackFlags.Skip) + return &stackRestackSeqResult{result: result, err: err} +} + +func (vm stackRestackViewModel) runSeq() tea.Msg { + result, err := vm.state.Seq.Run(vm.repo, vm.db, false, false, false) + return &stackRestackSeqResult{result: result, err: err} +} + +func init() { + stackRestackCmd.Flags().BoolVar( + &stackRestackFlags.Continue, "continue", false, + "continue an in-progress restack", + ) + stackRestackCmd.Flags().BoolVar( + &stackRestackFlags.Abort, "abort", false, + "abort an in-progress restack", + ) + stackRestackCmd.Flags().BoolVar( + &stackRestackFlags.Skip, "skip", false, + "skip the current commit and continue an in-progress restack", + ) + stackRestackCmd.Flags().BoolVar( + &stackRestackFlags.DryRun, "dry-run", false, + "dry-run the restack", + ) + + stackRestackCmd.MarkFlagsMutuallyExclusive("continue", "abort", "skip") +} diff --git a/cmd/av/stack_switch.go b/cmd/av/stack_switch.go index 536999aa..6c87b0ea 100644 --- a/cmd/av/stack_switch.go +++ b/cmd/av/stack_switch.go @@ -38,7 +38,7 @@ var stackSwitchCmd = &cobra.Command{ } } - rootNodes := stackutils.BuildStackTree(tx, currentBranch) + rootNodes := stackutils.BuildStackTreeAllBranches(tx, currentBranch, true) var branchList []*stackTreeBranchInfo branches := map[string]*stackTreeBranchInfo{} for _, node := range rootNodes { diff --git a/cmd/av/stack_tree.go b/cmd/av/stack_tree.go index d29bbab9..622e79b9 100644 --- a/cmd/av/stack_tree.go +++ b/cmd/av/stack_tree.go @@ -38,7 +38,7 @@ var stackTreeCmd = &cobra.Command{ } } - rootNodes := stackutils.BuildStackTree(tx, currentBranch) + rootNodes := stackutils.BuildStackTreeAllBranches(tx, currentBranch, true) for _, node := range rootNodes { fmt.Print(stackutils.RenderTree(node, func(branchName string, isTrunk bool) string { stbi := getStackTreeBranchInfo(repo, tx, branchName) diff --git a/go.mod b/go.mod index de162972..715e4500 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.22.1 require ( emperror.dev/errors v0.8.1 github.com/charmbracelet/bubbletea v0.26.2 + github.com/charmbracelet/lipgloss v0.10.0 github.com/fatih/color v1.17.0 github.com/golangci/golangci-lint v1.58.1 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 @@ -88,7 +89,7 @@ require ( github.com/ccojocar/zxcvbn-go v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charithe/durationcheck v0.0.10 // indirect - github.com/charmbracelet/lipgloss v0.10.0 + github.com/charmbracelet/bubbles v0.18.0 github.com/chavacava/garif v0.1.0 // indirect github.com/ckaznocha/intrange v0.1.2 // indirect github.com/curioswitch/go-reassign v0.2.0 // indirect diff --git a/go.sum b/go.sum index ac5ffe1c..0b2857c7 100644 --- a/go.sum +++ b/go.sum @@ -84,6 +84,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charithe/durationcheck v0.0.10 h1:wgw73BiocdBDQPik+zcEoBG/ob8uyBHf2iyoHGPf5w4= github.com/charithe/durationcheck v0.0.10/go.mod h1:bCWXb7gYRysD1CU3C+u4ceO49LoGOY1C1L6uouGNreQ= +github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0= +github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw= github.com/charmbracelet/bubbletea v0.26.2 h1:Eeb+n75Om9gQ+I6YpbCXQRKHt5Pn4vMwusQpwLiEgJQ= github.com/charmbracelet/bubbletea v0.26.2/go.mod h1:6I0nZ3YHUrQj7YHIHlM8RySX4ZIthTliMY+W8X8b+Gs= github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s= diff --git a/internal/actions/pr.go b/internal/actions/pr.go index 13355ebb..8130c5c6 100644 --- a/internal/actions/pr.go +++ b/internal/actions/pr.go @@ -879,7 +879,8 @@ func UpdatePullRequestWithStack( repoMeta := tx.Repository() - stackToWrite, err := stackutils.BuildStackTreeForPullRequest(tx, branchName) + // Don't sort based on the current branch so that the output is consistent between branches. + stackToWrite, err := stackutils.BuildStackTreeCurrentStack(tx, branchName, false) if err != nil { return err } diff --git a/internal/actions/sync_branch.go b/internal/actions/sync_branch.go index d1d46d45..fc86f392 100644 --- a/internal/actions/sync_branch.go +++ b/internal/actions/sync_branch.go @@ -606,7 +606,7 @@ func syncBranchPushAndUpdatePullRequest( var stackToWrite *stackutils.StackTreeNode if config.Av.PullRequest.WriteStack { - if stackToWrite, err = stackutils.BuildStackTreeForPullRequest(tx, branchName); err != nil { + if stackToWrite, err = stackutils.BuildStackTreeCurrentStack(tx, branchName, false); err != nil { return err } } diff --git a/internal/git/state_file.go b/internal/git/state_file.go index 0176dee0..b58ac099 100644 --- a/internal/git/state_file.go +++ b/internal/git/state_file.go @@ -11,6 +11,7 @@ type StateFileKind string const ( StateFileKindSync StateFileKind = "stack-sync.state.json" StateFileKindReorder StateFileKind = "stack-reorder.state.json" + StateFileKindRestack StateFileKind = "stack-restack.state.json" ) func (r *Repo) ReadStateFile(kind StateFileKind, msg any) error { diff --git a/internal/sequencer/planner/planner.go b/internal/sequencer/planner/planner.go index a8f21164..9833e859 100644 --- a/internal/sequencer/planner/planner.go +++ b/internal/sequencer/planner/planner.go @@ -7,7 +7,24 @@ import ( "github.com/go-git/go-git/v5/plumbing" ) -func Plan(tx meta.ReadTx, repo *git.Repo, targetBranches []plumbing.ReferenceName, syncToTrunkInsteadOfMergeCommit bool) ([]sequencer.RestackOp, error) { +func PlanForRestack(tx meta.ReadTx, repo *git.Repo, targetBranches []plumbing.ReferenceName) ([]sequencer.RestackOp, error) { + var ret []sequencer.RestackOp + for _, br := range targetBranches { + avbr, _ := tx.Branch(br.Short()) + if avbr.MergeCommit != "" { + // Skip rebasing branches that have merge commits. + continue + } + ret = append(ret, sequencer.RestackOp{ + Name: br, + NewParent: plumbing.NewBranchReferenceName(avbr.Parent.Name), + NewParentIsTrunk: avbr.Parent.Trunk, + }) + } + return ret, nil +} + +func PlanForSync(tx meta.ReadTx, repo *git.Repo, targetBranches []plumbing.ReferenceName, syncToTrunkInsteadOfMergeCommit bool) ([]sequencer.RestackOp, error) { var ret []sequencer.RestackOp for _, br := range targetBranches { avbr, _ := tx.Branch(br.Short()) diff --git a/internal/sequencer/sequencer.go b/internal/sequencer/sequencer.go index 21099fcc..d647c5c7 100644 --- a/internal/sequencer/sequencer.go +++ b/internal/sequencer/sequencer.go @@ -57,8 +57,8 @@ type Sequencer struct { Operations []RestackOp } -func NewSequencer(remoteName string, db meta.DB, ops []RestackOp) Sequencer { - return Sequencer{ +func NewSequencer(remoteName string, db meta.DB, ops []RestackOp) *Sequencer { + return &Sequencer{ RemoteName: remoteName, OriginalBranchSnapshots: getBranchSnapshots(db), Operations: ops, @@ -184,7 +184,7 @@ func (seq *Sequencer) rebaseBranch(repo *git.Repo, db meta.DB) (*git.RebaseResul // The commits from `rebaseFrom` to `snapshot.Name` should be rebased onto `rebaseOnto`. opts := git.RebaseOpts{ - Branch: op.Name.String(), + Branch: op.Name.Short(), Upstream: previousParentHash.String(), Onto: newParentHash.String(), } diff --git a/internal/utils/stackutils/stackutils.go b/internal/utils/stackutils/stackutils.go index f91fc01d..ad8589c2 100644 --- a/internal/utils/stackutils/stackutils.go +++ b/internal/utils/stackutils/stackutils.go @@ -9,7 +9,7 @@ import ( type StackTreeBranchInfo struct { BranchName string - parentBranchName string + ParentBranchName string } type StackTreeNode struct { @@ -22,7 +22,7 @@ func buildTree(currentBranchName string, branches []*StackTreeBranchInfo, sortCu branchMap := make(map[string]*StackTreeNode) for _, branch := range branches { branchMap[branch.BranchName] = &StackTreeNode{Branch: branch} - childBranches[branch.parentBranchName] = append(childBranches[branch.parentBranchName], branch.BranchName) + childBranches[branch.ParentBranchName] = append(childBranches[branch.ParentBranchName], branch.BranchName) } for _, branch := range branches { node := branchMap[branch.BranchName] @@ -34,7 +34,7 @@ func buildTree(currentBranchName string, branches []*StackTreeBranchInfo, sortCu // Find the root branches. var rootBranches []*StackTreeNode for _, branch := range branches { - if branch.parentBranchName == "" || branchMap[branch.parentBranchName] == nil { + if branch.ParentBranchName == "" || branchMap[branch.ParentBranchName] == nil { rootBranches = append(rootBranches, branchMap[branch.BranchName]) } } @@ -84,22 +84,19 @@ func buildTree(currentBranchName string, branches []*StackTreeBranchInfo, sortCu return rootBranches } -func BuildStackTree(tx meta.ReadTx, currentBranch string) []*StackTreeNode { - return buildStackTree(currentBranch, tx.AllBranches(), true) +func BuildStackTreeAllBranches(tx meta.ReadTx, currentBranch string, sortCurrent bool) []*StackTreeNode { + return buildStackTree(currentBranch, tx.AllBranches(), sortCurrent) } -func BuildStackTreeForPullRequest(tx meta.ReadTx, currentBranch string) (*StackTreeNode, error) { +func BuildStackTreeCurrentStack(tx meta.ReadTx, currentBranch string, sortCurrent bool) (*StackTreeNode, error) { branchesToInclude, err := meta.StackBranchesMap(tx, currentBranch) if err != nil { return nil, err } - - // Don't sort based on the current branch so that the output is consistent between branches. - stackTree := buildStackTree(currentBranch, branchesToInclude, false) + stackTree := buildStackTree(currentBranch, branchesToInclude, sortCurrent) if len(stackTree) != 1 { return nil, fmt.Errorf("expected one root branch, got %d", len(stackTree)) } - return stackTree[0], nil } @@ -109,7 +106,7 @@ func buildStackTree(currentBranch string, branchesToInclude map[string]meta.Bran for _, branch := range branchesToInclude { branches = append(branches, &StackTreeBranchInfo{ BranchName: branch.Name, - parentBranchName: branch.Parent.Name, + ParentBranchName: branch.Parent.Name, }) if branch.Parent.Trunk { trunks[branch.Parent.Name] = true @@ -118,7 +115,7 @@ func buildStackTree(currentBranch string, branchesToInclude map[string]meta.Bran for branch := range trunks { branches = append(branches, &StackTreeBranchInfo{ BranchName: branch, - parentBranchName: "", + ParentBranchName: "", }) } return buildTree(currentBranch, branches, sortCurrent)