diff --git a/lib/chat_models/chat_anthropic.ex b/lib/chat_models/chat_anthropic.ex index 280944d1..490b889b 100644 --- a/lib/chat_models/chat_anthropic.ex +++ b/lib/chat_models/chat_anthropic.ex @@ -40,7 +40,6 @@ defmodule LangChain.ChatModels.ChatAnthropic do use Ecto.Schema require Logger import Ecto.Changeset - import LangChain.Utils.ApiOverride alias __MODULE__ alias LangChain.Config alias LangChain.ChatModels.ChatModel @@ -259,40 +258,19 @@ defmodule LangChain.ChatModels.ChatAnthropic do call(anthropic, messages, functions) end - def call(%ChatAnthropic{} = anthropic, messages, functions) - when is_list(messages) do - if override_api_return?() do - Logger.warning("Found override API response. Will not make live API call.") + def call(%ChatAnthropic{} = anthropic, messages, functions) when is_list(messages) do + try do + # make base api request and perform high-level success/failure checks + case do_api_request(anthropic, messages, functions) do + {:error, reason} -> + {:error, reason} - case get_api_override() do - {:ok, {:ok, data, callback_name}} -> - # fire callback for fake responses too - Callbacks.fire(anthropic.callbacks, callback_name, [anthropic, data]) - # return the data portion - {:ok, data} - - # fake error response - {:ok, {:error, _reason} = response} -> - response - - _other -> - raise LangChainError, - "An unexpected fake API response was set. Should be an `{:ok, value, nil_or_callback_name}`" - end - else - try do - # make base api request and perform high-level success/failure checks - case do_api_request(anthropic, messages, functions) do - {:error, reason} -> - {:error, reason} - - parsed_data -> - {:ok, parsed_data} - end - rescue - err in LangChainError -> - {:error, err.message} + parsed_data -> + {:ok, parsed_data} end + rescue + err in LangChainError -> + {:error, err.message} end end diff --git a/lib/chat_models/chat_bumblebee.ex b/lib/chat_models/chat_bumblebee.ex index 89d73849..569bbb53 100644 --- a/lib/chat_models/chat_bumblebee.ex +++ b/lib/chat_models/chat_bumblebee.ex @@ -98,7 +98,6 @@ defmodule LangChain.ChatModels.ChatBumblebee do use Ecto.Schema require Logger import Ecto.Changeset - import LangChain.Utils.ApiOverride alias __MODULE__ alias LangChain.ChatModels.ChatModel alias LangChain.Message @@ -218,35 +217,19 @@ defmodule LangChain.ChatModels.ChatBumblebee do call(model, messages, functions) end - def call(%ChatBumblebee{} = model, messages, functions) - when is_list(messages) do - if override_api_return?() do - Logger.warning("Found override API response. Will not make live API call.") + def call(%ChatBumblebee{} = model, messages, functions) when is_list(messages) do + try do + # make base api request and perform high-level success/failure checks + case do_serving_request(model, messages, functions) do + {:error, reason} -> + {:error, reason} - # fire callback for fake responses too - case get_api_override() do - {:ok, {:ok, data, callback_name}} -> - Callbacks.fire(model.callbacks, callback_name, [model, data]) - {:ok, data} - - _other -> - raise LangChainError, - "An unexpected fake API response was set. Should be an `{:ok, value, nil_or_callback_name}`" - end - else - try do - # make base api request and perform high-level success/failure checks - case do_serving_request(model, messages, functions) do - {:error, reason} -> - {:error, reason} - - parsed_data -> - {:ok, parsed_data} - end - rescue - err in LangChainError -> - {:error, err.message} + parsed_data -> + {:ok, parsed_data} end + rescue + err in LangChainError -> + {:error, err.message} end end diff --git a/lib/chat_models/chat_mistral_ai.ex b/lib/chat_models/chat_mistral_ai.ex index 1325dd88..802f6b05 100644 --- a/lib/chat_models/chat_mistral_ai.ex +++ b/lib/chat_models/chat_mistral_ai.ex @@ -2,7 +2,6 @@ defmodule LangChain.ChatModels.ChatMistralAI do use Ecto.Schema require Logger import Ecto.Changeset - import LangChain.Utils.ApiOverride alias __MODULE__ alias LangChain.Config alias LangChain.ChatModels.ChatOpenAI @@ -146,37 +145,18 @@ defmodule LangChain.ChatModels.ChatMistralAI do end def call(%ChatMistralAI{} = mistral, messages, functions) when is_list(messages) do - if override_api_return?() do - Logger.warning("Found override API response. Will not make live API call.") - - case get_api_override() do - {:ok, {:ok, data, callback_name} = response} -> - # fire callback for fake responses too - Callbacks.fire(mistral.callbacks, callback_name, [mistral, data]) - response - - # fake error response - {:ok, {:error, _reason} = response} -> - response - - _other -> - raise LangChainError, - "An unexpected fake API response was set. Should be an `{:ok, value, nil_or_callback_name}`" - end - else - try do - # make base api request and perform high-level success/failure checks - case do_api_request(mistral, messages, functions) do - {:error, reason} -> - {:error, reason} - - parsed_data -> - {:ok, parsed_data} - end - rescue - err in LangChainError -> - {:error, err.message} + try do + # make base api request and perform high-level success/failure checks + case do_api_request(mistral, messages, functions) do + {:error, reason} -> + {:error, reason} + + parsed_data -> + {:ok, parsed_data} end + rescue + err in LangChainError -> + {:error, err.message} end end diff --git a/lib/chat_models/chat_open_ai.ex b/lib/chat_models/chat_open_ai.ex index 2097499a..5e97e2f8 100644 --- a/lib/chat_models/chat_open_ai.ex +++ b/lib/chat_models/chat_open_ai.ex @@ -69,7 +69,6 @@ defmodule LangChain.ChatModels.ChatOpenAI do use Ecto.Schema require Logger import Ecto.Changeset - import LangChain.Utils.ApiOverride alias __MODULE__ alias LangChain.Config alias LangChain.ChatModels.ChatModel @@ -432,38 +431,18 @@ defmodule LangChain.ChatModels.ChatOpenAI do end def call(%ChatOpenAI{} = openai, messages, tools) when is_list(messages) do - if override_api_return?() do - Logger.warning("Found override API response. Will not make live API call.") - - case get_api_override() do - {:ok, {:ok, data, callback_name}} -> - # fire callback for fake responses too - Callbacks.fire(openai.callbacks, callback_name, [openai, data]) - # return the data portion - {:ok, data} - - # fake error response - {:ok, {:error, _reason} = response} -> - response - - _other -> - raise LangChainError, - "An unexpected fake API response was set. Should be an `{:ok, value, nil_or_callback_name}`" - end - else - try do - # make base api request and perform high-level success/failure checks - case do_api_request(openai, messages, tools) do - {:error, reason} -> - {:error, reason} - - parsed_data -> - {:ok, parsed_data} - end - rescue - err in LangChainError -> - {:error, err.message} + try do + # make base api request and perform high-level success/failure checks + case do_api_request(openai, messages, tools) do + {:error, reason} -> + {:error, reason} + + parsed_data -> + {:ok, parsed_data} end + rescue + err in LangChainError -> + {:error, err.message} end end diff --git a/lib/images/open_ai_image.ex b/lib/images/open_ai_image.ex index 66934a16..44da97d5 100644 --- a/lib/images/open_ai_image.ex +++ b/lib/images/open_ai_image.ex @@ -13,7 +13,6 @@ defmodule LangChain.Images.OpenAIImage do use Ecto.Schema require Logger import Ecto.Changeset - import LangChain.Utils.ApiOverride alias __MODULE__ alias LangChain.Images.GeneratedImage alias LangChain.Config @@ -184,35 +183,18 @@ defmodule LangChain.Images.OpenAIImage do def call(openai) def call(%OpenAIImage{} = openai) do - if override_api_return?() do - Logger.warning("Found override API response. Will not make live API call.") - - case get_api_override() do - {:ok, {:ok, data, _callback_name} = _response} -> - {:ok, data} - - # fake error response - {:ok, {:error, _reason} = response} -> - response - - _other -> - raise LangChainError, - "An unexpected fake API response was set. Should be an `{:ok, value, nil_or_callback_name}`" - end - else - try do - # make base api request and perform high-level success/failure checks - case do_api_request(openai) do - {:error, reason} -> - {:error, reason} - - {:ok, parsed_data} -> - {:ok, parsed_data} - end - rescue - err in LangChainError -> - {:error, err.message} + try do + # make base api request and perform high-level success/failure checks + case do_api_request(openai) do + {:error, reason} -> + {:error, reason} + + {:ok, parsed_data} -> + {:ok, parsed_data} end + rescue + err in LangChainError -> + {:error, err.message} end end diff --git a/lib/utils/api_override.ex b/lib/utils/api_override.ex deleted file mode 100644 index c1353ca2..00000000 --- a/lib/utils/api_override.ex +++ /dev/null @@ -1,77 +0,0 @@ -defmodule LangChain.Utils.ApiOverride do - @moduledoc """ - Tools for overriding API results. Used for testing. - - Works by setting and checking for special use of the Process dictionary. - - ## Test Example - - import LangChain.Utils.ApiOverride - - model = ChatOpenAI.new!(%{temperature: 1, stream: true}) - - # Define the fake response to return - fake_messages = [ - [MessageDelta.new!(%{role: :assistant, content: nil, status: :incomplete})], - [MessageDelta.new!(%{content: "Sock", status: :incomplete})] - ] - - # Made NOT LIVE here. Will not make the external call to the LLM - set_api_override({:ok, fake_messages}) - - # We can construct an LLMChain from a PromptTemplate and an LLM. - {:ok, updated_chain, _response} = - %{llm: model, verbose: false} - |> LLMChain.new!() - |> LLMChain.add_message( - Message.new_user!("What is a good name for a company that makes colorful socks?") - ) - |> LLMChain.run() - - assert %Message{role: :assistant, content: "Sock"} = updated_chain.last_message - """ - - @key :fake_api_response - - @doc """ - Return if an override for the API response is set. Used for testing. - """ - @spec override_api_return? :: boolean() - def override_api_return?() do - @key in Process.get_keys() - end - - @doc """ - Set the data and callback to use as a fake API response. An `:ok` tuple - indicates a successful API call. The `fake_response_data` is the data to treat - as returned. The `callback_name` is the callback handler name to execute. - - set_api_override({:ok, fake_response_data, callback_name_to_fire}) - - ## Examples - - set_api_override({:ok, Message.new_assistant!(%{content: "154 bottles"}, :on_llm_new_message}) - - set_api_override({:ok, MessageDelta.new!(%{content: "Hi"}), :on_llm_new_delta}) - """ - @spec set_api_override(term()) :: :ok - def set_api_override(config_tuple) do - Process.put(@key, config_tuple) - :ok - end - - @doc """ - Get the API override to return. Returned as `{:ok, config_tuple}`. If not set, - it returns `:not_set`. - """ - @spec get_api_override() :: {:ok, term()} | :not_set - def get_api_override() do - case Process.get(@key, :not_set) do - :not_set -> - :not_set - - value -> - {:ok, value} - end - end -end diff --git a/mix.exs b/mix.exs index e6d69033..d46a8738 100644 --- a/mix.exs +++ b/mix.exs @@ -41,7 +41,8 @@ defmodule LangChain.MixProject do {:req, ">= 0.5.0"}, {:abacus, "~> 2.1.0"}, {:nx, ">= 0.7.0", optional: true}, - {:ex_doc, "~> 0.34", only: :dev, runtime: false} + {:ex_doc, "~> 0.34", only: :dev, runtime: false}, + {:mimic, "~> 1.8", only: :test} ] end diff --git a/mix.lock b/mix.lock index 5a0f3472..eff07dd7 100644 --- a/mix.lock +++ b/mix.lock @@ -16,6 +16,7 @@ "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.0", "6f0eff9c9c489f26b69b61440bf1b238d95badae49adac77973cbacae87e3c2e", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "ea7a9307de9d1548d2a72d299058d1fd2339e3d398560a0e46c27dab4891e4d2"}, "mime": {:hex, :mime, "2.0.5", "dc34c8efd439abe6ae0343edbb8556f4d63f178594894720607772a041b04b02", [:mix], [], "hexpm", "da0d64a365c45bc9935cc5c8a7fc5e49a0e0f9932a761c55d6c52b142780a05c"}, + "mimic": {:hex, :mimic, "1.8.2", "f4cf6ad13a305c5ee1a6c304ee36fc7afb3ad748e2af8cd5fbf122f44a283534", [:mix], [], "hexpm", "abc982d5fdcc4cb5292980cb698cd47c0c5d9541b401e540fb695b69f1d46485"}, "mint": {:hex, :mint, "1.6.1", "065e8a5bc9bbd46a41099dfea3e0656436c5cbcb6e741c80bd2bad5cd872446f", [:mix], [{:castore, "~> 0.1.0 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:hpax, "~> 0.1.1 or ~> 0.2.0", [hex: :hpax, repo: "hexpm", optional: false]}], "hexpm", "4fc518dcc191d02f433393a72a7ba3f6f94b101d094cb6bf532ea54c89423780"}, "nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"}, "nimble_ownership": {:hex, :nimble_ownership, "0.3.1", "99d5244672fafdfac89bfad3d3ab8f0d367603ce1dc4855f86a1c75008bce56f", [:mix], [], "hexpm", "4bf510adedff0449a1d6e200e43e57a814794c8b5b6439071274d248d272a549"}, diff --git a/test/chains/llm_chain_test.exs b/test/chains/llm_chain_test.exs index 97740ea5..51bf7030 100644 --- a/test/chains/llm_chain_test.exs +++ b/test/chains/llm_chain_test.exs @@ -1,5 +1,6 @@ defmodule LangChain.Chains.LLMChainTest do use LangChain.BaseCase + use Mimic doctest LangChain.Chains.LLMChain @@ -194,7 +195,9 @@ defmodule LangChain.Chains.LLMChainTest do [MessageDelta.new!(%{content: "Sock", status: :incomplete})] ] - set_api_override({:ok, fake_messages, :on_llm_new_delta}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, fake_messages} + end) # We can construct an LLMChain from a PromptTemplate and an LLM. {:ok, updated_chain, _response} = @@ -286,7 +289,10 @@ defmodule LangChain.Chains.LLMChainTest do # Made NOT LIVE here fake_message = Message.new!(%{role: :assistant, content: "Socktastic!", status: :complete}) - set_api_override({:ok, [fake_message], nil}) + + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [fake_message]} + end) # We can construct an LLMChain from a PromptTemplate and an LLM. {:ok, %LLMChain{} = updated_chain, message} = @@ -333,7 +339,9 @@ defmodule LangChain.Chains.LLMChainTest do [MessageDelta.new!(%{content: nil, status: :complete})] ] - set_api_override({:ok, fake_messages, :on_llm_new_delta}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, fake_messages} + end) # We can construct an LLMChain from a PromptTemplate and an LLM. {:ok, updated_chain, response} = @@ -345,10 +353,7 @@ defmodule LangChain.Chains.LLMChainTest do assert %Message{role: :assistant, content: "Socktastic!", status: :complete} = response assert updated_chain.last_message == response - # we should have received at least one callback message delta - assert_received {:fake_stream_deltas, delta_1} - assert %MessageDelta{role: :assistant, status: :incomplete} = delta_1 - + # we should have received a message for the completed, combined message assert_received {:fake_full_message, message} assert %Message{role: :assistant, content: "Socktastic!"} = message end @@ -1045,7 +1050,10 @@ defmodule LangChain.Chains.LLMChainTest do Message.new_assistant!(%{content: "Not what you wanted"}) ] - set_api_override({:ok, fake_messages, nil}) + # expect it to be called 3 times + expect(ChatOpenAI, :call, 3, fn _model, _messages, _tools -> + {:ok, fake_messages} + end) messages = [ Message.new_user!("Say what I want you to say.") @@ -1106,6 +1114,11 @@ defmodule LangChain.Chains.LLMChainTest do Message.new_assistant!(%{content: "Not what you wanted"}) ] + # expects to be called 2 times + expect(ChatOpenAI, :call, 2, fn _model, _messages, _tools -> + {:ok, fake_messages} + end) + chain = LLMChain.new!(%{ llm: ChatOpenAI.new!(%{temperature: 0}), @@ -1114,8 +1127,6 @@ defmodule LangChain.Chains.LLMChainTest do callbacks: [handler] }) - set_api_override({:ok, fake_messages, :on_llm_new_message}) - {:error, error_chain, reason} = chain |> LLMChain.message_processors([JsonProcessor.new!()]) @@ -1162,7 +1173,9 @@ defmodule LangChain.Chains.LLMChainTest do Message.new_assistant!(%{content: Jason.encode!(%{value: "abc"})}) ] - set_api_override({:ok, fake_messages, nil}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, fake_messages} + end) {:ok, _updated_chain, last_message} = chain @@ -1177,22 +1190,20 @@ defmodule LangChain.Chains.LLMChainTest do end test "mode: :until_success - message needs processing, fails, then succeeds", %{chat: chat} do - handler = %{ - on_message_processing_error: fn _chain, _data -> - # after the first processing error message, set to return a correct one - fake_messages = [ - Message.new_assistant!(%{content: Jason.encode!(%{value: "abc"})}) - ] - - set_api_override({:ok, fake_messages, nil}) - end - } - - # Made NOT LIVE here - set_api_override({:ok, [Message.new_assistant!(%{content: "invalid"})], nil}) + # Made NOT LIVE here - handles two consecutive calls + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [Message.new_assistant!(%{content: "invalid"})]} + end) + + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, + [ + Message.new_assistant!(%{content: Jason.encode!(%{value: "abc"})}) + ]} + end) {:ok, _updated_chain, last_message} = - %{llm: chat, callbacks: [handler]} + %{llm: chat} |> LLMChain.new!() |> LLMChain.message_processors([JsonProcessor.new!()]) |> LLMChain.add_message(Message.new_system!()) @@ -1214,7 +1225,9 @@ defmodule LangChain.Chains.LLMChainTest do ]) ] - set_api_override({:ok, fake_messages, nil}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, fake_messages} + end) {:ok, updated_chain, last_message} = %{llm: ChatOpenAI.new!(%{stream: false}), verbose: false} @@ -1242,7 +1255,9 @@ defmodule LangChain.Chains.LLMChainTest do ]) ] - set_api_override({:ok, fake_messages, nil}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, fake_messages} + end) {:ok, updated_chain, last_message} = %{llm: ChatOpenAI.new!(%{stream: false}), verbose: false} @@ -1268,7 +1283,10 @@ defmodule LangChain.Chains.LLMChainTest do ]) ] - set_api_override({:ok, fake_messages, nil}) + # expect 3 calls + expect(ChatOpenAI, :call, 3, fn _model, _messages, _tools -> + {:ok, fake_messages} + end) {:error, updated_chain, reason} = %{llm: ChatOpenAI.new!(%{stream: false}), verbose: false} diff --git a/test/chains/routing_chain_test.exs b/test/chains/routing_chain_test.exs index 7132f247..12742a79 100644 --- a/test/chains/routing_chain_test.exs +++ b/test/chains/routing_chain_test.exs @@ -1,5 +1,6 @@ defmodule LangChain.Chains.RoutingChainTest do use LangChain.BaseCase + use Mimic doctest LangChain.Chains.RoutingChain @@ -125,9 +126,12 @@ defmodule LangChain.Chains.RoutingChainTest do describe "run/2" do test "runs and returns updated chain and last message", %{routing_chain: routing_chain} do + # Made NOT LIVE here fake_message = Message.new_assistant!("blog") - fake_response = {:ok, [fake_message], nil} - set_api_override(fake_response) + + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [fake_message]} + end) assert {:ok, updated_chain, last_msg} = RoutingChain.run(routing_chain) assert %LLMChain{} = updated_chain @@ -140,13 +144,22 @@ defmodule LangChain.Chains.RoutingChainTest do routing_chain: routing_chain, default_route: default_route } do - set_api_override({:ok, [Message.new_assistant!("blog")], nil}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [Message.new_assistant!("blog")]} + end) + assert %PromptRoute{name: "blog"} = RoutingChain.evaluate(routing_chain) - set_api_override({:ok, [Message.new_assistant!("memo")], nil}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [Message.new_assistant!("memo")]} + end) + assert %PromptRoute{name: "memo"} = RoutingChain.evaluate(routing_chain) - set_api_override({:ok, [Message.new_assistant!("DEFAULT")], nil}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [Message.new_assistant!("DEFAULT")]} + end) + assert default_route == RoutingChain.evaluate(routing_chain) end @@ -154,7 +167,10 @@ defmodule LangChain.Chains.RoutingChainTest do routing_chain: routing_chain, default_route: default_route } do - set_api_override({:ok, [Message.new_assistant!("invalid")], nil}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [Message.new_assistant!("invalid")]} + end) + assert default_route == RoutingChain.evaluate(routing_chain) end @@ -162,7 +178,10 @@ defmodule LangChain.Chains.RoutingChainTest do routing_chain: routing_chain, default_route: default_route } do - set_api_override({:error, "FAKE API call failure"}) + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:error, "FAKE API call failure"} + end) + assert default_route == RoutingChain.evaluate(routing_chain) end end diff --git a/test/chains/text_to_title_chain_test.exs b/test/chains/text_to_title_chain_test.exs index bfe3adea..4f74d9d6 100644 --- a/test/chains/text_to_title_chain_test.exs +++ b/test/chains/text_to_title_chain_test.exs @@ -1,5 +1,6 @@ defmodule LangChain.Chains.TextToTitleChainTest do use LangChain.BaseCase + use Mimic doctest LangChain.Chains.TextToTitleChain @@ -60,8 +61,11 @@ defmodule LangChain.Chains.TextToTitleChainTest do describe "run/2" do test "runs and returns updated chain and last message", %{title_chain: title_chain} do fake_message = Message.new_assistant!("Summarized Title") - fake_response = {:ok, [fake_message], nil} - set_api_override(fake_response) + + # Made NOT LIVE here + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [fake_message]} + end) assert {:ok, updated_chain, last_msg} = TextToTitleChain.run(title_chain) assert %LLMChain{} = updated_chain @@ -71,7 +75,11 @@ defmodule LangChain.Chains.TextToTitleChainTest do describe "evaluate/2" do test "returns the summarized title", %{title_chain: title_chain} do - set_api_override({:ok, [Message.new_assistant!("Special Title")], nil}) + # Made NOT LIVE here + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:ok, [Message.new_assistant!("Special Title")]} + end) + assert "Special Title" == TextToTitleChain.evaluate(title_chain) end @@ -79,7 +87,11 @@ defmodule LangChain.Chains.TextToTitleChainTest do title_chain: title_chain, fallback_title: fallback_title } do - set_api_override({:error, "FAKE API call failure"}) + # Made NOT LIVE here + expect(ChatOpenAI, :call, fn _model, _messages, _tools -> + {:error, "FAKE API call failure"} + end) + assert fallback_title == TextToTitleChain.evaluate(title_chain) end end diff --git a/test/chat_models/chat_bumblebee_test.exs b/test/chat_models/chat_bumblebee_test.exs index b8107143..cab03d53 100644 --- a/test/chat_models/chat_bumblebee_test.exs +++ b/test/chat_models/chat_bumblebee_test.exs @@ -1,6 +1,6 @@ defmodule LangChain.ChatModels.ChatBumblebeeTest do use LangChain.BaseCase - import LangChain.Utils.ApiOverride + use Mimic doctest LangChain.ChatModels.ChatBumblebee alias LangChain.ChatModels.ChatBumblebee @@ -22,22 +22,6 @@ defmodule LangChain.ChatModels.ChatBumblebeeTest do end end - describe "call/4" do - test "supports API override" do - set_api_override({:ok, [Message.new_assistant!("Colorful Threads")], :on_llm_new_message}) - - # https://js.langchain.com/docs/modules/models/chat/ - {:ok, chat} = ChatBumblebee.new(%{serving: Fake}) - - {:ok, [%Message{role: :assistant, content: response}]} = - ChatBumblebee.call(chat, [ - Message.new_user!("Return the response 'Colorful Threads'.") - ]) - - assert response =~ "Colorful Threads" - end - end - describe "do_process_response/3" do setup do handler = %{ diff --git a/test/chat_models/chat_open_ai_test.exs b/test/chat_models/chat_open_ai_test.exs index d96a6014..ab57117c 100644 --- a/test/chat_models/chat_open_ai_test.exs +++ b/test/chat_models/chat_open_ai_test.exs @@ -468,8 +468,6 @@ defmodule LangChain.ChatModels.ChatOpenAITest do describe "call/2" do @tag live_call: true, live_open_ai: true test "basic content example and fires ratelimit callback" do - # set_fake_llm_response({:ok, Message.new_assistant("\n\nRainbow Sox Co.")}) - handlers = %{ on_llm_ratelimit_info: fn _model, headers -> send(self(), {:fired_ratelimit_info, headers}) @@ -508,8 +506,6 @@ defmodule LangChain.ChatModels.ChatOpenAITest do @tag live_call: true, live_open_ai: true test "basic streamed content example's final result and fires ratelimit callback" do - # set_fake_llm_response({:ok, Message.new_assistant("\n\nRainbow Sox Co.")}) - handlers = %{ on_llm_ratelimit_info: fn _model, headers -> send(self(), {:fired_ratelimit_info, headers}) @@ -584,8 +580,6 @@ defmodule LangChain.ChatModels.ChatOpenAITest do @tag live_call: true, live_open_ai: true test "basic streamed content fires token usage callback" do - # set_fake_llm_response({:ok, Message.new_assistant("\n\nRainbow Sox Co.")}) - handlers = %{ on_llm_token_usage: fn _model, usage -> send(self(), {:fired_token_usage, usage}) diff --git a/test/support/base_case.ex b/test/support/base_case.ex index 45945dcf..80e7c446 100644 --- a/test/support/base_case.ex +++ b/test/support/base_case.ex @@ -13,7 +13,6 @@ defmodule LangChain.BaseCase do # Import conveniences for testing with AI models import LangChain.BaseCase - import LangChain.Utils.ApiOverride @doc """ Helper function for loading an image as base64 encoded text. diff --git a/test/test_helper.exs b/test/test_helper.exs index 3a695ce0..02f2e53a 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -3,6 +3,12 @@ Application.put_env(:langchain, :openai_key, System.fetch_env!("OPENAI_API_KEY") Application.put_env(:langchain, :anthropic_key, System.fetch_env!("ANTHROPIC_API_KEY")) Application.put_env(:langchain, :google_ai_key, System.fetch_env!("GOOGLE_API_KEY")) +Mimic.copy(LangChain.ChatModels.ChatOpenAI) +Mimic.copy(LangChain.ChatModels.ChatAnthropic) +Mimic.copy(LangChain.ChatModels.ChatMistralAI) +Mimic.copy(LangChain.ChatModels.ChatBumblebee) +Mimic.copy(LangChain.Images.OpenAIImage) + ExUnit.configure(capture_log: true, exclude: [live_call: true]) ExUnit.start()