diff --git a/Cargo.lock b/Cargo.lock index b1c5a8d..740a173 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + [[package]] name = "aho-corasick" version = "0.7.18" @@ -417,6 +423,18 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37ab347416e802de484e4d03c7316c48f1ecb56574dfd4a46a80f173ce1de04d" +[[package]] +name = "flate2" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f" +dependencies = [ + "cfg-if", + "crc32fast", + "libc", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -757,6 +775,7 @@ dependencies = [ "base64", "candid", "clap", + "flate2", "garcon", "hex", "hyper", @@ -974,6 +993,16 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" +[[package]] +name = "miniz_oxide" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" +dependencies = [ + "adler", + "autocfg", +] + [[package]] name = "mio" version = "0.7.14" diff --git a/Cargo.toml b/Cargo.toml index 3d0dab1..f335dd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ anyhow = "1.0.34" base64 = "0.13" candid = { version = "0.7.11", features = ["mute_warnings"] } clap = { version = "3", features = ["cargo", "derive"] } +flate2 = "1.0.0" garcon = { version = "0.2.3", features = ["async"] } hex = "0.4.3" hyper = { version = "0.14.13", features = ["full"] } diff --git a/src/main.rs b/src/main.rs index 3e399b0..0ff6f91 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use crate::config::dns_canister_config::DnsCanisterConfig; use clap::{crate_authors, crate_version, AppSettings, Parser}; +use flate2::read::{DeflateDecoder, GzDecoder}; use hyper::{ body, body::Bytes, @@ -24,6 +25,7 @@ use ic_utils::{ use lazy_regex::regex_captures; use sha2::{Digest, Sha256}; use slog::Drain; +use std::io::prelude::Read; use std::{ convert::Infallible, error::Error, @@ -45,6 +47,9 @@ static MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: i32 = 1000; // The maximum length of a body we should log as tracing. static MAX_LOG_BODY_SIZE: usize = 100; +// The limit of a buffer we should decompress ~10mb. +static MAX_BYTES_SIZE_TO_DECOMPRESS: u64 = 10_000_000; + #[derive(Parser)] #[clap( version = crate_version!(), @@ -169,6 +174,111 @@ fn resolve_canister_id( None } +fn decode_hash_tree( + name: &str, + value: Option, + logger: &slog::Logger, +) -> Result, ()> { + match value { + Some(tree) => base64::decode(tree).map_err(|e| { + slog::warn!(logger, "Unable to decode {} from base64: {}", name, e); + }), + _ => Err(()), + } +} + +struct HeadersData { + certificate: Option, ()>>, + tree: Option, ()>>, + chunk_tree: Option>, + chunk_index: String, + encoding: Option, +} + +fn extract_headers_data(headers: &[HeaderField], logger: &slog::Logger) -> HeadersData { + let mut headers_data = HeadersData { + certificate: None, + tree: None, + chunk_tree: None, + chunk_index: String::from("0"), + encoding: None, + }; + + for HeaderField(name, value) in headers { + if name.eq_ignore_ascii_case("IC-CERTIFICATE") { + for field in value.split(',') { + if let Some((_, name, b64_value)) = regex_captures!("^(.*)=:(.*):$", field.trim()) { + slog::trace!(logger, ">> certificate {}: {}", name, b64_value); + if name == "chunk_index" { + headers_data.chunk_index = b64_value.to_string(); + continue; + } + let bytes = decode_hash_tree(name, Some(b64_value.to_string()), logger); + if name == "certificate" { + headers_data.certificate = Some(match (headers_data.certificate, bytes) { + (None, bytes) => bytes, + (Some(Ok(certificate)), Ok(bytes)) => { + slog::warn!(logger, "duplicate certificate field: {:?}", bytes); + Ok(certificate) + } + (Some(Ok(certificate)), Err(_)) => { + slog::warn!( + logger, + "duplicate certificate field (failed to decode)" + ); + Ok(certificate) + } + (Some(Err(_)), bytes) => { + slog::warn!( + logger, + "duplicate certificate field (failed to decode)" + ); + bytes + } + }); + } else if name == "tree" { + headers_data.tree = Some(match (headers_data.tree, bytes) { + (None, bytes) => bytes, + (Some(Ok(tree)), Ok(bytes)) => { + slog::warn!(logger, "duplicate tree field: {:?}", bytes); + Ok(tree) + } + (Some(Ok(tree)), Err(_)) => { + slog::warn!(logger, "duplicate tree field (failed to decode)"); + Ok(tree) + } + (Some(Err(_)), bytes) => { + slog::warn!(logger, "duplicate tree field (failed to decode)"); + bytes + } + }); + } else if name == "chunk_tree" { + headers_data.chunk_tree = match (headers_data.chunk_tree, bytes) { + (None, bytes) => bytes.ok(), + (Some(chunk_tree), Ok(bytes)) => { + slog::warn!(logger, "duplicate chunk_tree field: {:?}", bytes); + Some(chunk_tree) + } + (Some(chunk_tree), Err(_)) => { + slog::warn!( + logger, + "duplicate chunk_tree field (failed to decode)" + ); + Some(chunk_tree) + } + }; + } + } + } + } else if name.eq_ignore_ascii_case("CONTENT-ENCODING") { + let enc = value.trim().to_string(); + headers_data.encoding = Some(enc); + } + } + + headers_data +} + async fn forward_request( request: Request, agent: Arc, @@ -280,76 +390,39 @@ async fn forward_request( http_response }; - let mut certificate: Option, ()>> = None; - let mut tree: Option, ()>> = None; - let mut builder = Response::builder().status(StatusCode::from_u16(http_response.status_code)?); - for HeaderField(name, value) in http_response.headers { - if name.eq_ignore_ascii_case("IC-CERTIFICATE") { - for field in value.split(',') { - if let Some((_, name, b64_value)) = regex_captures!("^(.*)=:(.*):$", field.trim()) { - slog::trace!(logger, ">> certificate {}: {}", name, b64_value); - let bytes = base64::decode(b64_value).map_err(|e| { - slog::warn!( - logger, - "Unable to decode {} in ic-certificate from base64: {}", - name, - e - ); - }); - if name == "certificate" { - certificate = Some(match (certificate, bytes) { - (None, bytes) => bytes, - (Some(Ok(certificate)), Ok(bytes)) => { - slog::warn!(logger, "duplicate certificate field: {:?}", bytes); - Ok(certificate) - } - (Some(Ok(certificate)), Err(_)) => { - slog::warn!( - logger, - "duplicate certificate field (failed to decode)" - ); - Ok(certificate) - } - (Some(Err(_)), bytes) => { - slog::warn!( - logger, - "duplicate certificate field (failed to decode)" - ); - bytes - } - }); - } else if name == "tree" { - tree = Some(match (tree, bytes) { - (None, bytes) => bytes, - (Some(Ok(tree)), Ok(bytes)) => { - slog::warn!(logger, "duplicate tree field: {:?}", bytes); - Ok(tree) - } - (Some(Ok(tree)), Err(_)) => { - slog::warn!(logger, "duplicate tree field (failed to decode)"); - Ok(tree) - } - (Some(Err(_)), bytes) => { - slog::warn!(logger, "duplicate tree field (failed to decode)"); - bytes - } - }); - } - } - } - } - - builder = builder.header(&name, value); + for HeaderField(name, value) in &http_response.headers { + builder = builder.header(name, value); } + let headers_data = extract_headers_data(&http_response.headers, &logger); let body = if logger.is_trace_enabled() { Some(http_response.body.clone()) } else { None }; - let is_streaming = http_response.streaming_strategy.is_some(); - let response = if let Some(streaming_strategy) = http_response.streaming_strategy { + + // No need to stream when get 206 HTTP partial response + let is_streaming = + http_response.streaming_strategy.is_some() && http_response.status_code != 206; + let body_valid = validate( + &headers_data, + &canister_id, + &agent, + &uri, + &http_response.body, + is_streaming, + logger.clone(), + ); + if body_valid.is_err() { + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(body_valid.unwrap_err().into()) + .unwrap()); + } + + let response = if is_streaming { + let streaming_strategy = http_response.streaming_strategy.unwrap(); let (mut sender, body) = body::Body::channel(); let agent = agent.as_ref().clone(); sender.send_data(Bytes::from(http_response.body)).await?; @@ -359,6 +432,7 @@ async fn forward_request( let streaming_canister_id_id = callback.callback.principal; let method_name = callback.callback.method; let mut callback_token = callback.token; + let chunk_index = callback_token.index.0.to_str_radix(10); let logger = logger.clone(); tokio::spawn(async move { let canister = HttpRequestCanister::create(&agent, streaming_canister_id_id); @@ -376,8 +450,31 @@ async fn forward_request( .call() .await { - Ok((StreamingCallbackHttpResponse { body, token },)) => { - if sender.send_data(Bytes::from(body)).await.is_err() { + Ok((StreamingCallbackHttpResponse { + body, + token, + chunk_tree, + },)) => { + let chunk_headers_data = HeadersData { + certificate: headers_data.certificate.clone(), + tree: headers_data.tree.clone(), + encoding: headers_data.encoding.clone(), + chunk_tree, + chunk_index: chunk_index.clone(), + }; + let body_valid = validate( + &chunk_headers_data, + &canister_id, + &agent, + &uri, + &body, + is_streaming, + logger.clone(), + ); + + if body_valid.is_err() + || sender.send_data(Bytes::from(body)).await.is_err() + { sender.abort(); break; } @@ -400,35 +497,6 @@ async fn forward_request( builder.body(body)? } else { - let body_valid = match (certificate, tree) { - (Some(Ok(certificate)), Some(Ok(tree))) => match validate_body( - &certificate, - &tree, - &canister_id, - &agent, - &uri, - &http_response.body, - logger.clone(), - ) { - Ok(valid) => valid, - Err(e) => { - return Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(format!("Certificate validation failed: {}", e).into()) - .unwrap()); - } - }, - (Some(_), _) | (_, Some(_)) => false, - // Canisters don't have to provide certified variables - (None, None) => true, - }; - - if !body_valid && !cfg!(feature = "skip_body_verification") { - return Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body("Body does not pass verification".into()) - .unwrap()); - } builder.body(http_response.body.into())? }; @@ -467,18 +535,104 @@ async fn forward_request( Ok(response) } -fn validate_body( - certificate: &[u8], - tree: &[u8], +fn validate( + headers_data: &HeadersData, canister_id: &Principal, agent: &Agent, uri: &Uri, response_body: &[u8], + is_streaming: bool, + logger: slog::Logger, +) -> Result<(), String> { + let body_sha = decode_body(response_body, headers_data.encoding.clone()); + let body_valid = match (headers_data.certificate.clone(), headers_data.tree.clone()) { + (Some(Ok(certificate)), Some(Ok(tree))) => match validate_body( + Certificates { + certificate, + tree, + chunk_tree: headers_data.chunk_tree.clone(), + chunk_index: headers_data.chunk_index.clone(), + }, + canister_id, + agent, + uri, + &body_sha, + logger.clone(), + ) { + Ok(valid) => { + if valid { + Ok(()) + } else { + Err("Body does not pass verification".to_string()) + } + } + Err(e) => Err(format!("Certificate validation failed: {}", e)), + }, + (Some(_), _) | (_, Some(_)) => Err("Body does not pass verification".to_string()), + // Canisters don't have to provide certified variables + (None, None) => Ok(()), + }; + + if body_valid.is_err() && !cfg!(feature = "skip_body_verification") { + match (is_streaming, headers_data.chunk_tree.is_some()) { + (true, false) => {} // backward compatibility. Headers could not contain chunk_tree witness for streaming + _ => { + return body_valid; + } + } + } + + Ok(()) +} + +fn decode_body(body: &[u8], encoding: Option) -> [u8; 32] { + let mut sha256 = Sha256::new(); + match encoding { + Some(enc) => match enc.as_str() { + "gzip" => { + let decoded: &mut Vec = &mut vec![]; + let decoder = GzDecoder::new(body); + decoder + .take(MAX_BYTES_SIZE_TO_DECOMPRESS) + .read_to_end(decoded) + .unwrap(); + sha256.update(decoded); + } + "deflate" => { + let decoded: &mut Vec = &mut vec![]; + let decoder = DeflateDecoder::new(body); + decoder + .take(MAX_BYTES_SIZE_TO_DECOMPRESS) + .read_to_end(decoded) + .unwrap(); + sha256.update(decoded); + } + _ => sha256.update(body), + }, + _ => sha256.update(body), + }; + sha256.finalize().into() +} + +struct Certificates { + certificate: Vec, + tree: Vec, + chunk_tree: Option>, + chunk_index: String, +} + +fn validate_body( + certificates: Certificates, + canister_id: &Principal, + agent: &Agent, + uri: &Uri, + body_sha: &[u8; 32], logger: slog::Logger, ) -> anyhow::Result { let cert: Certificate = - serde_cbor::from_slice(certificate).map_err(AgentError::InvalidCborData)?; - let tree: HashTree = serde_cbor::from_slice(tree).map_err(AgentError::InvalidCborData)?; + serde_cbor::from_slice(&certificates.certificate).map_err(AgentError::InvalidCborData)?; + let tree: HashTree = + serde_cbor::from_slice(&certificates.tree).map_err(AgentError::InvalidCborData)?; if let Err(e) = agent.verify(&cert) { slog::trace!(logger, ">> certificate failed verification: {}", e); @@ -530,11 +684,37 @@ fn validate_body( }, }; - let mut sha256 = Sha256::new(); - sha256.update(response_body); - let body_sha = sha256.finalize(); + if let Some(tree) = certificates.chunk_tree { + let chunk_tree: HashTree = + serde_cbor::from_slice(&tree).map_err(AgentError::InvalidCborData)?; + + let chunk_tree_digest = chunk_tree.digest(); - Ok(&body_sha[..] == tree_sha) + if chunk_tree_digest != tree_sha { + slog::trace!( + logger, + ">> Invalid chunk_tree in the header. Digest does not equal tree_sha", + ); + return Ok(false); + } + + let index_path = [certificates.chunk_index.into()]; + let chunk_sha = match chunk_tree.lookup_path(&index_path) { + LookupResult::Found(v) => v, + _ => { + slog::trace!( + logger, + ">> Invalid Tree in the header. Does not contain path {:?}", + path + ); + return Ok(false); + } + }; + + Ok(body_sha == chunk_sha) + } else { + Ok(body_sha == tree_sha) + } } fn is_hop_header(name: &str) -> bool {