diff --git a/CHANGELOG.md b/CHANGELOG.md index c29006f6..3b24853e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,14 @@ * The attribute `processed_content` was added to a `LangChain.Message`. When a MessageProcessor is run on a received assistant message, the results of the processing are accumulated there. The original `content` remains unchanged for when it is sent back to the LLM and used when fixing or correcting it's generated content. * Callback support for LLM ratelimit information returned in API response headers. These are currently implemented for Anthropic and OpenAI. * Callback support for LLM token usage information returned when available. +* `LangChain.ChatModels.ChatModel` additions + * Added `add_callback/2` makes it easier to add a callback to an chat model. + * Added `serialize_config/1` to serialize an LLM chat model configuration to a map that can be restored later. + * Added `restore_from_map/1` to restore a configured LLM chat model from a database (for example). +* `LangChain.Chain.LLMChain` additions + * New function `add_callback/2` makes it easier to add a callback to an existing `LLMChain`. + * New function `add_llm_callback/2` makes it easier to add a callback to a chain's LLM. This is particularly useful when an LLM model is restored from a database when loading a past conversation and wanting to preserve the original configuration. + **Changed:** @@ -22,6 +30,7 @@ * Many smaller changes and contributions were made. This includes updates to the README for clarity, * `LangChain.Utils.fire_callback/3` was refactored into `LangChain.Utils.fire_streamed_callback/2` where it is only used for processing deltas and uses the new callback mechanism. * Notebooks were moved to the separate demo project +* `LangChain.ChatModels.ChatGoogleAI`'s key `:version` was changed to `:api_version` to be more consistent with other models and allow for model serializers to use the `:version` key. ### Migrations Steps diff --git a/lib/chains/llm_chain.ex b/lib/chains/llm_chain.ex index 71a90d96..e4850ecd 100644 --- a/lib/chains/llm_chain.ex +++ b/lib/chains/llm_chain.ex @@ -45,6 +45,7 @@ defmodule LangChain.Chains.LLMChain do use Ecto.Schema import Ecto.Changeset require Logger + alias LangChain.ChatModels.ChatModel alias LangChain.Callbacks alias LangChain.Chains.ChainCallbacks alias LangChain.PromptTemplate @@ -819,6 +820,15 @@ defmodule LangChain.Chains.LLMChain do %LLMChain{chain | callbacks: callbacks ++ [additional_callback]} end + @doc """ + Add a `LangChain.ChatModels.LLMCallbacks` callback map to the chain's `:llm` model if + it supports the `:callback` key. + """ + @spec add_llm_callback(t(), map()) :: t() + def add_llm_callback(%LLMChain{llm: model} = chain, callback_map) do + %LLMChain{chain | llm: ChatModel.add_callback(model, callback_map)} + end + # a pipe-friendly execution of callbacks that returns the chain defp fire_callback_and_return(%LLMChain{} = chain, callback_name, additional_arguments) when is_list(additional_arguments) do diff --git a/lib/chat_models/chat_anthropic.ex b/lib/chat_models/chat_anthropic.ex index ba8cdb8a..280944d1 100644 --- a/lib/chat_models/chat_anthropic.ex +++ b/lib/chat_models/chat_anthropic.ex @@ -58,6 +58,8 @@ defmodule LangChain.ChatModels.ChatAnthropic do @behaviour ChatModel + @current_config_version 1 + # allow up to 1 minute for response. @receive_timeout 60_000 @@ -866,4 +868,35 @@ defmodule LangChain.ChatModels.ChatAnthropic do end defp get_token_usage(_response_body), do: %{} + + @doc """ + Generate a config map that can later restore the model's configuration. + """ + @impl ChatModel + @spec serialize_config(t()) :: %{String.t() => any()} + def serialize_config(%ChatAnthropic{} = model) do + Utils.to_serializable_map( + model, + [ + :endpoint, + :model, + :api_version, + :temperature, + :max_tokens, + :receive_timeout, + :top_p, + :top_k, + :stream + ], + @current_config_version + ) + end + + @doc """ + Restores the model from the config. + """ + @impl ChatModel + def restore_from_map(%{"version" => 1} = data) do + ChatAnthropic.new(data) + end end diff --git a/lib/chat_models/chat_bumblebee.ex b/lib/chat_models/chat_bumblebee.ex index ca89a638..89d73849 100644 --- a/lib/chat_models/chat_bumblebee.ex +++ b/lib/chat_models/chat_bumblebee.ex @@ -112,6 +112,8 @@ defmodule LangChain.ChatModels.ChatBumblebee do @behaviour ChatModel + @current_config_version 1 + @primary_key false embedded_schema do # Name of the Nx.Serving to use when working with the LLM. @@ -164,6 +166,7 @@ defmodule LangChain.ChatModels.ChatBumblebee do def new(%{} = attrs \\ %{}) do %ChatBumblebee{} |> cast(attrs, @create_fields) + |> restore_serving_if_string() |> common_validation() |> apply_action(:insert) end @@ -182,6 +185,22 @@ defmodule LangChain.ChatModels.ChatBumblebee do end end + defp restore_serving_if_string(changeset) do + case get_field(changeset, :serving) do + value when is_binary(value) -> + case Utils.module_from_name(value) do + {:ok, module} -> + put_change(changeset, :serving, module) + + {:error, reason} -> + add_error(changeset, :serving, reason) + end + + _other -> + changeset + end + end + defp common_validation(changeset) do changeset |> validate_required(@required_fields) @@ -322,4 +341,30 @@ defmodule LangChain.ChatModels.ChatBumblebee do end defp fire_token_usage_callback(_model, _token_summary), do: :ok + + @doc """ + Generate a config map that can later restore the model's configuration. + """ + @impl ChatModel + @spec serialize_config(t()) :: %{String.t() => any()} + def serialize_config(%ChatBumblebee{} = model) do + Utils.to_serializable_map( + model, + [ + :serving, + :template_format, + :stream, + :seed + ], + @current_config_version + ) + end + + @doc """ + Restores the model from the config. + """ + @impl ChatModel + def restore_from_map(%{"version" => 1} = data) do + ChatBumblebee.new(data) + end end diff --git a/lib/chat_models/chat_google_ai.ex b/lib/chat_models/chat_google_ai.ex index 0fd8b911..99773981 100644 --- a/lib/chat_models/chat_google_ai.ex +++ b/lib/chat_models/chat_google_ai.ex @@ -18,9 +18,12 @@ defmodule LangChain.ChatModels.ChatGoogleAI do alias LangChain.Message.ToolResult alias LangChain.LangChainError alias LangChain.Utils + alias LangChain.Callbacks @behaviour ChatModel + @current_config_version 1 + @default_base_url "https://generativelanguage.googleapis.com" @default_api_version "v1beta" @default_endpoint "#{@default_base_url}/#{@default_api_version}" @@ -33,7 +36,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do field :endpoint, :string, default: @default_endpoint # The version of the API to use. - field :version, :string, default: @default_api_version + field :api_version, :string, default: @default_api_version field :model, :string, default: "gemini-pro" field :api_key, :string @@ -65,24 +68,28 @@ defmodule LangChain.ChatModels.ChatGoogleAI do field :receive_timeout, :integer, default: @receive_timeout field :stream, :boolean, default: false + + # A list of maps for callback handlers + field :callbacks, {:array, :map}, default: [] end @type t :: %ChatGoogleAI{} @create_fields [ :endpoint, - :version, + :api_version, :model, :api_key, :temperature, :top_p, :top_k, :receive_timeout, - :stream + :stream, + :callbacks ] @required_fields [ :endpoint, - :version, + :api_version, :model ] @@ -234,21 +241,21 @@ defmodule LangChain.ChatModels.ChatGoogleAI do `LangChain.Message` once fully complete. """ @impl ChatModel - def call(openai, prompt, tools \\ [], callback_fn \\ nil) + def call(openai, prompt, tools \\ []) - def call(%ChatGoogleAI{} = google_ai, prompt, tools, callback_fn) when is_binary(prompt) do + def call(%ChatGoogleAI{} = google_ai, prompt, tools) when is_binary(prompt) do messages = [ Message.new_system!(), Message.new_user!(prompt) ] - call(google_ai, messages, tools, callback_fn) + call(google_ai, messages, tools) end - def call(%ChatGoogleAI{} = google_ai, messages, tools, callback_fn) + def call(%ChatGoogleAI{} = google_ai, messages, tools) when is_list(messages) do try do - case do_api_request(google_ai, messages, tools, callback_fn) do + case do_api_request(google_ai, messages, tools) do {:error, reason} -> {:error, reason} @@ -262,9 +269,9 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end @doc false - @spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) :: + @spec do_api_request(t(), [Message.t()], [Function.t()]) :: list() | struct() | {:error, String.t()} - def do_api_request(%ChatGoogleAI{stream: false} = google_ai, messages, tools, callback_fn) do + def do_api_request(%ChatGoogleAI{stream: false} = google_ai, messages, tools) do req = Req.new( url: build_url(google_ai), @@ -279,12 +286,12 @@ defmodule LangChain.ChatModels.ChatGoogleAI do |> Req.post() |> case do {:ok, %Req.Response{body: data}} -> - case do_process_response(data) do + case do_process_response(google_ai, data) do {:error, reason} -> {:error, reason} result -> - Utils.fire_callback(google_ai, result, callback_fn) + Callbacks.fire(google_ai.callbacks, :on_llm_new_message, [google_ai, result]) result end @@ -297,7 +304,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end end - def do_api_request(%ChatGoogleAI{stream: true} = google_ai, messages, tools, callback_fn) do + def do_api_request(%ChatGoogleAI{stream: true} = google_ai, messages, tools) do Req.new( url: build_url(google_ai), json: for_api(google_ai, messages, tools), @@ -309,8 +316,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do Utils.handle_stream_fn( google_ai, &ChatOpenAI.decode_stream/1, - &do_process_response(&1, MessageDelta), - callback_fn + &do_process_response(google_ai, &1, MessageDelta) ) ) |> case do @@ -336,8 +342,10 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end @spec build_url(t()) :: String.t() - defp build_url(%ChatGoogleAI{endpoint: endpoint, version: version, model: model} = google_ai) do - "#{endpoint}/#{version}/models/#{model}:#{get_action(google_ai)}?key=#{get_api_key(google_ai)}" + defp build_url( + %ChatGoogleAI{endpoint: endpoint, api_version: api_version, model: model} = google_ai + ) do + "#{endpoint}/#{api_version}/models/#{model}:#{get_action(google_ai)}?key=#{get_api_key(google_ai)}" |> use_sse(google_ai) end @@ -353,14 +361,19 @@ defmodule LangChain.ChatModels.ChatGoogleAI do update_in(data, [Access.at(-1), Access.at(-1)], &%{&1 | status: :complete}) end - def do_process_response(response, message_type \\ Message) + def do_process_response(model, response, message_type \\ Message) - def do_process_response(%{"candidates" => candidates}, message_type) when is_list(candidates) do + def do_process_response(model, %{"candidates" => candidates}, message_type) + when is_list(candidates) do candidates - |> Enum.map(&do_process_response(&1, message_type)) + |> Enum.map(&do_process_response(model, &1, message_type)) end - def do_process_response(%{"content" => %{"parts" => parts} = content_data} = data, Message) do + def do_process_response( + model, + %{"content" => %{"parts" => parts} = content_data} = data, + Message + ) do text_part = parts |> filter_parts_for_types(["text"]) @@ -372,14 +385,14 @@ defmodule LangChain.ChatModels.ChatGoogleAI do parts |> filter_parts_for_types(["functionCall"]) |> Enum.map(fn part -> - do_process_response(part, nil) + do_process_response(model, part, nil) end) tool_result_from_parts = parts |> filter_parts_for_types(["functionResponse"]) |> Enum.map(fn part -> - do_process_response(part, nil) + do_process_response(model, part, nil) end) %{ @@ -400,7 +413,11 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end end - def do_process_response(%{"content" => %{"parts" => parts} = content_data} = data, MessageDelta) do + def do_process_response( + model, + %{"content" => %{"parts" => parts} = content_data} = data, + MessageDelta + ) do text_content = case parts do [%{"text" => text}] -> @@ -420,7 +437,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do parts |> filter_parts_for_types(["functionCall"]) |> Enum.map(fn part -> - do_process_response(part, nil) + do_process_response(model, part, nil) end) %{ @@ -440,7 +457,11 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end end - def do_process_response(%{"functionCall" => %{"args" => raw_args, "name" => name}} = data, _) do + def do_process_response( + _model, + %{"functionCall" => %{"args" => raw_args, "name" => name}} = data, + _ + ) do %{ call_id: "call-#{name}", name: name, @@ -459,6 +480,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end def do_process_response( + _model, %{ "finishReason" => finish, "content" => %{"parts" => parts, "role" => role}, @@ -502,18 +524,18 @@ defmodule LangChain.ChatModels.ChatGoogleAI do end end - def do_process_response(%{"error" => %{"message" => reason}}, _) do + def do_process_response(_model, %{"error" => %{"message" => reason}}, _) do Logger.error("Received error from API: #{inspect(reason)}") {:error, reason} end - def do_process_response({:error, %Jason.DecodeError{} = response}, _) do + def do_process_response(_model, {:error, %Jason.DecodeError{} = response}, _) do error_message = "Received invalid JSON: #{inspect(response)}" Logger.error(error_message) {:error, error_message} end - def do_process_response(other, _) do + def do_process_response(_model, other, _) do Logger.error("Trying to process an unexpected response. #{inspect(other)}") {:error, "Unexpected response"} end @@ -554,4 +576,34 @@ defmodule LangChain.ChatModels.ChatGoogleAI do defp unmap_role("model"), do: "assistant" defp unmap_role("function"), do: "tool" defp unmap_role(role), do: role + + @doc """ + Generate a config map that can later restore the model's configuration. + """ + @impl ChatModel + @spec serialize_config(t()) :: %{String.t() => any()} + def serialize_config(%ChatGoogleAI{} = model) do + Utils.to_serializable_map( + model, + [ + :endpoint, + :model, + :api_version, + :temperature, + :top_p, + :top_k, + :receive_timeout, + :stream + ], + @current_config_version + ) + end + + @doc """ + Restores the model from the config. + """ + @impl ChatModel + def restore_from_map(%{"version" => 1} = data) do + ChatGoogleAI.new(data) + end end diff --git a/lib/chat_models/chat_mistral_ai.ex b/lib/chat_models/chat_mistral_ai.ex index 278bf011..aa949736 100644 --- a/lib/chat_models/chat_mistral_ai.ex +++ b/lib/chat_models/chat_mistral_ai.ex @@ -12,8 +12,11 @@ defmodule Langchain.ChatModels.ChatMistralAI do alias LangChain.MessageDelta alias LangChain.LangChainError alias LangChain.Utils + alias LangChain.Callbacks @behaviour ChatModel + + @current_config_version 1 @receive_timeout 60_000 @default_endpoint "https://api.mistral.ai/v1/chat/completions" @@ -51,6 +54,9 @@ defmodule Langchain.ChatModels.ChatMistralAI do field :random_seed, :integer field :stream, :boolean, default: false + + # A list of maps for callback handlers + field :callbacks, {:array, :map}, default: [] end @type t :: %ChatMistralAI{} @@ -65,7 +71,8 @@ defmodule Langchain.ChatModels.ChatMistralAI do :max_tokens, :safe_prompt, :random_seed, - :stream + :stream, + :callbacks ] @required_fields [ :model @@ -127,25 +134,25 @@ defmodule Langchain.ChatModels.ChatMistralAI do end @impl ChatModel - def call(mistral, prompt, functions \\ [], callback_fn \\ nil) + def call(mistral, prompt, functions \\ []) - def call(%ChatMistralAI{} = mistral, prompt, functions, callback_fn) when is_binary(prompt) do + def call(%ChatMistralAI{} = mistral, prompt, functions) when is_binary(prompt) do messages = [ Message.new_system!(), Message.new_user!(prompt) ] - call(mistral, messages, functions, callback_fn) + call(mistral, messages, functions) end - def call(%ChatMistralAI{} = mistral, messages, functions, callback_fn) when is_list(messages) do + def call(%ChatMistralAI{} = mistral, messages, functions) when is_list(messages) do if override_api_return?() do Logger.warning("Found override API response. Will not make live API call.") case get_api_override() do - {:ok, {:ok, data} = response} -> + {:ok, {:ok, data, callback_name} = response} -> # fire callback for fake responses too - Utils.fire_callback(mistral, data, callback_fn) + Callbacks.fire(mistral.callbacks, callback_name, [mistral, data]) response # fake error response @@ -159,7 +166,7 @@ defmodule Langchain.ChatModels.ChatMistralAI do else try do # make base api request and perform high-level success/failure checks - case do_api_request(mistral, messages, functions, callback_fn) do + case do_api_request(mistral, messages, functions) do {:error, reason} -> {:error, reason} @@ -173,11 +180,11 @@ defmodule Langchain.ChatModels.ChatMistralAI do end end - @spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) :: + @spec do_api_request(t(), [Message.t()], [Function.t()], integer()) :: list() | struct() | {:error, String.t()} - def do_api_request(mistral, messages, functions, callback_fn, retry_count \\ 3) + def do_api_request(mistral, messages, functions, retry_count \\ 3) - def do_api_request(_mistral, _messages, _functions, _callback_fn, 0) do + def do_api_request(_mistral, _messages, _functions, 0) do raise LangChainError, "Retries exceeded. Connection failed." end @@ -185,7 +192,6 @@ defmodule Langchain.ChatModels.ChatMistralAI do %ChatMistralAI{stream: false} = mistral, messages, functions, - callback_fn, retry_count ) do req = @@ -204,12 +210,12 @@ defmodule Langchain.ChatModels.ChatMistralAI do # parse the body and return it as parsed structs |> case do {:ok, %Req.Response{body: data}} -> - case do_process_response(data) do + case do_process_response(mistral, data) do {:error, reason} -> {:error, reason} result -> - Utils.fire_callback(mistral, result, callback_fn) + Callbacks.fire(mistral.callbacks, :on_llm_new_message, [mistral, result]) result end @@ -219,7 +225,7 @@ defmodule Langchain.ChatModels.ChatMistralAI do {:error, %Req.TransportError{reason: :closed}} -> # Force a retry by making a recursive call decrementing the counter Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end) - do_api_request(mistral, messages, functions, callback_fn, retry_count - 1) + do_api_request(mistral, messages, functions, retry_count - 1) other -> Logger.error("Unexpected and unhandled API response! #{inspect(other)}") @@ -231,7 +237,6 @@ defmodule Langchain.ChatModels.ChatMistralAI do %ChatMistralAI{stream: true} = mistral, messages, functions, - callback_fn, retry_count ) do Req.new( @@ -245,8 +250,7 @@ defmodule Langchain.ChatModels.ChatMistralAI do Utils.handle_stream_fn( mistral, &ChatOpenAI.decode_stream/1, - &do_process_response/1, - callback_fn + &do_process_response(mistral, &1) ) ) |> case do @@ -262,7 +266,7 @@ defmodule Langchain.ChatModels.ChatMistralAI do {:error, %Req.TransportError{reason: :closed}} -> # Force a retry by making a recursive call decrementing the counter Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end) - do_api_request(mistral, messages, functions, callback_fn, retry_count - 1) + do_api_request(mistral, messages, functions, retry_count - 1) other -> Logger.error( @@ -275,20 +279,21 @@ defmodule Langchain.ChatModels.ChatMistralAI do # Parse a new message response @doc false - @spec do_process_response(data :: %{String.t() => any()} | {:error, any()}) :: + @spec do_process_response(t(), data :: %{String.t() => any()} | {:error, any()}) :: Message.t() | [Message.t()] | MessageDelta.t() | [MessageDelta.t()] | {:error, String.t()} - def do_process_response(%{"choices" => choices}) when is_list(choices) do + def do_process_response(model, %{"choices" => choices}) when is_list(choices) do # process each response individually. Return a list of all processed choices for choice <- choices do - do_process_response(choice) + do_process_response(model, choice) end end def do_process_response( + _model, %{"delta" => delta_body, "finish_reason" => finish, "index" => index} = _msg ) do status = @@ -334,7 +339,7 @@ defmodule Langchain.ChatModels.ChatMistralAI do end end - def do_process_response(%{ + def do_process_response(_model, %{ "finish_reason" => finish_reason, "message" => message, "index" => index @@ -364,19 +369,50 @@ defmodule Langchain.ChatModels.ChatMistralAI do end end - def do_process_response(%{"error" => %{"message" => reason}}) do + def do_process_response(_model, %{"error" => %{"message" => reason}}) do Logger.error("Received error from API: #{inspect(reason)}") {:error, reason} end - def do_process_response({:error, %Jason.DecodeError{} = response}) do + def do_process_response(_model, {:error, %Jason.DecodeError{} = response}) do error_message = "Received invalid JSON: #{inspect(response)}" Logger.error(error_message) {:error, error_message} end - def do_process_response(other) do + def do_process_response(_model, other) do Logger.error("Trying to process an unexpected response. #{inspect(other)}") {:error, "Unexpected response"} end + + @doc """ + Generate a config map that can later restore the model's configuration. + """ + @impl ChatModel + @spec serialize_config(t()) :: %{String.t() => any()} + def serialize_config(%ChatMistralAI{} = model) do + Utils.to_serializable_map( + model, + [ + :endpoint, + :model, + :temperature, + :top_p, + :receive_timeout, + :max_tokens, + :safe_prompt, + :random_seed, + :stream + ], + @current_config_version + ) + end + + @doc """ + Restores the model from the config. + """ + @impl ChatModel + def restore_from_map(%{"version" => 1} = data) do + ChatMistralAI.new(data) + end end diff --git a/lib/chat_models/chat_model.ex b/lib/chat_models/chat_model.ex index ea3d38a6..145fc276 100644 --- a/lib/chat_models/chat_model.ex +++ b/lib/chat_models/chat_model.ex @@ -1,7 +1,9 @@ defmodule LangChain.ChatModels.ChatModel do + require Logger alias LangChain.Message alias LangChain.MessageDelta alias LangChain.Function + alias LangChain.Utils @type call_response :: {:ok, Message.t() | [Message.t()] | [MessageDelta.t()]} | {:error, String.t()} @@ -16,4 +18,46 @@ defmodule LangChain.ChatModels.ChatModel do String.t() | [Message.t()], [LangChain.Function.t()] ) :: call_response() + + @callback serialize_config(t()) :: %{String.t() => any()} + + @callback restore_from_map(%{String.t() => any()}) :: {:ok, struct()} | {:error, String.t()} + + @doc """ + Add a `LangChain.ChatModels.LLMCallbacks` callback map to the ChatModel if + it includes the `:callback` key. + """ + @spec add_callback(%{optional(:callbacks) => nil | map()}, map()) :: map() | struct() + def add_callback(%_{callbacks: callbacks} = model, callback_map) do + existing_callbacks = callbacks || [] + %{model | callbacks: existing_callbacks ++ [callback_map]} + end + + def add_callback(model, _callback_map), do: model + + @doc """ + Create a serializable map from a ChatModel's current configuration that can + later be restored. + """ + def serialize_config(%chat_module{} = model) do + # plucks the module from the struct and, because of the behaviour, assumes + # the module defines a `serialize_config/1` function that is executed. + chat_module.serialize_config(model) + end + + @doc """ + Restore a ChatModel from a serialized config map. + """ + @spec restore_from_map(nil | %{String.t() => any()}) :: {:ok, struct()} | {:error, String.t()} + def restore_from_map(nil), do: {:error, "No data to restore"} + + def restore_from_map(%{"module" => module_name} = data) do + case Utils.module_from_name(module_name) do + {:ok, module} -> + module.restore_from_map(data) + + {:error, _reason} = error -> + error + end + end end diff --git a/lib/chat_models/chat_ollama_ai.ex b/lib/chat_models/chat_ollama_ai.ex index bf45acd3..0fc54766 100644 --- a/lib/chat_models/chat_ollama_ai.ex +++ b/lib/chat_models/chat_ollama_ai.ex @@ -43,6 +43,8 @@ defmodule LangChain.ChatModels.ChatOllamaAI do @behaviour ChatModel + @current_config_version 1 + @type t :: %ChatOllamaAI{} @create_fields [ @@ -146,6 +148,9 @@ defmodule LangChain.ChatModels.ChatOllamaAI do # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, # while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) field :top_p, :float, default: 0.9 + + # A list of maps for callback handlers + field :callbacks, {:array, :map}, default: [] end @doc """ @@ -227,21 +232,21 @@ defmodule LangChain.ChatModels.ChatOllamaAI do """ @impl ChatModel - def call(ollama_ai, prompt, functions \\ [], callback_fn \\ nil) + def call(ollama_ai, prompt, functions \\ []) - def call(%ChatOllamaAI{} = ollama_ai, prompt, functions, callback_fn) when is_binary(prompt) do + def call(%ChatOllamaAI{} = ollama_ai, prompt, functions) when is_binary(prompt) do messages = [ Message.new_system!(), Message.new_user!(prompt) ] - call(ollama_ai, messages, functions, callback_fn) + call(ollama_ai, messages, functions) end - def call(%ChatOllamaAI{} = ollama_ai, messages, functions, callback_fn) + def call(%ChatOllamaAI{} = ollama_ai, messages, functions) when is_list(messages) do try do - case do_api_request(ollama_ai, messages, functions, callback_fn) do + case do_api_request(ollama_ai, messages, functions) do {:error, reason} -> {:error, reason} @@ -269,11 +274,11 @@ defmodule LangChain.ChatModels.ChatOllamaAI do # # Retries the request up to 3 times on transient errors with a 1 second delay @doc false - @spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) :: + @spec do_api_request(t(), [Message.t()], [Function.t()]) :: list() | struct() | {:error, String.t()} - def do_api_request(ollama_ai, messages, functions, callback_fn, retry_count \\ 3) + def do_api_request(ollama_ai, messages, functions, retry_count \\ 3) - def do_api_request(_ollama_ai, _messages, _functions, _callback_fn, 0) do + def do_api_request(_ollama_ai, _messages, _functions, 0) do raise LangChainError, "Retries exceeded. Connection failed." end @@ -281,7 +286,6 @@ defmodule LangChain.ChatModels.ChatOllamaAI do %ChatOllamaAI{stream: false} = ollama_ai, messages, functions, - callback_fn, retry_count ) do req = @@ -298,7 +302,7 @@ defmodule LangChain.ChatModels.ChatOllamaAI do |> Req.post() |> case do {:ok, %Req.Response{body: data}} -> - case do_process_response(data) do + case do_process_response(ollama_ai, data) do {:error, reason} -> {:error, reason} @@ -312,7 +316,7 @@ defmodule LangChain.ChatModels.ChatOllamaAI do {:error, %Req.TransportError{reason: :closed}} -> # Force a retry by making a recursive call decrementing the counter Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end) - do_api_request(ollama_ai, messages, functions, callback_fn, retry_count - 1) + do_api_request(ollama_ai, messages, functions, retry_count - 1) other -> Logger.error("Unexpected and unhandled API response! #{inspect(other)}") @@ -324,7 +328,6 @@ defmodule LangChain.ChatModels.ChatOllamaAI do %ChatOllamaAI{stream: true} = ollama_ai, messages, functions, - callback_fn, retry_count ) do Req.new( @@ -337,8 +340,7 @@ defmodule LangChain.ChatModels.ChatOllamaAI do Utils.handle_stream_fn( ollama_ai, &ChatOpenAI.decode_stream/1, - &do_process_response/1, - callback_fn + &(do_process_response(ollama_ai, &1)) ) ) |> case do @@ -354,7 +356,7 @@ defmodule LangChain.ChatModels.ChatOllamaAI do {:error, %Req.TransportError{reason: :closed}} -> # Force a retry by making a recursive call decrementing the counter Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end) - do_api_request(ollama_ai, messages, functions, callback_fn, retry_count - 1) + do_api_request(ollama_ai, messages, functions, retry_count - 1) other -> Logger.error( @@ -365,15 +367,15 @@ defmodule LangChain.ChatModels.ChatOllamaAI do end end - def do_process_response(%{"message" => message, "done" => true}) do + def do_process_response(_model, %{"message" => message, "done" => true}) do create_message(message, :complete, Message) end - def do_process_response(%{"message" => message, "done" => _other}) do + def do_process_response(_model, %{"message" => message, "done" => _other}) do create_message(message, :incomplete, MessageDelta) end - def do_process_response(%{"error" => reason}) do + def do_process_response(_model, %{"error" => reason}) do Logger.error("Received error from API: #{inspect(reason)}") {:error, reason} end @@ -387,4 +389,46 @@ defmodule LangChain.ChatModels.ChatOllamaAI do {:error, Utils.changeset_error_to_string(changeset)} end end + + @doc """ + Generate a config map that can later restore the model's configuration. + """ + @impl ChatModel + @spec serialize_config(t()) :: %{String.t() => any()} + def serialize_config(%ChatOllamaAI{} = model) do + Utils.to_serializable_map( + model, + [ + :endpoint, + :model, + :mirostat, + :mirostat_eta, + :mirostat_tau, + :num_ctx, + :num_gqa, + :num_gpu, + :num_predict, + :num_thread, + :receive_timeout, + :repeat_last_n, + :repeat_penalty, + :seed, + :stop, + :stream, + :temperature, + :tfs_z, + :top_k, + :top_p + ], + @current_config_version + ) + end + + @doc """ + Restores the model from the config. + """ + @impl ChatModel + def restore_from_map(%{"version" => 1} = data) do + ChatOllamaAI.new(data) + end end diff --git a/lib/chat_models/chat_open_ai.ex b/lib/chat_models/chat_open_ai.ex index 6190ff9d..1a6dd1ef 100644 --- a/lib/chat_models/chat_open_ai.ex +++ b/lib/chat_models/chat_open_ai.ex @@ -87,6 +87,8 @@ defmodule LangChain.ChatModels.ChatOpenAI do @behaviour ChatModel + @current_config_version 1 + # NOTE: As of gpt-4 and gpt-3.5, only one function_call is issued at a time # even when multiple requests could be issued based on the prompt. @@ -879,4 +881,37 @@ defmodule LangChain.ChatModels.ChatOpenAI do end defp get_token_usage(_response_body), do: nil + + @doc """ + Generate a config map that can later restore the model's configuration. + """ + @impl ChatModel + @spec serialize_config(t()) :: %{String.t() => any()} + def serialize_config(%ChatOpenAI{} = model) do + Utils.to_serializable_map( + model, + [ + :endpoint, + :model, + :temperature, + :frequency_penalty, + :receive_timeout, + :seed, + :n, + :json_response, + :stream, + :max_tokens, + :stream_options + ], + @current_config_version + ) + end + + @doc """ + Restores the model from the config. + """ + @impl ChatModel + def restore_from_map(%{"version" => 1} = data) do + ChatOpenAI.new(data) + end end diff --git a/lib/chat_models/chat_vertex_ai.ex b/lib/chat_models/chat_vertex_ai.ex index dea89301..460f81a2 100644 --- a/lib/chat_models/chat_vertex_ai.ex +++ b/lib/chat_models/chat_vertex_ai.ex @@ -22,6 +22,8 @@ defmodule LangChain.ChatModels.ChatVertexAI do @behaviour ChatModel + @current_config_version 1 + # allow up to 2 minutes for response. @receive_timeout 60_000 @@ -585,4 +587,34 @@ defmodule LangChain.ChatModels.ChatVertexAI do defp unmap_role("model"), do: "assistant" defp unmap_role("function"), do: "tool" defp unmap_role(role), do: role + + @doc """ + Generate a config map that can later restore the model's configuration. + """ + @impl ChatModel + @spec serialize_config(t()) :: %{String.t() => any()} + def serialize_config(%ChatVertexAI{} = model) do + Utils.to_serializable_map( + model, + [ + :endpoint, + :model, + :temperature, + :top_p, + :top_k, + :receive_timeout, + :json_response, + :stream + ], + @current_config_version + ) + end + + @doc """ + Restores the model from the config. + """ + @impl ChatModel + def restore_from_map(%{"version" => 1} = data) do + ChatVertexAI.new(data) + end end diff --git a/lib/images/open_ai_image.ex b/lib/images/open_ai_image.ex index 4dcb1aa7..0594de08 100644 --- a/lib/images/open_ai_image.ex +++ b/lib/images/open_ai_image.ex @@ -188,7 +188,7 @@ defmodule LangChain.Images.OpenAIImage do Logger.warning("Found override API response. Will not make live API call.") case get_api_override() do - {:ok, {:ok, data} = response} -> + {:ok, {:ok, _data, _callback_name} = response} -> response # fake error response diff --git a/lib/utils.ex b/lib/utils.ex index 20449e45..b9d9ab8d 100644 --- a/lib/utils.ex +++ b/lib/utils.ex @@ -222,4 +222,67 @@ defmodule LangChain.Utils do List.replace_at(list, index, value) end end + + @doc """ + Given a struct, create a map with the selected keys converted to strings. + Additionally includes a `version` number for the data. + """ + @spec to_serializable_map(struct(), keys :: [atom()], version :: integer()) :: %{ + String.t() => any() + } + def to_serializable_map(%module{} = struct, keys, version \\ 1) do + struct + |> Map.from_struct() + |> Map.take(keys) + |> stringify_keys() + |> Map.put("module", Atom.to_string(module)) + |> Map.put("version", version) + end + + @doc """ + Convert map atom keys to strings + + Original source: https://gist.github.com/kipcole9/0bd4c6fb6109bfec9955f785087f53fb + """ + def stringify_keys(nil), do: nil + + def stringify_keys(map = %{}) do + map + |> Enum.map(fn {k, v} -> {to_string(k), stringify_keys(v)} end) + |> Enum.into(%{}) + end + + # Walk the list and stringify the keys of + # of any map members + def stringify_keys([head | rest]) do + [stringify_keys(head) | stringify_keys(rest)] + end + + def stringify_keys(not_a_map) when is_atom(not_a_map) and not is_boolean(not_a_map) do + Atom.to_string(not_a_map) + end + + def stringify_keys(not_a_map) do + not_a_map + end + + @doc """ + Return an `{:ok, module}` when the string successfully converts to an existing + module. + """ + def module_from_name("Elixir." <> _rest = module_name) do + try do + {:ok, String.to_existing_atom(module_name)} + rescue + _err -> + Logger.error("Failed to restore using module_name #{inspect(module_name)}. Not found.") + {:error, "ChatModel module #{inspect(module_name)} not found"} + end + end + + def module_from_name(module_name) do + msg = "Not an Elixir module: #{inspect(module_name)}" + Logger.error(msg) + {:error, msg} + end end diff --git a/mix.exs b/mix.exs index cae522b2..5246564a 100644 --- a/mix.exs +++ b/mix.exs @@ -37,7 +37,7 @@ defmodule LangChain.MixProject do [ {:ecto, "~> 3.10 or ~> 3.11"}, {:gettext, "~> 0.20"}, - {:req, ">= 0.4.8"}, + {:req, ">= 0.5.0"}, {:abacus, "~> 2.1.0"}, {:nx, ">= 0.7.0", optional: true}, {:ex_doc, "~> 0.27", only: :dev, runtime: false} diff --git a/test/chains/llm_chain_test.exs b/test/chains/llm_chain_test.exs index b239ad5e..1f26cf99 100644 --- a/test/chains/llm_chain_test.exs +++ b/test/chains/llm_chain_test.exs @@ -1541,6 +1541,24 @@ defmodule LangChain.Chains.LLMChainTest do end end + describe "add_llm_callback/2" do + test "appends a callback handler to the chain's LLM", %{chat: chat} do + handler1 = %{on_llm_new_message: fn _chain, _msg -> IO.puts("MESSAGE 1!") end} + handler2 = %{on_llm_new_message: fn _chain, _msg -> IO.puts("MESSAGE 2!") end} + + # none to start with + assert chat.callbacks == [] + + chain = + %{llm: chat} + |> LLMChain.new!() + |> LLMChain.add_llm_callback(handler1) + |> LLMChain.add_llm_callback(handler2) + + assert chain.llm.callbacks == [handler1, handler2] + end + end + # TODO: Sequential chains # https://js.langchain.com/docs/modules/chains/sequential_chain diff --git a/test/chat_models/chat_anthropic_test.exs b/test/chat_models/chat_anthropic_test.exs index 19c3cad6..41a5eb88 100644 --- a/test/chat_models/chat_anthropic_test.exs +++ b/test/chat_models/chat_anthropic_test.exs @@ -1290,4 +1290,39 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text # assert false # end end + + describe "serialize_config/2" do + test "does not include the API key or callbacks" do + model = ChatAnthropic.new!(%{model: "claude-3-haiku-20240307"}) + result = ChatAnthropic.serialize_config(model) + assert result["version"] == 1 + refute Map.has_key?(result, "api_key") + refute Map.has_key?(result, "callbacks") + end + + test "creates expected map" do + model = + ChatAnthropic.new!(%{ + model: "claude-3-haiku-20240307", + temperature: 0, + max_tokens: 1234 + }) + + result = ChatAnthropic.serialize_config(model) + + assert result == %{ + "endpoint" => "https://api.anthropic.com/v1/messages", + "model" => "claude-3-haiku-20240307", + "max_tokens" => 1234, + "receive_timeout" => 60000, + "stream" => false, + "temperature" => 0.0, + "api_version" => "2023-06-01", + "top_k" => nil, + "top_p" => nil, + "module" => "Elixir.LangChain.ChatModels.ChatAnthropic", + "version" => 1 + } + end + end end diff --git a/test/chat_models/chat_bumblebee_test.exs b/test/chat_models/chat_bumblebee_test.exs index 609b0067..b8107143 100644 --- a/test/chat_models/chat_bumblebee_test.exs +++ b/test/chat_models/chat_bumblebee_test.exs @@ -49,7 +49,7 @@ defmodule LangChain.ChatModels.ChatBumblebeeTest do end, on_llm_token_usage: fn _model, usage -> send(self(), {:callback_usage, usage}) - end, + end } %{handler: handler} @@ -174,4 +174,49 @@ defmodule LangChain.ChatModels.ChatBumblebeeTest do assert %TokenUsage{input: 38, output: 4} = usage end end + + describe "serialize_config/2" do + test "does not include the API key or callbacks" do + model = ChatBumblebee.new!(%{serving: Fake}) + result = ChatBumblebee.serialize_config(model) + assert result["version"] == 1 + refute Map.has_key?(result, "callbacks") + end + + test "creates expected map" do + model = + ChatBumblebee.new!(%{ + serving: Fake, + seed: 123, + template_format: :llama_3 + }) + + result = ChatBumblebee.serialize_config(model) + + assert result == %{ + "module" => "Elixir.LangChain.ChatModels.ChatBumblebee", + "seed" => 123, + "stream" => true, + "version" => 1, + "serving" => "Elixir.Fake", + "template_format" => "llama_3" + } + end + end + + describe "restore_from_map/1" do + test "restores from a serialized map" do + model = + ChatBumblebee.new!(%{ + serving: Fake, + seed: 123, + template_format: :llama_3 + }) + + serialized = ChatBumblebee.serialize_config(model) + + {:ok, restored} = ChatBumblebee.restore_from_map(serialized) + assert restored == model + end + end end diff --git a/test/chat_models/chat_google_ai_test.exs b/test/chat_models/chat_google_ai_test.exs index 59b24dde..d31eaab0 100644 --- a/test/chat_models/chat_google_ai_test.exs +++ b/test/chat_models/chat_google_ai_test.exs @@ -19,7 +19,9 @@ defmodule ChatModels.ChatGoogleAITest do function: fn _args, _context -> {:ok, "Hello world!"} end }) - %{hello_world: hello_world} + model = ChatGoogleAI.new!(%{}) + + %{model: model, hello_world: hello_world} end describe "new/1" do @@ -46,14 +48,14 @@ defmodule ChatModels.ChatGoogleAITest do end test "supports overriding the API version" do - version = "v1" + api_version = "v1" model = ChatGoogleAI.new!(%{ - version: version + api_version: api_version }) - assert model.version == version + assert model.api_version == api_version end end @@ -179,7 +181,7 @@ defmodule ChatModels.ChatGoogleAITest do end describe "do_process_response/2" do - test "handles receiving a message" do + test "handles receiving a message", %{model: model} do response = %{ "candidates" => [ %{ @@ -190,14 +192,14 @@ defmodule ChatModels.ChatGoogleAITest do ] } - assert [%Message{} = struct] = ChatGoogleAI.do_process_response(response) + assert [%Message{} = struct] = ChatGoogleAI.do_process_response(model, response) assert struct.role == :assistant [%ContentPart{type: :text, content: "Hello User!"}] = struct.content assert struct.index == 0 assert struct.status == :complete end - test "error if receiving non-text content" do + test "error if receiving non-text content", %{model: model} do response = %{ "candidates" => [ %{ @@ -208,11 +210,11 @@ defmodule ChatModels.ChatGoogleAITest do ] } - assert [{:error, error_string}] = ChatGoogleAI.do_process_response(response) + assert [{:error, error_string}] = ChatGoogleAI.do_process_response(model, response) assert error_string == "role: is invalid" end - test "handles receiving function calls" do + test "handles receiving function calls", %{model: model} do args = %{"args" => "data"} response = %{ @@ -228,7 +230,7 @@ defmodule ChatModels.ChatGoogleAITest do ] } - assert [%Message{} = struct] = ChatGoogleAI.do_process_response(response) + assert [%Message{} = struct] = ChatGoogleAI.do_process_response(model, response) assert struct.role == :assistant assert struct.index == 0 [call] = struct.tool_calls @@ -236,7 +238,7 @@ defmodule ChatModels.ChatGoogleAITest do assert call.arguments == args end - test "handles receiving MessageDeltas as well" do + test "handles receiving MessageDeltas as well", %{model: model} do response = %{ "candidates" => [ %{ @@ -250,14 +252,14 @@ defmodule ChatModels.ChatGoogleAITest do ] } - assert [%MessageDelta{} = struct] = ChatGoogleAI.do_process_response(response, MessageDelta) + assert [%MessageDelta{} = struct] = ChatGoogleAI.do_process_response(model, response, MessageDelta) assert struct.role == :assistant assert struct.content == "This is the first part of a mes" assert struct.index == 0 assert struct.status == :incomplete end - test "handles API error messages" do + test "handles API error messages", %{model: model} do response = %{ "error" => %{ "code" => 400, @@ -266,20 +268,20 @@ defmodule ChatModels.ChatGoogleAITest do } } - assert {:error, error_string} = ChatGoogleAI.do_process_response(response) + assert {:error, error_string} = ChatGoogleAI.do_process_response(model, response) assert error_string == "Invalid request" end - test "handles Jason.DecodeError" do + test "handles Jason.DecodeError", %{model: model} do response = {:error, %Jason.DecodeError{}} - assert {:error, error_string} = ChatGoogleAI.do_process_response(response) + assert {:error, error_string} = ChatGoogleAI.do_process_response(model, response) assert "Received invalid JSON:" <> _ = error_string end - test "handles unexpected response with error" do + test "handles unexpected response with error", %{model: model} do response = %{} - assert {:error, "Unexpected response"} = ChatGoogleAI.do_process_response(response) + assert {:error, "Unexpected response"} = ChatGoogleAI.do_process_response(model, response) end end @@ -340,4 +342,41 @@ defmodule ChatModels.ChatGoogleAITest do ] end end + + describe "serialize_config/2" do + test "does not include the API key or callbacks" do + model = ChatGoogleAI.new!(%{model: "gpt-4o"}) + result = ChatGoogleAI.serialize_config(model) + assert result["version"] == 1 + refute Map.has_key?(result, "api_key") + refute Map.has_key?(result, "callbacks") + end + + test "creates expected map" do + model = + ChatGoogleAI.new!(%{ + model: "gpt-4o", + temperature: 0, + frequency_penalty: 0.5, + seed: 123, + max_tokens: 1234, + stream_options: %{include_usage: true} + }) + + result = ChatGoogleAI.serialize_config(model) + + assert result == %{ + "endpoint" => "https://generativelanguage.googleapis.com/v1beta", + "model" => "gpt-4o", + "module" => "Elixir.LangChain.ChatModels.ChatGoogleAI", + "receive_timeout" => 60000, + "stream" => false, + "temperature" => 0.0, + "version" => 1, + "api_version" => "v1beta", + "top_k" => 1.0, + "top_p" => 1.0 + } + end + end end diff --git a/test/chat_models/chat_mistral_ai_test.exs b/test/chat_models/chat_mistral_ai_test.exs index bcad363e..f54599ca 100644 --- a/test/chat_models/chat_mistral_ai_test.exs +++ b/test/chat_models/chat_mistral_ai_test.exs @@ -5,6 +5,11 @@ defmodule LangChain.ChatModels.ChatMistralAITest do alias LangChain.Message alias LangChain.MessageDelta + setup do + model = ChatMistralAI.new!(%{"model" => "mistral-tiny"}) + %{model: model} + end + describe "new/1" do test "works with minimal attr" do assert {:ok, %ChatMistralAI{} = mistral_ai} = @@ -85,7 +90,7 @@ defmodule LangChain.ChatModels.ChatMistralAITest do end describe "do_process_response/2" do - test "handles receiving a message" do + test "handles receiving a message", %{model: model} do response = %{ "choices" => [ %{ @@ -99,14 +104,14 @@ defmodule LangChain.ChatModels.ChatMistralAITest do ] } - assert [%Message{} = struct] = ChatMistralAI.do_process_response(response) + assert [%Message{} = struct] = ChatMistralAI.do_process_response(model, response) assert struct.role == :assistant assert struct.content == "Hello User!" assert struct.index == 0 assert struct.status == :complete end - test "errors with invalid role" do + test "errors with invalid role", %{model: model} do response = %{ "choices" => [ %{ @@ -120,10 +125,10 @@ defmodule LangChain.ChatModels.ChatMistralAITest do ] } - assert [{:error, "role: is invalid"}] = ChatMistralAI.do_process_response(response) + assert [{:error, "role: is invalid"}] = ChatMistralAI.do_process_response(model, response) end - test "handles receiving MessageDeltas as well" do + test "handles receiving MessageDeltas as well", %{model: model} do response = %{ "choices" => [ %{ @@ -137,7 +142,7 @@ defmodule LangChain.ChatModels.ChatMistralAITest do ] } - assert [%MessageDelta{} = struct] = ChatMistralAI.do_process_response(response) + assert [%MessageDelta{} = struct] = ChatMistralAI.do_process_response(model, response) assert struct.role == :assistant assert struct.content == "This is the first part of a mes" @@ -145,7 +150,7 @@ defmodule LangChain.ChatModels.ChatMistralAITest do assert struct.status == :incomplete end - test "handles API error messages" do + test "handles API error messages", %{model: model} do response = %{ "error" => %{ "code" => 400, @@ -154,20 +159,58 @@ defmodule LangChain.ChatModels.ChatMistralAITest do } } - assert {:error, error_string} = ChatMistralAI.do_process_response(response) + assert {:error, error_string} = ChatMistralAI.do_process_response(model, response) assert error_string == "Invalid request" end - test "handles Jason.DecodeError" do + test "handles Jason.DecodeError", %{model: model} do response = {:error, %Jason.DecodeError{}} - assert {:error, error_string} = ChatMistralAI.do_process_response(response) + assert {:error, error_string} = ChatMistralAI.do_process_response(model, response) assert "Received invalid JSON:" <> _ = error_string end - test "handles unexpected response with error" do + test "handles unexpected response with error", %{model: model} do response = %{} - assert {:error, "Unexpected response"} = ChatMistralAI.do_process_response(response) + assert {:error, "Unexpected response"} = ChatMistralAI.do_process_response(model, response) + end + end + + describe "serialize_config/2" do + test "does not include the API key or callbacks" do + model = ChatMistralAI.new!(%{model: "mistral-tiny"}) + result = ChatMistralAI.serialize_config(model) + assert result["version"] == 1 + refute Map.has_key?(result, "api_key") + refute Map.has_key?(result, "callbacks") + end + + test "creates expected map" do + model = + ChatMistralAI.new!(%{ + model: "mistral-tiny", + temperature: 1.0, + top_p: 1.0, + max_tokens: 100, + safe_prompt: true, + random_seed: 42 + }) + + result = ChatMistralAI.serialize_config(model) + + assert result == %{ + "endpoint" => "https://api.mistral.ai/v1/chat/completions", + "model" => "mistral-tiny", + "max_tokens" => 100, + "module" => "Elixir.Langchain.ChatModels.ChatMistralAI", + "receive_timeout" => 60000, + "stream" => false, + "temperature" => 1.0, + "random_seed" => 42, + "safe_prompt" => true, + "top_p" => 1.0, + "version" => 1 + } end end end diff --git a/test/chat_models/chat_model_test.exs b/test/chat_models/chat_model_test.exs new file mode 100644 index 00000000..d2640162 --- /dev/null +++ b/test/chat_models/chat_model_test.exs @@ -0,0 +1,58 @@ +defmodule LangChain.ChatModels.ChatModelTest do + use ExUnit.Case + doctest LangChain.ChatModels.ChatModel + alias LangChain.ChatModels.ChatModel + alias LangChain.ChatModels.ChatOpenAI + + describe "add_callback/2" do + test "appends the callback to the model" do + model = %ChatOpenAI{} + assert model.callbacks == [] + handler = %{on_llm_new_message: fn _model, _msg -> :ok end} + %ChatOpenAI{} = updated = ChatModel.add_callback(model, handler) + assert updated.callbacks == [handler] + end + + test "does nothing on a model that doesn't support callbacks" do + handler = %{on_llm_new_message: fn _model, _msg -> :ok end} + non_model = %{something: "else"} + updated = ChatModel.add_callback(non_model, handler) + assert updated == non_model + end + end + + describe "serialize_config/1" do + test "creates a map from a chat model" do + model = ChatOpenAI.new!(%{model: "gpt-4o"}) + result = ChatModel.serialize_config(model) + assert Map.get(result, "module") == "Elixir.LangChain.ChatModels.ChatOpenAI" + assert Map.get(result, "model") == "gpt-4o" + assert Map.get(result, "version") == 1 + end + end + + describe "restore_from_map/1" do + test "return error when nil data given" do + assert {:error, reason} = ChatModel.restore_from_map(nil) + assert reason == "No data to restore" + end + + test "return error when module not found" do + assert {:error, reason} = + ChatModel.restore_from_map(%{ + "module" => "Elixir.InvalidModule", + "version" => 1, + "model" => "howdy" + }) + + assert reason == "ChatModel module \"Elixir.InvalidModule\" not found" + end + + test "restores using the module" do + model = ChatOpenAI.new!(%{model: "gpt-4o"}) + serialized = ChatModel.serialize_config(model) + {:ok, restored} = ChatModel.restore_from_map(serialized) + assert restored == model + end + end +end diff --git a/test/chat_models/chat_ollama_ai_test.exs b/test/chat_models/chat_ollama_ai_test.exs index e6776a35..b1fbd3f1 100644 --- a/test/chat_models/chat_ollama_ai_test.exs +++ b/test/chat_models/chat_ollama_ai_test.exs @@ -4,6 +4,12 @@ defmodule ChatModels.ChatOllamaAITest do doctest LangChain.ChatModels.ChatOllamaAI alias LangChain.ChatModels.ChatOllamaAI + setup do + model = ChatOllamaAI.new!(%{"model" => "llama2:latest"}) + + %{model: model} + end + describe "new/1" do test "works with minimal attributes" do assert {:ok, %ChatOllamaAI{} = ollama_ai} = ChatOllamaAI.new(%{"model" => "llama2:latest"}) @@ -187,7 +193,7 @@ defmodule ChatModels.ChatOllamaAITest do end describe "do_process_response/1" do - test "handles receiving a non streamed message result" do + test "handles receiving a non streamed message result", %{model: model} do response = %{ "model" => "llama2", "created_at" => "2024-01-15T23:02:24.087444Z", @@ -204,13 +210,13 @@ defmodule ChatModels.ChatOllamaAITest do "eval_duration" => 5_336_241_000 } - assert %Message{} = struct = ChatOllamaAI.do_process_response(response) + assert %Message{} = struct = ChatOllamaAI.do_process_response(model, response) assert struct.role == :assistant assert struct.content == "Greetings!" assert struct.index == nil end - test "handles receiving a streamed message result" do + test "handles receiving a streamed message result", %{model: model} do response = %{ "model" => "llama2", "created_at" => "2024-01-15T23:02:24.087444Z", @@ -221,10 +227,58 @@ defmodule ChatModels.ChatOllamaAITest do "done" => false } - assert %MessageDelta{} = struct = ChatOllamaAI.do_process_response(response) + assert %MessageDelta{} = struct = ChatOllamaAI.do_process_response(model, response) assert struct.role == :assistant assert struct.content == "Gre" assert struct.status == :incomplete end end + + describe "serialize_config/2" do + test "does not include the API key or callbacks" do + model = ChatOllamaAI.new!(%{model: "llama2"}) + result = ChatOllamaAI.serialize_config(model) + assert result["version"] == 1 + refute Map.has_key?(result, "callbacks") + end + + test "creates expected map" do + model = + ChatOllamaAI.new!(%{ + model: "llama2", + temperature: 0, + frequency_penalty: 0.5, + seed: 123, + num_gpu: 2, + stream: true + }) + + result = ChatOllamaAI.serialize_config(model) + + assert result == %{ + "endpoint" => "http://localhost:11434/api/chat", + "mirostat" => 0, + "mirostat_eta" => 0.1, + "mirostat_tau" => 5.0, + "model" => "llama2", + "module" => "Elixir.LangChain.ChatModels.ChatOllamaAI", + "num_ctx" => 2048, + "num_gpu" => 2, + "num_gqa" => nil, + "num_predict" => 128, + "num_thread" => nil, + "receive_timeout" => 300_000, + "repeat_last_n" => 64, + "repeat_penalty" => 1.1, + "seed" => 123, + "stop" => nil, + "stream" => true, + "temperature" => 0.0, + "tfs_z" => 1.0, + "top_k" => 40, + "top_p" => 0.9, + "version" => 1 + } + end + end end diff --git a/test/chat_models/chat_open_ai_test.exs b/test/chat_models/chat_open_ai_test.exs index a9b38f9c..d96a6014 100644 --- a/test/chat_models/chat_open_ai_test.exs +++ b/test/chat_models/chat_open_ai_test.exs @@ -1813,4 +1813,44 @@ defmodule LangChain.ChatModels.ChatOpenAITest do } ] end + + describe "serialize_config/2" do + test "does not include the API key or callbacks" do + model = ChatOpenAI.new!(%{model: "gpt-4o"}) + result = ChatOpenAI.serialize_config(model) + assert result["version"] == 1 + refute Map.has_key?(result, "api_key") + refute Map.has_key?(result, "callbacks") + end + + test "creates expected map" do + model = + ChatOpenAI.new!(%{ + model: "gpt-4o", + temperature: 0, + frequency_penalty: 0.5, + seed: 123, + max_tokens: 1234, + stream_options: %{include_usage: true} + }) + + result = ChatOpenAI.serialize_config(model) + + assert result == %{ + "endpoint" => "https://api.openai.com/v1/chat/completions", + "frequency_penalty" => 0.5, + "json_response" => false, + "max_tokens" => 1234, + "model" => "gpt-4o", + "n" => 1, + "receive_timeout" => 60000, + "seed" => 123, + "stream" => false, + "stream_options" => %{"include_usage" => true}, + "temperature" => 0.0, + "version" => 1, + "module" => "Elixir.LangChain.ChatModels.ChatOpenAI" + } + end + end end diff --git a/test/chat_models/chat_vertex_ai_test.exs b/test/chat_models/chat_vertex_ai_test.exs index e5a5acba..1157a1c5 100644 --- a/test/chat_models/chat_vertex_ai_test.exs +++ b/test/chat_models/chat_vertex_ai_test.exs @@ -321,4 +321,37 @@ defmodule ChatModels.ChatVertexAITest do ] end end + + describe "serialize_config/2" do + test "does not include the API key or callbacks" do + model = ChatVertexAI.new!(%{model: "gemini-pro", endpoint: "http://localhost:1234/"}) + result = ChatVertexAI.serialize_config(model) + assert result["version"] == 1 + refute Map.has_key?(result, "api_key") + refute Map.has_key?(result, "callbacks") + end + + test "creates expected map" do + model = + ChatVertexAI.new!(%{ + model: "gemini-pro", + endpoint: "http://localhost:1234/" + }) + + result = ChatVertexAI.serialize_config(model) + + assert result == %{ + "endpoint" => "http://localhost:1234/", + "model" => "gemini-pro", + "module" => "Elixir.LangChain.ChatModels.ChatVertexAI", + "receive_timeout" => 60000, + "stream" => false, + "temperature" => 0.9, + "top_k" => 1.0, + "top_p" => 1.0, + "version" => 1, + "json_response" => false + } + end + end end diff --git a/test/utils_test.exs b/test/utils_test.exs index de627475..fe8c40cc 100644 --- a/test/utils_test.exs +++ b/test/utils_test.exs @@ -2,6 +2,7 @@ defmodule LangChain.UtilsTest do use ExUnit.Case doctest LangChain.Utils + alias LangChain.ChatModels.ChatOpenAI alias LangChain.Utils defmodule FakeSchema do @@ -116,4 +117,55 @@ defmodule LangChain.UtilsTest do assert result == ["a", "b", "c"] end end + + describe "to_serializable_map/3" do + test "converts a chat model to a string keyed map with a version included" do + model = + ChatOpenAI.new!(%{ + model: "gpt-4o", + temperature: 0, + frequency_penalty: 0.5, + seed: 123, + max_tokens: 1234, + stream_options: %{include_usage: true} + }) + + result = + Utils.to_serializable_map(model, [ + :model, + :temperature, + :frequency_penalty, + :seed, + :max_tokens, + :stream_options + ]) + + assert result == %{ + "model" => "gpt-4o", + "temperature" => 0.0, + "frequency_penalty" => 0.5, + "seed" => 123, + "max_tokens" => 1234, + "stream_options" => %{"include_usage" => true}, + "version" => 1, + "module" => "Elixir.LangChain.ChatModels.ChatOpenAI" + } + end + end + + describe "module_from_name/1" do + test "returns :ok tuple with module when valid" do + assert {:ok, DateTime} = Utils.module_from_name("Elixir.DateTime") + end + + test "returns error when not a module" do + assert {:error, reason} = Utils.module_from_name("not a module") + assert reason == "Not an Elixir module: \"not a module\"" + end + + test "returns error when not an existing atom" do + assert {:error, reason} = Utils.module_from_name("Elixir.Missing.Module") + assert reason == "ChatModel module \"Elixir.Missing.Module\" not found" + end + end end