Server: Adds an option to allow CORS

pull/467/head
Tpt 2 years ago committed by Thomas Tanon
parent 86bbebf93c
commit 284e79521d
  1. 76
      server/src/main.rs

@ -1,7 +1,7 @@
use anyhow::{anyhow, bail, Context, Error}; use anyhow::{anyhow, bail, Context, Error};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use flate2::read::MultiGzDecoder; use flate2::read::MultiGzDecoder;
use oxhttp::model::{Body, HeaderName, HeaderValue, Request, Response, Status}; use oxhttp::model::{Body, HeaderName, HeaderValue, Method, Request, Response, Status};
use oxhttp::Server; use oxhttp::Server;
use oxigraph::io::{DatasetFormat, DatasetSerializer, GraphFormat, GraphSerializer}; use oxigraph::io::{DatasetFormat, DatasetSerializer, GraphFormat, GraphSerializer};
use oxigraph::model::{ use oxigraph::model::{
@ -52,6 +52,9 @@ enum Command {
/// Host and port to listen to. /// Host and port to listen to.
#[arg(short, long, default_value = "localhost:7878")] #[arg(short, long, default_value = "localhost:7878")]
bind: String, bind: String,
/// Allows cross-origin requests
#[arg(long)]
cors: bool,
}, },
/// Start Oxigraph HTTP server in read-only mode. /// Start Oxigraph HTTP server in read-only mode.
/// ///
@ -62,6 +65,9 @@ enum Command {
/// Host and port to listen to. /// Host and port to listen to.
#[arg(short, long, default_value = "localhost:7878")] #[arg(short, long, default_value = "localhost:7878")]
bind: String, bind: String,
/// Allows cross-origin requests
#[arg(long)]
cors: bool,
}, },
/// Start Oxigraph HTTP server in secondary mode. /// Start Oxigraph HTTP server in secondary mode.
/// ///
@ -82,6 +88,9 @@ enum Command {
/// Host and port to listen to. /// Host and port to listen to.
#[arg(short, long, default_value = "localhost:7878")] #[arg(short, long, default_value = "localhost:7878")]
bind: String, bind: String,
/// Allows cross-origin requests
#[arg(long)]
cors: bool,
}, },
/// Creates database backup into a target directory. /// Creates database backup into a target directory.
/// ///
@ -219,7 +228,7 @@ enum Command {
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
let matches = Args::parse(); let matches = Args::parse();
match matches.command { match matches.command {
Command::Serve { bind } => serve( Command::Serve { bind, cors } => serve(
if let Some(location) = matches.location { if let Some(location) = matches.location {
Store::open(location) Store::open(location)
} else { } else {
@ -227,8 +236,9 @@ pub fn main() -> anyhow::Result<()> {
}?, }?,
bind, bind,
false, false,
cors,
), ),
Command::ServeReadOnly { bind } => serve( Command::ServeReadOnly { bind, cors } => serve(
Store::open_read_only( Store::open_read_only(
matches matches
.location .location
@ -236,11 +246,13 @@ pub fn main() -> anyhow::Result<()> {
)?, )?,
bind, bind,
true, true,
cors,
), ),
Command::ServeSecondary { Command::ServeSecondary {
primary_location, primary_location,
secondary_location, secondary_location,
bind, bind,
cors,
} => { } => {
let primary_location = primary_location.or(matches.location).ok_or_else(|| { let primary_location = primary_location.or(matches.location).ok_or_else(|| {
anyhow!("Either the --location or the --primary-location argument is required") anyhow!("Either the --location or the --primary-location argument is required")
@ -253,6 +265,7 @@ pub fn main() -> anyhow::Result<()> {
}?, }?,
bind, bind,
true, true,
cors,
) )
} }
Command::Backup { destination } => { Command::Backup { destination } => {
@ -745,11 +758,18 @@ impl FromStr for GraphOrDatasetFormat {
} }
} }
fn serve(store: Store, bind: String, read_only: bool) -> anyhow::Result<()> { fn serve(store: Store, bind: String, read_only: bool, cors: bool) -> anyhow::Result<()> {
let mut server = Server::new(move |request| { let mut server = if cors {
Server::new(cors_middleware(move |request| {
handle_request(request, store.clone(), read_only) handle_request(request, store.clone(), read_only)
.unwrap_or_else(|(status, message)| error(status, message)) .unwrap_or_else(|(status, message)| error(status, message))
}); }))
} else {
Server::new(move |request| {
handle_request(request, store.clone(), read_only)
.unwrap_or_else(|(status, message)| error(status, message))
})
};
server.set_global_timeout(HTTP_TIMEOUT); server.set_global_timeout(HTTP_TIMEOUT);
server.set_server_name(concat!("Oxigraph/", env!("CARGO_PKG_VERSION")))?; server.set_server_name(concat!("Oxigraph/", env!("CARGO_PKG_VERSION")))?;
eprintln!("Listening for requests at http://{}", &bind); eprintln!("Listening for requests at http://{}", &bind);
@ -757,6 +777,50 @@ fn serve(store: Store, bind: String, read_only: bool) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
fn cors_middleware(
on_request: impl Fn(&mut Request) -> Response + Send + Sync + 'static,
) -> impl Fn(&mut Request) -> Response + Send + Sync + 'static {
let origin = HeaderName::from_str("Origin").unwrap();
let access_control_allow_origin = HeaderName::from_str("Access-Control-Allow-Origin").unwrap();
let access_control_request_method =
HeaderName::from_str("Access-Control-Request-Method").unwrap();
let access_control_allow_method = HeaderName::from_str("Access-Control-Allow-Methods").unwrap();
let access_control_request_headers =
HeaderName::from_str("Access-Control-Request-Headers").unwrap();
let access_control_allow_headers =
HeaderName::from_str("Access-Control-Allow-Headers").unwrap();
let star = HeaderValue::from_str("*").unwrap();
move |request| {
if *request.method() == Method::OPTIONS {
let mut response = Response::builder(Status::NO_CONTENT);
if request.header(&origin).is_some() {
response
.headers_mut()
.append(access_control_allow_origin.clone(), star.clone());
}
if let Some(method) = request.header(&access_control_request_method) {
response
.headers_mut()
.append(access_control_allow_method.clone(), method.clone());
}
if let Some(headers) = request.header(&access_control_request_headers) {
response
.headers_mut()
.append(access_control_allow_headers.clone(), headers.clone());
}
response.build()
} else {
let mut response = on_request(request);
if request.header(&origin).is_some() {
response
.headers_mut()
.append(access_control_allow_origin.clone(), star.clone());
}
response
}
}
}
type HttpError = (Status, String); type HttpError = (Status, String);
fn handle_request( fn handle_request(

Loading…
Cancel
Save