diff --git a/lib/chains/llm_chain.ex b/lib/chains/llm_chain.ex index d57ed9b7..b0fd5a9d 100644 --- a/lib/chains/llm_chain.ex +++ b/lib/chains/llm_chain.ex @@ -41,6 +41,83 @@ defmodule LangChain.Chains.LLMChain do |> LLMChain.run() In the LiveView, a `handle_info` function executes with the received message. + + ## Fallbacks + + When running a chain, the `:with_fallbacks` option can be used to provide a + list of fallback chat models to try when a failure is encountered. + + When working with language models, you may often encounter issues from the + underlying APIs, whether these be rate limiting, downtime, or something else. + Therefore, as you go to move your LLM applications into production it becomes + more and more important to safeguard against these. That's what fallbacks are + designed to provide. + + A **fallback** is an alternative plan that may be used in an emergency. + + A `before_fallback` function can be provided to alter or return a different + chain to use with the fallback LLM model. This is important because often, the + prompts needed for will differ for a fallback LLM. This means if your OpenAI + completion fails, a different prompt may be needed when retrying it with an + Anthropic fallback. + + ### Fallback for LLM API Errors + + This is perhaps the most common use case for fallbacks. A request to an LLM + API can fail for a variety of reasons - the API could be down, you could have + hit rate limits, any number of things. Therefore, using fallbacks can help + protect against these types of failures. + + ## Fallback Examples + + A simple fallback that tries a different LLM chat model + + fallback_llm = ChatAnthropic.new!(%{stream: false}) + + {:ok, updated_chain} = + %{llm: ChatOpenAI.new!(%{stream: false})} + |> LLMChain.new!() + |> LLMChain.add_message(Message.new_system!("OpenAI system prompt")) + |> LLMChain.add_message(Message.new_user!("Why is the sky blue?")) + |> LLMChain.run(with_fallbacks: [fallback_llm]) + + Note the `with_fallbacks: [fallback_llm]` option when running the chain. + + This example uses the `:before_fallback` option to provide a function that can + modify or return an alternate chain when used with a certain LLM. Also note + the utility function `LangChain.Utils.replace_system_message!/2` is used for + swapping out the system message when falling back to a different LLM. + + fallback_llm = ChatAnthropic.new!(%{stream: false}) + + {:ok, updated_chain} = + %{llm: ChatOpenAI.new!(%{stream: false})} + |> LLMChain.new!() + |> LLMChain.add_message(Message.new_system!("OpenAI system prompt")) + |> LLMChain.add_message(Message.new_user!("Why is the sky blue?")) + |> LLMChain.run( + with_fallbacks: [fallback_llm], + before_fallback: fn chain -> + case chain.llm do + %ChatAnthropic{} -> + # replace the system message + %LLMChain{ + chain + | messages: + Utils.replace_system_message!( + chain.messages, + Message.new_system!("Anthropic system prompt") + ) + } + + _open_ai -> + chain + end + end + ) + + See `LangChain.Chains.LLMChain.run/2` for more details. + """ use Ecto.Schema import Ecto.Changeset @@ -93,7 +170,11 @@ defmodule LangChain.Chains.LLMChain do # Track the last `%Message{}` received in the chain. field :last_message, :any, virtual: true # Internally managed. The list of exchanged messages during a `run` function - # execution. + # execution. A single run can result in a number of newly created messages. + # It generates an Assistant message with one or more ToolCalls, the message + # with tool results where some of them may have failed requiring the LLM to + # try again. This list tracks the full set of exchanged messages during a + # single run. field :exchanged_messages, {:array, :any}, default: [], virtual: true # Track if the state of the chain expects a response from the LLM. This # happens after sending a user message, when a tool_call is received, or @@ -235,6 +316,21 @@ defmodule LangChain.Chains.LLMChain do are evaluated, the `ToolResult` messages are returned to the LLM giving it an opportunity to use the `ToolResult` information in an assistant response message. In essence, this mode always gives the LLM the last word. + + - `with_fallbacks: [...]` - Provide a list of chat models to use as a fallback + when one fails. This helps a production system remain operational when an + API limit is reached, an LLM service is overloaded or down, or something + else new an exciting goes wrong. + + When all fallbacks fail, a `%LangChainError{type: "all_fallbacks_failed"}` + is returned in the error response. + + - `before_fallback: fn chain -> modified_chain end` - A `before_fallback` + function is called before the LLM call is made. **NOTE: When provided, it + also fires for the first attempt.** This allows a chain to be modified or + replaced before running against the configured LLM. This is helpful, for + example, when a different system prompt is needed for Anthropic vs OpenAI. + """ @spec run(t(), Keyword.t()) :: {:ok, t()} | {:error, t(), LangChainError.t()} def run(chain, opts \\ []) @@ -253,22 +349,85 @@ defmodule LangChain.Chains.LLMChain do # clear the set of exchanged messages. chain = clear_exchanged_messages(chain) - case Keyword.get(opts, :mode, nil) do - nil -> - # run the chain and format the return - case do_run(chain) do - {:ok, chain} -> - {:ok, chain} + # determine which function to run based on the mode. + function_to_run = + case Keyword.get(opts, :mode, nil) do + nil -> + &do_run/1 - {:error, _chain, _reason} = error -> - error - end + :while_needs_response -> + &run_while_needs_response/1 + + :until_success -> + &run_until_success/1 + end + + # Run the chain and return the success or error results. NOTE: We do not add + # the current LLM to the list and process everything through a single + # codepath because failing after attempted fallbacks returns a different + # error. + if Keyword.has_key?(opts, :with_fallbacks) do + # run function and using fallbacks as needed. + with_fallbacks(chain, opts, function_to_run) + else + # run it directly right now and return the success or error + function_to_run.(chain) + end + end + + defp with_fallbacks(%LLMChain{} = chain, opts, run_fn) do + # Sources of inspiration: + # - https://python.langchain.com/v0.1/docs/guides/productionization/fallbacks/ + # - https://python.langchain.com/docs/how_to/fallbacks/ + # - https://python.langchain.com/docs/how_to/fallbacks/ + + llm_list = Keyword.fetch!(opts, :with_fallbacks) + before_fallback_fn = Keyword.get(opts, :before_fallback, nil) - :while_needs_response -> - run_while_needs_response(chain) + # try the chain where we go through the full list of LLMs to try. Add the + # current LLM as the first so all are processed the same way. + try_chain_with_llm(chain, [chain.llm | llm_list], before_fallback_fn, run_fn) + end + + # nothing left to try + defp try_chain_with_llm(chain, [], _before_fallback_fn, _run_fn) do + {:error, chain, + LangChainError.exception( + type: "all_fallbacks_failed", + message: "Failed all attempts to generate response" + )} + end + + defp try_chain_with_llm(chain, [llm | tail], before_fallback_fn, run_fn) do + use_chain = %LLMChain{chain | llm: llm} + + use_chain = + if before_fallback_fn do + # use the returned chain from the before_fallback function. + before_fallback_fn.(use_chain) + else + use_chain + end + + try do + case run_fn.(use_chain) do + {:ok, result} -> + {:ok, result} + + {:error, _error_chain, reason} -> + # run attempt received an error. Try again with the next LLM + Logger.warning("LLM call failed, using next fallback. Reason: #{inspect(reason)}") + + try_chain_with_llm(use_chain, tail, before_fallback_fn, run_fn) + end + rescue + err -> + # Log the error and try again. + Logger.error( + "Rescued from exception during with_fallback processing. Error: #{inspect(err)}" + ) - :until_success -> - run_until_success(chain) + try_chain_with_llm(use_chain, tail, before_fallback_fn, run_fn) end end diff --git a/lib/langchain_error.ex b/lib/langchain_error.ex index c192313e..07d125f0 100644 --- a/lib/langchain_error.ex +++ b/lib/langchain_error.ex @@ -30,7 +30,7 @@ defmodule LangChain.LangChainError do Create the exception using either a message or a changeset who's errors are converted to a message. """ - @spec exception(message :: String.t() | Ecto.Changeset.t()) :: t() + @spec exception(message :: String.t() | Ecto.Changeset.t()) :: t() | no_return() def exception(message) when is_binary(message) do %LangChainError{message: message} end diff --git a/lib/utils.ex b/lib/utils.ex index 11f3d3b3..98241b5e 100644 --- a/lib/utils.ex +++ b/lib/utils.ex @@ -305,4 +305,15 @@ defmodule LangChain.Utils do {List.first(system), other} end + + @doc """ + Replace the system message with a new system message. This retains all other + messages as-is. An error is raised if there are more than 1 system messages. + """ + @spec replace_system_message!([Message.t()], Message.t()) :: [Message.t()] | no_return() + def replace_system_message!(messages, new_system_message) do + {_old_system, rest} = split_system_message(messages) + # return the new system message along with the rest + [new_system_message | rest] + end end diff --git a/test/chains/llm_chain_test.exs b/test/chains/llm_chain_test.exs index 53b0d6c1..3da32f8d 100644 --- a/test/chains/llm_chain_test.exs +++ b/test/chains/llm_chain_test.exs @@ -17,6 +17,7 @@ defmodule LangChain.Chains.LLMChainTest do alias LangChain.MessageDelta alias LangChain.LangChainError alias LangChain.MessageProcessors.JsonProcessor + alias LangChain.Utils @anthropic_test_model "claude-3-opus-20240229" @@ -483,7 +484,9 @@ defmodule LangChain.Chains.LLMChainTest do ) |> LLMChain.apply_delta(MessageDelta.new!(%{content: "your "})) |> LLMChain.apply_delta(MessageDelta.new!(%{content: "favorite "})) - |> LLMChain.apply_delta({:error, LangChainError.exception(type: "overloaded", message: "Overloaded")}) + |> LLMChain.apply_delta( + {:error, LangChainError.exception(type: "overloaded", message: "Overloaded")} + ) # the delta is complete and removed from the chain assert updated_chain.delta == nil @@ -1429,6 +1432,108 @@ defmodule LangChain.Chains.LLMChainTest do assert reason.message == "Exceeded max failure count" assert updated_chain.current_failure_count == 3 end + + test "with_fallbacks: re-runs with next LLM after first fails" do + # Made NOT LIVE here - handles two calls + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + # IO.puts "FAKE OpenAI ERROR RESULT RETURNED" + {:error, + LangChainError.exception(type: "too_many_requests", message: "Too many requests!")} + end) + + expect(ChatAnthropic, :call, fn _model, _messages, _tools -> + {:ok, + [ + Message.new_assistant!(%{content: "fallback worked!"}) + ]} + end) + + {:ok, updated_chain} = + %{llm: ChatOpenAI.new!(%{stream: false})} + |> LLMChain.new!() + |> LLMChain.add_message(Message.new_system!()) + |> LLMChain.add_message(Message.new_user!("Why is the sky blue?")) + |> LLMChain.run(with_fallbacks: [ChatAnthropic.new!(%{stream: false})]) + + # stopped after processing a successful assistant response + assert updated_chain.last_message.role == :assistant + assert updated_chain.last_message.content == "fallback worked!" + end + + test "with_fallbacks: runs each LLM option and returns when all failed" do + # Made NOT LIVE here - handles two calls + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + # IO.puts "FAKE OpenAI ERROR RESULT RETURNED" + {:error, + LangChainError.exception(type: "too_many_requests", message: "Too many requests!")} + end) + + expect(ChatAnthropic, :call, fn _model, _messages, _tools -> + {:error, LangChainError.exception(type: "overloaded", message: "Overloaded")} + end) + + {:error, _updated_chain, reason} = + %{llm: ChatOpenAI.new!(%{stream: false})} + |> LLMChain.new!() + |> LLMChain.add_message(Message.new_system!()) + |> LLMChain.add_message(Message.new_user!("Why is the sky blue?")) + |> LLMChain.run(with_fallbacks: [ChatAnthropic.new!(%{stream: false})]) + + assert %LangChainError{ + type: "all_fallbacks_failed", + message: "Failed all attempts to generate response" + } == reason + end + + test "with_fallbacks: runs before_fallback function and uses the resulting chain" do + # Made NOT LIVE here - handles two calls + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + # IO.puts "FAKE OpenAI ERROR RESULT RETURNED" + {:error, + LangChainError.exception(type: "too_many_requests", message: "Too many requests!")} + end) + + expect(ChatAnthropic, :call, fn _model, _messages, _tools -> + {:ok, Message.new_assistant!(%{content: "Claude says it's because it's not red."})} + end) + + {:ok, updated_chain} = + %{llm: ChatOpenAI.new!(%{stream: false})} + |> LLMChain.new!() + |> LLMChain.add_message(Message.new_system!("OpenAI system prompt")) + |> LLMChain.add_message(Message.new_user!("Why is the sky blue?")) + |> LLMChain.run( + with_fallbacks: [ + ChatAnthropic.new!(%{stream: false}) + ], + before_fallback: fn chain -> + send(self(), :before_fallback_fired) + + case chain.llm do + %ChatAnthropic{} -> + # replace the system message + %LLMChain{ + chain + | messages: + Utils.replace_system_message!( + chain.messages, + Message.new_system!("Anthropic system prompt") + ) + } + + _open_ai -> + chain + end + end + ) + + assert [system_msg | _rest] = updated_chain.messages + assert system_msg.role == :system + assert system_msg.content == "Anthropic system prompt" + assert updated_chain.last_message.role == :assistant + assert updated_chain.last_message.content == "Claude says it's because it's not red." + assert_received :before_fallback_fired + end end describe "increment_current_failure_count/1" do diff --git a/test/utils_test.exs b/test/utils_test.exs index d32defe4..e7b6a58d 100644 --- a/test/utils_test.exs +++ b/test/utils_test.exs @@ -204,4 +204,37 @@ defmodule LangChain.UtilsTest do end end end + + describe "replace_system_message!/2" do + test "returns list with new system message" do + non_system = [ + Message.new_user!("User 1"), + Message.new_assistant!("Assistant 1") + ] + + [new_system | rest] = + Utils.replace_system_message!( + [Message.new_system!("System A") | non_system], + Message.new_system!("System B") + ) + + assert rest == non_system + assert new_system.role == :system + assert new_system.content == "System B" + end + + test "handles when no existing system message" do + non_system = [ + Message.new_user!("User 1"), + Message.new_assistant!("Assistant 1") + ] + + [new_system | rest] = + Utils.replace_system_message!(non_system, Message.new_system!("System B")) + + assert rest == non_system + assert new_system.role == :system + assert new_system.content == "System B" + end + end end