Skip to content

Commit

Permalink
fix: handle the case where asset name can be missing [WPB-10830] πŸ’ (#…
Browse files Browse the repository at this point in the history
…3002)

* Commit with unresolved merge conflicts

* merge conflicts

---------

Co-authored-by: Mohamad Jaara <[email protected]>
  • Loading branch information
github-actions[bot] and MohamadJaara authored Sep 13, 2024
1 parent 987b782 commit a04b87f
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ data class AssetContent(
}

// We should not display Preview Assets (assets w/o valid encryption keys sent by Mac/Web clients) unless they include image metadata
val shouldBeDisplayed = !isPreviewMessage || hasValidImageMetadata
val isAssetDataComplete = !isPreviewMessage || hasValidImageMetadata

sealed class AssetMetadata {
data class Image(val width: Int, val height: Int) : AssetMetadata()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ internal class ScheduleNewAssetMessageUseCaseImpl(
FileSharingStatus.Value.Disabled -> return ScheduleNewAssetMessageResult.Failure.DisabledByTeam
FileSharingStatus.Value.EnabledAll -> { /* no-op*/ }

is FileSharingStatus.Value.EnabledSome -> if (!validateAssetFileUseCase(assetName, it.state.allowedType)) {
is FileSharingStatus.Value.EnabledSome -> if (!validateAssetFileUseCase(
fileName = assetName,
mimeType = assetMimeType,
allowedExtension = it.state.allowedType
)
) {
kaliumLogger.e("The asset message trying to be processed has invalid content data")
return ScheduleNewAssetMessageResult.Failure.RestrictedFileType
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,95 @@
*/
package com.wire.kalium.logic.feature.asset

import com.wire.kalium.logic.kaliumLogger

/**
* Returns true if the file extension is present in file name and is allowed and false otherwise.
* @param fileName the file name (with extension) to validate.
* @param allowedExtension the list of allowed extension.
*/
interface ValidateAssetFileTypeUseCase {
operator fun invoke(fileName: String?, allowedExtension: List<String>): Boolean
operator fun invoke(
fileName: String?,
mimeType: String,
allowedExtension: List<String>
): Boolean
}

internal class ValidateAssetFileTypeUseCaseImpl : ValidateAssetFileTypeUseCase {
override operator fun invoke(fileName: String?, allowedExtension: List<String>): Boolean {
if (fileName == null) return false

val split = fileName.split(".")
return if (split.size < 2) {
false
override operator fun invoke(
fileName: String?,
mimeType: String,
allowedExtension: List<String>
): Boolean {
kaliumLogger.d("Validating file type for $fileName with mimeType $mimeType is empty ${mimeType.isBlank()}")
val extension = if (fileName != null) {
extensionFromFileName(fileName)
} else {
val allowedExtensionLowerCase = allowedExtension.map { it.lowercase() }
val extensions = split.subList(1, split.size).map { it.lowercase() }
extensions.all { it.isNotEmpty() && allowedExtensionLowerCase.contains(it) }
extensionFromMimeType(mimeType)
}
return extension?.let { allowedExtension.contains(it) } ?: false
}

private fun extensionFromFileName(fileName: String): String? =
fileName.substringAfterLast('.', "").takeIf { it.isNotEmpty() }

private fun extensionFromMimeType(mimeType: String): String? = fileExtensions[mimeType]

private companion object {
val fileExtensions = mapOf(
"video/3gpp" to "3gpp",
"audio/aac" to "aac",
"audio/amr" to "amr",
"video/x-msvideo" to "avi",
"image/bmp" to "bmp",
"text/css" to "css",
"text/csv" to "csv",
"application/msword" to "doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" to "docx",
"message/rfc822" to "eml",
"audio/flac" to "flac",
"image/gif" to "gif",
"text/html" to "html",
"image/vnd.microsoft.icon" to "ico",
"image/jpeg" to "jpeg",
"image/jpeg" to "jpg",
"image/jpeg" to "jfif",
"application/vnd.apple.keynote" to "key",
"audio/mp4" to "m4a",
"video/x-m4v" to "m4v",
"text/markdown" to "md",
"audio/midi" to "midi",
"video/x-matroska" to "mkv",
"video/quicktime" to "mov",
"audio/mpeg" to "mp3",
"video/mp4" to "mp4",
"video/mpeg" to "mpeg",
"application/vnd.ms-outlook" to "msg",
"application/vnd.oasis.opendocument.spreadsheet" to "ods",
"application/vnd.oasis.opendocument.text" to "odt",
"audio/ogg" to "ogg",
"application/pdf" to "pdf",
"image/jpeg" to "pjp",
"image/pjpeg" to "pjpeg",
"image/png" to "png",
"application/vnd.ms-powerpoint" to "ppt",
"application/vnd.openxmlformats-officedocument.presentationml.presentation" to "pptx",
"image/vnd.adobe.photoshop" to "psd",
"application/rtf" to "rtf",
"application/sql" to "sql",
"image/svg+xml" to "svg",
"application/x-tex" to "tex",
"image/tiff" to "tiff",
"text/plain" to "txt",
"text/x-vcard" to "vcf",
"audio/wav" to "wav",
"video/webm" to "webm",
"image/webp" to "webp",
"video/x-ms-wmv" to "wmv",
"application/vnd.ms-excel" to "xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" to "xlsx",
"application/xml" to "xml"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ import com.wire.kalium.logic.data.message.Message
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.message.MessageRepository
import com.wire.kalium.logic.data.message.PersistMessageUseCase
import com.wire.kalium.logic.data.message.hasValidData
import com.wire.kalium.logic.data.message.getType
import com.wire.kalium.logic.feature.asset.ValidateAssetFileTypeUseCase
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.sync.receiver.conversation.message.hasValidData

internal interface AssetMessageHandler {
suspend fun handle(message: Message.Regular)
Expand All @@ -42,26 +43,55 @@ internal class AssetMessageHandlerImpl(
) : AssetMessageHandler {

override suspend fun handle(message: Message.Regular) {
val messageContent = message.content
if (messageContent !is MessageContent.Asset) {
if (message.content !is MessageContent.Asset) {
kaliumLogger.e("The asset message trying to be processed has invalid content data")
return
}

val messageContent = message.content as MessageContent.Asset

userConfigRepository.isFileSharingEnabled().onSuccess {
val isThisAssetAllowed = when (it.state) {
FileSharingStatus.Value.Disabled -> false
FileSharingStatus.Value.EnabledAll -> true
FileSharingStatus.Value.Disabled -> AssetRestrictionContinuationStrategy.Restrict
FileSharingStatus.Value.EnabledAll -> AssetRestrictionContinuationStrategy.Continue

is FileSharingStatus.Value.EnabledSome -> validateAssetMimeTypeUseCase(
messageContent.value.name,
it.state.allowedType
)
is FileSharingStatus.Value.EnabledSome -> {
// If the asset message is missing the name, but it does have full
// asset data then we can not decide now if it is allowed or not
// it is safe to continue and the code later will check the original
// asset message and decide if it is allowed or not
if (
messageContent.value.name.isNullOrEmpty() &&
messageContent.value.isAssetDataComplete
) {
kaliumLogger.e("The asset message trying to be processed has invalid data looking locally")
AssetRestrictionContinuationStrategy.RestrictIfThereIsNotOldMessageWithTheSameAssetID
} else {
validateAssetMimeTypeUseCase(
fileName = messageContent.value.name,
mimeType = messageContent.value.mimeType,
allowedExtension = it.state.allowedType
).let { validateResult ->
if (validateResult) {
AssetRestrictionContinuationStrategy.Continue
} else {
AssetRestrictionContinuationStrategy.Restrict
}
}
}
}
}

if (isThisAssetAllowed) {
processNonRestrictedAssetMessage(message, messageContent)
} else {
persistRestrictedAssetMessage(message, messageContent)
when (isThisAssetAllowed) {
AssetRestrictionContinuationStrategy.Continue -> processNonRestrictedAssetMessage(message, messageContent, false)
AssetRestrictionContinuationStrategy.RestrictIfThereIsNotOldMessageWithTheSameAssetID -> processNonRestrictedAssetMessage(
message,
messageContent,
true
)

AssetRestrictionContinuationStrategy.Restrict -> persistRestrictedAssetMessage(message, messageContent)

}
}
}
Expand All @@ -77,23 +107,34 @@ internal class AssetMessageHandlerImpl(
persistMessage(newMessage)
}

private suspend fun processNonRestrictedAssetMessage(processedMessage: Message.Regular, assetContent: MessageContent.Asset) {
private suspend fun processNonRestrictedAssetMessage(
processedMessage: Message.Regular,
assetContent: MessageContent.Asset,
restrictIfNotAFollowUpMessage: Boolean
) {
messageRepository.getMessageById(processedMessage.conversationId, processedMessage.id).onFailure {
// No asset message was received previously, so just persist the preview of the asset message
// Web/Mac clients split the asset message delivery into 2. One with the preview metadata (assetName, assetSize...) and
// with empty encryption keys and the second with empty metadata but all the correct encryption keys. We just want to
// hide the preview of generic asset messages with empty encryption keys as a way to avoid user interaction with them.
val initialMessage = processedMessage.copy(
visibility = if (assetContent.value.shouldBeDisplayed) Message.Visibility.VISIBLE else Message.Visibility.HIDDEN
)
persistMessage(initialMessage)

if (restrictIfNotAFollowUpMessage) {
persistRestrictedAssetMessage(processedMessage, assetContent)
} else {
val initialMessage = processedMessage.copy(
visibility = if (assetContent.value.isAssetDataComplete) Message.Visibility.VISIBLE else Message.Visibility.HIDDEN
)
persistMessage(initialMessage)
}
}.onSuccess { persistedMessage ->
val validDecryptionKeys = assetContent.value.remoteData
// Check the second asset message is from the same original sender
if (isSenderVerified(persistedMessage, processedMessage) && persistedMessage is Message.Regular) {
// The second asset message received from Web/Mac clients contains the full asset decryption keys, so we need to update
// the preview message persisted previously with the rest of the data
persistMessage(updateAssetMessageWithDecryptionKeys(persistedMessage, validDecryptionKeys))
updateAssetMessageWithDecryptionKeys(persistedMessage, validDecryptionKeys)?.let {
persistMessage(it)
}
} else {
kaliumLogger.e("The previously persisted message has a different sender id than the one we are trying to process")
}
Expand All @@ -106,8 +147,21 @@ internal class AssetMessageHandlerImpl(
private fun updateAssetMessageWithDecryptionKeys(
persistedMessage: Message.Regular,
remoteData: AssetContent.RemoteData
): Message.Regular {
val assetMessageContent = persistedMessage.content as MessageContent.Asset
): Message.Regular? {
val assetMessageContent = when (persistedMessage.content) {
is MessageContent.Asset -> persistedMessage.content as MessageContent.Asset
is MessageContent.RestrictedAsset -> {
// original message was a restricted asset message, ignoring
return null
}

is MessageContent.FailedDecryption,
is MessageContent.Knock,
is MessageContent.Location,
is MessageContent.Composite,
is MessageContent.Text,
is MessageContent.Unknown -> error("Invalid asset message content type ${persistedMessage.content.getType()}")
}
// The message was previously received with just metadata info, so let's update it with the raw data info
return persistedMessage.copy(
content = assetMessageContent.copy(
Expand All @@ -120,3 +174,9 @@ internal class AssetMessageHandlerImpl(
)
}
}

private sealed interface AssetRestrictionContinuationStrategy {
data object Continue : AssetRestrictionContinuationStrategy
data object Restrict : AssetRestrictionContinuationStrategy
data object RestrictIfThereIsNotOldMessageWithTheSameAssetID : AssetRestrictionContinuationStrategy
}
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,13 @@ class ScheduleNewAssetMessageUseCaseTest {
// Then
assertTrue(result is ScheduleNewAssetMessageResult.Failure.RestrictedFileType)

coVerify { arrangement.validateAssetMimeTypeUseCase(eq("some-asset.txt"), eq(listOf("png"))) }
coVerify {
arrangement.validateAssetMimeTypeUseCase(
fileName = eq("some-asset.txt"),
mimeType = eq("text/plain"),
allowedExtension = eq(listOf("png"))
)
}
.wasInvoked(exactly = once)
}

Expand Down Expand Up @@ -667,7 +673,13 @@ class ScheduleNewAssetMessageUseCaseTest {
// Then
assertTrue(result is ScheduleNewAssetMessageResult.Success)

coVerify { arrangement.validateAssetMimeTypeUseCase(eq("some-asset.png"), eq(listOf("png"))) }
coVerify {
arrangement.validateAssetMimeTypeUseCase(
fileName = eq("some-asset.png"),
mimeType = eq("image/png"),
allowedExtension = eq(listOf("png"))
)
}
.wasInvoked(exactly = once)
}

Expand Down Expand Up @@ -721,7 +733,7 @@ class ScheduleNewAssetMessageUseCaseTest {

fun withValidateAsseMimeTypeResult(result: Boolean) = apply {
every {
validateAssetMimeTypeUseCase.invoke(any(), any())
validateAssetMimeTypeUseCase.invoke(any(), any(), any())
}.returns(result)
}

Expand Down
Loading

0 comments on commit a04b87f

Please sign in to comment.