Skip to content

Commit

Permalink
Added support of custom OpenAI URL
Browse files Browse the repository at this point in the history
  • Loading branch information
kyrylo-mukha committed Nov 1, 2023
1 parent 9431444 commit 74c56d6
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 28 deletions.
2 changes: 1 addition & 1 deletion OpenAIKit.podspec
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

Pod::Spec.new do |s|
s.name = 'OpenAIKit'
s.version = '1.5.0'
s.version = '1.6.0'
s.summary = 'OpenAI is a community-maintained repository containing Swift implementation over OpenAI public API.'

s.description = <<-DESC
Expand Down
19 changes: 11 additions & 8 deletions Sources/OpenAIKit/Helpers/NetworkRoutes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,24 @@ enum OpenAIHTTPMethod: String {

typealias OpenAIHeaders = [String: String]

@available(swift 5.5)
@available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *)
protocol Endpoint {
var route: String { get }
var method: OpenAIHTTPMethod { get }
var baseURL: String { get }
var urlPath: String { get }
func urlPath(for aiKit: OpenAIKit) -> String
}

enum OpenAIEndpoint {
case completions
case chatCompletions
case edits
case dalleImage
case dalleImageEdit
case dalleImageEdit
}

@available(swift 5.5)
@available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *)
extension OpenAIEndpoint: Endpoint {
var route: String {
switch self {
Expand All @@ -47,8 +50,8 @@ extension OpenAIEndpoint: Endpoint {
return "/v1/edits"
case .dalleImage:
return "/v1/images/generations"
case .dalleImageEdit:
return "/v1/images/edits"
case .dalleImageEdit:
return "/v1/images/edits"
}
}

Expand All @@ -59,14 +62,14 @@ extension OpenAIEndpoint: Endpoint {
}
}

var baseURL: String {
private var baseURL: String {
switch self {
default:
return "https://api.openai.com"
}
}

var urlPath: String {
baseURL + route
func urlPath(for aiKit: OpenAIKit) -> String {
(aiKit.customOpenAIURL ?? baseURL) + route
}
}
23 changes: 13 additions & 10 deletions Sources/OpenAIKit/OpenAIKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ public final class OpenAIKit {

internal let jsonEncoder = JSONEncoder.aiEncoder

public let customOpenAIURL: String?

/// Initialize `OpenAIKit` with your API Token wherever convenient in your project. Organization name is optional.
public init(apiToken: String, organization: String? = nil, timeoutInterval: TimeInterval = 60) {
public init(apiToken: String, organization: String? = nil, timeoutInterval: TimeInterval = 60, customOpenAIURL: String? = nil) {
self.apiToken = apiToken
self.organization = organization
self.customOpenAIURL = customOpenAIURL

let configuration = URLSessionConfiguration.default
configuration.timeoutIntervalForRequest = timeoutInterval
Expand Down Expand Up @@ -49,17 +52,17 @@ extension OpenAIKit {
return headers
}

var baseMultipartHeaders: OpenAIHeaders {
var headers: OpenAIHeaders = [:]
var baseMultipartHeaders: OpenAIHeaders {
var headers: OpenAIHeaders = [:]

headers["Authorization"] = "Bearer \(apiToken)"
headers["Authorization"] = "Bearer \(apiToken)"

if let organization {
headers["OpenAI-Organization"] = organization
}
if let organization {
headers["OpenAI-Organization"] = organization
}

headers["content-type"] = "multipart/form-data"
headers["content-type"] = "multipart/form-data"

return headers
}
return headers
}
}
6 changes: 3 additions & 3 deletions Sources/OpenAIKit/OpenAIKitRequests/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public extension OpenAIKit {

let headers = baseHeaders

network.request(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers, completion: completion)
network.request(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers, completion: completion)
}

@available(swift 5.5)
Expand Down Expand Up @@ -122,7 +122,7 @@ public extension OpenAIKit {

let headers = baseHeaders

network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers) { (result: Result<AIStreamResponse<AIResponseModel>, Error>) in
network.requestStream(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers) { (result: Result<AIStreamResponse<AIResponseModel>, Error>) in
completion(result)
}
}
Expand Down Expand Up @@ -153,6 +153,6 @@ public extension OpenAIKit {

let headers = baseHeaders

return try await network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers)
return try await network.requestStream(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers)
}
}
6 changes: 3 additions & 3 deletions Sources/OpenAIKit/OpenAIKitRequests/Completions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public extension OpenAIKit {

let headers = baseHeaders

network.request(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers, completion: completion)
network.request(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers, completion: completion)
}

@available(swift 5.5)
Expand Down Expand Up @@ -110,7 +110,7 @@ public extension OpenAIKit {

let headers = baseHeaders

network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers) { (result: Result<AIStreamResponse<AIResponseModel>, Error>) in
network.requestStream(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers) { (result: Result<AIStreamResponse<AIResponseModel>, Error>) in
completion(result)
}
}
Expand All @@ -137,6 +137,6 @@ public extension OpenAIKit {

let headers = baseHeaders

return try await network.requestStream(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers)
return try await network.requestStream(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers)
}
}
2 changes: 1 addition & 1 deletion Sources/OpenAIKit/OpenAIKitRequests/Edits.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public extension OpenAIKit {

let headers = baseHeaders

network.request(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers, completion: completion)
network.request(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers, completion: completion)
}

@available(swift 5.5)
Expand Down
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/OpenAIKitRequests/Images.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public extension OpenAIKit {

let headers = baseHeaders

network.request(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers, completion: completion)
network.request(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers, completion: completion)
}

@available(swift 5.5)
Expand Down Expand Up @@ -77,7 +77,7 @@ public extension OpenAIKit {

let headers = baseMultipartHeaders

network.request(endpoint.method, url: endpoint.urlPath, body: requestData, headers: headers, completion: completion)
network.request(endpoint.method, url: endpoint.urlPath(for: self), body: requestData, headers: headers, completion: completion)
}

@available(swift 5.5)
Expand Down

0 comments on commit 74c56d6

Please sign in to comment.