@@ -72,10 +72,15 @@ const (
72
72
type query struct {
73
73
firstWord string
74
74
Query string
75
+ File string
75
76
Line int
76
77
tp int
77
78
}
78
79
80
+ func (q * query ) location () string {
81
+ return fmt .Sprintf ("%s:%d" , q .File , q .Line )
82
+ }
83
+
79
84
type Conn struct {
80
85
// DB might be a shared one by multiple Conn, if the connection information are the same.
81
86
mdb * sql.DB
@@ -325,7 +330,7 @@ func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt
325
330
func (t * tester ) Run () error {
326
331
t .preProcess ()
327
332
defer t .postProcess ()
328
- queries , err := t .loadQueries ()
333
+ queries , err := t .loadQueries (t . testFileName () )
329
334
if err != nil {
330
335
err = errors .Trace (err )
331
336
t .addFailure (& testSuite , & err , 0 )
@@ -338,17 +343,33 @@ func (t *tester) Run() error {
338
343
return err
339
344
}
340
345
341
- var s string
342
346
defer func () {
343
347
if t .resultFD != nil {
344
348
t .resultFD .Close ()
345
349
}
346
350
}()
347
351
348
- testCnt := 0
349
352
startTime := time .Now ()
353
+ testCnt , err := t .runQueries (queries )
354
+ if err != nil {
355
+ return err
356
+ }
357
+
358
+ fmt .Printf ("%s: ok! %d test cases passed, take time %v s\n " , t .testFileName (), testCnt , time .Since (startTime ).Seconds ())
359
+
360
+ if xmlPath != "" {
361
+ t .addSuccess (& testSuite , & startTime , testCnt )
362
+ }
363
+
364
+ return t .flushResult ()
365
+ }
366
+
367
+ func (t * tester ) runQueries (queries []query ) (int , error ) {
368
+ testCnt := 0
350
369
var concurrentQueue []query
351
370
var concurrentSize int
371
+ var s string
372
+ var err error
352
373
for _ , q := range queries {
353
374
s = q .Query
354
375
switch q .tp {
@@ -379,15 +400,15 @@ func (t *tester) Run() error {
379
400
if err != nil {
380
401
err = errors .Annotate (err , "Atoi failed" )
381
402
t .addFailure (& testSuite , & err , testCnt )
382
- return err
403
+ return testCnt , err
383
404
}
384
405
}
385
406
case Q_END_CONCURRENT :
386
407
t .enableConcurrent = false
387
408
if err = t .concurrentRun (concurrentQueue , concurrentSize ); err != nil {
388
409
err = errors .Annotate (err , fmt .Sprintf ("concurrent test failed in %v" , t .name ))
389
410
t .addFailure (& testSuite , & err , testCnt )
390
- return err
411
+ return testCnt , err
391
412
}
392
413
t .expectedErrs = nil
393
414
case Q_ERROR :
@@ -406,7 +427,7 @@ func (t *tester) Run() error {
406
427
} else if err = t .execute (q ); err != nil {
407
428
err = errors .Annotate (err , fmt .Sprintf ("sql:%v" , q .Query ))
408
429
t .addFailure (& testSuite , & err , testCnt )
409
- return err
430
+ return testCnt , err
410
431
}
411
432
412
433
testCnt ++
@@ -426,7 +447,7 @@ func (t *tester) Run() error {
426
447
if err != nil {
427
448
err = errors .Annotate (err , fmt .Sprintf ("Could not parse column in --replace_column: sql:%v" , q .Query ))
428
449
t .addFailure (& testSuite , & err , testCnt )
429
- return err
450
+ return testCnt , err
430
451
}
431
452
432
453
t .replaceColumn = append (t .replaceColumn , ReplaceColumn {col : colNr , replace : []byte (cols [i + 1 ])})
@@ -473,7 +494,7 @@ func (t *tester) Run() error {
473
494
r , err := t .executeStmtString (s )
474
495
if err != nil {
475
496
log .WithFields (log.Fields {
476
- "query" : s , "line" : q .Line },
497
+ "query" : s , "line" : q .location () },
477
498
).Error ("failed to perform let query" )
478
499
return ""
479
500
}
@@ -484,27 +505,59 @@ func (t *tester) Run() error {
484
505
case Q_REMOVE_FILE :
485
506
err = os .Remove (strings .TrimSpace (q .Query ))
486
507
if err != nil {
487
- return errors .Annotate (err , "failed to remove file" )
508
+ return testCnt , errors .Annotate (err , "failed to remove file" )
488
509
}
489
510
case Q_REPLACE_REGEX :
490
511
t .replaceRegex = nil
491
512
regex , err := ParseReplaceRegex (q .Query )
492
513
if err != nil {
493
- return errors .Annotate (err , fmt .Sprintf ("Could not parse regex in --replace_regex: line: %d sql:%v" , q .Line , q .Query ))
514
+ return testCnt , errors .Annotate (
515
+ err , fmt .Sprintf ("Could not parse regex in --replace_regex: line: %s sql:%v" ,
516
+ q .location (), q .Query ))
494
517
}
495
518
t .replaceRegex = regex
496
- default :
497
- log .WithFields (log.Fields {"command" : q .firstWord , "arguments" : q .Query , "line" : q .Line }).Warn ("command not implemented" )
498
- }
499
- }
519
+ case Q_SOURCE :
520
+ fileName := strings .TrimSpace (q .Query )
521
+ cwd , err := os .Getwd ()
522
+ if err != nil {
523
+ return testCnt , err
524
+ }
500
525
501
- fmt .Printf ("%s: ok! %d test cases passed, take time %v s\n " , t .testFileName (), testCnt , time .Since (startTime ).Seconds ())
526
+ // For security, don't allow to include files from other locations
527
+ fullpath , err := filepath .Abs (fileName )
528
+ if err != nil {
529
+ return testCnt , err
530
+ }
531
+ if ! strings .HasPrefix (fullpath , cwd ) {
532
+ return testCnt , errors .Errorf ("included file %s is not prefixed with %s" , fullpath , cwd )
533
+ }
502
534
503
- if xmlPath != "" {
504
- t .addSuccess (& testSuite , & startTime , testCnt )
505
- }
535
+ // Make sure we have a useful error message if the file can't be found or isn't a regular file
536
+ s , err := os .Stat (fileName )
537
+ if err != nil {
538
+ return testCnt , errors .Annotate (err ,
539
+ fmt .Sprintf ("file sourced with --source doesn't exist: line %s, file: %s" ,
540
+ q .location (), fileName ))
541
+ }
542
+ if ! s .Mode ().IsRegular () {
543
+ return testCnt , errors .Errorf ("file sourced with --source isn't a regular file: line %s, file: %s" ,
544
+ q .location (), fileName )
545
+ }
506
546
507
- return t .flushResult ()
547
+ // Process the queries in the file
548
+ includedQueries , err := t .loadQueries (fileName )
549
+ if err != nil {
550
+ return testCnt , errors .Annotate (err , fmt .Sprintf ("error loading queries from %s" , fileName ))
551
+ }
552
+ _ , err = t .runQueries (includedQueries )
553
+ if err != nil {
554
+ return testCnt , err
555
+ }
556
+ default :
557
+ log .WithFields (log.Fields {"command" : q .firstWord , "arguments" : q .Query , "line" : q .location ()}).Warn ("command not implemented" )
558
+ }
559
+ }
560
+ return testCnt , nil
508
561
}
509
562
510
563
func (t * tester ) concurrentRun (concurrentQueue []query , concurrentSize int ) error {
@@ -606,8 +659,8 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure
606
659
}
607
660
}
608
661
609
- func (t * tester ) loadQueries () ([]query , error ) {
610
- data , err := os .ReadFile (t . testFileName () )
662
+ func (t * tester ) loadQueries (fileName string ) ([]query , error ) {
663
+ data , err := os .ReadFile (fileName )
611
664
if err != nil {
612
665
return nil , err
613
666
}
@@ -623,18 +676,30 @@ func (t *tester) loadQueries() ([]query, error) {
623
676
newStmt = true
624
677
continue
625
678
} else if strings .HasPrefix (s , "--" ) {
626
- queries = append (queries , query {Query : s , Line : i + 1 })
679
+ queries = append (queries , query {
680
+ Query : s ,
681
+ Line : i + 1 ,
682
+ File : fileName ,
683
+ })
627
684
newStmt = true
628
685
continue
629
686
} else if len (s ) == 0 {
630
687
continue
631
688
}
632
689
633
690
if newStmt {
634
- queries = append (queries , query {Query : s , Line : i + 1 })
691
+ queries = append (queries , query {
692
+ Query : s ,
693
+ Line : i + 1 ,
694
+ File : fileName ,
695
+ })
635
696
} else {
636
697
lastQuery := queries [len (queries )- 1 ]
637
- lastQuery = query {Query : fmt .Sprintf ("%s\n %s" , lastQuery .Query , s ), Line : lastQuery .Line }
698
+ lastQuery = query {
699
+ Query : fmt .Sprintf ("%s\n %s" , lastQuery .Query , s ),
700
+ Line : lastQuery .Line ,
701
+ File : fileName ,
702
+ }
638
703
queries [len (queries )- 1 ] = lastQuery
639
704
}
640
705
@@ -668,8 +733,8 @@ func (t *tester) checkExpectedError(q query, err error) error {
668
733
}
669
734
}
670
735
if ! checkErr {
671
- log .Warnf ("%s:%d query succeeded, but expected error(s)! (expected errors: %s) (query: %s)" ,
672
- t . name , q . Line , strings .Join (t .expectedErrs , "," ), q .Query )
736
+ log .Warnf ("%s query succeeded, but expected error(s)! (expected errors: %s) (query: %s)" ,
737
+ q . location () , strings .Join (t .expectedErrs , "," ), q .Query )
673
738
return nil
674
739
}
675
740
return errors .Errorf ("Statement succeeded, expected error(s) '%s'" , strings .Join (t .expectedErrs , "," ))
@@ -684,7 +749,7 @@ func (t *tester) checkExpectedError(q query, err error) error {
684
749
errNo = int (innerErr .Number )
685
750
}
686
751
if errNo == 0 {
687
- log .Warnf ("%s:%d Could not parse mysql error: %s" , t . name , q . Line , err .Error ())
752
+ log .Warnf ("%s Could not parse mysql error: %s" , q . location () , err .Error ())
688
753
return err
689
754
}
690
755
for _ , s := range t .expectedErrs {
@@ -696,9 +761,9 @@ func (t *tester) checkExpectedError(q query, err error) error {
696
761
checkErrNo = i
697
762
} else {
698
763
if len (t .expectedErrs ) > 1 {
699
- log .Warnf ("%s:%d Unknown named error %s in --error %s" , t . name , q . Line , s , strings .Join (t .expectedErrs , "," ))
764
+ log .Warnf ("%s Unknown named error %s in --error %s" , q . location () , s , strings .Join (t .expectedErrs , "," ))
700
765
} else {
701
- log .Warnf ("%s:%d Unknown named --error %s" , t . name , q . Line , s )
766
+ log .Warnf ("%s Unknown named --error %s" , q . location () , s )
702
767
}
703
768
continue
704
769
}
@@ -726,11 +791,11 @@ func (t *tester) checkExpectedError(q query, err error) error {
726
791
}
727
792
}
728
793
if len (t .expectedErrs ) > 1 {
729
- log .Warnf ("%s:%d query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)" ,
730
- t . name , q . Line , gotErrCode , strings .Join (t .expectedErrs , "," ), err .Error (), q .Query )
794
+ log .Warnf ("%s query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)" ,
795
+ q . location () , gotErrCode , strings .Join (t .expectedErrs , "," ), err .Error (), q .Query )
731
796
} else {
732
- log .Warnf ("%s:%d query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)" ,
733
- t . name , q . Line , gotErrCode , t .expectedErrs [0 ], err .Error (), q .Query )
797
+ log .Warnf ("%s query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)" ,
798
+ q . location () , gotErrCode , t .expectedErrs [0 ], err .Error (), q .Query )
734
799
}
735
800
errStr := err .Error ()
736
801
for _ , reg := range t .replaceRegex {
0 commit comments