diff --git a/lib/src/sparql/eval.rs b/lib/src/sparql/eval.rs index e86dd36c..068adca6 100644 --- a/lib/src/sparql/eval.rs +++ b/lib/src/sparql/eval.rs @@ -160,12 +160,30 @@ impl SimpleEvaluator { buffered_results: errors, }) } - PlanNode::LeftJoin { left, right } => Box::new(LeftJoinIterator { - eval: self.clone(), - right_plan: *right, - left_iter: self.eval_plan(*left, from), - current_right_iter: None, - }), + PlanNode::LeftJoin { + left, + right, + possible_problem_vars, + } => { + let problem_vars = bind_variables_in_set(&from, &possible_problem_vars); + let mut filtered_from = from.clone(); + unbind_variables(&mut filtered_from, &problem_vars); + let iter = LeftJoinIterator { + eval: self.clone(), + right_plan: *right, + left_iter: self.eval_plan(*left, filtered_from), + current_right_iter: None, + }; + if problem_vars.is_empty() { + Box::new(iter) + } else { + Box::new(BadLeftJoinIterator { + input: from, + iter, + problem_vars, + }) + } + } PlanNode::Filter { child, expression } => { let eval = self.clone(); Box::new(self.eval_plan(*child, from).filter(move |tuple| { @@ -635,6 +653,21 @@ fn put_value(position: usize, value: EncodedTerm, tuple: &mut EncodedTuple) { } } +fn bind_variables_in_set(binding: &[Option], set: &[usize]) -> Vec { + set.into_iter() + .cloned() + .filter(|key| *key < binding.len() && binding[*key].is_some()) + .collect() +} + +fn unbind_variables(binding: &mut [Option], variables: &[usize]) { + for var in variables { + if *var < binding.len() { + binding[*var] = None + } + } +} + fn combine_tuples(a: &[Option], b: &[Option]) -> Option { if a.len() < b.len() { let mut result = b.to_owned(); @@ -725,6 +758,42 @@ impl Iterator for LeftJoinIterator { } } +struct BadLeftJoinIterator { + input: EncodedTuple, + iter: LeftJoinIterator, + problem_vars: Vec, +} + +impl Iterator for BadLeftJoinIterator { + type Item = Result; + + fn next(&mut self) -> Option> { + loop { + match self.iter.next()? { + Ok(mut tuple) => { + let mut conflict = false; + for problem_var in &self.problem_vars { + if let Some(input_value) = self.input[*problem_var] { + if let Some(result_value) = get_tuple_value(*problem_var, &tuple) { + if input_value != result_value { + conflict = true; + continue; //Binding conflict + } + } else { + put_value(*problem_var, input_value, &mut tuple); + } + } + } + if !conflict { + return Some(Ok(tuple)); + } + } + Err(error) => return Some(Err(error)), + } + } + } +} + struct UnionIterator { eval: SimpleEvaluator, children_plan: Vec, diff --git a/lib/src/sparql/plan.rs b/lib/src/sparql/plan.rs index c8c7e5ba..0832fd85 100644 --- a/lib/src/sparql/plan.rs +++ b/lib/src/sparql/plan.rs @@ -1,6 +1,7 @@ use model::vocab::xsd; use model::Literal; use sparql::algebra::*; +use std::collections::BTreeSet; use store::encoded::EncodedQuadsStore; use store::numeric_encoder::EncodedTerm; use Result; @@ -35,6 +36,7 @@ pub enum PlanNode { LeftJoin { left: Box, right: Box, + possible_problem_vars: Vec, //Variables that should not be part of the entry of the left join }, Extend { child: Box, @@ -58,6 +60,82 @@ pub enum PlanNode { }, } +impl PlanNode { + fn variables(&self) -> BTreeSet { + let mut set = BTreeSet::default(); + self.add_variables(&mut set); + set + } + + fn add_variables(&self, set: &mut BTreeSet) { + match self { + PlanNode::Init => (), + PlanNode::StaticBindings { tuples } => { + for tuple in tuples { + for (key, value) in tuple.into_iter().enumerate() { + if value.is_some() { + set.insert(key); + } + } + } + } + PlanNode::QuadPatternJoin { + child, + subject, + predicate, + object, + graph_name, + } => { + if let PatternValue::Variable(var) = subject { + set.insert(*var); + } + if let PatternValue::Variable(var) = predicate { + set.insert(*var); + } + if let PatternValue::Variable(var) = object { + set.insert(*var); + } + if let Some(PatternValue::Variable(var)) = graph_name { + set.insert(*var); + } + child.add_variables(set); + } + PlanNode::Filter { child, expression } => { + child.add_variables(set); + expression.add_variables(set); + } //TODO: condition vars + PlanNode::Union { entry, children } => { + entry.add_variables(set); + for child in children { + child.add_variables(set); + } + } + PlanNode::Join { left, right } => { + left.add_variables(set); + right.add_variables(set); + } + PlanNode::LeftJoin { left, right, .. } => { + left.add_variables(set); + right.add_variables(set); + } + PlanNode::Extend { + child, position, .. + } => { + set.insert(*position); + child.add_variables(set); + } + PlanNode::HashDeduplicate { child } => child.add_variables(set), + PlanNode::Skip { child, .. } => child.add_variables(set), + PlanNode::Limit { child, .. } => child.add_variables(set), + PlanNode::Project { child, mapping } => { + for i in 0..mapping.len() { + set.insert(i); + } + } + } + } +} + #[derive(Eq, PartialEq, Debug, Clone, Copy, Hash)] pub enum PatternValue { Constant(EncodedTerm), @@ -161,6 +239,61 @@ pub enum PlanExpression { StringCast(Box), } +impl PlanExpression { + fn add_variables(&self, set: &mut BTreeSet) { + match self { + PlanExpression::Constant(_) | PlanExpression::BNode(None) => (), + PlanExpression::Variable(v) | PlanExpression::Bound(v) => { + set.insert(*v); + } + PlanExpression::Or(a, b) + | PlanExpression::And(a, b) + | PlanExpression::Equal(a, b) + | PlanExpression::NotEqual(a, b) + | PlanExpression::Greater(a, b) + | PlanExpression::GreaterOrEq(a, b) + | PlanExpression::Lower(a, b) + | PlanExpression::LowerOrEq(a, b) + | PlanExpression::Add(a, b) + | PlanExpression::Sub(a, b) + | PlanExpression::Mul(a, b) + | PlanExpression::Div(a, b) + | PlanExpression::SameTerm(a, b) + | PlanExpression::LangMatches(a, b) + | PlanExpression::Regex(a, b, None) => { + a.add_variables(set); + b.add_variables(set); + } + PlanExpression::UnaryPlus(e) + | PlanExpression::UnaryMinus(e) + | PlanExpression::UnaryNot(e) + | PlanExpression::Str(e) + | PlanExpression::Lang(e) + | PlanExpression::Datatype(e) + | PlanExpression::IRI(e) + | PlanExpression::BNode(Some(e)) + | PlanExpression::IsIRI(e) + | PlanExpression::IsBlank(e) + | PlanExpression::IsLiteral(e) + | PlanExpression::IsNumeric(e) + | PlanExpression::BooleanCast(e) + | PlanExpression::DoubleCast(e) + | PlanExpression::FloatCast(e) + | PlanExpression::IntegerCast(e) + | PlanExpression::DecimalCast(e) + | PlanExpression::DateTimeCast(e) + | PlanExpression::StringCast(e) => { + e.add_variables(set); + } + PlanExpression::Regex(a, b, Some(c)) => { + a.add_variables(set); + b.add_variables(set); + c.add_variables(set); + } + } + } +} + pub struct PlanBuilder<'a, S: EncodedQuadsStore> { store: &'a S, } @@ -216,22 +349,28 @@ impl<'a, S: EncodedQuadsStore> PlanBuilder<'a, S> { right: Box::new(self.build_for_graph_pattern(b, input, variables, graph_name)?), }, GraphPattern::LeftJoin(a, b, e) => { - let right = Box::new(self.build_for_graph_pattern( - b, - PlanNode::Init, - variables, - graph_name, - )?); + let left = self.build_for_graph_pattern(a, input, variables, graph_name)?; + let right = + self.build_for_graph_pattern(b, PlanNode::Init, variables, graph_name)?; + //We add the extra filter if needed + let right = if *e == Expression::from(Literal::from(true)) { + right + } else { + PlanNode::Filter { + child: Box::new(right), + expression: self.build_for_expression(e, variables)?, + } + }; + let possible_problem_vars = right + .variables() + .difference(&left.variables()) + .cloned() + .collect(); + PlanNode::LeftJoin { - left: Box::new(self.build_for_graph_pattern(a, input, variables, graph_name)?), - right: if *e == Expression::from(Literal::from(true)) { - right - } else { - Box::new(PlanNode::Filter { - child: right, - expression: self.build_for_expression(e, variables)?, - }) - }, + left: Box::new(left), + right: Box::new(right), + possible_problem_vars, } } GraphPattern::Filter(e, p) => PlanNode::Filter { diff --git a/lib/tests/sparql_test_cases.rs b/lib/tests/sparql_test_cases.rs index 5f1662f2..d7f5e47e 100644 --- a/lib/tests/sparql_test_cases.rs +++ b/lib/tests/sparql_test_cases.rs @@ -100,13 +100,6 @@ fn sparql_w3c_query_evaluation_testsuite() { .unwrap(), ]; let test_blacklist = vec![ - // Bad nested optionals - NamedNode::from_str( - "http://www.w3.org/2001/sw/DataAccess/tests/data-r2/algebra/manifest#nested-opt-1", - ).unwrap(), - NamedNode::from_str( - "http://www.w3.org/2001/sw/DataAccess/tests/data-r2/algebra/manifest#nested-opt-2", - ).unwrap(), //Multiple writing of the same xsd:integer. Our system does strong normalization. NamedNode::from_str( "http://www.w3.org/2001/sw/DataAccess/tests/data-r2/distinct/manifest#distinct-1",