Skip to content

Commit

Permalink
Merge pull request #115 from rhx/main
Browse files Browse the repository at this point in the history
Add Ollama support
  • Loading branch information
buhe authored Apr 21, 2024
2 parents eeff567 + 0547eac commit 0f902da
Show file tree
Hide file tree
Showing 5 changed files with 550 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ OPENWEATHER_API_KEY=xxx
LLAMA2_API_KEY=xxx
GOOGLEAI_API_KEY=xxx
LMSTUDIO_URL=xxx
OLLAMA_URL=xxx
OLLAMA_MODEL=xxx
NOTION_API_KEY=xxx
NOTION_ROOT_NODE_ID=xxx
BILIBILI_SESSION=xxx
Expand Down Expand Up @@ -405,6 +407,7 @@ Task(priority: .background) {
- [x] Llama 2
- [x] Gemini
- [x] LMStudio API
- [x] Ollama API
- [x] Local Model
- Vectorstore
- [x] Supabase
Expand All @@ -415,6 +418,7 @@ Task(priority: .background) {
- [x] FileStore
- Embedding
- [x] OpenAI
- [x] Ollama
- [ ] Distilbert
- Chain
- [x] Base
Expand Down
46 changes: 46 additions & 0 deletions Sources/LangChain/embeddings/OllamaEmbeddings.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//
// OllamaEmbeddings.swift
//
// Created by Rene Hexel on 20/4/2024.
//
import Foundation
import AsyncHTTPClient

extension Ollama: Embeddings {
/// Ollama embedding request.
struct EmbeddingRequest: Codable {
let model: String
let prompt: String
}
/// Ollama embedding structure.
struct Embedding: Codable {
let embedding: [Float]
}
/// Create embeddings for a given text.
///
/// This function sends a text to the Ollama API and returns the resulting embeddings.
///
/// - Parameter text: The text to create embeddings for.
/// - Returns: An array of embeddings for the given text.
public func embedQuery(text: String) async -> [Float] {
do {
return try await getEmbeddings(for: text)
} catch {
return []
}
}
/// Get the embeddings vector for a given text.
///
/// This function sends a text to the Ollama API and returns the resulting embeddings.
///
/// - Parameter text: The text to create embeddings vector for.
/// - Returns: An array of embeddings for the given text.
public func getEmbeddings(for text: String) async throws -> [Float] {
let embeddingRequest = EmbeddingRequest(model: model, prompt: text)
guard let data = try await sendJSON(request: embeddingRequest, endpoint: "embeddings") else {
return []
}
let apiResponse = try JSONDecoder().decode(Embedding.self, from: data)
return apiResponse.embedding
}
}
100 changes: 100 additions & 0 deletions Sources/LangChain/llms/ChatOllama.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//
// ChatOllama.swift
//
// Created by Rene Hexel on 21/4/2024.
//

import Foundation
import OpenAIKit

/// Ollama class for chat functionality.
///
/// This class interfaces with the Ollama chat API.
public class ChatOllama: Ollama {
/// The chat history.
///
/// This array contains the chat history
/// of the conversation so far.
var history = [ChatGLMMessage]()

/// Create a new Ollama chat instance.
///
/// This initialiser creates a new Ollama chat instance with the given parameters.
///
/// - Parameters:
/// - baseURL: The base URL for the Ollama API.
/// - model: The model to use for the chat instance.
/// - options: Additional options for the chat instance.
/// - timeout: The request timeout in seconds.
/// - callbacks: The callback handlers to use.
/// - cache: The cache to use.
public override init(baseURL: String? = nil, model: String? = nil, options: [String : String]? = nil, timeout: Int = 3600, callbacks: [BaseCallbackHandler] = [], cache: BaseCache? = nil) {
super.init(baseURL: baseURL, model: model, options: options, timeout: timeout, callbacks: callbacks, cache: cache)
}

/// Send a text to the Ollama API.
///
/// This function implements the main interaction with the Ollama API
/// through its `chat` API.
///
/// - Parameters:
/// - text: The text to send to the Ollama API.
/// - stops: An array of strings that, if present in the response, will stop the generation.
/// - Returns:
public override func _send(text: String, stops: [String] = []) async throws -> LLMResult {
let message = ChatGLMMessage(role: "user", content: text)
history.append(message)
let chatRequest = ChatRequest(model: model, options: modelOptions, format: "json", stream: false, messages: history)
guard let data = try await sendJSON(request: chatRequest, endpoint: "chat") else {
return LLMResult()
}
let llmResponse = try JSONDecoder().decode(ChatResponse.self, from: data)
history.append(llmResponse.message)
return LLMResult(llm_output: llmResponse.message.content)
}
}

public extension ChatOllama {
/// Generate the next message in a chat with a provided model.
///
/// This is a streaming endpoint, so there can be a series of responses.
/// Streaming can be disabled using "stream": false.
struct ChatRequest: Codable, Sendable {
let model: String
let options: [String: String]?
let format: String
let stream: Bool
let messages: [ChatGLMMessage]
}
/// Ollama response to a `ChatRequest`.
///
/// This response object includes the next message in a chat conversation.
/// The final response object will include statistics and additional data from the request.
struct ChatResponse: Codable, Sendable {
let message: ChatGLMMessage
let model: String
let done: Bool
let totalDuration: Int?
let loadDuration: Int?
let promptEvalDuration: Int?
let evalDuration: Int?
let promptEvalCount: Int?
let evalCount: Int?

/// Return the message content.
public var content: String { message.content }

/// JSON coding keys for the `ChatResponse` struct.
enum CodingKeys: String, CodingKey {
case message
case model
case done
case totalDuration = "total_duration"
case loadDuration = "load_duration"
case promptEvalDuration = "prompt_eval_duration"
case evalDuration = "eval_duration"
case promptEvalCount = "prompt_eval_count"
case evalCount = "eval_count"
}
}
}
Loading

0 comments on commit 0f902da

Please sign in to comment.