From 704b6fd7b4008f0b9aae6dfa5b089b665ab33925 Mon Sep 17 00:00:00 2001 From: Mark Ericksen Date: Fri, 8 Dec 2023 16:27:50 -0700 Subject: [PATCH] track models that don't support functions - raise error when adding functions and not supported --- lib/chains/llm_chain.ex | 7 ++++++ ...chat_bumble_model.ex => chat_bumblebee.ex} | 24 +++++++++---------- lib/chat_models/chat_open_ai.ex | 3 +++ test/chains/llm_chain_test.exs | 8 +++++++ 4 files changed, 30 insertions(+), 12 deletions(-) rename lib/chat_models/{chat_bumble_model.ex => chat_bumblebee.ex} (92%) diff --git a/lib/chains/llm_chain.ex b/lib/chains/llm_chain.ex index 5b3c660a..8818afa6 100644 --- a/lib/chains/llm_chain.ex +++ b/lib/chains/llm_chain.ex @@ -117,6 +117,13 @@ defmodule LangChain.Chains.LLMChain do Add more functions to an LLMChain. """ @spec add_functions(t(), Function.t() | [Function.t()]) :: t() | no_return() + def add_functions( + %LLMChain{llm: %{supports_functions: false}} = _chain, + _functions + ) do + raise LangChainError, "The LLM does not support functions." + end + def add_functions(%LLMChain{} = chain, %Function{} = function) do add_functions(chain, [function]) end diff --git a/lib/chat_models/chat_bumble_model.ex b/lib/chat_models/chat_bumblebee.ex similarity index 92% rename from lib/chat_models/chat_bumble_model.ex rename to lib/chat_models/chat_bumblebee.ex index e2c85415..65940774 100644 --- a/lib/chat_models/chat_bumble_model.ex +++ b/lib/chat_models/chat_bumblebee.ex @@ -1,4 +1,4 @@ -defmodule LangChain.ChatModels.ChatBumbleModel do +defmodule LangChain.ChatModels.ChatBumblebee do @moduledoc """ Represents a chat model hosted and accessed through Bumblebee. @@ -46,7 +46,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do field :stream, :boolean, default: true end - @type t :: %ChatBumbleModel{} + @type t :: %ChatBumblebee{} @type call_response :: {:ok, Message.t() | [Message.t()]} | {:error, String.t()} @type callback_data :: @@ -72,18 +72,18 @@ defmodule LangChain.ChatModels.ChatBumbleModel do @text_end_tag "" @doc """ - Setup a ChatBumbleModel client configuration. + Setup a ChatBumblebee client configuration. """ @spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()} def new(%{} = attrs \\ %{}) do - %ChatBumbleModel{} + %ChatBumblebee{} |> cast(attrs, @create_fields) |> common_validation() |> apply_action(:insert) end @doc """ - Setup a ChatBumbleModel client configuration and return it or raise an error if invalid. + Setup a ChatBumblebee client configuration and return it or raise an error if invalid. """ @spec new!(attrs :: map()) :: t() | no_return() def new!(attrs \\ %{}) do @@ -113,7 +113,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do ) :: call_response() def call(model, prompt, functions \\ [], callback_fn \\ nil) - def call(%ChatBumbleModel{} = model, prompt, functions, callback_fn) when is_binary(prompt) do + def call(%ChatBumblebee{} = model, prompt, functions, callback_fn) when is_binary(prompt) do messages = [ Message.new_system!(), Message.new_user!(prompt) @@ -122,7 +122,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do call(model, messages, functions, callback_fn) end - def call(%ChatBumbleModel{} = model, messages, functions, callback_fn) + def call(%ChatBumblebee{} = model, messages, functions, callback_fn) when is_list(messages) do if override_api_return?() do Logger.warning("Found override API response. Will not make live API call.") @@ -176,7 +176,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do @spec do_serving_request(t(), [Message.t()], [Function.t()], callback_fn()) :: list() | struct() | {:error, String.t()} def do_serving_request( - %ChatBumbleModel{stream: false} = model, + %ChatBumblebee{stream: false} = model, messages, _functions, callback_fn @@ -210,7 +210,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do end def do_serving_request( - %ChatBumbleModel{stream: true} = model, + %ChatBumblebee{stream: true} = model, messages, _functions, callback_fn @@ -289,14 +289,14 @@ defmodule LangChain.ChatModels.ChatBumbleModel do data :: callback_data(), nil | callback_fn() ) :: :ok - defp fire_callback(%ChatBumbleModel{stream: true}, _data, nil) do + defp fire_callback(%ChatBumblebee{stream: true}, _data, nil) do Logger.warning("Streaming call requested but no callback function was given.") :ok end - defp fire_callback(%ChatBumbleModel{stream: false}, _data, nil), do: :ok + defp fire_callback(%ChatBumblebee{stream: false}, _data, nil), do: :ok - defp fire_callback(%ChatBumbleModel{}, data, callback_fn) when is_function(callback_fn) do + defp fire_callback(%ChatBumblebee{}, data, callback_fn) when is_function(callback_fn) do # OPTIONAL: Execute callback function callback_fn.(data) :ok diff --git a/lib/chat_models/chat_open_ai.ex b/lib/chat_models/chat_open_ai.ex index 09e41cfb..e988fefa 100644 --- a/lib/chat_models/chat_open_ai.ex +++ b/lib/chat_models/chat_open_ai.ex @@ -52,6 +52,9 @@ defmodule LangChain.ChatModels.ChatOpenAI do field :n, :integer, default: 1 field :json_response, :boolean, default: false field :stream, :boolean, default: false + + # For compatibility between models, reflect that functions are supported + field :supports_functions, :boolean, default: true end @type t :: %ChatOpenAI{} diff --git a/test/chains/llm_chain_test.exs b/test/chains/llm_chain_test.exs index 716e6881..68e2b23a 100644 --- a/test/chains/llm_chain_test.exs +++ b/test/chains/llm_chain_test.exs @@ -8,6 +8,7 @@ defmodule LangChain.Chains.LLMChainTest do alias LangChain.Function alias LangChain.Message alias LangChain.MessageDelta + alias LangChain.ChatModels.ChatBumblebee setup do {:ok, chat} = ChatOpenAI.new(%{temperature: 0}) @@ -71,6 +72,13 @@ defmodule LangChain.Chains.LLMChainTest do assert updated_chain2.functions == [function, howdy_fn] assert updated_chain2.function_map == %{"hello_world" => function, "howdy" => howdy_fn} end + + test "raises an exception when the model does not support functions" do + chain = LLMChain.new!(%{llm: %ChatBumblebee{supports_functions: false}}) + assert_raise LangChain.LangChainError, "The LLM does not support functions.", fn -> + LLMChain.add_functions(chain, Function.new!(%{name: "test"})) + end + end end describe "cancelled_delta/1" do