Skip to content

Commit

Permalink
track models that don't support functions
Browse files Browse the repository at this point in the history
- raise error when adding functions and not supported
  • Loading branch information
brainlid committed Dec 8, 2023
1 parent 1a03db8 commit 704b6fd
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 12 deletions.
7 changes: 7 additions & 0 deletions lib/chains/llm_chain.ex
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ defmodule LangChain.Chains.LLMChain do
Add more functions to an LLMChain.
"""
@spec add_functions(t(), Function.t() | [Function.t()]) :: t() | no_return()
def add_functions(
%LLMChain{llm: %{supports_functions: false}} = _chain,
_functions
) do
raise LangChainError, "The LLM does not support functions."
end

def add_functions(%LLMChain{} = chain, %Function{} = function) do
add_functions(chain, [function])
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defmodule LangChain.ChatModels.ChatBumbleModel do
defmodule LangChain.ChatModels.ChatBumblebee do
@moduledoc """
Represents a chat model hosted and accessed through Bumblebee.
Expand Down Expand Up @@ -46,7 +46,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do
field :stream, :boolean, default: true
end

@type t :: %ChatBumbleModel{}
@type t :: %ChatBumblebee{}

@type call_response :: {:ok, Message.t() | [Message.t()]} | {:error, String.t()}
@type callback_data ::
Expand All @@ -72,18 +72,18 @@ defmodule LangChain.ChatModels.ChatBumbleModel do
@text_end_tag "</s>"

Check warning on line 72 in lib/chat_models/chat_bumblebee.ex

View workflow job for this annotation

GitHub Actions / Build and test

module attribute @text_end_tag was set but never used

@doc """
Setup a ChatBumbleModel client configuration.
Setup a ChatBumblebee client configuration.
"""
@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(%{} = attrs \\ %{}) do
%ChatBumbleModel{}
%ChatBumblebee{}
|> cast(attrs, @create_fields)
|> common_validation()
|> apply_action(:insert)
end

@doc """
Setup a ChatBumbleModel client configuration and return it or raise an error if invalid.
Setup a ChatBumblebee client configuration and return it or raise an error if invalid.
"""
@spec new!(attrs :: map()) :: t() | no_return()
def new!(attrs \\ %{}) do
Expand Down Expand Up @@ -113,7 +113,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do
) :: call_response()
def call(model, prompt, functions \\ [], callback_fn \\ nil)

def call(%ChatBumbleModel{} = model, prompt, functions, callback_fn) when is_binary(prompt) do
def call(%ChatBumblebee{} = model, prompt, functions, callback_fn) when is_binary(prompt) do
messages = [
Message.new_system!(),
Message.new_user!(prompt)
Expand All @@ -122,7 +122,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do
call(model, messages, functions, callback_fn)
end

def call(%ChatBumbleModel{} = model, messages, functions, callback_fn)
def call(%ChatBumblebee{} = model, messages, functions, callback_fn)
when is_list(messages) do
if override_api_return?() do
Logger.warning("Found override API response. Will not make live API call.")
Expand Down Expand Up @@ -176,7 +176,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do
@spec do_serving_request(t(), [Message.t()], [Function.t()], callback_fn()) ::
list() | struct() | {:error, String.t()}
def do_serving_request(
%ChatBumbleModel{stream: false} = model,
%ChatBumblebee{stream: false} = model,
messages,
_functions,
callback_fn
Expand Down Expand Up @@ -210,7 +210,7 @@ defmodule LangChain.ChatModels.ChatBumbleModel do
end

def do_serving_request(
%ChatBumbleModel{stream: true} = model,
%ChatBumblebee{stream: true} = model,
messages,
_functions,
callback_fn
Expand Down Expand Up @@ -289,14 +289,14 @@ defmodule LangChain.ChatModels.ChatBumbleModel do
data :: callback_data(),
nil | callback_fn()
) :: :ok
defp fire_callback(%ChatBumbleModel{stream: true}, _data, nil) do
defp fire_callback(%ChatBumblebee{stream: true}, _data, nil) do
Logger.warning("Streaming call requested but no callback function was given.")
:ok
end

defp fire_callback(%ChatBumbleModel{stream: false}, _data, nil), do: :ok
defp fire_callback(%ChatBumblebee{stream: false}, _data, nil), do: :ok

defp fire_callback(%ChatBumbleModel{}, data, callback_fn) when is_function(callback_fn) do
defp fire_callback(%ChatBumblebee{}, data, callback_fn) when is_function(callback_fn) do
# OPTIONAL: Execute callback function
callback_fn.(data)
:ok
Expand Down
3 changes: 3 additions & 0 deletions lib/chat_models/chat_open_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ defmodule LangChain.ChatModels.ChatOpenAI do
field :n, :integer, default: 1
field :json_response, :boolean, default: false
field :stream, :boolean, default: false

# For compatibility between models, reflect that functions are supported
field :supports_functions, :boolean, default: true
end

@type t :: %ChatOpenAI{}
Expand Down
8 changes: 8 additions & 0 deletions test/chains/llm_chain_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ defmodule LangChain.Chains.LLMChainTest do
alias LangChain.Function
alias LangChain.Message
alias LangChain.MessageDelta
alias LangChain.ChatModels.ChatBumblebee

setup do
{:ok, chat} = ChatOpenAI.new(%{temperature: 0})
Expand Down Expand Up @@ -71,6 +72,13 @@ defmodule LangChain.Chains.LLMChainTest do
assert updated_chain2.functions == [function, howdy_fn]
assert updated_chain2.function_map == %{"hello_world" => function, "howdy" => howdy_fn}
end

test "raises an exception when the model does not support functions" do
chain = LLMChain.new!(%{llm: %ChatBumblebee{supports_functions: false}})
assert_raise LangChain.LangChainError, "The LLM does not support functions.", fn ->
LLMChain.add_functions(chain, Function.new!(%{name: "test"}))
end
end
end

describe "cancelled_delta/1" do
Expand Down

0 comments on commit 704b6fd

Please sign in to comment.