Skip to content

Commit 4a2376e

Browse files
committed
Add support for the source command
1 parent aa83826 commit 4a2376e

File tree

5 files changed

+106
-33
lines changed

5 files changed

+106
-33
lines changed

include/hello.inc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
--echo Hello from the included file

r/source.result

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
first line
2+
Hello from the included file
3+
last line

src/main.go

+98-33
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,15 @@ const (
7272
type query struct {
7373
firstWord string
7474
Query string
75+
File string
7576
Line int
7677
tp int
7778
}
7879

80+
func (q *query) location() string {
81+
return fmt.Sprintf("%s:%d", q.File, q.Line)
82+
}
83+
7984
type Conn struct {
8085
// DB might be a shared one by multiple Conn, if the connection information are the same.
8186
mdb *sql.DB
@@ -325,7 +330,7 @@ func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt
325330
func (t *tester) Run() error {
326331
t.preProcess()
327332
defer t.postProcess()
328-
queries, err := t.loadQueries()
333+
queries, err := t.loadQueries(t.testFileName())
329334
if err != nil {
330335
err = errors.Trace(err)
331336
t.addFailure(&testSuite, &err, 0)
@@ -338,17 +343,33 @@ func (t *tester) Run() error {
338343
return err
339344
}
340345

341-
var s string
342346
defer func() {
343347
if t.resultFD != nil {
344348
t.resultFD.Close()
345349
}
346350
}()
347351

348-
testCnt := 0
349352
startTime := time.Now()
353+
testCnt, err := t.runQueries(queries)
354+
if err != nil {
355+
return err
356+
}
357+
358+
fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds())
359+
360+
if xmlPath != "" {
361+
t.addSuccess(&testSuite, &startTime, testCnt)
362+
}
363+
364+
return t.flushResult()
365+
}
366+
367+
func (t *tester) runQueries(queries []query) (int, error) {
368+
testCnt := 0
350369
var concurrentQueue []query
351370
var concurrentSize int
371+
var s string
372+
var err error
352373
for _, q := range queries {
353374
s = q.Query
354375
switch q.tp {
@@ -379,15 +400,15 @@ func (t *tester) Run() error {
379400
if err != nil {
380401
err = errors.Annotate(err, "Atoi failed")
381402
t.addFailure(&testSuite, &err, testCnt)
382-
return err
403+
return testCnt, err
383404
}
384405
}
385406
case Q_END_CONCURRENT:
386407
t.enableConcurrent = false
387408
if err = t.concurrentRun(concurrentQueue, concurrentSize); err != nil {
388409
err = errors.Annotate(err, fmt.Sprintf("concurrent test failed in %v", t.name))
389410
t.addFailure(&testSuite, &err, testCnt)
390-
return err
411+
return testCnt, err
391412
}
392413
t.expectedErrs = nil
393414
case Q_ERROR:
@@ -406,7 +427,7 @@ func (t *tester) Run() error {
406427
} else if err = t.execute(q); err != nil {
407428
err = errors.Annotate(err, fmt.Sprintf("sql:%v", q.Query))
408429
t.addFailure(&testSuite, &err, testCnt)
409-
return err
430+
return testCnt, err
410431
}
411432

412433
testCnt++
@@ -426,7 +447,7 @@ func (t *tester) Run() error {
426447
if err != nil {
427448
err = errors.Annotate(err, fmt.Sprintf("Could not parse column in --replace_column: sql:%v", q.Query))
428449
t.addFailure(&testSuite, &err, testCnt)
429-
return err
450+
return testCnt, err
430451
}
431452

432453
t.replaceColumn = append(t.replaceColumn, ReplaceColumn{col: colNr, replace: []byte(cols[i+1])})
@@ -473,7 +494,7 @@ func (t *tester) Run() error {
473494
r, err := t.executeStmtString(s)
474495
if err != nil {
475496
log.WithFields(log.Fields{
476-
"query": s, "line": q.Line},
497+
"query": s, "line": q.location()},
477498
).Error("failed to perform let query")
478499
return ""
479500
}
@@ -484,27 +505,59 @@ func (t *tester) Run() error {
484505
case Q_REMOVE_FILE:
485506
err = os.Remove(strings.TrimSpace(q.Query))
486507
if err != nil {
487-
return errors.Annotate(err, "failed to remove file")
508+
return testCnt, errors.Annotate(err, "failed to remove file")
488509
}
489510
case Q_REPLACE_REGEX:
490511
t.replaceRegex = nil
491512
regex, err := ParseReplaceRegex(q.Query)
492513
if err != nil {
493-
return errors.Annotate(err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %d sql:%v", q.Line, q.Query))
514+
return testCnt, errors.Annotate(
515+
err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %s sql:%v",
516+
q.location(), q.Query))
494517
}
495518
t.replaceRegex = regex
496-
default:
497-
log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.Line}).Warn("command not implemented")
498-
}
499-
}
519+
case Q_SOURCE:
520+
fileName := strings.TrimSpace(q.Query)
521+
cwd, err := os.Getwd()
522+
if err != nil {
523+
return testCnt, err
524+
}
500525

501-
fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds())
526+
// For security, don't allow to include files from other locations
527+
fullpath, err := filepath.Abs(fileName)
528+
if err != nil {
529+
return testCnt, err
530+
}
531+
if !strings.HasPrefix(fullpath, cwd) {
532+
return testCnt, errors.Errorf("included file %s is not prefixed with %s", fullpath, cwd)
533+
}
502534

503-
if xmlPath != "" {
504-
t.addSuccess(&testSuite, &startTime, testCnt)
505-
}
535+
// Make sure we have a useful error message if the file can't be found or isn't a regular file
536+
s, err := os.Stat(fileName)
537+
if err != nil {
538+
return testCnt, errors.Annotate(err,
539+
fmt.Sprintf("file sourced with --source doesn't exist: line %s, file: %s",
540+
q.location(), fileName))
541+
}
542+
if !s.Mode().IsRegular() {
543+
return testCnt, errors.Errorf("file sourced with --source isn't a regular file: line %s, file: %s",
544+
q.location(), fileName)
545+
}
506546

507-
return t.flushResult()
547+
// Process the queries in the file
548+
includedQueries, err := t.loadQueries(fileName)
549+
if err != nil {
550+
return testCnt, errors.Annotate(err, fmt.Sprintf("error loading queries from %s", fileName))
551+
}
552+
_, err = t.runQueries(includedQueries)
553+
if err != nil {
554+
return testCnt, err
555+
}
556+
default:
557+
log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.location()}).Warn("command not implemented")
558+
}
559+
}
560+
return testCnt, nil
508561
}
509562

510563
func (t *tester) concurrentRun(concurrentQueue []query, concurrentSize int) error {
@@ -606,8 +659,8 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure
606659
}
607660
}
608661

609-
func (t *tester) loadQueries() ([]query, error) {
610-
data, err := os.ReadFile(t.testFileName())
662+
func (t *tester) loadQueries(fileName string) ([]query, error) {
663+
data, err := os.ReadFile(fileName)
611664
if err != nil {
612665
return nil, err
613666
}
@@ -623,18 +676,30 @@ func (t *tester) loadQueries() ([]query, error) {
623676
newStmt = true
624677
continue
625678
} else if strings.HasPrefix(s, "--") {
626-
queries = append(queries, query{Query: s, Line: i + 1})
679+
queries = append(queries, query{
680+
Query: s,
681+
Line: i + 1,
682+
File: fileName,
683+
})
627684
newStmt = true
628685
continue
629686
} else if len(s) == 0 {
630687
continue
631688
}
632689

633690
if newStmt {
634-
queries = append(queries, query{Query: s, Line: i + 1})
691+
queries = append(queries, query{
692+
Query: s,
693+
Line: i + 1,
694+
File: fileName,
695+
})
635696
} else {
636697
lastQuery := queries[len(queries)-1]
637-
lastQuery = query{Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), Line: lastQuery.Line}
698+
lastQuery = query{
699+
Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s),
700+
Line: lastQuery.Line,
701+
File: fileName,
702+
}
638703
queries[len(queries)-1] = lastQuery
639704
}
640705

@@ -668,8 +733,8 @@ func (t *tester) checkExpectedError(q query, err error) error {
668733
}
669734
}
670735
if !checkErr {
671-
log.Warnf("%s:%d query succeeded, but expected error(s)! (expected errors: %s) (query: %s)",
672-
t.name, q.Line, strings.Join(t.expectedErrs, ","), q.Query)
736+
log.Warnf("%s query succeeded, but expected error(s)! (expected errors: %s) (query: %s)",
737+
q.location(), strings.Join(t.expectedErrs, ","), q.Query)
673738
return nil
674739
}
675740
return errors.Errorf("Statement succeeded, expected error(s) '%s'", strings.Join(t.expectedErrs, ","))
@@ -684,7 +749,7 @@ func (t *tester) checkExpectedError(q query, err error) error {
684749
errNo = int(innerErr.Number)
685750
}
686751
if errNo == 0 {
687-
log.Warnf("%s:%d Could not parse mysql error: %s", t.name, q.Line, err.Error())
752+
log.Warnf("%s Could not parse mysql error: %s", q.location(), err.Error())
688753
return err
689754
}
690755
for _, s := range t.expectedErrs {
@@ -696,9 +761,9 @@ func (t *tester) checkExpectedError(q query, err error) error {
696761
checkErrNo = i
697762
} else {
698763
if len(t.expectedErrs) > 1 {
699-
log.Warnf("%s:%d Unknown named error %s in --error %s", t.name, q.Line, s, strings.Join(t.expectedErrs, ","))
764+
log.Warnf("%s Unknown named error %s in --error %s", q.location(), s, strings.Join(t.expectedErrs, ","))
700765
} else {
701-
log.Warnf("%s:%d Unknown named --error %s", t.name, q.Line, s)
766+
log.Warnf("%s Unknown named --error %s", q.location(), s)
702767
}
703768
continue
704769
}
@@ -726,11 +791,11 @@ func (t *tester) checkExpectedError(q query, err error) error {
726791
}
727792
}
728793
if len(t.expectedErrs) > 1 {
729-
log.Warnf("%s:%d query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)",
730-
t.name, q.Line, gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query)
794+
log.Warnf("%s query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)",
795+
q.location(), gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query)
731796
} else {
732-
log.Warnf("%s:%d query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)",
733-
t.name, q.Line, gotErrCode, t.expectedErrs[0], err.Error(), q.Query)
797+
log.Warnf("%s query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)",
798+
q.location(), gotErrCode, t.expectedErrs[0], err.Error(), q.Query)
734799
}
735800
errStr := err.Error()
736801
for _, reg := range t.replaceRegex {

src/query.go

+1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ func ParseQueries(qs ...query) ([]query, error) {
136136
q := query{}
137137
q.tp = Q_UNKNOWN
138138
q.Line = rs.Line
139+
q.File = rs.File
139140
// a valid query's length should be at least 3.
140141
if len(s) < 3 {
141142
continue

t/source.test

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
--echo first line
2+
--source include/hello.inc
3+
--echo last line

0 commit comments

Comments
 (0)