Skip to content

Commit

Permalink
Add graph projection (#4630)
Browse files Browse the repository at this point in the history
* Add graph projection

* Run clang-format

---------

Co-authored-by: CI Bot <[email protected]>
  • Loading branch information
andyfengHKU and andyfengHKU authored Dec 15, 2024
1 parent 2bb28c2 commit 5460249
Show file tree
Hide file tree
Showing 32 changed files with 3,466 additions and 4,088 deletions.
40 changes: 26 additions & 14 deletions extension/fts/src/function/query_fts_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
std::move(columnTypes), std::move(columnNames), std::move(config));
}

static std::unique_ptr<QueryResult> runQuery(main::ClientContext* context, std::string query) {
auto result =
context->queryInternal(query, "", false /* enumerateAllPlans*/, std::nullopt /* queryID*/);
if (!result->isSuccess()) {
throw RuntimeException(result->getErrorMessage());
}
return result;
}

static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output) {
// TODO(Xiyang/Ziyi): Currently we don't have a dedicated planner for queryFTS, so
// we need a wrapper call function to CALL the actual GDS function.
Expand All @@ -97,27 +106,30 @@ static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output)
auto avgDocLen = bindData.entry.getAvgDocLen();
auto query = common::stringFormat("UNWIND tokenize('{}') AS tk RETURN COUNT(DISTINCT tk);",
actualQuery);
auto numTermsInQuery = data.context->clientContext
->queryInternal(query, "" /* encodedJoin */,
false /* enumerateAllPlans */, std::nullopt /* queryID */)
->getNext()
->getValue(0)
->toString();
query = common::stringFormat("PROJECT GRAPH PK (`{}`, `{}`, `{}`) "
"UNWIND tokenize('{}') AS tk "
auto clientContext = data.context->clientContext;
auto result = runQuery(clientContext, query);
auto numTermsInQuery = result->getNext()->getValue(0)->toString();
// Project graph
query = stringFormat("CALL create_project_graph('PK', ['{}', '{}'], ['{}'])",
bindData.getTermsTableName(), bindData.getDocsTableName(),
bindData.getAppearsInTableName());
runQuery(clientContext, query);
// Compute score
query = common::stringFormat("UNWIND tokenize('{}') AS tk "
"WITH collect(stem(tk, '{}')) AS keywords "
"MATCH (a:`{}`) "
"WHERE list_contains(keywords, a.term) "
"CALL QFTS(PK, a, {}, {}, cast({} as UINT64), {}, {}, {}) "
"MATCH (p:`{}`) "
"WHERE _node.docID = offset(id(p)) "
"RETURN p, score",
bindData.getTermsTableName(), bindData.getDocsTableName(),
bindData.getAppearsInTableName(), actualQuery, bindData.entry.getFTSConfig().stemmer,
bindData.getTermsTableName(), bindData.config.k, bindData.config.b, numDocs, avgDocLen,
numTermsInQuery, bindData.config.isConjunctive ? "true" : "false", bindData.tableName);
localState->result = data.context->clientContext->queryInternal(query, "", false,
std::nullopt /* queryID */);
actualQuery, bindData.entry.getFTSConfig().stemmer, bindData.getTermsTableName(),
bindData.config.k, bindData.config.b, numDocs, avgDocLen, numTermsInQuery,
bindData.config.isConjunctive ? "true" : "false", bindData.tableName);
localState->result = runQuery(clientContext, query);
// Remove project graph
query = stringFormat("CALL drop_project_graph('PK')");
runQuery(clientContext, query);
}
if (localState->numRowsOutput >= localState->result->getNumTuples()) {
return 0;
Expand Down
19 changes: 2 additions & 17 deletions scripts/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,7 @@ kU_InstallExtension
: INSTALL SP oC_Variable ;

oC_Query
: (kU_ProjectGraph SP? )? oC_RegularQuery ;

kU_ProjectGraph
: PROJECT SP GRAPH SP oC_SchemaName SP? '(' SP? kU_GraphProjectionTableItems SP? ')' ;

kU_GraphProjectionTableItems
: kU_GraphProjectionTableItem ( SP? ',' SP? kU_GraphProjectionTableItem )* ;
: oC_RegularQuery ;

oC_RegularQuery
: oC_SingleQuery ( SP? oC_Union )*
Expand Down Expand Up @@ -489,16 +483,7 @@ kU_LoadFrom
: LOAD ( SP WITH SP HEADERS SP? '(' SP? kU_ColumnDefinitions SP? ')' )? SP FROM SP kU_ScanSource (SP? kU_ParsingOptions)? (SP? oC_Where)? ;

kU_InQueryCall
: ( kU_ProjectGraph SP? )? CALL SP oC_FunctionInvocation (SP? oC_Where)? ;

kU_GraphProjectionTableItem
: oC_SchemaName ( SP? '{' SP? kU_GraphProjectionColumnItems SP? '}' )? ;

kU_GraphProjectionColumnItems
: kU_GraphProjectionColumnItem ( SP? ',' SP? kU_GraphProjectionColumnItem )* ;

kU_GraphProjectionColumnItem
: oC_PropertyKeyName ( SP kU_Default )? ( SP oC_Where )? ;
: CALL SP oC_FunctionInvocation (SP? oC_Where)? ;

oC_Match
: ( OPTIONAL SP )? MATCH SP? oC_Pattern ( SP oC_Where )? ( SP kU_Hint )? ;
Expand Down
2 changes: 1 addition & 1 deletion scripts/antlr4/hash.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
81546847023f2a4c8d1fec8ef8ff0d88
8836c17d54a52a955eddd8a411a02dba
19 changes: 2 additions & 17 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,7 @@ kU_InstallExtension
: INSTALL SP oC_Variable ;

oC_Query
: (kU_ProjectGraph SP? )? oC_RegularQuery ;

kU_ProjectGraph
: PROJECT SP GRAPH SP oC_SchemaName SP? '(' SP? kU_GraphProjectionTableItems SP? ')' ;

kU_GraphProjectionTableItems
: kU_GraphProjectionTableItem ( SP? ',' SP? kU_GraphProjectionTableItem )* ;
: oC_RegularQuery ;

oC_RegularQuery
: oC_SingleQuery ( SP? oC_Union )*
Expand Down Expand Up @@ -262,16 +256,7 @@ kU_LoadFrom
: LOAD ( SP WITH SP HEADERS SP? '(' SP? kU_ColumnDefinitions SP? ')' )? SP FROM SP kU_ScanSource (SP? kU_ParsingOptions)? (SP? oC_Where)? ;

kU_InQueryCall
: ( kU_ProjectGraph SP? )? CALL SP oC_FunctionInvocation (SP? oC_Where)? ;

kU_GraphProjectionTableItem
: oC_SchemaName ( SP? '{' SP? kU_GraphProjectionColumnItems SP? '}' )? ;

kU_GraphProjectionColumnItems
: kU_GraphProjectionColumnItem ( SP? ',' SP? kU_GraphProjectionColumnItem )* ;

kU_GraphProjectionColumnItem
: oC_PropertyKeyName ( SP kU_Default )? ( SP oC_Where )? ;
: CALL SP oC_FunctionInvocation (SP? oC_Where)? ;

oC_Match
: ( OPTIONAL SP )? MATCH SP? oC_Pattern ( SP oC_Where )? ( SP kU_Hint )? ;
Expand Down
1 change: 0 additions & 1 deletion src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ add_library(
OBJECT
bind_attach_database.cpp
bind_create_macro.cpp
bind_project_graph.cpp
bind_ddl.cpp
bind_detach_database.cpp
bind_explain.cpp
Expand Down
37 changes: 0 additions & 37 deletions src/binder/bind/bind_project_graph.cpp

This file was deleted.

6 changes: 0 additions & 6 deletions src/binder/bind/bind_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ void validateIsAllUnionOrUnionAll(const BoundRegularQuery& regularQuery) {
}

std::unique_ptr<BoundRegularQuery> Binder::bindQuery(const RegularQuery& regularQuery) {
if (regularQuery.hasProjectGraph()) {
auto projectGraph = regularQuery.getProjectGraph();
KU_ASSERT(!graphEntrySet.hasGraph(projectGraph->getGraphName()));
auto entry = bindProjectGraph(*projectGraph);
graphEntrySet.addGraph(projectGraph->getGraphName(), entry);
}
std::vector<NormalizedSingleQuery> normalizedSingleQueries;
for (auto i = 0u; i < regularQuery.getNumSingleQueries(); i++) {
// Don't clear scope within bindSingleQuery() yet because it is also used for subquery
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/read/bind_in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
}
auto varName =
functionExpr->getChild(0)->constPtrCast<ParsedVariableExpression>()->getVariableName();
if (!graphEntrySet.hasGraph(varName)) {
if (!clientContext->getGraphEntrySetUnsafe().hasGraph(varName)) {
throw BinderException(stringFormat("Cannot find graph {}.", varName));
}
auto graphEntry = graphEntrySet.getEntry(varName);
auto graphEntry = clientContext->getGraphEntrySetUnsafe().getEntry(varName);
expression_vector children;
std::vector<LogicalType> childrenTypes;
children.push_back(nullptr); // placeholder for graph variable.
Expand Down
2 changes: 2 additions & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ FunctionCollection* FunctionCollection::getFunctions() {

// Standalone Table functions
STANDALONE_TABLE_FUNCTION(ClearWarningsFunction),
STANDALONE_TABLE_FUNCTION(CreateProjectGraphFunction),
STANDALONE_TABLE_FUNCTION(DropProjectGraphFunction),

// Scan functions
TABLE_FUNCTION(ParquetScanFunction), TABLE_FUNCTION(NpyScanFunction),
Expand Down
9 changes: 7 additions & 2 deletions src/function/table/bind_input.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "function/table/bind_input.h"

#include "binder/expression/expression_util.h"
#include "binder/expression/literal_expression.h"

namespace kuzu {
Expand All @@ -9,10 +10,14 @@ void TableFuncBindInput::addLiteralParam(common::Value value) {
params.push_back(std::make_shared<binder::LiteralExpression>(std::move(value), ""));
}

common::Value TableFuncBindInput::getValue(common::idx_t idx) const {
binder::ExpressionUtil::validateExpressionType(*params[idx], common::ExpressionType::LITERAL);
return params[idx]->constCast<binder::LiteralExpression>().getValue();
}

template<typename T>
T TableFuncBindInput::getLiteralVal(common::idx_t idx) const {
KU_ASSERT(params[idx]->expressionType == common::ExpressionType::LITERAL);
return params[idx]->constCast<binder::LiteralExpression>().getValue().getValue<T>();
return getValue(idx).getValue<T>();
}

template KUZU_API std::string TableFuncBindInput::getLiteralVal<std::string>(
Expand Down
2 changes: 2 additions & 0 deletions src/function/table/call/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
add_library(kuzu_table_call
OBJECT
bm_info.cpp
create_project_graph.cpp
current_setting.cpp
db_version.cpp
drop_project_graph.cpp
show_connection.cpp
show_attached_databases.cpp
show_tables.cpp
Expand Down
92 changes: 92 additions & 0 deletions src/function/table/call/create_project_graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "common/exception/runtime.h"
#include "common/types/value/nested.h"
#include "function/table/simple_table_functions.h"
#include "graph/graph_entry.h"
#include "processor/execution_context.h"

using namespace kuzu::common;
using namespace kuzu::catalog;

namespace kuzu {
namespace function {

struct CreateProjectGraphBindData : SimpleTableFuncBindData {
std::string graphName;
std::vector<TableCatalogEntry*> nodeEntries;
std::vector<TableCatalogEntry*> relEntries;

CreateProjectGraphBindData(std::string graphName, std::vector<TableCatalogEntry*> nodeEntries,
std::vector<TableCatalogEntry*> relEntries)
: SimpleTableFuncBindData{0}, graphName{graphName}, nodeEntries{nodeEntries},
relEntries{relEntries} {}

std::unique_ptr<TableFuncBindData> copy() const override {
return std::make_unique<CreateProjectGraphBindData>(graphName, nodeEntries, relEntries);
}
};

static common::offset_t tableFunc(TableFuncInput& input, TableFuncOutput& /*output*/) {
auto bindData = ku_dynamic_cast<CreateProjectGraphBindData*>(input.bindData);
auto& graphEntrySet = input.context->clientContext->getGraphEntrySetUnsafe();
if (graphEntrySet.hasGraph(bindData->graphName)) {
throw RuntimeException(
stringFormat("Project graph {} already exists.", bindData->graphName));
}
auto entry = graph::GraphEntry(bindData->nodeEntries, bindData->relEntries);
graphEntrySet.addGraph(bindData->graphName, std::move(entry));
return 0;
}

static std::vector<std::string> getAsStringVector(const Value& value) {
std::vector<std::string> result;
for (auto i = 0u; i < NestedVal::getChildrenSize(&value); ++i) {
result.push_back(NestedVal::getChildVal(&value, i)->getValue<std::string>());
}
return result;
}

static std::vector<catalog::TableCatalogEntry*> getTableEntries(
const std::vector<std::string>& tableNames, CatalogEntryType expectedType,
main::ClientContext& context) {
std::vector<catalog::TableCatalogEntry*> entries;
for (auto& tableName : tableNames) {
auto entry = context.getCatalog()->getTableCatalogEntry(context.getTx(), tableName);
if (entry->getType() != expectedType) {
throw BinderException(stringFormat("Expect catalog entry type {} but got {}.",
CatalogEntryTypeUtils::toString(expectedType),
CatalogEntryTypeUtils::toString(entry->getType())));
}
entries.push_back(entry);
}
return entries;
}

static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
TableFuncBindInput* input) {
auto graphName = input->getLiteralVal<std::string>(0);
auto nodeTableNames = getAsStringVector(input->getValue(1));
auto relTableNames = getAsStringVector(input->getValue(2));
auto nodeEntries =
getTableEntries(nodeTableNames, CatalogEntryType::NODE_TABLE_ENTRY, *context);
auto relEntries = getTableEntries(relTableNames, CatalogEntryType::REL_TABLE_ENTRY, *context);
return std::make_unique<CreateProjectGraphBindData>(graphName, nodeEntries, relEntries);
}

function_set CreateProjectGraphFunction::getFunctionSet() {
function_set functionSet;
auto func =
std::make_unique<TableFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::STRING,
LogicalTypeID::LIST, LogicalTypeID::LIST});
func->bindFunc = bindFunc;
func->tableFunc = tableFunc;
func->initSharedStateFunc = initSharedState;
func->initLocalStateFunc = initEmptyLocalState;
func->canParallelFunc = []() { return false; };
functionSet.push_back(std::move(func));
return functionSet;
}

} // namespace function
} // namespace kuzu
53 changes: 53 additions & 0 deletions src/function/table/call/drop_project_graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "common/exception/runtime.h"
#include "function/table/simple_table_functions.h"
#include "graph/graph_entry.h"
#include "processor/execution_context.h"

using namespace kuzu::common;

namespace kuzu {
namespace function {

struct DropProjectGraphBindData : SimpleTableFuncBindData {
std::string graphName;

explicit DropProjectGraphBindData(std::string graphName)
: SimpleTableFuncBindData{0}, graphName{graphName} {}

std::unique_ptr<TableFuncBindData> copy() const override {
return std::make_unique<DropProjectGraphBindData>(graphName);
}
};

static common::offset_t tableFunc(TableFuncInput& input, TableFuncOutput& /*output*/) {
auto bindData = ku_dynamic_cast<DropProjectGraphBindData*>(input.bindData);
auto& graphEntrySet = input.context->clientContext->getGraphEntrySetUnsafe();
if (!graphEntrySet.hasGraph(bindData->graphName)) {
throw RuntimeException(
stringFormat("Project graph {} does not exists.", bindData->graphName));
}
graphEntrySet.dropGraph(bindData->graphName);
return 0;
}

static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext*,
TableFuncBindInput* input) {
auto graphName = input->getLiteralVal<std::string>(0);
return std::make_unique<DropProjectGraphBindData>(graphName);
}

function_set DropProjectGraphFunction::getFunctionSet() {
function_set functionSet;
auto func =
std::make_unique<TableFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::STRING});
func->bindFunc = bindFunc;
func->tableFunc = tableFunc;
func->initSharedStateFunc = initSharedState;
func->initLocalStateFunc = initEmptyLocalState;
func->canParallelFunc = []() { return false; };
functionSet.push_back(std::move(func));
return functionSet;
}

} // namespace function
} // namespace kuzu
Loading

0 comments on commit 5460249

Please sign in to comment.