Skip to content

Commit

Permalink
Make OpenAiAiAssertionModel gemini API compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
takahirom committed Jan 6, 2025
1 parent 4b0f77f commit 4e2a9c7
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class OpenAiAiAssertionModel(
private val loggingEnabled: Boolean = false,
private val temperature: Float = DefaultTemperature,
private val maxTokens: Int = DefaultMaxOutputTokens,
private val seed: Int = 1566,
private val seed: Int? = 1566,
private val requestBuilderModifier: (HttpRequestBuilder.() -> Unit) = {
header("Authorization", "Bearer $apiKey")
},
Expand Down Expand Up @@ -231,7 +231,7 @@ private data class ChatCompletionRequest(
val temperature: Float,
@SerialName("max_tokens") val maxTokens: Int,
@SerialName("response_format") val responseFormat: ResponseFormat?,
val seed: Int,
val seed: Int?,
)

@Serializable
Expand Down Expand Up @@ -260,7 +260,8 @@ private data class ImageUrl(

@Serializable
private data class ChatCompletionResponse(
val id: String,
// null on gemini
val id: String? = null,
val `object`: String,
val created: Long,
val model: String,
Expand All @@ -283,9 +284,12 @@ private data class ChoiceMessage(

@Serializable
private data class Usage(
@SerialName("prompt_tokens") val promptTokens: Int,
// null on gemini
@SerialName("prompt_tokens") val promptTokens: Int? = null,
// null on gemini
@SerialName("completion_tokens") val completionTokens: Int? = null,
@SerialName("total_tokens") val totalTokens: Int,
// null on gemini
@SerialName("total_tokens") val totalTokens: Int? = null,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.github.takahirom.roborazzi.sample

import androidx.compose.ui.test.junit4.createAndroidComposeRule
import androidx.test.espresso.Espresso.onView
import androidx.test.espresso.matcher.ViewMatchers
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.github.takahirom.roborazzi.AiAssertionOptions
import com.github.takahirom.roborazzi.DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH
import com.github.takahirom.roborazzi.ExperimentalRoborazziApi
import com.github.takahirom.roborazzi.OpenAiAiAssertionModel
import com.github.takahirom.roborazzi.ROBORAZZI_DEBUG
import com.github.takahirom.roborazzi.RobolectricDeviceQualifiers
import com.github.takahirom.roborazzi.RoborazziOptions
import com.github.takahirom.roborazzi.RoborazziRule
import com.github.takahirom.roborazzi.RoborazziTaskType
import com.github.takahirom.roborazzi.captureRoboImage
import com.github.takahirom.roborazzi.provideRoborazziContext
import com.github.takahirom.roborazzi.roboOutputName
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.annotation.Config
import org.robolectric.annotation.GraphicsMode
import java.io.File

@OptIn(ExperimentalRoborazziApi::class)
@RunWith(AndroidJUnit4::class)
@GraphicsMode(GraphicsMode.Mode.NATIVE)
@Config(
sdk = [30],
qualifiers = RobolectricDeviceQualifiers.NexusOne
)
class GeminiWithOpenAiApiInterfaceTest {
@get:Rule
val composeTestRule = createAndroidComposeRule<MainActivity>()

@get:Rule
val roborazziRule = RoborazziRule(
options = RoborazziRule.Options(
roborazziOptions = RoborazziOptions(
taskType = RoborazziTaskType.Compare,
compareOptions = RoborazziOptions.CompareOptions(
aiAssertionOptions = AiAssertionOptions(
aiAssertionModel = OpenAiAiAssertionModel(
baseUrl = "https://generativelanguage.googleapis.com/v1beta/openai/",
apiKey = System.getenv("gemini_api_key").orEmpty(),
modelName = "gemini-1.5-flash",
seed = null
),
)
)
)
)
)

@Test
fun captureWithAi() {
ROBORAZZI_DEBUG = true
if (System.getenv("gemini_api_key") == null) {
println("Skip the test because openai_api_key is not set.")
return
}
File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + ".png").delete()
onView(ViewMatchers.isRoot())
.captureRoboImage(
roborazziOptions = provideRoborazziContext().options.addedAiAssertions(
AiAssertionOptions.AiAssertion(
assertionPrompt = "it should have PREVIOUS button",
requiredFulfillmentPercent = 90,
),
AiAssertionOptions.AiAssertion(
assertionPrompt = "it should show First Fragment",
requiredFulfillmentPercent = 90,
)
)
)
File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + "_compare.png").delete()
}
}

0 comments on commit 4e2a9c7

Please sign in to comment.