Skip to content

Commit

Permalink
Align proto primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
daymxn committed Jul 2, 2024
1 parent 69329f7 commit 1a88a23
Show file tree
Hide file tree
Showing 23 changed files with 214 additions and 203 deletions.
1 change: 1 addition & 0 deletions .changes/common/breath-committee-dad-curtain.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]}
1 change: 1 addition & 0 deletions .changes/common/carpenter-beggar-creator-celery.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]}
1 change: 1 addition & 0 deletions .changes/generativeai/breath-brush-achiever-boat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]}
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,19 @@ private suspend fun validateResponse(response: HttpResponse) {
if (message.contains("quota")) {
throw QuotaExceededException(message)
}
if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) {
if (error.details.any { "SERVICE_DISABLED" == it.reason }) {
throw ServiceDisabledException(message)
}
throw ServerException(message)
}

private fun GenerateContentResponse.validate() = apply {
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
if (candidates.isEmpty() && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
}
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
candidates
?.mapNotNull { it.finishReason }
?.firstOrNull { it != FinishReason.STOP }
.map { it.finishReason }
.firstOrNull { it != FinishReason.STOP }
?.let { throw ResponseStoppedException(this) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class InvalidStateException(message: String, cause: Throwable? = null) :
*/
class ResponseStoppedException(val response: GenerateContentResponse, cause: Throwable? = null) :
GoogleGenerativeAIException(
"Content generation stopped. Reason: ${response.candidates?.first()?.finishReason?.name}",
"Content generation stopped. Reason: ${response.candidates.first().finishReason?.name}",
cause,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

@file:OptIn(ExperimentalSerializationApi::class)

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.GenerationConfig
Expand All @@ -22,45 +24,41 @@ import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.util.fullModelName
import kotlinx.serialization.SerialName
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable

sealed interface Request

@Serializable
data class GenerateContentRequest(
val model: String? = null,
val model: String,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val safetySettings: List<SafetySetting> = emptyList(),
val generationConfig: GenerationConfig? = null,
val tools: List<Tool> = emptyList(),
val toolConfig: ToolConfig? = null,
val systemInstruction: Content? = null,
) : Request

@Serializable
data class CountTokensRequest(
val model: String,
val contents: List<Content> = emptyList(),
val tools: List<Tool> = emptyList(),
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val systemInstruction: Content? = null,
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
generateContentRequest =
generateContentRequest.model?.let {
generateContentRequest.copy(model = fullModelName(it))
} ?: generateContentRequest
)
fun forGenAI(request: GenerateContentRequest) =
CountTokensRequest(fullModelName(request.model), request.contents, emptyList(), request)

fun forVertexAI(generateContentRequest: GenerateContentRequest) =
fun forVertexAI(request: GenerateContentRequest) =
CountTokensRequest(
model = generateContentRequest.model?.let { fullModelName(it) },
contents = generateContentRequest.contents,
tools = generateContentRequest.tools,
systemInstruction = generateContentRequest.systemInstruction,
fullModelName(request.model),
request.contents,
request.tools,
null,
request.systemInstruction,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ sealed interface Response

@Serializable
data class GenerateContentResponse(
val candidates: List<Candidate>? = null,
val candidates: List<Candidate> = emptyList(),
val promptFeedback: PromptFeedback? = null,
val usageMetadata: UsageMetadata? = null,
) : Response

@Serializable
data class CountTokensResponse(val totalTokens: Int, val totalBillableCharacters: Int? = null) :
data class CountTokensResponse(val totalTokens: Int = 0, val totalBillableCharacters: Int = 0) :
Response

@Serializable data class GRpcErrorResponse(val error: GRpcError) : Response

@Serializable
data class UsageMetadata(
val promptTokenCount: Int? = null,
val candidatesTokenCount: Int? = null,
val totalTokenCount: Int? = null,
val promptTokenCount: Int = 0,
val candidatesTokenCount: Int = 0,
val totalTokenCount: Int = 0,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,30 @@ import kotlinx.serialization.json.JsonObject

@Serializable
data class GenerationConfig(
val temperature: Float?,
@SerialName("top_p") val topP: Float?,
@SerialName("top_k") val topK: Int?,
@SerialName("candidate_count") val candidateCount: Int?,
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?,
@SerialName("response_mime_type") val responseMimeType: String? = null,
@SerialName("presence_penalty") val presencePenalty: Float? = null,
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
@SerialName("response_schema") val responseSchema: Schema? = null,
val temperature: Float = 0f,
val topP: Float = 0f,
val topK: Int = 0,
val candidateCount: Int = 0,
val maxOutputTokens: Int = 0,
val stopSequences: List<String> = emptyList(),
val responseMimeType: String = "",
val presencePenalty: Float = 0f,
val frequencyPenalty: Float = 0f,
val responseSchema: Schema? = null,
)

@Serializable
data class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
val functionDeclarations: List<FunctionDeclaration> = emptyList(),
// This is a json object because it is not possible to make a data class with no parameters.
val codeExecution: JsonObject? = null,
)

@Serializable
data class ToolConfig(
@SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig
)
data class ToolConfig(val functionCallingConfig: FunctionCallingConfig = FunctionCallingConfig())

@Serializable
data class FunctionCallingConfig(val mode: Mode) {
data class FunctionCallingConfig(val mode: Mode? = null) {
@Serializable
enum class Mode {
@SerialName("MODE_UNSPECIFIED") UNSPECIFIED,
Expand All @@ -58,16 +56,20 @@ data class FunctionCallingConfig(val mode: Mode) {
}

@Serializable
data class FunctionDeclaration(val name: String, val description: String, val parameters: Schema)
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: Schema? = null,
)

@Serializable
data class Schema(
val type: String,
val description: String? = null,
val format: String? = null,
val nullable: Boolean? = false,
val enum: List<String>? = null,
val properties: Map<String, Schema>? = null,
val required: List<String>? = null,
val description: String = "",
val format: String = "",
val nullable: Boolean = false,
val enum: List<String> = emptyList(),
val properties: Map<String, Schema> = emptyMap(),
val required: List<String> = emptyList(),
val items: Schema? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

@file:OptIn(ExperimentalSerializationApi::class)

package com.google.ai.client.generativeai.common.server

import com.google.ai.client.generativeai.common.shared.Content
Expand All @@ -37,7 +39,7 @@ object FinishReasonSerializer :
@Serializable
data class PromptFeedback(
val blockReason: BlockReason? = null,
val safetyRatings: List<SafetyRating>? = null,
val safetyRatings: List<SafetyRating> = emptyList(),
)

@Serializable(BlockReasonSerializer::class)
Expand All @@ -52,59 +54,51 @@ enum class BlockReason {
data class Candidate(
val content: Content? = null,
val finishReason: FinishReason? = null,
val safetyRatings: List<SafetyRating>? = null,
val safetyRatings: List<SafetyRating> = emptyList(),
val citationMetadata: CitationMetadata? = null,
val groundingMetadata: GroundingMetadata? = null,
)

@Serializable
data class CitationMetadata
@OptIn(ExperimentalSerializationApi::class)
constructor(@JsonNames("citations") val citationSources: List<CitationSources>)
data class CitationMetadata(
@JsonNames("citations") val citationSources: List<CitationSources> = emptyList()
)

@Serializable
data class CitationSources(
val startIndex: Int = 0,
val endIndex: Int,
val uri: String,
val license: String? = null,
val endIndex: Int = 0,
val uri: String = "",
val license: String = "",
)

@Serializable
data class SafetyRating(
val category: HarmCategory,
val probability: HarmProbability,
val blocked: Boolean? = null, // TODO(): any reason not to default to false?
val probabilityScore: Float? = null,
val blocked: Boolean = false,
val probabilityScore: Float = 0f,
val severity: HarmSeverity? = null,
val severityScore: Float? = null,
val severityScore: Float = 0f,
)

@Serializable
data class GroundingMetadata(
@SerialName("web_search_queries") val webSearchQueries: List<String>?,
@SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?,
@SerialName("retrieval_queries") val retrievalQueries: List<String>?,
@SerialName("grounding_attribution") val groundingAttribution: List<GroundingAttribution>?,
val webSearchQueries: List<String> = emptyList(),
val searchEntryPoint: SearchEntryPoint? = null,
val retrievalQueries: List<String> = emptyList(),
val groundingAttribution: List<GroundingAttribution> = emptyList(),
)

@Serializable
data class SearchEntryPoint(
@SerialName("rendered_content") val renderedContent: String?,
@SerialName("sdk_blob") val sdkBlob: String?,
)
data class SearchEntryPoint(val renderedContent: String = "", val sdkBlob: String = "")

// TODO() Has a different definition for labs vs vertex. May need to split into diff types in future
// (when labs supports it)
@Serializable
data class GroundingAttribution(
val segment: Segment,
@SerialName("confidence_score") val confidenceScore: Float?,
)
data class GroundingAttribution(val segment: Segment, val confidenceScore: Float = 0f)

@Serializable
data class Segment(
@SerialName("start_index") val startIndex: Int,
@SerialName("end_index") val endIndex: Int,
)
@Serializable data class Segment(val startIndex: Int = 0, val endIndex: Int = 0)

@Serializable(HarmProbabilitySerializer::class)
enum class HarmProbability {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ typealias Base64 = String

@ExperimentalSerializationApi
@Serializable
data class Content(@EncodeDefault val role: String? = "user", val parts: List<Part>)
data class Content(@EncodeDefault val role: String = "", val parts: List<Part>)

@Serializable(PartSerializer::class) sealed interface Part

@Serializable data class TextPart(val text: String) : Part
@Serializable data class TextPart(val text: String = "") : Part

@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part
@Serializable data class BlobPart(val inlineData: Blob) : Part

@Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part

Expand All @@ -64,17 +64,14 @@ data class CodeExecutionResultPart(val codeExecutionResult: CodeExecutionResult)

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String?>)
@Serializable
data class FunctionCall(val name: String, val args: Map<String, String?> = emptyMap())

@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part
@Serializable data class FileDataPart(val fileData: FileData) : Part

@Serializable
data class FileData(
@SerialName("mime_type") val mimeType: String,
@SerialName("file_uri") val fileUri: String,
)
@Serializable data class FileData(val mimeType: String, val fileUri: String)

@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64)
@Serializable data class Blob(val mimeType: String, val data: Base64)

@Serializable data class ExecutableCode(val language: String, val code: String)

Expand Down
Loading

0 comments on commit 1a88a23

Please sign in to comment.