diff --git a/lib/sparopt/src/reasoning.rs b/lib/sparopt/src/reasoning.rs index 190d8de1..ff022928 100644 --- a/lib/sparopt/src/reasoning.rs +++ b/lib/sparopt/src/reasoning.rs @@ -39,7 +39,6 @@ impl QueryRewriter { } pub fn rewrite_graph_pattern(&self, pattern: &GraphPattern) -> GraphPattern { - //TODO: rewrite EXISTS match pattern { GraphPattern::QuadPattern { subject, @@ -93,9 +92,10 @@ impl QueryRewriter { self.rewrite_graph_pattern(left), self.rewrite_graph_pattern(right), ), - GraphPattern::Filter { inner, expression } => { - GraphPattern::filter(self.rewrite_graph_pattern(inner), expression.clone()) - } + GraphPattern::Filter { inner, expression } => GraphPattern::filter( + self.rewrite_graph_pattern(inner), + self.rewrite_expression(expression), + ), GraphPattern::Union { inner } => inner .iter() .map(|p| self.rewrite_graph_pattern(p)) @@ -108,7 +108,7 @@ impl QueryRewriter { } => GraphPattern::extend( self.rewrite_graph_pattern(inner), variable.clone(), - expression.clone(), + self.rewrite_expression(expression), ), GraphPattern::Minus { left, right } => GraphPattern::minus( self.rewrite_graph_pattern(left), @@ -118,9 +118,18 @@ impl QueryRewriter { variables, bindings, } => GraphPattern::values(variables.clone(), bindings.clone()), - GraphPattern::OrderBy { inner, expression } => { - GraphPattern::order_by(self.rewrite_graph_pattern(inner), expression.clone()) - } + GraphPattern::OrderBy { inner, expression } => GraphPattern::order_by( + self.rewrite_graph_pattern(inner), + expression + .iter() + .map(|e| match e { + OrderExpression::Asc(e) => OrderExpression::Asc(self.rewrite_expression(e)), + OrderExpression::Desc(e) => { + OrderExpression::Desc(self.rewrite_expression(e)) + } + }) + .collect(), + ), GraphPattern::Project { inner, variables } => { GraphPattern::project(self.rewrite_graph_pattern(inner), variables.clone()) } @@ -142,14 +151,159 @@ impl QueryRewriter { } => GraphPattern::group( self.rewrite_graph_pattern(inner), variables.clone(), - aggregates.clone(), + aggregates + .iter() + .map(|(v, e)| { + ( + v.clone(), + match e { + AggregateExpression::Count { expr, distinct } => { + AggregateExpression::Count { + expr: expr + .as_ref() + .map(|e| Box::new(self.rewrite_expression(e))), + distinct: *distinct, + } + } + AggregateExpression::Sum { expr, distinct } => { + AggregateExpression::Sum { + expr: Box::new(self.rewrite_expression(expr)), + distinct: *distinct, + } + } + AggregateExpression::Min { expr, distinct } => { + AggregateExpression::Min { + expr: Box::new(self.rewrite_expression(expr)), + distinct: *distinct, + } + } + AggregateExpression::Max { expr, distinct } => { + AggregateExpression::Max { + expr: Box::new(self.rewrite_expression(expr)), + distinct: *distinct, + } + } + AggregateExpression::Avg { expr, distinct } => { + AggregateExpression::Avg { + expr: Box::new(self.rewrite_expression(expr)), + distinct: *distinct, + } + } + AggregateExpression::Sample { expr, distinct } => { + AggregateExpression::Sample { + expr: Box::new(self.rewrite_expression(expr)), + distinct: *distinct, + } + } + AggregateExpression::GroupConcat { + expr, + distinct, + separator, + } => AggregateExpression::GroupConcat { + expr: Box::new(self.rewrite_expression(expr)), + distinct: *distinct, + separator: separator.clone(), + }, + AggregateExpression::Custom { + name, + expr, + distinct, + } => AggregateExpression::Custom { + name: name.clone(), + expr: Box::new(self.rewrite_expression(expr)), + distinct: *distinct, + }, + }, + ) + }) + .collect(), ), GraphPattern::Service { inner, silent, name, } => GraphPattern::service(self.rewrite_graph_pattern(inner), name.clone(), *silent), - GraphPattern::FixedPoint { .. } => todo!(), + GraphPattern::FixedPoint { .. } => unreachable!(), + } + } + + fn rewrite_expression(&self, expression: &Expression) -> Expression { + match expression { + Expression::NamedNode(node) => node.clone().into(), + Expression::Literal(literal) => literal.clone().into(), + Expression::Variable(variable) => variable.clone().into(), + Expression::Or(left, right) => Expression::or( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::And(left, right) => Expression::and( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::Equal(left, right) => Expression::equal( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::SameTerm(left, right) => Expression::same_term( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::Greater(left, right) => Expression::greater( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::GreaterOrEqual(left, right) => Expression::greater_or_equal( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::Less(left, right) => Expression::less( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::LessOrEqual(left, right) => Expression::less_or_equal( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::Add(left, right) => Expression::add( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::Subtract(left, right) => Expression::subtract( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::Multiply(left, right) => Expression::multiply( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::Divide(left, right) => Expression::divide( + self.rewrite_expression(left), + self.rewrite_expression(right), + ), + Expression::UnaryPlus(inner) => Expression::unary_plus(self.rewrite_expression(inner)), + Expression::UnaryMinus(inner) => { + Expression::unary_minus(self.rewrite_expression(inner)) + } + Expression::Not(inner) => Expression::not(self.rewrite_expression(inner)), + Expression::Exists(inner) => Expression::exists(self.rewrite_graph_pattern(inner)), + Expression::Bound(variable) => Expression::Bound(variable.clone()), + Expression::If(cond, then, els) => Expression::if_cond( + self.rewrite_expression(cond), + self.rewrite_expression(then), + self.rewrite_expression(els), + ), + Expression::Coalesce(inners) => Expression::coalesce( + inners + .into_iter() + .map(|a| self.rewrite_expression(a)) + .collect(), + ), + Expression::FunctionCall(name, args) => Expression::call( + name.clone(), + args.into_iter() + .map(|a| self.rewrite_expression(a)) + .collect(), + ), } }