From b8c88a3e9129bd2f976a8c7631d754fed0765324 Mon Sep 17 00:00:00 2001 From: Tobias Hartmann Date: Thu, 4 Jan 2024 09:16:19 +0000 Subject: [PATCH] 8321599: Data loss in AVX3 Base64 decoding Reviewed-by: chagedorn Backport-of: 13c11487f7126a370d9ce8e62f661ea83eedefe6 --- src/hotspot/cpu/x86/stubGenerator_x86_64.cpp | 6 +- .../intrinsics/base64/TestBase64.java | 121 +++++++++++++++++- 2 files changed, 124 insertions(+), 3 deletions(-) diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp index c73e0759b57..9abc559090a 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2003, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2003, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -2318,7 +2318,7 @@ address StubGenerator::generate_base64_decodeBlock() { const Register isURL = c_rarg5;// Base64 or URL character set __ movl(isMIME, Address(rbp, 2 * wordSize)); #else - const Address dp_mem(rbp, 6 * wordSize); // length is on stack on Win64 + const Address dp_mem(rbp, 6 * wordSize); // length is on stack on Win64 const Address isURL_mem(rbp, 7 * wordSize); const Register isURL = r10; // pick the volatile windows register const Register dp = r12; @@ -2540,10 +2540,12 @@ address StubGenerator::generate_base64_decodeBlock() { // output_size in r13 // Strip pad characters, if any, and adjust length and mask + __ addq(length, start_offset); __ cmpb(Address(source, length, Address::times_1, -1), '='); __ jcc(Assembler::equal, L_padding); __ BIND(L_donePadding); + __ subq(length, start_offset); // Output size is (64 - output_size), output mask is (all 1s >> output_size). __ kmovql(input_mask, rax); diff --git a/test/hotspot/jtreg/compiler/intrinsics/base64/TestBase64.java b/test/hotspot/jtreg/compiler/intrinsics/base64/TestBase64.java index 5d2651c3285..0d3c9569c33 100644 --- a/test/hotspot/jtreg/compiler/intrinsics/base64/TestBase64.java +++ b/test/hotspot/jtreg/compiler/intrinsics/base64/TestBase64.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -46,10 +46,13 @@ import java.util.Base64; import java.util.Base64.Decoder; import java.util.Base64.Encoder; +import java.util.HexFormat; import java.util.Objects; import java.util.Random; import java.util.Arrays; +import static java.lang.String.format; + import compiler.whitebox.CompilerWhiteBoxTest; import jdk.test.whitebox.code.Compiler; import jtreg.SkippedException; @@ -69,6 +72,8 @@ public static void main(String[] args) throws Exception { warmup(); + length_checks(); + test0(FileType.ASCII, Base64Type.BASIC, Base64.getEncoder(), Base64.getDecoder(),"plain.txt", "baseEncode.txt", iters); test0(FileType.ASCII, Base64Type.URLSAFE, Base64.getUrlEncoder(), Base64.getUrlDecoder(),"plain.txt", "urlEncode.txt", iters); test0(FileType.ASCII, Base64Type.MIME, Base64.getMimeEncoder(), Base64.getMimeDecoder(),"plain.txt", "mimeEncode.txt", iters); @@ -302,4 +307,118 @@ private static final byte getBadBase64Char(Base64Type b64Type) { throw new InternalError("Internal test error: getBadBase64Char called with unknown Base64Type value"); } } + + static final int POSITIONS = 30_000; + static final int BASE_LENGTH = 256; + static final HexFormat HEX_FORMAT = HexFormat.of().withUpperCase().withDelimiter(" "); + + static int[] plainOffsets = new int[POSITIONS + 1]; + static byte[] plainBytes; + static int[] base64Offsets = new int[POSITIONS + 1]; + static byte[] base64Bytes; + + static { + // Set up ByteBuffer with characters to be encoded + int plainLength = 0; + for (int i = 0; i < plainOffsets.length; i++) { + plainOffsets[i] = plainLength; + int positionLength = (BASE_LENGTH + i) % 2048; + plainLength += positionLength; + } + // Put one of each possible byte value into ByteBuffer + plainBytes = new byte[plainLength]; + for (int i = 0; i < plainBytes.length; i++) { + plainBytes[i] = (byte) i; + } + + // Grab various slices of the ByteBuffer and encode them + ByteBuffer plainBuffer = ByteBuffer.wrap(plainBytes); + int base64Length = 0; + for (int i = 0; i < POSITIONS; i++) { + base64Offsets[i] = base64Length; + int offset = plainOffsets[i]; + int length = plainOffsets[i + 1] - offset; + ByteBuffer plainSlice = plainBuffer.slice(offset, length); + base64Length += Base64.getEncoder().encode(plainSlice).remaining(); + } + + // Decode the slices created above and ensure lengths match + base64Offsets[base64Offsets.length - 1] = base64Length; + base64Bytes = new byte[base64Length]; + for (int i = 0; i < POSITIONS; i++) { + int plainOffset = plainOffsets[i]; + ByteBuffer plainSlice = plainBuffer.slice(plainOffset, plainOffsets[i + 1] - plainOffset); + ByteBuffer encodedBytes = Base64.getEncoder().encode(plainSlice); + int base64Offset = base64Offsets[i]; + int expectedLength = base64Offsets[i + 1] - base64Offset; + if (expectedLength != encodedBytes.remaining()) { + throw new IllegalStateException(format("Unexpected length: %s <> %s", encodedBytes.remaining(), expectedLength)); + } + encodedBytes.get(base64Bytes, base64Offset, expectedLength); + } + } + + public static void length_checks() { + decodeAndCheck(); + encodeDecode(); + System.out.println("Test complete, no invalid decodes detected"); + } + + // Use ByteBuffer to cause decode() to use the base + offset form of decode + // Checks for bug reported in JDK-8321599 where padding characters appear + // within the beginning of the ByteBuffer *before* the offset. This caused + // the decoded string length to be off by 1 or 2 bytes. + static void decodeAndCheck() { + for (int i = 0; i < POSITIONS; i++) { + ByteBuffer encodedBytes = base64BytesAtPosition(i); + ByteBuffer decodedBytes = Base64.getDecoder().decode(encodedBytes); + + if (!decodedBytes.equals(plainBytesAtPosition(i))) { + String base64String = base64StringAtPosition(i); + String plainHexString = plainHexStringAtPosition(i); + String decodedHexString = HEX_FORMAT.formatHex(decodedBytes.array(), decodedBytes.arrayOffset() + decodedBytes.position(), decodedBytes.arrayOffset() + decodedBytes.limit()); + throw new IllegalStateException(format("Mismatch for %s\n\nExpected:\n%s\n\nActual:\n%s", base64String, plainHexString, decodedHexString)); + } + } + } + + // Encode strings of lengths 1-1K, decode, and ensure length and contents correct. + // This checks that padding characters are properly handled by decode. + static void encodeDecode() { + String allAs = "A(=)".repeat(128); + for (int i = 1; i <= 512; i++) { + String encStr = Base64.getEncoder().encodeToString(allAs.substring(0, i).getBytes()); + String decStr = new String(Base64.getDecoder().decode(encStr)); + + if ((decStr.length() != allAs.substring(0, i).length()) || + (!Objects.equals(decStr, allAs.substring(0, i))) + ) { + throw new IllegalStateException(format("Mismatch: Expected: %s\n Actual: %s\n", allAs.substring(0, i), decStr)); + } + } + } + + static ByteBuffer plainBytesAtPosition(int position) { + int offset = plainOffsets[position]; + int length = plainOffsets[position + 1] - offset; + return ByteBuffer.wrap(plainBytes, offset, length); + } + + static String plainHexStringAtPosition(int position) { + int offset = plainOffsets[position]; + int length = plainOffsets[position + 1] - offset; + return HEX_FORMAT.formatHex(plainBytes, offset, offset + length); + } + + static String base64StringAtPosition(int position) { + int offset = base64Offsets[position]; + int length = base64Offsets[position + 1] - offset; + return new String(base64Bytes, offset, length, StandardCharsets.UTF_8); + } + + static ByteBuffer base64BytesAtPosition(int position) { + int offset = base64Offsets[position]; + int length = base64Offsets[position + 1] - offset; + return ByteBuffer.wrap(base64Bytes, offset, length); + } }