Skip to content

Commit

Permalink
Update tests with kotlin test usages
Browse files Browse the repository at this point in the history
  • Loading branch information
PatilShreyas committed Sep 21, 2024
1 parent c988de5 commit 573d41d
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,21 @@ internal constructor(
headerProvider: HeaderProvider? = null,
) : this(key, model, requestOptions, null, apiClient, headerProvider)

// @VisibleForTesting(otherwise = VisibleForTesting.NONE)
constructor(
key: String,
model: String,
requestOptions: RequestOptions,
apiClient: String,
headerProvider: HeaderProvider?,
httpEngine: HttpClientEngine?,
channel: ByteChannel,
status: HttpStatusCode,
) : this(
key,
model,
requestOptions,
null,
httpEngine,
apiClient,
headerProvider,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonObject
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.test.Test
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
Expand Down Expand Up @@ -295,42 +293,39 @@ internal class RequestFormatTests {

requestBodyAsText shouldContainJsonKey "tools[0].codeExecution"
}
}

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {

@Test
fun `request should include right model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController(
"super_cool_test_key",
modelName,
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)
val models = models()
models.forEach { (modelName, actualName) ->
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController(
"super_cool_test_key",
modelName,
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}
}

mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
}
}

companion object {
@JvmStatic
@Parameterized.Parameters
fun data() =
fun models() =
listOf(
arrayOf("gemini-pro", "models/gemini-pro"),
arrayOf("x/gemini-pro", "x/gemini-pro"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.http.HttpStatusCode
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.withTimeout
import org.junit.Test
import kotlin.test.Test
import kotlin.time.Duration.Companion.seconds

internal class StreamingSnapshotTests {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import io.kotest.matchers.types.shouldBeInstanceOf
import io.ktor.http.HttpStatusCode
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.Serializable
import org.junit.Test
import kotlin.test.Test
import kotlin.time.Duration.Companion.seconds

@Serializable internal data class MountainColors(val name: String, val colors: List<String>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,20 @@ internal fun commonTest(
block: CommonTest,
) = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val apiController =
APIController(
"super_cool_test_key",
"gemini-pro",
requestOptions,
TEST_CLIENT_ID,
null,
MockEngine {
respond(
channel,
status,
headersOf(HttpHeaders.ContentType, "application/json"),
)
},
channel,
status,
)
Expand Down

0 comments on commit 573d41d

Please sign in to comment.