Skip to content

Commit

Permalink
fix(liveness): API to add liveness version to websocket (#2572)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Leing <[email protected]>
Co-authored-by: Matt Creaser <[email protected]>
Co-authored-by: Tyler Roach <[email protected]>
  • Loading branch information
4 people authored Sep 13, 2023
1 parent 7a9343b commit d98c4b4
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import android.content.Context;
import android.graphics.Bitmap;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.VisibleForTesting;

import com.amplifyframework.annotations.InternalAmplifyApi;
Expand Down Expand Up @@ -319,19 +320,61 @@ public InterpretOperation<?> interpret(
* @param onError Called when an error occurs during the session.
*/
@InternalAmplifyApi
public static void startFaceLivenessSession(@NonNull String sessionId,
@NonNull FaceLivenessSessionInformation sessionInformation,
@NonNull Consumer<FaceLivenessSession> onSessionStarted,
@NonNull Action onComplete,
@NonNull Consumer<PredictionsException> onError) {
startFaceLivenessSession(sessionId, sessionInformation, FaceLivenessSessionOptions.defaults(),
onSessionStarted, onComplete, onError);
}

/**
* Starts a Liveness session with the given options.
* @param sessionId ID for the session to start.
* @param sessionInformation Information about the face liveness session.
* @param options The options for this session.
* @param onSessionStarted Called when the face liveness session has been started.
* @param onComplete Called when the session is complete.
* @param onError Called when an error occurs during the session.
*/
@InternalAmplifyApi
public static void startFaceLivenessSession(@NonNull String sessionId,
@NonNull FaceLivenessSessionInformation sessionInformation,
@NonNull FaceLivenessSessionOptions options,
@NonNull Consumer<FaceLivenessSession> onSessionStarted,
@NonNull Action onComplete,
@NonNull Consumer<PredictionsException> onError) {

startFaceLivenessSession(sessionId, sessionInformation, options, null,
onSessionStarted, onComplete, onError);
}

/**
* Starts a Liveness session.
* @param sessionId ID for the session to start.
* @param sessionInformation Information about the face liveness session.
* @param livenessVersion The version of liveness, which will be attached to the user agent.
* @param onSessionStarted Called when the face liveness session has been started.
* @param onComplete Called when the session is complete.
* @param onError Called when an error occurs during the session.
*/
@InternalAmplifyApi
public static void startFaceLivenessSession(@NonNull String sessionId,
@NonNull FaceLivenessSessionInformation sessionInformation,
@Nullable String livenessVersion,
@NonNull Consumer<FaceLivenessSession> onSessionStarted,
@NonNull Action onComplete,
@NonNull Consumer<PredictionsException> onError) {
startFaceLivenessSession(sessionId, sessionInformation, FaceLivenessSessionOptions.defaults(),
onSessionStarted, onComplete, onError);
livenessVersion, onSessionStarted, onComplete, onError);
}

/**
* Starts a Liveness session with the given options.
* @param sessionId ID for the session to start.
* @param sessionInformation Information about the face liveness session.
* @param livenessVersion The version of liveness, which will be attached to the user agent.
* @param options The options for this session.
* @param onSessionStarted Called when the face liveness session has been started.
* @param onComplete Called when the session is complete.
Expand All @@ -341,6 +384,7 @@ public static void startFaceLivenessSession(@NonNull String sessionId,
public static void startFaceLivenessSession(@NonNull String sessionId,
@NonNull FaceLivenessSessionInformation sessionInformation,
@NonNull FaceLivenessSessionOptions options,
@Nullable String livenessVersion,
@NonNull Consumer<FaceLivenessSession> onSessionStarted,
@NonNull Action onComplete,
@NonNull Consumer<PredictionsException> onError) {
Expand All @@ -358,6 +402,6 @@ public static void startFaceLivenessSession(@NonNull String sessionId,
.convertToSdkCredentialsProvider(awsCredentialsProvider);
}
new RunFaceLivenessSession(sessionId, sessionInformation, credentialsProvider,
onSessionStarted, onComplete, onError);
livenessVersion, onSessionStarted, onComplete, onError);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ internal class LivenessWebSocket(
val endpoint: String,
val region: String,
val sessionInformation: FaceLivenessSessionInformation,
val livenessVersion: String?,
val onSessionInformationReceived: Consumer<SessionInformation>,
val onErrorReceived: Consumer<PredictionsException>,
val onComplete: Action
Expand Down Expand Up @@ -197,12 +198,18 @@ internal class LivenessWebSocket(
}
}

private fun getUserAgent(): String {
@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
fun getUserAgent(): String {
val amplifyVersion = BuildConfig.VERSION_NAME
val deviceManufacturer = Build.MANUFACTURER.replace(" ", "_")
val deviceName = Build.MODEL.replace(" ", "_")
val userAgent = "${UserAgent.string()} os/Android/${Build.VERSION.SDK_INT} md/device/$deviceName " +
var userAgent = "${UserAgent.string()} os/Android/${Build.VERSION.SDK_INT} md/device/$deviceName " +
"md/device-manufacturer/$deviceManufacturer api/rekognitionstreaming/$amplifyVersion"

if (!livenessVersion.isNullOrBlank()) {
userAgent += " api/liveness/$livenessVersion"
}

return userAgent.replace(Build.MANUFACTURER, deviceManufacturer).replace(Build.MODEL, deviceName)
.replace("+", "_")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ internal class RunFaceLivenessSession(
sessionId: String,
sessionInformation: FaceLivenessSessionInformation,
val credentialsProvider: CredentialsProvider,
livenessVersion: String?,
onSessionStarted: Consumer<FaceLivenessSession>,
onComplete: Action,
onError: Consumer<PredictionsException>
Expand All @@ -55,6 +56,7 @@ internal class RunFaceLivenessSession(
"${sessionInformation.videoWidth.toInt()}&video-height=${sessionInformation.videoHeight.toInt()}",
region = sessionInformation.region,
sessionInformation = sessionInformation,
livenessVersion = livenessVersion,
onSessionInformationReceived = { sessionInformation ->
val challenges = processSessionInformation(sessionInformation)
val faceLivenessSession = FaceLivenessSession(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

package com.amplifyframework.predictions.aws.http

import android.os.Build
import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
import aws.smithy.kotlin.runtime.auth.awscredentials.CredentialsProvider
import aws.smithy.kotlin.runtime.util.Attributes
import com.amplifyframework.core.Action
import com.amplifyframework.core.BuildConfig
import com.amplifyframework.core.Consumer
import com.amplifyframework.predictions.PredictionsException
import com.amplifyframework.predictions.aws.models.liveness.ChallengeConfig
Expand Down Expand Up @@ -68,7 +70,6 @@ import org.robolectric.RobolectricTestRunner
internal class LivenessWebSocketTest {
private val json = Json { encodeDefaults = true }

private lateinit var livenessWebSocket: LivenessWebSocket
private lateinit var server: MockWebServer

private val onComplete = mockk<Action>(relaxed = true)
Expand All @@ -77,7 +78,11 @@ internal class LivenessWebSocketTest {
private val credentialsProvider = object : CredentialsProvider {
override suspend fun resolve(attributes: Attributes): Credentials {
return Credentials(
"", "", "", null, ""
"",
"",
"",
null,
""
)
}
}
Expand All @@ -86,18 +91,7 @@ internal class LivenessWebSocketTest {
@Before
fun setUp() {
Dispatchers.setMain(Dispatchers.Unconfined)

server = MockWebServer()

livenessWebSocket = LivenessWebSocket(
credentialsProvider,
server.url("/").toString(),
"",
sessionInformation,
onSessionInformationReceived,
onErrorReceived,
onComplete
)
}

@After
Expand All @@ -109,6 +103,7 @@ internal class LivenessWebSocketTest {
@Test
fun `onClosing informs webSocket`() {
val webSocket = mockk<WebSocket>(relaxed = true)
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.webSocket = webSocket

livenessWebSocket.webSocketListener.onClosing(webSocket, 4, "closing")
Expand All @@ -118,20 +113,23 @@ internal class LivenessWebSocketTest {

@Test
fun `normal status onClosed calls onComplete`() {
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.webSocketListener.onClosed(mockk(), 1000, "closing")

verify { onComplete.call() }
}

@Test
fun `bad status onClosed calls onError`() {
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.webSocketListener.onClosed(mockk(), 5000, "closing")

verify { onErrorReceived.accept(any()) }
}

@Test
fun `onClosed does not call onError if client stopped`() {
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.clientStoppedSession = true

livenessWebSocket.webSocketListener.onClosed(mockk(), 5000, "closing")
Expand All @@ -141,6 +139,7 @@ internal class LivenessWebSocketTest {

@Test
fun `onFailure calls onError`() {
val livenessWebSocket = createLivenessWebSocket()
// Response does noted like to be mockk
val response = Response.Builder()
.code(200)
Expand All @@ -156,6 +155,7 @@ internal class LivenessWebSocketTest {

@Test
fun `onFailure does not call onError if client stopped`() {
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.clientStoppedSession = true
// Response does noted like to be mockk
val response = Response.Builder()
Expand All @@ -172,6 +172,7 @@ internal class LivenessWebSocketTest {

@Test
fun `web socket assigned on open`() {
val livenessWebSocket = createLivenessWebSocket()
val openLatch = CountDownLatch(1)
val latchingListener = LatchingWebSocketResponseListener(
livenessWebSocket.webSocketListener,
Expand Down Expand Up @@ -200,6 +201,7 @@ internal class LivenessWebSocketTest {

@Test
fun `server session event tracked`() {
val livenessWebSocket = createLivenessWebSocket()
val event = ServerSessionInformationEvent(
sessionInformation = SessionInformation(
challenge = ServerChallenge(
Expand Down Expand Up @@ -242,6 +244,7 @@ internal class LivenessWebSocketTest {

@Test
fun `disconnect event stops websocket`() {
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.webSocket = mockk()
val event = DisconnectionEvent(1)
val headers = mapOf(
Expand All @@ -260,6 +263,7 @@ internal class LivenessWebSocketTest {

@Test
fun `web socket error closes websocket`() {
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.webSocket = mockk()
val event = ValidationException("ValidationException")
val headers = mapOf(
Expand All @@ -276,6 +280,43 @@ internal class LivenessWebSocketTest {
verify { livenessWebSocket.webSocket!!.close(1000, any()) }
}

@Test
fun `web socket user agent with null UI version`() {
val livenessWebSocket = createLivenessWebSocket(livenessVersion = null)
livenessWebSocket.webSocket = mockk()

val version = BuildConfig.VERSION_NAME
val os = Build.VERSION.SDK_INT
val baseline = "amplify-android:$version md/unknown/robolectric md/locale/en_UNKNOWN os/Android/$os " +
"md/device/robolectric md/device-manufacturer/unknown api/rekognitionstreaming/$version"
assertEquals(livenessWebSocket.getUserAgent(), baseline)
}

@Test
fun `web socket user agent with blank UI version`() {
val livenessWebSocket = createLivenessWebSocket(livenessVersion = " ")
livenessWebSocket.webSocket = mockk()

val version = BuildConfig.VERSION_NAME
val os = Build.VERSION.SDK_INT
val baseline = "amplify-android:$version md/unknown/robolectric md/locale/en_UNKNOWN os/Android/$os " +
"md/device/robolectric md/device-manufacturer/unknown api/rekognitionstreaming/$version"
assertEquals(livenessWebSocket.getUserAgent(), baseline)
}

@Test
fun `web socket user agent includes added UI version`() {
val livenessWebSocket = createLivenessWebSocket(livenessVersion = "1.1.1")
livenessWebSocket.webSocket = mockk()

val version = BuildConfig.VERSION_NAME
val os = Build.VERSION.SDK_INT
val baseline = "amplify-android:$version md/unknown/robolectric md/locale/en_UNKNOWN os/Android/$os " +
"md/device/robolectric md/device-manufacturer/unknown api/rekognitionstreaming/$version"
val additional = "api/liveness/1.1.1"
assertEquals(livenessWebSocket.getUserAgent(), "$baseline $additional")
}

@Test
@Ignore("Need to work on parsing the onMessage byteString from ServerWebSocketListener")
fun `sendInitialFaceDetectedEvent test`() {
Expand All @@ -300,11 +341,24 @@ internal class LivenessWebSocketTest {
@Ignore("Need to work on parsing the onMessage byteString from ServerWebSocketListener")
fun `sendVideoEvent test`() {
}

private fun createLivenessWebSocket(
livenessVersion: String? = null
) = LivenessWebSocket(
credentialsProvider,
server.url("/").toString(),
"",
sessionInformation,
livenessVersion,
onSessionInformationReceived,
onErrorReceived,
onComplete
)
}

class LatchingWebSocketResponseListener(
private val webSocketListener: WebSocketListener,
private val openLatch: CountDownLatch = CountDownLatch(1),
private val openLatch: CountDownLatch = CountDownLatch(1)
) : WebSocketListener() {

override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
Expand Down

0 comments on commit d98c4b4

Please sign in to comment.