From f6b1ee9a3145d625965a5fb9975030767cb41907 Mon Sep 17 00:00:00 2001 From: cn-kali-team Date: Mon, 4 Nov 2024 23:59:11 +0800 Subject: [PATCH] fix keepalive --- Cargo.toml | 2 +- src/client.rs | 33 +++++++++++++++++++++++++++++++++ src/response.rs | 44 ++------------------------------------------ 3 files changed, 36 insertions(+), 43 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index af54e06..79857f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "slinger" #改这个 -version = "0.1.13" +version = "0.1.14" edition = "2021" description = "An HTTP Client for Rust designed for hackers." homepage = "https://github.com/emo-crab/slinger" diff --git a/src/client.rs b/src/client.rs index 91e87f7..bbd5a98 100644 --- a/src/client.rs +++ b/src/client.rs @@ -281,6 +281,8 @@ impl Client { let port = u.port_u16().unwrap_or_default(); format!("{}{}{}", scheme, host, port) }; + // 先默认尝试复用连接 + let mut keepalive = true; loop { let mut record = HTTPRecord::default(); for (k, v) in self.inner.header.iter() { @@ -298,6 +300,17 @@ impl Client { } } } + // 配置了keepalive和服务器支持复用才设置请求头 + if self.inner.keepalive && keepalive { + request.headers_mut().insert( + http::header::CONNECTION, + HeaderValue::from_static("keep-alive"), + ); + } else { + request + .headers_mut() + .insert(http::header::CONNECTION, HeaderValue::from_static("close")); + } record.record_request(&request); let socket = conn .entry(uniq_key(&cur_uri)) @@ -318,13 +331,20 @@ impl Client { *s = self.inner.connector.connect_with_uri(&cur_uri)?; } } + if !self.inner.keepalive { + conn.remove(&uniq_key(&cur_uri)); + } else { + keepalive = true; + } } _ => { conn.remove(&uniq_key(&cur_uri)); + keepalive = false; } } } else { conn.remove(&uniq_key(&cur_uri)); + keepalive = false; } // 原始请求不跳转 if request.raw_request().is_some() { @@ -522,6 +542,7 @@ impl ClientBuilder { let connector = connector_builder.build()?; Ok(Client { inner: ClientRef { + keepalive: config.keepalive, timeout: config.timeout, #[cfg(feature = "cookie")] cookie_store: config.cookie_store, @@ -650,6 +671,14 @@ impl ClientBuilder { self.config.nodelay = enabled; self } + // HTTP keepalive options + + /// + /// Default is `false`. + pub fn keepalive(mut self, keepalive: bool) -> ClientBuilder { + self.config.keepalive = keepalive; + self + } #[cfg(feature = "tls")] // TLS options /// Add a custom root certificate. @@ -813,6 +842,7 @@ struct Config { referer: bool, proxy: Option, nodelay: bool, + keepalive: bool, #[cfg(feature = "tls")] root_certs: Vec, #[cfg(feature = "tls")] @@ -839,6 +869,7 @@ impl Debug for Config { .field("proxy", &self.proxy) .field("timeout", &self.timeout) .field("nodelay", &self.nodelay) + .field("keepalive", &self.keepalive) .field("hostname_verification", &self.hostname_verification) .field("certs_verification", &self.certs_verification) .field("redirect_policy", &self.redirect_policy) @@ -855,6 +886,7 @@ impl Default for Config { referer: false, proxy: None, nodelay: false, + keepalive: false, #[cfg(feature = "tls")] root_certs: vec![], #[cfg(feature = "tls")] @@ -876,6 +908,7 @@ impl Default for Config { #[derive(Clone, Debug)] struct ClientRef { + keepalive: bool, timeout: Option, #[cfg(feature = "cookie")] cookie_store: Option>, diff --git a/src/response.rs b/src/response.rs index dfc3275..c03017d 100644 --- a/src/response.rs +++ b/src/response.rs @@ -429,50 +429,10 @@ impl ResponseBuilder { config, } } - fn read_lines(&mut self) -> Option> { - let mut lines = Vec::new(); - let mut buffer = vec![0; 1]; // 定义一个缓冲区 - let mut total_bytes_read = 0; - let mut start = Instant::now(); - let timeout = self.config.timeout; - loop { - match self.reader.read(&mut buffer) { - Ok(0) => break, - Ok(n) => { - lines.extend_from_slice(&buffer[..n]); - total_bytes_read += n; - // 当有读取到数据的时候重置计时器 - start = Instant::now(); - if buffer[0] == b'\n' { - break; - } - } - Err(ref err) if err.kind() == std::io::ErrorKind::WouldBlock => { - // 如果没有数据可读,但超时尚未到达,可以在这里等待或重试 - // 当已经有数据了或者触发超时就跳出循环,防止防火墙一直把会话挂着不释放 - if total_bytes_read > 0 { - break; - } else if let Some(to) = timeout { - if start.elapsed() > to { - break; - } - } - std::thread::sleep(Duration::from_micros(100)); - } - Err(_err) => break, - } - // 检查是否读取到了全部数据,如果是,则退出循环 - if let Some(limit) = self.config.max_read { - if total_bytes_read >= limit as usize { - break; - } - } - } - Some(lines) - } fn parser_version(&mut self) -> Result<(http::Version, http::StatusCode)> { let (mut vf, mut sf) = (false, false); - if let Some(lines) = self.read_lines() { + let mut lines = Vec::new(); + if let Ok(_length) = self.reader.read_until(b'\n', &mut lines) { let mut version = http::Version::default(); let mut sc = http::StatusCode::default(); for (index, vc) in lines.splitn(3, |b| b == &b' ').enumerate() {