diff --git a/lib/file-source/src/fingerprinter.rs b/lib/file-source/src/fingerprinter.rs index 84bb2efe04d7d..69f5e690e22a1 100644 --- a/lib/file-source/src/fingerprinter.rs +++ b/lib/file-source/src/fingerprinter.rs @@ -77,6 +77,7 @@ enum SupportedCompressionAlgorithms { impl SupportedCompressionAlgorithms { fn values() -> Vec { + // Enumerate these from smallest magic_header_bytes to largest vec![SupportedCompressionAlgorithms::GZIP] } @@ -88,39 +89,41 @@ impl SupportedCompressionAlgorithms { } trait UncompressedReader { - fn check(fp: &mut File) -> Option; - fn reader<'a>(fp: &'a mut File) -> Box; + fn check(fp: &mut File) -> Result, std::io::Error>; + fn reader<'a>(fp: &'a mut File) -> Result, std::io::Error>; } struct UncompressedReaderImpl; impl UncompressedReader for UncompressedReaderImpl { - fn check(fp: &mut File) -> Option { - SupportedCompressionAlgorithms::values() - .iter() - .find_map(|compression_algorithm| { - // magic headers for algorithms can be of different lengths, and using a buffer too long could exceed the length of the file - // so instantiate and check the various sizes independently - let magic_header_bytes = compression_algorithm.magic_header_bytes(); - let mut magic = vec![0u8; magic_header_bytes.len()]; - - if fp.read_exact(&mut magic).is_ok() - && fp.seek(SeekFrom::Start(0)).is_ok() - && magic == magic_header_bytes - { - Some(*compression_algorithm) - } else { - None - } - }) + fn check(fp: &mut File) -> Result, std::io::Error> { + for compression_algorithm in SupportedCompressionAlgorithms::values() { + // magic headers for algorithms can be of different lengths, and using a buffer too long could exceed the length of the file + // so instantiate and check the various sizes in monotonically increasing order + let magic_header_bytes = compression_algorithm.magic_header_bytes(); + let mut magic = vec![0u8; magic_header_bytes.len()]; + + let result = fp.read_exact(&mut magic); + let reset = fp.seek(SeekFrom::Start(0)); + if reset.is_err() { + return Err(reset.unwrap_err()); + } else if result.is_err() { + return Err(result.unwrap_err()); + } else if magic == magic_header_bytes { + return Ok(Some(compression_algorithm)); + } else { + continue; + } + } + return Ok(None); } - fn reader<'a>(fp: &'a mut File) -> Box { + fn reader<'a>(fp: &'a mut File) -> Result, std::io::Error> { // To support new compression algorithms, add them below - match Self::check(fp) { + match Self::check(fp)? { Some(SupportedCompressionAlgorithms::GZIP) => { - Box::new(BufReader::new(GzDecoder::new(BufReader::new(fp)))) + Ok(Box::new(BufReader::new(GzDecoder::new(BufReader::new(fp))))) } // No compression, or read the raw bytes - None => Box::new(BufReader::new(fp)), + None => Ok(Box::new(BufReader::new(fp))), } } } @@ -164,7 +167,7 @@ impl Fingerprinter { } => { buffer.resize(self.max_line_length, 0u8); let mut fp = fs::File::open(path)?; - let mut reader = UncompressedReaderImpl::reader(&mut fp); + let mut reader = UncompressedReaderImpl::reader(&mut fp)?; skip_first_n_bytes(&mut reader, ignored_header_bytes)?; let bytes_read = fingerprinter_read_until(reader, b'\n', lines, buffer)?; @@ -728,6 +731,22 @@ mod test { .is_none()); } + #[test] + fn test_monotonic_compression_algorithms() { + // This test is necessary to handle an edge case where when assessing the magic header + // bytes of a file to determine the compression algorithm, it's possible that the length of + // the file is smaller than the size of the magic header bytes it's being assessed against. + // While this could be an indication that the file is simply too small, it could also + // just be that the compression header is a smaller one than the assessed algorithm. + // Checking this with a guarantee on the monotonically increasing order assures that this edge case doesn't happen. + let algos = super::SupportedCompressionAlgorithms::values(); + let mut smallest_byte_length = 0; + for algo in algos { + let magic_header_bytes = algo.magic_header_bytes(); + assert!(smallest_byte_length <= magic_header_bytes.len()); + smallest_byte_length = magic_header_bytes.len(); + } + } #[derive(Clone)] struct NoErrors;