diff --git a/src/graph/util/ExpressionUtils.cpp b/src/graph/util/ExpressionUtils.cpp index cb20cfcadf0..476e3381416 100644 --- a/src/graph/util/ExpressionUtils.cpp +++ b/src/graph/util/ExpressionUtils.cpp @@ -758,6 +758,14 @@ void ExpressionUtils::pullOrs(Expression *expr) { logic->setOperands(std::move(operands)); } +void ExpressionUtils::pullXors(Expression *expr) { + DCHECK(expr->kind() == Expression::Kind::kLogicalXor); + auto *logic = static_cast(expr); + std::vector operands; + pullXorsImpl(logic, operands); + logic->setOperands(std::move(operands)); +} + void ExpressionUtils::pullAndsImpl(LogicalExpression *expr, std::vector &operands) { for (auto &operand : expr->operands()) { if (operand->kind() != Expression::Kind::kLogicalAnd) { @@ -778,6 +786,16 @@ void ExpressionUtils::pullOrsImpl(LogicalExpression *expr, std::vector &operands) { + for (auto &operand : expr->operands()) { + if (operand->kind() != Expression::Kind::kLogicalXor) { + operands.emplace_back(std::move(operand)); + continue; + } + pullXorsImpl(static_cast(operand), operands); + } +} + Expression *ExpressionUtils::flattenInnerLogicalAndExpr(const Expression *expr) { auto matcher = [](const Expression *e) -> bool { return e->kind() == Expression::Kind::kLogicalAnd; @@ -1094,8 +1112,12 @@ LogicalExpression *ExpressionUtils::reverseLogicalExpr(LogicalExpression *expr) std::vector operands; if (expr->kind() == Expression::Kind::kLogicalAnd) { pullAnds(expr); - } else { + } else if (expr->kind() == Expression::Kind::kLogicalOr) { pullOrs(expr); + } else if (expr->kind() == Expression::Kind::kLogicalXor) { + pullXors(expr); + } else { + LOG(FATAL) << "Invalid logical expression kind: " << static_cast(expr->kind()); } auto &flattenOperands = static_cast(expr)->operands(); @@ -1118,8 +1140,7 @@ Expression::Kind ExpressionUtils::getNegatedLogicalExprKind(const Expression::Ki case Expression::Kind::kLogicalOr: return Expression::Kind::kLogicalAnd; case Expression::Kind::kLogicalXor: - LOG(FATAL) << "Unsupported logical expression kind: " << static_cast(kind); - break; + return Expression::Kind::kLogicalXor; default: LOG(FATAL) << "Invalid logical expression kind: " << static_cast(kind); break; diff --git a/src/graph/util/ExpressionUtils.h b/src/graph/util/ExpressionUtils.h index 29f3e96cc72..e762b4160ae 100644 --- a/src/graph/util/ExpressionUtils.h +++ b/src/graph/util/ExpressionUtils.h @@ -153,6 +153,11 @@ class ExpressionUtils { static void pullOrs(Expression* expr); static void pullOrsImpl(LogicalExpression* expr, std::vector& operands); + // For a logical XOR expression, extracts all non-logicalXorExpr from its operands and set them as + // the new operands + static void pullXors(Expression* expr); + static void pullXorsImpl(LogicalExpression* expr, std::vector& operands); + // Constructs a nested logical OR expression // Example: // [expr1, expr2, expr3] => ((expr1 OR expr2) OR expr3) diff --git a/tests/tck/features/expression/LogicalExpression.feature b/tests/tck/features/expression/LogicalExpression.feature new file mode 100644 index 00000000000..f299cb4fe74 --- /dev/null +++ b/tests/tck/features/expression/LogicalExpression.feature @@ -0,0 +1,21 @@ +# Copyright (c) 2020 vesoft inc. All rights reserved. +# +# This source code is licensed under Apache 2.0 License. +Feature: Logical Expression + + Scenario: xor crash bug fix 1 + Given a graph with space named "nba" + When executing query: + """ + match (v0:player)-[e:serve]->(v1) where not ((e.start_year == 1997 xor e.end_year != 2016) or (e.start_year > 1000 and e.end_year < 3000)) return count(*) + """ + Then the result should be, in any order: + | count(*) | + | 0 | + When executing query: + """ + match (v0:player)-[e:serve]->(v1) where not ((e.start_year == 1997 xor e.end_year != 2016) and (e.start_year > 1000 and e.end_year < 3000)) return count(*) + """ + Then the result should be, in any order: + | count(*) | + | 140 |