diff --git a/.github/workflows/elixir.yml b/.github/workflows/elixir.yml index 132df1e4..5237d5d6 100644 --- a/.github/workflows/elixir.yml +++ b/.github/workflows/elixir.yml @@ -15,6 +15,8 @@ env: OPENAI_API_KEY: invalid ANTHROPIC_API_KEY: invalid GOOGLE_API_KEY: invalid + AWS_ACCESS_KEY_ID: invalid + AWS_SECRET_ACCESS_KEY: invalid permissions: contents: read diff --git a/lib/chat_models/chat_anthropic.ex b/lib/chat_models/chat_anthropic.ex index a5ec19c2..552ffeec 100644 --- a/lib/chat_models/chat_anthropic.ex +++ b/lib/chat_models/chat_anthropic.ex @@ -54,6 +54,8 @@ defmodule LangChain.ChatModels.ChatAnthropic do alias LangChain.FunctionParam alias LangChain.Utils alias LangChain.Callbacks + alias LangChain.Utils.BedrockStreamDecoder + alias LangChain.Utils.BedrockConfig @behaviour ChatModel @@ -67,6 +69,9 @@ defmodule LangChain.ChatModels.ChatAnthropic do # API endpoint to use. Defaults to Anthropic's API field :endpoint, :string, default: "https://api.anthropic.com/v1/messages" + # Configuration for AWS Bedrock. Configure this instead of endpoint & api_key if you want to use Bedrock. + embeds_one :bedrock, BedrockConfig + # API key for Anthropic. If not set, will use global api key. Allows for usage # of a different API key per-call if desired. For instance, allowing a # customer to provide their own. @@ -131,12 +136,6 @@ defmodule LangChain.ChatModels.ChatAnthropic do ] @required_fields [:endpoint, :model] - @spec get_api_key(t()) :: String.t() - defp get_api_key(%ChatAnthropic{api_key: api_key}) do - # if no API key is set default to `""` which will raise an error - api_key || Config.resolve(:anthropic_key, "") - end - @doc """ Setup a ChatAnthropic client configuration. """ @@ -144,6 +143,7 @@ defmodule LangChain.ChatModels.ChatAnthropic do def new(%{} = attrs \\ %{}) do %ChatAnthropic{} |> cast(attrs, @create_fields) + |> cast_embed(:bedrock) |> common_validation() |> apply_action(:insert) end @@ -175,7 +175,8 @@ defmodule LangChain.ChatModels.ChatAnthropic do @spec for_api(t, message :: [map()], ChatModel.tools()) :: %{atom() => any()} def for_api(%ChatAnthropic{} = anthropic, messages, tools) do # separate the system message from the rest. Handled separately. - {system, messages} = split_system_message(messages) + {system, messages} = + Utils.split_system_message(messages, "Anthropic only supports a single System message") system_text = case system do @@ -203,6 +204,15 @@ defmodule LangChain.ChatModels.ChatAnthropic do |> Utils.conditionally_add_to_map(:max_tokens, anthropic.max_tokens) |> Utils.conditionally_add_to_map(:top_p, anthropic.top_p) |> Utils.conditionally_add_to_map(:top_k, anthropic.top_k) + |> maybe_transform_for_bedrock(anthropic.bedrock) + end + + defp maybe_transform_for_bedrock(body, nil), do: body + + defp maybe_transform_for_bedrock(body, %BedrockConfig{} = bedrock) do + body + |> Map.put(:anthropic_version, bedrock.anthropic_version) + |> Map.drop([:model, :stream]) end defp get_tools_for_api(nil), do: [] @@ -214,21 +224,6 @@ defmodule LangChain.ChatModels.ChatAnthropic do end) end - # Unlike OpenAI, Anthropic only supports one system message. - @doc false - @spec split_system_message([Message.t()]) :: {nil | Message.t(), [Message.t()]} | no_return() - def split_system_message(messages) do - # split the messages into "system" and "other". Error if more than 1 system - # message. Return the other messages as a separate list. - {system, other} = Enum.split_with(messages, &(&1.role == :system)) - - if length(system) > 1 do - raise LangChainError, "Anthropic only supports a single System message" - end - - {List.first(system), other} - end - @doc """ Calls the Anthropic API passing the ChatAnthropic struct with configuration, plus either a simple message or the list of messages to act as the prompt. @@ -301,13 +296,14 @@ defmodule LangChain.ChatModels.ChatAnthropic do ) do req = Req.new( - url: anthropic.endpoint, + url: url(anthropic), json: for_api(anthropic, messages, tools), - headers: headers(get_api_key(anthropic), anthropic.api_version), + headers: headers(anthropic), receive_timeout: anthropic.receive_timeout, retry: :transient, max_retries: 3, - retry_delay: fn attempt -> 300 * attempt end + retry_delay: fn attempt -> 300 * attempt end, + aws_sigv4: aws_sigv4_opts(anthropic.bedrock) ) req @@ -355,14 +351,19 @@ defmodule LangChain.ChatModels.ChatAnthropic do retry_count ) do Req.new( - url: anthropic.endpoint, + url: url(anthropic), json: for_api(anthropic, messages, tools), - headers: headers(get_api_key(anthropic), anthropic.api_version), - receive_timeout: anthropic.receive_timeout + headers: headers(anthropic), + receive_timeout: anthropic.receive_timeout, + aws_sigv4: aws_sigv4_opts(anthropic.bedrock) ) |> Req.post( into: - Utils.handle_stream_fn(anthropic, &decode_stream/1, &do_process_response(anthropic, &1)) + Utils.handle_stream_fn( + anthropic, + &decode_stream(anthropic, &1), + &do_process_response(anthropic, &1) + ) ) |> case do {:ok, %Req.Response{body: data} = response} -> @@ -393,9 +394,18 @@ defmodule LangChain.ChatModels.ChatAnthropic do end end - defp headers(api_key, api_version) do + defp aws_sigv4_opts(nil), do: nil + defp aws_sigv4_opts(%BedrockConfig{} = bedrock), do: BedrockConfig.aws_sigv4_opts(bedrock) + + @spec get_api_key(binary() | nil) :: String.t() + defp get_api_key(api_key) do + # if no API key is set default to `""` which will raise an error + api_key || Config.resolve(:anthropic_key, "") + end + + defp headers(%ChatAnthropic{bedrock: nil, api_key: api_key, api_version: api_version}) do %{ - "x-api-key" => api_key, + "x-api-key" => get_api_key(api_key), "content-type" => "application/json", "anthropic-version" => api_version, # https://docs.anthropic.com/claude/docs/tool-use - requires this header during beta @@ -403,6 +413,21 @@ defmodule LangChain.ChatModels.ChatAnthropic do } end + defp headers(%ChatAnthropic{bedrock: %BedrockConfig{}}) do + %{ + "content-type" => "application/json", + "accept" => "application/json" + } + end + + defp url(%ChatAnthropic{bedrock: nil} = anthropic) do + anthropic.endpoint + end + + defp url(%ChatAnthropic{bedrock: %BedrockConfig{} = bedrock, stream: stream} = anthropic) do + BedrockConfig.url(bedrock, model: anthropic.model, stream: stream) + end + # Parse a new message response @doc false @spec do_process_response(t(), data :: %{String.t() => any()} | {:error, any()}) :: @@ -527,6 +552,16 @@ defmodule LangChain.ChatModels.ChatAnthropic do {:error, error_message} end + def do_process_response(%ChatAnthropic{bedrock: %BedrockConfig{}}, %{"message" => message}) do + {:error, "Received error from API: #{message}"} + end + + def do_process_response(%ChatAnthropic{bedrock: %BedrockConfig{}}, %{ + bedrock_exception: exceptions + }) do + {:error, "Stream exception received: #{inspect(exceptions)}"} + end + def do_process_response(_model, other) do Logger.error("Trying to process an unexpected response. #{inspect(other)}") {:error, "Unexpected response"} @@ -597,7 +632,7 @@ defmodule LangChain.ChatModels.ChatAnthropic do end @doc false - def decode_stream({chunk, buffer}) do + def decode_stream(%ChatAnthropic{bedrock: nil}, {chunk, buffer}) do # Combine the incoming data with the buffered incomplete data combined_data = buffer <> chunk # Split data by double newline to find complete messages @@ -665,6 +700,18 @@ defmodule LangChain.ChatModels.ChatAnthropic do # assumed the response is JSON. Return as-is defp extract_data(json), do: json + @doc false + def decode_stream(%ChatAnthropic{bedrock: %BedrockConfig{}}, {chunk, buffer}, chunks \\ []) do + {chunks, remaining} = BedrockStreamDecoder.decode_stream({chunk, buffer}, chunks) + + chunks = + Enum.filter(chunks, fn chunk -> + Map.has_key?(chunk, :bedrock_exception) || relevant_event?("event: #{chunk["type"]}\n") + end) + + {chunks, remaining} + end + @doc """ Convert a LangChain structure to the expected map of data for the OpenAI API. """ diff --git a/lib/chat_models/chat_google_ai.ex b/lib/chat_models/chat_google_ai.ex index f24cdfaf..be86cef9 100644 --- a/lib/chat_models/chat_google_ai.ex +++ b/lib/chat_models/chat_google_ai.ex @@ -138,20 +138,34 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end def for_api(%ChatGoogleAI{} = google_ai, messages, functions) do + {system, messages} = + Utils.split_system_message(messages, "Google AI only supports a single System message") + + system_instruction = + case system do + nil -> + nil + + %Message{role: :system, content: content} -> + %{"parts" => [%{"text" => content}]} + end + messages_for_api = messages |> Enum.map(&for_api/1) |> List.flatten() |> List.wrap() - req = %{ - "contents" => messages_for_api, - "generationConfig" => %{ - "temperature" => google_ai.temperature, - "topP" => google_ai.top_p, - "topK" => google_ai.top_k + req = + %{ + "contents" => messages_for_api, + "generationConfig" => %{ + "temperature" => google_ai.temperature, + "topP" => google_ai.top_p, + "topK" => google_ai.top_k + } } - } + |> LangChain.Utils.conditionally_add_to_map("system_instruction", system_instruction) if functions && not Enum.empty?(functions) do req @@ -159,7 +173,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do %{ # Google AI functions use an OpenAI compatible format. # See: https://ai.google.dev/docs/function_calling#how_it_works - "functionDeclarations" => Enum.map(functions, &ChatOpenAI.for_api/1) + "functionDeclarations" => Enum.map(functions, &for_api/1) } ]) else @@ -188,21 +202,6 @@ defmodule LangChain.ChatModels.ChatGoogleAI do } end - def for_api(%Message{role: :system} = message) do - # No system messages support means we need to fake a prompt and response - # to pretend like it worked. - [ - %{ - "role" => :user, - "parts" => [%{"text" => message.content}] - }, - %{ - "role" => :model, - "parts" => [%{"text" => ""}] - } - ] - end - def for_api(%Message{} = message) do %{ "role" => map_role(message.role), @@ -249,6 +248,18 @@ defmodule LangChain.ChatModels.ChatGoogleAI do } end + def for_api(%Function{} = function) do + encoded = ChatOpenAI.for_api(function) + + # For functions with no parameters, Google AI needs the parameters field removing, otherwise it will error + # with "* GenerateContentRequest.tools[0].function_declarations[0].parameters.properties: should be non-empty for OBJECT type\n" + if encoded["parameters"] == %{"properties" => %{}, "type" => "object"} do + Map.delete(encoded, "parameters") + else + encoded + end + end + @doc """ Calls the Google AI API passing the ChatGoogleAI struct with configuration, plus either a simple message or the list of messages to act as the prompt. @@ -426,6 +437,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do text_part = parts |> filter_parts_for_types(["text"]) + |> filter_text_parts() |> Enum.map(fn part -> ContentPart.new!(%{type: :text, content: part["text"]}) end) @@ -479,6 +491,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do parts |> filter_parts_for_types(["text"]) + |> filter_text_parts() |> Enum.map(fn part -> ContentPart.new!(%{type: :text, content: part["text"]}) end) @@ -597,6 +610,16 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end) end + @doc false + def filter_text_parts(parts) when is_list(parts) do + Enum.filter(parts, fn p -> + case p do + %{"text" => text} -> text && text != "" + _ -> false + end + end) + end + @doc """ Return the content parts for the message. """ @@ -660,8 +683,8 @@ defmodule LangChain.ChatModels.ChatGoogleAI do defp get_token_usage(%{"usageMetadata" => usage} = _response_body) do # extract out the reported response token usage TokenUsage.new!(%{ - input: Map.get(usage, "promptTokenCount"), - output: Map.get(usage, "candidatesTokenCount") + input: Map.get(usage, "promptTokenCount", 0), + output: Map.get(usage, "candidatesTokenCount", 0) }) end diff --git a/lib/chat_models/chat_open_ai.ex b/lib/chat_models/chat_open_ai.ex index cb2a189f..ecd9f53f 100644 --- a/lib/chat_models/chat_open_ai.ex +++ b/lib/chat_models/chat_open_ai.ex @@ -122,6 +122,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do # How many chat completion choices to generate for each input message. field :n, :integer, default: 1 field :json_response, :boolean, default: false + field :json_schema, :map, default: nil field :stream, :boolean, default: false field :max_tokens, :integer, default: nil # Options for streaming response. Only set this when you set `stream: true` @@ -153,6 +154,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do :stream, :receive_timeout, :json_response, + :json_schema, :max_tokens, :stream_options, :user, @@ -263,11 +265,20 @@ defmodule LangChain.ChatModels.ChatOpenAI do %{"include_usage" => Map.get(data, :include_usage, Map.get(data, "include_usage"))} end - defp set_response_format(%ChatOpenAI{json_response: true}), - do: %{"type" => "json_object"} + defp set_response_format(%ChatOpenAI{json_response: true, json_schema: json_schema}) when not is_nil(json_schema) do + %{ + "type" => "json_schema", + "json_schema" => json_schema + } + end - defp set_response_format(%ChatOpenAI{json_response: false}), - do: %{"type" => "text"} + defp set_response_format(%ChatOpenAI{json_response: true}) do + %{"type" => "json_object"} + end + + defp set_response_format(%ChatOpenAI{json_response: false}) do + %{"type" => "text"} + end @doc """ Convert a LangChain structure to the expected map of data for the OpenAI API. @@ -908,6 +919,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do :seed, :n, :json_response, + :json_schema, :stream, :max_tokens, :stream_options diff --git a/lib/function.ex b/lib/function.ex index d79485ba..87d351b7 100644 --- a/lib/function.ex +++ b/lib/function.ex @@ -124,6 +124,9 @@ defmodule LangChain.Function do field :description, :string # Optional text the UI can display for when the function is executed. field :display_text, :string + # Optional flag to indicate if the function should be executed in strict mode. + # Defaults to `false`. + field :strict, :boolean, default: false # flag if the function should be auto-evaluated. Defaults to `false` # requiring an explicit step to perform the evaluation. # field :auto_evaluate, :boolean, default: false @@ -146,6 +149,7 @@ defmodule LangChain.Function do :name, :description, :display_text, + :strict, :parameters_schema, :parameters, :function, diff --git a/lib/message/tool_call.ex b/lib/message/tool_call.ex index e81ef3cf..8d2cfb7d 100644 --- a/lib/message/tool_call.ex +++ b/lib/message/tool_call.ex @@ -48,6 +48,7 @@ defmodule LangChain.Message.ToolCall do def new(attrs \\ %{}) do %ToolCall{} |> cast(attrs, @create_fields) + |> assign_string_value(:arguments, attrs) |> common_validations() |> apply_action(:insert) end @@ -139,6 +140,8 @@ defmodule LangChain.Message.ToolCall do raise LangChainError, "Can only merge tool calls with the same index" end + def merge(%ToolCall{} = t1, %ToolCall{} = t2) when t1 == t2, do: t1 + def merge(%ToolCall{} = primary, %ToolCall{} = call_part) do # merge the "part" into the primary. primary @@ -150,6 +153,11 @@ defmodule LangChain.Message.ToolCall do |> update_status(call_part) end + defp append_tool_name(%ToolCall{name: primary_name} = primary, %ToolCall{name: new_name}) + when primary_name == new_name do + primary + end + defp append_tool_name(%ToolCall{} = primary, %ToolCall{name: new_name}) when is_binary(new_name) do %ToolCall{primary | name: (primary.name || "") <> new_name} @@ -212,4 +220,21 @@ defmodule LangChain.Message.ToolCall do # no arguments to merge primary end + + # The contents and arguments get streamed as a string. A tool call may be part of a delta and it + # might be expected to have strings made up of spaces. The "cast" process of the changeset turns + # this into `nil` causing us to lose data. + # + # We want to take whatever we are given here. + defp assign_string_value(changeset, field, attrs) do + # get both possible versions of the arguments. + case Map.get(attrs, field) || Map.get(attrs, to_string(field)) do + "" -> + changeset + val when is_binary(val) -> + put_change(changeset, field, val) + _ -> + changeset + end + end end diff --git a/lib/message_delta.ex b/lib/message_delta.ex index af95ee77..249e2360 100644 --- a/lib/message_delta.ex +++ b/lib/message_delta.ex @@ -56,7 +56,6 @@ defmodule LangChain.MessageDelta do %MessageDelta{} |> cast(attrs, @create_fields) |> assign_string_value(:content, attrs) - |> assign_string_value(:arguments, attrs) |> validate_required(@required_fields) |> apply_action(:insert) end diff --git a/lib/utils.ex b/lib/utils.ex index 6a7ba642..11f3d3b3 100644 --- a/lib/utils.ex +++ b/lib/utils.ex @@ -288,4 +288,21 @@ defmodule LangChain.Utils do Logger.error(msg) {:error, msg} end + + @doc """ + Split the messages into "system" and "other". + Raises an error with the specified error message if more than 1 system message present. + Returns a tuple with the single system message and the list of other messages. + """ + @spec split_system_message([Message.t()], error_message :: String.t()) :: + {nil | Message.t(), [Message.t()]} | no_return() + def split_system_message(messages, error_message \\ "Only one system message is allowed") do + {system, other} = Enum.split_with(messages, &(&1.role == :system)) + + if length(system) > 1 do + raise LangChainError, error_message + end + + {List.first(system), other} + end end diff --git a/lib/utils/aws_eventstream_decoder.ex b/lib/utils/aws_eventstream_decoder.ex new file mode 100644 index 00000000..43869de7 --- /dev/null +++ b/lib/utils/aws_eventstream_decoder.ex @@ -0,0 +1,43 @@ +defmodule LangChain.Utils.AwsEventstreamDecoder do + @moduledoc """ + Decodes AWS messages in the application/vnd.amazon.eventstream content-type. + Ignores the headers because on Bedrock it's the same content type, event type & message type headers in every message. + """ + + def decode(<< + message_length::32, + headers_length::32, + prelude_checksum::32, + headers::binary-size(headers_length), + body::binary-size(message_length - headers_length - 16), + message_checksum::32, + rest::bitstring + >>) do + message_without_checksum = + <> + + with :ok <- + verify_checksum(<>, prelude_checksum, :prelude), + :ok <- verify_checksum(message_without_checksum, message_checksum, :message) do + {:ok, body, rest} + end + end + + def decode(<> = data) do + {:incomplete_message, "Expected message length #{message_length} but got #{byte_size(data)}"} + end + + def decode(_) do + {:error, "Unable to decode message"} + end + + defp verify_checksum(data, checksum, part) do + if :erlang.crc32(data) == checksum do + :ok + else + {:error, "Checksum mismatch for #{part}"} + end + end +end diff --git a/lib/utils/bedrock_config.ex b/lib/utils/bedrock_config.ex new file mode 100644 index 00000000..b97d4edd --- /dev/null +++ b/lib/utils/bedrock_config.ex @@ -0,0 +1,36 @@ +defmodule LangChain.Utils.BedrockConfig do + @moduledoc """ + Configuration for AWS Bedrock. + """ + use Ecto.Schema + import Ecto.Changeset + + @primary_key false + embedded_schema do + # A function that returns a keyword list including access_key_id, secret_access_key, and optionally token. + # Used to configure Req's aws_sigv4 option. + field :credentials, :any, virtual: true + field :region, :string + field :anthropic_version, :string, default: "bedrock-2023-05-31" + end + + def changeset(bedrock, attrs) do + bedrock + |> cast(attrs, [:credentials, :region, :anthropic_version]) + |> validate_required([:credentials, :region, :anthropic_version]) + end + + def aws_sigv4_opts(%__MODULE__{} = bedrock) do + Keyword.merge(bedrock.credentials.(), + region: bedrock.region, + service: :bedrock + ) + end + + def url(%__MODULE__{region: region}, model: model, stream: stream) do + "https://bedrock-runtime.#{region}.amazonaws.com/model/#{model}/#{action(stream: stream)}" + end + + defp action(stream: true), do: "invoke-with-response-stream" + defp action(stream: false), do: "invoke" +end diff --git a/lib/utils/bedrock_stream_decoder.ex b/lib/utils/bedrock_stream_decoder.ex new file mode 100644 index 00000000..6dce39b9 --- /dev/null +++ b/lib/utils/bedrock_stream_decoder.ex @@ -0,0 +1,77 @@ +defmodule LangChain.Utils.BedrockStreamDecoder do + alias LangChain.Utils.AwsEventstreamDecoder + require Logger + + def decode_stream({chunk, buffer}, chunks \\ []) do + combined_data = buffer <> chunk + + case decode_chunk(combined_data) do + {:ok, chunk, remaining} -> + chunks = [chunk | chunks] + finish_or_decode_remaining(chunks, remaining) + + {:incomplete_message, _} -> + {chunks, combined_data} + + {:exception_response, response, remaining} -> + chunks = [response | chunks] + finish_or_decode_remaining(chunks, remaining) + + {:error, error} -> + Logger.error("Failed to decode Bedrock chunk: #{inspect(error)}") + {chunks, combined_data} + end + end + + defp finish_or_decode_remaining(chunks, remaining) when byte_size(remaining) > 0 do + decode_stream({"", remaining}, chunks) + end + + defp finish_or_decode_remaining(chunks, remaining) do + {chunks, remaining} + end + + defp decode_chunk(chunk) do + with {:ok, decoded_message, remaining} <- AwsEventstreamDecoder.decode(chunk), + {:ok, response_json} <- decode_json(decoded_message), + {:ok, bytes} <- get_bytes(response_json, remaining), + {:ok, json} <- decode_base64(bytes), + {:ok, payload} <- decode_json(json) do + {:ok, payload, remaining} + end + end + + defp decode_json(data) do + case Jason.decode(data) do + {:ok, json} -> + {:ok, json} + + {:error, error} -> + {:error, "Unable to decode JSON: #{inspect(error)}"} + end + end + + defp decode_base64(bytes) do + case Base.decode64(bytes) do + {:ok, bytes} -> + {:ok, bytes} + + :error -> + {:error, "Unable to decode base64 \"bytes\" from Bedrock response"} + end + end + + defp get_bytes(%{"bytes" => bytes}, _remaining) do + {:ok, bytes} + end + + # bytes is likely missing from the response in exception cases + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModelWithResponseStream.html + defp get_bytes(response, remaining) do + Logger.debug("Bedrock response is an exception: #{inspect(response)}") + exception_message = Map.keys(response) |> Enum.join(", ") + # Make it easier to match on this pattern in process_data fns + response = Map.put(response, :bedrock_exception, exception_message) + {:exception_response, response, remaining} + end +end diff --git a/mix.exs b/mix.exs index 1de843d2..5b19bd5d 100644 --- a/mix.exs +++ b/mix.exs @@ -38,7 +38,7 @@ defmodule LangChain.MixProject do [ {:ecto, "~> 3.10 or ~> 3.11"}, {:gettext, "~> 0.20"}, - {:req, ">= 0.5.0"}, + {:req, ">= 0.5.2"}, {:abacus, "~> 2.1.0"}, {:nx, ">= 0.7.0", optional: true}, {:ex_doc, "~> 0.34", only: :dev, runtime: false}, diff --git a/mix.lock b/mix.lock index eff07dd7..1429e187 100644 --- a/mix.lock +++ b/mix.lock @@ -1,28 +1,28 @@ %{ "abacus": {:hex, :abacus, "2.1.0", "b6db5c989ba3d9dd8c36d1cb269e2f0058f34768d47c67eb8ce06697ecb36dd4", [:mix], [], "hexpm", "255de08b02884e8383f1eed8aa31df884ce0fb5eb394db81ff888089f2a1bbff"}, - "castore": {:hex, :castore, "1.0.7", "b651241514e5f6956028147fe6637f7ac13802537e895a724f90bf3e36ddd1dd", [:mix], [], "hexpm", "da7785a4b0d2a021cd1292a60875a784b6caef71e76bf4917bdee1f390455cf5"}, + "castore": {:hex, :castore, "1.0.8", "dedcf20ea746694647f883590b82d9e96014057aff1d44d03ec90f36a5c0dc6e", [:mix], [], "hexpm", "0b2b66d2ee742cb1d9cb8c8be3b43c3a70ee8651f37b75a8b982e036752983f1"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, - "ecto": {:hex, :ecto, "3.11.1", "4b4972b717e7ca83d30121b12998f5fcdc62ba0ed4f20fd390f16f3270d85c3e", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ebd3d3772cd0dfcd8d772659e41ed527c28b2a8bde4b00fe03e0463da0f1983b"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.40", "f3534689f6b58f48aa3a9ac850d4f05832654fe257bf0549c08cc290035f70d5", [:mix], [], "hexpm", "cdb34f35892a45325bad21735fadb88033bcb7c4c296a999bde769783f53e46a"}, + "ecto": {:hex, :ecto, "3.11.2", "e1d26be989db350a633667c5cda9c3d115ae779b66da567c68c80cfb26a8c9ee", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3c38bca2c6f8d8023f2145326cc8a80100c3ffe4dcbd9842ff867f7fc6156c65"}, "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, "ex_doc": {:hex, :ex_doc, "0.34.1", "9751a0419bc15bc7580c73fde506b17b07f6402a1e5243be9e0f05a68c723368", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "d441f1a86a235f59088978eff870de2e815e290e44a8bd976fe5d64470a4c9d2"}, - "expo": {:hex, :expo, "0.4.1", "1c61d18a5df197dfda38861673d392e642649a9cef7694d2f97a587b2cfb319b", [:mix], [], "hexpm", "2ff7ba7a798c8c543c12550fa0e2cbc81b95d4974c65855d8d15ba7b37a1ce47"}, + "expo": {:hex, :expo, "0.5.2", "beba786aab8e3c5431813d7a44b828e7b922bfa431d6bfbada0904535342efe2", [:mix], [], "hexpm", "8c9bfa06ca017c9cb4020fabe980bc7fdb1aaec059fd004c2ab3bff03b1c599c"}, "finch": {:hex, :finch, "0.18.0", "944ac7d34d0bd2ac8998f79f7a811b21d87d911e77a786bc5810adb75632ada4", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "69f5045b042e531e53edc2574f15e25e735b522c37e2ddb766e15b979e03aa65"}, - "gettext": {:hex, :gettext, "0.22.3", "c8273e78db4a0bb6fba7e9f0fd881112f349a3117f7f7c598fa18c66c888e524", [:mix], [{:expo, "~> 0.4.0", [hex: :expo, repo: "hexpm", optional: false]}], "hexpm", "935f23447713954a6866f1bb28c3a878c4c011e802bcd68a726f5e558e4b64bd"}, - "hpax": {:hex, :hpax, "0.2.0", "5a58219adcb75977b2edce5eb22051de9362f08236220c9e859a47111c194ff5", [:mix], [], "hexpm", "bea06558cdae85bed075e6c036993d43cd54d447f76d8190a8db0dc5893fa2f1"}, - "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, + "gettext": {:hex, :gettext, "0.24.0", "6f4d90ac5f3111673cbefc4ebee96fe5f37a114861ab8c7b7d5b30a1108ce6d8", [:mix], [{:expo, "~> 0.5.1", [hex: :expo, repo: "hexpm", optional: false]}], "hexpm", "bdf75cdfcbe9e4622dd18e034b227d77dd17f0f133853a1c73b97b3d6c770e8b"}, + "hpax": {:hex, :hpax, "1.0.0", "28dcf54509fe2152a3d040e4e3df5b265dcb6cb532029ecbacf4ce52caea3fd2", [:mix], [], "hexpm", "7f1314731d711e2ca5fdc7fd361296593fc2542570b3105595bb0bc6d0fad601"}, + "jason": {:hex, :jason, "1.4.3", "d3f984eeb96fe53b85d20e0b049f03e57d075b5acda3ac8d465c969a2536c17b", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "9a90e868927f7c777689baa16d86f4d0e086d968db5c05d917ccff6d443e58a3"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.0", "6f0eff9c9c489f26b69b61440bf1b238d95badae49adac77973cbacae87e3c2e", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "ea7a9307de9d1548d2a72d299058d1fd2339e3d398560a0e46c27dab4891e4d2"}, - "mime": {:hex, :mime, "2.0.5", "dc34c8efd439abe6ae0343edbb8556f4d63f178594894720607772a041b04b02", [:mix], [], "hexpm", "da0d64a365c45bc9935cc5c8a7fc5e49a0e0f9932a761c55d6c52b142780a05c"}, + "mime": {:hex, :mime, "2.0.6", "8f18486773d9b15f95f4f4f1e39b710045fa1de891fada4516559967276e4dc2", [:mix], [], "hexpm", "c9945363a6b26d747389aac3643f8e0e09d30499a138ad64fe8fd1d13d9b153e"}, "mimic": {:hex, :mimic, "1.8.2", "f4cf6ad13a305c5ee1a6c304ee36fc7afb3ad748e2af8cd5fbf122f44a283534", [:mix], [], "hexpm", "abc982d5fdcc4cb5292980cb698cd47c0c5d9541b401e540fb695b69f1d46485"}, - "mint": {:hex, :mint, "1.6.1", "065e8a5bc9bbd46a41099dfea3e0656436c5cbcb6e741c80bd2bad5cd872446f", [:mix], [{:castore, "~> 0.1.0 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:hpax, "~> 0.1.1 or ~> 0.2.0", [hex: :hpax, repo: "hexpm", optional: false]}], "hexpm", "4fc518dcc191d02f433393a72a7ba3f6f94b101d094cb6bf532ea54c89423780"}, + "mint": {:hex, :mint, "1.6.2", "af6d97a4051eee4f05b5500671d47c3a67dac7386045d87a904126fd4bbcea2e", [:mix], [{:castore, "~> 0.1.0 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:hpax, "~> 0.1.1 or ~> 0.2.0 or ~> 1.0", [hex: :hpax, repo: "hexpm", optional: false]}], "hexpm", "5ee441dffc1892f1ae59127f74afe8fd82fda6587794278d924e4d90ea3d63f9"}, "nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"}, "nimble_ownership": {:hex, :nimble_ownership, "0.3.1", "99d5244672fafdfac89bfad3d3ab8f0d367603ce1dc4855f86a1c75008bce56f", [:mix], [], "hexpm", "4bf510adedff0449a1d6e200e43e57a814794c8b5b6439071274d248d272a549"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, - "nx": {:hex, :nx, "0.7.1", "5f6376e3d18408116e8a84b8f4ac851fb07dfe61764a5410ebf0b5dcb69c1b7e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e3ddd6a3f2a9bac79c67b3933368c25bb5ec814a883fc68aba8fd8a236751777"}, - "req": {:hex, :req, "0.5.0", "6d8a77c25cfc03e06a439fb12ffb51beade53e3fe0e2c5e362899a18b50298b3", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.17", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 1.6 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "dda04878c1396eebbfdec6db6f3d4ca609e5c8846b7ee88cc56eb9891406f7a3"}, + "nx": {:hex, :nx, "0.7.3", "51ff45d9f9ff58b616f4221fa54ccddda98f30319bb8caaf86695234a469017a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "5ff29af84f08db9bda66b8ef7ce92ab583ab4f983629fe00b479f1e5c7c705a6"}, + "req": {:hex, :req, "0.5.2", "70b4976e5fbefe84e5a57fd3eea49d4e9aa0ac015301275490eafeaec380f97f", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.17", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 2.0.6 or ~> 2.1", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "0c63539ab4c2d6ced6114d2684276cef18ac185ee00674ee9af4b1febba1f986"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, } diff --git a/test/chains/data_extraction_chain_test.exs b/test/chains/data_extraction_chain_test.exs index 185c26f4..fc7e8613 100644 --- a/test/chains/data_extraction_chain_test.exs +++ b/test/chains/data_extraction_chain_test.exs @@ -15,8 +15,8 @@ defmodule LangChain.Chains.DataExtractionChainTest do FunctionParam.new!(%{name: "person_name", type: :string}), FunctionParam.new!(%{name: "person_age", type: :number}), FunctionParam.new!(%{name: "person_hair_color", type: :string}), - FunctionParam.new!(%{name: "dog_name", type: :string}), - FunctionParam.new!(%{name: "dog_breed", type: :string}) + FunctionParam.new!(%{name: "pet_dog_name", type: :string}), + FunctionParam.new!(%{name: "pet_dog_breed", type: :string}) ] |> FunctionParam.to_parameters_schema() @@ -31,8 +31,8 @@ defmodule LangChain.Chains.DataExtractionChainTest do items: %{ "type" => "object", "properties" => %{ - "dog_breed" => %{"type" => "string"}, - "dog_name" => %{"type" => "string"}, + "pet_dog_breed" => %{"type" => "string"}, + "pet_dog_name" => %{"type" => "string"}, "person_age" => %{"type" => "number"}, "person_hair_color" => %{"type" => "string"}, "person_name" => %{"type" => "string"} @@ -55,32 +55,34 @@ defmodule LangChain.Chains.DataExtractionChainTest do FunctionParam.new!(%{name: "person_name", type: :string}), FunctionParam.new!(%{name: "person_age", type: :number}), FunctionParam.new!(%{name: "person_hair_color", type: :string}), - FunctionParam.new!(%{name: "dog_name", type: :string}), - FunctionParam.new!(%{name: "dog_breed", type: :string}) + FunctionParam.new!(%{name: "pet_dog_name", type: :string}), + FunctionParam.new!(%{name: "pet_dog_breed", type: :string}) ] |> FunctionParam.to_parameters_schema() # Model setup - specify the model and seed - {:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o", temperature: 0, seed: 0, stream: false}) + {:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o-mini-2024-07-18", temperature: 0, seed: 0, stream: false}) # run the chain, chain.run(prompt to extract data from) - data_prompt = - "Alex is 5 feet tall. Claudia is 4 feet taller than Alex and jumps higher than him. - Claudia is a brunette and Alex is blonde. Alex's dog Frosty is a labrador and likes to play hide and seek. Identify each person and their relevant information." + data_prompt = """ + Alex is 5 feet tall. Claudia is 4 feet taller than Alex and jumps higher than him. + Claudia is a brunette and Alex is blonde. + Alex's dog Frosty is a labrador and likes to play hide and seek. Identify each person and their relevant information. + """ {:ok, result} = DataExtractionChain.run(chat, schema_parameters, data_prompt, verbose: true) assert result == [ %{ - "dog_breed" => "labrador", - "dog_name" => "Frosty", + "pet_dog_breed" => "labrador", + "pet_dog_name" => "Frosty", "person_age" => nil, "person_hair_color" => "blonde", "person_name" => "Alex" }, %{ - "dog_breed" => nil, - "dog_name" => nil, + "pet_dog_breed" => nil, + "pet_dog_name" => nil, "person_age" => nil, "person_hair_color" => "brunette", "person_name" => "Claudia" diff --git a/test/chat_models/chat_anthropic_test.exs b/test/chat_models/chat_anthropic_test.exs index e2bcc217..ae732fe5 100644 --- a/test/chat_models/chat_anthropic_test.exs +++ b/test/chat_models/chat_anthropic_test.exs @@ -1,5 +1,7 @@ defmodule LangChain.ChatModels.ChatAnthropicTest do + alias LangChain.Utils.BedrockStreamDecoder use LangChain.BaseCase + use Mimic doctest LangChain.ChatModels.ChatAnthropic alias LangChain.ChatModels.ChatAnthropic @@ -11,14 +13,25 @@ defmodule LangChain.ChatModels.ChatAnthropicTest do alias LangChain.TokenUsage alias LangChain.Function alias LangChain.FunctionParam + alias LangChain.BedrockHelpers @test_model "claude-3-opus-20240229" + @bedrock_test_model "anthropic.claude-3-5-sonnet-20240620-v1:0" + @apis [:anthropic, :anthropic_bedrock] defp hello_world(_args, _context) do "Hello world!" end - setup do + defp api_config_for(:anthropic_bedrock) do + %{bedrock: BedrockHelpers.bedrock_config(), model: @bedrock_test_model} + end + + defp api_config_for(_), do: %{} + + setup context do + api_config = api_config_for(context[:live_api]) + {:ok, hello_world} = Function.new(%{ name: "hello_world", @@ -26,7 +39,7 @@ defmodule LangChain.ChatModels.ChatAnthropicTest do function: fn _args, _context -> "Hello world!" end }) - %{hello_world: hello_world} + %{hello_world: hello_world, api_config: api_config} end describe "new/1" do @@ -224,7 +237,46 @@ defmodule LangChain.ChatModels.ChatAnthropicTest do end end - describe "do_process_response/1" do + describe "do_process_response/2 with Bedrock" do + setup do + model = + ChatAnthropic.new!(%{stream: false} |> Map.merge(api_config_for(:anthropic_bedrock))) + + %{model: model} + end + + test "handles messages the same as anthropics API", %{model: model} do + response = %{ + "id" => "id-123", + "type" => "message", + "role" => "assistant", + "content" => [%{"type" => "text", "text" => "Greetings!"}], + "model" => "claude-3-haiku-20240307", + "stop_reason" => "end_turn" + } + + assert %Message{} = struct = ChatAnthropic.do_process_response(model, response) + assert struct.role == :assistant + assert struct.content == "Greetings!" + assert is_nil(struct.index) + end + + test "handles error messages", %{model: model} do + error = "Invalid API key" + + assert {:error, "Received error from API: #{error}"} == + ChatAnthropic.do_process_response(model, %{"message" => error}) + end + + test "handles stream error messages", %{model: model} do + error = "Internal error" + + assert {:error, "Stream exception received: #{inspect(error)}"} == + ChatAnthropic.do_process_response(model, %{bedrock_exception: error}) + end + end + + describe "do_process_response/2" do setup do model = ChatAnthropic.new!(%{stream: false}) %{model: model} @@ -434,78 +486,147 @@ defmodule LangChain.ChatModels.ChatAnthropicTest do assert reason == "Authentication failure with request" end - @tag live_call: true, live_anthropic: true - test "basic streamed content example and fires ratelimit callback and token usage" do - handlers = %{ - on_llm_ratelimit_info: fn _model, headers -> - send(self(), {:fired_ratelimit_info, headers}) - end, - on_llm_token_usage: fn _model, usage -> - send(self(), {:fired_token_usage, usage}) - end - } - - {:ok, chat} = ChatAnthropic.new(%{stream: true, callbacks: [handlers]}) + @tag live_call: true, live_anthropic_bedrock: true + test "Bedrock: handles when invalid credentials given" do + {:ok, chat} = + ChatAnthropic.new(%{ + stream: true, + bedrock: %{ + credentials: fn -> [access_key_id: "invalid", secret_access_key: "invalid"] end, + region: "us-east-1" + } + }) - {:ok, result} = + {:error, reason} = ChatAnthropic.call(chat, [ Message.new_user!("Return the response 'Colorful Threads'.") ]) - # returns a list of MessageDeltas. - assert result == [ - %LangChain.MessageDelta{ - content: "", - status: :incomplete, - index: nil, - role: :assistant - }, - %LangChain.MessageDelta{ - content: "Color", - status: :incomplete, - index: nil, - role: :assistant - }, - %LangChain.MessageDelta{ - content: "ful", - status: :incomplete, - index: nil, - role: :assistant - }, - %LangChain.MessageDelta{ - content: " Threads", - status: :incomplete, - index: nil, - role: :assistant - }, - %LangChain.MessageDelta{ - content: "", - status: :complete, - index: nil, - role: :assistant - } - ] + assert reason == + "Received error from API: The security token included in the request is invalid." + end + + for api <- @apis do + Module.put_attribute(__MODULE__, :tag, {:"live_#{api}", true}) + @tag live_call: true, live_api: api + test "#{BedrockHelpers.prefix_for(api)}basic streamed content example and fires ratelimit callback and token usage", + %{live_api: api, api_config: api_config} do + handlers = %{ + on_llm_ratelimit_info: fn _model, headers -> + send(self(), {:fired_ratelimit_info, headers}) + end, + on_llm_token_usage: fn _model, usage -> + send(self(), {:fired_token_usage, usage}) + end + } + + {:ok, chat} = + ChatAnthropic.new(%{stream: true, callbacks: [handlers]} |> Map.merge(api_config)) + + {:ok, result} = + ChatAnthropic.call(chat, [ + Message.new_user!("Return the response 'Keep up the good work!'.") + ]) + + # returns a list of MessageDeltas. + assert result == [ + %LangChain.MessageDelta{ + content: "", + status: :incomplete, + index: nil, + role: :assistant + }, + %LangChain.MessageDelta{ + content: "Keep", + status: :incomplete, + index: nil, + role: :assistant + }, + %LangChain.MessageDelta{ + content: " up the good work", + status: :incomplete, + index: nil, + role: :assistant + }, + %LangChain.MessageDelta{ + content: "!", + status: :incomplete, + index: nil, + role: :assistant + }, + %LangChain.MessageDelta{ + content: "", + status: :complete, + index: nil, + role: :assistant + } + ] + + assert_received {:fired_ratelimit_info, info} + + if api != :anthropic_bedrock do + assert %{ + "anthropic-ratelimit-requests-limit" => _, + "anthropic-ratelimit-requests-remaining" => _, + "anthropic-ratelimit-requests-reset" => _, + "anthropic-ratelimit-tokens-limit" => _, + "anthropic-ratelimit-tokens-remaining" => _, + "anthropic-ratelimit-tokens-reset" => _, + # Not always included + # "retry-after" => _, + "request-id" => _ + } = info + end + + assert_received {:fired_token_usage, usage} + assert %TokenUsage{output: 9} = usage + end + end + end - assert_received {:fired_ratelimit_info, info} - - assert %{ - "anthropic-ratelimit-requests-limit" => _, - "anthropic-ratelimit-requests-remaining" => _, - "anthropic-ratelimit-requests-reset" => _, - "anthropic-ratelimit-tokens-limit" => _, - "anthropic-ratelimit-tokens-remaining" => _, - "anthropic-ratelimit-tokens-reset" => _, - # Not always included - # "retry-after" => _, - "request-id" => _ - } = info - - assert_received {:fired_token_usage, usage} - assert %TokenUsage{output: 8} = usage + describe "decode_stream/2 with Bedrock" do + setup do + {:ok, model} = + ChatAnthropic.new( + %{} + |> Map.merge(api_config_for(:anthropic_bedrock)) + ) + + %{model: model} + end + + test "filters irrelevant events", %{model: model} do + relevant_events = [ + %{"type" => "content_block_start"}, + %{"type" => "content_block_delta"}, + %{"type" => "message_delta"} + ] + + BedrockStreamDecoder + |> stub(:decode_stream, fn _, _ -> + {[ + %{"type" => "message_start"}, + %{"type" => "message_stop"} + ] ++ relevant_events, ""} + end) + + {chunks, remaining} = ChatAnthropic.decode_stream(model, {"", ""}) + + assert chunks == relevant_events + assert remaining == "" + end + + test "it passes through exception_message", %{model: model} do + BedrockStreamDecoder + |> stub(:decode_stream, fn _, _ -> {[%{bedrock_exception: "internalServerError"}], ""} end) + + {chunks, remaining} = ChatAnthropic.decode_stream(model, {"", ""}) + assert chunks == [%{bedrock_exception: "internalServerError"}] + assert remaining == "" end end - describe "decode_stream/1" do + describe "decode_stream/2" do test "when data is broken" do data1 = ~s|event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hr"} }\n\n @@ -513,7 +634,7 @@ event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" back"} }\n\n event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" what"} }\n\nevent: content_block_delta\ndata: {"type":"content_block_delta","index":0| - {processed1, incomplete} = ChatAnthropic.decode_stream({data1, ""}) + {processed1, incomplete} = ChatAnthropic.decode_stream(%ChatAnthropic{}, {data1, ""}) assert incomplete == ~s|event: content_block_delta\ndata: {"type":"content_block_delta","index":0| @@ -546,7 +667,8 @@ event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" friend"} }\n\n event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" said"} }\n\n| - {processed2, incomplete} = ChatAnthropic.decode_stream({data2, incomplete}) + {processed2, incomplete} = + ChatAnthropic.decode_stream(%ChatAnthropic{}, {data2, incomplete}) assert incomplete == "" @@ -585,7 +707,7 @@ event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta """ - {parsed, buffer} = ChatAnthropic.decode_stream({chunk, ""}) + {parsed, buffer} = ChatAnthropic.decode_stream(%ChatAnthropic{}, {chunk, ""}) assert [ %{ @@ -608,7 +730,7 @@ event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta """ - {parsed, buffer} = ChatAnthropic.decode_stream({chunk, ""}) + {parsed, buffer} = ChatAnthropic.decode_stream(%ChatAnthropic{}, {chunk, ""}) assert [ %{ @@ -632,7 +754,7 @@ event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta """ - {parsed, buffer} = ChatAnthropic.decode_stream({chunk, ""}) + {parsed, buffer} = ChatAnthropic.decode_stream(%ChatAnthropic{}, {chunk, ""}) assert [ %{ @@ -655,7 +777,7 @@ event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta """ - {parsed, buffer} = ChatAnthropic.decode_stream({chunk, ""}) + {parsed, buffer} = ChatAnthropic.decode_stream(%ChatAnthropic{}, {chunk, ""}) assert [ %{ @@ -677,7 +799,7 @@ event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta chunk_1 = "event: content_blo" - {parsed, buffer} = ChatAnthropic.decode_stream({chunk_1, ""}) + {parsed, buffer} = ChatAnthropic.decode_stream(%ChatAnthropic{}, {chunk_1, ""}) assert [] = parsed assert buffer == chunk_1 @@ -685,7 +807,7 @@ event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta chunk_2 = "ck_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"de" - {parsed, buffer} = ChatAnthropic.decode_stream({chunk_2, buffer}) + {parsed, buffer} = ChatAnthropic.decode_stream(%ChatAnthropic{}, {chunk_2, buffer}) assert [] = parsed assert buffer == chunk_1 <> chunk_2 @@ -697,7 +819,7 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text | - {parsed, buffer} = ChatAnthropic.decode_stream({chunk_3, buffer}) + {parsed, buffer} = ChatAnthropic.decode_stream(%ChatAnthropic{}, {chunk_3, buffer}) assert [ %{ @@ -716,29 +838,6 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text end end - describe "split_system_message/1" do - test "returns system message and rest separately" do - system = Message.new_system!() - user_msg = Message.new_user!("Hi") - assert {system, [user_msg]} == ChatAnthropic.split_system_message([system, user_msg]) - end - - test "return nil when no system message set" do - user_msg = Message.new_user!("Hi") - assert {nil, [user_msg]} == ChatAnthropic.split_system_message([user_msg]) - end - - test "raises exception with multiple system messages" do - assert_raise LangChain.LangChainError, - "Anthropic only supports a single System message", - fn -> - system = Message.new_system!() - user_msg = Message.new_user!("Hi") - ChatAnthropic.split_system_message([system, user_msg, system]) - end - end - end - describe "for_api/1" do test "turns a basic user message into the expected JSON format" do expected = %{"role" => "user", "content" => "Hi."} @@ -1090,74 +1189,88 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text end describe "image vision using message parts" do - @tag live_call: true, live_anthropic: true - test "supports multi-modal user message with image prompt" do - image_data = load_image_base64("barn_owl.jpg") + for api <- @apis do + Module.put_attribute(__MODULE__, :tag, {:"live_#{api}", true}) + @tag live_call: true, live_api: api + test "#{BedrockHelpers.prefix_for(api)} supports multi-modal user message with image prompt", + %{api_config: api_config} do + image_data = load_image_base64("barn_owl.jpg") - # https://docs.anthropic.com/claude/reference/messages-examples#vision - {:ok, chat} = ChatAnthropic.new(%{model: @test_model}) + # https://docs.anthropic.com/claude/reference/messages-examples#vision + {:ok, chat} = ChatAnthropic.new(%{model: @test_model} |> Map.merge(api_config)) - message = - Message.new_user!([ - ContentPart.text!("Identify what this is a picture of:"), - ContentPart.image!(image_data, media: :jpg) - ]) + message = + Message.new_user!([ + ContentPart.text!("Identify what this is a picture of:"), + ContentPart.image!(image_data, media: :jpg) + ]) - {:ok, response} = ChatAnthropic.call(chat, [message], []) + {:ok, response} = ChatAnthropic.call(chat, [message], []) - assert %Message{role: :assistant} = response - assert String.contains?(response.content, "barn owl") + assert %Message{role: :assistant} = response + assert String.contains?(response.content |> String.downcase(), "barn owl") + end end end describe "a tool use" do - @tag live_call: true, live_anthropic: true - test "uses a tool with no parameters" do - # https://docs.anthropic.com/en/docs/tool-use - {:ok, chat} = ChatAnthropic.new(%{model: @test_model}) - - message = Message.new_user!("Use the 'do_something' tool.") - - tool = - Function.new!(%{ - name: "do_something", - parameters: [], - function: fn _args, _context -> :ok end - }) - - {:ok, response} = ChatAnthropic.call(chat, [message], [tool]) - - assert %Message{role: :assistant} = response - assert [%ToolCall{} = call] = response.tool_calls - assert call.status == :complete - assert call.type == :function - assert call.name == "do_something" - # detects empty and returns nil - assert call.arguments == nil - - # %LangChain.Message{ - # content: "\nThe user has requested to use the 'do_something' tool. Let's look at the parameters for this tool:\n{\"name\": \"do_something\", \"parameters\": {\"properties\": {}, \"type\": \"object\"}}\n\nThis tool does not require any parameters. Since there are no required parameters missing, we can proceed with invoking the 'do_something' tool.\n", - # index: nil, - # status: :complete, - # role: :assistant, - # name: nil, - # tool_calls: [ - # %LangChain.Message.ToolCall{ - # status: :complete, - # type: :function, - # call_id: "toolu_01Pch8mywrRttVZNK3zvntuF", - # name: "do_something", - # arguments: %{}, - # index: nil - # } - # ], - # } + for api <- @apis do + Module.put_attribute(__MODULE__, :tag, {:"live_#{api}", true}) + @tag live_call: true, live_api: api + test "#{BedrockHelpers.prefix_for(api)} uses a tool with no parameters", %{ + api_config: api_config + } do + # https://docs.anthropic.com/en/docs/tool-use + {:ok, chat} = ChatAnthropic.new(%{model: @test_model} |> Map.merge(api_config)) + + message = Message.new_user!("Use the 'do_something' tool.") + + tool = + Function.new!(%{ + name: "do_something", + parameters: [], + function: fn _args, _context -> :ok end + }) + + {:ok, response} = ChatAnthropic.call(chat, [message], [tool]) + + assert %Message{role: :assistant} = response + assert [%ToolCall{} = call] = response.tool_calls + assert call.status == :complete + assert call.type == :function + assert call.name == "do_something" + # detects empty and returns nil + assert call.arguments == nil + + # %LangChain.Message{ + # content: "\nThe user has requested to use the 'do_something' tool. Let's look at the parameters for this tool:\n{\"name\": \"do_something\", \"parameters\": {\"properties\": {}, \"type\": \"object\"}}\n\nThis tool does not require any parameters. Since there are no required parameters missing, we can proceed with invoking the 'do_something' tool.\n", + # index: nil, + # status: :complete, + # role: :assistant, + # name: nil, + # tool_calls: [ + # %LangChain.Message.ToolCall{ + # status: :complete, + # type: :function, + # call_id: "toolu_01Pch8mywrRttVZNK3zvntuF", + # name: "do_something", + # arguments: %{}, + # index: nil + # } + # ], + # } + end end + end - @tag live_call: true, live_anthropic: true - test "uses a tool with parameters" do + for api <- @apis do + Module.put_attribute(__MODULE__, :tag, {:"live_#{api}", true}) + @tag live_call: true, live_api: api + test "#{BedrockHelpers.prefix_for(api)} uses a tool with parameters", %{ + api_config: api_config + } do # https://docs.anthropic.com/claude/reference/messages-examples#vision - {:ok, chat} = ChatAnthropic.new(%{model: @test_model}) + {:ok, chat} = ChatAnthropic.new(%{model: @test_model} |> Map.merge(api_config)) message = Message.new_user!("Use the 'do_something' tool with the value 'cat'.") @@ -1197,7 +1310,7 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text end @tag live_call: true, live_anthropic: true - test "streams a tool call with parameters" do + test "#{BedrockHelpers.prefix_for(api)} streams a tool call with parameters" do handler = %{ on_llm_new_delta: fn _model, delta -> # IO.inspect(delta, label: "DELTA") @@ -1233,105 +1346,126 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text [tool_result] = updated_chain.last_message.tool_results assert tool_result.content == "SUCCESS" end - end - describe "works within a chain" do - @tag live_call: true, live_anthropic: true - test "works with a streaming response" do - test_pid = self() + describe "#{BedrockHelpers.prefix_for(api)} works within a chain" do + Module.put_attribute(__MODULE__, :tag, {:"live_#{api}", true}) + @tag live_call: true, live_api: api + test "works with a streaming response", %{api_config: api_config} do + test_pid = self() - handler = %{ - on_llm_new_delta: fn _model, delta -> - send(test_pid, {:streamed_fn, delta}) - end - } + handler = %{ + on_llm_new_delta: fn _model, delta -> + send(test_pid, {:streamed_fn, delta}) + end + } - {:ok, updated_chain} = - LLMChain.new!(%{llm: %ChatAnthropic{stream: true, callbacks: [handler]}}) - |> LLMChain.add_message(Message.new_user!("Say, 'Hi!'!")) - |> LLMChain.run() + {:ok, chat} = + ChatAnthropic.new( + %{stream: true, callbacks: [handler]} + |> Map.merge(api_config) + ) - assert updated_chain.last_message.content == "Hi!" - assert updated_chain.last_message.status == :complete - assert updated_chain.last_message.role == :assistant + {:ok, updated_chain} = + LLMChain.new!(%{llm: chat}) + |> LLMChain.add_message(Message.new_user!("Say, 'Hi!'!")) + |> LLMChain.run() - assert_received {:streamed_fn, data} - assert %MessageDelta{role: :assistant} = data - end + assert updated_chain.last_message.content == "Hi!" + assert updated_chain.last_message.status == :complete + assert updated_chain.last_message.role == :assistant - @tag live_call: true, live_anthropic: true - test "works with NON streaming response and fires ratelimit callback and token usage" do - test_pid = self() + assert_received {:streamed_fn, data} + assert %MessageDelta{role: :assistant} = data + end - handler = %{ - on_llm_new_message: fn _model, message -> - send(test_pid, {:received_msg, message}) - end, - on_llm_ratelimit_info: fn _model, headers -> - send(test_pid, {:fired_ratelimit_info, headers}) - end, - on_llm_token_usage: fn _model, usage -> - send(self(), {:fired_token_usage, usage}) + Module.put_attribute(__MODULE__, :tag, {:"live_#{api}", true}) + @tag live_call: true, live_api: api + test "works with NON streaming response and fires ratelimit callback and token usage", %{ + api_config: api_config, + live_api: api + } do + test_pid = self() + + handler = %{ + on_llm_new_message: fn _model, message -> + send(test_pid, {:received_msg, message}) + end, + on_llm_ratelimit_info: fn _model, headers -> + send(test_pid, {:fired_ratelimit_info, headers}) + end, + on_llm_token_usage: fn _model, usage -> + send(self(), {:fired_token_usage, usage}) + end + } + + {:ok, updated_chain} = + LLMChain.new!(%{ + llm: + ChatAnthropic.new!(%{stream: false, callbacks: [handler]} |> Map.merge(api_config)) + }) + |> LLMChain.add_message(Message.new_user!("Say, 'Hi!'!")) + |> LLMChain.run() + + assert updated_chain.last_message.content == "Hi!" + assert updated_chain.last_message.status == :complete + assert updated_chain.last_message.role == :assistant + + assert_received {:received_msg, data} + assert %Message{role: :assistant} = data + + assert_received {:fired_ratelimit_info, info} + + if api != :anthropic_bedrock do + assert %{ + "anthropic-ratelimit-requests-limit" => _, + "anthropic-ratelimit-requests-remaining" => _, + "anthropic-ratelimit-requests-reset" => _, + "anthropic-ratelimit-tokens-limit" => _, + "anthropic-ratelimit-tokens-remaining" => _, + "anthropic-ratelimit-tokens-reset" => _, + # Not always included + # "retry-after" => _, + "request-id" => _ + } = info end - } - {:ok, updated_chain} = - LLMChain.new!(%{llm: %ChatAnthropic{stream: false, callbacks: [handler]}}) - |> LLMChain.add_message(Message.new_user!("Say, 'Hi!'!")) - |> LLMChain.run() - - assert updated_chain.last_message.content == "Hi!" - assert updated_chain.last_message.status == :complete - assert updated_chain.last_message.role == :assistant - - assert_received {:received_msg, data} - assert %Message{role: :assistant} = data - - assert_received {:fired_ratelimit_info, info} - - assert %{ - "anthropic-ratelimit-requests-limit" => _, - "anthropic-ratelimit-requests-remaining" => _, - "anthropic-ratelimit-requests-reset" => _, - "anthropic-ratelimit-tokens-limit" => _, - "anthropic-ratelimit-tokens-remaining" => _, - "anthropic-ratelimit-tokens-reset" => _, - # Not always included - # "retry-after" => _, - "request-id" => _ - } = info - - assert_received {:fired_token_usage, usage} - assert %TokenUsage{input: 14} = usage - end + assert_received {:fired_token_usage, usage} + assert %TokenUsage{input: 14} = usage + end - @tag live_call: true, live_anthropic: true - test "supports continuing a conversation with streaming" do - test_pid = self() + Module.put_attribute(__MODULE__, :tag, {:"live_#{api}", true}) + @tag live_call: true, live_api: api + test "supports continuing a conversation with streaming", %{api_config: api_config} do + test_pid = self() - handler = %{ - on_llm_new_delta: fn _model, delta -> - # IO.inspect(data, label: "DATA") - send(test_pid, {:streamed_fn, delta}) - end - } + handler = %{ + on_llm_new_delta: fn _model, delta -> + # IO.inspect(data, label: "DATA") + send(test_pid, {:streamed_fn, delta}) + end + } - {:ok, updated_chain} = - LLMChain.new!(%{ - llm: %ChatAnthropic{model: @test_model, stream: true, callbacks: [handler]} - }) - |> LLMChain.add_message(Message.new_system!("You are a helpful and concise assistant.")) - |> LLMChain.add_message(Message.new_user!("Say, 'Hi!'!")) - |> LLMChain.add_message(Message.new_assistant!("Hi!")) - |> LLMChain.add_message(Message.new_user!("What's the capitol of Norway?")) - |> LLMChain.run() - - assert updated_chain.last_message.content =~ "Oslo" - assert updated_chain.last_message.status == :complete - assert updated_chain.last_message.role == :assistant - - assert_received {:streamed_fn, data} - assert %MessageDelta{role: :assistant} = data + {:ok, updated_chain} = + LLMChain.new!(%{ + llm: + ChatAnthropic.new!( + %{model: @test_model, stream: true, callbacks: [handler]} + |> Map.merge(api_config) + ) + }) + |> LLMChain.add_message(Message.new_system!("You are a helpful and concise assistant.")) + |> LLMChain.add_message(Message.new_user!("Say, 'Hi!'!")) + |> LLMChain.add_message(Message.new_assistant!("Hi!")) + |> LLMChain.add_message(Message.new_user!("What's the capitol of Norway?")) + |> LLMChain.run() + + assert updated_chain.last_message.content =~ "Oslo" + assert updated_chain.last_message.status == :complete + assert updated_chain.last_message.role == :assistant + + assert_received {:streamed_fn, data} + assert %MessageDelta{role: :assistant} = data + end end # @tag live_call: true, live_anthropic: true diff --git a/test/chat_models/chat_google_ai_test.exs b/test/chat_models/chat_google_ai_test.exs index d5c7250d..e372f68e 100644 --- a/test/chat_models/chat_google_ai_test.exs +++ b/test/chat_models/chat_google_ai_test.exs @@ -1,5 +1,4 @@ defmodule ChatModels.ChatGoogleAITest do - alias LangChain.ChatModels.ChatGoogleAI use LangChain.BaseCase doctest LangChain.ChatModels.ChatGoogleAI @@ -11,6 +10,9 @@ defmodule ChatModels.ChatGoogleAITest do alias LangChain.TokenUsage alias LangChain.MessageDelta alias LangChain.Function + alias LangChain.FunctionParam + alias LangChain.LangChainError + alias LangChain.ChatModels.ChatGoogleAI setup do {:ok, hello_world} = @@ -239,14 +241,26 @@ defmodule ChatModels.ChatGoogleAITest do assert expected == ChatGoogleAI.for_api(tool_result) end - test "expands system messages into two", %{google_ai: google_ai} do - message = "These are some instructions." - + test "adds system instruction to the request if present", %{google_ai: google_ai} do + message = "You are a helpful assistant." data = ChatGoogleAI.for_api(google_ai, [Message.new_system!(message)], []) - assert %{"contents" => [msg1, msg2]} = data - assert %{"role" => :user, "parts" => [%{"text" => ^message}]} = msg1 - assert %{"role" => :model, "parts" => [%{"text" => ""}]} = msg2 + assert %{"system_instruction" => %{"parts" => [%{"text" => ^message}]}} = data + end + + test "does not add system instruction if not present", %{google_ai: google_ai} do + data = ChatGoogleAI.for_api(google_ai, [Message.new_user!("Hello!")], []) + refute Map.has_key?(data, "system_instruction") + end + + test "raises an error if more than one system message is present", %{google_ai: google_ai} do + assert_raise LangChainError, "Google AI only supports a single System message", fn -> + ChatGoogleAI.for_api( + google_ai, + [Message.new_system!("First instruction."), Message.new_system!("Second instruction.")], + [] + ) + end end test "generates a map containing function declarations", %{ @@ -262,12 +276,68 @@ defmodule ChatModels.ChatGoogleAITest do "functionDeclarations" => [ %{ "name" => "hello_world", - "description" => "Give a hello world greeting.", - "parameters" => %{"properties" => %{}, "type" => "object"} + "description" => "Give a hello world greeting." } ] } = tool_call end + + test "handles converting functions with parameters" do + {:ok, weather} = + Function.new(%{ + name: "get_weather", + description: "Get the current weather in a given US location", + parameters: [ + FunctionParam.new!(%{ + name: "city", + type: "string", + description: "The city name, e.g. San Francisco", + required: true + }), + FunctionParam.new!(%{ + name: "state", + type: "string", + description: "The 2 letter US state abbreviation, e.g. CA, NY, UT", + required: true + }) + ], + function: fn _args, _context -> {:ok, "75 degrees"} end + }) + + assert %{ + "description" => "Get the current weather in a given US location", + "name" => "get_weather", + "parameters" => %{ + "properties" => %{ + "city" => %{ + "description" => "The city name, e.g. San Francisco", + "type" => "string" + }, + "state" => %{ + "description" => "The 2 letter US state abbreviation, e.g. CA, NY, UT", + "type" => "string" + } + }, + "required" => ["city", "state"], + "type" => "object" + } + } == ChatGoogleAI.for_api(weather) + end + + test "handles functions without parameters" do + {:ok, function} = + Function.new(%{ + name: "hello_world", + description: "Give a hello world greeting.", + parameters: [], + function: fn _args, _context -> {:ok, "Hello User!"} end + }) + + assert %{ + "description" => "Give a hello world greeting.", + "name" => "hello_world" + } == ChatGoogleAI.for_api(function) + end end describe "do_process_response/2" do @@ -289,6 +359,21 @@ defmodule ChatModels.ChatGoogleAITest do assert struct.status == :complete end + test "handles receiving a message with an empty text part", %{model: model} do + response = %{ + "candidates" => [ + %{ + "content" => %{"role" => "model", "parts" => [%{"text" => ""}]}, + "finishReason" => "STOP", + "index" => 0 + } + ] + } + + assert [%Message{} = struct] = ChatGoogleAI.do_process_response(model, response) + assert struct.content == [] + end + test "error if receiving non-text content", %{model: model} do response = %{ "candidates" => [ @@ -351,6 +436,26 @@ defmodule ChatModels.ChatGoogleAITest do assert struct.status == :incomplete end + test "handles receiving a MessageDelta with an empty text part", %{model: model} do + response = %{ + "candidates" => [ + %{ + "content" => %{ + "role" => "model", + "parts" => [%{"text" => ""}] + }, + "finishReason" => "STOP", + "index" => 0 + } + ] + } + + assert [%MessageDelta{} = struct] = + ChatGoogleAI.do_process_response(model, response, MessageDelta) + + assert struct.content == "" + end + test "handles API error messages", %{model: model} do response = %{ "error" => %{ @@ -410,6 +515,22 @@ defmodule ChatModels.ChatGoogleAITest do end end + describe "filter_text_parts/1" do + test "returns only text parts that are not nil or empty" do + parts = [ + %{"text" => "I have text"}, + %{"text" => nil}, + %{"text" => ""}, + %{"text" => "I have more text"} + ] + + assert ChatGoogleAI.filter_text_parts(parts) == [ + %{"text" => "I have text"}, + %{"text" => "I have more text"} + ] + end + end + describe "get_message_contents/1" do test "returns basic text as a ContentPart" do message = Message.new_user!("Howdy!") diff --git a/test/chat_models/chat_open_ai_test.exs b/test/chat_models/chat_open_ai_test.exs index 8ea8e067..ffde2ebf 100644 --- a/test/chat_models/chat_open_ai_test.exs +++ b/test/chat_models/chat_open_ai_test.exs @@ -12,7 +12,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do alias LangChain.Message.ToolCall alias LangChain.Message.ToolResult - @test_model "gpt-3.5-turbo" + @test_model "gpt-4o-mini-2024-07-18" @gpt4 "gpt-4-1106-preview" defp hello_world(_args, _context) do @@ -73,6 +73,25 @@ defmodule LangChain.ChatModels.ChatOpenAITest do assert model.endpoint == override_url end + + test "supports setting json_response and json_schema" do + json_schema = %{ + "type" => "object", + "properties" => %{ + "name" => %{"type" => "string"}, + "age" => %{"type" => "integer"} + } + } + + {:ok, openai} = ChatOpenAI.new(%{ + "model" => @test_model, + "json_response" => true, + "json_schema" => json_schema + }) + + assert openai.json_response == true + assert openai.json_schema == json_schema + end end describe "for_api/3" do @@ -108,6 +127,34 @@ defmodule LangChain.ChatModels.ChatOpenAITest do assert data.response_format == %{"type" => "json_object"} end + test "generates a map for an API call with JSON response and schema" do + json_schema = %{ + "type" => "object", + "properties" => %{ + "name" => %{"type" => "string"}, + "age" => %{"type" => "integer"} + } + } + + {:ok, openai} = + ChatOpenAI.new(%{ + "model" => @test_model, + "temperature" => 1, + "frequency_penalty" => 0.5, + "json_response" => true, + "json_schema" => json_schema + }) + + data = ChatOpenAI.for_api(openai, [], []) + assert data.model == @test_model + assert data.temperature == 1 + assert data.frequency_penalty == 0.5 + assert data.response_format == %{ + "type" => "json_schema", + "json_schema" => json_schema + } + end + test "generates a map for an API call with max_tokens set" do {:ok, openai} = ChatOpenAI.new(%{ @@ -419,7 +466,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do "description" => nil, "enum" => ["yellow", "red", "green"], "type" => "string" - } + } }, "required" => ["p1"] } @@ -789,7 +836,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do @tag live_call: true, live_open_ai: true test "handles when request is too large" do {:ok, chat} = - ChatOpenAI.new(%{model: "gpt-3.5-turbo-0301", seed: 0, stream: false, temperature: 1}) + ChatOpenAI.new(%{model: "gpt-4-0613", seed: 0, stream: false, temperature: 1}) {:error, reason} = ChatOpenAI.call(chat, [too_large_user_request()]) assert reason =~ "maximum context length" @@ -1233,6 +1280,31 @@ defmodule LangChain.ChatModels.ChatOpenAITest do assert parsed == [json_1, json_2] end + test "correctly parses when data content contains spaces such as python code with indentation" do + data = + "data: {\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.completion.chunk\",\"created\":1689801995,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"def my_function(x):\\n return x + 1\"},\"finish_reason\":null}]}\n\n" + + {parsed, incomplete} = ChatOpenAI.decode_stream({data, ""}) + + assert incomplete == "" + + assert parsed == [ + %{ + "id" => "chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS", + "object" => "chat.completion.chunk", + "created" => 1_689_801_995, + "model" => "gpt-4-0613", + "choices" => [ + %{ + "index" => 0, + "delta" => %{"content" => "def my_function(x):\n return x + 1"}, + "finish_reason" => nil + } + ] + } + ] + end + test "correctly parses when data split over received messages", %{json_1: json_1} do # split the data over multiple messages data = @@ -1305,7 +1377,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do @tag live_call: true, live_open_ai: true test "supports multi-modal user message with image prompt" do # https://platform.openai.com/docs/guides/vision - {:ok, chat} = ChatOpenAI.new(%{model: "gpt-4-vision-preview", seed: 0}) + {:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o-2024-08-06", seed: 0}) url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" @@ -1866,8 +1938,53 @@ defmodule LangChain.ChatModels.ChatOpenAITest do "stream_options" => %{"include_usage" => true}, "temperature" => 0.0, "version" => 1, + "json_schema" => nil, "module" => "Elixir.LangChain.ChatModels.ChatOpenAI" } end end + + describe "set_response_format/1" do + test "generates a map for an API call with text format when json_response is false" do + {:ok, openai} = ChatOpenAI.new(%{ + model: @test_model, + json_response: false + }) + data = ChatOpenAI.for_api(openai, [], []) + + assert data.response_format == %{"type" => "text"} + end + + test "generates a map for an API call with json_object format when json_response is true and no schema" do + {:ok, openai} = ChatOpenAI.new(%{ + model: @test_model, + json_response: true + }) + data = ChatOpenAI.for_api(openai, [], []) + + assert data.response_format == %{"type" => "json_object"} + end + + test "generates a map for an API call with json_schema format when json_response is true and schema is provided" do + json_schema = %{ + "type" => "object", + "properties" => %{ + "name" => %{"type" => "string"}, + "age" => %{"type" => "integer"} + } + } + + {:ok, openai} = ChatOpenAI.new(%{ + model: @test_model, + json_response: true, + json_schema: json_schema + }) + data = ChatOpenAI.for_api(openai, [], []) + + assert data.response_format == %{ + "type" => "json_schema", + "json_schema" => json_schema + } + end + end end diff --git a/test/message/tool_call_test.exs b/test/message/tool_call_test.exs index 31bc9e25..1770b7d1 100644 --- a/test/message/tool_call_test.exs +++ b/test/message/tool_call_test.exs @@ -52,6 +52,49 @@ defmodule LangChain.Message.ToolCallTest do assert msg.call_id == "call_asdf" assert msg.index == 0 end + + test "can preserve spaces in arguments" do + assert {:ok, %ToolCall{} = msg} = + ToolCall.new(%{ + "status" => :incomplete, + "type" => "function", + "index" => 0, + "call_id" => "call_asdf", + "name" => "hello_world", + "arguments" => "{\"code\": \"def my_function(x):\n return x + 1\"}" + }) + + assert msg.arguments == "{\"code\": \"def my_function(x):\n return x + 1\"}" + end + + test "casts spaces in arguments as spaces" do + one_space = " " + assert {:ok, %ToolCall{} = msg} = + ToolCall.new(%{ + "status" => :incomplete, + "type" => "function", + "index" => 0, + "call_id" => "call_asdf", + "name" => "hello_world", + "arguments" => one_space + }) + + assert msg.arguments == one_space + + # Multiple spaces + four_spaces = " " + assert {:ok, %ToolCall{} = msg} = + ToolCall.new(%{ + "status" => :incomplete, + "type" => "function", + "index" => 0, + "call_id" => "call_asdf", + "name" => "hello_world", + "arguments" => four_spaces + }) + + assert msg.arguments == four_spaces + end end describe "complete/1" do @@ -121,6 +164,20 @@ defmodule LangChain.Message.ToolCallTest do assert result == received end + test "does not duplicate incomplete call" do + received = %ToolCall{ + status: :incomplete, + type: :function, + call_id: nil, + name: "get_weather", + arguments: nil, + index: 0 + } + + result = ToolCall.merge(received, received) + assert result == received + end + test "updates tool name" do call_1 = %ToolCall{ status: :incomplete, @@ -252,5 +309,41 @@ defmodule LangChain.Message.ToolCallTest do ToolCall.merge(call_1, call_2) end end + + test "preserves empty spaces when merging arguments" do + call_1 = %ToolCall{ + status: :incomplete, + type: :function, + name: "get_code", + # 2 spaces + arguments: "{\"code\": \"def my_function(x):\n ", + index: 0 + } + + call_2 = %ToolCall{ + status: :incomplete, + type: :function, + name: "get_code", + # one space + arguments: " ", + index: 0 + } + + call_3 = %ToolCall{ + status: :incomplete, + type: :function, + name: "get_code", + # one space + arguments: " return x + 1\"}", + index: 0 + } + + result = + call_1 + |> ToolCall.merge(call_2) + |> ToolCall.merge(call_3) + + assert result.arguments == "{\"code\": \"def my_function(x):\n return x + 1\"}" + end end end diff --git a/test/message_delta_test.exs b/test/message_delta_test.exs index 1c0b6511..591908f8 100644 --- a/test/message_delta_test.exs +++ b/test/message_delta_test.exs @@ -175,6 +175,136 @@ defmodule LangChain.MessageDeltaTest do assert merged == expected end + test "correctly merge message with tool_call containing empty spaces" do + first_delta = + %LangChain.MessageDelta{ + content: "", + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: nil + } + + deltas = [ + %LangChain.MessageDelta{ + content: "stu", + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: nil + }, + %LangChain.MessageDelta{ + content: "ff", + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: nil + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: [ + %LangChain.Message.ToolCall{ + status: :incomplete, + type: :function, + call_id: "toolu_123", + name: "get_code", + arguments: nil, + index: 1 + } + ] + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: [ + %LangChain.Message.ToolCall{ + status: :incomplete, + type: :function, + call_id: "toolu_123", + name: "get_code", + arguments: "{\"code\": \"def my_function(x):\n ", + index: 1 + } + ] + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: [ + %LangChain.Message.ToolCall{ + status: :incomplete, + type: :function, + call_id: "toolu_123", + name: "get_code", + arguments: " ", + index: 1 + } + ] + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: [ + %LangChain.Message.ToolCall{ + status: :incomplete, + type: :function, + call_id: "toolu_123", + name: "get_code", + arguments: " ", + index: 1 + } + ] + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: [ + %LangChain.Message.ToolCall{ + status: :incomplete, + type: :function, + call_id: "toolu_123", + name: "get_code", + arguments: "return x + 1\"}", + index: 1 + } + ] + } + ] + + merged = + Enum.reduce(deltas, first_delta, fn new_delta, acc -> + MessageDelta.merge_delta(acc, new_delta) + end) + + assert merged == %LangChain.MessageDelta{ + content: "stuff", + status: :incomplete, + index: nil, + role: :assistant, + tool_calls: [ + %LangChain.Message.ToolCall{ + status: :incomplete, + type: :function, + call_id: "toolu_123", + name: "get_code", + arguments: "{\"code\": \"def my_function(x):\n return x + 1\"}", + index: 1 + } + ] + } + end + test "correctly merges message with tool_call split over multiple deltas and index is not by position" do first_delta = %LangChain.MessageDelta{ diff --git a/test/support/bedrock_helpers.ex b/test/support/bedrock_helpers.ex new file mode 100644 index 00000000..540928ae --- /dev/null +++ b/test/support/bedrock_helpers.ex @@ -0,0 +1,17 @@ +defmodule LangChain.BedrockHelpers do + def bedrock_config() do + %{ + credentials: fn -> + [ + access_key_id: Application.fetch_env!(:langchain, :aws_access_key_id), + secret_access_key: Application.fetch_env!(:langchain, :aws_secret_access_key) + ] + end, + region: "us-east-1" + } + end + + def prefix_for(api) do + "#{api} API:" + end +end diff --git a/test/support/fixtures.ex b/test/support/fixtures.ex index b35353c3..ff234c4a 100644 --- a/test/support/fixtures.ex +++ b/test/support/fixtures.ex @@ -409,7 +409,7 @@ defmodule LangChain.Fixtures do end def too_large_user_request() do - Message.new_user!("Analyze the following text: \n\n" <> text_chunks(8)) + Message.new_user!("Analyze the following text: \n\n" <> text_chunks(16)) end def results_in_too_long_response() do diff --git a/test/test_helper.exs b/test/test_helper.exs index 02f2e53a..12ad4d25 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -2,6 +2,16 @@ Application.put_env(:langchain, :openai_key, System.fetch_env!("OPENAI_API_KEY")) Application.put_env(:langchain, :anthropic_key, System.fetch_env!("ANTHROPIC_API_KEY")) Application.put_env(:langchain, :google_ai_key, System.fetch_env!("GOOGLE_API_KEY")) +Application.put_env(:langchain, :aws_access_key_id, System.fetch_env!("AWS_ACCESS_KEY_ID")) + +Application.put_env( + :langchain, + :aws_secret_access_key, + System.fetch_env!("AWS_SECRET_ACCESS_KEY") +) + +Mimic.copy(LangChain.Utils.BedrockStreamDecoder) +Mimic.copy(LangChain.Utils.AwsEventstreamDecoder) Mimic.copy(LangChain.ChatModels.ChatOpenAI) Mimic.copy(LangChain.ChatModels.ChatAnthropic) diff --git a/test/utils/aws_eventstream_decoder_test.exs b/test/utils/aws_eventstream_decoder_test.exs new file mode 100644 index 00000000..232cd65f --- /dev/null +++ b/test/utils/aws_eventstream_decoder_test.exs @@ -0,0 +1,86 @@ +defmodule LangChain.Utils.AwsEventstreamDecoderTest do + use ExUnit.Case + + # https://github.com/aws/aws-sdk-ruby/tree/version-3/gems/aws-eventstream/spec/fixtures/encoded + @corrupted_length <<0, 0, 0, 62, 0, 0, 0, 32, 7, 253, 131, 150, 12, 99, 111, 110, 116, 101, 110, + 116, 45, 116, 121, 112, 101, 7, 0, 16, 97, 112, 112, 108, 105, 99, 97, 116, + 105, 111, 110, 47, 106, 115, 111, 110, 123, 39, 102, 111, 111, 39, 58, 39, + 98, 97, 114, 39, 125, 141, 156, 8, 177>> + @corrupted_payload <<0, 0, 0, 29, 0, 0, 0, 0, 253, 82, 140, 90, 91, 39, 102, 111, 111, 39, 58, + 39, 98, 97, 114, 39, 125, 195, 101, 57, 54>> + @corrupted_headers <<0, 0, 0, 61, 0, 0, 0, 32, 7, 253, 131, 150, 12, 99, 111, 110, 116, 101, + 110, 116, 45, 116, 121, 112, 101, 7, 0, 16, 97, 112, 112, 108, 105, 99, 97, + 116, 105, 111, 110, 47, 106, 115, 111, 110, 123, 97, 102, 111, 111, 39, 58, + 39, 98, 97, 114, 39, 125, 141, 156, 8, 177>> + @corrupted_header_len <<0, 0, 0, 61, 0, 0, 0, 33, 7, 253, 131, 150, 12, 99, 111, 110, 116, 101, + 110, 116, 45, 116, 121, 112, 101, 7, 0, 16, 97, 112, 112, 108, 105, 99, + 97, 116, 105, 111, 110, 47, 106, 115, 111, 110, 123, 39, 102, 111, 111, + 39, 58, 39, 98, 97, 114, 39, 125, 141, 156, 8, 177>> + @all_headers <<0, 0, 0, 204, 0, 0, 0, 175, 15, 174, 100, 202, 10, 101, 118, 101, 110, 116, 45, + 116, 121, 112, 101, 4, 0, 0, 160, 12, 12, 99, 111, 110, 116, 101, 110, 116, 45, + 116, 121, 112, 101, 7, 0, 16, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, + 47, 106, 115, 111, 110, 10, 98, 111, 111, 108, 32, 102, 97, 108, 115, 101, 1, 9, + 98, 111, 111, 108, 32, 116, 114, 117, 101, 0, 4, 98, 121, 116, 101, 2, 207, 8, + 98, 121, 116, 101, 32, 98, 117, 102, 6, 0, 20, 73, 39, 109, 32, 97, 32, 108, 105, + 116, 116, 108, 101, 32, 116, 101, 97, 112, 111, 116, 33, 9, 116, 105, 109, 101, + 115, 116, 97, 109, 112, 8, 0, 0, 0, 0, 0, 132, 95, 237, 5, 105, 110, 116, 49, 54, + 3, 0, 42, 5, 105, 110, 116, 54, 52, 5, 0, 0, 0, 0, 2, 135, 87, 178, 4, 117, 117, + 105, 100, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 123, 39, 102, + 111, 111, 39, 58, 39, 98, 97, 114, 39, 125, 171, 165, 241, 12>> + @empty_message <<0, 0, 0, 16, 0, 0, 0, 0, 5, 194, 72, 235, 125, 152, 200, 255>> + @int32_header <<0, 0, 0, 45, 0, 0, 0, 16, 65, 196, 36, 184, 10, 101, 118, 101, 110, 116, 45, + 116, 121, 112, 101, 4, 0, 0, 160, 12, 123, 39, 102, 111, 111, 39, 58, 39, 98, + 97, 114, 39, 125, 54, 244, 128, 160>> + @payload_no_headers <<0, 0, 0, 29, 0, 0, 0, 0, 253, 82, 140, 90, 123, 39, 102, 111, 111, 39, 58, + 39, 98, 97, 114, 39, 125, 195, 101, 57, 54>> + @payload_one_str_header <<0, 0, 0, 61, 0, 0, 0, 32, 7, 253, 131, 150, 12, 99, 111, 110, 116, + 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 16, 97, 112, 112, 108, + 105, 99, 97, 116, 105, 111, 110, 47, 106, 115, 111, 110, 123, 39, 102, + 111, 111, 39, 58, 39, 98, 97, 114, 39, 125, 141, 156, 8, 177>> + + @payload_no_headers_twice <<@payload_no_headers::bitstring, @payload_no_headers::bitstring>> + + test "corrupted_length" do + assert decode(@corrupted_length) == + {:incomplete_message, "Expected message length 62 but got 61"} + end + + test "corrupted_payload" do + assert decode(@corrupted_payload) == {:error, "Checksum mismatch for message"} + end + + test "corrupted_headers" do + assert decode(@corrupted_headers) == {:error, "Checksum mismatch for message"} + end + + test "corrupted_header_len" do + assert decode(@corrupted_header_len) == {:error, "Checksum mismatch for prelude"} + end + + test "all_headers" do + assert decode(@all_headers) == {:ok, "{'foo':'bar'}", ""} + end + + test "empty_message" do + assert decode(@empty_message) == {:ok, "", ""} + end + + test "int32_header" do + assert decode(@int32_header) == {:ok, "{'foo':'bar'}", ""} + end + + test "payload_no_headers" do + assert decode(@payload_no_headers) == {:ok, "{'foo':'bar'}", ""} + end + + test "payload_one_str_header" do + assert decode(@payload_one_str_header) == {:ok, "{'foo':'bar'}", ""} + end + + test "payload_no_headers_twice" do + {:ok, "{'foo':'bar'}", remaining} = decode(@payload_no_headers_twice) + assert decode(remaining) == {:ok, "{'foo':'bar'}", ""} + end + + defdelegate decode(data), to: LangChain.Utils.AwsEventstreamDecoder +end diff --git a/test/utils/bedrock_config_test.exs b/test/utils/bedrock_config_test.exs new file mode 100644 index 00000000..6e0d223d --- /dev/null +++ b/test/utils/bedrock_config_test.exs @@ -0,0 +1,33 @@ +defmodule LangChain.Utils.BedrockConfigTest do + alias LangChain.Utils.BedrockConfig + use ExUnit.Case, async: true + + test "supports aws credentials without session token" do + bedrock_config = %BedrockConfig{ + credentials: fn -> [access_key_id: "KEY", secret_access_key: "SECRET"] end, + region: "us-east-1" + } + + assert BedrockConfig.aws_sigv4_opts(bedrock_config) == [ + access_key_id: "KEY", + secret_access_key: "SECRET", + region: "us-east-1", + service: :bedrock + ] + end + + test "supports aws credentials with session token" do + bedrock_config = %BedrockConfig{ + credentials: fn -> [access_key_id: "KEY", secret_access_key: "SECRET", token: "TOKEN"] end, + region: "ap-southeast-2" + } + + assert BedrockConfig.aws_sigv4_opts(bedrock_config) == [ + access_key_id: "KEY", + secret_access_key: "SECRET", + token: "TOKEN", + region: "ap-southeast-2", + service: :bedrock + ] + end +end diff --git a/test/utils/bedrock_stream_decoder_test.exs b/test/utils/bedrock_stream_decoder_test.exs new file mode 100644 index 00000000..adac6871 --- /dev/null +++ b/test/utils/bedrock_stream_decoder_test.exs @@ -0,0 +1,79 @@ +defmodule LangChain.Utils.BedrockStreamDecoderTest do + use ExUnit.Case + use Mimic + alias LangChain.Utils.AwsEventstreamDecoder + alias LangChain.Utils.BedrockStreamDecoder + + @message <<0, 0, 1, 17, 0, 0, 0, 75, 18, 240, 42, 230, 11, 58, 101, 118, 101, 110, 116, 45, 116, + 121, 112, 101, 7, 0, 5, 99, 104, 117, 110, 107, 13, 58, 99, 111, 110, 116, 101, 110, + 116, 45, 116, 121, 112, 101, 7, 0, 16, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, + 110, 47, 106, 115, 111, 110, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, + 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 123, 34, 98, 121, 116, 101, 115, 34, 58, + 34, 101, 121, 74, 48, 101, 88, 66, 108, 73, 106, 111, 105, 89, 50, 57, 117, 100, 71, + 86, 117, 100, 70, 57, 105, 98, 71, 57, 106, 97, 49, 57, 107, 90, 87, 120, 48, 89, 83, + 73, 115, 73, 109, 108, 117, 90, 71, 86, 52, 73, 106, 111, 119, 76, 67, 74, 107, 90, + 87, 120, 48, 89, 83, 73, 54, 101, 121, 74, 48, 101, 88, 66, 108, 73, 106, 111, 105, + 100, 71, 86, 52, 100, 70, 57, 107, 90, 87, 120, 48, 89, 83, 73, 115, 73, 110, 82, + 108, 101, 72, 81, 105, 79, 105, 74, 68, 98, 50, 120, 118, 99, 109, 90, 49, 98, 67, + 66, 85, 97, 72, 74, 108, 89, 87, 82, 122, 73, 110, 49, 57, 34, 44, 34, 112, 34, 58, + 34, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, + 114, 115, 116, 117, 118, 119, 120, 121, 122, 65, 66, 67, 68, 69, 70, 71, 72, 73, 34, + 125, 181, 231, 17, 159>> + @message_decoded %{ + "delta" => %{"text" => "Colorful Threads", "type" => "text_delta"}, + "index" => 0, + "type" => "content_block_delta" + } + @message_twice @message <> @message + + describe "integration" do + test "decodes a single message" do + assert decode_stream({@message, ""}) == {[@message_decoded], ""} + end + + test "decodes multiple messages" do + assert decode_stream({@message_twice, ""}) == {[@message_decoded, @message_decoded], ""} + end + + test "returns buffer of incomplete messages" do + <> = + @message_twice + + <> = @message + + {chunks, buffer} = decode_stream({incomplete, ""}) + assert chunks == [@message_decoded] + assert buffer == expected_buffer + {chunks, buffer} = decode_stream({rest, buffer}) + assert chunks == [@message_decoded] + assert buffer == "" + end + end + + describe "unit" do + @success_message_base64 %{ + "type" => "content_block_start" + } + |> Jason.encode!() + |> Base.encode64() + @success_message %{"bytes" => @success_message_base64} |> Jason.encode!() + + @exception_message %{ + "internalServerError" => %{} + } + |> Jason.encode!() + + test "it passes through successfully decoded messages" do + stub(AwsEventstreamDecoder, :decode, fn _ -> {:ok, @success_message, ""} end) + assert decode_stream({"", ""}) == {[%{"type" => "content_block_start"}], ""} + end + + test "it passes through an exception_message when bytes aren't present" do + stub(AwsEventstreamDecoder, :decode, fn _ -> {:ok, @exception_message, ""} end) + message = %{"internalServerError" => %{}, bedrock_exception: "internalServerError"} + assert decode_stream({"", ""}) == {[message], ""} + end + end + + defdelegate decode_stream(data), to: BedrockStreamDecoder +end diff --git a/test/utils_test.exs b/test/utils_test.exs index fe8c40cc..d32defe4 100644 --- a/test/utils_test.exs +++ b/test/utils_test.exs @@ -2,6 +2,7 @@ defmodule LangChain.UtilsTest do use ExUnit.Case doctest LangChain.Utils + alias LangChain.Message alias LangChain.ChatModels.ChatOpenAI alias LangChain.Utils @@ -168,4 +169,39 @@ defmodule LangChain.UtilsTest do assert reason == "ChatModel module \"Elixir.Missing.Module\" not found" end end + + describe "split_system_message/2" do + test "returns system message and rest separately" do + system = Message.new_system!() + user_msg = Message.new_user!("Hi") + assert {system, [user_msg]} == Utils.split_system_message([system, user_msg]) + end + + test "return nil when no system message set" do + user_msg = Message.new_user!("Hi") + assert {nil, [user_msg]} == Utils.split_system_message([user_msg]) + end + + test "raises exception with multiple system messages" do + error_message = "Anthropic only supports a single System message" + + assert_raise LangChain.LangChainError, + error_message, + fn -> + system = Message.new_system!() + user_msg = Message.new_user!("Hi") + Utils.split_system_message([system, user_msg, system], error_message) + end + end + + test "has a default error message when no error message provided" do + assert_raise LangChain.LangChainError, + "Only one system message is allowed", + fn -> + system = Message.new_system!() + user_msg = Message.new_user!("Hi") + Utils.split_system_message([system, user_msg, system]) + end + end + end end