From 91bcc04245aa44811af1403407a35981b825492f Mon Sep 17 00:00:00 2001 From: Tpt Date: Thu, 31 Dec 2020 22:25:05 +0100 Subject: [PATCH] Adds custom aggregate functions to SPARQL parser and algebra --- lib/src/sparql/algebra.rs | 68 ++++++++++++++++++++++++---------- lib/src/sparql/parser.rs | 44 +++++++++++----------- lib/src/sparql/plan_builder.rs | 39 ++++++++++--------- 3 files changed, 92 insertions(+), 59 deletions(-) diff --git a/lib/src/sparql/algebra.rs b/lib/src/sparql/algebra.rs index 39a7e2b3..ef4c3251 100644 --- a/lib/src/sparql/algebra.rs +++ b/lib/src/sparql/algebra.rs @@ -932,7 +932,7 @@ pub enum GraphPattern { Group { inner: Box, by: Vec, - aggregates: Vec<(Variable, SetFunction)>, + aggregates: Vec<(Variable, AggregationFunction)>, }, /// [Service](https://www.w3.org/TR/sparql11-federated-query/#defn_evalService) Service { @@ -1253,7 +1253,7 @@ impl<'a> fmt::Display for SparqlGraphPattern<'a> { "{{ SELECT {} WHERE {{ {} }} GROUP BY {} }}", aggregates .iter() - .map(|(v, a)| format!("({} AS {})", SparqlAggregation(a), v)) + .map(|(v, a)| format!("({} AS {})", SparqlAggregationFunction(a), v)) .chain(by.iter().map(|e| e.to_string())) .collect::>() .join(" "), @@ -1369,7 +1369,7 @@ fn build_sparql_select_arguments(args: &[Variable]) -> String { /// A set function used in aggregates (c.f. [`GraphPattern::Group`]) #[derive(Eq, PartialEq, Debug, Clone, Hash)] -pub enum SetFunction { +pub enum AggregationFunction { /// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount) Count { expr: Option>, @@ -1406,12 +1406,18 @@ pub enum SetFunction { expr: Box, distinct: bool, }, + /// Custom function + Custom { + name: NamedNode, + expr: Box, + distinct: bool, + }, } -impl fmt::Display for SetFunction { +impl fmt::Display for AggregationFunction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SetFunction::Count { expr, distinct } => { + AggregationFunction::Count { expr, distinct } => { if *distinct { if let Some(expr) = expr { write!(f, "(count distinct {})", expr) @@ -1424,42 +1430,42 @@ impl fmt::Display for SetFunction { write!(f, "(count)") } } - SetFunction::Sum { expr, distinct } => { + AggregationFunction::Sum { expr, distinct } => { if *distinct { write!(f, "(sum distinct {})", expr) } else { write!(f, "(sum {})", expr) } } - SetFunction::Avg { expr, distinct } => { + AggregationFunction::Avg { expr, distinct } => { if *distinct { write!(f, "(avg distinct {})", expr) } else { write!(f, "(avg {})", expr) } } - SetFunction::Min { expr, distinct } => { + AggregationFunction::Min { expr, distinct } => { if *distinct { write!(f, "(min distinct {})", expr) } else { write!(f, "(min {})", expr) } } - SetFunction::Max { expr, distinct } => { + AggregationFunction::Max { expr, distinct } => { if *distinct { write!(f, "(max distinct {})", expr) } else { write!(f, "(max {})", expr) } } - SetFunction::Sample { expr, distinct } => { + AggregationFunction::Sample { expr, distinct } => { if *distinct { write!(f, "(sample distinct {})", expr) } else { write!(f, "(sample {})", expr) } } - SetFunction::GroupConcat { + AggregationFunction::GroupConcat { expr, distinct, separator, @@ -1476,16 +1482,27 @@ impl fmt::Display for SetFunction { write!(f, "(group_concat {})", expr) } } + AggregationFunction::Custom { + name, + expr, + distinct, + } => { + if *distinct { + write!(f, "({} distinct {})", name, expr) + } else { + write!(f, "({} {})", name, expr) + } + } } } } -struct SparqlAggregation<'a>(&'a SetFunction); +struct SparqlAggregationFunction<'a>(&'a AggregationFunction); -impl<'a> fmt::Display for SparqlAggregation<'a> { +impl<'a> fmt::Display for SparqlAggregationFunction<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { - SetFunction::Count { expr, distinct } => { + AggregationFunction::Count { expr, distinct } => { if *distinct { if let Some(expr) = expr { write!(f, "COUNT(DISTINCT {})", SparqlExpression(expr)) @@ -1498,42 +1515,42 @@ impl<'a> fmt::Display for SparqlAggregation<'a> { write!(f, "COUNT(*)") } } - SetFunction::Sum { expr, distinct } => { + AggregationFunction::Sum { expr, distinct } => { if *distinct { write!(f, "SUM(DISTINCT {})", SparqlExpression(expr)) } else { write!(f, "SUM({})", SparqlExpression(expr)) } } - SetFunction::Min { expr, distinct } => { + AggregationFunction::Min { expr, distinct } => { if *distinct { write!(f, "MIN(DISTINCT {})", SparqlExpression(expr)) } else { write!(f, "MIN({})", SparqlExpression(expr)) } } - SetFunction::Max { expr, distinct } => { + AggregationFunction::Max { expr, distinct } => { if *distinct { write!(f, "MAX(DISTINCT {})", SparqlExpression(expr)) } else { write!(f, "MAX({})", SparqlExpression(expr)) } } - SetFunction::Avg { expr, distinct } => { + AggregationFunction::Avg { expr, distinct } => { if *distinct { write!(f, "AVG(DISTINCT {})", SparqlExpression(expr)) } else { write!(f, "AVG({})", SparqlExpression(expr)) } } - SetFunction::Sample { expr, distinct } => { + AggregationFunction::Sample { expr, distinct } => { if *distinct { write!(f, "SAMPLE(DISTINCT {})", SparqlExpression(expr)) } else { write!(f, "SAMPLE({})", SparqlExpression(expr)) } } - SetFunction::GroupConcat { + AggregationFunction::GroupConcat { expr, distinct, separator, @@ -1560,6 +1577,17 @@ impl<'a> fmt::Display for SparqlAggregation<'a> { write!(f, "GROUP_CONCAT({})", SparqlExpression(expr)) } } + AggregationFunction::Custom { + name, + expr, + distinct, + } => { + if *distinct { + write!(f, "{}(DISTINCT {})", name, SparqlExpression(expr)) + } else { + write!(f, "{}({})", name, SparqlExpression(expr)) + } + } } } } diff --git a/lib/src/sparql/parser.rs b/lib/src/sparql/parser.rs index 6c54f989..a3f41886 100644 --- a/lib/src/sparql/parser.rs +++ b/lib/src/sparql/parser.rs @@ -487,7 +487,7 @@ pub struct ParserState { namespaces: HashMap, used_bnodes: HashSet, currently_used_bnodes: HashSet, - aggregates: Vec>, + aggregates: Vec>, } impl ParserState { @@ -499,7 +499,7 @@ impl ParserState { } } - fn new_aggregation(&mut self, agg: SetFunction) -> Result { + fn new_aggregation(&mut self, agg: AggregationFunction) -> Result { let aggregates = self.aggregates.last_mut().ok_or("Unexpected aggregate")?; Ok(aggregates .iter() @@ -1825,25 +1825,27 @@ parser! { rule NotExistsFunc() -> Expression = i("NOT") _ i("EXISTS") _ p:GroupGraphPattern() { Expression::Not(Box::new(Expression::Exists(Box::new(p)))) } //[127] - rule Aggregate() -> SetFunction = - i("COUNT") _ "(" _ i("DISTINCT") _ "*" _ ")" { SetFunction::Count { expr: None, distinct: true } } / - i("COUNT") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { SetFunction::Count { expr: Some(Box::new(e)), distinct: true } } / - i("COUNT") _ "(" _ "*" _ ")" { SetFunction::Count { expr: None, distinct: false } } / - i("COUNT") _ "(" _ e:Expression() _ ")" { SetFunction::Count { expr: Some(Box::new(e)), distinct: false } } / - i("SUM") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { SetFunction::Sum { expr: Box::new(e), distinct: true } } / - i("SUM") _ "(" _ e:Expression() _ ")" { SetFunction::Sum { expr: Box::new(e), distinct: false } } / - i("MIN") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { SetFunction::Min { expr: Box::new(e), distinct: true } } / - i("MIN") _ "(" _ e:Expression() _ ")" { SetFunction::Min { expr: Box::new(e), distinct: false } } / - i("MAX") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { SetFunction::Max { expr: Box::new(e), distinct: true } } / - i("MAX") _ "(" _ e:Expression() _ ")" { SetFunction::Max { expr: Box::new(e), distinct: false } } / - i("AVG") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { SetFunction::Avg { expr: Box::new(e), distinct: true } } / - i("AVG") _ "(" _ e:Expression() _ ")" { SetFunction::Avg { expr: Box::new(e), distinct: false } } / - i("SAMPLE") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { SetFunction::Sample { expr: Box::new(e), distinct: true } } / - i("SAMPLE") _ "(" _ e:Expression() _ ")" { SetFunction::Sample { expr: Box::new(e), distinct: false } } / - i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ e:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { SetFunction::GroupConcat { expr: Box::new(e), distinct: true, separator: Some(s) } } / - i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { SetFunction::GroupConcat { expr: Box::new(e), distinct: true, separator: None } } / - i("GROUP_CONCAT") _ "(" _ e:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { SetFunction::GroupConcat { expr: Box::new(e), distinct: true, separator: Some(s) } } / - i("GROUP_CONCAT") _ "(" _ e:Expression() _ ")" { SetFunction::GroupConcat { expr: Box::new(e), distinct: false, separator: None } } + rule Aggregate() -> AggregationFunction = + i("COUNT") _ "(" _ i("DISTINCT") _ "*" _ ")" { AggregationFunction::Count { expr: None, distinct: true } } / + i("COUNT") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregationFunction::Count { expr: Some(Box::new(e)), distinct: true } } / + i("COUNT") _ "(" _ "*" _ ")" { AggregationFunction::Count { expr: None, distinct: false } } / + i("COUNT") _ "(" _ e:Expression() _ ")" { AggregationFunction::Count { expr: Some(Box::new(e)), distinct: false } } / + i("SUM") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregationFunction::Sum { expr: Box::new(e), distinct: true } } / + i("SUM") _ "(" _ e:Expression() _ ")" { AggregationFunction::Sum { expr: Box::new(e), distinct: false } } / + i("MIN") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregationFunction::Min { expr: Box::new(e), distinct: true } } / + i("MIN") _ "(" _ e:Expression() _ ")" { AggregationFunction::Min { expr: Box::new(e), distinct: false } } / + i("MAX") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregationFunction::Max { expr: Box::new(e), distinct: true } } / + i("MAX") _ "(" _ e:Expression() _ ")" { AggregationFunction::Max { expr: Box::new(e), distinct: false } } / + i("AVG") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregationFunction::Avg { expr: Box::new(e), distinct: true } } / + i("AVG") _ "(" _ e:Expression() _ ")" { AggregationFunction::Avg { expr: Box::new(e), distinct: false } } / + i("SAMPLE") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregationFunction::Sample { expr: Box::new(e), distinct: true } } / + i("SAMPLE") _ "(" _ e:Expression() _ ")" { AggregationFunction::Sample { expr: Box::new(e), distinct: false } } / + i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ e:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregationFunction::GroupConcat { expr: Box::new(e), distinct: true, separator: Some(s) } } / + i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregationFunction::GroupConcat { expr: Box::new(e), distinct: true, separator: None } } / + i("GROUP_CONCAT") _ "(" _ e:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregationFunction::GroupConcat { expr: Box::new(e), distinct: true, separator: Some(s) } } / + i("GROUP_CONCAT") _ "(" _ e:Expression() _ ")" { AggregationFunction::GroupConcat { expr: Box::new(e), distinct: false, separator: None } } / + name:iri() _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregationFunction::Custom { name, expr: Box::new(e), distinct: true } } / + name:iri() _ "(" _ e:Expression() _ ")" { AggregationFunction::Custom { name, expr: Box::new(e), distinct: false } } //[128] rule iriOrFunction() -> Expression = i: iri() _ a: ArgList()? { diff --git a/lib/src/sparql/plan_builder.rs b/lib/src/sparql/plan_builder.rs index 447e5973..c9ace5bb 100644 --- a/lib/src/sparql/plan_builder.rs +++ b/lib/src/sparql/plan_builder.rs @@ -818,56 +818,59 @@ impl> PlanBuilder { fn build_for_aggregate( &mut self, - aggregate: &SetFunction, + aggregate: &AggregationFunction, variables: &mut Vec, graph_name: PatternValue, ) -> Result, EvaluationError> { - Ok(match aggregate { - SetFunction::Count { expr, distinct } => PlanAggregation { + match aggregate { + AggregationFunction::Count { expr, distinct } => Ok(PlanAggregation { function: PlanAggregationFunction::Count, parameter: match expr { Some(expr) => Some(self.build_for_expression(expr, variables, graph_name)?), None => None, }, distinct: *distinct, - }, - SetFunction::Sum { expr, distinct } => PlanAggregation { + }), + AggregationFunction::Sum { expr, distinct } => Ok(PlanAggregation { function: PlanAggregationFunction::Sum, parameter: Some(self.build_for_expression(expr, variables, graph_name)?), distinct: *distinct, - }, - SetFunction::Min { expr, distinct } => PlanAggregation { + }), + AggregationFunction::Min { expr, distinct } => Ok(PlanAggregation { function: PlanAggregationFunction::Min, parameter: Some(self.build_for_expression(expr, variables, graph_name)?), distinct: *distinct, - }, - SetFunction::Max { expr, distinct } => PlanAggregation { + }), + AggregationFunction::Max { expr, distinct } => Ok(PlanAggregation { function: PlanAggregationFunction::Max, parameter: Some(self.build_for_expression(expr, variables, graph_name)?), distinct: *distinct, - }, - SetFunction::Avg { expr, distinct } => PlanAggregation { + }), + AggregationFunction::Avg { expr, distinct } => Ok(PlanAggregation { function: PlanAggregationFunction::Avg, parameter: Some(self.build_for_expression(expr, variables, graph_name)?), distinct: *distinct, - }, - SetFunction::Sample { expr, distinct } => PlanAggregation { + }), + AggregationFunction::Sample { expr, distinct } => Ok(PlanAggregation { function: PlanAggregationFunction::Sample, parameter: Some(self.build_for_expression(expr, variables, graph_name)?), distinct: *distinct, - }, - SetFunction::GroupConcat { + }), + AggregationFunction::GroupConcat { expr, distinct, separator, - } => PlanAggregation { + } => Ok(PlanAggregation { function: PlanAggregationFunction::GroupConcat { separator: Rc::new(separator.clone().unwrap_or_else(|| " ".to_string())), }, parameter: Some(self.build_for_expression(expr, variables, graph_name)?), distinct: *distinct, - }, - }) + }), + AggregationFunction::Custom { .. } => Err(EvaluationError::msg( + "Custom aggregation functions are not supported yet", + )), + } } fn build_for_graph_template(