diff --git a/lib/chains/data_extraction_chain.ex b/lib/chains/data_extraction_chain.ex index f216fa59..37b15758 100644 --- a/lib/chains/data_extraction_chain.ex +++ b/lib/chains/data_extraction_chain.ex @@ -72,6 +72,7 @@ defmodule LangChain.Chains.DataExtractionChain do require Logger alias LangChain.PromptTemplate alias LangChain.Message + alias LangChain.Message.ToolCall alias LangChain.Chains.LLMChain @function_name "information_extraction" @@ -101,15 +102,19 @@ Passage: {:ok, chain} = LLMChain.new(%{llm: llm, verbose: verbose}) chain - |> LLMChain.add_functions(build_extract_function(json_schema)) + |> LLMChain.add_tools(build_extract_function(json_schema)) |> LLMChain.add_messages(messages) |> LLMChain.run() |> case do {:ok, _updated_chain, %Message{ role: :assistant, - function_name: @function_name, - arguments: %{"info" => info} + tool_calls: [ + %ToolCall{ + name: @function_name, + arguments: %{"info" => info} + } + ] }} when is_list(info) -> {:ok, info} @@ -136,6 +141,12 @@ Passage: LangChain.Function.new!(%{ name: @function_name, description: "Extracts the relevant information from the passage.", + function: fn args, _context -> + # NOTE: The function is not executed here because we won't be returning + # this to the LLM. The LLMChain does not run the function, but stops at + # the request for it. + {:ok, args} + end, parameters_schema: %{ type: "object", properties: %{