SPARQL: refactor AggregateExpression

Avoids code duplication
pull/673/head
Tpt 1 year ago committed by Thomas Tanon
parent 98caee8f92
commit f8034c68e9
  1. 241
      lib/spargebra/src/algebra.rs
  2. 40
      lib/spargebra/src/parser.rs
  3. 122
      lib/sparopt/src/algebra.rs
  4. 13
      lib/sparopt/src/optimizer.rs
  5. 99
      lib/src/sparql/eval.rs

@ -1114,46 +1114,11 @@ impl<'a> fmt::Display for SparqlGraphRootPattern<'a> {
/// A set function used in aggregates (c.f. [`GraphPattern::Group`]).
#[derive(Eq, PartialEq, Debug, Clone, Hash)]
pub enum AggregateExpression {
/// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount).
Count {
expr: Option<Box<Expression>>,
distinct: bool,
},
/// [Sum](https://www.w3.org/TR/sparql11-query/#defn_aggSum).
Sum {
expr: Box<Expression>,
distinct: bool,
},
/// [Avg](https://www.w3.org/TR/sparql11-query/#defn_aggAvg).
Avg {
expr: Box<Expression>,
distinct: bool,
},
/// [Min](https://www.w3.org/TR/sparql11-query/#defn_aggMin).
Min {
expr: Box<Expression>,
distinct: bool,
},
/// [Max](https://www.w3.org/TR/sparql11-query/#defn_aggMax).
Max {
expr: Box<Expression>,
distinct: bool,
},
/// [GroupConcat](https://www.w3.org/TR/sparql11-query/#defn_aggGroupConcat).
GroupConcat {
expr: Box<Expression>,
distinct: bool,
separator: Option<String>,
},
/// [Sample](https://www.w3.org/TR/sparql11-query/#defn_aggSample).
Sample {
expr: Box<Expression>,
distinct: bool,
},
/// Custom function.
Custom {
name: NamedNode,
expr: Box<Expression>,
/// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount) with *.
CountSolutions { distinct: bool },
FunctionCall {
name: AggregateFunction,
expr: Expression,
distinct: bool,
},
}
@ -1162,82 +1127,39 @@ impl AggregateExpression {
/// Formats using the [SPARQL S-Expression syntax](https://jena.apache.org/documentation/notes/sse.html).
pub(crate) fn fmt_sse(&self, f: &mut impl fmt::Write) -> fmt::Result {
match self {
Self::Count { expr, distinct } => {
write!(f, "(sum")?;
Self::CountSolutions { distinct } => {
write!(f, "(count")?;
if *distinct {
write!(f, " distinct")?;
}
if let Some(expr) = expr {
write!(f, " ")?;
expr.fmt_sse(f)?;
}
write!(f, ")")
}
Self::Sum { expr, distinct } => {
write!(f, "(sum ")?;
if *distinct {
write!(f, "distinct ")?;
}
expr.fmt_sse(f)?;
write!(f, ")")
}
Self::Avg { expr, distinct } => {
write!(f, "(avg ")?;
if *distinct {
write!(f, "distinct ")?;
}
expr.fmt_sse(f)?;
write!(f, ")")
}
Self::Min { expr, distinct } => {
write!(f, "(min ")?;
if *distinct {
write!(f, "distinct ")?;
}
expr.fmt_sse(f)?;
write!(f, ")")
}
Self::Max { expr, distinct } => {
write!(f, "(max ")?;
if *distinct {
write!(f, "distinct ")?;
}
expr.fmt_sse(f)?;
write!(f, ")")
}
Self::Sample { expr, distinct } => {
write!(f, "(sample ")?;
if *distinct {
write!(f, "distinct ")?;
}
expr.fmt_sse(f)?;
write!(f, ")")
}
Self::GroupConcat {
Self::FunctionCall {
name:
AggregateFunction::GroupConcat {
separator: Some(separator),
},
expr,
distinct,
separator,
} => {
write!(f, "(group_concat ")?;
if *distinct {
write!(f, "distinct ")?;
}
expr.fmt_sse(f)?;
if let Some(separator) = separator {
write!(f, " {}", LiteralRef::new_simple_literal(separator))?;
}
write!(f, ")")
write!(f, " {})", LiteralRef::new_simple_literal(separator))
}
Self::Custom {
Self::FunctionCall {
name,
expr,
distinct,
} => {
write!(f, "({name}")?;
write!(f, "(")?;
name.fmt_sse(f)?;
write!(f, " ")?;
if *distinct {
write!(f, " distinct")?;
write!(f, "distinct ")?;
}
write!(f, " ")?;
expr.fmt_sse(f)?;
write!(f, ")")
}
@ -1248,82 +1170,38 @@ impl AggregateExpression {
impl fmt::Display for AggregateExpression {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Count { expr, distinct } => {
Self::CountSolutions { distinct } => {
if *distinct {
if let Some(expr) = expr {
write!(f, "COUNT(DISTINCT {expr})")
} else {
write!(f, "COUNT(DISTINCT *)")
}
} else if let Some(expr) = expr {
write!(f, "COUNT({expr})")
write!(f, "COUNT(DISTINCT *)")
} else {
write!(f, "COUNT(*)")
}
}
Self::Sum { expr, distinct } => {
if *distinct {
write!(f, "SUM(DISTINCT {expr})")
} else {
write!(f, "SUM({expr})")
}
}
Self::Min { expr, distinct } => {
if *distinct {
write!(f, "MIN(DISTINCT {expr})")
} else {
write!(f, "MIN({expr})")
}
}
Self::Max { expr, distinct } => {
if *distinct {
write!(f, "MAX(DISTINCT {expr})")
} else {
write!(f, "MAX({expr})")
}
}
Self::Avg { expr, distinct } => {
if *distinct {
write!(f, "AVG(DISTINCT {expr})")
} else {
write!(f, "AVG({expr})")
}
}
Self::Sample { expr, distinct } => {
if *distinct {
write!(f, "SAMPLE(DISTINCT {expr})")
} else {
write!(f, "SAMPLE({expr})")
}
}
Self::GroupConcat {
Self::FunctionCall {
name:
AggregateFunction::GroupConcat {
separator: Some(separator),
},
expr,
distinct,
separator,
} => {
if *distinct {
if let Some(separator) = separator {
write!(
f,
"GROUP_CONCAT(DISTINCT {}; SEPARATOR = {})",
expr,
LiteralRef::new_simple_literal(separator)
)
} else {
write!(f, "GROUP_CONCAT(DISTINCT {expr})")
}
} else if let Some(separator) = separator {
write!(
f,
"GROUP_CONCAT({}; SEPARATOR = {})",
"GROUP_CONCAT(DISTINCT {}; SEPARATOR = {})",
expr,
LiteralRef::new_simple_literal(separator)
)
} else {
write!(f, "GROUP_CONCAT({expr})")
write!(
f,
"GROUP_CONCAT({}; SEPARATOR = {})",
expr,
LiteralRef::new_simple_literal(separator)
)
}
}
Self::Custom {
Self::FunctionCall {
name,
expr,
distinct,
@ -1338,6 +1216,59 @@ impl fmt::Display for AggregateExpression {
}
}
/// An aggregate function name.
#[derive(Eq, PartialEq, Debug, Clone, Hash)]
pub enum AggregateFunction {
/// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount) with *.
Count,
/// [Sum](https://www.w3.org/TR/sparql11-query/#defn_aggSum).
Sum,
/// [Avg](https://www.w3.org/TR/sparql11-query/#defn_aggAvg).
Avg,
/// [Min](https://www.w3.org/TR/sparql11-query/#defn_aggMin).
Min,
/// [Max](https://www.w3.org/TR/sparql11-query/#defn_aggMax).
Max,
/// [GroupConcat](https://www.w3.org/TR/sparql11-query/#defn_aggGroupConcat).
GroupConcat {
separator: Option<String>,
},
/// [Sample](https://www.w3.org/TR/sparql11-query/#defn_aggSample).
Sample,
Custom(NamedNode),
}
impl AggregateFunction {
/// Formats using the [SPARQL S-Expression syntax](https://jena.apache.org/documentation/notes/sse.html).
pub(crate) fn fmt_sse(&self, f: &mut impl fmt::Write) -> fmt::Result {
match self {
Self::Count => write!(f, "count"),
Self::Sum => write!(f, "sum"),
Self::Avg => write!(f, "avg"),
Self::Min => write!(f, "min"),
Self::Max => write!(f, "max"),
Self::GroupConcat { .. } => write!(f, "group_concat"),
Self::Sample => write!(f, "sample"),
Self::Custom(iri) => write!(f, "{iri}"),
}
}
}
impl fmt::Display for AggregateFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Count => write!(f, "COUNT"),
Self::Sum => write!(f, "SUM"),
Self::Avg => write!(f, "AVG"),
Self::Min => write!(f, "MIN"),
Self::Max => write!(f, "MAX"),
Self::GroupConcat { .. } => write!(f, "GROUP_CONCAT"),
Self::Sample => write!(f, "SAMPLE"),
Self::Custom(iri) => iri.fmt(f),
}
}
}
/// An ordering comparator used by [`GraphPattern::OrderBy`].
#[derive(Eq, PartialEq, Debug, Clone, Hash)]
pub enum OrderExpression {

@ -1918,26 +1918,26 @@ parser! {
rule NotExistsFunc() -> Expression = i("NOT") _ i("EXISTS") _ p:GroupGraphPattern() { Expression::Not(Box::new(Expression::Exists(Box::new(p)))) }
rule Aggregate() -> AggregateExpression =
i("COUNT") _ "(" _ i("DISTINCT") _ "*" _ ")" { AggregateExpression::Count { expr: None, distinct: true } } /
i("COUNT") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Count { expr: Some(Box::new(e)), distinct: true } } /
i("COUNT") _ "(" _ "*" _ ")" { AggregateExpression::Count { expr: None, distinct: false } } /
i("COUNT") _ "(" _ e:Expression() _ ")" { AggregateExpression::Count { expr: Some(Box::new(e)), distinct: false } } /
i("SUM") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Sum { expr: Box::new(e), distinct: true } } /
i("SUM") _ "(" _ e:Expression() _ ")" { AggregateExpression::Sum { expr: Box::new(e), distinct: false } } /
i("MIN") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Min { expr: Box::new(e), distinct: true } } /
i("MIN") _ "(" _ e:Expression() _ ")" { AggregateExpression::Min { expr: Box::new(e), distinct: false } } /
i("MAX") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Max { expr: Box::new(e), distinct: true } } /
i("MAX") _ "(" _ e:Expression() _ ")" { AggregateExpression::Max { expr: Box::new(e), distinct: false } } /
i("AVG") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Avg { expr: Box::new(e), distinct: true } } /
i("AVG") _ "(" _ e:Expression() _ ")" { AggregateExpression::Avg { expr: Box::new(e), distinct: false } } /
i("SAMPLE") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Sample { expr: Box::new(e), distinct: true } } /
i("SAMPLE") _ "(" _ e:Expression() _ ")" { AggregateExpression::Sample { expr: Box::new(e), distinct: false } } /
i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ e:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregateExpression::GroupConcat { expr: Box::new(e), distinct: true, separator: Some(s) } } /
i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::GroupConcat { expr: Box::new(e), distinct: true, separator: None } } /
i("GROUP_CONCAT") _ "(" _ e:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregateExpression::GroupConcat { expr: Box::new(e), distinct: true, separator: Some(s) } } /
i("GROUP_CONCAT") _ "(" _ e:Expression() _ ")" { AggregateExpression::GroupConcat { expr: Box::new(e), distinct: false, separator: None } } /
name:iri() _ "(" _ i("DISTINCT") _ e:Expression() _ ")" { AggregateExpression::Custom { name, expr: Box::new(e), distinct: true } } /
name:iri() _ "(" _ e:Expression() _ ")" { AggregateExpression::Custom { name, expr: Box::new(e), distinct: false } }
i("COUNT") _ "(" _ i("DISTINCT") _ "*" _ ")" { AggregateExpression::CountSolutions { distinct: true } } /
i("COUNT") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Count, expr, distinct: true } } /
i("COUNT") _ "(" _ "*" _ ")" { AggregateExpression::CountSolutions { distinct: false } } /
i("COUNT") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Count, expr, distinct: false } } /
i("SUM") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Sum, expr, distinct: true } } /
i("SUM") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Sum, expr, distinct: false } } /
i("MIN") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Min, expr, distinct: true } } /
i("MIN") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Min, expr, distinct: false } } /
i("MAX") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Max, expr, distinct: true } } /
i("MAX") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Max, expr, distinct: false } } /
i("AVG") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Avg, expr, distinct: true } } /
i("AVG") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Avg, expr, distinct: false } } /
i("SAMPLE") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Sample, expr, distinct: true } } /
i("SAMPLE") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Sample, expr, distinct: false } } /
i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ expr:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::GroupConcat { separator: Some(s) }, expr, distinct: true } } /
i("GROUP_CONCAT") _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::GroupConcat { separator: None }, expr, distinct: true } } /
i("GROUP_CONCAT") _ "(" _ expr:Expression() _ ";" _ i("SEPARATOR") _ "=" _ s:String() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::GroupConcat { separator: Some(s) }, expr, distinct: true } } /
i("GROUP_CONCAT") _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::GroupConcat { separator: None }, expr, distinct: false } } /
name:iri() _ "(" _ i("DISTINCT") _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Custom(name), expr, distinct: true } } /
name:iri() _ "(" _ expr:Expression() _ ")" { AggregateExpression::FunctionCall { name: AggregateFunction::Custom(name), expr, distinct: false } }
rule iriOrFunction() -> Expression = i: iri() _ a: ArgList()? {
match a {

@ -3,7 +3,7 @@
use oxrdf::vocab::xsd;
use rand::random;
use spargebra::algebra::{
AggregateExpression as AlAggregateExpression, Expression as AlExpression,
AggregateExpression as AlAggregateExpression, AggregateFunction, Expression as AlExpression,
GraphPattern as AlGraphPattern, OrderExpression as AlOrderExpression,
};
pub use spargebra::algebra::{Function, PropertyPathExpression};
@ -1538,46 +1538,12 @@ impl Default for MinusAlgorithm {
/// A set function used in aggregates (c.f. [`GraphPattern::Group`]).
#[derive(Eq, PartialEq, Debug, Clone, Hash)]
pub enum AggregateExpression {
/// [Count](https://www.w3.org/TR/sparql11-query/#defn_aggCount).
Count {
expr: Option<Box<Expression>>,
CountSolutions {
distinct: bool,
},
/// [Sum](https://www.w3.org/TR/sparql11-query/#defn_aggSum).
Sum {
expr: Box<Expression>,
distinct: bool,
},
/// [Avg](https://www.w3.org/TR/sparql11-query/#defn_aggAvg).
Avg {
expr: Box<Expression>,
distinct: bool,
},
/// [Min](https://www.w3.org/TR/sparql11-query/#defn_aggMin).
Min {
expr: Box<Expression>,
distinct: bool,
},
/// [Max](https://www.w3.org/TR/sparql11-query/#defn_aggMax).
Max {
expr: Box<Expression>,
distinct: bool,
},
/// [GroupConcat](https://www.w3.org/TR/sparql11-query/#defn_aggGroupConcat).
GroupConcat {
expr: Box<Expression>,
distinct: bool,
separator: Option<String>,
},
/// [Sample](https://www.w3.org/TR/sparql11-query/#defn_aggSample).
Sample {
expr: Box<Expression>,
distinct: bool,
},
/// Custom function.
Custom {
name: NamedNode,
expr: Box<Expression>,
FunctionCall {
name: AggregateFunction,
expr: Expression,
distinct: bool,
},
}
@ -1588,48 +1554,16 @@ impl AggregateExpression {
graph_name: Option<&NamedNodePattern>,
) -> Self {
match expression {
AlAggregateExpression::Count { expr, distinct } => Self::Count {
expr: expr
.as_ref()
.map(|e| Box::new(Expression::from_sparql_algebra(e, graph_name))),
distinct: *distinct,
},
AlAggregateExpression::Sum { expr, distinct } => Self::Sum {
expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)),
distinct: *distinct,
},
AlAggregateExpression::Avg { expr, distinct } => Self::Avg {
expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)),
distinct: *distinct,
},
AlAggregateExpression::Min { expr, distinct } => Self::Min {
expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)),
distinct: *distinct,
},
AlAggregateExpression::Max { expr, distinct } => Self::Max {
expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)),
AlAggregateExpression::CountSolutions { distinct } => Self::CountSolutions {
distinct: *distinct,
},
AlAggregateExpression::GroupConcat {
expr,
distinct,
separator,
} => Self::GroupConcat {
expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)),
distinct: *distinct,
separator: separator.clone(),
},
AlAggregateExpression::Sample { expr, distinct } => Self::Sample {
expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)),
distinct: *distinct,
},
AlAggregateExpression::Custom {
AlAggregateExpression::FunctionCall {
name,
expr,
distinct,
} => Self::Custom {
} => Self::FunctionCall {
name: name.clone(),
expr: Box::new(Expression::from_sparql_algebra(expr, graph_name)),
expr: Expression::from_sparql_algebra(expr, graph_name),
distinct: *distinct,
},
}
@ -1639,46 +1573,16 @@ impl AggregateExpression {
impl From<&AggregateExpression> for AlAggregateExpression {
fn from(expression: &AggregateExpression) -> Self {
match expression {
AggregateExpression::Count { expr, distinct } => Self::Count {
expr: expr.as_ref().map(|e| Box::new(e.as_ref().into())),
distinct: *distinct,
},
AggregateExpression::Sum { expr, distinct } => Self::Sum {
expr: Box::new(expr.as_ref().into()),
distinct: *distinct,
},
AggregateExpression::Avg { expr, distinct } => Self::Avg {
expr: Box::new(expr.as_ref().into()),
distinct: *distinct,
},
AggregateExpression::Min { expr, distinct } => Self::Min {
expr: Box::new(expr.as_ref().into()),
distinct: *distinct,
},
AggregateExpression::Max { expr, distinct } => Self::Max {
expr: Box::new(expr.as_ref().into()),
distinct: *distinct,
},
AggregateExpression::GroupConcat {
expr,
distinct,
separator,
} => Self::GroupConcat {
expr: Box::new(expr.as_ref().into()),
distinct: *distinct,
separator: separator.clone(),
},
AggregateExpression::Sample { expr, distinct } => Self::Sample {
expr: Box::new(expr.as_ref().into()),
AggregateExpression::CountSolutions { distinct } => Self::CountSolutions {
distinct: *distinct,
},
AggregateExpression::Custom {
AggregateExpression::FunctionCall {
name,
expr,
distinct,
} => Self::Custom {
} => Self::FunctionCall {
name: name.clone(),
expr: Box::new(expr.as_ref().into()),
expr: expr.into(),
distinct: *distinct,
},
}

@ -157,11 +157,14 @@ impl Optimizer {
inner,
variables,
aggregates,
} => GraphPattern::group(
Self::normalize_pattern(*inner, input_types),
variables,
aggregates,
),
} => {
// TODO: min, max and sample don't care about DISTINCT
GraphPattern::group(
Self::normalize_pattern(*inner, input_types),
variables,
aggregates,
)
}
GraphPattern::Service {
name,
inner,

@ -18,7 +18,7 @@ use rand::random;
use regex::{Regex, RegexBuilder};
use sha1::Sha1;
use sha2::{Sha256, Sha384, Sha512};
use spargebra::algebra::{Function, PropertyPathExpression};
use spargebra::algebra::{AggregateFunction, Function, PropertyPathExpression};
use spargebra::term::{
GroundSubject, GroundTerm, GroundTermPattern, GroundTriple, NamedNodePattern, TermPattern,
TriplePattern,
@ -974,16 +974,8 @@ impl SimpleEvaluator {
let aggregate_input_expressions = aggregates
.iter()
.map(|(_, expression)| match expression {
AggregateExpression::Count { expr, .. } => expr.as_ref().map(|e| {
self.expression_evaluator(e, encoded_variables, stat_children)
}),
AggregateExpression::Sum { expr, .. }
| AggregateExpression::Avg { expr, .. }
| AggregateExpression::Min { expr, .. }
| AggregateExpression::Max { expr, .. }
| AggregateExpression::GroupConcat { expr, .. }
| AggregateExpression::Sample { expr, .. }
| AggregateExpression::Custom { expr, .. } => {
AggregateExpression::CountSolutions { .. } => None,
AggregateExpression::FunctionCall { expr, .. } => {
Some(self.expression_evaluator(expr, encoded_variables, stat_children))
}
})
@ -1101,52 +1093,26 @@ impl SimpleEvaluator {
dataset: &Rc<DatasetView>,
expression: &AggregateExpression,
) -> Box<dyn Fn() -> Box<dyn Accumulator>> {
match expression {
AggregateExpression::Count { distinct, .. } => {
if *distinct {
Box::new(|| Box::new(DistinctAccumulator::new(CountAccumulator::default())))
} else {
Box::new(|| Box::<CountAccumulator>::default())
let mut accumulator: Box<dyn Fn() -> Box<dyn Accumulator>> = match expression {
AggregateExpression::CountSolutions { .. } => {
Box::new(|| Box::<CountAccumulator>::default())
}
AggregateExpression::FunctionCall { name, .. } => match name {
AggregateFunction::Count => Box::new(|| Box::<CountAccumulator>::default()),
AggregateFunction::Sum => Box::new(|| Box::<SumAccumulator>::default()),
AggregateFunction::Min => {
let dataset = Rc::clone(dataset);
Box::new(move || Box::new(MinAccumulator::new(Rc::clone(&dataset))))
}
}
AggregateExpression::Sum { distinct, .. } => {
if *distinct {
Box::new(|| Box::new(DistinctAccumulator::new(SumAccumulator::default())))
} else {
Box::new(|| Box::<SumAccumulator>::default())
AggregateFunction::Max => {
let dataset = Rc::clone(dataset);
Box::new(move || Box::new(MaxAccumulator::new(Rc::clone(&dataset))))
}
}
AggregateExpression::Min { .. } => {
let dataset = Rc::clone(dataset);
Box::new(move || Box::new(MinAccumulator::new(Rc::clone(&dataset))))
} // DISTINCT does not make sense with min
AggregateExpression::Max { .. } => {
let dataset = Rc::clone(dataset);
Box::new(move || Box::new(MaxAccumulator::new(Rc::clone(&dataset))))
} // DISTINCT does not make sense with max
AggregateExpression::Avg { distinct, .. } => {
if *distinct {
Box::new(|| Box::new(DistinctAccumulator::new(AvgAccumulator::default())))
} else {
Box::new(|| Box::<AvgAccumulator>::default())
}
}
AggregateExpression::Sample { .. } => Box::new(|| Box::<SampleAccumulator>::default()), // DISTINCT does not make sense with sample
AggregateExpression::GroupConcat {
distinct,
separator,
..
} => {
let dataset = Rc::clone(dataset);
let separator = Rc::from(separator.as_deref().unwrap_or(" "));
if *distinct {
Box::new(move || {
Box::new(DistinctAccumulator::new(GroupConcatAccumulator::new(
Rc::clone(&dataset),
Rc::clone(&separator),
)))
})
} else {
AggregateFunction::Avg => Box::new(|| Box::<AvgAccumulator>::default()),
AggregateFunction::Sample => Box::new(|| Box::<SampleAccumulator>::default()),
AggregateFunction::GroupConcat { separator } => {
let dataset = Rc::clone(dataset);
let separator = Rc::from(separator.as_deref().unwrap_or(" "));
Box::new(move || {
Box::new(GroupConcatAccumulator::new(
Rc::clone(&dataset),
@ -1154,9 +1120,17 @@ impl SimpleEvaluator {
))
})
}
}
AggregateExpression::Custom { .. } => Box::new(|| Box::new(FailingAccumulator)),
AggregateFunction::Custom(_) => Box::new(|| Box::new(FailingAccumulator)),
},
};
if matches!(
expression,
AggregateExpression::CountSolutions { distinct: true }
| AggregateExpression::FunctionCall { distinct: true, .. }
) {
accumulator = Box::new(move || Box::new(Deduplicate::new(accumulator())));
}
accumulator
}
fn expression_evaluator(
@ -5262,14 +5236,13 @@ trait Accumulator {
fn state(&self) -> Option<EncodedTerm>;
}
#[derive(Default, Debug)]
struct DistinctAccumulator<T: Accumulator> {
struct Deduplicate {
seen: HashSet<Option<EncodedTerm>>,
inner: T,
inner: Box<dyn Accumulator>,
}
impl<T: Accumulator> DistinctAccumulator<T> {
fn new(inner: T) -> Self {
impl Deduplicate {
fn new(inner: Box<dyn Accumulator>) -> Self {
Self {
seen: HashSet::default(),
inner,
@ -5277,7 +5250,7 @@ impl<T: Accumulator> DistinctAccumulator<T> {
}
}
impl<T: Accumulator> Accumulator for DistinctAccumulator<T> {
impl Accumulator for Deduplicate {
fn add(&mut self, element: Option<EncodedTerm>) {
if self.seen.insert(element.clone()) {
self.inner.add(element)

Loading…
Cancel
Save