Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
brainlid committed Apr 27, 2024
1 parent 6f3eef2 commit 6ac8fcc
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 12 deletions.
5 changes: 4 additions & 1 deletion lib/chains/llm_chain.ex
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ defmodule LangChain.Chains.LLMChain do
combined_results = async_results ++ sync_results ++ invalid_calls

# create a single tool message that contains all the tool results
message = Message.new_tool_result!(%{content: message.content, tool_results: combined_results})
message =
Message.new_tool_result!(%{content: message.content, tool_results: combined_results})

if chain.verbose, do: IO.inspect(message, label: "TOOL RESULTS")
fire_callback(chain, message)

Expand Down Expand Up @@ -508,6 +510,7 @@ defmodule LangChain.Chains.LLMChain do

{:error, reason} when is_binary(reason) ->
if verbose, do: IO.inspect(reason, label: "FUNCTION ERROR")

ToolResult.new!(%{
tool_call_id: call.call_id,
content: reason,
Expand Down
10 changes: 9 additions & 1 deletion lib/chat_models/chat_mistral_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,15 @@ defmodule Langchain.ChatModels.ChatMistralAI do
headers: get_headers(mistral),
receive_timeout: mistral.receive_timeout
)
|> Req.post(into: Utils.handle_stream_fn(mistral, &ChatOpenAI.decode_stream/1, &do_process_response/1, callback_fn))
|> Req.post(
into:
Utils.handle_stream_fn(
mistral,
&ChatOpenAI.decode_stream/1,
&do_process_response/1,
callback_fn
)
)
|> case do
{:ok, %Req.Response{body: data}} ->
data
Expand Down
3 changes: 2 additions & 1 deletion lib/chat_models/chat_model.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ defmodule LangChain.ChatModels.ChatModel do
alias LangChain.MessageDelta
alias LangChain.Function

@type call_response :: {:ok, Message.t() | [Message.t()] | [MessageDelta.t()]} | {:error, String.t()}
@type call_response ::
{:ok, Message.t() | [Message.t()] | [MessageDelta.t()]} | {:error, String.t()}

@type tool :: Function.t()
@type tools :: [tool()]
Expand Down
10 changes: 9 additions & 1 deletion lib/chat_models/chat_ollama_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,15 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
json: for_api(ollama_ai, messages, functions),
receive_timeout: ollama_ai.receive_timeout
)
|> Req.post(into: Utils.handle_stream_fn(ollama_ai, &ChatOpenAI.decode_stream/1, &do_process_response/1, callback_fn))
|> Req.post(
into:
Utils.handle_stream_fn(
ollama_ai,
&ChatOpenAI.decode_stream/1,
&do_process_response/1,
callback_fn
)
)
|> case do
{:ok, %Req.Response{body: data}} ->
data
Expand Down
2 changes: 1 addition & 1 deletion lib/chat_models/chat_open_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do

def for_api(%Message{role: :tool, tool_results: tool_results} = _msg)
when is_list(tool_results) do
# ToolResults turn into a list of tool messages for OpenAI
# ToolResults turn into a list of tool messages for OpenAI
Enum.map(tool_results, fn result ->
%{
"role" => :tool,
Expand Down
3 changes: 2 additions & 1 deletion lib/function_param.ex
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ defmodule LangChain.FunctionParam do
end

def to_json_schema(%{} = _data, param) do
raise LangChainError, "Expected to receive a FunctionParam but instead received #{inspect param}"
raise LangChainError,
"Expected to receive a FunctionParam but instead received #{inspect(param)}"
end

# conditionally add the description field if set
Expand Down
1 change: 0 additions & 1 deletion test/chat_models/chat_anthropic_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,6 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text
end
end


describe "image vision using message parts" do
@tag live_call: true, live_anthropic: true
test "supports multi-modal user message with image prompt" do
Expand Down
1 change: 1 addition & 0 deletions test/chat_models/chat_google_ai_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ defmodule ChatModels.ChatGoogleAITest do
assert %{"role" => :user, "parts" => [%{"text" => ^message}]} = msg1
assert %{"role" => :model, "parts" => [tool_call]} = msg2
assert %{"role" => :function, "parts" => [tool_result]} = msg3

assert %{
"functionCall" => %{
"args" => ^arguments,
Expand Down
4 changes: 2 additions & 2 deletions test/support/fixtures.ex
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ defmodule LangChain.Fixtures do
status: :incomplete,
index: 0,
tool_calls: [%ToolCall{call_id: "call_123", name: "regions_list", index: 0}],
role: :tool_call,
role: :tool_call
}
],
[
Expand All @@ -982,7 +982,7 @@ defmodule LangChain.Fixtures do
status: :incomplete,
index: 0,
role: :tool_call,
tool_calls: [%ToolCall{arguments: "{}", index: 0}],
tool_calls: [%ToolCall{arguments: "{}", index: 0}]
}
],
[
Expand Down
9 changes: 6 additions & 3 deletions test/utils/chat_templates_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ defmodule LangChain.Utils.ChatTemplatesTest do
Message.new_user!("user_prompt")
]

expected = "<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"
expected =
"<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"

result = ChatTemplates.apply_chat_template!(messages, :llama_3)
assert result == expected
Expand All @@ -419,7 +420,8 @@ defmodule LangChain.Utils.ChatTemplatesTest do
Message.new_user!("user_prompt")
]

expected = "<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n"
expected =
"<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n"

result =
ChatTemplates.apply_chat_template!(messages, :llama_3, add_generation_prompt: false)
Expand All @@ -430,7 +432,8 @@ defmodule LangChain.Utils.ChatTemplatesTest do
test "no system message when not provided" do
messages = [Message.new_user!("user_prompt")]

expected = "<|begin_of_text|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"
expected =
"<|begin_of_text|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"

result = ChatTemplates.apply_chat_template!(messages, :llama_3)
assert result == expected
Expand Down

0 comments on commit 6ac8fcc

Please sign in to comment.