diff --git a/internal/merger/factory/factory.go b/internal/merger/factory/factory.go index 22ab6a1..cad53f5 100644 --- a/internal/merger/factory/factory.go +++ b/internal/merger/factory/factory.go @@ -117,34 +117,40 @@ func (q QuerySpec) validateGroupBy() error { } for _, c := range q.GroupBy { if !c.Validate() { - return fmt.Errorf("%w: groupby %v", ErrInvalidColumnInfo, c.Name) + return fmt.Errorf("%w: groupby %#v", ErrInvalidColumnInfo, c) } - // 清除ASC - c.Order = merger.OrderDESC - if !slice.Contains(q.Select, c) { - return fmt.Errorf("%w: groupby %v", ErrColumnNotFoundInSelectList, c.Name) + if !slice.ContainsFunc(q.Select, func(src merger.ColumnInfo) bool { return equals(src, c) }) { + return fmt.Errorf("%w: groupby %#v", ErrColumnNotFoundInSelectList, c) } } for _, c := range q.Select { - if c.AggregateFunc == "" && !slice.Contains(q.GroupBy, c) { - return fmt.Errorf("%w: 非聚合列 %v 必须出现在groupby列表中", ErrInvalidColumnInfo, c.Name) + isInGroupByList := slice.ContainsFunc(q.GroupBy, func(src merger.ColumnInfo) bool { return equals(src, c) }) + if c.AggregateFunc == "" && !isInGroupByList { + return fmt.Errorf("%w: 非聚合列 %#v 必须出现在groupby列表中", ErrInvalidColumnInfo, c) } - if c.AggregateFunc != "" && slice.Contains(q.GroupBy, c) { - return fmt.Errorf("%w: 聚合列 %v 不能出现在groupby列表中", ErrInvalidColumnInfo, c.Name) + if c.AggregateFunc != "" && isInGroupByList { + return fmt.Errorf("%w: 聚合列 %#v 不能出现在groupby列表中", ErrInvalidColumnInfo, c) } } return nil } +func equals(a, b merger.ColumnInfo) bool { + // 这里忽略Order和Distinct字段的比较 + return a.Index == b.Index && + strings.Trim(a.Name, "`") == strings.Trim(b.Name, "`") && + strings.EqualFold(a.AggregateFunc, b.AggregateFunc) && + strings.Trim(a.Alias, "`") == strings.Trim(b.Alias, "`") +} + func (q QuerySpec) validateDistinct() error { if !slice.Contains(q.Features, query.Distinct) { return nil } - // 程序走到这q.Select的长度至少为1 + // 注意: 程序走到这q.Select的长度至少为1 for _, c := range q.Select { - // case2,3 if !c.Distinct || !c.Validate() { - return fmt.Errorf("%w: distinct %v", ErrInvalidColumnInfo, c.Name) + return fmt.Errorf("%w: distinct %#v", ErrInvalidColumnInfo, c) } } return nil @@ -158,15 +164,11 @@ func (q QuerySpec) validateOrderBy() error { return fmt.Errorf("%w: orderby", ErrEmptyColumnList) } for _, c := range q.OrderBy { - if !c.Validate() { - return fmt.Errorf("%w: orderby %v", ErrInvalidColumnInfo, c.Name) + return fmt.Errorf("%w: orderby %#v", ErrInvalidColumnInfo, c) } - _, ok := slice.Find(q.Select, func(src merger.ColumnInfo) bool { - return src.Index == c.Index && src.SelectName() == c.SelectName() - }) - if !ok { - return fmt.Errorf("%w: orderby %v", ErrColumnNotFoundInSelectList, c.Name) + if !slice.ContainsFunc(q.Select, func(src merger.ColumnInfo) bool { return equals(src, c) }) { + return fmt.Errorf("%w: orderby %#v", ErrColumnNotFoundInSelectList, c) } } return nil