Skip to content

Commit

Permalink
Merge pull request #4 from PatilShreyas/v0.2.0-kmp
Browse files Browse the repository at this point in the history
Sync with google/generative-ai-android v0.2.0
  • Loading branch information
PatilShreyas authored Feb 18, 2024
2 parents d3984b8 + f2c73e5 commit 1d56070
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 612 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ generate text from text-and-image input:

```kotlin
val generativeModel = GenerativeModel(
modelName = "gemini-pro-vision", // or "gemini-pro" for text-only input
modelName = "gemini-1.0-pro-vision-latest",
apiKey = "YOUR_API_KEY"
)

Expand Down Expand Up @@ -59,7 +59,7 @@ The versioning scheme is of the form `X-Y` where:

X is the _Generative AI Android SDK_ version that is being tracked.
Y is the _Multiplatform SDK_ version.
For example, if _Generative AI Android SDK_ is on `0.1.2` and _Multiplatform SDK_ is on `0.0.1`, the artifact for a release will be `dev.shreyaspatil.generativeai:generativeai-google:0.1.2-0.0.1`.
For example, if _Generative AI Android SDK_ is on `0.2.0` and _Multiplatform SDK_ is on `1.0.0`, the artifact for a release will be `dev.shreyaspatil.generativeai:generativeai-google:0.2.0-1.0.0`.

## Try sample app

Expand Down
292 changes: 0 additions & 292 deletions api/0.1.1.api

This file was deleted.

289 changes: 0 additions & 289 deletions api/0.1.2.api

This file was deleted.

2 changes: 1 addition & 1 deletion generativeai/gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
SONATYPE_HOST=DEFAULT
RELEASE_SIGNING_ENABLED=true
GROUP=dev.shreyaspatil.generativeai
VERSION_NAME=0.1.2-0.0.1
VERSION_NAME=0.2.0-1.0.0

POM_ARTIFACT_ID=generativeai-google
POM_NAME=Google Generative AI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import dev.shreyaspatil.ai.client.generativeai.type.GenerationConfig
import dev.shreyaspatil.ai.client.generativeai.type.GoogleGenerativeAIException
import dev.shreyaspatil.ai.client.generativeai.type.PlatformImage
import dev.shreyaspatil.ai.client.generativeai.type.PromptBlockedException
import dev.shreyaspatil.ai.client.generativeai.type.RequestOptions
import dev.shreyaspatil.ai.client.generativeai.type.ResponseStoppedException
import dev.shreyaspatil.ai.client.generativeai.type.SafetySetting
import dev.shreyaspatil.ai.client.generativeai.type.SerializationException
Expand All @@ -52,6 +53,7 @@ internal constructor(
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController,
) {

Expand All @@ -61,7 +63,15 @@ internal constructor(
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName))
requestOptions: RequestOptions = RequestOptions(),
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
requestOptions,
APIController(apiKey, modelName, requestOptions.apiVersion, requestOptions.timeout),
)

/**
* Generates a response from the backend with the provided [Content]s.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json
import kotlin.time.Duration

// TODO: Should these stay here or be moved elsewhere?
internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1"
internal const val DOMAIN = "https://generativelanguage.googleapis.com"

internal val JSON = Json {
ignoreUnknownKeys = true
Expand All @@ -58,34 +58,38 @@ internal val JSON = Json {
* Exposed primarily for DI in tests.
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property apiVersion the endpoint version to communicate with.
* @property timeout the maximum amount of time for a request to take in the initial exchange.
*/
internal class APIController(
private val key: String,
model: String,
engine: HttpClientEngine? = null,
private val apiVersion: String,
private val timeout: Duration,
httpEngine: HttpClientEngine? = null,
) {
private val model = fullModelName(model)
private val client = getHttpClient(engine)
private val client = getHttpClient(httpEngine, timeout)

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse {
return client
.post("$DOMAIN/$model:generateContent") { applyCommonConfiguration(request) }
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
client
.post("$DOMAIN/$apiVersion/$model:generateContent") { applyCommonConfiguration(request) }
.also { validateResponse(it) }
.body()
}

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
return client.postStream("$DOMAIN/$model:streamGenerateContent?alt=sse") {
return client.postStream<GenerateContentResponse>(
"$DOMAIN/$apiVersion/$model:streamGenerateContent?alt=sse",
) {
applyCommonConfiguration(request)
}
}

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse {
return client
.post("$DOMAIN/$model:countTokens") { applyCommonConfiguration(request) }
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
client
.post("$DOMAIN/$apiVersion/$model:countTokens") { applyCommonConfiguration(request) }
.also { validateResponse(it) }
.body()
}

private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
when (request) {
Expand All @@ -98,10 +102,10 @@ internal class APIController(
}

companion object {
fun getHttpClient(engine: HttpClientEngine?): HttpClient {
fun getHttpClient(engine: HttpClientEngine?, timeout: Duration): HttpClient {
val configuration: HttpClientConfig<*>.() -> Unit = {
install(HttpTimeout) {
requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS
requestTimeoutMillis = timeout.inWholeMilliseconds
socketTimeoutMillis = 80_000
}
install(ContentNegotiation) { json(JSON) }
Expand All @@ -120,8 +124,7 @@ internal class APIController(
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
private fun fullModelName(name: String): String =
name.takeIf { it.startsWith("models/") } ?: "models/$name"
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package dev.shreyaspatil.ai.client.generativeai.type

import dev.shreyaspatil.ai.client.generativeai.GenerativeModel
import io.ktor.serialization.JsonConvertException
import kotlinx.coroutines.TimeoutCancellationException

/** Parent class for any errors that occur from [GenerativeModel]. */
sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) :
Expand All @@ -39,6 +40,10 @@ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = nu
"Something went wrong while trying to deserialize a response from the server.",
cause,
)

is TimeoutCancellationException ->
RequestTimeoutException("The request failed to complete in the allotted time.")

else -> UnknownException("Something unexpected happened.", cause)
}
}
Expand Down Expand Up @@ -84,6 +89,14 @@ class ResponseStoppedException(val response: GenerateContentResponse, cause: Thr
cause,
)

/**
* A request took too long to complete.
*
* Usually occurs due to a user specified [timeout][RequestOptions.timeout].
*/
class RequestTimeoutException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** Catch all case for exceptions not explicitly expected. */
class UnknownException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2024 Shreyas Patil
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.shreyaspatil.ai.client.generativeai.type

import io.ktor.client.plugins.HttpTimeout
import kotlin.jvm.JvmOverloads
import kotlin.time.Duration
import kotlin.time.DurationUnit
import kotlin.time.toDuration

/**
* Configurable options unique to how requests to the backend are performed.
*
* @property timeout the maximum amount of time for a request to take, from the first request to
* first response.
* @property apiVersion the api endpoint to call.
*/
class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1",
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,28 @@
*/
package dev.shreyaspatil.ai.client.generativeai

import dev.shreyaspatil.ai.client.generativeai.type.RequestOptions
import dev.shreyaspatil.ai.client.generativeai.type.RequestTimeoutException
import dev.shreyaspatil.ai.client.generativeai.util.commonTest
import dev.shreyaspatil.ai.client.generativeai.util.createGenerativeModel
import dev.shreyaspatil.ai.client.generativeai.util.createResponses
import dev.shreyaspatil.ai.client.generativeai.util.doBlocking
import dev.shreyaspatil.ai.client.generativeai.util.prepareStreamingResponse
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.withTimeout
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.time.Duration.Companion.seconds

internal class GenerativeModelTests {
Expand All @@ -44,4 +57,53 @@ internal class GenerativeModelTests {
}
}
}

@Test
fun `(generateContent) respects a custom timeout`() =
commonTest(requestOptions = RequestOptions(2.seconds)) {
shouldThrow<RequestTimeoutException> {
withTimeout(testTimeout) { model.generateContent("d") }
}
}
}

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {

@Test
fun `request should include right model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(
channel,
HttpStatusCode.OK,
headersOf(HttpHeaders.ContentType, "application/json"),
)
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val model =
createGenerativeModel(modelName, "super_cool_test_key", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
model.generateContentStream().collect {
it.candidates.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
}

companion object {
@JvmStatic
@Parameterized.Parameters
fun data() =
listOf(
arrayOf("gemini-pro", "models/gemini-pro"),
arrayOf("x/gemini-pro", "x/gemini-pro"),
arrayOf("models/gemini-pro", "models/gemini-pro"),
arrayOf("/modelname", "/modelname"),
arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import dev.shreyaspatil.ai.client.generativeai.internal.api.shared.Content
import dev.shreyaspatil.ai.client.generativeai.internal.api.shared.TextPart
import dev.shreyaspatil.ai.client.generativeai.internal.util.SSE_SEPARATOR
import dev.shreyaspatil.ai.client.generativeai.internal.util.send
import dev.shreyaspatil.ai.client.generativeai.type.RequestOptions
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
Expand Down Expand Up @@ -92,19 +93,42 @@ internal typealias CommonTest = suspend CommonTestScope.() -> Unit
* ```
*
* @param status An optional [HttpStatusCode] to return as a response
* @param requestOptions Optional [RequestOptions] to utilize in the underlying controller
* @param block The test contents themselves, with the [CommonTestScope] implicitly provided
* @see CommonTestScope
*/
internal fun commonTest(status: HttpStatusCode = HttpStatusCode.OK, block: CommonTest) =
doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val controller = APIController("super_cool_test_key", "gemini-pro", mockEngine)
val model = GenerativeModel("gemini-pro", "super_cool_test_key", controller = controller)
CommonTestScope(channel, model).block()
internal fun commonTest(
status: HttpStatusCode = HttpStatusCode.OK,
requestOptions: RequestOptions = RequestOptions(),
block: CommonTest,
) = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val model = createGenerativeModel("gemini-pro", "super_cool_test_key", requestOptions, mockEngine)
CommonTestScope(channel, model).block()
}

/** Simple wrapper that guarantees the model and APIController are created using the same data */
internal fun createGenerativeModel(
name: String,
apikey: String,
requestOptions: RequestOptions = RequestOptions(),
engine: MockEngine,
) =
GenerativeModel(
name,
apikey,
controller =
APIController(
"super_cool_test_key",
name,
requestOptions.apiVersion,
requestOptions.timeout,
engine,
),
)

/**
* A variant of [commonTest] for performing *streaming-based* snapshot tests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ abstract class VersionBumpTask : DefaultTask() {

@TaskAction
fun build() {
if(newVersion.get().major > 0)
throw RuntimeException("You're trying to bump the major version. This is a no 1.0+ zone!!")

versionFile.get().rewriteLines {
when {
it.startsWith("version=") -> "version=${newVersion.get()}"
Expand Down

0 comments on commit 1d56070

Please sign in to comment.