Skip to content

Commit

Permalink
Add support for Ollama open source models
Browse files Browse the repository at this point in the history
  • Loading branch information
medoror committed Jan 23, 2024
1 parent c890f58 commit bd904e8
Show file tree
Hide file tree
Showing 3 changed files with 601 additions and 1 deletion.
370 changes: 370 additions & 0 deletions lib/chat_models/chat_ollama_ai.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
defmodule LangChain.ChatModels.ChatOllamaAI do
@moduledoc """
Represents the [Ollama AI Chat model](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion)
Parses and validates inputs for making a requests from the Ollama Chat API.
Converts responses into more specialized `LangChain` data structures.
The module's functionalities include:
- Initializing a new `ChatOllamaAI` struct with defaults or specific attributes.
- Validating and casting input data to fit the expected schema.
- Preparing and sending requests to the Ollama AI service API.
- Managing both streaming and non-streaming API responses.
- Processing API responses to convert them into suitable message formats.
The `ChatOllamaAI` struct has fields to configure the AI, including but not limited to:
- `endpoint`: URL of the Ollama AI service.
- `model`: The AI model used, e.g., "llama2:latest".
- `receive_timeout`: Max wait time for AI service responses.
- `temperature`: Influences the AI's response creativity.
For detailed info on on all other parameters see documentation here:
https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
This module is for use within LangChain and follows the `ChatModel` behavior,
outlining callbacks AI chat models must implement.
Usage examples and more details are in the LangChain documentation or the
module's function docs.
"""
use Ecto.Schema
require Logger
import Ecto.Changeset
alias __MODULE__
alias LangChain.ChatModels.ChatModel
alias LangChain.Message
alias LangChain.MessageDelta
alias LangChain.LangChainError
alias LangChain.ForOpenAIApi
alias LangChain.Utils

@behaviour ChatModel

@type t :: %ChatOllamaAI{}

@create_fields [
:endpoint,
:mirostat,
:mirostat_eta,
:mirostat_tau,
:model,
:num_ctx,
:num_gqa,
:num_gpu,
:num_predict,
:num_thread,
:receive_timeout,
:repeat_last_n,
:repeat_penalty,
:seed,
:stop,
:stream,
:temperature,
:tfs_z,
:top_k,
:top_p
]

@required_fields [:endpoint, :model]

@receive_timeout 60_000 * 5

@primary_key false
embedded_schema do
field :endpoint, :string, default: "http://localhost:11434/api/chat"

# Enable Mirostat sampling for controlling perplexity.
# (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
field :mirostat, :integer, default: 0

# Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate
# will result in slower adjustments, while a higher learning rate will make the algorithm more responsive.
# (Default: 0.1)
field :mirostat_eta, :float, default: 0.1

# Controls the balance between coherence and diversity of the output. A lower value will result in more focused
# and coherent text. (Default: 5.0)
field :mirostat_tau, :float, default: 5.0

field :model, :string, default: "llama2:latest"

# Sets the size of the context window used to generate the next token. (Default: 2048)
field :num_ctx, :integer, default: 2048

# The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b
field :num_gqa, :integer

# The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable.
field :num_gpu, :integer

# Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)
field :num_predict, :integer, default: 128

# Sets the number of threads to use during computation. By default, Ollama will detect this for optimal
# performance. It is recommended to set this value to the number of physical CPU cores your system has (as
# opposed to the logical number of cores).
field :num_thread, :integer

# Duration in seconds for the response to be received. When streaming a very
# lengthy response, a longer time limit may be required. However, when it
# goes on too long by itself, it tends to hallucinate more.
# Seems like the default for ollama is 5 minutes? https://github.com/jmorganca/ollama/pull/1257
field :receive_timeout, :integer, default: @receive_timeout

# Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
field :repeat_last_n, :integer, default: 64

# Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly,
# while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
field :repeat_penalty, :float, default: 1.1

# Sets the random number seed to use for generation. Setting this to a specific number will make the
# model generate the same text for the same prompt. (Default: 0)
field :seed, :integer, default: 0

# Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return.
# Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile.
field :stop, :string

field :stream, :boolean, default: false

# The temperature of the model. Increasing the temperature will make the model
# answer more creatively. (Default: 0.8)
field :temperature, :float, default: 0.8

# Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0)
# will reduce the impact more, while a value of 1.0 disables this setting. (default: 1)
field :tfs_z, :float, default: 1.0

# Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers,
# while a lower value (e.g. 10) will be more conservative. (Default: 40)
field :top_k, :integer, default: 40

# Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text,
# while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
field :top_p, :float, default: 0.9
end

@doc """
Creates a new `ChatOllamaAI` struct with the given attributes.
"""
@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(%{} = attrs \\ %{}) do
%ChatOllamaAI{}
|> cast(attrs, @create_fields, empty_values: [""])
|> common_validation()
|> apply_action(:insert)
end

@doc """
Creates a new `ChatOllamaAI` struct with the given attributes. Will raise an error if the changeset is invalid.
"""
@spec new!(attrs :: map()) :: t() | no_return()
def new!(attrs \\ %{}) do
case new(attrs) do
{:ok, chain} ->
chain

{:error, changeset} ->
raise LangChainError, changeset
end
end

defp common_validation(changeset) do
changeset
|> validate_required(@required_fields)
|> validate_number(:temperature, greater_than_or_equal_to: 0.0, less_than_or_equal_to: 1.0)
|> validate_number(:mirostat_eta, greater_than_or_equal_to: 0.0, less_than_or_equal_to: 1.0)
end

@doc """
Return the params formatted for an API request.
"""
def for_api(%ChatOllamaAI{} = model, messages, _functions) do
%{
model: model.model,
temperature: model.temperature,
messages: messages |> Enum.map(&ForOpenAIApi.for_api/1),
stream: model.stream,
seed: model.seed,
num_ctx: model.num_ctx,
num_predict: model.num_predict,
repeat_last_n: model.repeat_last_n,
repeat_penalty: model.repeat_penalty,
mirostat: model.mirostat,
mirostat_eta: model.mirostat_eta,
mirostat_tau: model.mirostat_tau,
num_gqa: model.num_gqa,
num_gpu: model.num_gpu,
num_thread: model.num_thread,
receive_timeout: model.receive_timeout,
stop: model.stop,
tfs_z: model.tfs_z,
top_k: model.top_k,
top_p: model.top_p
}
end

@doc """
Calls the Ollama Chat Completion API struct with configuration, plus
either a simple message or the list of messages to act as the prompt.
**NOTE:** This API as of right now does not take any functions. More
information here: https://github.com/jmorganca/ollama/issues/1729
**NOTE:** This function *can* be used directly, but the primary interface
should be through `LangChain.Chains.LLMChain`. The `ChatOllamaAI` module is more focused on
translating the `LangChain` data structures to and from the Ollama API.
Another benefit of using `LangChain.Chains.LLMChain` is that it combines the
storage of messages, adding functions, adding custom context that should be
passed to functions, and automatically applying `LangChain.MessageDelta`
structs as they are are received, then converting those to the full
`LangChain.Message` once fully complete.
"""

@impl ChatModel
def call(ollama_ai, prompt, functions \\ [], callback_fn \\ nil)

def call(%ChatOllamaAI{} = ollama_ai, prompt, functions, callback_fn) when is_binary(prompt) do
messages = [
Message.new_system!(),
Message.new_user!(prompt)
]

call(ollama_ai, messages, functions, callback_fn)
end

def call(%ChatOllamaAI{} = ollama_ai, messages, functions, callback_fn)
when is_list(messages) do
try do
case do_api_request(ollama_ai, messages, functions, callback_fn) do
{:error, reason} ->
{:error, reason}

parsed_data ->
{:ok, parsed_data}
end
rescue
err in LangChainError ->
{:error, err.message}
end
end

# Make the API request from the Ollama server.
#
# The result of the function is:
#
# - `result` - where `result` is a data-structure like a list or map.
# - `{:error, reason}` - Where reason is a string explanation of what went wrong.
#
# **NOTE:** callback function are IGNORED for ollama ai
# When `stream: true` is
# If `stream: false`, the completed message is returned.
#
# If `stream: true`, the completed message is returned after MessageDelta's.
#
# Retries the request up to 3 times on transient errors with a 1 second delay
@doc false
@spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) ::
list() | struct() | {:error, String.t()}
def do_api_request(ollama_ai, messages, functions, callback_fn, retry_count \\ 3)

def do_api_request(_ollama_ai, _messages, _functions, _callback_fn, 0) do
raise LangChainError, "Retries exceeded. Connection failed."
end

def do_api_request(%ChatOllamaAI{stream: false} = ollama_ai, messages, functions, callback_fn, retry_count) do
req =
Req.new(
url: ollama_ai.endpoint,
json: for_api(ollama_ai, messages, functions),
receive_timeout: ollama_ai.receive_timeout,
retry: :transient,
max_retries: 3,
retry_delay: fn attempt -> 300 * attempt end
)

req
|> Req.post()
|> case do
{:ok, %Req.Response{body: data}} ->
case do_process_response(data) do
{:error, reason} ->
{:error, reason}

result ->
result
end

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

{:error, %Mint.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(ollama_ai, messages, functions, callback_fn, retry_count - 1)

other ->
Logger.error("Unexpected and unhandled API response! #{inspect(other)}")
other
end
end

def do_api_request(%ChatOllamaAI{stream: true} = ollama_ai, messages, functions, callback_fn, retry_count) do
Req.new(
url: ollama_ai.endpoint,
json: for_api(ollama_ai, messages, functions),
receive_timeout: ollama_ai.receive_timeout
)
|> Req.post(into: Utils.handle_stream_fn(ollama_ai, &do_process_response/1, callback_fn))
|> case do
{:ok, %Req.Response{body: data}} ->
data

{:error, %LangChainError{message: reason}} ->
{:error, reason}

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

{:error, %Mint.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(ollama_ai, messages, functions, callback_fn, retry_count - 1)

other ->
Logger.error(
"Unhandled and unexpected response from streamed post call. #{inspect(other)}"
)

{:error, "Unexpected response"}
end
end

def do_process_response(%{"message" => message, "done" => true}) do
create_message(message, :complete, Message)
end

def do_process_response(%{"message" => message, "done" => _other}) do
create_message(message, :incomplete, MessageDelta)
end

def do_process_response(%{"error" => reason}) do
Logger.error("Received error from API: #{inspect(reason)}")
{:error, reason}
end

defp create_message(message, status, message_type) do
case message_type.new(Map.merge(message, %{"status" => status})) do
{:ok, new_message} ->
new_message

{:error, changeset} ->
{:error, Utils.changeset_error_to_string(changeset)}
end
end
end
Loading

0 comments on commit bd904e8

Please sign in to comment.