Server: Adds an option to allow CORS

pull/467/head
Tpt 2 years ago committed by Thomas Tanon
parent 86bbebf93c
commit 284e79521d
  1. 80
      server/src/main.rs

@ -1,7 +1,7 @@
use anyhow::{anyhow, bail, Context, Error};
use clap::{Parser, Subcommand};
use flate2::read::MultiGzDecoder;
use oxhttp::model::{Body, HeaderName, HeaderValue, Request, Response, Status};
use oxhttp::model::{Body, HeaderName, HeaderValue, Method, Request, Response, Status};
use oxhttp::Server;
use oxigraph::io::{DatasetFormat, DatasetSerializer, GraphFormat, GraphSerializer};
use oxigraph::model::{
@ -52,6 +52,9 @@ enum Command {
/// Host and port to listen to.
#[arg(short, long, default_value = "localhost:7878")]
bind: String,
/// Allows cross-origin requests
#[arg(long)]
cors: bool,
},
/// Start Oxigraph HTTP server in read-only mode.
///
@ -62,6 +65,9 @@ enum Command {
/// Host and port to listen to.
#[arg(short, long, default_value = "localhost:7878")]
bind: String,
/// Allows cross-origin requests
#[arg(long)]
cors: bool,
},
/// Start Oxigraph HTTP server in secondary mode.
///
@ -82,6 +88,9 @@ enum Command {
/// Host and port to listen to.
#[arg(short, long, default_value = "localhost:7878")]
bind: String,
/// Allows cross-origin requests
#[arg(long)]
cors: bool,
},
/// Creates database backup into a target directory.
///
@ -219,7 +228,7 @@ enum Command {
pub fn main() -> anyhow::Result<()> {
let matches = Args::parse();
match matches.command {
Command::Serve { bind } => serve(
Command::Serve { bind, cors } => serve(
if let Some(location) = matches.location {
Store::open(location)
} else {
@ -227,8 +236,9 @@ pub fn main() -> anyhow::Result<()> {
}?,
bind,
false,
cors,
),
Command::ServeReadOnly { bind } => serve(
Command::ServeReadOnly { bind, cors } => serve(
Store::open_read_only(
matches
.location
@ -236,11 +246,13 @@ pub fn main() -> anyhow::Result<()> {
)?,
bind,
true,
cors,
),
Command::ServeSecondary {
primary_location,
secondary_location,
bind,
cors,
} => {
let primary_location = primary_location.or(matches.location).ok_or_else(|| {
anyhow!("Either the --location or the --primary-location argument is required")
@ -253,6 +265,7 @@ pub fn main() -> anyhow::Result<()> {
}?,
bind,
true,
cors,
)
}
Command::Backup { destination } => {
@ -745,11 +758,18 @@ impl FromStr for GraphOrDatasetFormat {
}
}
fn serve(store: Store, bind: String, read_only: bool) -> anyhow::Result<()> {
let mut server = Server::new(move |request| {
handle_request(request, store.clone(), read_only)
.unwrap_or_else(|(status, message)| error(status, message))
});
fn serve(store: Store, bind: String, read_only: bool, cors: bool) -> anyhow::Result<()> {
let mut server = if cors {
Server::new(cors_middleware(move |request| {
handle_request(request, store.clone(), read_only)
.unwrap_or_else(|(status, message)| error(status, message))
}))
} else {
Server::new(move |request| {
handle_request(request, store.clone(), read_only)
.unwrap_or_else(|(status, message)| error(status, message))
})
};
server.set_global_timeout(HTTP_TIMEOUT);
server.set_server_name(concat!("Oxigraph/", env!("CARGO_PKG_VERSION")))?;
eprintln!("Listening for requests at http://{}", &bind);
@ -757,6 +777,50 @@ fn serve(store: Store, bind: String, read_only: bool) -> anyhow::Result<()> {
Ok(())
}
fn cors_middleware(
on_request: impl Fn(&mut Request) -> Response + Send + Sync + 'static,
) -> impl Fn(&mut Request) -> Response + Send + Sync + 'static {
let origin = HeaderName::from_str("Origin").unwrap();
let access_control_allow_origin = HeaderName::from_str("Access-Control-Allow-Origin").unwrap();
let access_control_request_method =
HeaderName::from_str("Access-Control-Request-Method").unwrap();
let access_control_allow_method = HeaderName::from_str("Access-Control-Allow-Methods").unwrap();
let access_control_request_headers =
HeaderName::from_str("Access-Control-Request-Headers").unwrap();
let access_control_allow_headers =
HeaderName::from_str("Access-Control-Allow-Headers").unwrap();
let star = HeaderValue::from_str("*").unwrap();
move |request| {
if *request.method() == Method::OPTIONS {
let mut response = Response::builder(Status::NO_CONTENT);
if request.header(&origin).is_some() {
response
.headers_mut()
.append(access_control_allow_origin.clone(), star.clone());
}
if let Some(method) = request.header(&access_control_request_method) {
response
.headers_mut()
.append(access_control_allow_method.clone(), method.clone());
}
if let Some(headers) = request.header(&access_control_request_headers) {
response
.headers_mut()
.append(access_control_allow_headers.clone(), headers.clone());
}
response.build()
} else {
let mut response = on_request(request);
if request.header(&origin).is_some() {
response
.headers_mut()
.append(access_control_allow_origin.clone(), star.clone());
}
response
}
}
}
type HttpError = (Status, String);
fn handle_request(

Loading…
Cancel
Save