Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to address issue 148 #185

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions src/hooks/executor_start.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <nodes/nodes.h>
#include <nodes/pg_list.h>
#include <nodes/plannodes.h>
#include <utils/builtins.h>
#include <catalog/namespace.h>

#include "../hnsw/utils.h"
#include "plan_tree_walker.h"
Expand All @@ -17,25 +19,44 @@ ExecutorStart_hook_type original_ExecutorStart_hook = NULL;
typedef struct
{
List *oidList;
List *distanceFunctionOidList;
Oid defaultDistanceFunction;
bool isIndexScan;
bool isSequentialScan;
} OperatorUsedCorrectlyContext;

static bool operator_used_incorrectly_walker(Node *node, void *context)
{
OperatorUsedCorrectlyContext *context_typed = (OperatorUsedCorrectlyContext *)context;

if(node == NULL) return false;
if(IsA(node, IndexScan)) {
context_typed->isIndexScan = true;
bool status = plan_tree_walker((Plan *)node, operator_used_incorrectly_walker, context);
context_typed->isIndexScan = false;
return status;
}
if (IsA(node, SeqScan)) {
context_typed->isSequentialScan = true;
bool status = plan_tree_walker((Plan *)node, operator_used_incorrectly_walker, context);
context_typed->isSequentialScan = false;
return status;
}

if(IsA(node, OpExpr)) {
OpExpr *opExpr = (OpExpr *)node;
if(list_member_oid(context_typed->oidList, opExpr->opno) && !context_typed->isIndexScan) {
if(list_member_oid(context_typed->distanceFunctionOidList, opExpr->opno)) {
context_typed->defaultDistanceFunction = opExpr->opfuncid;
}
if(list_member_oid(context_typed->oidList, opExpr->opno) && !context_typed->isIndexScan && !context_typed->isSequentialScan) {
return true;
}
if(list_member_oid(context_typed->oidList, opExpr->opno) && context_typed->isSequentialScan) {
elog(WARNING, "A sequential scan is being used. This might be because queried column doesn't have an index, or the distance function of the column is a different one.");
opExpr->opfuncid = context_typed->defaultDistanceFunction;
}
}

if(IsA(node, List)) {
List *list = (List *)node;
ListCell *lc;
Expand All @@ -53,11 +74,19 @@ static bool operator_used_incorrectly_walker(Node *node, void *context)
return false;
}

static void validate_operator_usage(Plan *plan, List *oidList)
static void validate_operator_usage(Plan *plan, List *oidList, void *distanceFunctionOidList)
{
OperatorUsedCorrectlyContext context;
context.oidList = oidList;
context.distanceFunctionOidList = distanceFunctionOidList;
FuncCandidateList funcList1 = FuncnameGetCandidates(list_make1(makeString("l2sq_dist")), -1, NIL, false, false, false, true);
if (funcList1 != NULL) {
context.defaultDistanceFunction = funcList1->oid;
} else {
elog(WARNING, "Default distance function was not found.");
}
context.isIndexScan = false;
context.isSequentialScan = false;
if(operator_used_incorrectly_walker((Node *)plan, (void *)&context)) {
elog(ERROR, "Operator <-> has no standalone meaning and is reserved for use in vector index lookups only");
}
Expand All @@ -70,16 +99,18 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags)
}

List *oidList = ldb_get_operator_oids();
List *distanceFunctionOidList = ldb_get_distance_function_oids();
if(oidList != NULL) {
// oidList will be NULL if LanternDB extension is not fully initialized
// e.g. in statements executed as a result of CREATE EXTENSION ... statement
validate_operator_usage(queryDesc->plannedstmt->planTree, oidList);
validate_operator_usage(queryDesc->plannedstmt->planTree, oidList, distanceFunctionOidList);
ListCell *lc;
foreach(lc, queryDesc->plannedstmt->subplans) {
Plan *subplan = (Plan *)lfirst(lc);
validate_operator_usage(subplan, oidList);
validate_operator_usage(subplan, oidList, distanceFunctionOidList);
}
list_free(oidList);
list_free(distanceFunctionOidList);
}

standard_ExecutorStart(queryDesc, eflags);
Expand Down
43 changes: 43 additions & 0 deletions src/hooks/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <nodes/makefuncs.h>
#include <nodes/pg_list.h>
#include <parser/parse_oper.h>
#include <utils/builtins.h>
#include <catalog/namespace.h>



List *ldb_get_operator_oids()
{
Expand All @@ -23,5 +27,44 @@ List *ldb_get_operator_oids()

list_free(nameList);

return oidList;
}

List *ldb_get_distance_function_oids()
{
List *oidList = NIL;

Oid l2sq_dist_oid = InvalidOid;
Oid cos_dist_oid = InvalidOid;
Oid hamming_dist_oid = InvalidOid;
FuncCandidateList funcList1 = FuncnameGetCandidates(list_make1(makeString("l2sq_dist")), -1, NIL, false, false, false, true);
if (funcList1 != NULL) {
l2sq_dist_oid = funcList1->oid;
} else {
elog(WARNING, "l2sq_dist was not found.");
}
FuncCandidateList funcList2 = FuncnameGetCandidates(list_make1(makeString("cos_dist")), -1, NIL, false, false, false, true);
if (funcList2 != NULL) {
cos_dist_oid = funcList2->oid;
} else {
elog(WARNING, "cos_dist was not found.");
}
FuncCandidateList funcList3 = FuncnameGetCandidates(list_make1(makeString("hamming_dist")), -1, NIL, false, false, false, true);
if (funcList2 != NULL) {
hamming_dist_oid = funcList3->oid;
} else {
elog(WARNING, "hamming_dist was not found.");
}

if(OidIsValid(l2sq_dist_oid)) {
oidList = lappend_oid(oidList, l2sq_dist_oid);
}
if(OidIsValid(cos_dist_oid)) {
oidList = lappend_oid(oidList, cos_dist_oid);
}
if(OidIsValid(hamming_dist_oid)) {
oidList = lappend_oid(oidList, hamming_dist_oid);
}

return oidList;
}
1 change: 1 addition & 0 deletions src/hooks/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
#include <nodes/pg_list.h>

List *ldb_get_operator_oids();
List *ldb_get_distance_function_oids();

#endif // LDB_HOOKS_UTILS_H