Skip to content

Commit

Permalink
Stream syntax and background performance update. Update stream error log
Browse files Browse the repository at this point in the history
  • Loading branch information
kyrylo-mukha committed Apr 6, 2023
1 parent 0820a99 commit aa0605b
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 35 deletions.
2 changes: 1 addition & 1 deletion OpenAIKit.podspec
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

Pod::Spec.new do |s|
s.name = 'OpenAIKit'
s.version = '1.3.2'
s.version = '1.3.3'
s.summary = 'OpenAI is a community-maintained repository containing Swift implementation over OpenAI public API.'

s.description = <<-DESC
Expand Down
14 changes: 12 additions & 2 deletions Sources/OpenAIKit/Helpers/AIEventStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public final class AIEventStream<ResponseType: Decodable>: NSObject, URLSessionD

private var isStreamActive: Bool = false

private var fetchError: Error? = nil

init(request: URLRequest) {
self.request = request
self.operationQueue = OperationQueue()
Expand Down Expand Up @@ -73,6 +75,14 @@ public final class AIEventStream<ResponseType: Decodable>: NSObject, URLSessionD
public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
guard isStreamActive else { return }

if let response = (dataTask.response as? HTTPURLResponse), let decodedError = try? JSONSerialization.jsonObject(with: data) as? [String: Any] {
guard 200 ... 299 ~= response.statusCode, decodedError["error"] != nil else {
let error = NSError(domain: NSURLErrorDomain, code: response.statusCode, userInfo: decodedError)
fetchError = error
return
}
}

let decoder = JSONDecoder.aiDecoder

let dataString = String(data: data, encoding: .utf8) ?? ""
Expand All @@ -91,11 +101,11 @@ public final class AIEventStream<ResponseType: Decodable>: NSObject, URLSessionD

public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
guard let responseStatusCode = (task.response as? HTTPURLResponse)?.statusCode else {
try? onCompleteCompletion?(nil, false, error as NSError?)
try? onCompleteCompletion?(nil, false, error ?? fetchError)
return
}

try? onCompleteCompletion?(responseStatusCode, false, nil)
try? onCompleteCompletion?(responseStatusCode, false, error ?? fetchError)
}

public func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) {
Expand Down
46 changes: 46 additions & 0 deletions Sources/OpenAIKit/Helpers/OpenAIKitNetwork.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,47 @@ public final class OpenAIKitNetwork {
task.resume()
}

fileprivate struct StreamTaskState {
var isStreamFinished = false
var isStreamForceStop = false
}

func requestStream<ResponseType: Decodable>(_ method: OpenAIHTTPMethod, url: String, body: Data? = nil, headers: OpenAIHeaders? = nil, completion: @escaping (Result<AIStreamResponse<ResponseType>, Error>) -> Void) {
guard let url = URL(string: url) else {
completion(.failure(OpenAINetworkError.invalidURL))
return
}

var request = URLRequest(url: url)
request.httpMethod = method.rawValue
request.httpBody = body

headers?.forEach { key, value in
request.addValue(value, forHTTPHeaderField: key)
}

let stream = AIEventStream<ResponseType>(request: request)
var streamState = StreamTaskState()

stream.onMessage { data, message in
completion(.success(AIStreamResponse(stream: stream, message: message, data: data, isFinished: streamState.isStreamFinished, forceEnd: streamState.isStreamForceStop)))
}

stream.onComplete { _, forceEnd, error in
if let error {
completion(.failure(error))
return
}

streamState.isStreamFinished = true
streamState.isStreamForceStop = forceEnd

completion(.success(AIStreamResponse(stream: stream, message: nil, data: nil, isFinished: streamState.isStreamFinished, forceEnd: streamState.isStreamForceStop)))
}

stream.startStream()
}

func requestStream<ResponseType: Decodable>(_ method: OpenAIHTTPMethod, url: String, body: Data? = nil, headers: OpenAIHeaders? = nil) async throws -> AsyncThrowingStream<AIStreamResponse<ResponseType>, Error> {
guard let url = URL(string: url) else {
throw OpenAINetworkError.invalidURL
Expand All @@ -94,11 +135,16 @@ public final class OpenAIKitNetwork {

return AsyncThrowingStream<AIStreamResponse<ResponseType>, Error> { continuation in
Task(priority: .userInitiated) {
var streamState = StreamTaskState()

stream.onMessage { data, message in
continuation.yield(AIStreamResponse(stream: stream, message: message, data: data))
}

stream.onComplete { _, forceEnd, error in
streamState.isStreamFinished = true
streamState.isStreamForceStop = forceEnd

if let error { throw error }

continuation.yield(AIStreamResponse(stream: stream, message: nil, data: nil, isFinished: true, forceEnd: forceEnd))
Expand Down
31 changes: 15 additions & 16 deletions Sources/OpenAIKit/OpenAIKitRequests/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,8 @@ public extension OpenAIKit {

let headers = baseHeaders

Task(priority: .userInitiated) {
do {
let asyncStream: AsyncThrowingStream<AIStreamResponse<AIResponseModel>, Error> = try await network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers)

for try await result in asyncStream {
completion(.success(result))
}
} catch {
completion(.failure(error))
}
network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers) { (result: Result<AIStreamResponse<AIResponseModel>, Error>) in
completion(result)
}
}

Expand All @@ -148,12 +140,19 @@ public extension OpenAIKit {
presencePenalty: Double? = nil,
logprobs: Int? = nil,
stop: [String]? = nil,
user: String? = nil) async -> Result<AIStreamResponse<AIResponseModel>, Error>
user: String? = nil) async throws -> AsyncThrowingStream<AIStreamResponse<AIResponseModel>, Error>
{
return await withCheckedContinuation { continuation in
sendStreamChatCompletion(newMessage: newMessage, previousMessages: previousMessages, model: model, maxTokens: maxTokens, temperature: temperature, n: n, topP: topP, frequencyPenalty: frequencyPenalty, presencePenalty: presencePenalty, logprobs: logprobs, stop: stop, user: user) { result in
continuation.resume(returning: result)
}
}
let endpoint = OpenAIEndpoint.chatCompletions

var messages = previousMessages
messages.append(newMessage)

let requestBody = ChatCompletionsRequest(model: model, messages: messages, temperature: temperature, n: n, maxTokens: maxTokens, topP: topP, frequencyPenalty: frequencyPenalty, presencePenalty: presencePenalty, logprobs: logprobs, stop: stop, user: user, stream: true)

let requestData = try? jsonEncoder.encode(requestBody)

let headers = baseHeaders

return try await network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers)
}
}
28 changes: 12 additions & 16 deletions Sources/OpenAIKit/OpenAIKitRequests/Completions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,8 @@ public extension OpenAIKit {

let headers = baseHeaders

Task(priority: .userInitiated) {
do {
let asyncStream: AsyncThrowingStream<AIStreamResponse<AIResponseModel>, Error> = try await network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers)

for try await result in asyncStream {
completion(.success(result))
}
} catch {
completion(.failure(error))
}
network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers) { (result: Result<AIStreamResponse<AIResponseModel>, Error>) in
completion(result)
}
}

Expand All @@ -135,12 +127,16 @@ public extension OpenAIKit {
presencePenalty: Double? = nil,
logprobs: Int? = nil,
stop: [String]? = nil,
user: String? = nil) async -> Result<AIStreamResponse<AIResponseModel>, Error>
user: String? = nil) async throws -> AsyncThrowingStream<AIStreamResponse<AIResponseModel>, Error>
{
return await withCheckedContinuation { continuation in
sendStreamCompletion(prompt: prompt, model: model, maxTokens: maxTokens, temperature: temperature, n: n, topP: topP, frequencyPenalty: frequencyPenalty, presencePenalty: presencePenalty, logprobs: logprobs, stop: stop, user: user) { result in
continuation.resume(returning: result)
}
}
let endpoint = OpenAIEndpoint.completions

let requestBody = CompletionsRequest(model: model, prompt: prompt, temperature: temperature, n: n, maxTokens: maxTokens, topP: topP, frequencyPenalty: frequencyPenalty, presencePenalty: presencePenalty, logprobs: logprobs, stop: stop, user: user, stream: true)

let requestData = try? jsonEncoder.encode(requestBody)

let headers = baseHeaders

return try await network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers)
}
}

0 comments on commit aa0605b

Please sign in to comment.