diff --git a/ydb/library/yql/dq/opt/dq_opt_hypergraph_ut.cpp b/ydb/library/yql/dq/opt/dq_opt_hypergraph_ut.cpp index 2ba71afe5d40..9a70bd4ba251 100644 --- a/ydb/library/yql/dq/opt/dq_opt_hypergraph_ut.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_hypergraph_ut.cpp @@ -423,6 +423,22 @@ Y_UNIT_TEST_SUITE(HypergraphBuild) { } } + Y_UNIT_TEST(ManyCondsBetweenJoinForTransitiveClosure) { + auto join = Join(Join("A", "B", "A.PUDGE=B.PUDGE,A.DOTA=B.DOTA"), "C", "A.PUDGE=C.PUDGE,A.DOTA=C.DOTA"); + + auto graph = MakeJoinHypergraph(join); + Cout << graph.String() << Endl; + + auto B = graph.GetNodesByRelNames({"B"}); + auto C = graph.GetNodesByRelNames({"C"}); + UNIT_ASSERT(graph.FindEdgeBetween(B, C)); + + { + auto optimizedJoin = Enumerate(join, TOptimizerHints::Parse("Rows(B C # 0)")); + UNIT_ASSERT(HaveSameConditionCount(optimizedJoin, join)); + } + } + auto MakeClique(size_t size) { std::shared_ptr root = Join("R0", "R1", "R0.id=R1.id"); diff --git a/ydb/library/yql/dq/opt/dq_opt_join_hypergraph.h b/ydb/library/yql/dq/opt/dq_opt_join_hypergraph.h index b91d7f6af614..98ec18bfa98d 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join_hypergraph.h +++ b/ydb/library/yql/dq/opt/dq_opt_join_hypergraph.h @@ -481,8 +481,21 @@ class TTransitiveClosureConstructor { auto iNode = Graph_.GetNodesByRelNames({joinCondById[i].RelName}); auto jNode = Graph_.GetNodesByRelNames({joinCondById[j].RelName}); - if (Graph_.FindEdgeBetween(iNode, jNode)) { - continue; + if (auto* maybeEdge = Graph_.FindEdgeBetween(iNode, jNode)) { + auto addUniqueKey = [](auto& vector, const auto& key) { + if (std::find(vector.begin(), vector.end(), key) == vector.end()) { + vector.push_back(key); + } + }; + + auto& revEdge = Graph_.GetEdge(maybeEdge->ReversedEdgeId); + addUniqueKey(revEdge.LeftJoinKeys, joinCondById[j]); + addUniqueKey(revEdge.RightJoinKeys, joinCondById[i]); + + auto& edge = Graph_.GetEdge(revEdge.ReversedEdgeId); + addUniqueKey(edge.LeftJoinKeys, joinCondById[i]); + addUniqueKey(edge.RightJoinKeys, joinCondById[j]); + continue; } Graph_.AddEdge(THyperedge(iNode, jNode, InnerJoin, false, false, true, {joinCondById[i]}, {joinCondById[j]}));