diff --git a/github/github.go b/github/github.go index 43d7f22b8..2fb55861f 100644 --- a/github/github.go +++ b/github/github.go @@ -379,23 +379,32 @@ func (client *Client) commitMessage(pr *github.PullRequest, mergeMethod string) } } - var r *regexp.Regexp - if strings.Contains(pr.GetBody(), "==COMMIT_MSG==") { - r = regexp.MustCompile("(?s:(==COMMIT_MSG==\r\n)(.*)(\r\n==COMMIT_MSG==))") - } else if strings.Contains(pr.GetBody(), "==SQUASH_MSG==") { - r = regexp.MustCompile("(?s:(==SQUASH_MSG==\r\n)(.*)(\r\n==SQUASH_MSG==))") - } - if r != nil { - m := r.FindStringSubmatch(pr.GetBody()) - if len(m) == 4 { - commitMessage = m[2] - } + if msg, ok := extractMessageOverride(pr.GetBody()); ok { + commitMessage = msg } } return commitMessage, nil } +func extractMessageOverride(body string) (msg string, found bool) { + var r *regexp.Regexp + if strings.Contains(body, "==COMMIT_MSG==") { + r = regexp.MustCompile(`(?sm:(==COMMIT_MSG==\s*)^(.*)$(\s*==COMMIT_MSG==))`) + } else if strings.Contains(body, "==SQUASH_MSG==") { + r = regexp.MustCompile(`(?sm:(==SQUASH_MSG==\s*)^(.*)$(\s*==SQUASH_MSG==))`) + } + + if r != nil { + m := r.FindStringSubmatch(body) + if len(m) == 4 { + msg = strings.TrimSpace(m[2]) + found = true + } + } + return +} + func (client *Client) Merge(pr *github.PullRequest) error { logger := client.Logger diff --git a/github/github_test.go b/github/github_test.go index f44bed11e..0711a4113 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -771,3 +771,26 @@ func TestSquashCommitMessage(t *testing.T) { require.Nil(t, err) require.Equal(t, "", commitMessage) } + +func TestExtractMessageOverride(t *testing.T) { + _, ok := extractMessageOverride("no override here") + assert.False(t, ok, "found unexpected message override") + + _, ok = extractMessageOverride("==COMITT_MSG==\r\nUnclosed message") + assert.False(t, ok, "found unexpected message override") + + msg, ok := extractMessageOverride("==COMMIT_MSG==\r\nThe real message\r\n==COMMIT_MSG==") + if assert.True(t, ok, "override was not found") { + assert.Equal(t, "The real message", msg) + } + + msg, ok = extractMessageOverride("==COMMIT_MSG== \r\nThe real message\r\n ==COMMIT_MSG==") + if assert.True(t, ok, "override was not found") { + assert.Equal(t, "The real message", msg) + } + + msg, ok = extractMessageOverride("==SQUASH_MSG==\nThe real message\n==SQUASH_MSG==") + if assert.True(t, ok, "override was not found") { + assert.Equal(t, "The real message", msg) + } +}