From 9a176fd6513844663376ae67265fcd1b0691e4c8 Mon Sep 17 00:00:00 2001 From: Song Gao Date: Mon, 14 Aug 2023 18:42:38 +0800 Subject: [PATCH] fix: fix alias cycle reference (#2184) Signed-off-by: yisaer --- internal/topo/planner/analyzer.go | 31 ++++++++++++++++++++------ internal/topo/planner/analyzer_test.go | 4 ++++ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/internal/topo/planner/analyzer.go b/internal/topo/planner/analyzer.go index f4c74629fb..3a0f382420 100644 --- a/internal/topo/planner/analyzer.go +++ b/internal/topo/planner/analyzer.go @@ -256,13 +256,9 @@ func checkAliasReferenceCycle(s *ast.SelectStatement) bool { _, ok := aliasRef[f.Name] if ok { aliasRef[field.AName][f.Name] = struct{}{} - v, ok1 := aliasRef[f.Name] - if ok1 { - _, ok2 := v[field.AName] - if ok2 { - hasCycleAlias = true - return false - } + if dfsRef(aliasRef, map[string]struct{}{}, f.Name, field.AName) { + hasCycleAlias = true + return false } } } @@ -277,6 +273,27 @@ func checkAliasReferenceCycle(s *ast.SelectStatement) bool { return false } +func dfsRef(aliasRef map[string]map[string]struct{}, walked map[string]struct{}, currentName, targetName string) bool { + defer func() { + walked[currentName] = struct{}{} + }() + for refName := range aliasRef[currentName] { + if refName == targetName { + return true + } + } + for name := range aliasRef[currentName] { + _, ok := walked[name] + if ok { + continue + } + if dfsRef(aliasRef, walked, name, targetName) { + return true + } + } + return false +} + func aliasFieldTopoSort(s *ast.SelectStatement, streamStmts []*streamInfo) { nonAliasFields := make([]ast.Field, 0) aliasDegreeMap := make(map[string]*aliasTopoDegree) diff --git a/internal/topo/planner/analyzer_test.go b/internal/topo/planner/analyzer_test.go index 8a4d4cc766..5840206f08 100644 --- a/internal/topo/planner/analyzer_test.go +++ b/internal/topo/planner/analyzer_test.go @@ -142,6 +142,10 @@ var tests = []struct { sql: "select a + 1 as b, b + 1 as a from src1", r: newErrorStruct("select fields have cycled alias"), }, + { + sql: "select a + 1 as b, b * 2 as c, c + 1 as a from src1", + r: newErrorStruct("select fields have cycled alias"), + }, //{ // 19 already captured in parser // sql: `SELECT * FROM src1 GROUP BY SlidingWindow(ss,5) Over (WHEN abs(sum(a)) > 1) HAVING last_agg_hit_count() < 3`, // r: newErrorStruct("error compile sql: Not allowed to call aggregate functions in GROUP BY clause."),