diff --git a/base64.cpp b/base64.cpp index 9354e3c..7666ef8 100644 --- a/base64.cpp +++ b/base64.cpp @@ -191,17 +191,45 @@ static std::string decode(String encoded_string, bool remove_linebreaks) { ret.reserve(approx_length_of_decoded_string); while (pos < length_of_string) { - - unsigned int pos_of_char_1 = pos_of_char(encoded_string[pos+1] ); - + // + // Iterate over encoded input string in chunks. The size of all + // chunks except the last one is 4 bytes. + // + // The last chunk might be padded with equal signs or dots + // in order to make it 4 bytes in size as well, but this + // is not required as per RFC 2045. + // + // All chunks except the last one produce three output bytes. + // + // The last chunk produces at least one and up to three bytes. + // + + size_t pos_of_char_1 = pos_of_char(encoded_string[pos+1] ); + + // + // Emit the first output byte that is produced in each chunk: + // ret.push_back(static_cast( ( (pos_of_char(encoded_string[pos+0]) ) << 2 ) + ( (pos_of_char_1 & 0x30 ) >> 4))); - if (encoded_string[pos+2] != '=' && encoded_string[pos+2] != '.') { // accept URL-safe base 64 strings, too, so check for '.' also. - + if ( ( pos + 2 < length_of_string ) && // Check for data that is not padded with equal signs (which is allowed by RFC 2045) + encoded_string[pos+2] != '=' && + encoded_string[pos+2] != '.' // accept URL-safe base 64 strings, too, so check for '.' also. + ) + { + // + // Emit a chunk's second byte (which might not be produced in the last chunk). + // unsigned int pos_of_char_2 = pos_of_char(encoded_string[pos+2] ); ret.push_back(static_cast( (( pos_of_char_1 & 0x0f) << 4) + (( pos_of_char_2 & 0x3c) >> 2))); - if (encoded_string[pos+3] != '=' && encoded_string[pos+3] != '.') { + if ( ( pos + 3 < length_of_string ) && + encoded_string[pos+3] != '=' && + encoded_string[pos+3] != '.' + ) + { + // + // Emit a chunk's third byte (which might not be produced in the last chunk). + // ret.push_back(static_cast( ( (pos_of_char_2 & 0x03 ) << 6 ) + pos_of_char(encoded_string[pos+3]) )); } } diff --git a/test.cpp b/test.cpp index 4601de3..18b0afa 100644 --- a/test.cpp +++ b/test.cpp @@ -175,6 +175,47 @@ int main() { all_tests_passed = false; } + // ---------------------------------------------- + + std::string unpadded_input = "YWJjZGVmZw"; // Note the 'missing' "==" + std::string unpadded_decoded = base64_decode(unpadded_input); + if (unpadded_decoded != "abcdefg") { + std::cout << "Failed to decode unpadded input " << unpadded_input << std::endl; + all_tests_passed = false; + } + + unpadded_input = "YWJjZGU"; // Note the 'missing' "=" + unpadded_decoded = base64_decode(unpadded_input); + if (unpadded_decoded != "abcde") { + std::cout << "Failed to decode unpadded input " << unpadded_input << std::endl; + std::cout << unpadded_decoded << std::endl; + all_tests_passed = false; + } + + unpadded_input = ""; + unpadded_decoded = base64_decode(unpadded_input); + if (unpadded_decoded != "") { + std::cout << "Failed to decode unpadded input " << unpadded_input << std::endl; + std::cout << unpadded_decoded << std::endl; + all_tests_passed = false; + } + + unpadded_input = "YQ"; + unpadded_decoded = base64_decode(unpadded_input); + if (unpadded_decoded != "a") { + std::cout << "Failed to decode unpadded input " << unpadded_input << std::endl; + std::cout << unpadded_decoded << std::endl; + all_tests_passed = false; + } + + unpadded_input = "YWI"; + unpadded_decoded = base64_decode(unpadded_input); + if (unpadded_decoded != "ab") { + std::cout << "Failed to decode unpadded input " << unpadded_input << std::endl; + std::cout << unpadded_decoded << std::endl; + all_tests_passed = false; + } + // -------------------------------------------------------------- #if __cplusplus >= 201703L