From 0bdda22ef9496353839e11e6921cfab4c83229b8 Mon Sep 17 00:00:00 2001 From: Tpt Date: Mon, 17 Apr 2023 20:44:44 +0200 Subject: [PATCH] Server: Adds an option to allow CORS --- server/src/main.rs | 80 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index 7d5b5235..3c807ab4 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, bail, Context, Error}; use clap::{Parser, Subcommand}; use flate2::read::MultiGzDecoder; -use oxhttp::model::{Body, HeaderName, HeaderValue, Request, Response, Status}; +use oxhttp::model::{Body, HeaderName, HeaderValue, Method, Request, Response, Status}; use oxhttp::Server; use oxigraph::io::{DatasetFormat, DatasetSerializer, GraphFormat, GraphSerializer}; use oxigraph::model::{ @@ -52,6 +52,9 @@ enum Command { /// Host and port to listen to. #[arg(short, long, default_value = "localhost:7878")] bind: String, + /// Allows cross-origin requests + #[arg(long)] + cors: bool, }, /// Start Oxigraph HTTP server in read-only mode. /// @@ -62,6 +65,9 @@ enum Command { /// Host and port to listen to. #[arg(short, long, default_value = "localhost:7878")] bind: String, + /// Allows cross-origin requests + #[arg(long)] + cors: bool, }, /// Start Oxigraph HTTP server in secondary mode. /// @@ -82,6 +88,9 @@ enum Command { /// Host and port to listen to. #[arg(short, long, default_value = "localhost:7878")] bind: String, + /// Allows cross-origin requests + #[arg(long)] + cors: bool, }, /// Creates database backup into a target directory. /// @@ -219,7 +228,7 @@ enum Command { pub fn main() -> anyhow::Result<()> { let matches = Args::parse(); match matches.command { - Command::Serve { bind } => serve( + Command::Serve { bind, cors } => serve( if let Some(location) = matches.location { Store::open(location) } else { @@ -227,8 +236,9 @@ pub fn main() -> anyhow::Result<()> { }?, bind, false, + cors, ), - Command::ServeReadOnly { bind } => serve( + Command::ServeReadOnly { bind, cors } => serve( Store::open_read_only( matches .location @@ -236,11 +246,13 @@ pub fn main() -> anyhow::Result<()> { )?, bind, true, + cors, ), Command::ServeSecondary { primary_location, secondary_location, bind, + cors, } => { let primary_location = primary_location.or(matches.location).ok_or_else(|| { anyhow!("Either the --location or the --primary-location argument is required") @@ -253,6 +265,7 @@ pub fn main() -> anyhow::Result<()> { }?, bind, true, + cors, ) } Command::Backup { destination } => { @@ -745,11 +758,18 @@ impl FromStr for GraphOrDatasetFormat { } } -fn serve(store: Store, bind: String, read_only: bool) -> anyhow::Result<()> { - let mut server = Server::new(move |request| { - handle_request(request, store.clone(), read_only) - .unwrap_or_else(|(status, message)| error(status, message)) - }); +fn serve(store: Store, bind: String, read_only: bool, cors: bool) -> anyhow::Result<()> { + let mut server = if cors { + Server::new(cors_middleware(move |request| { + handle_request(request, store.clone(), read_only) + .unwrap_or_else(|(status, message)| error(status, message)) + })) + } else { + Server::new(move |request| { + handle_request(request, store.clone(), read_only) + .unwrap_or_else(|(status, message)| error(status, message)) + }) + }; server.set_global_timeout(HTTP_TIMEOUT); server.set_server_name(concat!("Oxigraph/", env!("CARGO_PKG_VERSION")))?; eprintln!("Listening for requests at http://{}", &bind); @@ -757,6 +777,50 @@ fn serve(store: Store, bind: String, read_only: bool) -> anyhow::Result<()> { Ok(()) } +fn cors_middleware( + on_request: impl Fn(&mut Request) -> Response + Send + Sync + 'static, +) -> impl Fn(&mut Request) -> Response + Send + Sync + 'static { + let origin = HeaderName::from_str("Origin").unwrap(); + let access_control_allow_origin = HeaderName::from_str("Access-Control-Allow-Origin").unwrap(); + let access_control_request_method = + HeaderName::from_str("Access-Control-Request-Method").unwrap(); + let access_control_allow_method = HeaderName::from_str("Access-Control-Allow-Methods").unwrap(); + let access_control_request_headers = + HeaderName::from_str("Access-Control-Request-Headers").unwrap(); + let access_control_allow_headers = + HeaderName::from_str("Access-Control-Allow-Headers").unwrap(); + let star = HeaderValue::from_str("*").unwrap(); + move |request| { + if *request.method() == Method::OPTIONS { + let mut response = Response::builder(Status::NO_CONTENT); + if request.header(&origin).is_some() { + response + .headers_mut() + .append(access_control_allow_origin.clone(), star.clone()); + } + if let Some(method) = request.header(&access_control_request_method) { + response + .headers_mut() + .append(access_control_allow_method.clone(), method.clone()); + } + if let Some(headers) = request.header(&access_control_request_headers) { + response + .headers_mut() + .append(access_control_allow_headers.clone(), headers.clone()); + } + response.build() + } else { + let mut response = on_request(request); + if request.header(&origin).is_some() { + response + .headers_mut() + .append(access_control_allow_origin.clone(), star.clone()); + } + response + } + } +} + type HttpError = (Status, String); fn handle_request(