diff --git a/lib/chat_models/chat_google_ai.ex b/lib/chat_models/chat_google_ai.ex index 27b53fc1..299bf67d 100644 --- a/lib/chat_models/chat_google_ai.ex +++ b/lib/chat_models/chat_google_ai.ex @@ -13,6 +13,9 @@ defmodule LangChain.ChatModels.ChatGoogleAI do alias LangChain.ChatModels.ChatOpenAI alias LangChain.Message alias LangChain.MessageDelta + alias LangChain.Message.ContentPart + alias LangChain.Message.ToolCall + alias LangChain.Message.ToolResult alias LangChain.LangChainError alias LangChain.Utils @@ -120,13 +123,14 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end def for_api(%ChatGoogleAI{} = google_ai, messages, functions) do + messages_for_api = + messages + |> Enum.map(&for_api/1) + |> List.flatten() + |> List.wrap() + req = %{ - "contents" => - Stream.map(messages, &for_api/1) - |> Enum.flat_map(fn - list when is_list(list) -> list - not_list -> [not_list] - end), + "contents" => messages_for_api, "generationConfig" => %{ "temperature" => google_ai.temperature, "topP" => google_ai.top_p, @@ -148,32 +152,20 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end end - defp for_api(%Message{role: :assistant, function_name: fun_name} = fun) - when is_binary(fun_name) do + defp for_api(%Message{role: :assistant} = message) do + content_parts = get_message_contents(message) || [] + tool_calls = Enum.map(message.tool_calls || [], &for_api/1) + %{ "role" => map_role(:assistant), - "parts" => [ - %{ - "functionCall" => %{ - "name" => fun_name, - "args" => fun.arguments - } - } - ] + "parts" => content_parts ++ tool_calls } end - defp for_api(%Message{role: :function} = message) do + defp for_api(%Message{role: :tool} = message) do %{ - "role" => map_role(:function), - "parts" => [ - %{ - "functionResponse" => %{ - "name" => message.function_name, - "response" => Jason.decode!(message.content) - } - } - ] + "role" => map_role(:tool), + "parts" => Enum.map(message.tool_results, &for_api/1) } end @@ -199,20 +191,33 @@ defmodule LangChain.ChatModels.ChatGoogleAI do } end - defp map_role(role) do - case role do - :assistant -> :model - # System prompts are not supported yet. Google recommends using user prompt. - :system -> :user - role -> role - end + defp for_api(%ContentPart{type: :text} = part) do + %{"text" => part.content} + end + + defp for_api(%ToolCall{} = call) do + %{ + "functionCall" => %{ + "args" => call.arguments, + "name" => call.name + } + } + end + + defp for_api(%ToolResult{} = result) do + %{ + "functionResponse" => %{ + "name" => result.name, + "response" => Jason.decode!(result.content) + } + } end @doc """ Calls the Google AI API passing the ChatGoogleAI struct with configuration, plus either a simple message or the list of messages to act as the prompt. - Optionally pass in a list of functions available to the LLM for requesting + Optionally pass in a list of tools available to the LLM for requesting execution in response. Optionally pass in a callback function that can be executed as data is @@ -223,27 +228,27 @@ defmodule LangChain.ChatModels.ChatGoogleAI do translating the `LangChain` data structures to and from the OpenAI API. Another benefit of using `LangChain.Chains.LLMChain` is that it combines the - storage of messages, adding functions, adding custom context that should be - passed to functions, and automatically applying `LangChain.MessageDelta` + storage of messages, adding tools, adding custom context that should be + passed to tools, and automatically applying `LangChain.MessageDelta` structs as they are are received, then converting those to the full `LangChain.Message` once fully complete. """ @impl ChatModel - def call(openai, prompt, functions \\ [], callback_fn \\ nil) + def call(openai, prompt, tools \\ [], callback_fn \\ nil) - def call(%ChatGoogleAI{} = google_ai, prompt, functions, callback_fn) when is_binary(prompt) do + def call(%ChatGoogleAI{} = google_ai, prompt, tools, callback_fn) when is_binary(prompt) do messages = [ Message.new_system!(), Message.new_user!(prompt) ] - call(google_ai, messages, functions, callback_fn) + call(google_ai, messages, tools, callback_fn) end - def call(%ChatGoogleAI{} = google_ai, messages, functions, callback_fn) + def call(%ChatGoogleAI{} = google_ai, messages, tools, callback_fn) when is_list(messages) do try do - case do_api_request(google_ai, messages, functions, callback_fn) do + case do_api_request(google_ai, messages, tools, callback_fn) do {:error, reason} -> {:error, reason} @@ -259,11 +264,11 @@ defmodule LangChain.ChatModels.ChatGoogleAI do @doc false @spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) :: list() | struct() | {:error, String.t()} - def do_api_request(%ChatGoogleAI{stream: false} = google_ai, messages, functions, callback_fn) do + def do_api_request(%ChatGoogleAI{stream: false} = google_ai, messages, tools, callback_fn) do req = Req.new( url: build_url(google_ai), - json: for_api(google_ai, messages, functions), + json: for_api(google_ai, messages, tools), receive_timeout: google_ai.receive_timeout, retry: :transient, max_retries: 3, @@ -292,15 +297,21 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end end - def do_api_request(%ChatGoogleAI{stream: true} = google_ai, messages, functions, callback_fn) do + def do_api_request(%ChatGoogleAI{stream: true} = google_ai, messages, tools, callback_fn) do Req.new( url: build_url(google_ai), - json: for_api(google_ai, messages, functions), + json: for_api(google_ai, messages, tools), receive_timeout: google_ai.receive_timeout ) |> Req.Request.put_header("accept-encoding", "utf-8") |> Req.post( - into: Utils.handle_stream_fn(google_ai, &ChatOpenAI.decode_stream/1, &do_process_response(&1, MessageDelta), callback_fn) + into: + Utils.handle_stream_fn( + google_ai, + &ChatOpenAI.decode_stream/1, + &do_process_response(&1, MessageDelta), + callback_fn + ) ) |> case do {:ok, %Req.Response{body: data}} -> @@ -349,19 +360,96 @@ defmodule LangChain.ChatModels.ChatGoogleAI do |> Enum.map(&do_process_response(&1, message_type)) end - def do_process_response( - %{ - "content" => %{"parts" => [%{"functionCall" => %{"args" => raw_args, "name" => name}}]} - } = data, - message_type - ) do - case message_type.new(%{ - "role" => "assistant", - "function_name" => name, - "arguments" => raw_args, - "complete" => true, - "index" => data["index"] - }) do + def do_process_response(%{"content" => %{"parts" => parts} = content_data} = data, Message) do + text_part = + parts + |> filter_parts_for_types(["text"]) + |> Enum.map(fn part -> + ContentPart.new!(%{type: :text, content: part["text"]}) + end) + + tool_calls_from_parts = + parts + |> filter_parts_for_types(["functionCall"]) + |> Enum.map(fn part -> + do_process_response(part, nil) + end) + + tool_result_from_parts = + parts + |> filter_parts_for_types(["functionResponse"]) + |> Enum.map(fn part -> + do_process_response(part, nil) + end) + + %{ + role: unmap_role(content_data["role"]), + content: text_part, + complete: false, + index: data["index"] + } + |> Utils.conditionally_add_to_map(:tool_calls, tool_calls_from_parts) + |> Utils.conditionally_add_to_map(:tool_results, tool_result_from_parts) + |> Message.new() + |> case do + {:ok, message} -> + message + + {:error, changeset} -> + {:error, Utils.changeset_error_to_string(changeset)} + end + end + + def do_process_response(%{"content" => %{"parts" => parts} = content_data} = data, MessageDelta) do + text_content = + case parts do + [%{"text" => text}] -> + text + + _other -> + nil + end + + parts + |> filter_parts_for_types(["text"]) + |> Enum.map(fn part -> + ContentPart.new!(%{type: :text, content: part["text"]}) + end) + + tool_calls_from_parts = + parts + |> filter_parts_for_types(["functionCall"]) + |> Enum.map(fn part -> + do_process_response(part, nil) + end) + + %{ + role: unmap_role(content_data["role"]), + content: text_content, + complete: true, + index: data["index"] + } + |> Utils.conditionally_add_to_map(:tool_calls, tool_calls_from_parts) + |> MessageDelta.new() + |> case do + {:ok, message} -> + message + + {:error, changeset} -> + {:error, Utils.changeset_error_to_string(changeset)} + end + end + + def do_process_response(%{"functionCall" => %{"args" => raw_args, "name" => name}} = data, _) do + %{ + call_id: "call-#{name}", + name: name, + arguments: raw_args, + complete: true, + index: data["index"] + } + |> ToolCall.new() + |> case do {:ok, message} -> message @@ -430,6 +518,40 @@ defmodule LangChain.ChatModels.ChatGoogleAI do {:error, "Unexpected response"} end + @doc false + def filter_parts_for_types(parts, types) when is_list(parts) and is_list(types) do + Enum.filter(parts, fn p -> + Enum.any?(types, &Map.has_key?(p, &1)) + end) + end + + @doc """ + Return the content parts for the message. + """ + @spec get_message_contents(MessageDelta.t() | Message.t()) :: [%{String.t() => any()}] + def get_message_contents(%{content: content} = _message) when is_binary(content) do + [%{"text" => content}] + end + + def get_message_contents(%{content: contents} = _message) when is_list(contents) do + Enum.map(contents, &for_api/1) + end + + def get_message_contents(%{content: nil} = _message) do + nil + end + + defp map_role(role) do + case role do + :assistant -> :model + :tool -> :function + # System prompts are not supported yet. Google recommends using user prompt. + :system -> :user + role -> role + end + end + defp unmap_role("model"), do: "assistant" + defp unmap_role("function"), do: "tool" defp unmap_role(role), do: role end diff --git a/lib/message.ex b/lib/message.ex index 1b0bc735..7375c92a 100644 --- a/lib/message.ex +++ b/lib/message.ex @@ -180,7 +180,7 @@ defmodule LangChain.Message do changeset {:ok, [%ContentPart{} | _] = value} -> - if role == :user do + if role in [:user, :assistant] do # if a list, verify all elements are a ContentPart if Enum.all?(value, &match?(%ContentPart{}, &1)) do changeset diff --git a/test/chat_models/chat_anthropic_test.exs b/test/chat_models/chat_anthropic_test.exs index 26723cce..baabd702 100644 --- a/test/chat_models/chat_anthropic_test.exs +++ b/test/chat_models/chat_anthropic_test.exs @@ -1180,7 +1180,7 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text send(test_pid, {:streamed_fn, data}) end - {:ok, result_chain, last_message} = + {:ok, _result_chain, last_message} = LLMChain.new!(%{llm: %ChatAnthropic{model: @test_model, stream: true}}) |> LLMChain.add_message(Message.new_system!("You are a helpful and concise assistant.")) |> LLMChain.add_message(Message.new_user!("Say, 'Hi!'!")) diff --git a/test/chat_models/chat_google_ai_test.exs b/test/chat_models/chat_google_ai_test.exs index 9b2d4905..abe69388 100644 --- a/test/chat_models/chat_google_ai_test.exs +++ b/test/chat_models/chat_google_ai_test.exs @@ -5,6 +5,9 @@ defmodule ChatModels.ChatGoogleAITest do doctest LangChain.ChatModels.ChatGoogleAI alias LangChain.ChatModels.ChatGoogleAI alias LangChain.Message + alias LangChain.Message.ContentPart + alias LangChain.Message.ToolCall + alias LangChain.Message.ToolResult alias LangChain.MessageDelta alias LangChain.Function @@ -102,8 +105,24 @@ defmodule ChatModels.ChatGoogleAITest do google_ai, [ Message.new_user!(message), - Message.new_function_call!("userland_action", Jason.encode!(arguments)), - Message.new_function!("userland_action", function_result) + Message.new_assistant!(%{ + tool_calls: [ + ToolCall.new!(%{ + call_id: "call_123", + name: "userland_action", + arguments: Jason.encode!(arguments) + }) + ] + }), + Message.new_tool_result!(%{ + tool_results: [ + ToolResult.new!(%{ + tool_call_id: "call_123", + name: "userland_action", + content: Jason.encode!(function_result) + }) + ] + }) ], [] ) @@ -112,7 +131,6 @@ defmodule ChatModels.ChatGoogleAITest do assert %{"role" => :user, "parts" => [%{"text" => ^message}]} = msg1 assert %{"role" => :model, "parts" => [tool_call]} = msg2 assert %{"role" => :function, "parts" => [tool_result]} = msg3 - assert %{ "functionCall" => %{ "args" => ^arguments, @@ -173,7 +191,7 @@ defmodule ChatModels.ChatGoogleAITest do assert [%Message{} = struct] = ChatGoogleAI.do_process_response(response) assert struct.role == :assistant - assert struct.content == "Hello User!" + [%ContentPart{type: :text, content: "Hello User!"}] = struct.content assert struct.index == 0 assert struct.status == :complete end @@ -212,8 +230,9 @@ defmodule ChatModels.ChatGoogleAITest do assert [%Message{} = struct] = ChatGoogleAI.do_process_response(response) assert struct.role == :assistant assert struct.index == 0 - assert struct.function_name == "hello_world" - assert struct.arguments == args + [call] = struct.tool_calls + assert call.name == "hello_world" + assert call.arguments == args end test "handles receiving MessageDeltas as well" do @@ -262,4 +281,62 @@ defmodule ChatModels.ChatGoogleAITest do assert {:error, "Unexpected response"} = ChatGoogleAI.do_process_response(response) end end + + describe "filter_parts_for_types/2" do + test "returns a single functionCall type" do + parts = [ + %{"text" => "I think I'll call this function."}, + %{ + "functionCall" => %{ + "args" => %{"args" => "data"}, + "name" => "userland_action" + } + } + ] + + assert [%{"text" => _}] = ChatGoogleAI.filter_parts_for_types(parts, ["text"]) + + assert [%{"functionCall" => _}] = + ChatGoogleAI.filter_parts_for_types(parts, ["functionCall"]) + end + + test "returns a set of types" do + parts = [ + %{"text" => "I think I'll call this function."}, + %{ + "functionCall" => %{ + "args" => %{"args" => "data"}, + "name" => "userland_action" + } + } + ] + + assert parts == ChatGoogleAI.filter_parts_for_types(parts, ["text", "functionCall"]) + end + end + + describe "get_message_contents/1" do + test "returns basic text as a ContentPart" do + message = Message.new_user!("Howdy!") + + result = ChatGoogleAI.get_message_contents(message) + + assert result == [%{"text" => "Howdy!"}] + end + + test "supports a list of ContentParts" do + message = + Message.new_user!([ + ContentPart.new!(%{type: :text, content: "Hello!"}), + ContentPart.new!(%{type: :text, content: "What's up?"}) + ]) + + result = ChatGoogleAI.get_message_contents(message) + + assert result == [ + %{"text" => "Hello!"}, + %{"text" => "What's up?"} + ] + end + end end diff --git a/test/message_test.exs b/test/message_test.exs index 78874d81..78a05d68 100644 --- a/test/message_test.exs +++ b/test/message_test.exs @@ -94,9 +94,6 @@ defmodule LangChain.MessageTest do assert message.content == "Hi" # content parts not allowed for other role types - {:error, changeset} = Message.new_assistant(%{content: [part]}) - assert {"is invalid for role assistant", _} = changeset.errors[:content] - {:error, changeset} = Message.new_system([part]) assert {"is invalid for role system", _} = changeset.errors[:content]