From 6930a8452141225828a45a5c071df59d68c2e871 Mon Sep 17 00:00:00 2001 From: Tpt Date: Wed, 20 May 2020 12:59:25 +0200 Subject: [PATCH] Uses async-h1 in oxigraph_wikibase --- server/Cargo.toml | 4 +- wikibase/Cargo.toml | 11 +- wikibase/src/loader.rs | 85 ++++++----- wikibase/src/main.rs | 318 ++++++++++++++++++++++++++--------------- 4 files changed, 262 insertions(+), 156 deletions(-) diff --git a/server/Cargo.toml b/server/Cargo.toml index 46dc339b..635664f3 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,9 +11,9 @@ SPARQL server based on Oxigraph edition = "2018" [dependencies] -oxigraph = { path = "../lib", features = ["rocksdb"] } argh = "0.1" async-std = { version = "1", features = ["attributes"] } async-h1 = "1" http-types = "1" -url = "2" \ No newline at end of file +oxigraph = { path = "../lib", features = ["rocksdb"] } +url = "2" diff --git a/wikibase/Cargo.toml b/wikibase/Cargo.toml index a4dc7190..7d1fbebc 100644 --- a/wikibase/Cargo.toml +++ b/wikibase/Cargo.toml @@ -11,9 +11,12 @@ SPARQL server based on Oxigraph for Wikibase instances edition = "2018" [dependencies] -oxigraph = {path = "../lib", features = ["rocksdb"] } argh = "0.1" -rouille = "3" -reqwest = "0.9" +async-std = { version = "1", features = ["attributes"] } +async-h1 = "1" +chrono = "0.4" +http-client = { version = "2.0", features = ["h1_client"] } +http-types = "1" +oxigraph = { path = "../lib", features = ["rocksdb"] } serde_json = "1" -chrono = "0.4" \ No newline at end of file +url = "2" \ No newline at end of file diff --git a/wikibase/src/loader.rs b/wikibase/src/loader.rs index 6b4f360b..130c7329 100644 --- a/wikibase/src/loader.rs +++ b/wikibase/src/loader.rs @@ -1,20 +1,24 @@ use crate::SERVER; +use async_std::prelude::*; +use async_std::task::block_on; use chrono::{DateTime, Datelike, Utc}; +use http_client::h1::H1Client; +use http_client::HttpClient; +use http_types::{Method, Request, Result}; use oxigraph::model::NamedNode; -use oxigraph::*; -use reqwest::header::USER_AGENT; -use reqwest::{Client, Url}; +use oxigraph::{GraphSyntax, Repository, RepositoryConnection, RepositoryTransaction}; use serde_json::Value; use std::collections::{HashMap, HashSet}; -use std::io::{BufReader, Read}; +use std::io::{BufReader, Cursor, Read}; use std::thread::sleep; use std::time::Duration; +use url::{form_urlencoded, Url}; pub struct WikibaseLoader { repository: R, api_url: Url, entity_data_url: Url, - client: Client, + client: H1Client, namespaces: Vec, slot: Option, frequency: Duration, @@ -32,10 +36,9 @@ impl WikibaseLoader { ) -> Result { Ok(Self { repository, - api_url: Url::parse(api_url).map_err(Error::wrap)?, - entity_data_url: Url::parse(&(pages_base_url.to_owned() + "Special:EntityData")) - .map_err(Error::wrap)?, - client: Client::new(), + api_url: Url::parse(api_url)?, + entity_data_url: Url::parse(&(pages_base_url.to_owned() + "Special:EntityData"))?, + client: H1Client::new(), namespaces: namespaces.to_vec(), slot: slot.map(|t| t.to_owned()), start: Utc::now(), @@ -59,6 +62,7 @@ impl WikibaseLoader { parameters.insert("action".to_owned(), "query".to_owned()); parameters.insert("list".to_owned(), "allpages".to_owned()); parameters.insert("apnamespace".to_owned(), namespace.to_string()); + parameters.insert("aplimit".to_owned(), "50".to_owned()); self.api_get_with_continue(parameters, |results| { println!("*"); @@ -81,7 +85,7 @@ impl WikibaseLoader { Ok(data) => { self.load_entity_data( &(self.entity_data_url.to_string() + "/" + id), - data, + Cursor::new(data), )?; } Err(e) => eprintln!("Error while retrieving data for entity {}: {}", id, e), @@ -127,7 +131,7 @@ impl WikibaseLoader { } parameters.insert("rcend".to_owned(), start.to_rfc2822()); parameters.insert("rcprop".to_owned(), "title|ids".to_owned()); - parameters.insert("limit".to_owned(), "50".to_owned()); + parameters.insert("rclimit".to_owned(), "50".to_owned()); self.api_get_with_continue(parameters, |results| { for change in results @@ -155,7 +159,10 @@ impl WikibaseLoader { match self.get_entity_data(&id) { Ok(data) => { - self.load_entity_data(&format!("{}/{}", self.entity_data_url, id), data)?; + self.load_entity_data( + &format!("{}/{}", self.entity_data_url, id), + Cursor::new(data), + )?; } Err(e) => eprintln!("Error while retrieving data for entity {}: {}", id, e), } @@ -186,29 +193,39 @@ impl WikibaseLoader { fn api_get(&self, parameters: &mut HashMap) -> Result { parameters.insert("format".to_owned(), "json".to_owned()); - Ok(self - .client - .get(self.api_url.clone()) - .query(parameters) - .header(USER_AGENT, SERVER) - .send() - .map_err(Error::wrap)? - .error_for_status() - .map_err(Error::wrap)? - .json() - .map_err(Error::wrap)?) + Ok(serde_json::from_slice( + &self.get_request(&self.api_url, parameters)?, + )?) } - fn get_entity_data(&self, id: &str) -> Result { - Ok(self - .client - .get(self.entity_data_url.clone()) - .query(&[("id", id), ("format", "nt"), ("flavor", "dump")]) - .header(USER_AGENT, SERVER) - .send() - .map_err(Error::wrap)? - .error_for_status() - .map_err(Error::wrap)?) + fn get_entity_data(&self, id: &str) -> Result> { + Ok(self.get_request( + &self.entity_data_url, + [("id", id), ("format", "nt"), ("flavor", "dump")] + .iter() + .cloned(), + )?) + } + + fn get_request, V: AsRef>( + &self, + url: &Url, + params: impl IntoIterator, + ) -> Result> { + let mut query_serializer = form_urlencoded::Serializer::new(String::new()); + for (k, v) in params { + query_serializer.append_pair(k.as_ref(), v.as_ref()); + } + let url = url.join(&("?".to_owned() + &query_serializer.finish()))?; + let mut request = Request::new(Method::Get, url); + request.append_header("user-agent", SERVER)?; + let response = self.client.send(request); + block_on(async { + let mut response = response.await?; + let mut buffer = Vec::new(); + response.read_to_end(&mut buffer).await?; + Ok(buffer) + }) } fn load_entity_data(&self, uri: &str, data: impl Read) -> Result<()> { @@ -217,7 +234,7 @@ impl WikibaseLoader { connection.transaction(|transaction| { let to_remove = connection .quads_for_pattern(None, None, None, Some(Some(&graph_name))) - .collect::>>()?; + .collect::>>()?; for q in to_remove { transaction.remove(&q)?; } diff --git a/wikibase/src/main.rs b/wikibase/src/main.rs index 9a8f70cd..5998d2c0 100644 --- a/wikibase/src/main.rs +++ b/wikibase/src/main.rs @@ -11,18 +11,20 @@ use crate::loader::WikibaseLoader; use argh::FromArgs; +use async_std::future::Future; +use async_std::net::{TcpListener, TcpStream}; +use async_std::prelude::*; +use async_std::sync::Arc; +use async_std::task::{spawn, spawn_blocking}; +use http_types::headers::HeaderName; +use http_types::{headers, Body, Error, Method, Mime, Request, Response, Result, StatusCode}; use oxigraph::sparql::{PreparedQuery, QueryOptions, QueryResult, QueryResultSyntax}; use oxigraph::{ FileSyntax, GraphSyntax, MemoryRepository, Repository, RepositoryConnection, RocksDbRepository, }; -use rouille::input::priority_header_preferred; -use rouille::url::form_urlencoded; -use rouille::{content_encoding, start_server, Request, Response}; -use std::io::Read; use std::str::FromStr; -use std::sync::Arc; -use std::thread; use std::time::Duration; +use url::form_urlencoded; mod loader; @@ -57,23 +59,22 @@ struct Args { slot: Option, } -pub fn main() { +#[async_std::main] +pub async fn main() -> Result<()> { let args: Args = argh::from_env(); let file = args.file.clone(); if let Some(file) = file { - main_with_dataset(Arc::new(RocksDbRepository::open(file).unwrap()), args) + main_with_dataset(Arc::new(RocksDbRepository::open(file)?), args).await } else { - main_with_dataset(Arc::new(MemoryRepository::default()), args) + main_with_dataset(Arc::new(MemoryRepository::default()), args).await } } -fn main_with_dataset(repository: Arc, args: Args) +async fn main_with_dataset(repository: Arc, args: Args) -> Result<()> where for<'a> &'a R: Repository, { - println!("Listening for requests at http://{}", &args.bind); - let repo = repository.clone(); let mediawiki_api = args.mediawiki_api.clone(); let mediawiki_base_url = args.mediawiki_base_url.clone(); @@ -92,7 +93,7 @@ where }) .collect::>(); let slot = args.slot.clone(); - thread::spawn(move || { + spawn_blocking(move || { let mut loader = WikibaseLoader::new( repo.as_ref(), &mediawiki_api, @@ -106,131 +107,216 @@ where loader.update_loop(); }); - start_server(args.bind, move |request| { - content_encoding::apply( - request, - handle_request(request, repository.connection().unwrap()), - ) - .with_unique_header("Server", SERVER) + println!("Listening for requests at http://{}", &args.bind); + + http_server(args.bind, move |request| { + handle_request(request, Arc::clone(&repository)) }) + .await } -fn handle_request(request: &Request, connection: R) -> Response { - match (request.url().as_str(), request.method()) { - ("/query", "GET") => evaluate_urlencoded_sparql_query( - connection, - request.raw_query_string().as_bytes(), - request, - ), - ("/query", "POST") => { - if let Some(body) = request.data() { - if let Some(content_type) = request.header("Content-Type") { - if content_type.starts_with("application/sparql-query") { - let mut buffer = String::default(); - body.take(MAX_SPARQL_BODY_SIZE) - .read_to_string(&mut buffer) - .unwrap(); - evaluate_sparql_query(connection, &buffer, request) - } else if content_type.starts_with("application/x-www-form-urlencoded") { - let mut buffer = Vec::default(); - body.take(MAX_SPARQL_BODY_SIZE) - .read_to_end(&mut buffer) - .unwrap(); - evaluate_urlencoded_sparql_query(connection, &buffer, request) - } else { - Response::text(format!( - "No supported content Content-Type given: {}", - content_type - )) - .with_status_code(415) - } +async fn handle_request( + request: Request, + repository: Arc, +) -> Result +where + for<'a> &'a R: Repository, +{ + let mut response = match (request.url().path(), request.method()) { + ("/query", Method::Get) => { + evaluate_urlencoded_sparql_query( + repository, + request.url().query().unwrap_or("").as_bytes().to_vec(), + request, + ) + .await? + } + ("/query", Method::Post) => { + if let Some(content_type) = request.content_type() { + if essence(&content_type) == "application/sparql-query" { + let mut buffer = String::new(); + let mut request = request; + request + .take_body() + .take(MAX_SPARQL_BODY_SIZE) + .read_to_string(&mut buffer) + .await?; + evaluate_sparql_query(repository, buffer, request).await? + } else if essence(&content_type) == "application/x-www-form-urlencoded" { + let mut buffer = Vec::new(); + let mut request = request; + request + .take_body() + .take(MAX_SPARQL_BODY_SIZE) + .read_to_end(&mut buffer) + .await?; + evaluate_urlencoded_sparql_query(repository, buffer, request).await? } else { - Response::text("No Content-Type given").with_status_code(400) + simple_response( + StatusCode::UnsupportedMediaType, + format!("No supported Content-Type given: {}", content_type), + ) } } else { - Response::text("No content given").with_status_code(400) + simple_response(StatusCode::BadRequest, "No Content-Type given") } } - _ => Response::empty_404(), - } + _ => Response::new(StatusCode::NotFound), + }; + response.append_header("Server", SERVER)?; + Ok(response) } -fn evaluate_urlencoded_sparql_query( - connection: R, - encoded: &[u8], - request: &Request, -) -> Response { - if let Some((_, query)) = form_urlencoded::parse(encoded).find(|(k, _)| k == "query") { - evaluate_sparql_query(connection, &query, request) +/// TODO: bad hack to overcome http_types limitations +fn essence(mime: &Mime) -> &str { + mime.essence().split(';').next().unwrap_or("") +} + +fn simple_response(status: StatusCode, body: impl Into) -> Response { + let mut response = Response::new(status); + response.set_body(body); + response +} + +async fn evaluate_urlencoded_sparql_query( + repository: Arc, + encoded: Vec, + request: Request, +) -> Result +where + for<'a> &'a R: Repository, +{ + if let Some((_, query)) = form_urlencoded::parse(&encoded).find(|(k, _)| k == "query") { + evaluate_sparql_query(repository, query.to_string(), request).await } else { - Response::text("You should set the 'query' parameter").with_status_code(400) + Ok(simple_response( + StatusCode::BadRequest, + "You should set the 'query' parameter", + )) } } -fn evaluate_sparql_query( - connection: R, - query: &str, - request: &Request, -) -> Response { - //TODO: stream - match connection.prepare_query(query, QueryOptions::default().with_default_graph_as_union()) { - Ok(query) => { - let results = query.exec().unwrap(); - if let QueryResult::Graph(_) = results { - let supported_formats = [ +async fn evaluate_sparql_query( + repository: Arc, + query: String, + request: Request, +) -> Result +where + for<'a> &'a R: Repository, +{ + spawn_blocking(move || { + //TODO: stream + let query = repository + .connection()? + .prepare_query(&query, QueryOptions::default()) + .map_err(|e| { + let mut e = Error::from(e); + e.set_status(StatusCode::BadRequest); + e + })?; + let results = query.exec()?; + if let QueryResult::Graph(_) = results { + let format = content_negotiation( + request, + &[ GraphSyntax::NTriples.media_type(), GraphSyntax::Turtle.media_type(), GraphSyntax::RdfXml.media_type(), - ]; - let format = if let Some(accept) = request.header("Accept") { - if let Some(media_type) = - priority_header_preferred(accept, supported_formats.iter().cloned()) - .and_then(|p| GraphSyntax::from_mime_type(supported_formats[p])) - { - media_type - } else { - return Response::text(format!( - "No supported Accept given: {}. Supported format: {:?}", - accept, supported_formats - )) - .with_status_code(415); - } - } else { - GraphSyntax::NTriples - }; + ], + )?; - Response::from_data( - format.media_type(), - results.write_graph(Vec::default(), format).unwrap(), - ) - } else { - let supported_formats = [ + let mut response = Response::from(results.write_graph(Vec::default(), format)?); + response.insert_header(headers::CONTENT_TYPE, format.media_type())?; + Ok(response) + } else { + let format = content_negotiation( + request, + &[ QueryResultSyntax::Xml.media_type(), QueryResultSyntax::Json.media_type(), - ]; - let format = if let Some(accept) = request.header("Accept") { - if let Some(media_type) = - priority_header_preferred(accept, supported_formats.iter().cloned()) - .and_then(|p| QueryResultSyntax::from_mime_type(supported_formats[p])) - { - media_type - } else { - return Response::text(format!( - "No supported Accept given: {}. Supported format: {:?}", - accept, supported_formats - )) - .with_status_code(415); - } - } else { - QueryResultSyntax::Json - }; + ], + )?; + let mut response = Response::from(results.write(Vec::default(), format)?); + response.insert_header(headers::CONTENT_TYPE, format.media_type())?; + Ok(response) + } + }) + .await +} + +async fn http_server< + F: Clone + Send + Sync + 'static + Fn(Request) -> Fut, + Fut: Send + Future>, +>( + host: String, + handle: F, +) -> Result<()> { + async fn accept Fut, Fut: Future>>( + addr: String, + stream: TcpStream, + handle: F, + ) -> Result<()> { + async_h1::accept(&addr, stream, |request| async { + Ok(match handle(request).await { + Ok(result) => result, + Err(error) => simple_response(error.status(), error.to_string()), + }) + }) + .await + } - Response::from_data( - format.media_type(), - results.write(Vec::default(), format).unwrap(), - ) + let listener = TcpListener::bind(&host).await?; + let mut incoming = listener.incoming(); + while let Some(stream) = incoming.next().await { + let stream = stream?.clone(); //TODO: clone stream? + let handle = handle.clone(); + let addr = format!("http://{}", host); + spawn(async { + if let Err(err) = accept(addr, stream, handle).await { + eprintln!("{}", err); + }; + }); + } + Ok(()) +} + +fn content_negotiation(request: Request, supported: &[&str]) -> Result { + let header = request + .header(&HeaderName::from_str("Accept").unwrap()) + .and_then(|h| h.last()) + .map(|h| h.as_str().trim()) + .unwrap_or(""); + let supported: Vec = supported + .iter() + .map(|h| Mime::from_str(h).unwrap()) + .collect(); + + let mut result = supported.first().unwrap(); + let mut result_score = 0f32; + + if !header.is_empty() { + for possible in header.split(',') { + let possible = Mime::from_str(possible.trim())?; + let score = if let Some(q) = possible.param("q") { + f32::from_str(q)? + } else { + 1. + }; + if score <= result_score { + continue; + } + for candidate in &supported { + if (possible.basetype() == candidate.basetype() || possible.basetype() == "*") + && (possible.subtype() == candidate.subtype() || possible.subtype() == "*") + { + result = candidate; + result_score = score; + break; + } } } - Err(error) => Response::text(error.to_string()).with_status_code(400), } + + F::from_mime_type(essence(result)) + .ok_or_else(|| Error::from_str(StatusCode::InternalServerError, "Unknown mime type")) }