Skip to content

Commit

Permalink
Merge pull request #76 from michalwarda/mistral_chat
Browse files Browse the repository at this point in the history
Add mistral chat
  • Loading branch information
brainlid authored Feb 20, 2024
2 parents 4e2f466 + 52174fb commit f279837
Show file tree
Hide file tree
Showing 2 changed files with 547 additions and 0 deletions.
373 changes: 373 additions & 0 deletions lib/chat_models/chat_mistral_ai.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
defmodule Langchain.ChatModels.ChatMistralAI do
use Ecto.Schema
require Logger
import Ecto.Changeset
import LangChain.Utils.ApiOverride
alias LangChain.ForOpenAIApi
alias __MODULE__
alias LangChain.Config
alias LangChain.ChatModels.ChatModel
alias LangChain.Message
alias LangChain.MessageDelta
alias LangChain.LangChainError
alias LangChain.Utils

@behaviour ChatModel
@receive_timeout 60_000

@default_endpoint "https://api.mistral.ai/v1/chat/completions"

@primary_key false
embedded_schema do
field :endpoint, :string, default: @default_endpoint

# The version of the API to use.
field :model, :string
field :api_key, :string

# What sampling temperature to use, between 0 and 1. Higher values like 0.8
# will make the output more random, while lower values like 0.2 will make it
# more focused and deterministic.
field :temperature, :float, default: 0.9

# The topP parameter changes how the model selects tokens for output. Tokens
# are selected from the most to least probable until the sum 3of their
# probabilities equals the topP value. For example, if tokens A, B, and C have
# a probability of 0.3, 0.2, and 0.1 and the topP value is 0.5, then the model
# will select either A or B as the next token by using the temperature and exclude
# C as a candidate. The default topP value is 0.95.
field :top_p, :float, default: 1.0

# Duration in seconds for the response to be received. When streaming a very
# lengthy response, a longer time limit may be required. However, when it
# goes on too long by itself, it tends to hallucinate more.
field :receive_timeout, :integer, default: @receive_timeout

field :max_tokens, :integer

field :safe_prompt, :boolean, default: false

field :random_seed, :integer

field :stream, :boolean, default: false
end

@type t :: %ChatMistralAI{}

@create_fields [
:endpoint,
:model,
:api_key,
:temperature,
:top_p,
:receive_timeout,
:max_tokens,
:safe_prompt,
:random_seed,
:stream
]
@required_fields [
:model
]

@spec get_api_key(t) :: String.t()
defp get_api_key(%ChatMistralAI{api_key: api_key}) do
# if no API key is set default to `""` which will raise an API error
api_key || Config.resolve(:mistral_api_key)
end

@spec get_headers(t) :: [tuple()]
defp get_headers(%ChatMistralAI{} = chat) do
api_key = get_api_key(chat)

[
Authorization: "Bearer #{api_key}",
"Content-Type": "application/json",
Accept: "application/json"
]
end

@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(%{} = attrs \\ %{}) do
%ChatMistralAI{}
|> cast(attrs, @create_fields)
|> common_validation()
|> apply_action(:insert)
end

@spec new!(attrs :: map()) :: t() | no_return()
def new!(attrs \\ %{}) do
case new(attrs) do
{:ok, chain} ->
chain

{:error, changeset} ->
raise LangChainError, changeset
end
end

defp common_validation(changeset) do
changeset
|> validate_required(@required_fields)
end

@spec for_api(t, message :: [map()], functions :: [map()]) :: %{atom() => any()}
def for_api(%ChatMistralAI{} = mistral, messages, _functions) do
%{
model: mistral.model,
temperature: mistral.temperature,
top_p: mistral.top_p,
safe_prompt: mistral.safe_prompt,
stream: mistral.stream,
messages: Enum.map(messages, &ForOpenAIApi.for_api/1)
}
|> Utils.conditionally_add_to_map(:random_seed, mistral.random_seed)
|> Utils.conditionally_add_to_map(:max_tokens, mistral.max_tokens)
end

@impl ChatModel
def call(mistral, prompt, functions \\ [], callback_fn \\ nil)

def call(%ChatMistralAI{} = mistral, prompt, functions, callback_fn) when is_binary(prompt) do
messages = [
Message.new_system!(),
Message.new_user!(prompt)
]

call(mistral, messages, functions, callback_fn)
end

def call(%ChatMistralAI{} = mistral, messages, functions, callback_fn) when is_list(messages) do
if override_api_return?() do
Logger.warning("Found override API response. Will not make live API call.")

case get_api_override() do
{:ok, {:ok, data} = response} ->
# fire callback for fake responses too
Utils.fire_callback(mistral, data, callback_fn)
response

# fake error response
{:ok, {:error, _reason} = response} ->
response

_other ->
raise LangChainError,
"An unexpected fake API response was set. Should be an `{:ok, value}`"
end
else
try do
# make base api request and perform high-level success/failure checks
case do_api_request(mistral, messages, functions, callback_fn) do
{:error, reason} ->
{:error, reason}

parsed_data ->
{:ok, parsed_data}
end
rescue
err in LangChainError ->
{:error, err.message}
end
end
end

@spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) ::
list() | struct() | {:error, String.t()}
def do_api_request(mistral, messages, functions, callback_fn, retry_count \\ 3)

def do_api_request(_mistral, _messages, _functions, _callback_fn, 0) do
raise LangChainError, "Retries exceeded. Connection failed."
end

def do_api_request(
%ChatMistralAI{stream: false} = mistral,
messages,
functions,
callback_fn,
retry_count
) do
req =
Req.new(
url: mistral.endpoint,
json: for_api(mistral, messages, functions),
headers: get_headers(mistral),
receive_timeout: mistral.receive_timeout,
retry: :transient,
max_retries: 3,
retry_delay: fn attempt -> 300 * attempt end
)

req
|> Req.post()
# parse the body and return it as parsed structs
|> case do
{:ok, %Req.Response{body: data}} ->
case do_process_response(data) do
{:error, reason} ->
{:error, reason}

result ->
Utils.fire_callback(mistral, result, callback_fn)
result
end

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

{:error, %Mint.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(mistral, messages, functions, callback_fn, retry_count - 1)

other ->
Logger.error("Unexpected and unhandled API response! #{inspect(other)}")
other
end
end

def do_api_request(
%ChatMistralAI{stream: true} = mistral,
messages,
functions,
callback_fn,
retry_count
) do
Req.new(
url: mistral.endpoint,
json: for_api(mistral, messages, functions),
headers: get_headers(mistral),
receive_timeout: mistral.receive_timeout
)
|> Req.post(into: Utils.handle_stream_fn(mistral, &do_process_response/1, callback_fn))
|> case do
{:ok, %Req.Response{body: data}} ->
data

{:error, %LangChainError{message: reason}} ->
{:error, reason}

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

{:error, %Mint.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(mistral, messages, functions, callback_fn, retry_count - 1)

other ->
Logger.error(
"Unhandled and unexpected response from streamed post call. #{inspect(other)}"
)

{:error, "Unexpected response"}
end
end

# Parse a new message response
@doc false
@spec do_process_response(data :: %{String.t() => any()} | {:error, any()}) ::
Message.t()
| [Message.t()]
| MessageDelta.t()
| [MessageDelta.t()]
| {:error, String.t()}
def do_process_response(%{"choices" => choices}) when is_list(choices) do
# process each response individually. Return a list of all processed choices
for choice <- choices do
do_process_response(choice)
end
end

def do_process_response(
%{"delta" => delta_body, "finish_reason" => finish, "index" => index} = _msg
) do
status =
case finish do
nil ->
:incomplete

"stop" ->
:complete

"length" ->
:length

"model_length" ->
:length

other ->
Logger.warning("Unsupported finish_reason in delta message. Reason: #{inspect(other)}")
nil
end

# more explicitly interpret the role. We treat a "function_call" as a a role
# while OpenAI addresses it as an "assistant". Technically, they are correct
# that the assistant is issuing the function_call.
role =
case delta_body do
%{"role" => role} -> role
_other -> "unknown"
end

data =
delta_body
|> Map.put("role", role)
|> Map.put("index", index)
|> Map.put("status", status)

case MessageDelta.new(data) do
{:ok, message} ->
message

{:error, changeset} ->
{:error, Utils.changeset_error_to_string(changeset)}
end
end

def do_process_response(%{
"finish_reason" => finish_reason,
"message" => message,
"index" => index
}) do
status =
case finish_reason do
"stop" ->
:complete

"length" ->
:length

"model_length" ->
:length

other ->
Logger.warning("Unsupported finish_reason in message. Reason: #{inspect(other)}")
nil
end

case Message.new(Map.merge(message, %{"status" => status, "index" => index})) do
{:ok, message} ->
message

{:error, changeset} ->
{:error, Utils.changeset_error_to_string(changeset)}
end
end

def do_process_response(%{"error" => %{"message" => reason}}) do
Logger.error("Received error from API: #{inspect(reason)}")
{:error, reason}
end

def do_process_response({:error, %Jason.DecodeError{} = response}) do
error_message = "Received invalid JSON: #{inspect(response)}"
Logger.error(error_message)
{:error, error_message}
end

def do_process_response(other) do
Logger.error("Trying to process an unexpected response. #{inspect(other)}")
{:error, "Unexpected response"}
end
end
Loading

0 comments on commit f279837

Please sign in to comment.