-
-
Notifications
You must be signed in to change notification settings - Fork 17.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/Add Neo4j GraphRag support (#3686)
* added: Neo4j database connectivity, Neo4j credentials, supports the usage of the GraphCypherQaChain node and modifies the FewShotPromptTemplate node to handle variables from the prefix field. * Merge branch 'main' of github.com:FlowiseAI/Flowise into feature/graphragsupport * revert pnpm-lock.yaml * add: neo4j package * Refactor GraphCypherQAChain: Update version to 1.0, remove memory input, and enhance prompt handling - Changed version from 2.0 to 1.0. - Removed the 'Memory' input parameter from the GraphCypherQAChain. - Made 'cypherPrompt' optional and improved error handling for prompt validation. - Updated the 'init' and 'run' methods to streamline input processing and response handling. - Enhanced streaming response logic based on the 'returnDirect' flag. * Refactor GraphCypherQAChain: Simplify imports and update init method signature - Consolidated import statements for better readability. - Removed the 'input' and 'options' parameters from the 'init' method, streamlining its signature to only accept 'nodeData'. * add output, format final response, fix optional inputs --------- Co-authored-by: Henry <[email protected]>
- Loading branch information
1 parent
93f3a5d
commit a7c1ab8
Showing
8 changed files
with
34,325 additions
and
33,897 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import { INodeParams, INodeCredential } from '../src/Interface' | ||
|
||
class Neo4jApi implements INodeCredential { | ||
label: string | ||
name: string | ||
version: number | ||
description: string | ||
inputs: INodeParams[] | ||
|
||
constructor() { | ||
this.label = 'Neo4j API' | ||
this.name = 'neo4jApi' | ||
this.version = 1.0 | ||
this.description = | ||
'Refer to <a target="_blank" href="https://neo4j.com/docs/operations-manual/current/authentication-authorization/">official guide</a> on Neo4j authentication' | ||
this.inputs = [ | ||
{ | ||
label: 'Neo4j URL', | ||
name: 'url', | ||
type: 'string', | ||
description: 'Your Neo4j instance URL (e.g., neo4j://localhost:7687)' | ||
}, | ||
{ | ||
label: 'Username', | ||
name: 'username', | ||
type: 'string', | ||
description: 'Neo4j database username' | ||
}, | ||
{ | ||
label: 'Password', | ||
name: 'password', | ||
type: 'password', | ||
description: 'Neo4j database password' | ||
} | ||
] | ||
} | ||
} | ||
|
||
module.exports = { credClass: Neo4jApi } |
256 changes: 256 additions & 0 deletions
256
packages/components/nodes/chains/GraphCypherQAChain/GraphCypherQAChain.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface' | ||
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher' | ||
import { getBaseClasses } from '../../../src/utils' | ||
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts' | ||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' | ||
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console' | ||
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' | ||
import { formatResponse } from '../../outputparsers/OutputParserHelpers' | ||
|
||
class GraphCypherQA_Chain implements INode { | ||
label: string | ||
name: string | ||
version: number | ||
type: string | ||
icon: string | ||
category: string | ||
description: string | ||
baseClasses: string[] | ||
inputs: INodeParams[] | ||
sessionId?: string | ||
outputs: INodeOutputsValue[] | ||
|
||
constructor(fields?: { sessionId?: string }) { | ||
this.label = 'Graph Cypher QA Chain' | ||
this.name = 'graphCypherQAChain' | ||
this.version = 1.0 | ||
this.type = 'GraphCypherQAChain' | ||
this.icon = 'graphqa.svg' | ||
this.category = 'Chains' | ||
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements' | ||
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)] | ||
this.sessionId = fields?.sessionId | ||
this.inputs = [ | ||
{ | ||
label: 'Language Model', | ||
name: 'model', | ||
type: 'BaseLanguageModel', | ||
description: 'Model for generating Cypher queries and answers.' | ||
}, | ||
{ | ||
label: 'Neo4j Graph', | ||
name: 'graph', | ||
type: 'Neo4j' | ||
}, | ||
{ | ||
label: 'Cypher Generation Prompt', | ||
name: 'cypherPrompt', | ||
optional: true, | ||
type: 'BasePromptTemplate', | ||
description: 'Prompt template for generating Cypher queries. Must include {schema} and {question} variables' | ||
}, | ||
{ | ||
label: 'Cypher Generation Model', | ||
name: 'cypherModel', | ||
optional: true, | ||
type: 'BaseLanguageModel', | ||
description: 'Model for generating Cypher queries. If not provided, the main model will be used.' | ||
}, | ||
{ | ||
label: 'QA Prompt', | ||
name: 'qaPrompt', | ||
optional: true, | ||
type: 'BasePromptTemplate', | ||
description: 'Prompt template for generating answers. Must include {context} and {question} variables' | ||
}, | ||
{ | ||
label: 'QA Model', | ||
name: 'qaModel', | ||
optional: true, | ||
type: 'BaseLanguageModel', | ||
description: 'Model for generating answers. If not provided, the main model will be used.' | ||
}, | ||
{ | ||
label: 'Input Moderation', | ||
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model', | ||
name: 'inputModeration', | ||
type: 'Moderation', | ||
optional: true, | ||
list: true | ||
}, | ||
{ | ||
label: 'Return Direct', | ||
name: 'returnDirect', | ||
type: 'boolean', | ||
default: false, | ||
optional: true, | ||
description: 'If true, return the raw query results instead of using the QA chain' | ||
} | ||
] | ||
this.outputs = [ | ||
{ | ||
label: 'Graph Cypher QA Chain', | ||
name: 'graphCypherQAChain', | ||
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)] | ||
}, | ||
{ | ||
label: 'Output Prediction', | ||
name: 'outputPrediction', | ||
baseClasses: ['string', 'json'] | ||
} | ||
] | ||
} | ||
|
||
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> { | ||
const model = nodeData.inputs?.model | ||
const cypherModel = nodeData.inputs?.cypherModel | ||
const qaModel = nodeData.inputs?.qaModel | ||
const graph = nodeData.inputs?.graph | ||
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined | ||
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined | ||
const returnDirect = nodeData.inputs?.returnDirect as boolean | ||
const output = nodeData.outputs?.output as string | ||
|
||
// Handle prompt values if they exist | ||
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined | ||
let qaPromptTemplate: PromptTemplate | undefined | ||
|
||
if (cypherPrompt) { | ||
if (cypherPrompt instanceof PromptTemplate) { | ||
cypherPromptTemplate = new PromptTemplate({ | ||
template: cypherPrompt.template as string, | ||
inputVariables: cypherPrompt.inputVariables | ||
}) | ||
if (!qaPrompt) { | ||
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template') | ||
} | ||
} else if (cypherPrompt instanceof FewShotPromptTemplate) { | ||
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate | ||
cypherPromptTemplate = new FewShotPromptTemplate({ | ||
examples: cypherPrompt.examples, | ||
examplePrompt: examplePrompt, | ||
inputVariables: cypherPrompt.inputVariables, | ||
prefix: cypherPrompt.prefix, | ||
suffix: cypherPrompt.suffix, | ||
exampleSeparator: cypherPrompt.exampleSeparator, | ||
templateFormat: cypherPrompt.templateFormat | ||
}) | ||
} else { | ||
cypherPromptTemplate = cypherPrompt as PromptTemplate | ||
} | ||
} | ||
|
||
if (qaPrompt instanceof PromptTemplate) { | ||
qaPromptTemplate = new PromptTemplate({ | ||
template: qaPrompt.template as string, | ||
inputVariables: qaPrompt.inputVariables | ||
}) | ||
} | ||
|
||
if ((!cypherModel || !qaModel) && !model) { | ||
throw new Error('Language Model is required when Cypher Model or QA Model are not provided') | ||
} | ||
|
||
// Validate required variables in prompts | ||
if ( | ||
cypherPromptTemplate && | ||
(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question')) | ||
) { | ||
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables') | ||
} | ||
|
||
const fromLLMInput: FromLLMInput = { | ||
llm: model, | ||
graph, | ||
returnDirect | ||
} | ||
|
||
if (cypherModel && cypherPromptTemplate) { | ||
fromLLMInput['cypherLLM'] = cypherModel | ||
fromLLMInput['cypherPrompt'] = cypherPromptTemplate | ||
} | ||
|
||
if (qaModel && qaPromptTemplate) { | ||
fromLLMInput['qaLLM'] = qaModel | ||
fromLLMInput['qaPrompt'] = qaPromptTemplate | ||
} | ||
|
||
const chain = GraphCypherQAChain.fromLLM(fromLLMInput) | ||
|
||
if (output === this.name) { | ||
return chain | ||
} else if (output === 'outputPrediction') { | ||
nodeData.instance = chain | ||
return await this.run(nodeData, input, options) | ||
} | ||
|
||
return chain | ||
} | ||
|
||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> { | ||
const chain = nodeData.instance as GraphCypherQAChain | ||
const moderations = nodeData.inputs?.inputModeration as Moderation[] | ||
const returnDirect = nodeData.inputs?.returnDirect as boolean | ||
|
||
const shouldStreamResponse = options.shouldStreamResponse | ||
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer | ||
const chatId = options.chatId | ||
|
||
// Handle input moderation if configured | ||
if (moderations && moderations.length > 0) { | ||
try { | ||
input = await checkInputs(moderations, input) | ||
} catch (e) { | ||
await new Promise((resolve) => setTimeout(resolve, 500)) | ||
if (shouldStreamResponse) { | ||
streamResponse(sseStreamer, chatId, e.message) | ||
} | ||
return formatResponse(e.message) | ||
} | ||
} | ||
|
||
const obj = { | ||
query: input | ||
} | ||
|
||
const loggerHandler = new ConsoleCallbackHandler(options.logger) | ||
const callbackHandlers = await additionalCallbacks(nodeData, options) | ||
let callbacks = [loggerHandler, ...callbackHandlers] | ||
|
||
if (process.env.DEBUG === 'true') { | ||
callbacks.push(new LCConsoleCallbackHandler()) | ||
} | ||
|
||
try { | ||
let response | ||
if (shouldStreamResponse) { | ||
if (returnDirect) { | ||
response = await chain.invoke(obj, { callbacks }) | ||
let result = response?.result | ||
if (typeof result === 'object') { | ||
result = '```json\n' + JSON.stringify(result, null, 2) | ||
} | ||
if (result && typeof result === 'string') { | ||
streamResponse(sseStreamer, chatId, result) | ||
} | ||
} else { | ||
const handler = new CustomChainHandler(sseStreamer, chatId, 2) | ||
callbacks.push(handler) | ||
response = await chain.invoke(obj, { callbacks }) | ||
} | ||
} else { | ||
response = await chain.invoke(obj, { callbacks }) | ||
} | ||
|
||
return formatResponse(response?.result) | ||
} catch (error) { | ||
console.error('Error in GraphCypherQAChain:', error) | ||
if (shouldStreamResponse) { | ||
streamResponse(sseStreamer, chatId, error.message) | ||
} | ||
return formatResponse(`Error: ${error.message}`) | ||
} | ||
} | ||
} | ||
|
||
module.exports = { nodeClass: GraphCypherQA_Chain } |
22 changes: 22 additions & 0 deletions
22
packages/components/nodes/chains/GraphCypherQAChain/graphqa.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.