Skip to content

Commit

Permalink
updated ChatGoogleAI for new message structures
Browse files Browse the repository at this point in the history
- updated tests
  • Loading branch information
brainlid committed Apr 27, 2024
1 parent e3d555f commit e400443
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 70 deletions.
240 changes: 181 additions & 59 deletions lib/chat_models/chat_google_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
alias LangChain.ChatModels.ChatOpenAI
alias LangChain.Message
alias LangChain.MessageDelta
alias LangChain.Message.ContentPart
alias LangChain.Message.ToolCall
alias LangChain.Message.ToolResult
alias LangChain.LangChainError
alias LangChain.Utils

Expand Down Expand Up @@ -120,13 +123,14 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
end

def for_api(%ChatGoogleAI{} = google_ai, messages, functions) do
messages_for_api =
messages
|> Enum.map(&for_api/1)
|> List.flatten()
|> List.wrap()

req = %{
"contents" =>
Stream.map(messages, &for_api/1)
|> Enum.flat_map(fn
list when is_list(list) -> list
not_list -> [not_list]
end),
"contents" => messages_for_api,
"generationConfig" => %{
"temperature" => google_ai.temperature,
"topP" => google_ai.top_p,
Expand All @@ -148,32 +152,20 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
end
end

defp for_api(%Message{role: :assistant, function_name: fun_name} = fun)
when is_binary(fun_name) do
defp for_api(%Message{role: :assistant} = message) do
content_parts = get_message_contents(message) || []
tool_calls = Enum.map(message.tool_calls || [], &for_api/1)

%{
"role" => map_role(:assistant),
"parts" => [
%{
"functionCall" => %{
"name" => fun_name,
"args" => fun.arguments
}
}
]
"parts" => content_parts ++ tool_calls
}
end

defp for_api(%Message{role: :function} = message) do
defp for_api(%Message{role: :tool} = message) do
%{
"role" => map_role(:function),
"parts" => [
%{
"functionResponse" => %{
"name" => message.function_name,
"response" => Jason.decode!(message.content)
}
}
]
"role" => map_role(:tool),
"parts" => Enum.map(message.tool_results, &for_api/1)
}
end

Expand All @@ -199,20 +191,33 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
}
end

defp map_role(role) do
case role do
:assistant -> :model
# System prompts are not supported yet. Google recommends using user prompt.
:system -> :user
role -> role
end
defp for_api(%ContentPart{type: :text} = part) do
%{"text" => part.content}
end

defp for_api(%ToolCall{} = call) do
%{
"functionCall" => %{
"args" => call.arguments,
"name" => call.name
}
}
end

defp for_api(%ToolResult{} = result) do
%{
"functionResponse" => %{
"name" => result.name,
"response" => Jason.decode!(result.content)
}
}
end

@doc """
Calls the Google AI API passing the ChatGoogleAI struct with configuration, plus
either a simple message or the list of messages to act as the prompt.
Optionally pass in a list of functions available to the LLM for requesting
Optionally pass in a list of tools available to the LLM for requesting
execution in response.
Optionally pass in a callback function that can be executed as data is
Expand All @@ -223,27 +228,27 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
translating the `LangChain` data structures to and from the OpenAI API.
Another benefit of using `LangChain.Chains.LLMChain` is that it combines the
storage of messages, adding functions, adding custom context that should be
passed to functions, and automatically applying `LangChain.MessageDelta`
storage of messages, adding tools, adding custom context that should be
passed to tools, and automatically applying `LangChain.MessageDelta`
structs as they are are received, then converting those to the full
`LangChain.Message` once fully complete.
"""
@impl ChatModel
def call(openai, prompt, functions \\ [], callback_fn \\ nil)
def call(openai, prompt, tools \\ [], callback_fn \\ nil)

def call(%ChatGoogleAI{} = google_ai, prompt, functions, callback_fn) when is_binary(prompt) do
def call(%ChatGoogleAI{} = google_ai, prompt, tools, callback_fn) when is_binary(prompt) do
messages = [
Message.new_system!(),
Message.new_user!(prompt)
]

call(google_ai, messages, functions, callback_fn)
call(google_ai, messages, tools, callback_fn)
end

def call(%ChatGoogleAI{} = google_ai, messages, functions, callback_fn)
def call(%ChatGoogleAI{} = google_ai, messages, tools, callback_fn)
when is_list(messages) do
try do
case do_api_request(google_ai, messages, functions, callback_fn) do
case do_api_request(google_ai, messages, tools, callback_fn) do
{:error, reason} ->
{:error, reason}

Expand All @@ -259,11 +264,11 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
@doc false
@spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) ::
list() | struct() | {:error, String.t()}
def do_api_request(%ChatGoogleAI{stream: false} = google_ai, messages, functions, callback_fn) do
def do_api_request(%ChatGoogleAI{stream: false} = google_ai, messages, tools, callback_fn) do
req =
Req.new(
url: build_url(google_ai),
json: for_api(google_ai, messages, functions),
json: for_api(google_ai, messages, tools),
receive_timeout: google_ai.receive_timeout,
retry: :transient,
max_retries: 3,
Expand Down Expand Up @@ -292,15 +297,21 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
end
end

def do_api_request(%ChatGoogleAI{stream: true} = google_ai, messages, functions, callback_fn) do
def do_api_request(%ChatGoogleAI{stream: true} = google_ai, messages, tools, callback_fn) do
Req.new(
url: build_url(google_ai),
json: for_api(google_ai, messages, functions),
json: for_api(google_ai, messages, tools),
receive_timeout: google_ai.receive_timeout
)
|> Req.Request.put_header("accept-encoding", "utf-8")
|> Req.post(
into: Utils.handle_stream_fn(google_ai, &ChatOpenAI.decode_stream/1, &do_process_response(&1, MessageDelta), callback_fn)
into:
Utils.handle_stream_fn(
google_ai,
&ChatOpenAI.decode_stream/1,
&do_process_response(&1, MessageDelta),
callback_fn
)
)
|> case do
{:ok, %Req.Response{body: data}} ->
Expand Down Expand Up @@ -349,19 +360,96 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
|> Enum.map(&do_process_response(&1, message_type))
end

def do_process_response(
%{
"content" => %{"parts" => [%{"functionCall" => %{"args" => raw_args, "name" => name}}]}
} = data,
message_type
) do
case message_type.new(%{
"role" => "assistant",
"function_name" => name,
"arguments" => raw_args,
"complete" => true,
"index" => data["index"]
}) do
def do_process_response(%{"content" => %{"parts" => parts} = content_data} = data, Message) do
text_part =
parts
|> filter_parts_for_types(["text"])
|> Enum.map(fn part ->
ContentPart.new!(%{type: :text, content: part["text"]})
end)

tool_calls_from_parts =
parts
|> filter_parts_for_types(["functionCall"])
|> Enum.map(fn part ->
do_process_response(part, nil)
end)

tool_result_from_parts =
parts
|> filter_parts_for_types(["functionResponse"])
|> Enum.map(fn part ->
do_process_response(part, nil)
end)

%{
role: unmap_role(content_data["role"]),
content: text_part,
complete: false,
index: data["index"]
}
|> Utils.conditionally_add_to_map(:tool_calls, tool_calls_from_parts)
|> Utils.conditionally_add_to_map(:tool_results, tool_result_from_parts)
|> Message.new()
|> case do
{:ok, message} ->
message

{:error, changeset} ->
{:error, Utils.changeset_error_to_string(changeset)}
end
end

def do_process_response(%{"content" => %{"parts" => parts} = content_data} = data, MessageDelta) do
text_content =
case parts do
[%{"text" => text}] ->
text

_other ->
nil
end

parts
|> filter_parts_for_types(["text"])
|> Enum.map(fn part ->
ContentPart.new!(%{type: :text, content: part["text"]})
end)

tool_calls_from_parts =
parts
|> filter_parts_for_types(["functionCall"])
|> Enum.map(fn part ->
do_process_response(part, nil)
end)

%{
role: unmap_role(content_data["role"]),
content: text_content,
complete: true,
index: data["index"]
}
|> Utils.conditionally_add_to_map(:tool_calls, tool_calls_from_parts)
|> MessageDelta.new()
|> case do
{:ok, message} ->
message

{:error, changeset} ->
{:error, Utils.changeset_error_to_string(changeset)}
end
end

def do_process_response(%{"functionCall" => %{"args" => raw_args, "name" => name}} = data, _) do
%{
call_id: "call-#{name}",
name: name,
arguments: raw_args,
complete: true,
index: data["index"]
}
|> ToolCall.new()
|> case do
{:ok, message} ->
message

Expand Down Expand Up @@ -430,6 +518,40 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
{:error, "Unexpected response"}
end

@doc false
def filter_parts_for_types(parts, types) when is_list(parts) and is_list(types) do
Enum.filter(parts, fn p ->
Enum.any?(types, &Map.has_key?(p, &1))
end)
end

@doc """
Return the content parts for the message.
"""
@spec get_message_contents(MessageDelta.t() | Message.t()) :: [%{String.t() => any()}]
def get_message_contents(%{content: content} = _message) when is_binary(content) do
[%{"text" => content}]
end

def get_message_contents(%{content: contents} = _message) when is_list(contents) do
Enum.map(contents, &for_api/1)
end

def get_message_contents(%{content: nil} = _message) do
nil
end

defp map_role(role) do
case role do
:assistant -> :model
:tool -> :function
# System prompts are not supported yet. Google recommends using user prompt.
:system -> :user
role -> role
end
end

defp unmap_role("model"), do: "assistant"
defp unmap_role("function"), do: "tool"
defp unmap_role(role), do: role
end
2 changes: 1 addition & 1 deletion lib/message.ex
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ defmodule LangChain.Message do
changeset

{:ok, [%ContentPart{} | _] = value} ->
if role == :user do
if role in [:user, :assistant] do
# if a list, verify all elements are a ContentPart
if Enum.all?(value, &match?(%ContentPart{}, &1)) do
changeset
Expand Down
2 changes: 1 addition & 1 deletion test/chat_models/chat_anthropic_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text
send(test_pid, {:streamed_fn, data})
end

{:ok, result_chain, last_message} =
{:ok, _result_chain, last_message} =
LLMChain.new!(%{llm: %ChatAnthropic{model: @test_model, stream: true}})
|> LLMChain.add_message(Message.new_system!("You are a helpful and concise assistant."))
|> LLMChain.add_message(Message.new_user!("Say, 'Hi!'!"))
Expand Down
Loading

0 comments on commit e400443

Please sign in to comment.