diff --git a/server/src/main.rs b/server/src/main.rs index 3736d7c7..3b78b1bd 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -15,7 +15,9 @@ use async_std::io::Read; use async_std::net::{TcpListener, TcpStream}; use async_std::prelude::*; use async_std::task::{block_on, spawn}; -use http_types::{headers, Body, Error, Method, Mime, Request, Response, Result, StatusCode}; +use http_types::{ + bail_status, headers, Body, Error, Method, Mime, Request, Response, Result, StatusCode, +}; use oxigraph::io::{DatasetFormat, GraphFormat}; use oxigraph::model::{GraphName, NamedNode, NamedOrBlankNode}; use oxigraph::sparql::algebra::GraphUpdateOperation; @@ -99,12 +101,7 @@ async fn handle_request(request: Request, store: Store) -> Result { } } ("/query", Method::Get) => { - evaluate_urlencoded_sparql_query( - store, - request.url().query().unwrap_or("").as_bytes().to_vec(), - request, - ) - .await? + configure_and_evaluate_sparql_query(store, url_query(&request), None, request)? } ("/query", Method::Post) => { if let Some(content_type) = request.content_type() { @@ -116,7 +113,12 @@ async fn handle_request(request: Request, store: Store) -> Result { .take(MAX_SPARQL_BODY_SIZE) .read_to_string(&mut buffer) .await?; - evaluate_sparql_query(store, buffer, Vec::new(), Vec::new(), request).await? + configure_and_evaluate_sparql_query( + store, + url_query(&request), + Some(buffer), + request, + )? } else if content_type.essence() == "application/x-www-form-urlencoded" { let mut buffer = Vec::new(); let mut request = request; @@ -125,7 +127,7 @@ async fn handle_request(request: Request, store: Store) -> Result { .take(MAX_SPARQL_BODY_SIZE) .read_to_end(&mut buffer) .await?; - evaluate_urlencoded_sparql_query(store, buffer, request).await? + configure_and_evaluate_sparql_query(store, buffer, None, request)? } else { simple_response( StatusCode::UnsupportedMediaType, @@ -146,7 +148,12 @@ async fn handle_request(request: Request, store: Store) -> Result { .take(MAX_SPARQL_BODY_SIZE) .read_to_string(&mut buffer) .await?; - evaluate_sparql_update(store, buffer, Vec::new(), Vec::new()).await? + configure_and_evaluate_sparql_update( + store, + url_query(&request), + Some(buffer), + request, + )? } else if content_type.essence() == "application/x-www-form-urlencoded" { let mut buffer = Vec::new(); let mut request = request; @@ -155,7 +162,7 @@ async fn handle_request(request: Request, store: Store) -> Result { .take(MAX_SPARQL_BODY_SIZE) .read_to_end(&mut buffer) .await?; - evaluate_urlencoded_sparql_update(store, buffer).await? + configure_and_evaluate_sparql_update(store, buffer, None, request)? } else { simple_response( StatusCode::UnsupportedMediaType, @@ -178,17 +185,31 @@ fn simple_response(status: StatusCode, body: impl Into) -> Response { response } -async fn evaluate_urlencoded_sparql_query( +fn base_url(request: &Request) -> &str { + let url = request.url().as_str(); + url.split('?').next().unwrap_or(url) +} + +fn url_query(request: &Request) -> Vec { + request.url().query().unwrap_or("").as_bytes().to_vec() +} + +fn configure_and_evaluate_sparql_query( store: Store, encoded: Vec, + mut query: Option, request: Request, ) -> Result { - let mut query = None; let mut default_graph_uris = Vec::new(); let mut named_graph_uris = Vec::new(); for (k, v) in form_urlencoded::parse(&encoded) { match k.as_ref() { - "query" => query = Some(v.into_owned()), + "query" => { + if query.is_some() { + bail_status!(400, "Multiple query parameters provided") + } + query = Some(v.into_owned()) + } "default-graph-uri" => default_graph_uris.push(v.into_owned()), "named-graph-uri" => named_graph_uris.push(v.into_owned()), _ => { @@ -200,7 +221,7 @@ async fn evaluate_urlencoded_sparql_query( } } if let Some(query) = query { - evaluate_sparql_query(store, query, default_graph_uris, named_graph_uris, request).await + evaluate_sparql_query(store, query, default_graph_uris, named_graph_uris, request) } else { Ok(simple_response( StatusCode::BadRequest, @@ -209,14 +230,14 @@ async fn evaluate_urlencoded_sparql_query( } } -async fn evaluate_sparql_query( +fn evaluate_sparql_query( store: Store, query: String, default_graph_uris: Vec, named_graph_uris: Vec, request: Request, ) -> Result { - let mut query = Query::parse(&query, None).map_err(bad_request)?; + let mut query = Query::parse(&query, Some(base_url(&request))).map_err(bad_request)?; let default_graph_uris = default_graph_uris .into_iter() .map(|e| Ok(NamedNode::new(e)?.into())) @@ -271,13 +292,22 @@ async fn evaluate_sparql_query( } } -async fn evaluate_urlencoded_sparql_update(store: Store, encoded: Vec) -> Result { - let mut update = None; +fn configure_and_evaluate_sparql_update( + store: Store, + encoded: Vec, + mut update: Option, + request: Request, +) -> Result { let mut default_graph_uris = Vec::new(); let mut named_graph_uris = Vec::new(); for (k, v) in form_urlencoded::parse(&encoded) { match k.as_ref() { - "update" => update = Some(v.into_owned()), + "update" => { + if update.is_some() { + bail_status!(400, "Multiple update parameters provided") + } + update = Some(v.into_owned()) + } "using-graph-uri" => default_graph_uris.push(v.into_owned()), "using-named-graph-uri" => named_graph_uris.push(v.into_owned()), _ => { @@ -289,7 +319,7 @@ async fn evaluate_urlencoded_sparql_update(store: Store, encoded: Vec) -> Re } } if let Some(update) = update { - evaluate_sparql_update(store, update, default_graph_uris, named_graph_uris).await + evaluate_sparql_update(store, update, default_graph_uris, named_graph_uris, request) } else { Ok(simple_response( StatusCode::BadRequest, @@ -298,13 +328,14 @@ async fn evaluate_urlencoded_sparql_update(store: Store, encoded: Vec) -> Re } } -async fn evaluate_sparql_update( +fn evaluate_sparql_update( store: Store, update: String, default_graph_uris: Vec, named_graph_uris: Vec, + request: Request, ) -> Result { - let mut update = Update::parse(&update, None).map_err(|e| { + let mut update = Update::parse(&update, Some(base_url(&request))).map_err(|e| { let mut e = Error::from(e); e.set_status(StatusCode::BadRequest); e