Skip to content

Commit

Permalink
remove all expect call
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive authored and cavivie committed Oct 13, 2024
1 parent 1da226b commit 3a7e0b9
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async fn main_exec(opt: Opt) {
if fd >= 0 {
cfg.raw_fd(fd);
} else {
cfg.tun_name("utun8")
cfg.tun_name(&opt.interface)
.address("10.10.10.2")
.destination("10.10.10.1")
.mtu(tun2::DEFAULT_MTU);
Expand Down
2 changes: 1 addition & 1 deletion src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ impl<T> Future for BoxFuture<'_, T> {
}
}

pub type Runner = BoxFuture<'static, ()>;
pub type Runner = BoxFuture<'static, std::io::Result<()>>;
8 changes: 3 additions & 5 deletions src/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,8 @@ impl StackBuilder {
// ICMP is handled by TCP's Interface.
// smoltcp's interface will always send replies to EchoRequest
if self.enable_icmp && !self.enable_tcp {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Enabling icmp requires enabling tcp",
));
use std::io::{Error, ErrorKind::InvalidInput};
return Err(Error::new(InvalidInput, "ICMP requires TCP"));
}
let icmp_tx = if self.enable_icmp {
tcp_tx.clone()
Expand All @@ -133,7 +131,7 @@ impl StackBuilder {
let udp_socket = udp_rx.map(|udp_rx| UdpSocket::new(udp_rx, stack_tx.clone()));

let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx)?;
(Some(tcp_runner), Some(tcp_listener))
} else {
(None, None)
Expand Down
37 changes: 20 additions & 17 deletions src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ impl TcpListenerRunner {
Runner::new(async move {
let notify = Arc::new(Notify::new());
let (socket_tx, socket_rx) = unbounded_channel::<TcpSocketCreation>();
tokio::select! {
_ = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => {}
_ = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => {}
}
let res = tokio::select! {
v = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => v,
v = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => v,
};
res?;
trace!("VirtDevice::poll thread exited");
Ok(())
})
}

Expand All @@ -93,7 +95,7 @@ impl TcpListenerRunner {
mut tcp_rx: Receiver<AnyIpPktFrame>,
stream_tx: UnboundedSender<TcpStream>,
socket_tx: UnboundedSender<TcpSocketCreation>,
) {
) -> std::io::Result<()> {
while let Some(frame) = tcp_rx.recv().await {
let packet = match IpPacket::new_checked(frame.as_slice()) {
Ok(p) => p,
Expand All @@ -107,7 +109,7 @@ impl TcpListenerRunner {
if matches!(packet.protocol(), IpProtocol::Icmp | IpProtocol::Icmpv6) {
iface_ingress_tx
.send(frame)
.expect("channel already closed");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
iface_ingress_tx_avail.store(true, Ordering::Release);
notify.notify_one();
continue;
Expand Down Expand Up @@ -165,19 +167,20 @@ impl TcpListenerRunner {
notify: notify.clone(),
control: control.clone(),
})
.expect("channel already closed");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
socket_tx
.send(TcpSocketCreation { control, socket })
.expect("channel already closed");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
}

// Pipeline tcp stream packet
iface_ingress_tx
.send(frame)
.expect("channel already closed");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
iface_ingress_tx_avail.store(true, Ordering::Release);
notify.notify_one();
}
Ok(())
}

async fn handle_socket(
Expand All @@ -187,7 +190,7 @@ impl TcpListenerRunner {
iface_ingress_tx_avail: Arc<AtomicBool>,
mut sockets: HashMap<SocketHandle, SharedControl>,
mut socket_rx: UnboundedReceiver<TcpSocketCreation>,
) {
) -> std::io::Result<()> {
let mut socket_set = SocketSet::new(vec![]);
loop {
while let Ok(TcpSocketCreation { control, socket }) = socket_rx.try_recv() {
Expand Down Expand Up @@ -354,9 +357,9 @@ impl TcpListener {
pub(super) fn new(
tcp_rx: Receiver<AnyIpPktFrame>,
stack_tx: Sender<AnyIpPktFrame>,
) -> (Runner, Self) {
) -> std::io::Result<(Runner, Self)> {
let (mut device, iface_ingress_tx, iface_ingress_tx_avail) = VirtualDevice::new(stack_tx);
let iface = Self::create_interface(&mut device);
let iface = Self::create_interface(&mut device)?;

let (stream_tx, stream_rx) = unbounded_channel();

Expand All @@ -370,10 +373,10 @@ impl TcpListener {
HashMap::new(),
);

(runner, Self { stream_rx })
Ok((runner, Self { stream_rx }))
}

fn create_interface<D>(device: &mut D) -> Interface
fn create_interface<D>(device: &mut D) -> std::io::Result<Interface>
where
D: Device + ?Sized,
{
Expand All @@ -391,13 +394,13 @@ impl TcpListener {
iface
.routes_mut()
.add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1))
.expect("IPv4 default route");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
iface
.routes_mut()
.add_default_ipv6_route(Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1))
.expect("IPv6 default route");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
iface.set_any_ip(true);
iface
Ok(iface)
}
}

Expand Down

0 comments on commit 3a7e0b9

Please sign in to comment.