Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
vmanot committed Mar 15, 2024
1 parent f0f1503 commit 61211f9
Show file tree
Hide file tree
Showing 26 changed files with 310 additions and 204 deletions.
4 changes: 1 addition & 3 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ let package = Package(
name: "AI",
platforms: [
.iOS(.v16),
.macOS(.v13),
.tvOS(.v16),
.watchOS(.v9)
.macOS(.v13)
],
products: [
.library(
Expand Down
111 changes: 69 additions & 42 deletions Sources/Anthropic/Intramodular/Anthropic+LLMRequestHandling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,21 @@ extension Anthropic: LLMRequestHandling {

public func complete<Prompt: AbstractLLM.Prompt>(
prompt: Prompt,
parameters: Prompt.CompletionParameters,
heuristics: AbstractLLM.CompletionHeuristics
parameters: Prompt.CompletionParameters
) async throws -> Prompt.Completion {
let _completion: Any

switch prompt {
case let prompt as AbstractLLM.TextPrompt:
_completion = try await _complete(
prompt: prompt,
parameters: try cast(parameters),
heuristics: heuristics
parameters: try cast(parameters)
)

case let prompt as AbstractLLM.ChatPrompt:
_completion = try await _complete(
prompt: prompt,
parameters: try cast(parameters),
heuristics: heuristics
parameters: try cast(parameters)
)
default:
throw LLMRequestHandlingError.unsupportedPromptType(Prompt.self)
Expand All @@ -51,8 +48,7 @@ extension Anthropic: LLMRequestHandling {

private func _complete(
prompt: AbstractLLM.TextPrompt,
parameters: AbstractLLM.TextCompletionParameters,
heuristics: AbstractLLM.CompletionHeuristics
parameters: AbstractLLM.TextCompletionParameters
) async throws -> AbstractLLM.TextCompletion {
let response = try await run(
\.complete,
Expand All @@ -74,10 +70,9 @@ extension Anthropic: LLMRequestHandling {
)
}

private func _complete(
private func _completeUsingTextPrompt(
prompt: AbstractLLM.ChatPrompt,
parameters: AbstractLLM.ChatCompletionParameters,
heuristics: AbstractLLM.CompletionHeuristics
parameters: AbstractLLM.ChatCompletionParameters
) async throws -> AbstractLLM.ChatCompletion {
let completion = try await _complete(
prompt: AbstractLLM.TextPrompt(
Expand All @@ -86,8 +81,7 @@ extension Anthropic: LLMRequestHandling {
parameters: .init(
tokenLimit: parameters.tokenLimit ?? .max,
stops: parameters.stops
),
heuristics: heuristics
)
)

let isAssistantReply = (prompt.messages.last?.role ?? .user) == .user
Expand All @@ -99,11 +93,33 @@ extension Anthropic: LLMRequestHandling {
)

return AbstractLLM.ChatCompletion(
prompt: prompt.messages,
message: message,
stopReason: .init() // FIXME!!!
stopReason: .init() // FIXME: !!!
)
}

private func _complete(
prompt: AbstractLLM.ChatPrompt,
parameters: AbstractLLM.ChatCompletionParameters
) async throws -> AbstractLLM.ChatCompletion {
let response = try await run(
\.createMessage,
with: createMessageRequestBody(from: prompt, parameters: parameters, stream: false)
)

let message = try Anthropic.ChatMessage(
role: response.role,
content: response.content
)

return try AbstractLLM.ChatCompletion(
prompt: prompt.messages,
message: message.__conversion(),
stopReason: response.stopReason?.__conversion()
)
}

public func completion(
for prompt: AbstractLLM.ChatPrompt
) throws -> AbstractLLM.ChatCompletionStream {
Expand All @@ -115,31 +131,7 @@ extension Anthropic: LLMRequestHandling {
private func _completion(
for prompt: AbstractLLM.ChatPrompt
) async throws -> AsyncThrowingStream<AbstractLLM.ChatCompletionStream.Event, Error> {
var prompt = prompt
let parameters: AbstractLLM.ChatCompletionParameters? = try cast(prompt.context.completionParameters)
let model: Anthropic.Model = try await _model(for: prompt)

/// Anthropic doesn't support a `system` role.
let system: String? = try prompt.messages
.removeFirst(byUnwrapping: { $0.role == .system ? $0.content : nil })
.map({ try $0._stripToText() })

let messages = try prompt.messages.map { (message: AbstractLLM.ChatMessage) in
try Anthropic.ChatMessage(from: message)
}

let requestBody = Anthropic.API.RequestBodies.CreateMessage(
model: model,
messages: messages,
system: system,
maxTokens: parameters?.tokenLimit?.fixedValue ?? 4000, // FIXME: Hardcoded,
temperature: parameters?.temperatureOrTopP?.temperature,
topP: parameters?.temperatureOrTopP?.topProbabilityMass,
topK: nil,
stopSequences: parameters?.stops,
stream: true,
metadata: nil
)
let requestBody = try await createMessageRequestBody(from: prompt, stream: true)

let request = try HTTPRequest(url: "https://api.anthropic.com/v1/messages")
.jsonBody(requestBody, keyEncodingStrategy: .convertToSnakeCase)
Expand Down Expand Up @@ -169,7 +161,8 @@ extension Anthropic: LLMRequestHandling {
let rest = line.index(line.startIndex, offsetBy: 6)
let data: Data = line[rest...].data(using: .utf8)!

let response = try JSONDecoder().decode(Anthropic.API.ResponseBodies.CreateMessageStream.self, from: data)
let decoder = JSONDecoder(keyDecodingStrategy: .convertFromSnakeCase)
let response = try decoder.decode(Anthropic.API.ResponseBodies.CreateMessageStream.self, from: data)

if
let content: Anthropic.API.ResponseBodies.CreateMessageStream.Delta = response.delta,
Expand Down Expand Up @@ -201,15 +194,49 @@ extension Anthropic: LLMRequestHandling {
return result
}

func _model(
private func createMessageRequestBody(
from prompt: AbstractLLM.ChatPrompt,
parameters: AbstractLLM.ChatCompletionParameters? = nil,
stream: Bool
) async throws -> Anthropic.API.RequestBodies.CreateMessage {
var prompt = prompt
let parameters: AbstractLLM.ChatCompletionParameters? = try parameters ?? cast(prompt.context.completionParameters)
let model: Anthropic.Model = try await _model(for: prompt)

/// Anthropic doesn't support a `system` role.
let system: String? = try prompt.messages
.removeFirst(byUnwrapping: { $0.role == .system ? $0.content : nil })
.map({ try $0._stripToText() })

let messages = try prompt.messages.map { (message: AbstractLLM.ChatMessage) in
try Anthropic.ChatMessage(from: message)
}

let requestBody = Anthropic.API.RequestBodies.CreateMessage(
model: model,
messages: messages,
system: system,
maxTokens: parameters?.tokenLimit?.fixedValue ?? 4000, // FIXME: Hardcoded,
temperature: parameters?.temperatureOrTopP?.temperature,
topP: parameters?.temperatureOrTopP?.topProbabilityMass,
topK: nil,
stopSequences: parameters?.stops,
stream: stream,
metadata: nil
)

return requestBody
}

private func _model(
for prompt: any AbstractLLM.Prompt
) async throws -> Anthropic.Model {
do {
guard let modelIdentifierScope: _MLModelIdentifierScope = prompt.context.get(\.modelIdentifier) else {
return Anthropic.Model.claude_3_opus_20240229
}

let modelIdentifier: _MLModelIdentifier = try modelIdentifierScope._oneValue.unwrap()
let modelIdentifier: _MLModelIdentifier = try modelIdentifierScope._oneValue

return try Anthropic.Model(from: modelIdentifier)
} catch {
Expand Down
39 changes: 28 additions & 11 deletions Sources/Anthropic/Intramodular/Anthropic.API.swift
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ extension Anthropic.API.RequestBodies {
extension Anthropic.API.ResponseBodies {
public struct Complete: Codable, Hashable, Sendable {
public enum StopReason: String, Codable, Hashable, Sendable {
case stopSequence = "stop_sequence"
case maxTokens = "max_tokens"
case stopSequence = "stop_sequence"
}

public var completion: String
Expand All @@ -239,31 +239,48 @@ extension Anthropic.API.ResponseBodies {
case type
case role
case content
case stopReason = "stop_reason"
case stopSequence = "stop_sequence"
case stopReason
case stopSequence
case usage
}

public enum StopReason: String, Codable, Hashable, Sendable {
case endTurn = "end_turn"
case maxTokens = "max_tokens"
case stopSequence = "stop_sequence"

public func __conversion() -> AbstractLLM.ChatCompletion.StopReason {
switch self {
case .endTurn:
return .endTurn
case .maxTokens:
return .maxTokens
case .stopSequence:
return .stopSequence
}
}
}

public let id: String
public let model: Anthropic.Model
public let type: String?
public let role: Anthropic.ChatMessage.Role
public let content: [Content]
public let stopReason: String?
public let stopReason: StopReason?
public let stopSequence: String?
public let usage: Usage

public enum ContentType: String, Codable, Hashable, Sendable {
case image // FIXME: Unimplemented
case text
}

public struct Content: Codable, Hashable, Sendable {
public let type: String
public let type: ContentType
public let text: String
}

public struct Usage: Codable, Hashable, Sendable {
public enum CodingKeys: String, CodingKey {
case inputTokens = "input_tokens"
case outputTokens = "output_tokens"
}

public let inputTokens: Int
public let outputTokens: Int
}
Expand All @@ -275,7 +292,7 @@ extension Anthropic.API.ResponseBodies {
case index
case message
case delta
case contentBlock = "content_block"
case contentBlock
}

public struct Delta: Codable, Hashable, Sendable {
Expand Down
10 changes: 10 additions & 0 deletions Sources/Anthropic/Intramodular/Anthropic.ChatMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ extension Anthropic {
}

extension Anthropic.ChatMessage {
public init(
role: Role,
content: [Anthropic.API.ResponseBodies.CreateMessage.Content]
) throws {
assert(content.allSatisfy({ $0.type == .text }))

self.role = role
self.content = content.map({ $0.text }).joined()
}

public init(
from message: AbstractLLM.ChatMessage
) throws {
Expand Down
5 changes: 5 additions & 0 deletions Sources/Anthropic/Intramodular/Anthropic.Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ extension Anthropic {
case claude_v1_2 = "claude-v1.2"
case claude_v1_3 = "claude-v1.3"

case claude_3_haiku_20240307 = "claude-3-haiku-20240307"
case claude_3_sonnet_20240229 = "claude-3-sonnet-20240229"
case claude_3_opus_20240229 = "claude-3-opus-20240229"

Expand Down Expand Up @@ -54,6 +55,8 @@ extension Anthropic {
return "Claude 1.2"
case .claude_v1_3:
return "Claude 1.3"
case .claude_3_haiku_20240307:
return "Claude 3 Haiku"
case .claude_3_sonnet_20240229:
return "Claude 3 Sonnet"
case .claude_3_opus_20240229:
Expand All @@ -63,6 +66,8 @@ extension Anthropic {

public var contextSize: Int? {
switch self {
case .claude_3_haiku_20240307:
return 200000
case .claude_3_sonnet_20240229:
return 200000
case .claude_3_opus_20240229:
Expand Down
2 changes: 2 additions & 0 deletions Sources/Anthropic/Intramodular/Anthropic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public final class Anthropic: HTTPClient, PersistentlyRepresentableType, _Static
public init(interface: API, session: HTTPSession) {
self.interface = interface
self.session = session

session.disableTimeouts()
}

public convenience init(apiKey: String?) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ extension _MLModelIdentifier {
case claude_v1_0 = "claude-v1.0"
case claude_v1_2 = "claude-v1.2"
case claude_v1_3 = "claude-v1.3"
case claude_3_haiku_20240307 = "claude-3-haiku-20240307"
case claude_3_sonnet_20240229 = "claude-3-sonnet-20240229"
case claude_3_opus_20240229 = "claude-3-opus-20240229"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ public enum _MLModelIdentifierScope: Codable, Hashable, Sendable {
case one(_MLModelIdentifier)
case choiceOf(Set<_MLModelIdentifier>)

public var _oneValue: _MLModelIdentifier? {
guard case .one(let value) = self else {
if case .choiceOf(let set) = self, let value = try? set.toCollectionOfOne().first {
return value
public var _oneValue: _MLModelIdentifier {
get throws {
guard case .one(let value) = self else {
if case .choiceOf(let set) = self, let value = try set.toCollectionOfOne().first {
return value
}

throw Never.Reason.illegal
}

return nil
return value
}

return value
}

public init(_ identifier: _MLModelIdentifier) {
Expand Down
Loading

0 comments on commit 61211f9

Please sign in to comment.