Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Utils.ChainResult module #45

Merged
merged 1 commit into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions lib/utils/chain_result.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
defmodule LangChain.Utils.ChainResult do
@moduledoc """
Module to help when working with the results of a chain.
"""
alias LangChain.LangChainError
alias LangChain.Chains.LLMChain
alias LangChain.Message
alias __MODULE__

@doc """
Return the result of the chain as a string. Returned in an `:ok` tuple format.
An `{:error, reason}` is returned for various reasons. These include:
- The last message of the chain is not an `:assistant` message.
- The last message of the chain is incomplete.
- There is no last message.
"""
@spec to_string(LLMChain.t()) :: {:ok, String.t()} | {:error, String.t()}
def to_string(%LLMChain{last_message: %Message{role: :assistant, status: :complete}} = chain) do
{:ok, chain.last_message.content}
end

def to_string(%LLMChain{last_message: %Message{role: :assistant, status: _incomplete}} = _chain) do
{:error, "Message is incomplete"}
end

def to_string(%LLMChain{last_message: %Message{}} = _chain) do
{:error, "Message is not from assistant"}
end

def to_string(%LLMChain{last_message: nil} = _chain) do
{:error, "No last message"}
end

@doc """
Return the last message's content when it is valid to use it. Otherwise it
raises and exception with the reason why it cannot be used. See the docs for
`LangChain.Utils.ChainResult.to_string/2` for details.
"""
@spec to_string!(LLMChain.t()) :: String.t() | no_return()
def to_string!(%LLMChain{} = chain) do
case ChainResult.to_string(chain) do
{:ok, result} -> result
{:error, reason} -> raise LangChainError, reason
end
end

@doc """
Write the result to the given map as the value of the given key.
"""
@spec to_map(LLMChain.t(), map(), any()) :: {:ok, map()} | {:error, String.t()}
def to_map(%LLMChain{} = chain, map, key) do
case ChainResult.to_string(chain) do
{:ok, value} ->
{:ok, Map.put(map, key, value)}

{:error, _reason} = error ->
error
end
end

@doc """
Write the result to the given map as the value of the given key. If invalid,
an exception is raised.
"""
@spec to_map!(LLMChain.t(), map(), any()) :: map() | no_return()
def to_map!(%LLMChain{} = chain, map, key) do
case ChainResult.to_map(chain, map, key) do
{:ok, updated} ->
updated

{:error, reason} ->
raise LangChainError, reason
end
end
end
86 changes: 86 additions & 0 deletions test/utils/chain_result_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
defmodule LangChain.Utils.ChainResultTest do
use ExUnit.Case

doctest LangChain.Utils.ChainResult

alias LangChain.Utils.ChainResult
alias LangChain.Chains.LLMChain
alias LangChain.Message
alias LangChain.LangChainError

describe "to_string/1" do
test "returns {:ok, answer} when valid" do
chain = %LLMChain{last_message: Message.new_assistant!("the answer")}
assert {:ok, "the answer"} == ChainResult.to_string(chain)
end

test "returns error when no last message" do
chain = %LLMChain{last_message: nil}
assert {:error, "No last message"} == ChainResult.to_string(chain)
end

test "returns error when incomplete last message" do
chain = %LLMChain{
last_message: Message.new!(%{role: :assistant, content: "Incomplete", status: :length})
}

assert {:error, "Message is incomplete"} == ChainResult.to_string(chain)
end

test "returns error when last message is not from assistant" do
chain = %LLMChain{
last_message: Message.new_user!("The question")
}

assert {:error, "Message is not from assistant"} == ChainResult.to_string(chain)
end
end

describe "to_string!/1" do
test "returns string when valid" do
chain = %LLMChain{last_message: Message.new_assistant!("the answer")}
assert "the answer" == ChainResult.to_string!(chain)
end

test "raises LangChainError when invalid" do
chain = %LLMChain{last_message: nil}

assert_raise LangChainError, "No last message", fn ->
ChainResult.to_string!(chain)
end
end
end

describe "to_map/3" do
test "writes string result to map when valid" do
data = %{thing: "one"}
chain = %LLMChain{last_message: Message.new_assistant!("the answer")}
assert {:ok, result} = ChainResult.to_map(chain, data, :answer)
assert %{thing: "one", answer: "the answer"} == result
end

test "returns error tuple with reason when invalid" do
data = %{thing: "one"}
chain = %LLMChain{last_message: nil}
assert {:error, "No last message"} = ChainResult.to_map(chain, data, :answer)
end
end

describe "to_map!/3" do
test "writes string result to map when valid" do
data = %{thing: "one"}
chain = %LLMChain{last_message: Message.new_assistant!("the answer")}
result = ChainResult.to_map!(chain, data, :answer)
assert %{thing: "one", answer: "the answer"} == result
end

test "raises error tuple with reason when invalid" do
data = %{thing: "one"}
chain = %LLMChain{last_message: nil}

assert_raise LangChainError, "No last message", fn ->
ChainResult.to_map!(chain, data, :answer)
end
end
end
end
Loading