Skip to content

Commit 3cdc5f5

Browse files
committed
TcpStreamTask
1 parent f4935f8 commit 3cdc5f5

File tree

1 file changed

+297
-0
lines changed

1 file changed

+297
-0
lines changed

src/tcp/sys.rs

+297
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
use std::io;
2+
use std::io::Error;
3+
use std::pin::Pin;
4+
use std::time::Duration;
5+
6+
use bytes::{Buf, BytesMut};
7+
use pnet_packet::ip::IpNextHeaderProtocols;
8+
use tokio::sync::mpsc::error::TryRecvError;
9+
use tokio::sync::mpsc::{Receiver, Sender};
10+
use tokio::time::Instant;
11+
12+
use crate::ip_stack::{IpStack, NetworkTuple, TransportPacket};
13+
use crate::tcp::tcb::Tcb;
14+
15+
#[derive(Debug)]
16+
pub struct TcpStreamTask {
17+
tcb: Tcb,
18+
ip_stack: IpStack,
19+
application_layer_receiver: Receiver<BytesMut>,
20+
last_buffer: Option<BytesMut>,
21+
packet_receiver: Receiver<TransportPacket>,
22+
application_layer_sender: Sender<BytesMut>,
23+
timeout: Duration,
24+
read_half_closed: bool,
25+
write_half_closed: bool,
26+
retransmission: bool,
27+
}
28+
29+
impl Drop for TcpStreamTask {
30+
fn drop(&mut self) {
31+
let peer_addr = self.tcb.peer_addr();
32+
let local_addr = self.tcb.local_addr();
33+
let network_tuple = NetworkTuple::new(peer_addr, local_addr, IpNextHeaderProtocols::Tcp);
34+
self.ip_stack.remove_tcp_socket(&network_tuple);
35+
}
36+
}
37+
38+
impl TcpStreamTask {
39+
pub fn new(
40+
tcb: Tcb,
41+
ip_stack: IpStack,
42+
application_layer_sender: Sender<BytesMut>,
43+
application_layer_receiver: Receiver<BytesMut>,
44+
packet_receiver: Receiver<TransportPacket>,
45+
) -> Self {
46+
let timeout = ip_stack.config.retransmission_timeout;
47+
Self {
48+
tcb,
49+
ip_stack,
50+
application_layer_receiver,
51+
last_buffer: None,
52+
packet_receiver,
53+
application_layer_sender,
54+
timeout,
55+
read_half_closed: false,
56+
write_half_closed: false,
57+
retransmission: false,
58+
}
59+
}
60+
}
61+
62+
impl TcpStreamTask {
63+
pub async fn run(&mut self) -> io::Result<()> {
64+
let result = self.run0().await;
65+
self.push_application_layer().await;
66+
result
67+
}
68+
pub async fn run0(&mut self) -> io::Result<()> {
69+
loop {
70+
if self.tcb.is_close() {
71+
return Ok(());
72+
}
73+
if self.tcb.decelerate() {
74+
// todo Need for more efficient flow control
75+
let target_duration = Duration::from_micros(100);
76+
let start = Instant::now();
77+
for _ in 0..1000 {
78+
if start.elapsed() >= target_duration {
79+
break;
80+
}
81+
tokio::task::yield_now().await;
82+
}
83+
}
84+
let deadline = if let Some(v) = self.tcb.time_wait() {
85+
Some(v.into())
86+
} else {
87+
self.tcb.write_timeout().map(|v| v.into())
88+
};
89+
90+
let data = if let Some(deadline) = deadline {
91+
if self.only_recv_in() {
92+
self.recv_in_timeout(deadline).await
93+
} else {
94+
self.recv_timeout(deadline).await
95+
}
96+
} else {
97+
if self.only_recv_in() {
98+
self.recv_in().await
99+
} else {
100+
self.recv().await
101+
}
102+
};
103+
if !self.write_half_closed && !self.retransmission {
104+
self.flush().await?;
105+
}
106+
match data {
107+
TaskRecvData::In(buf) => {
108+
if let Some(reply_packet) = self.tcb.push_packet(buf) {
109+
self.ip_stack.send_packet(reply_packet).await?;
110+
}
111+
self.push_application_layer().await;
112+
}
113+
TaskRecvData::Out(buf) => {
114+
self.write(buf).await?;
115+
}
116+
TaskRecvData::InClose => return Err(Error::new(io::ErrorKind::Other, "NetworkDown")),
117+
TaskRecvData::OutClose => {
118+
assert!(self.last_buffer.is_none());
119+
self.write_half_closed = true;
120+
let packet = self.tcb.fin_packet();
121+
self.ip_stack.send_packet(packet).await?;
122+
self.tcb.sent_fin();
123+
}
124+
TaskRecvData::Timeout => {
125+
self.tcb.timeout();
126+
if self.tcb.is_close() {
127+
return Ok(());
128+
}
129+
if self.tcb.cannot_write() {
130+
let packet = self.tcb.fin_packet();
131+
self.ip_stack.send_packet(packet).await?;
132+
}
133+
}
134+
}
135+
if self.try_retransmission().await? {
136+
self.retransmission = true;
137+
} else {
138+
self.retransmission = false;
139+
self.try_send_ack().await?;
140+
}
141+
if !self.read_half_closed && self.tcb.cannot_read() {
142+
self.close_read().await;
143+
}
144+
}
145+
}
146+
fn only_recv_in(&self) -> bool {
147+
self.retransmission || self.last_buffer.is_some() || self.write_half_closed || self.tcb.limit()
148+
}
149+
async fn push_application_layer(&mut self) {
150+
if self.read_half_closed {
151+
self.tcb.read_none();
152+
} else {
153+
let len = self.tcb.readable();
154+
if len > 0 {
155+
let mut buffer = BytesMut::zeroed(len);
156+
let len = self.tcb.read(&mut buffer);
157+
buffer.truncate(len);
158+
if !buffer.is_empty() {
159+
match self.application_layer_sender.send(buffer).await {
160+
Ok(_) => {}
161+
Err(_e) => {
162+
// Ignore the closure of reading
163+
self.read_half_closed = true;
164+
}
165+
}
166+
}
167+
}
168+
if self.tcb.cannot_read() {
169+
self.close_read().await;
170+
}
171+
}
172+
}
173+
async fn close_read(&mut self) {
174+
_ = self.application_layer_sender.send(BytesMut::new()).await;
175+
self.read_half_closed = true;
176+
}
177+
async fn write_slice0(tcb: &mut Tcb, ip_stack: &IpStack, mut buf: &[u8]) -> io::Result<usize> {
178+
let len = buf.len();
179+
while !buf.is_empty() {
180+
if let Some((packet, len)) = tcb.write(&buf) {
181+
if len == 0 {
182+
break;
183+
}
184+
ip_stack.send_packet(packet).await?;
185+
buf = &buf[len..];
186+
} else {
187+
break;
188+
}
189+
}
190+
Ok(len - buf.len())
191+
}
192+
async fn write_slice(&mut self, buf: &[u8]) -> io::Result<usize> {
193+
Self::write_slice0(&mut self.tcb, &self.ip_stack, buf).await
194+
}
195+
async fn write(&mut self, mut buf: BytesMut) -> io::Result<usize> {
196+
let len = self.write_slice(&buf).await?;
197+
if len != buf.len() {
198+
// Buffer is full
199+
buf.advance(len);
200+
self.last_buffer.replace(buf);
201+
}
202+
Ok(len)
203+
}
204+
async fn flush(&mut self) -> io::Result<()> {
205+
if let Some(buf) = self.last_buffer.as_mut() {
206+
let len = Self::write_slice0(&mut self.tcb, &self.ip_stack, buf).await?;
207+
if buf.len() == len {
208+
self.last_buffer.take();
209+
} else {
210+
buf.advance(len);
211+
}
212+
}
213+
Ok(())
214+
}
215+
216+
async fn try_retransmission(&mut self) -> io::Result<bool> {
217+
if self.write_half_closed {
218+
return Ok(false);
219+
}
220+
if let Some(v) = self.tcb.retransmission() {
221+
self.ip_stack.send_packet(v).await?;
222+
return Ok(true);
223+
}
224+
if self.tcb.no_inflight_packet() {
225+
return Ok(false);
226+
}
227+
if self.tcb.need_retransmission() {
228+
if let Some(v) = self.tcb.retransmission() {
229+
self.ip_stack.send_packet(v).await?;
230+
return Ok(true);
231+
}
232+
}
233+
Ok(false)
234+
}
235+
async fn try_send_ack(&mut self) -> io::Result<()> {
236+
if self.tcb.need_ack() {
237+
self.tcb.set_ack();
238+
let packet = self.tcb.ack_packet();
239+
self.ip_stack.send_packet(packet).await?;
240+
}
241+
Ok(())
242+
}
243+
244+
async fn recv_timeout(&mut self, deadline: Instant) -> TaskRecvData {
245+
tokio::select! {
246+
rs=self.packet_receiver.recv()=>{
247+
rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
248+
}
249+
rs=self.application_layer_receiver.recv()=>{
250+
rs.map(|v| TaskRecvData::Out(v)).unwrap_or(TaskRecvData::OutClose)
251+
}
252+
_=tokio::time::sleep_until(deadline)=>{
253+
TaskRecvData::Timeout
254+
}
255+
}
256+
}
257+
async fn recv(&mut self) -> TaskRecvData {
258+
tokio::select! {
259+
rs=self.packet_receiver.recv()=>{
260+
rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
261+
}
262+
rs=self.application_layer_receiver.recv()=>{
263+
rs.map(|v| TaskRecvData::Out(v)).unwrap_or(TaskRecvData::OutClose)
264+
}
265+
}
266+
}
267+
fn try_recv_in(&mut self) -> Option<TaskRecvData> {
268+
match self.packet_receiver.try_recv() {
269+
Ok(rs) => Some(TaskRecvData::In(rs.buf)),
270+
Err(e) => match e {
271+
TryRecvError::Empty => None,
272+
TryRecvError::Disconnected => Some(TaskRecvData::InClose),
273+
},
274+
}
275+
}
276+
async fn recv_in_timeout(&mut self, deadline: Instant) -> TaskRecvData {
277+
tokio::time::timeout_at(deadline, self.recv_in())
278+
.await
279+
.unwrap_or_else(|_| TaskRecvData::Timeout)
280+
}
281+
async fn recv_in(&mut self) -> TaskRecvData {
282+
let rs = self.packet_receiver.recv().await;
283+
rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
284+
}
285+
async fn recv_out(&mut self) -> TaskRecvData {
286+
let rs = self.application_layer_receiver.recv().await;
287+
rs.map(|v| TaskRecvData::Out(v)).unwrap_or(TaskRecvData::OutClose)
288+
}
289+
}
290+
291+
enum TaskRecvData {
292+
In(BytesMut),
293+
Out(BytesMut),
294+
InClose,
295+
OutClose,
296+
Timeout,
297+
}

0 commit comments

Comments
 (0)