From 4a2376e5da7f112fb996f6f235ea2e1a39f41c3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 26 Jul 2024 10:49:14 +0200 Subject: [PATCH 1/4] Add support for the source command --- include/hello.inc | 1 + r/source.result | 3 ++ src/main.go | 131 ++++++++++++++++++++++++++++++++++------------ src/query.go | 1 + t/source.test | 3 ++ 5 files changed, 106 insertions(+), 33 deletions(-) create mode 100644 include/hello.inc create mode 100644 r/source.result create mode 100644 t/source.test diff --git a/include/hello.inc b/include/hello.inc new file mode 100644 index 0000000..3b2bbc4 --- /dev/null +++ b/include/hello.inc @@ -0,0 +1 @@ +--echo Hello from the included file diff --git a/r/source.result b/r/source.result new file mode 100644 index 0000000..911f76d --- /dev/null +++ b/r/source.result @@ -0,0 +1,3 @@ +first line +Hello from the included file +last line diff --git a/src/main.go b/src/main.go index 9ddaf45..97fcd4b 100644 --- a/src/main.go +++ b/src/main.go @@ -72,10 +72,15 @@ const ( type query struct { firstWord string Query string + File string Line int tp int } +func (q *query) location() string { + return fmt.Sprintf("%s:%d", q.File, q.Line) +} + type Conn struct { // DB might be a shared one by multiple Conn, if the connection information are the same. mdb *sql.DB @@ -325,7 +330,7 @@ func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt func (t *tester) Run() error { t.preProcess() defer t.postProcess() - queries, err := t.loadQueries() + queries, err := t.loadQueries(t.testFileName()) if err != nil { err = errors.Trace(err) t.addFailure(&testSuite, &err, 0) @@ -338,17 +343,33 @@ func (t *tester) Run() error { return err } - var s string defer func() { if t.resultFD != nil { t.resultFD.Close() } }() - testCnt := 0 startTime := time.Now() + testCnt, err := t.runQueries(queries) + if err != nil { + return err + } + + fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) + + if xmlPath != "" { + t.addSuccess(&testSuite, &startTime, testCnt) + } + + return t.flushResult() +} + +func (t *tester) runQueries(queries []query) (int, error) { + testCnt := 0 var concurrentQueue []query var concurrentSize int + var s string + var err error for _, q := range queries { s = q.Query switch q.tp { @@ -379,7 +400,7 @@ func (t *tester) Run() error { if err != nil { err = errors.Annotate(err, "Atoi failed") t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } } case Q_END_CONCURRENT: @@ -387,7 +408,7 @@ func (t *tester) Run() error { if err = t.concurrentRun(concurrentQueue, concurrentSize); err != nil { err = errors.Annotate(err, fmt.Sprintf("concurrent test failed in %v", t.name)) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } t.expectedErrs = nil case Q_ERROR: @@ -406,7 +427,7 @@ func (t *tester) Run() error { } else if err = t.execute(q); err != nil { err = errors.Annotate(err, fmt.Sprintf("sql:%v", q.Query)) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } testCnt++ @@ -426,7 +447,7 @@ func (t *tester) Run() error { if err != nil { err = errors.Annotate(err, fmt.Sprintf("Could not parse column in --replace_column: sql:%v", q.Query)) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } t.replaceColumn = append(t.replaceColumn, ReplaceColumn{col: colNr, replace: []byte(cols[i+1])}) @@ -473,7 +494,7 @@ func (t *tester) Run() error { r, err := t.executeStmtString(s) if err != nil { log.WithFields(log.Fields{ - "query": s, "line": q.Line}, + "query": s, "line": q.location()}, ).Error("failed to perform let query") return "" } @@ -484,27 +505,59 @@ func (t *tester) Run() error { case Q_REMOVE_FILE: err = os.Remove(strings.TrimSpace(q.Query)) if err != nil { - return errors.Annotate(err, "failed to remove file") + return testCnt, errors.Annotate(err, "failed to remove file") } case Q_REPLACE_REGEX: t.replaceRegex = nil regex, err := ParseReplaceRegex(q.Query) if err != nil { - return errors.Annotate(err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %d sql:%v", q.Line, q.Query)) + return testCnt, errors.Annotate( + err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %s sql:%v", + q.location(), q.Query)) } t.replaceRegex = regex - default: - log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.Line}).Warn("command not implemented") - } - } + case Q_SOURCE: + fileName := strings.TrimSpace(q.Query) + cwd, err := os.Getwd() + if err != nil { + return testCnt, err + } - fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) + // For security, don't allow to include files from other locations + fullpath, err := filepath.Abs(fileName) + if err != nil { + return testCnt, err + } + if !strings.HasPrefix(fullpath, cwd) { + return testCnt, errors.Errorf("included file %s is not prefixed with %s", fullpath, cwd) + } - if xmlPath != "" { - t.addSuccess(&testSuite, &startTime, testCnt) - } + // Make sure we have a useful error message if the file can't be found or isn't a regular file + s, err := os.Stat(fileName) + if err != nil { + return testCnt, errors.Annotate(err, + fmt.Sprintf("file sourced with --source doesn't exist: line %s, file: %s", + q.location(), fileName)) + } + if !s.Mode().IsRegular() { + return testCnt, errors.Errorf("file sourced with --source isn't a regular file: line %s, file: %s", + q.location(), fileName) + } - return t.flushResult() + // Process the queries in the file + includedQueries, err := t.loadQueries(fileName) + if err != nil { + return testCnt, errors.Annotate(err, fmt.Sprintf("error loading queries from %s", fileName)) + } + _, err = t.runQueries(includedQueries) + if err != nil { + return testCnt, err + } + default: + log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.location()}).Warn("command not implemented") + } + } + return testCnt, nil } func (t *tester) concurrentRun(concurrentQueue []query, concurrentSize int) error { @@ -606,8 +659,8 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure } } -func (t *tester) loadQueries() ([]query, error) { - data, err := os.ReadFile(t.testFileName()) +func (t *tester) loadQueries(fileName string) ([]query, error) { + data, err := os.ReadFile(fileName) if err != nil { return nil, err } @@ -623,7 +676,11 @@ func (t *tester) loadQueries() ([]query, error) { newStmt = true continue } else if strings.HasPrefix(s, "--") { - queries = append(queries, query{Query: s, Line: i + 1}) + queries = append(queries, query{ + Query: s, + Line: i + 1, + File: fileName, + }) newStmt = true continue } else if len(s) == 0 { @@ -631,10 +688,18 @@ func (t *tester) loadQueries() ([]query, error) { } if newStmt { - queries = append(queries, query{Query: s, Line: i + 1}) + queries = append(queries, query{ + Query: s, + Line: i + 1, + File: fileName, + }) } else { lastQuery := queries[len(queries)-1] - lastQuery = query{Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), Line: lastQuery.Line} + lastQuery = query{ + Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), + Line: lastQuery.Line, + File: fileName, + } queries[len(queries)-1] = lastQuery } @@ -668,8 +733,8 @@ func (t *tester) checkExpectedError(q query, err error) error { } } if !checkErr { - log.Warnf("%s:%d query succeeded, but expected error(s)! (expected errors: %s) (query: %s)", - t.name, q.Line, strings.Join(t.expectedErrs, ","), q.Query) + log.Warnf("%s query succeeded, but expected error(s)! (expected errors: %s) (query: %s)", + q.location(), strings.Join(t.expectedErrs, ","), q.Query) return nil } return errors.Errorf("Statement succeeded, expected error(s) '%s'", strings.Join(t.expectedErrs, ",")) @@ -684,7 +749,7 @@ func (t *tester) checkExpectedError(q query, err error) error { errNo = int(innerErr.Number) } if errNo == 0 { - log.Warnf("%s:%d Could not parse mysql error: %s", t.name, q.Line, err.Error()) + log.Warnf("%s Could not parse mysql error: %s", q.location(), err.Error()) return err } for _, s := range t.expectedErrs { @@ -696,9 +761,9 @@ func (t *tester) checkExpectedError(q query, err error) error { checkErrNo = i } else { if len(t.expectedErrs) > 1 { - log.Warnf("%s:%d Unknown named error %s in --error %s", t.name, q.Line, s, strings.Join(t.expectedErrs, ",")) + log.Warnf("%s Unknown named error %s in --error %s", q.location(), s, strings.Join(t.expectedErrs, ",")) } else { - log.Warnf("%s:%d Unknown named --error %s", t.name, q.Line, s) + log.Warnf("%s Unknown named --error %s", q.location(), s) } continue } @@ -726,11 +791,11 @@ func (t *tester) checkExpectedError(q query, err error) error { } } if len(t.expectedErrs) > 1 { - log.Warnf("%s:%d query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)", - t.name, q.Line, gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query) + log.Warnf("%s query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)", + q.location(), gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query) } else { - log.Warnf("%s:%d query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)", - t.name, q.Line, gotErrCode, t.expectedErrs[0], err.Error(), q.Query) + log.Warnf("%s query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)", + q.location(), gotErrCode, t.expectedErrs[0], err.Error(), q.Query) } errStr := err.Error() for _, reg := range t.replaceRegex { diff --git a/src/query.go b/src/query.go index 6a128d8..f9bb41b 100644 --- a/src/query.go +++ b/src/query.go @@ -136,6 +136,7 @@ func ParseQueries(qs ...query) ([]query, error) { q := query{} q.tp = Q_UNKNOWN q.Line = rs.Line + q.File = rs.File // a valid query's length should be at least 3. if len(s) < 3 { continue diff --git a/t/source.test b/t/source.test new file mode 100644 index 0000000..a204f48 --- /dev/null +++ b/t/source.test @@ -0,0 +1,3 @@ +--echo first line +--source include/hello.inc +--echo last line From 3c938b540ee366624544db17b648f411b74df303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 26 Jul 2024 11:04:31 +0200 Subject: [PATCH 2/4] Include included tests in test count --- include/hello.inc | 1 + r/source.result | 3 +++ src/main.go | 5 +++-- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/include/hello.inc b/include/hello.inc index 3b2bbc4..b9ed56d 100644 --- a/include/hello.inc +++ b/include/hello.inc @@ -1 +1,2 @@ --echo Hello from the included file +SELECT 1 diff --git a/r/source.result b/r/source.result index 911f76d..053b6f8 100644 --- a/r/source.result +++ b/r/source.result @@ -1,3 +1,6 @@ first line Hello from the included file +SELECT 1 +1 +1 last line diff --git a/src/main.go b/src/main.go index 97fcd4b..fb0307b 100644 --- a/src/main.go +++ b/src/main.go @@ -425,7 +425,7 @@ func (t *tester) runQueries(queries []query) (int, error) { if t.enableConcurrent { concurrentQueue = append(concurrentQueue, q) } else if err = t.execute(q); err != nil { - err = errors.Annotate(err, fmt.Sprintf("sql:%v", q.Query)) + err = errors.Annotate(err, fmt.Sprintf("sql:%v line:%s", q.Query, q.location())) t.addFailure(&testSuite, &err, testCnt) return testCnt, err } @@ -549,10 +549,11 @@ func (t *tester) runQueries(queries []query) (int, error) { if err != nil { return testCnt, errors.Annotate(err, fmt.Sprintf("error loading queries from %s", fileName)) } - _, err = t.runQueries(includedQueries) + includeCnt, err := t.runQueries(includedQueries) if err != nil { return testCnt, err } + testCnt += includeCnt default: log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.location()}).Warn("command not implemented") } From 41d573d3d2a9b31f325675355014b8162235ace4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 30 Jul 2024 09:03:58 +0200 Subject: [PATCH 3/4] Add enable_source/disable_source --- include/hello.inc | 2 -- r/source.result | 8 ++++++-- src/main.go | 9 +++++++++ src/query.go | 2 ++ src/type.go | 2 ++ t/source.test | 12 +++++++++--- 6 files changed, 28 insertions(+), 7 deletions(-) delete mode 100644 include/hello.inc diff --git a/include/hello.inc b/include/hello.inc deleted file mode 100644 index b9ed56d..0000000 --- a/include/hello.inc +++ /dev/null @@ -1,2 +0,0 @@ ---echo Hello from the included file -SELECT 1 diff --git a/r/source.result b/r/source.result index 053b6f8..ad266c0 100644 --- a/r/source.result +++ b/r/source.result @@ -1,6 +1,10 @@ -first line +before source Hello from the included file SELECT 1 1 1 -last line +after source +Goodbye from the included file +SELECT 3 +3 +3 diff --git a/src/main.go b/src/main.go index fb0307b..dfaa51a 100644 --- a/src/main.go +++ b/src/main.go @@ -47,6 +47,7 @@ var ( retryConnCount int collationDisable bool checkErr bool + disableSource bool ) func init() { @@ -516,7 +517,15 @@ func (t *tester) runQueries(queries []query) (int, error) { q.location(), q.Query)) } t.replaceRegex = regex + case Q_ENABLE_SOURCE: + disableSource = false + case Q_DISABLE_SOURCE: + disableSource = true case Q_SOURCE: + if disableSource { + log.WithFields(log.Fields{"line": q.location()}).Warn("source command disabled") + break + } fileName := strings.TrimSpace(q.Query) cwd, err := os.Getwd() if err != nil { diff --git a/src/query.go b/src/query.go index f9bb41b..8f4518f 100644 --- a/src/query.go +++ b/src/query.go @@ -124,6 +124,8 @@ const ( Q_COMMENT /* Comments, ignored. */ Q_COMMENT_WITH_COMMAND Q_EMPTY_LINE + Q_DISABLE_SOURCE + Q_ENABLE_SOURCE ) // ParseQueries parses an array of string into an array of query object. diff --git a/src/type.go b/src/type.go index 50ea5a6..d4b4766 100644 --- a/src/type.go +++ b/src/type.go @@ -114,6 +114,8 @@ var commandMap = map[string]int{ "single_query": Q_SINGLE_QUERY, "begin_concurrent": Q_BEGIN_CONCURRENT, "end_concurrent": Q_END_CONCURRENT, + "disable_source": Q_DISABLE_SOURCE, + "enable_source": Q_ENABLE_SOURCE, } func findType(cmdName string) int { diff --git a/t/source.test b/t/source.test index a204f48..05b440d 100644 --- a/t/source.test +++ b/t/source.test @@ -1,3 +1,9 @@ ---echo first line ---source include/hello.inc ---echo last line +--echo before source +--source include/hello1.inc +--echo after source + +--disable_source +--source include/hello2.inc +--enable_source + +--source include/hello3.inc From 32d8f5387df40e45d2035fccbdca2ac72ebfdbba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 30 Jul 2024 09:09:02 +0200 Subject: [PATCH 4/4] Disable source by default --- src/main.go | 5 ++++- t/source.test | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/main.go b/src/main.go index dfaa51a..5e09363 100644 --- a/src/main.go +++ b/src/main.go @@ -51,6 +51,9 @@ var ( ) func init() { + // Disable the `--source` command by default to avoid breaking existing tests + disableSource = true + flag.StringVar(&host, "host", "127.0.0.1", "The host of the TiDB/MySQL server.") flag.StringVar(&port, "port", "4000", "The listen port of TiDB/MySQL server.") flag.StringVar(&user, "user", "root", "The user for connecting to the database.") @@ -523,7 +526,7 @@ func (t *tester) runQueries(queries []query) (int, error) { disableSource = true case Q_SOURCE: if disableSource { - log.WithFields(log.Fields{"line": q.location()}).Warn("source command disabled") + log.WithFields(log.Fields{"line": q.location()}).Warn("source command disabled, add '--enable_source' to your file to enable") break } fileName := strings.TrimSpace(q.Query) diff --git a/t/source.test b/t/source.test index 05b440d..05140a3 100644 --- a/t/source.test +++ b/t/source.test @@ -1,3 +1,5 @@ +--enable_source + --echo before source --source include/hello1.inc --echo after source