From 4e2a9c7e035dec5e88b2d64811b32426879e8968 Mon Sep 17 00:00:00 2001 From: takahirom Date: Mon, 6 Jan 2025 13:30:23 +0900 Subject: [PATCH] Make OpenAiAiAssertionModel gemini API compatible --- .../roborazzi/OpenAiAiAssertionModel.kt | 14 ++-- .../GeminiWithOpenAiApiInterfaceTest.kt | 79 +++++++++++++++++++ 2 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 sample-android/src/test/java/com/github/takahirom/roborazzi/sample/GeminiWithOpenAiApiInterfaceTest.kt diff --git a/roborazzi-ai-openai/src/commonMain/kotlin/com/github/takahirom/roborazzi/OpenAiAiAssertionModel.kt b/roborazzi-ai-openai/src/commonMain/kotlin/com/github/takahirom/roborazzi/OpenAiAiAssertionModel.kt index 37a58435e..afaccb982 100644 --- a/roborazzi-ai-openai/src/commonMain/kotlin/com/github/takahirom/roborazzi/OpenAiAiAssertionModel.kt +++ b/roborazzi-ai-openai/src/commonMain/kotlin/com/github/takahirom/roborazzi/OpenAiAiAssertionModel.kt @@ -41,7 +41,7 @@ class OpenAiAiAssertionModel( private val loggingEnabled: Boolean = false, private val temperature: Float = DefaultTemperature, private val maxTokens: Int = DefaultMaxOutputTokens, - private val seed: Int = 1566, + private val seed: Int? = 1566, private val requestBuilderModifier: (HttpRequestBuilder.() -> Unit) = { header("Authorization", "Bearer $apiKey") }, @@ -231,7 +231,7 @@ private data class ChatCompletionRequest( val temperature: Float, @SerialName("max_tokens") val maxTokens: Int, @SerialName("response_format") val responseFormat: ResponseFormat?, - val seed: Int, + val seed: Int?, ) @Serializable @@ -260,7 +260,8 @@ private data class ImageUrl( @Serializable private data class ChatCompletionResponse( - val id: String, + // null on gemini + val id: String? = null, val `object`: String, val created: Long, val model: String, @@ -283,9 +284,12 @@ private data class ChoiceMessage( @Serializable private data class Usage( - @SerialName("prompt_tokens") val promptTokens: Int, + // null on gemini + @SerialName("prompt_tokens") val promptTokens: Int? = null, + // null on gemini @SerialName("completion_tokens") val completionTokens: Int? = null, - @SerialName("total_tokens") val totalTokens: Int, + // null on gemini + @SerialName("total_tokens") val totalTokens: Int? = null, ) diff --git a/sample-android/src/test/java/com/github/takahirom/roborazzi/sample/GeminiWithOpenAiApiInterfaceTest.kt b/sample-android/src/test/java/com/github/takahirom/roborazzi/sample/GeminiWithOpenAiApiInterfaceTest.kt new file mode 100644 index 000000000..f4b22938f --- /dev/null +++ b/sample-android/src/test/java/com/github/takahirom/roborazzi/sample/GeminiWithOpenAiApiInterfaceTest.kt @@ -0,0 +1,79 @@ +package com.github.takahirom.roborazzi.sample + +import androidx.compose.ui.test.junit4.createAndroidComposeRule +import androidx.test.espresso.Espresso.onView +import androidx.test.espresso.matcher.ViewMatchers +import androidx.test.ext.junit.runners.AndroidJUnit4 +import com.github.takahirom.roborazzi.AiAssertionOptions +import com.github.takahirom.roborazzi.DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH +import com.github.takahirom.roborazzi.ExperimentalRoborazziApi +import com.github.takahirom.roborazzi.OpenAiAiAssertionModel +import com.github.takahirom.roborazzi.ROBORAZZI_DEBUG +import com.github.takahirom.roborazzi.RobolectricDeviceQualifiers +import com.github.takahirom.roborazzi.RoborazziOptions +import com.github.takahirom.roborazzi.RoborazziRule +import com.github.takahirom.roborazzi.RoborazziTaskType +import com.github.takahirom.roborazzi.captureRoboImage +import com.github.takahirom.roborazzi.provideRoborazziContext +import com.github.takahirom.roborazzi.roboOutputName +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.annotation.Config +import org.robolectric.annotation.GraphicsMode +import java.io.File + +@OptIn(ExperimentalRoborazziApi::class) +@RunWith(AndroidJUnit4::class) +@GraphicsMode(GraphicsMode.Mode.NATIVE) +@Config( + sdk = [30], + qualifiers = RobolectricDeviceQualifiers.NexusOne +) +class GeminiWithOpenAiApiInterfaceTest { + @get:Rule + val composeTestRule = createAndroidComposeRule() + + @get:Rule + val roborazziRule = RoborazziRule( + options = RoborazziRule.Options( + roborazziOptions = RoborazziOptions( + taskType = RoborazziTaskType.Compare, + compareOptions = RoborazziOptions.CompareOptions( + aiAssertionOptions = AiAssertionOptions( + aiAssertionModel = OpenAiAiAssertionModel( + baseUrl = "https://generativelanguage.googleapis.com/v1beta/openai/", + apiKey = System.getenv("gemini_api_key").orEmpty(), + modelName = "gemini-1.5-flash", + seed = null + ), + ) + ) + ) + ) + ) + + @Test + fun captureWithAi() { + ROBORAZZI_DEBUG = true + if (System.getenv("gemini_api_key") == null) { + println("Skip the test because openai_api_key is not set.") + return + } + File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + ".png").delete() + onView(ViewMatchers.isRoot()) + .captureRoboImage( + roborazziOptions = provideRoborazziContext().options.addedAiAssertions( + AiAssertionOptions.AiAssertion( + assertionPrompt = "it should have PREVIOUS button", + requiredFulfillmentPercent = 90, + ), + AiAssertionOptions.AiAssertion( + assertionPrompt = "it should show First Fragment", + requiredFulfillmentPercent = 90, + ) + ) + ) + File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + "_compare.png").delete() + } +} \ No newline at end of file