Skip to content

Commit

Permalink
Merge pull request #11 from PatilShreyas/v0.9.0
Browse files Browse the repository at this point in the history
Sync with google/generative-ai-android v0.9.0
  • Loading branch information
PatilShreyas authored Sep 21, 2024
2 parents bb2d555 + 573d41d commit c1fe685
Show file tree
Hide file tree
Showing 46 changed files with 13,309 additions and 417 deletions.
1 change: 0 additions & 1 deletion .changes/generativeai/cork-dock-apple-cobweb.json

This file was deleted.

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ Try these apps built with this SDK by the community:
You can use the APIs mentioned in the [API Reference](https://ai.google.dev/tutorials/android_quickstart)
by the official library.

See the [Gemini API Cookbook](https://github.com/google-gemini/gemini-api-cookbook/) or [ai.google.dev](https://ai.google.dev) for complete documentation.

From the official library, there are two major changes:
- Package `com.google` is mapped to `dev.shreyaspatil`.
- `Image(Bitmap)` was there for Android, instead `PlatformImage(ByteArray)` is used for KMP.
Expand Down
1,569 changes: 1,569 additions & 0 deletions api/common/0.10.0.api

Large diffs are not rendered by default.

1,223 changes: 1,223 additions & 0 deletions api/common/0.5.0.api

Large diffs are not rendered by default.

1,414 changes: 1,414 additions & 0 deletions api/common/0.6.0.api

Large diffs are not rendered by default.

1,429 changes: 1,429 additions & 0 deletions api/common/0.7.0.api

Large diffs are not rendered by default.

1,429 changes: 1,429 additions & 0 deletions api/common/0.7.1.api

Large diffs are not rendered by default.

1,567 changes: 1,567 additions & 0 deletions api/common/0.8.0.api

Large diffs are not rendered by default.

1,568 changes: 1,568 additions & 0 deletions api/common/0.9.0.api

Large diffs are not rendered by default.

507 changes: 507 additions & 0 deletions api/generativeai/0.6.0.api

Large diffs are not rendered by default.

509 changes: 509 additions & 0 deletions api/generativeai/0.7.0.api

Large diffs are not rendered by default.

511 changes: 511 additions & 0 deletions api/generativeai/0.8.0.api

Large diffs are not rendered by default.

492 changes: 492 additions & 0 deletions api/generativeai/0.9.0.api

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/

plugins {
//trick: for the same plugin versions in all sub-modules
alias(libs.plugins.androidLibrary).apply(false)
alias(libs.plugins.kotlinMultiplatform).apply(false)
id("org.jetbrains.dokka") version "1.9.10" apply false
Expand Down
2 changes: 2 additions & 0 deletions common/consumer-rules.pro
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

-keep class dev.shreyaspatil.ai.client.generativeai.common.** { *; }
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package dev.shreyaspatil.ai.client.generativeai.common
import dev.shreyaspatil.ai.client.generativeai.common.server.FinishReason
import dev.shreyaspatil.ai.client.generativeai.common.util.Log
import dev.shreyaspatil.ai.client.generativeai.common.util.decodeToFlow
import dev.shreyaspatil.ai.client.generativeai.common.util.fullModelName
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
Expand All @@ -35,6 +36,7 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import io.ktor.utils.io.ByteChannel
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.flow.Flow
Expand All @@ -46,9 +48,10 @@ import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.Json
import kotlin.time.Duration

val JSON = Json {
internal val JSON = Json {
ignoreUnknownKeys = true
prettyPrint = false
isLenient = true
}

/**
Expand Down Expand Up @@ -81,6 +84,25 @@ internal constructor(
headerProvider: HeaderProvider? = null,
) : this(key, model, requestOptions, null, apiClient, headerProvider)

// @VisibleForTesting(otherwise = VisibleForTesting.NONE)
constructor(
key: String,
model: String,
requestOptions: RequestOptions,
apiClient: String,
headerProvider: HeaderProvider?,
httpEngine: HttpClientEngine?,
channel: ByteChannel,
status: HttpStatusCode,
) : this(
key,
model,
requestOptions,
httpEngine,
apiClient,
headerProvider,
)

private val model = fullModelName(model)

private val client = if (httpEngine != null) {
Expand Down Expand Up @@ -220,22 +242,16 @@ interface HeaderProvider {
suspend fun generateHeaders(): Map<String, String>
}

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
val message =
val error =
try {
JSON.decodeFromString<GRpcErrorResponse>(text).error.message
JSON.decodeFromString<GRpcErrorResponse>(text).error
} catch (e: Throwable) {
"Unexpected Response:\n$text"
throw ServerException("Unexpected Response:\n$text $e")
}
val message = error.message
if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
}
Expand All @@ -246,6 +262,9 @@ private suspend fun validateResponse(response: HttpResponse) {
if (message.contains("quota")) {
throw QuotaExceededException(message)
}
if (error.details.any { "SERVICE_DISABLED" == it.reason }) {
throw ServiceDisabledException(message)
}
throw ServerException(message)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ class RequestTimeoutException(message: String, cause: Throwable? = null) :
class QuotaExceededException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** The service is not enabled for this project. Visit the Firebase Console to enable it. */
class ServiceDisabledException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** Catch all case for exceptions not explicitly expected. */
class UnknownException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ import dev.shreyaspatil.ai.client.generativeai.common.client.Tool
import dev.shreyaspatil.ai.client.generativeai.common.client.ToolConfig
import dev.shreyaspatil.ai.client.generativeai.common.shared.Content
import dev.shreyaspatil.ai.client.generativeai.common.shared.SafetySetting
import dev.shreyaspatil.ai.client.generativeai.common.util.fullModelName
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient

sealed interface Request

@Serializable
data class GenerateContentRequest(
@Transient val model: String? = null,
val model: String? = null,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
Expand All @@ -38,5 +38,28 @@ data class GenerateContentRequest(
) : Request

@Serializable
data class CountTokensRequest(@Transient val model: String? = null, val contents: List<Content>) :
Request
data class CountTokensRequest(
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
generateContentRequest =
generateContentRequest.model?.let {
generateContentRequest.copy(model = fullModelName(it))
} ?: generateContentRequest,
)

fun forVertexAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
model = generateContentRequest.model?.let { fullModelName(it) },
contents = generateContentRequest.contents,
tools = generateContentRequest.tools,
systemInstruction = generateContentRequest.systemInstruction,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import dev.shreyaspatil.ai.client.generativeai.common.util.enumSerializer
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

@Serializable
data class GenerationConfig(
Expand All @@ -29,9 +30,18 @@ data class GenerationConfig(
@SerialName("candidate_count") val candidateCount: Int?,
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?,
@SerialName("response_mime_type") val responseMimeType: String? = null,
@SerialName("presence_penalty") val presencePenalty: Float? = null,
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
@SerialName("response_schema") val responseSchema: Schema? = null,
)

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)
@Serializable
data class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
// This is a json object because it is not possible to make a data class with no parameters.
val codeExecution: JsonObject? = null,
)

@Serializable
data class ToolConfig(
Expand All @@ -50,17 +60,14 @@ data class FunctionCallingConfig(val mode: Mode) {
}

@Serializable
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: Schema,
)
data class FunctionDeclaration(val name: String, val description: String, val parameters: Schema)

@Serializable
data class Schema(
val type: String,
val description: String? = null,
val format: String? = null,
val nullable: Boolean? = false,
val enum: List<String>? = null,
val properties: Map<String, Schema>? = null,
val required: List<String>? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ data class Candidate(
val finishReason: FinishReason? = null,
val safetyRatings: List<SafetyRating>? = null,
val citationMetadata: CitationMetadata? = null,
val groundingMetadata: GroundingMetadata? = null,
)

@Serializable
Expand All @@ -73,12 +74,39 @@ data class CitationSources(
data class SafetyRating(
val category: HarmCategory,
val probability: HarmProbability,
// TODO(): any reason not to default to false?
val blocked: Boolean? = null,
val probabilityScore: Float? = null,
val severity: HarmSeverity? = null,
val severityScore: Float? = null,
)

@Serializable
data class GroundingMetadata(
@SerialName("web_search_queries") val webSearchQueries: List<String>?,
@SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?,
@SerialName("retrieval_queries") val retrievalQueries: List<String>?,
@SerialName("grounding_attribution") val groundingAttribution: List<GroundingAttribution>?,
)

@Serializable
data class SearchEntryPoint(
@SerialName("rendered_content") val renderedContent: String?,
@SerialName("sdk_blob") val sdkBlob: String?,
)

@Serializable
data class GroundingAttribution(
val segment: Segment,
@SerialName("confidence_score") val confidenceScore: Float?,
)

@Serializable
data class Segment(
@SerialName("start_index") val startIndex: Int,
@SerialName("end_index") val endIndex: Int,
)

@Serializable(HarmProbabilitySerializer::class)
enum class HarmProbability(override val serialName: String) : SerializableEnum<HarmProbability> {
UNKNOWN("UNKNOWN"),
Expand Down Expand Up @@ -121,7 +149,6 @@ enum class FinishReason(override val serialName: String) : SerializableEnum<Fini
}

@Serializable
data class GRpcError(
val code: Int,
val message: String,
)
data class GRpcError(val code: Int, val message: String, val details: List<GRpcErrorDetails>)

@Serializable data class GRpcErrorDetails(val reason: String? = null)
Original file line number Diff line number Diff line change
Expand Up @@ -50,50 +50,62 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa
@Serializable(PartSerializer::class)
sealed interface Part

@Serializable
data class TextPart(val text: String) : Part
@Serializable data class TextPart(val text: String) : Part

@Serializable
data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part
@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part

@Serializable
data class FunctionCallPart(val functionCall: FunctionCall) : Part
@Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part

@Serializable
data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part
@Serializable data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

@Serializable
data class FunctionResponse(val name: String, val response: JsonObject)
@Serializable data class ExecutableCodePart(val executableCode: ExecutableCode) : Part

@Serializable
data class FunctionCall(val name: String, val args: Map<String, String>)
data class CodeExecutionResultPart(val codeExecutionResult: CodeExecutionResult) : Part

@Serializable
data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part
@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String?>? = null)

@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part

@Serializable
data class FileData(
@SerialName("mime_type") val mimeType: String,
@SerialName("file_uri") val fileUri: String,
)

@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64)

@Serializable data class ExecutableCode(val language: String, val code: String)

@Serializable data class CodeExecutionResult(val outcome: Outcome, val output: String)

@Serializable
data class Blob(
@SerialName("mime_type") val mimeType: String,
val data: Base64,
)
enum class Outcome(override val serialName: String) : SerializableEnum<Outcome> {
UNSPECIFIED("OUTCOME_UNSPECIFIED"),
OUTCOME_OK("OUTCOME_OK"),
OUTCOME_FAILED("OUTCOME_FAILED"),
OUTCOME_DEADLINE_EXCEEDED("OUTCOME_DEADLINE_EXCEEDED"),
}

@Serializable
data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold, val method: HarmBlockMethod? = null)

@Serializable
data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold)
enum class HarmBlockThreshold(override val serialName: String) : SerializableEnum<HarmBlockThreshold> {
UNSPECIFIED("HARM_BLOCK_THRESHOLD_UNSPECIFIED"),
BLOCK_LOW_AND_ABOVE("BLOCK_LOW_AND_ABOVE"),
BLOCK_MEDIUM_AND_ABOVE("BLOCK_MEDIUM_AND_ABOVE"),
BLOCK_ONLY_HIGH("BLOCK_ONLY_HIGH"),
BLOCK_NONE("BLOCK_NONE"),
}

@Serializable
enum class HarmBlockThreshold {
@SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED")
UNSPECIFIED,
BLOCK_LOW_AND_ABOVE,
BLOCK_MEDIUM_AND_ABOVE,
BLOCK_ONLY_HIGH,
BLOCK_NONE,
enum class HarmBlockMethod(override val serialName: String) : SerializableEnum<HarmCategory> {
UNSPECIFIED("HARM_BLOCK_METHOD_UNSPECIFIED"),
SEVERITY("SEVERITY"),
PROBABILITY("PROBABILITY"),
}

object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::class) {
Expand All @@ -103,8 +115,10 @@ object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::class) {
"text" in jsonObject -> TextPart.serializer()
"functionCall" in jsonObject -> FunctionCallPart.serializer()
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
"inline_data" in jsonObject -> BlobPart.serializer()
"file_data" in jsonObject -> FileDataPart.serializer()
"inlineData" in jsonObject -> BlobPart.serializer()
"fileData" in jsonObject -> FileDataPart.serializer()
"executableCode" in jsonObject -> ExecutableCodePart.serializer()
"codeExecutionResult" in jsonObject -> CodeExecutionResultPart.serializer()
else -> throw SerializationException("Unknown Part type")
}
}
Expand Down
Loading

0 comments on commit c1fe685

Please sign in to comment.