parent
31dac03929
commit
2e61ecbbb1
@ -0,0 +1,190 @@ |
||||
//! A chat server that broadcasts a message to all connections.
|
||||
//!
|
||||
//! This is a simple line-based server which accepts WebSocket connections,
|
||||
//! reads lines from those connections, and broadcasts the lines to all other
|
||||
//! connected clients.
|
||||
//!
|
||||
//! You can test this out by running:
|
||||
//!
|
||||
//! cargo run --example server 127.0.0.1:12345
|
||||
//!
|
||||
//! And then in another window run:
|
||||
//!
|
||||
//! cargo run --example client ws://127.0.0.1:12345/socket
|
||||
//!
|
||||
//! You can run the second command in multiple windows and then chat between the
|
||||
//! two, seeing the messages from the other client as they're received. For all
|
||||
//! connected clients they'll all join the same room and see everyone else's
|
||||
//! messages.
|
||||
|
||||
use std::{ |
||||
collections::HashMap, |
||||
convert::Infallible, |
||||
env, |
||||
net::SocketAddr, |
||||
sync::{Arc, Mutex}, |
||||
}; |
||||
|
||||
use hyper::{ |
||||
header::{ |
||||
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, |
||||
UPGRADE, |
||||
}, |
||||
server::conn::AddrStream, |
||||
service::{make_service_fn, service_fn}, |
||||
upgrade::Upgraded, |
||||
Body, Method, Request, Response, Server, StatusCode, Version, |
||||
}; |
||||
|
||||
use futures_channel::mpsc::{unbounded, UnboundedSender}; |
||||
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt}; |
||||
|
||||
use async_tungstenite::{tokio::TokioAdapter, WebSocketStream}; |
||||
use tungstenite::{ |
||||
handshake::derive_accept_key, |
||||
protocol::{Message, Role}, |
||||
}; |
||||
|
||||
type Tx = UnboundedSender<Message>; |
||||
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>; |
||||
|
||||
async fn handle_connection( |
||||
peer_map: PeerMap, |
||||
ws_stream: WebSocketStream<TokioAdapter<Upgraded>>, |
||||
addr: SocketAddr, |
||||
) { |
||||
println!("WebSocket connection established: {}", addr); |
||||
|
||||
// Insert the write part of this peer to the peer map.
|
||||
let (tx, rx) = unbounded(); |
||||
peer_map.lock().unwrap().insert(addr, tx); |
||||
|
||||
let (outgoing, incoming) = ws_stream.split(); |
||||
|
||||
let broadcast_incoming = incoming.try_for_each(|msg| { |
||||
println!( |
||||
"Received a message from {}: {}", |
||||
addr, |
||||
msg.to_text().unwrap() |
||||
); |
||||
let peers = peer_map.lock().unwrap(); |
||||
|
||||
// We want to broadcast the message to everyone except ourselves.
|
||||
let broadcast_recipients = peers |
||||
.iter() |
||||
.filter(|(peer_addr, _)| peer_addr != &&addr) |
||||
.map(|(_, ws_sink)| ws_sink); |
||||
|
||||
for recp in broadcast_recipients { |
||||
recp.unbounded_send(msg.clone()).unwrap(); |
||||
} |
||||
|
||||
future::ok(()) |
||||
}); |
||||
|
||||
let receive_from_others = rx.map(Ok).forward(outgoing); |
||||
|
||||
pin_mut!(broadcast_incoming, receive_from_others); |
||||
future::select(broadcast_incoming, receive_from_others).await; |
||||
|
||||
println!("{} disconnected", &addr); |
||||
peer_map.lock().unwrap().remove(&addr); |
||||
} |
||||
|
||||
async fn handle_request( |
||||
peer_map: PeerMap, |
||||
mut req: Request<Body>, |
||||
addr: SocketAddr, |
||||
) -> Result<Response<Body>, Infallible> { |
||||
println!("Received a new, potentially ws handshake"); |
||||
println!("The request's path is: {}", req.uri().path()); |
||||
println!("The request's headers are:"); |
||||
for (ref header, _value) in req.headers() { |
||||
println!("* {}", header); |
||||
} |
||||
let upgrade = HeaderValue::from_static("Upgrade"); |
||||
let websocket = HeaderValue::from_static("websocket"); |
||||
let headers = req.headers(); |
||||
let key = headers.get(SEC_WEBSOCKET_KEY); |
||||
let derived = key.map(|k| derive_accept_key(k.as_bytes())); |
||||
if req.method() != Method::GET |
||||
|| req.version() < Version::HTTP_11 |
||||
|| !headers |
||||
.get(CONNECTION) |
||||
.and_then(|h| h.to_str().ok()) |
||||
.map(|h| { |
||||
h.split(|c| c == ' ' || c == ',') |
||||
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap())) |
||||
}) |
||||
.unwrap_or(false) |
||||
|| !headers |
||||
.get(UPGRADE) |
||||
.and_then(|h| h.to_str().ok()) |
||||
.map(|h| h.eq_ignore_ascii_case("websocket")) |
||||
.unwrap_or(false) |
||||
|| !headers |
||||
.get(SEC_WEBSOCKET_VERSION) |
||||
.map(|h| h == "13") |
||||
.unwrap_or(false) |
||||
|| key.is_none() |
||||
|| req.uri() != "/socket" |
||||
{ |
||||
return Ok(Response::new(Body::from("Hello World!"))); |
||||
} |
||||
let ver = req.version(); |
||||
tokio::task::spawn(async move { |
||||
match hyper::upgrade::on(&mut req).await { |
||||
Ok(upgraded) => { |
||||
handle_connection( |
||||
peer_map, |
||||
WebSocketStream::from_raw_socket( |
||||
TokioAdapter::new(upgraded), |
||||
Role::Server, |
||||
None, |
||||
) |
||||
.await, |
||||
addr, |
||||
) |
||||
.await; |
||||
} |
||||
Err(e) => println!("upgrade error: {}", e), |
||||
} |
||||
}); |
||||
let mut res = Response::new(Body::empty()); |
||||
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; |
||||
*res.version_mut() = ver; |
||||
res.headers_mut().append(CONNECTION, upgrade); |
||||
res.headers_mut().append(UPGRADE, websocket); |
||||
res.headers_mut() |
||||
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap()); |
||||
// Let's add an additional header to our response to the client.
|
||||
res.headers_mut() |
||||
.append("MyCustomHeader", ":)".parse().unwrap()); |
||||
res.headers_mut() |
||||
.append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap()); |
||||
Ok(res) |
||||
} |
||||
|
||||
#[tokio::main] |
||||
async fn main() -> Result<(), hyper::Error> { |
||||
let state = PeerMap::new(Mutex::new(HashMap::new())); |
||||
|
||||
let addr = env::args() |
||||
.nth(1) |
||||
.unwrap_or_else(|| "127.0.0.1:8080".to_string()) |
||||
.parse() |
||||
.unwrap(); |
||||
|
||||
let make_svc = make_service_fn(move |conn: &AddrStream| { |
||||
let remote_addr = conn.remote_addr(); |
||||
let state = state.clone(); |
||||
let service = service_fn(move |req| handle_request(state.clone(), req, remote_addr)); |
||||
async { Ok::<_, Infallible>(service) } |
||||
}); |
||||
|
||||
let server = Server::bind(&addr).serve(make_svc); |
||||
|
||||
server.await?; |
||||
|
||||
Ok::<_, hyper::Error>(()) |
||||
} |
Loading…
Reference in new issue