From f8034c68e9341010d4949f3c5532cdeba83213bc Mon Sep 17 00:00:00 2001 From: Tpt Date: Sat, 18 Nov 2023 21:20:12 +0100 Subject: [PATCH] SPARQL: refactor AggregateExpression Avoids code duplication --- lib/spargebra/src/algebra.rs | 241 +++++++++++++---------------------- lib/spargebra/src/parser.rs | 40 +++--- lib/sparopt/src/algebra.rs | 122 ++---------------- lib/sparopt/src/optimizer.rs | 13 +- lib/src/sparql/eval.rs | 99 ++++++-------- 5 files changed, 163 insertions(+), 352 deletions(-) diff --git a/lib/spargebra/src/algebra.rs b/lib/spargebra/src/algebra.rs index d694ba53..6c629246 100644 --- a/lib/spargebra/src/algebra.rs +++ b/lib/spargebra/src/algebra.rs @@ -1114,46 +1114,11 @@ impl<'a> fmt::Display for SparqlGraphRootPattern<'a> { /// A set function used in aggregates (c.f. [`GraphPattern::Group`]). #[derive(Eq, PartialEq, Debug, Clone, Hash)] pub enum AggregateExpression { - /// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount). - Count { - expr: Option>, - distinct: bool, - }, - /// [Sum](https://www.w3.org/TR/sparql11-query/#defn_aggSum). - Sum { - expr: Box, - distinct: bool, - }, - /// [Avg](https://www.w3.org/TR/sparql11-query/#defn_aggAvg). - Avg { - expr: Box, - distinct: bool, - }, - /// [Min](https://www.w3.org/TR/sparql11-query/#defn_aggMin). - Min { - expr: Box, - distinct: bool, - }, - /// [Max](https://www.w3.org/TR/sparql11-query/#defn_aggMax). - Max { - expr: Box, - distinct: bool, - }, - /// [GroupConcat](https://www.w3.org/TR/sparql11-query/#defn_aggGroupConcat). - GroupConcat { - expr: Box, - distinct: bool, - separator: Option, - }, - /// [Sample](https://www.w3.org/TR/sparql11-query/#defn_aggSample). - Sample { - expr: Box, - distinct: bool, - }, - /// Custom function. - Custom { - name: NamedNode, - expr: Box, + /// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount) with *. + CountSolutions { distinct: bool }, + FunctionCall { + name: AggregateFunction, + expr: Expression, distinct: bool, }, } @@ -1162,82 +1127,39 @@ impl AggregateExpression { /// Formats using the [SPARQL S-Expression syntax](https://jena.apache.org/documentation/notes/sse.html). pub(crate) fn fmt_sse(&self, f: &mut impl fmt::Write) -> fmt::Result { match self { - Self::Count { expr, distinct } => { - write!(f, "(sum")?; + Self::CountSolutions { distinct } => { + write!(f, "(count")?; if *distinct { write!(f, " distinct")?; } - if let Some(expr) = expr { - write!(f, " ")?; - expr.fmt_sse(f)?; - } - write!(f, ")") - } - Self::Sum { expr, distinct } => { - write!(f, "(sum ")?; - if *distinct { - write!(f, "distinct ")?; - } - expr.fmt_sse(f)?; write!(f, ")") } - Self::Avg { expr, distinct } => { - write!(f, "(avg ")?; - if *distinct { - write!(f, "distinct ")?; - } - expr.fmt_sse(f)?; - write!(f, ")") - } - Self::Min { expr, distinct } => { - write!(f, "(min ")?; - if *distinct { - write!(f, "distinct ")?; - } - expr.fmt_sse(f)?; - write!(f, ")") - } - Self::Max { expr, distinct } => { - write!(f, "(max ")?; - if *distinct { - write!(f, "distinct ")?; - } - expr.fmt_sse(f)?; - write!(f, ")") - } - Self::Sample { expr, distinct } => { - write!(f, "(sample ")?; - if *distinct { - write!(f, "distinct ")?; - } - expr.fmt_sse(f)?; - write!(f, ")") - } - Self::GroupConcat { + Self::FunctionCall { + name: + AggregateFunction::GroupConcat { + separator: Some(separator), + }, expr, distinct, - separator, } => { write!(f, "(group_concat ")?; if *distinct { write!(f, "distinct ")?; } expr.fmt_sse(f)?; - if let Some(separator) = separator { - write!(f, " {}", LiteralRef::new_simple_literal(separator))?; - } - write!(f, ")") + write!(f, " {})", LiteralRef::new_simple_literal(separator)) } - Self::Custom { + Self::FunctionCall { name, expr, distinct, } => { - write!(f, "({name}")?; + write!(f, "(")?; + name.fmt_sse(f)?; + write!(f, " ")?; if *distinct { - write!(f, " distinct")?; + write!(f, "distinct ")?; } - write!(f, " ")?; expr.fmt_sse(f)?; write!(f, ")") } @@ -1248,82 +1170,38 @@ impl AggregateExpression { impl fmt::Display for AggregateExpression { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Count { expr, distinct } => { + Self::CountSolutions { distinct } => { if *distinct { - if let Some(expr) = expr { - write!(f, "COUNT(DISTINCT {expr})") - } else { - write!(f, "COUNT(DISTINCT *)") - } - } else if let Some(expr) = expr { - write!(f, "COUNT({expr})") + write!(f, "COUNT(DISTINCT *)") } else { write!(f, "COUNT(*)") } } - Self::Sum { expr, distinct } => { - if *distinct { - write!(f, "SUM(DISTINCT {expr})") - } else { - write!(f, "SUM({expr})") - } - } - Self::Min { expr, distinct } => { - if *distinct { - write!(f, "MIN(DISTINCT {expr})") - } else { - write!(f, "MIN({expr})") - } - } - Self::Max { expr, distinct } => { - if *distinct { - write!(f, "MAX(DISTINCT {expr})") - } else { - write!(f, "MAX({expr})") - } - } - Self::Avg { expr, distinct } => { - if *distinct { - write!(f, "AVG(DISTINCT {expr})") - } else { - write!(f, "AVG({expr})") - } - } - Self::Sample { expr, distinct } => { - if *distinct { - write!(f, "SAMPLE(DISTINCT {expr})") - } else { - write!(f, "SAMPLE({expr})") - } - } - Self::GroupConcat { + Self::FunctionCall { + name: + AggregateFunction::GroupConcat { + separator: Some(separator), + }, expr, distinct, - separator, } => { if *distinct { - if let Some(separator) = separator { - write!( - f, - "GROUP_CONCAT(DISTINCT {}; SEPARATOR = {})", - expr, - LiteralRef::new_simple_literal(separator) - ) - } else { - write!(f, "GROUP_CONCAT(DISTINCT {expr})") - } - } else if let Some(separator) = separator { write!( f, - "GROUP_CONCAT({}; SEPARATOR = {})", + "GROUP_CONCAT(DISTINCT {}; SEPARATOR = {})", expr, LiteralRef::new_simple_literal(separator) ) } else { - write!(f, "GROUP_CONCAT({expr})") + write!( + f, + "GROUP_CONCAT({}; SEPARATOR = {})", + expr, + LiteralRef::new_simple_literal(separator) + ) } } - Self::Custom { + Self::FunctionCall { name, expr, distinct, @@ -1338,6 +1216,59 @@ impl fmt::Display for AggregateExpression { } } +/// An aggregate function name. +#[derive(Eq, PartialEq, Debug, Clone, Hash)] +pub enum AggregateFunction { + /// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount) with *. + Count, + /// [Sum](https://www.w3.org/TR/sparql11-query/#defn_aggSum). + Sum, + /// [Avg](https://www.w3.org/TR/sparql11-query/#defn_aggAvg). + Avg, + /// [Min](https://www.w3.org/TR/sparql11-query/#defn_aggMin). + Min, + /// [Max](https://www.w3.org/TR/sparql11-query/#defn_aggMax). + Max, + /// [GroupConcat](https://www.w3.org/TR/sparql11-query/#defn_aggGroupConcat). + GroupConcat { + separator: Option, + }, + /// [Sample](https://www.w3.org/TR/sparql11-query/#defn_aggSample). + Sample, + Custom(NamedNode), +} + +impl AggregateFunction { + /// Formats using the [SPARQL S-Expression syntax](https://jena.apache.org/documentation/notes/sse.html). + pub(crate) fn fmt_sse(&self, f: &mut impl fmt::Write) -> fmt::Result { + match self { + Self::Count => write!(f, "count"), + Self::Sum => write!(f, "sum"), + Self::Avg => write!(f, "avg"), + Self::Min => write!(f, "min"), + Self::Max => write!(f, "max"), + Self::GroupConcat { .. } => write!(f, "group_concat"), + Self::Sample => write!(f, "sample"), + Self::Custom(iri) => write!(f, "{iri}"), + } + } +} + +impl fmt::Display for AggregateFunction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Count => write!(f, "COUNT"), + Self::Sum => write!(f, "SUM"), + Self::Avg => write!(f, "AVG"), + Self::Min => write!(f, "MIN"), + Self::Max => write!(f, "MAX"), + Self::GroupConcat { .. } => write!(f, "GROUP_CONCAT"), + Self::Sample => write!(f, "SAMPLE"), + Self::Custom(iri) => iri.fmt(f), + } + } +} + /// An ordering comparator used by [`GraphPattern::OrderBy`]. #[derive(Eq, PartialEq, Debug, Clone, Hash)] pub enum OrderExpression { diff --git a/lib/spargebra/src/parser.rs b/lib/spargebra/src/parser.rs index 60f8038c..52bfbb54 100644 --- a/lib/spargebra/src/parser.rs +++ b/lib/spargebra/src/parser.rs @@ -1918,26 +1918,26 @@ parser! { rule NotExistsFunc() -> Expression = i("NOT") _ i("EXISTS") _ p:GroupGraphPattern() { Expression::Not(Box::new(Expression::Exists(Box::new(p)))) } rule Aggregate() -> AggregateExpression = - i("COUNT") _ "(" _ i("DISTINCT") _ "*" _ ")" { AggregateExpression::Count { expr: None, distinct: true } } / - i("COUNT") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Count { expr: Some(Box::new(e)), distinct: true } } / - i("COUNT") _ "(" _ "*" _ ")" { AggregateExpression::Count { expr: None, distinct: false } } / - i("COUNT") _ "(" _ e:Expression() _ ")" { AggregateExpression::Count { expr: Some(Box::new(e)), distinct: false } } / - i("SUM") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Sum { expr: Box::new(e), distinct: true } } / - i("SUM") _ "(" _ e:Expression() _ ")" { AggregateExpression::Sum { expr: Box::new(e), distinct: false } } / - i("MIN") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Min { expr: Box::new(e), distinct: true } } / - i("MIN") _ "(" _ e:Expression() _ ")" { AggregateExpression::Min { expr: Box::new(e), distinct: false } } / - i("MAX") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Max { expr: Box::new(e), distinct: true } } / - i("MAX") _ "(" _ e:Expression() _ ")" { AggregateExpression::Max { expr: Box::new(e), distinct: false } } / - i("AVG") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Avg { expr: Box::new(e), distinct: true } } / - i("AVG") _ "(" _ e:Expression() _ ")" { AggregateExpression::Avg { expr: Box::new(e), distinct: false } } / - i("SAMPLE") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Sample { expr: Box::new(e), distinct: true } } / - i("SAMPLE") _ "(" _ e:Expression() _ ")" { AggregateExpression::Sample { expr: Box::new(e), distinct: false } } / - i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ e:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregateExpression::GroupConcat { expr: Box::new(e), distinct: true, separator: Some(s) } } / - i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::GroupConcat { expr: Box::new(e), distinct: true, separator: None } } / - i("GROUP_CONCAT") _ "(" _ e:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregateExpression::GroupConcat { expr: Box::new(e), distinct: true, separator: Some(s) } } / - i("GROUP_CONCAT") _ "(" _ e:Expression() _ ")" { AggregateExpression::GroupConcat { expr: Box::new(e), distinct: false, separator: None } } / - name:iri() _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Custom { name, expr: Box::new(e), distinct: true } } / - name:iri() _ "(" _ e:Expression() _ ")" { AggregateExpression::Custom { name, expr: Box::new(e), distinct: false } } + i("COUNT") _ "(" _ i("DISTINCT") _ "*" _ ")" { AggregateExpression::CountSolutions { distinct: true } } / + i("COUNT") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Count, expr, distinct: true } } / + i("COUNT") _ "(" _ "*" _ ")" { AggregateExpression::CountSolutions { distinct: false } } / + i("COUNT") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Count, expr, distinct: false } } / + i("SUM") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Sum, expr, distinct: true } } / + i("SUM") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Sum, expr, distinct: false } } / + i("MIN") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Min, expr, distinct: true } } / + i("MIN") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Min, expr, distinct: false } } / + i("MAX") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Max, expr, distinct: true } } / + i("MAX") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Max, expr, distinct: false } } / + i("AVG") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Avg, expr, distinct: true } } / + i("AVG") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Avg, expr, distinct: false } } / + i("SAMPLE") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Sample, expr, distinct: true } } / + i("SAMPLE") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Sample, expr, distinct: false } } / + i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ expr:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::GroupConcat { separator: Some(s) }, expr, distinct: true } } / + i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::GroupConcat { separator: None }, expr, distinct: true } } / + i("GROUP_CONCAT") _ "(" _ expr:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::GroupConcat { separator: Some(s) }, expr, distinct: true } } / + i("GROUP_CONCAT") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::GroupConcat { separator: None }, expr, distinct: false } } / + name:iri() _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Custom(name), expr, distinct: true } } / + name:iri() _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Custom(name), expr, distinct: false } } rule iriOrFunction() -> Expression = i: iri() _ a: ArgList()? { match a { diff --git a/lib/sparopt/src/algebra.rs b/lib/sparopt/src/algebra.rs index fd7942d5..e5cb0952 100644 --- a/lib/sparopt/src/algebra.rs +++ b/lib/sparopt/src/algebra.rs @@ -3,7 +3,7 @@ use oxrdf::vocab::xsd; use rand::random; use spargebra::algebra::{ - AggregateExpression as AlAggregateExpression, Expression as AlExpression, + AggregateExpression as AlAggregateExpression, AggregateFunction, Expression as AlExpression, GraphPattern as AlGraphPattern, OrderExpression as AlOrderExpression, }; pub use spargebra::algebra::{Function, PropertyPathExpression}; @@ -1538,46 +1538,12 @@ impl Default for MinusAlgorithm { /// A set function used in aggregates (c.f. [`GraphPattern::Group`]). #[derive(Eq, PartialEq, Debug, Clone, Hash)] pub enum AggregateExpression { - /// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount). - Count { - expr: Option>, + CountSolutions { distinct: bool, }, - /// [Sum](https://www.w3.org/TR/sparql11-query/#defn_aggSum). - Sum { - expr: Box, - distinct: bool, - }, - /// [Avg](https://www.w3.org/TR/sparql11-query/#defn_aggAvg). - Avg { - expr: Box, - distinct: bool, - }, - /// [Min](https://www.w3.org/TR/sparql11-query/#defn_aggMin). - Min { - expr: Box, - distinct: bool, - }, - /// [Max](https://www.w3.org/TR/sparql11-query/#defn_aggMax). - Max { - expr: Box, - distinct: bool, - }, - /// [GroupConcat](https://www.w3.org/TR/sparql11-query/#defn_aggGroupConcat). - GroupConcat { - expr: Box, - distinct: bool, - separator: Option, - }, - /// [Sample](https://www.w3.org/TR/sparql11-query/#defn_aggSample). - Sample { - expr: Box, - distinct: bool, - }, - /// Custom function. - Custom { - name: NamedNode, - expr: Box, + FunctionCall { + name: AggregateFunction, + expr: Expression, distinct: bool, }, } @@ -1588,48 +1554,16 @@ impl AggregateExpression { graph_name: Option<&NamedNodePattern>, ) -> Self { match expression { - AlAggregateExpression::Count { expr, distinct } => Self::Count { - expr: expr - .as_ref() - .map(|e| Box::new(Expression::from_sparql_algebra(e, graph_name))), - distinct: *distinct, - }, - AlAggregateExpression::Sum { expr, distinct } => Self::Sum { - expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)), - distinct: *distinct, - }, - AlAggregateExpression::Avg { expr, distinct } => Self::Avg { - expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)), - distinct: *distinct, - }, - AlAggregateExpression::Min { expr, distinct } => Self::Min { - expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)), - distinct: *distinct, - }, - AlAggregateExpression::Max { expr, distinct } => Self::Max { - expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)), + AlAggregateExpression::CountSolutions { distinct } => Self::CountSolutions { distinct: *distinct, }, - AlAggregateExpression::GroupConcat { - expr, - distinct, - separator, - } => Self::GroupConcat { - expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)), - distinct: *distinct, - separator: separator.clone(), - }, - AlAggregateExpression::Sample { expr, distinct } => Self::Sample { - expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)), - distinct: *distinct, - }, - AlAggregateExpression::Custom { + AlAggregateExpression::FunctionCall { name, expr, distinct, - } => Self::Custom { + } => Self::FunctionCall { name: name.clone(), - expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)), + expr: Expression::from_sparql_algebra(expr, graph_name), distinct: *distinct, }, } @@ -1639,46 +1573,16 @@ impl AggregateExpression { impl From<&AggregateExpression> for AlAggregateExpression { fn from(expression: &AggregateExpression) -> Self { match expression { - AggregateExpression::Count { expr, distinct } => Self::Count { - expr: expr.as_ref().map(|e| Box::new(e.as_ref().into())), - distinct: *distinct, - }, - AggregateExpression::Sum { expr, distinct } => Self::Sum { - expr: Box::new(expr.as_ref().into()), - distinct: *distinct, - }, - AggregateExpression::Avg { expr, distinct } => Self::Avg { - expr: Box::new(expr.as_ref().into()), - distinct: *distinct, - }, - AggregateExpression::Min { expr, distinct } => Self::Min { - expr: Box::new(expr.as_ref().into()), - distinct: *distinct, - }, - AggregateExpression::Max { expr, distinct } => Self::Max { - expr: Box::new(expr.as_ref().into()), - distinct: *distinct, - }, - AggregateExpression::GroupConcat { - expr, - distinct, - separator, - } => Self::GroupConcat { - expr: Box::new(expr.as_ref().into()), - distinct: *distinct, - separator: separator.clone(), - }, - AggregateExpression::Sample { expr, distinct } => Self::Sample { - expr: Box::new(expr.as_ref().into()), + AggregateExpression::CountSolutions { distinct } => Self::CountSolutions { distinct: *distinct, }, - AggregateExpression::Custom { + AggregateExpression::FunctionCall { name, expr, distinct, - } => Self::Custom { + } => Self::FunctionCall { name: name.clone(), - expr: Box::new(expr.as_ref().into()), + expr: expr.into(), distinct: *distinct, }, } diff --git a/lib/sparopt/src/optimizer.rs b/lib/sparopt/src/optimizer.rs index 87902b59..5dc9d404 100644 --- a/lib/sparopt/src/optimizer.rs +++ b/lib/sparopt/src/optimizer.rs @@ -157,11 +157,14 @@ impl Optimizer { inner, variables, aggregates, - } => GraphPattern::group( - Self::normalize_pattern(*inner, input_types), - variables, - aggregates, - ), + } => { + // TODO: min, max and sample don't care about DISTINCT + GraphPattern::group( + Self::normalize_pattern(*inner, input_types), + variables, + aggregates, + ) + } GraphPattern::Service { name, inner, diff --git a/lib/src/sparql/eval.rs b/lib/src/sparql/eval.rs index 2997b83d..413fcd20 100644 --- a/lib/src/sparql/eval.rs +++ b/lib/src/sparql/eval.rs @@ -18,7 +18,7 @@ use rand::random; use regex::{Regex, RegexBuilder}; use sha1::Sha1; use sha2::{Sha256, Sha384, Sha512}; -use spargebra::algebra::{Function, PropertyPathExpression}; +use spargebra::algebra::{AggregateFunction, Function, PropertyPathExpression}; use spargebra::term::{ GroundSubject, GroundTerm, GroundTermPattern, GroundTriple, NamedNodePattern, TermPattern, TriplePattern, @@ -974,16 +974,8 @@ impl SimpleEvaluator { let aggregate_input_expressions = aggregates .iter() .map(|(_, expression)| match expression { - AggregateExpression::Count { expr, .. } => expr.as_ref().map(|e| { - self.expression_evaluator(e, encoded_variables, stat_children) - }), - AggregateExpression::Sum { expr, .. } - | AggregateExpression::Avg { expr, .. } - | AggregateExpression::Min { expr, .. } - | AggregateExpression::Max { expr, .. } - | AggregateExpression::GroupConcat { expr, .. } - | AggregateExpression::Sample { expr, .. } - | AggregateExpression::Custom { expr, .. } => { + AggregateExpression::CountSolutions { .. } => None, + AggregateExpression::FunctionCall { expr, .. } => { Some(self.expression_evaluator(expr, encoded_variables, stat_children)) } }) @@ -1101,52 +1093,26 @@ impl SimpleEvaluator { dataset: &Rc, expression: &AggregateExpression, ) -> Box Box> { - match expression { - AggregateExpression::Count { distinct, .. } => { - if *distinct { - Box::new(|| Box::new(DistinctAccumulator::new(CountAccumulator::default()))) - } else { - Box::new(|| Box::::default()) + let mut accumulator: Box Box> = match expression { + AggregateExpression::CountSolutions { .. } => { + Box::new(|| Box::::default()) + } + AggregateExpression::FunctionCall { name, .. } => match name { + AggregateFunction::Count => Box::new(|| Box::::default()), + AggregateFunction::Sum => Box::new(|| Box::::default()), + AggregateFunction::Min => { + let dataset = Rc::clone(dataset); + Box::new(move || Box::new(MinAccumulator::new(Rc::clone(&dataset)))) } - } - AggregateExpression::Sum { distinct, .. } => { - if *distinct { - Box::new(|| Box::new(DistinctAccumulator::new(SumAccumulator::default()))) - } else { - Box::new(|| Box::::default()) + AggregateFunction::Max => { + let dataset = Rc::clone(dataset); + Box::new(move || Box::new(MaxAccumulator::new(Rc::clone(&dataset)))) } - } - AggregateExpression::Min { .. } => { - let dataset = Rc::clone(dataset); - Box::new(move || Box::new(MinAccumulator::new(Rc::clone(&dataset)))) - } // DISTINCT does not make sense with min - AggregateExpression::Max { .. } => { - let dataset = Rc::clone(dataset); - Box::new(move || Box::new(MaxAccumulator::new(Rc::clone(&dataset)))) - } // DISTINCT does not make sense with max - AggregateExpression::Avg { distinct, .. } => { - if *distinct { - Box::new(|| Box::new(DistinctAccumulator::new(AvgAccumulator::default()))) - } else { - Box::new(|| Box::::default()) - } - } - AggregateExpression::Sample { .. } => Box::new(|| Box::::default()), // DISTINCT does not make sense with sample - AggregateExpression::GroupConcat { - distinct, - separator, - .. - } => { - let dataset = Rc::clone(dataset); - let separator = Rc::from(separator.as_deref().unwrap_or(" ")); - if *distinct { - Box::new(move || { - Box::new(DistinctAccumulator::new(GroupConcatAccumulator::new( - Rc::clone(&dataset), - Rc::clone(&separator), - ))) - }) - } else { + AggregateFunction::Avg => Box::new(|| Box::::default()), + AggregateFunction::Sample => Box::new(|| Box::::default()), + AggregateFunction::GroupConcat { separator } => { + let dataset = Rc::clone(dataset); + let separator = Rc::from(separator.as_deref().unwrap_or(" ")); Box::new(move || { Box::new(GroupConcatAccumulator::new( Rc::clone(&dataset), @@ -1154,9 +1120,17 @@ impl SimpleEvaluator { )) }) } - } - AggregateExpression::Custom { .. } => Box::new(|| Box::new(FailingAccumulator)), + AggregateFunction::Custom(_) => Box::new(|| Box::new(FailingAccumulator)), + }, + }; + if matches!( + expression, + AggregateExpression::CountSolutions { distinct: true } + | AggregateExpression::FunctionCall { distinct: true, .. } + ) { + accumulator = Box::new(move || Box::new(Deduplicate::new(accumulator()))); } + accumulator } fn expression_evaluator( @@ -5262,14 +5236,13 @@ trait Accumulator { fn state(&self) -> Option; } -#[derive(Default, Debug)] -struct DistinctAccumulator { +struct Deduplicate { seen: HashSet>, - inner: T, + inner: Box, } -impl DistinctAccumulator { - fn new(inner: T) -> Self { +impl Deduplicate { + fn new(inner: Box) -> Self { Self { seen: HashSet::default(), inner, @@ -5277,7 +5250,7 @@ impl DistinctAccumulator { } } -impl Accumulator for DistinctAccumulator { +impl Accumulator for Deduplicate { fn add(&mut self, element: Option) { if self.seen.insert(element.clone()) { self.inner.add(element)