Skip to content

Commit 1ee78bc

Browse files
committed
Add a SockAddr type
We don't want to be stuck only supporting IP sockets
1 parent f11c451 commit 1ee78bc

File tree

5 files changed

+210
-177
lines changed

5 files changed

+210
-177
lines changed

src/lib.rs

+20-7
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
//! # Examples
2121
//!
2222
//! ```no_run
23+
//! use std::net::SocketAddr;
2324
//! use socket2::{Socket, Domain, Type};
2425
//!
2526
//! // create a TCP listener bound to two addresses
2627
//! let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
2728
//!
28-
//! socket.bind(&"127.0.0.1:12345".parse().unwrap()).unwrap();
29-
//! socket.bind(&"127.0.0.1:12346".parse().unwrap()).unwrap();
29+
//! socket.bind(&"127.0.0.1:12345".parse::<SocketAddr>().unwrap().into()).unwrap();
30+
//! socket.bind(&"127.0.0.1:12346".parse::<SocketAddr>().unwrap().into()).unwrap();
3031
//! socket.listen(128).unwrap();
3132
//!
3233
//! let listener = socket.into_tcp_listener();
@@ -45,6 +46,10 @@
4546

4647
use utils::NetInt;
4748

49+
#[cfg(unix)] use libc::{sockaddr_storage, socklen_t};
50+
#[cfg(windows)] use winapi::{SOCKADDR_STORAGE as sockaddr_storage, socklen_t};
51+
52+
mod sockaddr;
4853
mod socket;
4954
mod utils;
5055

@@ -63,13 +68,14 @@ mod utils;
6368
/// # Examples
6469
///
6570
/// ```no_run
66-
/// use socket2::{Socket, Domain, Type};
71+
/// use std::net::SocketAddr;
72+
/// use socket2::{Socket, Domain, Type, SockAddr};
6773
///
6874
/// // create a TCP listener bound to two addresses
6975
/// let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
7076
///
71-
/// socket.bind(&"127.0.0.1:12345".parse().unwrap()).unwrap();
72-
/// socket.bind(&"127.0.0.1:12346".parse().unwrap()).unwrap();
77+
/// socket.bind(&"127.0.0.1:12345".parse::<SocketAddr>().unwrap().into()).unwrap();
78+
/// socket.bind(&"127.0.0.1:12346".parse::<SocketAddr>().unwrap().into()).unwrap();
7379
/// socket.listen(128).unwrap();
7480
///
7581
/// let listener = socket.into_tcp_listener();
@@ -79,6 +85,15 @@ pub struct Socket {
7985
inner: sys::Socket,
8086
}
8187

88+
/// The address of a socket.
89+
///
90+
/// `SockAddr`s may be constructed directly to and from the standard library
91+
/// `SocketAddr`, `SocketAddrV4`, and `SocketAddrV6` types.
92+
pub struct SockAddr {
93+
storage: sockaddr_storage,
94+
len: socklen_t,
95+
}
96+
8297
/// Specification of the communication domain for a socket.
8398
///
8499
/// This is a newtype wrapper around an integer which provides a nicer API in
@@ -111,5 +126,3 @@ pub struct Type(i32);
111126
pub struct Protocol(i32);
112127

113128
fn hton<I: NetInt>(i: I) -> I { i.to_be() }
114-
115-
fn ntoh<I: NetInt>(i: I) -> I { I::from_be(i) }

src/sockaddr.rs

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
use std::net::{SocketAddrV4, SocketAddrV6, SocketAddr};
2+
use std::mem;
3+
use std::ptr;
4+
use std::fmt;
5+
6+
#[cfg(unix)]
7+
use libc::{sockaddr, sockaddr_storage, sa_family_t, socklen_t, AF_INET, AF_INET6};
8+
#[cfg(windows)]
9+
use winapi::{SOCKADDR as sockaddr, SOCKADDR_STORAGE as sockaddr_storage,
10+
ADDRESS_FAMILY as sa_family_t, socklen_t, AF_INET, AF_INET6};
11+
12+
use SockAddr;
13+
14+
impl fmt::Debug for SockAddr {
15+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
16+
let mut builder = fmt.debug_struct("SockAddr");
17+
builder.field("family", &self.family());
18+
if let Some(addr) = self.as_inet() {
19+
builder.field("inet", &addr);
20+
} else if let Some(addr) = self.as_inet6() {
21+
builder.field("inet6", &addr);
22+
}
23+
builder.finish()
24+
}
25+
}
26+
27+
impl SockAddr {
28+
/// Constructs a `SockAddr` from its raw components.
29+
pub unsafe fn from_raw_parts(addr: *const sockaddr, len: socklen_t) -> SockAddr {
30+
let mut storage = mem::uninitialized::<sockaddr_storage>();
31+
ptr::copy_nonoverlapping(addr as *const _ as *const u8,
32+
&mut storage as *mut _ as *mut u8,
33+
len as usize);
34+
35+
SockAddr {
36+
storage: storage,
37+
len: len,
38+
}
39+
}
40+
41+
unsafe fn as_<T>(&self, family: sa_family_t) -> Option<T> {
42+
if self.storage.ss_family != family {
43+
return None;
44+
}
45+
46+
Some(mem::transmute_copy(&self.storage))
47+
}
48+
49+
/// Returns this address as a `SocketAddrV4` if it is in the `AF_INET`
50+
/// family.
51+
pub fn as_inet(&self) -> Option<SocketAddrV4> {
52+
unsafe { self.as_(AF_INET as sa_family_t) }
53+
}
54+
55+
/// Returns this address as a `SocketAddrV4` if it is in the `AF_INET6`
56+
/// family.
57+
pub fn as_inet6(&self) -> Option<SocketAddrV6> {
58+
unsafe { self.as_(AF_INET6 as sa_family_t) }
59+
}
60+
61+
/// Returns this address's family.
62+
pub fn family(&self) -> sa_family_t {
63+
self.storage.ss_family
64+
}
65+
66+
/// Returns the size of this address in bytes.
67+
pub fn len(&self) -> socklen_t {
68+
self.len
69+
}
70+
71+
/// Returns a raw pointer to the address.
72+
pub fn as_ptr(&self) -> *const sockaddr {
73+
&self.storage as *const _ as *const _
74+
}
75+
}
76+
77+
// SocketAddrV4 and SocketAddrV6 are just wrappers around sockaddr_in and sockaddr_in6
78+
79+
impl From<SocketAddrV4> for SockAddr {
80+
fn from(addr: SocketAddrV4) -> SockAddr {
81+
unsafe {
82+
SockAddr::from_raw_parts(&addr as *const _ as *const _,
83+
mem::size_of::<SocketAddrV4>() as socklen_t)
84+
}
85+
}
86+
}
87+
88+
89+
impl From<SocketAddrV6> for SockAddr {
90+
fn from(addr: SocketAddrV6) -> SockAddr {
91+
unsafe {
92+
SockAddr::from_raw_parts(&addr as *const _ as *const _,
93+
mem::size_of::<SocketAddrV6>() as socklen_t)
94+
}
95+
}
96+
}
97+
98+
impl From<SocketAddr> for SockAddr {
99+
fn from(addr: SocketAddr) -> SockAddr {
100+
match addr {
101+
SocketAddr::V4(addr) => addr.into(),
102+
SocketAddr::V6(addr) => addr.into(),
103+
}
104+
}
105+
}
106+
107+
#[cfg(test)]
108+
mod test {
109+
use super::*;
110+
111+
#[test]
112+
fn inet() {
113+
let raw = "127.0.0.1:80".parse::<SocketAddrV4>().unwrap();
114+
let addr = SockAddr::from(raw);
115+
assert!(addr.as_inet6().is_none());
116+
let addr = addr.as_inet().unwrap();
117+
assert_eq!(raw, addr);
118+
}
119+
120+
#[test]
121+
fn inet6() {
122+
let raw = "[2001:db8::ff00:42:8329]:80"
123+
.parse::<SocketAddrV6>()
124+
.unwrap();
125+
let addr = SockAddr::from(raw);
126+
assert!(addr.as_inet().is_none());
127+
let addr = addr.as_inet6().unwrap();
128+
assert_eq!(raw, addr);
129+
}
130+
}

src/socket.rs

+15-13
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
use std::fmt;
1212
use std::io::{self, Read, Write};
13-
use std::net::{self, SocketAddr, Ipv4Addr, Ipv6Addr, Shutdown};
13+
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
1414
use std::time::Duration;
1515

1616
#[cfg(unix)]
@@ -19,7 +19,7 @@ use libc as c;
1919
use winapi as c;
2020

2121
use sys;
22-
use {Socket, Protocol, Domain, Type};
22+
use {Socket, Protocol, Domain, Type, SockAddr};
2323

2424
impl Socket {
2525
/// Creates a new socket ready to be configured.
@@ -58,7 +58,7 @@ impl Socket {
5858
///
5959
/// An error will be returned if `listen` or `connect` has already been
6060
/// called on this builder.
61-
pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
61+
pub fn connect(&self, addr: &SockAddr) -> io::Result<()> {
6262
self.inner.connect(addr)
6363
}
6464

@@ -81,15 +81,15 @@ impl Socket {
8181
///
8282
/// If the connection request times out, it may still be processing in the
8383
/// background - a second call to `connect` or `connect_timeout` may fail.
84-
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
84+
pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> {
8585
self.inner.connect_timeout(addr, timeout)
8686
}
8787

8888
/// Binds this socket to the specified address.
8989
///
9090
/// This function directly corresponds to the bind(2) function on Windows
9191
/// and Unix.
92-
pub fn bind(&self, addr: &SocketAddr) -> io::Result<()> {
92+
pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
9393
self.inner.bind(addr)
9494
}
9595

@@ -110,19 +110,19 @@ impl Socket {
110110
/// This function will block the calling thread until a new connection is
111111
/// established. When established, the corresponding `Socket` and the
112112
/// remote peer's address will be returned.
113-
pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
113+
pub fn accept(&self) -> io::Result<(Socket, SockAddr)> {
114114
self.inner.accept().map(|(socket, addr)| {
115115
(Socket { inner: socket }, addr)
116116
})
117117
}
118118

119119
/// Returns the socket address of the local half of this TCP connection.
120-
pub fn local_addr(&self) -> io::Result<SocketAddr> {
120+
pub fn local_addr(&self) -> io::Result<SockAddr> {
121121
self.inner.local_addr()
122122
}
123123

124124
/// Returns the socket address of the remote peer of this TCP connection.
125-
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
125+
pub fn peer_addr(&self) -> io::Result<SockAddr> {
126126
self.inner.peer_addr()
127127
}
128128

@@ -184,7 +184,7 @@ impl Socket {
184184

185185
/// Receives data from the socket. On success, returns the number of bytes
186186
/// read and the address from whence the data came.
187-
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
187+
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
188188
self.inner.recv_from(buf)
189189
}
190190

@@ -195,7 +195,7 @@ impl Socket {
195195
///
196196
/// On success, returns the number of bytes peeked and the address from
197197
/// whence the data came.
198-
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
198+
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
199199
self.inner.peek_from(buf)
200200
}
201201

@@ -214,7 +214,7 @@ impl Socket {
214214
///
215215
/// This is typically used on UDP or datagram-oriented sockets. On success
216216
/// returns the number of bytes that were sent.
217-
pub fn send_to(&self, buf: &[u8], addr: &SocketAddr) -> io::Result<usize> {
217+
pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result<usize> {
218218
self.inner.send_to(buf, addr)
219219
}
220220

@@ -693,12 +693,14 @@ impl From<Protocol> for i32 {
693693

694694
#[cfg(test)]
695695
mod test {
696+
use std::net::SocketAddr;
697+
696698
use super::*;
697699

698700
#[test]
699701
fn connect_timeout_unrouteable() {
700702
// this IP is unroutable, so connections should always time out
701-
let addr: SocketAddr = "10.255.255.1:80".parse().unwrap();
703+
let addr = "10.255.255.1:80".parse::<SocketAddr>().unwrap().into();
702704

703705
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
704706
match socket.connect_timeout(&addr, Duration::from_millis(250)) {
@@ -711,7 +713,7 @@ mod test {
711713
#[test]
712714
fn connect_timeout_valid() {
713715
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
714-
socket.bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
716+
socket.bind(&"127.0.0.1:0".parse::<SocketAddr>().unwrap().into()).unwrap();
715717
socket.listen(128).unwrap();
716718

717719
let addr = socket.local_addr().unwrap();

0 commit comments

Comments
 (0)