From 96901ac84030fdaf09f23cc6ffcf396384a149ec Mon Sep 17 00:00:00 2001 From: Carlos Ballesteros Velasco Date: Fri, 27 Nov 2020 18:50:09 +0100 Subject: [PATCH] Fixes Base64 decode when no padding is provided (#37) * Add tests from Korio + add test reproducing and fixing the issue #64 * Fixes Base64 decoding when no padding is provided --- .../com/soywiz/krypto/encoding/Base64.kt | 8 ++-- .../kotlin/com/soywiz/krypto/Base64Test.kt | 45 +++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 krypto/src/commonTest/kotlin/com/soywiz/krypto/Base64Test.kt diff --git a/krypto/src/commonMain/kotlin/com/soywiz/krypto/encoding/Base64.kt b/krypto/src/commonMain/kotlin/com/soywiz/krypto/encoding/Base64.kt index fa460ea..1035ad4 100644 --- a/krypto/src/commonMain/kotlin/com/soywiz/krypto/encoding/Base64.kt +++ b/krypto/src/commonMain/kotlin/com/soywiz/krypto/encoding/Base64.kt @@ -33,10 +33,10 @@ object Base64 { continue // skip character } - val b0 = DECODE[src.readU8(n++)] - val b1 = DECODE[src.readU8(n++)] - val b2 = DECODE[src.readU8(n++)] - val b3 = DECODE[src.readU8(n++)] + val b0 = if (n < src.size) DECODE[src.readU8(n++)] else 64 + val b1 = if (n < src.size) DECODE[src.readU8(n++)] else 64 + val b2 = if (n < src.size) DECODE[src.readU8(n++)] else 64 + val b3 = if (n < src.size) DECODE[src.readU8(n++)] else 64 dst[m++] = (b0 shl 2 or (b1 shr 4)).toByte() if (b2 < 64) { dst[m++] = (b1 shl 4 or (b2 shr 2)).toByte() diff --git a/krypto/src/commonTest/kotlin/com/soywiz/krypto/Base64Test.kt b/krypto/src/commonTest/kotlin/com/soywiz/krypto/Base64Test.kt new file mode 100644 index 0000000..3f06379 --- /dev/null +++ b/krypto/src/commonTest/kotlin/com/soywiz/krypto/Base64Test.kt @@ -0,0 +1,45 @@ +package com.soywiz.krypto + +import com.soywiz.krypto.encoding.Base64 +import com.soywiz.krypto.encoding.fromBase64 +import com.soywiz.krypto.encoding.toBase64 +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class Base64Test { + @Test + fun name() { + assertEquals("AQID", Base64.encode(byteArrayOf(1, 2, 3))) + assertEquals("aGVsbG8=", Base64.encode("hello".encodeToByteArray())) + assertEquals(byteArrayOf(1, 2, 3).toList(), Base64.decode("AQID").toList()) + assertEquals("hello", Base64.decode("aGVsbG8=").decodeToString()) + } + + @Test + fun testSeveral() { + for (item in listOf("", "a", "aa", "aaa", "aaaa", "aaaaa", "Hello World!")) { + assertEquals(item, item.encodeToByteArray().toBase64().fromBase64().decodeToString()) + } + } + + @Test + fun testGlobal() { + assertEquals("hello", globalBase64) + assertEquals("hello", ObjectBase64.globalBase64) + } + + @Test + fun testIssue64DecodeWithMissingPadding() { + assertEquals( + "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ==", + Base64.encode(Base64.decode("eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ")) + ) + } +} + +object ObjectBase64 { + val globalBase64 = "aGVsbG8=".fromBase64().decodeToString() +} + +val globalBase64 = "aGVsbG8=".fromBase64().decodeToString()