-
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #115 from rhx/main
Add Ollama support
- Loading branch information
Showing
5 changed files
with
550 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// | ||
// OllamaEmbeddings.swift | ||
// | ||
// Created by Rene Hexel on 20/4/2024. | ||
// | ||
import Foundation | ||
import AsyncHTTPClient | ||
|
||
extension Ollama: Embeddings { | ||
/// Ollama embedding request. | ||
struct EmbeddingRequest: Codable { | ||
let model: String | ||
let prompt: String | ||
} | ||
/// Ollama embedding structure. | ||
struct Embedding: Codable { | ||
let embedding: [Float] | ||
} | ||
/// Create embeddings for a given text. | ||
/// | ||
/// This function sends a text to the Ollama API and returns the resulting embeddings. | ||
/// | ||
/// - Parameter text: The text to create embeddings for. | ||
/// - Returns: An array of embeddings for the given text. | ||
public func embedQuery(text: String) async -> [Float] { | ||
do { | ||
return try await getEmbeddings(for: text) | ||
} catch { | ||
return [] | ||
} | ||
} | ||
/// Get the embeddings vector for a given text. | ||
/// | ||
/// This function sends a text to the Ollama API and returns the resulting embeddings. | ||
/// | ||
/// - Parameter text: The text to create embeddings vector for. | ||
/// - Returns: An array of embeddings for the given text. | ||
public func getEmbeddings(for text: String) async throws -> [Float] { | ||
let embeddingRequest = EmbeddingRequest(model: model, prompt: text) | ||
guard let data = try await sendJSON(request: embeddingRequest, endpoint: "embeddings") else { | ||
return [] | ||
} | ||
let apiResponse = try JSONDecoder().decode(Embedding.self, from: data) | ||
return apiResponse.embedding | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// | ||
// ChatOllama.swift | ||
// | ||
// Created by Rene Hexel on 21/4/2024. | ||
// | ||
|
||
import Foundation | ||
import OpenAIKit | ||
|
||
/// Ollama class for chat functionality. | ||
/// | ||
/// This class interfaces with the Ollama chat API. | ||
public class ChatOllama: Ollama { | ||
/// The chat history. | ||
/// | ||
/// This array contains the chat history | ||
/// of the conversation so far. | ||
var history = [ChatGLMMessage]() | ||
|
||
/// Create a new Ollama chat instance. | ||
/// | ||
/// This initialiser creates a new Ollama chat instance with the given parameters. | ||
/// | ||
/// - Parameters: | ||
/// - baseURL: The base URL for the Ollama API. | ||
/// - model: The model to use for the chat instance. | ||
/// - options: Additional options for the chat instance. | ||
/// - timeout: The request timeout in seconds. | ||
/// - callbacks: The callback handlers to use. | ||
/// - cache: The cache to use. | ||
public override init(baseURL: String? = nil, model: String? = nil, options: [String : String]? = nil, timeout: Int = 3600, callbacks: [BaseCallbackHandler] = [], cache: BaseCache? = nil) { | ||
super.init(baseURL: baseURL, model: model, options: options, timeout: timeout, callbacks: callbacks, cache: cache) | ||
} | ||
|
||
/// Send a text to the Ollama API. | ||
/// | ||
/// This function implements the main interaction with the Ollama API | ||
/// through its `chat` API. | ||
/// | ||
/// - Parameters: | ||
/// - text: The text to send to the Ollama API. | ||
/// - stops: An array of strings that, if present in the response, will stop the generation. | ||
/// - Returns: | ||
public override func _send(text: String, stops: [String] = []) async throws -> LLMResult { | ||
let message = ChatGLMMessage(role: "user", content: text) | ||
history.append(message) | ||
let chatRequest = ChatRequest(model: model, options: modelOptions, format: "json", stream: false, messages: history) | ||
guard let data = try await sendJSON(request: chatRequest, endpoint: "chat") else { | ||
return LLMResult() | ||
} | ||
let llmResponse = try JSONDecoder().decode(ChatResponse.self, from: data) | ||
history.append(llmResponse.message) | ||
return LLMResult(llm_output: llmResponse.message.content) | ||
} | ||
} | ||
|
||
public extension ChatOllama { | ||
/// Generate the next message in a chat with a provided model. | ||
/// | ||
/// This is a streaming endpoint, so there can be a series of responses. | ||
/// Streaming can be disabled using "stream": false. | ||
struct ChatRequest: Codable, Sendable { | ||
let model: String | ||
let options: [String: String]? | ||
let format: String | ||
let stream: Bool | ||
let messages: [ChatGLMMessage] | ||
} | ||
/// Ollama response to a `ChatRequest`. | ||
/// | ||
/// This response object includes the next message in a chat conversation. | ||
/// The final response object will include statistics and additional data from the request. | ||
struct ChatResponse: Codable, Sendable { | ||
let message: ChatGLMMessage | ||
let model: String | ||
let done: Bool | ||
let totalDuration: Int? | ||
let loadDuration: Int? | ||
let promptEvalDuration: Int? | ||
let evalDuration: Int? | ||
let promptEvalCount: Int? | ||
let evalCount: Int? | ||
|
||
/// Return the message content. | ||
public var content: String { message.content } | ||
|
||
/// JSON coding keys for the `ChatResponse` struct. | ||
enum CodingKeys: String, CodingKey { | ||
case message | ||
case model | ||
case done | ||
case totalDuration = "total_duration" | ||
case loadDuration = "load_duration" | ||
case promptEvalDuration = "prompt_eval_duration" | ||
case evalDuration = "eval_duration" | ||
case promptEvalCount = "prompt_eval_count" | ||
case evalCount = "eval_count" | ||
} | ||
} | ||
} |
Oops, something went wrong.