diff --git a/lib/utils.ex b/lib/utils.ex index e3177372..87438cd8 100644 --- a/lib/utils.ex +++ b/lib/utils.ex @@ -4,7 +4,6 @@ defmodule LangChain.Utils do """ alias Ecto.Changeset require Logger - alias LangChain.LangChainError @doc """ Only add the key to the map if the value is present. When the value is a list, @@ -113,7 +112,7 @@ defmodule LangChain.Utils do end @doc """ - Create a function to handle the streaming request. + Creates and returns an anonymous function to handle the streaming request. """ @spec handle_stream_fn( %{optional(:stream) => boolean()}, @@ -123,9 +122,14 @@ defmodule LangChain.Utils do def handle_stream_fn(model, process_response_fn, callback_fn) do fn {:data, raw_data}, {req, response} -> # cleanup data because it isn't structured well for JSON. - new_data = decode_streamed_data(raw_data, process_response_fn) + + # Fetch any previously incomplete messages that are buffered in the + # response struct. and pass that in with the data for decode + buffered = Req.Response.get_private(response, :lang_incomplete, "") + {parsed_data, incomplete} = decode_streamed_data({raw_data, buffered}, process_response_fn) + # execute the callback function for each MessageDelta - fire_callback(model, new_data, callback_fn) + fire_callback(model, parsed_data, callback_fn) old_body = if response.body == "", do: [], else: response.body # Returns %Req.Response{} where the body contains ALL the stream delta @@ -146,13 +150,17 @@ defmodule LangChain.Utils do # ] # # The reason for the inner list is for each entry in the "n" choices. By default only 1. - updated_response = %{response | body: old_body ++ new_data} + updated_response = %{response | body: old_body ++ parsed_data} + # write any incomplete portion to the response's private data for when + # more data is received. + updated_response = Req.Response.put_private(updated_response, :lang_incomplete, incomplete) {:cont, {req, updated_response}} end end - defp decode_streamed_data(data, process_response_fn) do + @doc false + def decode_streamed_data({raw_data, buffer}, process_response_fn) do # Data comes back like this: # # "data: {\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.completion.chunk\",\"created\":1689801995,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"calculator\",\"arguments\":\"\"}},\"finish_reason\":null}]}\n\n @@ -161,42 +169,35 @@ defmodule LangChain.Utils do # In that form, the data is not ready to be interpreted as JSON. Let's clean # it up first. - data + # as we start, the initial accumulator is an empty set of parsed results and + # any left-over buffer from a previous processing. + raw_data |> String.split("data: ") - |> Enum.map(fn str -> + |> Enum.reduce({[], buffer}, fn str, {done, incomplete} = acc -> + # auto filter out "" and "[DONE]" by not including the accumulator str |> String.trim() |> case do "" -> - :empty + acc "[DONE]" -> - :empty + acc json -> - json + # combine with any previous incomplete data + starting_json = incomplete <> json + + starting_json |> Jason.decode() |> case do {:ok, parsed} -> - parsed + {done ++ [process_response_fn.(parsed)], ""} - {:error, reason} -> - {:error, reason} + {:error, _reason} -> + {done, starting_json} end - |> process_response_fn.() end end) - # returning a list of elements. "junk" elements were replaced with `:empty`. - # Filter those out down and return the final list of MessageDelta structs. - |> Enum.filter(fn d -> d != :empty end) - # if there was a single error returned in a list, flatten it out to just - # return the error - |> case do - [{:error, reason}] -> - raise LangChainError, reason - - other -> - other - end end end diff --git a/mix.lock b/mix.lock index 084c086b..c614b6e3 100644 --- a/mix.lock +++ b/mix.lock @@ -4,6 +4,7 @@ "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, "ecto": {:hex, :ecto, "3.11.1", "4b4972b717e7ca83d30121b12998f5fcdc62ba0ed4f20fd390f16f3270d85c3e", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ebd3d3772cd0dfcd8d772659e41ed527c28b2a8bde4b00fe03e0463da0f1983b"}, + "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, "expo": {:hex, :expo, "0.4.1", "1c61d18a5df197dfda38861673d392e642649a9cef7694d2f97a587b2cfb319b", [:mix], [], "hexpm", "2ff7ba7a798c8c543c12550fa0e2cbc81b95d4974c65855d8d15ba7b37a1ce47"}, "finch": {:hex, :finch, "0.17.0", "17d06e1d44d891d20dbd437335eebe844e2426a0cd7e3a3e220b461127c73f70", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "8d014a661bb6a437263d4b5abf0bcbd3cf0deb26b1e8596f2a271d22e48934c7"}, @@ -12,7 +13,7 @@ "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, - "makeup_erlang": {:hex, :makeup_erlang, "0.1.3", "d684f4bac8690e70b06eb52dad65d26de2eefa44cd19d64a8095e1417df7c8fd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "b78dc853d2e670ff6390b605d807263bf606da3c82be37f9d7f68635bd886fc9"}, + "makeup_erlang": {:hex, :makeup_erlang, "0.1.4", "29563475afa9b8a2add1b7a9c8fb68d06ca7737648f28398e04461f008b69521", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f4ed47ecda66de70dd817698a703f8816daa91272e7e45812469498614ae8b29"}, "mime": {:hex, :mime, "2.0.5", "dc34c8efd439abe6ae0343edbb8556f4d63f178594894720607772a041b04b02", [:mix], [], "hexpm", "da0d64a365c45bc9935cc5c8a7fc5e49a0e0f9932a761c55d6c52b142780a05c"}, "mint": {:hex, :mint, "1.5.2", "4805e059f96028948870d23d7783613b7e6b0e2fb4e98d720383852a760067fd", [:mix], [{:castore, "~> 0.1.0 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:hpax, "~> 0.1.1", [hex: :hpax, repo: "hexpm", optional: false]}], "hexpm", "d77d9e9ce4eb35941907f1d3df38d8f750c357865353e21d335bdcdf6d892a02"}, "nimble_options": {:hex, :nimble_options, "1.1.0", "3b31a57ede9cb1502071fade751ab0c7b8dbe75a9a4c2b5bbb0943a690b63172", [:mix], [], "hexpm", "8bbbb3941af3ca9acc7835f5655ea062111c9c27bcac53e004460dfd19008a99"}, diff --git a/test/utils_test.exs b/test/utils_test.exs index 51251105..462fded6 100644 --- a/test/utils_test.exs +++ b/test/utils_test.exs @@ -92,4 +92,119 @@ defmodule LangChain.UtilsTest do assert result == "role: is important, is invalid; index: is numeric, is invalid" end end + + def setup_expected_json(_) do + json_1 = %{ + "choices" => [ + %{ + "delta" => %{ + "content" => nil, + "function_call" => %{"arguments" => "", "name" => "calculator"}, + "role" => "assistant" + }, + "finish_reason" => nil, + "index" => 0 + } + ], + "created" => 1_689_801_995, + "id" => "chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS", + "model" => "gpt-4-0613", + "object" => "chat.completion.chunk" + } + + json_2 = %{ + "choices" => [ + %{ + "delta" => %{"function_call" => %{"arguments" => "{\n"}}, + "finish_reason" => nil, + "index" => 0 + } + ], + "created" => 1_689_801_995, + "id" => "chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS", + "model" => "gpt-4-0613", + "object" => "chat.completion.chunk" + } + + %{json_1: json_1, json_2: json_2} + end + + defp send_parsed_data(%{} = parsed_data) do + send(self(), {:parsed_data, parsed_data}) + parsed_data + end + + describe "decode_streamed_data/2" do + setup :setup_expected_json + + test "correctly handles fully formed chat completion chunks", %{json_1: json_1, json_2: json_2} do + data = + "data: {\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.completion.chunk\",\"created\":1689801995,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"calculator\",\"arguments\":\"\"}},\"finish_reason\":null}]}\n\n + data: {\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.completion.chunk\",\"created\":1689801995,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"function_call\":{\"arguments\":\"{\\n\"}},\"finish_reason\":null}]}\n\n" + + {parsed, incomplete} = Utils.decode_streamed_data({data, ""}, &send_parsed_data/1) + + # callback should have fired with matching parsed data + assert_received {:parsed_data, ^json_1} + assert_received {:parsed_data, ^json_2} + + # nothing incomplete. Parsed 2 objects. + assert incomplete == "" + assert parsed == [json_1, json_2] + end + + test "correctly parses when data split over received messages", %{json_1: json_1} do + # split the data over multiple messages + data = + "data: {\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.comple + data: tion.chunk\",\"created\":1689801995,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"calculator\",\"arguments\":\"\"}},\"finish_reason\":null}]}\n\n" + + {parsed, incomplete} = Utils.decode_streamed_data({data, ""}, &send_parsed_data/1) + + # callback should have fired with matching parsed data + assert_received {:parsed_data, ^json_1} + + # nothing incomplete. Parsed 1 object. + assert incomplete == "" + assert parsed == [json_1] + end + + test "correctly parses when data split over decode calls", %{json_1: json_1} do + buffered = "{\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.comple" + + # incomplete message chunk processed in next call + data = "data: tion.chunk\",\"created\":1689801995,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"calculator\",\"arguments\":\"\"}},\"finish_reason\":null}]}\n\n" + + {parsed, incomplete} = Utils.decode_streamed_data({data, buffered}, &send_parsed_data/1) + + # callback should have fired with matching parsed data + assert_received {:parsed_data, ^json_1} + + # nothing incomplete. Parsed 1 object. + assert incomplete == "" + assert parsed == [json_1] + end + + test "correctly parses when data previously buffered and responses split and has leftovers", %{json_1: json_1, json_2: json_2} do + buffered = "{\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.comple" + + # incomplete message chunk processed in next call + data = + "data: tion.chunk\",\"created\":1689801995,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"calculator\",\"arguments\":\"\"}},\"finish_reason\":null}]}\n\n + data: {\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.completion.chunk\",\"crea + data: ted\":1689801995,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"function_call\":{\"argu + data: ments\":\"{\\n\"}},\"finish_reason\":null}]}\n\n + data: {\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.comp" + + {parsed, incomplete} = Utils.decode_streamed_data({data, buffered}, &send_parsed_data/1) + + # callback should have fired with matching parsed data + assert_received {:parsed_data, ^json_1} + assert_received {:parsed_data, ^json_2} + + # nothing incomplete. Parsed 1 object. + assert incomplete == "{\"id\":\"chatcmpl-7e8yp1xBhriNXiqqZ0xJkgNrmMuGS\",\"object\":\"chat.comp" + assert parsed == [json_1, json_2] + end + end end