Skip to content

Commit a960ccb

Browse files
Merge pull request #237 from ahmedcharles/example
Add an example of using tungstenite with a custom accept.
2 parents 61f5926 + 4e1559a commit a960ccb

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ version = "0.22.1"
6565

6666
[dev-dependencies]
6767
futures-channel = "0.3"
68+
hyper = { version = "0.14", default-features = false, features = ["http1", "server", "tcp"] }
6869
tokio = { version = "1.0.0", default-features = false, features = ["io-std", "macros", "net", "rt-multi-thread", "time"] }
6970
url = "2.0.0"
7071
env_logger = "0.9"

examples/server-custom-accept.rs

+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
//! A chat server that broadcasts a message to all connections.
2+
//!
3+
//! This is a simple line-based server which accepts WebSocket connections,
4+
//! reads lines from those connections, and broadcasts the lines to all other
5+
//! connected clients.
6+
//!
7+
//! You can test this out by running:
8+
//!
9+
//! cargo run --example server 127.0.0.1:12345
10+
//!
11+
//! And then in another window run:
12+
//!
13+
//! cargo run --example client ws://127.0.0.1:12345/socket
14+
//!
15+
//! You can run the second command in multiple windows and then chat between the
16+
//! two, seeing the messages from the other client as they're received. For all
17+
//! connected clients they'll all join the same room and see everyone else's
18+
//! messages.
19+
20+
use std::{
21+
collections::HashMap,
22+
convert::Infallible,
23+
env,
24+
net::SocketAddr,
25+
sync::{Arc, Mutex},
26+
};
27+
28+
use hyper::{
29+
header::{
30+
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
31+
UPGRADE,
32+
},
33+
server::conn::AddrStream,
34+
service::{make_service_fn, service_fn},
35+
upgrade::Upgraded,
36+
Body, Method, Request, Response, Server, StatusCode, Version,
37+
};
38+
39+
use futures_channel::mpsc::{unbounded, UnboundedSender};
40+
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};
41+
42+
use tokio_tungstenite::WebSocketStream;
43+
use tungstenite::{
44+
handshake::derive_accept_key,
45+
protocol::{Message, Role},
46+
};
47+
48+
type Tx = UnboundedSender<Message>;
49+
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
50+
51+
async fn handle_connection(
52+
peer_map: PeerMap,
53+
ws_stream: WebSocketStream<Upgraded>,
54+
addr: SocketAddr,
55+
) {
56+
println!("WebSocket connection established: {}", addr);
57+
58+
// Insert the write part of this peer to the peer map.
59+
let (tx, rx) = unbounded();
60+
peer_map.lock().unwrap().insert(addr, tx);
61+
62+
let (outgoing, incoming) = ws_stream.split();
63+
64+
let broadcast_incoming = incoming.try_for_each(|msg| {
65+
println!("Received a message from {}: {}", addr, msg.to_text().unwrap());
66+
let peers = peer_map.lock().unwrap();
67+
68+
// We want to broadcast the message to everyone except ourselves.
69+
let broadcast_recipients =
70+
peers.iter().filter(|(peer_addr, _)| peer_addr != &&addr).map(|(_, ws_sink)| ws_sink);
71+
72+
for recp in broadcast_recipients {
73+
recp.unbounded_send(msg.clone()).unwrap();
74+
}
75+
76+
future::ok(())
77+
});
78+
79+
let receive_from_others = rx.map(Ok).forward(outgoing);
80+
81+
pin_mut!(broadcast_incoming, receive_from_others);
82+
future::select(broadcast_incoming, receive_from_others).await;
83+
84+
println!("{} disconnected", &addr);
85+
peer_map.lock().unwrap().remove(&addr);
86+
}
87+
88+
async fn handle_request(
89+
peer_map: PeerMap,
90+
mut req: Request<Body>,
91+
addr: SocketAddr,
92+
) -> Result<Response<Body>, Infallible> {
93+
println!("Received a new, potentially ws handshake");
94+
println!("The request's path is: {}", req.uri().path());
95+
println!("The request's headers are:");
96+
for (ref header, _value) in req.headers() {
97+
println!("* {}", header);
98+
}
99+
let upgrade = HeaderValue::from_static("Upgrade");
100+
let websocket = HeaderValue::from_static("websocket");
101+
let headers = req.headers();
102+
let key = headers.get(SEC_WEBSOCKET_KEY);
103+
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
104+
if req.method() != Method::GET
105+
|| req.version() < Version::HTTP_11
106+
|| !headers
107+
.get(CONNECTION)
108+
.and_then(|h| h.to_str().ok())
109+
.map(|h| {
110+
h.split(|c| c == ' ' || c == ',')
111+
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap()))
112+
})
113+
.unwrap_or(false)
114+
|| !headers
115+
.get(UPGRADE)
116+
.and_then(|h| h.to_str().ok())
117+
.map(|h| h.eq_ignore_ascii_case("websocket"))
118+
.unwrap_or(false)
119+
|| !headers.get(SEC_WEBSOCKET_VERSION).map(|h| h == "13").unwrap_or(false)
120+
|| key.is_none()
121+
|| req.uri() != "/socket"
122+
{
123+
return Ok(Response::new(Body::from("Hello World!")));
124+
}
125+
let ver = req.version();
126+
tokio::task::spawn(async move {
127+
match hyper::upgrade::on(&mut req).await {
128+
Ok(upgraded) => {
129+
handle_connection(
130+
peer_map,
131+
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
132+
addr,
133+
)
134+
.await;
135+
}
136+
Err(e) => println!("upgrade error: {}", e),
137+
}
138+
});
139+
let mut res = Response::new(Body::empty());
140+
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
141+
*res.version_mut() = ver;
142+
res.headers_mut().append(CONNECTION, upgrade);
143+
res.headers_mut().append(UPGRADE, websocket);
144+
res.headers_mut().append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
145+
// Let's add an additional header to our response to the client.
146+
res.headers_mut().append("MyCustomHeader", ":)".parse().unwrap());
147+
res.headers_mut().append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap());
148+
Ok(res)
149+
}
150+
151+
#[tokio::main]
152+
async fn main() -> Result<(), hyper::Error> {
153+
let state = PeerMap::new(Mutex::new(HashMap::new()));
154+
155+
let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string()).parse().unwrap();
156+
157+
let make_svc = make_service_fn(move |conn: &AddrStream| {
158+
let remote_addr = conn.remote_addr();
159+
let state = state.clone();
160+
let service = service_fn(move |req| handle_request(state.clone(), req, remote_addr));
161+
async { Ok::<_, Infallible>(service) }
162+
});
163+
164+
let server = Server::bind(&addr).serve(make_svc);
165+
166+
server.await?;
167+
168+
Ok::<_, hyper::Error>(())
169+
}

0 commit comments

Comments
 (0)