diff --git a/common/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/common/APIController.kt b/common/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/common/APIController.kt index cc47b173..3382b488 100644 --- a/common/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/common/APIController.kt +++ b/common/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/common/APIController.kt @@ -84,19 +84,21 @@ internal constructor( headerProvider: HeaderProvider? = null, ) : this(key, model, requestOptions, null, apiClient, headerProvider) + // @VisibleForTesting(otherwise = VisibleForTesting.NONE) constructor( key: String, model: String, requestOptions: RequestOptions, apiClient: String, headerProvider: HeaderProvider?, + httpEngine: HttpClientEngine?, channel: ByteChannel, status: HttpStatusCode, ) : this( key, model, requestOptions, - null, + httpEngine, apiClient, headerProvider, ) diff --git a/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/APIControllerTests.kt index c279ce7b..cfa05080 100644 --- a/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/APIControllerTests.kt @@ -41,9 +41,7 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.withTimeout import kotlinx.serialization.encodeToString import kotlinx.serialization.json.JsonObject -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.Parameterized +import kotlin.test.Test import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds @@ -295,42 +293,39 @@ internal class RequestFormatTests { requestBodyAsText shouldContainJsonKey "tools[0].codeExecution" } -} - -@RunWith(Parameterized::class) -internal class ModelNamingTests(private val modelName: String, private val actualName: String) { @Test fun `request should include right model name`() = doBlocking { - val channel = ByteChannel(autoFlush = true) - val mockEngine = MockEngine { - respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) - } - prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } - val controller = - APIController( - "super_cool_test_key", - modelName, - RequestOptions(), - mockEngine, - TEST_CLIENT_ID, - null, - ) + val models = models() + models.forEach { (modelName, actualName) -> + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + val controller = + APIController( + "super_cool_test_key", + modelName, + RequestOptions(), + mockEngine, + TEST_CLIENT_ID, + null, + ) - withTimeout(5.seconds) { - controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false - channel.close() + withTimeout(5.seconds) { + controller.generateContentStream(textGenerateContentRequest("cats")).collect { + it.candidates?.isEmpty() shouldBe false + channel.close() + } } - } - mockEngine.requestHistory.first().url.encodedPath shouldContain actualName + mockEngine.requestHistory.first().url.encodedPath shouldContain actualName + } } companion object { - @JvmStatic - @Parameterized.Parameters - fun data() = + fun models() = listOf( arrayOf("gemini-pro", "models/gemini-pro"), arrayOf("x/gemini-pro", "x/gemini-pro"), diff --git a/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/StreamingSnapshotTests.kt b/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/StreamingSnapshotTests.kt index 7d923716..890015e2 100644 --- a/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/StreamingSnapshotTests.kt +++ b/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/StreamingSnapshotTests.kt @@ -26,10 +26,9 @@ import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import io.ktor.http.HttpStatusCode import kotlinx.coroutines.flow.collect -import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withTimeout -import org.junit.Test +import kotlin.test.Test import kotlin.time.Duration.Companion.seconds internal class StreamingSnapshotTests { diff --git a/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/UnarySnapshotTests.kt b/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/UnarySnapshotTests.kt index 6733eed3..9d30aec3 100644 --- a/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/UnarySnapshotTests.kt +++ b/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/UnarySnapshotTests.kt @@ -39,7 +39,7 @@ import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.http.HttpStatusCode import kotlinx.coroutines.withTimeout import kotlinx.serialization.Serializable -import org.junit.Test +import kotlin.test.Test import kotlin.time.Duration.Companion.seconds @Serializable internal data class MountainColors(val name: String, val colors: List) diff --git a/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/util/tests.kt b/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/util/tests.kt index c763891f..414d759c 100644 --- a/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/util/tests.kt +++ b/common/src/test/kotlin/dev/shreyaspatil/ai/client/generativeai/common/util/tests.kt @@ -104,9 +104,6 @@ internal fun commonTest( block: CommonTest, ) = doBlocking { val channel = ByteChannel(autoFlush = true) - val mockEngine = MockEngine { - respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) - } val apiController = APIController( "super_cool_test_key", @@ -114,6 +111,13 @@ internal fun commonTest( requestOptions, TEST_CLIENT_ID, null, + MockEngine { + respond( + channel, + status, + headersOf(HttpHeaders.ContentType, "application/json"), + ) + }, channel, status, )