You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
191 lines
5.9 KiB
191 lines
5.9 KiB
2 years ago
|
//! A chat server that broadcasts a message to all connections.
|
||
|
//!
|
||
|
//! This is a simple line-based server which accepts WebSocket connections,
|
||
|
//! reads lines from those connections, and broadcasts the lines to all other
|
||
|
//! connected clients.
|
||
|
//!
|
||
|
//! You can test this out by running:
|
||
|
//!
|
||
|
//! cargo run --example server 127.0.0.1:12345
|
||
|
//!
|
||
|
//! And then in another window run:
|
||
|
//!
|
||
|
//! cargo run --example client ws://127.0.0.1:12345/socket
|
||
|
//!
|
||
|
//! You can run the second command in multiple windows and then chat between the
|
||
|
//! two, seeing the messages from the other client as they're received. For all
|
||
|
//! connected clients they'll all join the same room and see everyone else's
|
||
|
//! messages.
|
||
|
|
||
|
use std::{
|
||
|
collections::HashMap,
|
||
|
convert::Infallible,
|
||
|
env,
|
||
|
net::SocketAddr,
|
||
|
sync::{Arc, Mutex},
|
||
|
};
|
||
|
|
||
|
use hyper::{
|
||
|
header::{
|
||
|
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
|
||
|
UPGRADE,
|
||
|
},
|
||
|
server::conn::AddrStream,
|
||
|
service::{make_service_fn, service_fn},
|
||
|
upgrade::Upgraded,
|
||
|
Body, Method, Request, Response, Server, StatusCode, Version,
|
||
|
};
|
||
|
|
||
|
use futures_channel::mpsc::{unbounded, UnboundedSender};
|
||
|
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};
|
||
|
|
||
|
use async_tungstenite::{tokio::TokioAdapter, WebSocketStream};
|
||
|
use tungstenite::{
|
||
|
handshake::derive_accept_key,
|
||
|
protocol::{Message, Role},
|
||
|
};
|
||
|
|
||
|
type Tx = UnboundedSender<Message>;
|
||
|
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
|
||
|
|
||
|
async fn handle_connection(
|
||
|
peer_map: PeerMap,
|
||
|
ws_stream: WebSocketStream<TokioAdapter<Upgraded>>,
|
||
|
addr: SocketAddr,
|
||
|
) {
|
||
|
println!("WebSocket connection established: {}", addr);
|
||
|
|
||
|
// Insert the write part of this peer to the peer map.
|
||
|
let (tx, rx) = unbounded();
|
||
|
peer_map.lock().unwrap().insert(addr, tx);
|
||
|
|
||
|
let (outgoing, incoming) = ws_stream.split();
|
||
|
|
||
|
let broadcast_incoming = incoming.try_for_each(|msg| {
|
||
|
println!(
|
||
|
"Received a message from {}: {}",
|
||
|
addr,
|
||
|
msg.to_text().unwrap()
|
||
|
);
|
||
|
let peers = peer_map.lock().unwrap();
|
||
|
|
||
|
// We want to broadcast the message to everyone except ourselves.
|
||
|
let broadcast_recipients = peers
|
||
|
.iter()
|
||
|
.filter(|(peer_addr, _)| peer_addr != &&addr)
|
||
|
.map(|(_, ws_sink)| ws_sink);
|
||
|
|
||
|
for recp in broadcast_recipients {
|
||
|
recp.unbounded_send(msg.clone()).unwrap();
|
||
|
}
|
||
|
|
||
|
future::ok(())
|
||
|
});
|
||
|
|
||
|
let receive_from_others = rx.map(Ok).forward(outgoing);
|
||
|
|
||
|
pin_mut!(broadcast_incoming, receive_from_others);
|
||
|
future::select(broadcast_incoming, receive_from_others).await;
|
||
|
|
||
|
println!("{} disconnected", &addr);
|
||
|
peer_map.lock().unwrap().remove(&addr);
|
||
|
}
|
||
|
|
||
|
async fn handle_request(
|
||
|
peer_map: PeerMap,
|
||
|
mut req: Request<Body>,
|
||
|
addr: SocketAddr,
|
||
|
) -> Result<Response<Body>, Infallible> {
|
||
|
println!("Received a new, potentially ws handshake");
|
||
|
println!("The request's path is: {}", req.uri().path());
|
||
|
println!("The request's headers are:");
|
||
|
for (ref header, _value) in req.headers() {
|
||
|
println!("* {}", header);
|
||
|
}
|
||
|
let upgrade = HeaderValue::from_static("Upgrade");
|
||
|
let websocket = HeaderValue::from_static("websocket");
|
||
|
let headers = req.headers();
|
||
|
let key = headers.get(SEC_WEBSOCKET_KEY);
|
||
|
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
|
||
|
if req.method() != Method::GET
|
||
|
|| req.version() < Version::HTTP_11
|
||
|
|| !headers
|
||
|
.get(CONNECTION)
|
||
|
.and_then(|h| h.to_str().ok())
|
||
|
.map(|h| {
|
||
|
h.split(|c| c == ' ' || c == ',')
|
||
|
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap()))
|
||
|
})
|
||
|
.unwrap_or(false)
|
||
|
|| !headers
|
||
|
.get(UPGRADE)
|
||
|
.and_then(|h| h.to_str().ok())
|
||
|
.map(|h| h.eq_ignore_ascii_case("websocket"))
|
||
|
.unwrap_or(false)
|
||
|
|| !headers
|
||
|
.get(SEC_WEBSOCKET_VERSION)
|
||
|
.map(|h| h == "13")
|
||
|
.unwrap_or(false)
|
||
|
|| key.is_none()
|
||
|
|| req.uri() != "/socket"
|
||
|
{
|
||
|
return Ok(Response::new(Body::from("Hello World!")));
|
||
|
}
|
||
|
let ver = req.version();
|
||
|
tokio::task::spawn(async move {
|
||
|
match hyper::upgrade::on(&mut req).await {
|
||
|
Ok(upgraded) => {
|
||
|
handle_connection(
|
||
|
peer_map,
|
||
|
WebSocketStream::from_raw_socket(
|
||
|
TokioAdapter::new(upgraded),
|
||
|
Role::Server,
|
||
|
None,
|
||
|
)
|
||
|
.await,
|
||
|
addr,
|
||
|
)
|
||
|
.await;
|
||
|
}
|
||
|
Err(e) => println!("upgrade error: {}", e),
|
||
|
}
|
||
|
});
|
||
|
let mut res = Response::new(Body::empty());
|
||
|
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
|
||
|
*res.version_mut() = ver;
|
||
|
res.headers_mut().append(CONNECTION, upgrade);
|
||
|
res.headers_mut().append(UPGRADE, websocket);
|
||
|
res.headers_mut()
|
||
|
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
|
||
|
// Let's add an additional header to our response to the client.
|
||
|
res.headers_mut()
|
||
|
.append("MyCustomHeader", ":)".parse().unwrap());
|
||
|
res.headers_mut()
|
||
|
.append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap());
|
||
|
Ok(res)
|
||
|
}
|
||
|
|
||
|
#[tokio::main]
|
||
|
async fn main() -> Result<(), hyper::Error> {
|
||
|
let state = PeerMap::new(Mutex::new(HashMap::new()));
|
||
|
|
||
|
let addr = env::args()
|
||
|
.nth(1)
|
||
|
.unwrap_or_else(|| "127.0.0.1:8080".to_string())
|
||
|
.parse()
|
||
|
.unwrap();
|
||
|
|
||
|
let make_svc = make_service_fn(move |conn: &AddrStream| {
|
||
|
let remote_addr = conn.remote_addr();
|
||
|
let state = state.clone();
|
||
|
let service = service_fn(move |req| handle_request(state.clone(), req, remote_addr));
|
||
|
async { Ok::<_, Infallible>(service) }
|
||
|
});
|
||
|
|
||
|
let server = Server::bind(&addr).serve(make_svc);
|
||
|
|
||
|
server.await?;
|
||
|
|
||
|
Ok::<_, hyper::Error>(())
|
||
|
}
|