From 102eb0c226c4f25bc9e28b64049861616a74bf63 Mon Sep 17 00:00:00 2001 From: gnehzza Date: Thu, 28 Sep 2023 04:33:08 -0400 Subject: [PATCH] changes to address issue 148 --- src/hooks/executor_start.c | 39 ++++++++++++++++++++++++++++++---- src/hooks/utils.c | 43 ++++++++++++++++++++++++++++++++++++++ src/hooks/utils.h | 1 + 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/src/hooks/executor_start.c b/src/hooks/executor_start.c index 866cae499..0b3958ec2 100644 --- a/src/hooks/executor_start.c +++ b/src/hooks/executor_start.c @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include "../hnsw/utils.h" #include "plan_tree_walker.h" @@ -17,12 +19,16 @@ ExecutorStart_hook_type original_ExecutorStart_hook = NULL; typedef struct { List *oidList; + List *distanceFunctionOidList; + Oid defaultDistanceFunction; bool isIndexScan; + bool isSequentialScan; } OperatorUsedCorrectlyContext; static bool operator_used_incorrectly_walker(Node *node, void *context) { OperatorUsedCorrectlyContext *context_typed = (OperatorUsedCorrectlyContext *)context; + if(node == NULL) return false; if(IsA(node, IndexScan)) { context_typed->isIndexScan = true; @@ -30,12 +36,27 @@ static bool operator_used_incorrectly_walker(Node *node, void *context) context_typed->isIndexScan = false; return status; } + if (IsA(node, SeqScan)) { + context_typed->isSequentialScan = true; + bool status = plan_tree_walker((Plan *)node, operator_used_incorrectly_walker, context); + context_typed->isSequentialScan = false; + return status; + } + if(IsA(node, OpExpr)) { OpExpr *opExpr = (OpExpr *)node; - if(list_member_oid(context_typed->oidList, opExpr->opno) && !context_typed->isIndexScan) { + if(list_member_oid(context_typed->distanceFunctionOidList, opExpr->opno)) { + context_typed->defaultDistanceFunction = opExpr->opfuncid; + } + if(list_member_oid(context_typed->oidList, opExpr->opno) && !context_typed->isIndexScan && !context_typed->isSequentialScan) { return true; } + if(list_member_oid(context_typed->oidList, opExpr->opno) && context_typed->isSequentialScan) { + elog(WARNING, "A sequential scan is being used. This might be because queried column doesn't have an index, or the distance function of the column is a different one."); + opExpr->opfuncid = context_typed->defaultDistanceFunction; + } } + if(IsA(node, List)) { List *list = (List *)node; ListCell *lc; @@ -53,11 +74,19 @@ static bool operator_used_incorrectly_walker(Node *node, void *context) return false; } -static void validate_operator_usage(Plan *plan, List *oidList) +static void validate_operator_usage(Plan *plan, List *oidList, void *distanceFunctionOidList) { OperatorUsedCorrectlyContext context; context.oidList = oidList; + context.distanceFunctionOidList = distanceFunctionOidList; + FuncCandidateList funcList1 = FuncnameGetCandidates(list_make1(makeString("l2sq_dist")), -1, NIL, false, false, false, true); + if (funcList1 != NULL) { + context.defaultDistanceFunction = funcList1->oid; + } else { + elog(WARNING, "Default distance function was not found."); + } context.isIndexScan = false; + context.isSequentialScan = false; if(operator_used_incorrectly_walker((Node *)plan, (void *)&context)) { elog(ERROR, "Operator <-> has no standalone meaning and is reserved for use in vector index lookups only"); } @@ -70,16 +99,18 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags) } List *oidList = ldb_get_operator_oids(); + List *distanceFunctionOidList = ldb_get_distance_function_oids(); if(oidList != NULL) { // oidList will be NULL if LanternDB extension is not fully initialized // e.g. in statements executed as a result of CREATE EXTENSION ... statement - validate_operator_usage(queryDesc->plannedstmt->planTree, oidList); + validate_operator_usage(queryDesc->plannedstmt->planTree, oidList, distanceFunctionOidList); ListCell *lc; foreach(lc, queryDesc->plannedstmt->subplans) { Plan *subplan = (Plan *)lfirst(lc); - validate_operator_usage(subplan, oidList); + validate_operator_usage(subplan, oidList, distanceFunctionOidList); } list_free(oidList); + list_free(distanceFunctionOidList); } standard_ExecutorStart(queryDesc, eflags); diff --git a/src/hooks/utils.c b/src/hooks/utils.c index 34ed1adbc..8bedf2453 100644 --- a/src/hooks/utils.c +++ b/src/hooks/utils.c @@ -4,6 +4,10 @@ #include #include #include +#include +#include + + List *ldb_get_operator_oids() { @@ -23,5 +27,44 @@ List *ldb_get_operator_oids() list_free(nameList); + return oidList; +} + +List *ldb_get_distance_function_oids() +{ + List *oidList = NIL; + + Oid l2sq_dist_oid = InvalidOid; + Oid cos_dist_oid = InvalidOid; + Oid hamming_dist_oid = InvalidOid; + FuncCandidateList funcList1 = FuncnameGetCandidates(list_make1(makeString("l2sq_dist")), -1, NIL, false, false, false, true); + if (funcList1 != NULL) { + l2sq_dist_oid = funcList1->oid; + } else { + elog(WARNING, "l2sq_dist was not found."); + } + FuncCandidateList funcList2 = FuncnameGetCandidates(list_make1(makeString("cos_dist")), -1, NIL, false, false, false, true); + if (funcList2 != NULL) { + cos_dist_oid = funcList2->oid; + } else { + elog(WARNING, "cos_dist was not found."); + } + FuncCandidateList funcList3 = FuncnameGetCandidates(list_make1(makeString("hamming_dist")), -1, NIL, false, false, false, true); + if (funcList2 != NULL) { + hamming_dist_oid = funcList3->oid; + } else { + elog(WARNING, "hamming_dist was not found."); + } + + if(OidIsValid(l2sq_dist_oid)) { + oidList = lappend_oid(oidList, l2sq_dist_oid); + } + if(OidIsValid(cos_dist_oid)) { + oidList = lappend_oid(oidList, cos_dist_oid); + } + if(OidIsValid(hamming_dist_oid)) { + oidList = lappend_oid(oidList, hamming_dist_oid); + } + return oidList; } \ No newline at end of file diff --git a/src/hooks/utils.h b/src/hooks/utils.h index ea3516b3c..772935b73 100644 --- a/src/hooks/utils.h +++ b/src/hooks/utils.h @@ -6,5 +6,6 @@ #include List *ldb_get_operator_oids(); +List *ldb_get_distance_function_oids(); #endif // LDB_HOOKS_UTILS_H \ No newline at end of file