diff --git a/src/storm/storage/expressions/PredicateExpression.cpp b/src/storm/storage/expressions/PredicateExpression.cpp index 453a776ab..0ce7565c9 100644 --- a/src/storm/storage/expressions/PredicateExpression.cpp +++ b/src/storm/storage/expressions/PredicateExpression.cpp @@ -1,5 +1,6 @@ #include "storm/storage/expressions/PredicateExpression.h" +#include "storm/storage/expressions/BooleanLiteralExpression.h" #include "storm/storage/expressions/ExpressionVisitor.h" #include "storm/utility/macros.h" @@ -42,11 +43,47 @@ namespace storm { std::shared_ptr PredicateExpression::simplify() const { std::vector> simplifiedOperands; + uint64_t trueCount = 0; for (auto const& operand : operands) { - simplifiedOperands.push_back(operand->simplify()); + auto res = operand->simplify(); + if (res->isLiteral()) { + if (res->isTrue()) { + if (predicate == PredicateType::AtLeastOneOf) { + return res; + } else { + assert(predicate == PredicateType::AtMostOneOf || predicate == PredicateType::ExactlyOneOf); + trueCount++; + simplifiedOperands.push_back(res); + } + } else { + assert (res->isFalse()); + assert (predicate == PredicateType::AtMostOneOf || predicate == PredicateType::AtLeastOneOf || predicate == PredicateType::ExactlyOneOf); + // do nothing, in particular, do not add. + } + } else { + simplifiedOperands.push_back(res); + } + } + if (trueCount > 1 && (predicate == PredicateType::ExactlyOneOf || predicate == PredicateType::AtMostOneOf)) { + return std::shared_ptr(new BooleanLiteralExpression(this->getManager(), + false)); + } + + if (simplifiedOperands.size() == 0) { + switch(predicate) { + case PredicateType::ExactlyOneOf: return std::shared_ptr(new BooleanLiteralExpression(this->getManager(), + false)); + case PredicateType::AtLeastOneOf: return std::shared_ptr(new BooleanLiteralExpression(this->getManager(), + false)); + case PredicateType::AtMostOneOf: return std::shared_ptr(new BooleanLiteralExpression(this->getManager(), + true)); + } } // Return new expression if something changed. - for (uint64_t i = 0; i < operands.size(); ++i) { + if (simplifiedOperands.size() != operands.size()) { + return std::shared_ptr(new PredicateExpression(this->getManager(), this->getType(), simplifiedOperands, predicate)); + } + for (uint64_t i = 0; i < simplifiedOperands.size(); ++i) { if (operands[i] != simplifiedOperands[i]) { return std::shared_ptr(new PredicateExpression(this->getManager(), this->getType(), simplifiedOperands, predicate)); }