diff --git a/Cargo.toml b/Cargo.toml index 3f588cf..5e0bc1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,6 +104,8 @@ url = "2.0.0" env_logger = "0.9" async-std = { version = "1.0", features = ["attributes", "unstable"] } tokio = { version = "1.0", features = ["full"] } +futures-channel = "0.3" +hyper = { version = "0.14", default-features = false, features = ["http1", "server", "tcp"] } [[example]] name = "autobahn-client" @@ -140,3 +142,7 @@ required-features = ["gio-runtime"] [[example]] name = "tokio-echo" required-features = ["tokio-runtime"] + +[[example]] +name = "server-custom-accept" +required-features = ["tokio-runtime"] diff --git a/examples/server-custom-accept.rs b/examples/server-custom-accept.rs new file mode 100644 index 0000000..bf2f97f --- /dev/null +++ b/examples/server-custom-accept.rs @@ -0,0 +1,190 @@ +//! A chat server that broadcasts a message to all connections. +//! +//! This is a simple line-based server which accepts WebSocket connections, +//! reads lines from those connections, and broadcasts the lines to all other +//! connected clients. +//! +//! You can test this out by running: +//! +//! cargo run --example server 127.0.0.1:12345 +//! +//! And then in another window run: +//! +//! cargo run --example client ws://127.0.0.1:12345/socket +//! +//! You can run the second command in multiple windows and then chat between the +//! two, seeing the messages from the other client as they're received. For all +//! connected clients they'll all join the same room and see everyone else's +//! messages. + +use std::{ + collections::HashMap, + convert::Infallible, + env, + net::SocketAddr, + sync::{Arc, Mutex}, +}; + +use hyper::{ + header::{ + HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, + UPGRADE, + }, + server::conn::AddrStream, + service::{make_service_fn, service_fn}, + upgrade::Upgraded, + Body, Method, Request, Response, Server, StatusCode, Version, +}; + +use futures_channel::mpsc::{unbounded, UnboundedSender}; +use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt}; + +use async_tungstenite::{tokio::TokioAdapter, WebSocketStream}; +use tungstenite::{ + handshake::derive_accept_key, + protocol::{Message, Role}, +}; + +type Tx = UnboundedSender; +type PeerMap = Arc>>; + +async fn handle_connection( + peer_map: PeerMap, + ws_stream: WebSocketStream>, + addr: SocketAddr, +) { + println!("WebSocket connection established: {}", addr); + + // Insert the write part of this peer to the peer map. + let (tx, rx) = unbounded(); + peer_map.lock().unwrap().insert(addr, tx); + + let (outgoing, incoming) = ws_stream.split(); + + let broadcast_incoming = incoming.try_for_each(|msg| { + println!( + "Received a message from {}: {}", + addr, + msg.to_text().unwrap() + ); + let peers = peer_map.lock().unwrap(); + + // We want to broadcast the message to everyone except ourselves. + let broadcast_recipients = peers + .iter() + .filter(|(peer_addr, _)| peer_addr != &&addr) + .map(|(_, ws_sink)| ws_sink); + + for recp in broadcast_recipients { + recp.unbounded_send(msg.clone()).unwrap(); + } + + future::ok(()) + }); + + let receive_from_others = rx.map(Ok).forward(outgoing); + + pin_mut!(broadcast_incoming, receive_from_others); + future::select(broadcast_incoming, receive_from_others).await; + + println!("{} disconnected", &addr); + peer_map.lock().unwrap().remove(&addr); +} + +async fn handle_request( + peer_map: PeerMap, + mut req: Request, + addr: SocketAddr, +) -> Result, Infallible> { + println!("Received a new, potentially ws handshake"); + println!("The request's path is: {}", req.uri().path()); + println!("The request's headers are:"); + for (ref header, _value) in req.headers() { + println!("* {}", header); + } + let upgrade = HeaderValue::from_static("Upgrade"); + let websocket = HeaderValue::from_static("websocket"); + let headers = req.headers(); + let key = headers.get(SEC_WEBSOCKET_KEY); + let derived = key.map(|k| derive_accept_key(k.as_bytes())); + if req.method() != Method::GET + || req.version() < Version::HTTP_11 + || !headers + .get(CONNECTION) + .and_then(|h| h.to_str().ok()) + .map(|h| { + h.split(|c| c == ' ' || c == ',') + .any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap())) + }) + .unwrap_or(false) + || !headers + .get(UPGRADE) + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) + || !headers + .get(SEC_WEBSOCKET_VERSION) + .map(|h| h == "13") + .unwrap_or(false) + || key.is_none() + || req.uri() != "/socket" + { + return Ok(Response::new(Body::from("Hello World!"))); + } + let ver = req.version(); + tokio::task::spawn(async move { + match hyper::upgrade::on(&mut req).await { + Ok(upgraded) => { + handle_connection( + peer_map, + WebSocketStream::from_raw_socket( + TokioAdapter::new(upgraded), + Role::Server, + None, + ) + .await, + addr, + ) + .await; + } + Err(e) => println!("upgrade error: {}", e), + } + }); + let mut res = Response::new(Body::empty()); + *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; + *res.version_mut() = ver; + res.headers_mut().append(CONNECTION, upgrade); + res.headers_mut().append(UPGRADE, websocket); + res.headers_mut() + .append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap()); + // Let's add an additional header to our response to the client. + res.headers_mut() + .append("MyCustomHeader", ":)".parse().unwrap()); + res.headers_mut() + .append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap()); + Ok(res) +} + +#[tokio::main] +async fn main() -> Result<(), hyper::Error> { + let state = PeerMap::new(Mutex::new(HashMap::new())); + + let addr = env::args() + .nth(1) + .unwrap_or_else(|| "127.0.0.1:8080".to_string()) + .parse() + .unwrap(); + + let make_svc = make_service_fn(move |conn: &AddrStream| { + let remote_addr = conn.remote_addr(); + let state = state.clone(); + let service = service_fn(move |req| handle_request(state.clone(), req, remote_addr)); + async { Ok::<_, Infallible>(service) } + }); + + let server = Server::bind(&addr).serve(make_svc); + + server.await?; + + Ok::<_, hyper::Error>(()) +}