diff --git a/CHANGELOG.md b/CHANGELOG.md index ed86c3df0..fcb656e7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ Changelog This changelog lists only the most important changes. Smaller (bug)fixes as well as non-mature features are not part of the changelog. The releases of major and minor versions contain an overview of changes since the last major/minor update. +Branch Changes +-------------- + +- n-ary predicates like atMostOneOf, ExactlyOneOf added +- export to Dice expressions added + Version 1.6.x ------------- diff --git a/src/storm-parsers/parser/ExpressionCreator.cpp b/src/storm-parsers/parser/ExpressionCreator.cpp index 13ffea01b..4a491e4ec 100644 --- a/src/storm-parsers/parser/ExpressionCreator.cpp +++ b/src/storm-parsers/parser/ExpressionCreator.cpp @@ -230,6 +230,21 @@ namespace storm { } return manager.boolean(false); } + + storm::expressions::Expression ExpressionCreator::createPredicateExpression(storm::expressions::OperatorType const& opTyp, std::vector const&operands, bool &pass) const { + if (this->createExpressions) { + try { + switch (opTyp) { + case storm::expressions::OperatorType::AtLeastOneOf: return storm::expressions::atLeastOneOf(operands); + case storm::expressions::OperatorType::AtMostOneOf: return storm::expressions::atMostOneOf(operands); + case storm::expressions::OperatorType::ExactlyOneOf: return storm::expressions::exactlyOneOf(operands); + } + } catch (storm::exceptions::InvalidTypeException const& e) { + pass = false; + } + } + return manager.boolean(false); + } storm::expressions::Expression ExpressionCreator::getIdentifierExpression(std::string const& identifier, bool& pass) const { if (this->createExpressions) { diff --git a/src/storm-parsers/parser/ExpressionCreator.h b/src/storm-parsers/parser/ExpressionCreator.h index 2c6f31894..21170bbe8 100644 --- a/src/storm-parsers/parser/ExpressionCreator.h +++ b/src/storm-parsers/parser/ExpressionCreator.h @@ -72,6 +72,7 @@ namespace storm { storm::expressions::Expression createFloorCeilExpression(storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e1, bool& pass) const; storm::expressions::Expression createRoundExpression(storm::expressions::Expression const& e1, bool& pass) const; storm::expressions::Expression getIdentifierExpression(std::string const& identifier, bool& pass) const; + storm::expressions::Expression createPredicateExpression(storm::expressions::OperatorType const& opTyp, std::vector const& operands, bool& pass) const; private: diff --git a/src/storm-parsers/parser/ExpressionParser.cpp b/src/storm-parsers/parser/ExpressionParser.cpp index 4c8e0b676..0e42e156f 100644 --- a/src/storm-parsers/parser/ExpressionParser.cpp +++ b/src/storm-parsers/parser/ExpressionParser.cpp @@ -44,7 +44,14 @@ namespace storm { identifier %= qi::as_string[qi::raw[qi::lexeme[((qi::alpha | qi::char_('_') | qi::char_('.')) >> *(qi::alnum | qi::char_('_')))]]][qi::_pass = phoenix::bind(&ExpressionParser::isValidIdentifier, phoenix::ref(*this), qi::_1)]; identifier.name("identifier"); - + + if (allowBacktracking) { + predicateExpression = ((predicateOperator_ >> qi::lit("(")) >> (expression % qi::lit(",") ) >> qi::lit(")"))[qi::_val = phoenix::bind(&ExpressionCreator::createPredicateExpression, phoenix::ref(*expressionCreator), qi::_1, qi::_2, qi::_pass)]; + } else { + predicateExpression = ((predicateOperator_ >> qi::lit("(")) > (expression % qi::lit(",") ) > qi::lit(")"))[qi::_val = phoenix::bind(&ExpressionCreator::createPredicateExpression, phoenix::ref(*expressionCreator), qi::_1, qi::_2, qi::_pass)]; + } + predicateExpression.name("predicate expression"); + if (allowBacktracking) { floorCeilExpression = ((floorCeilOperator_ >> qi::lit("(")) >> expression >> qi::lit(")"))[qi::_val = phoenix::bind(&ExpressionCreator::createFloorCeilExpression, phoenix::ref(*expressionCreator), qi::_1, qi::_2, qi::_pass)]; } else { @@ -84,7 +91,7 @@ namespace storm { | qi::long_long[qi::_val = phoenix::bind(&ExpressionCreator::createIntegerLiteralExpression, phoenix::ref(*expressionCreator), qi::_1, qi::_pass)]; literalExpression.name("literal expression"); - atomicExpression = floorCeilExpression | roundExpression | prefixPowerModuloExpression | minMaxExpression | (qi::lit("(") >> expression >> qi::lit(")")) | identifierExpression | literalExpression; + atomicExpression = predicateExpression | floorCeilExpression | roundExpression | prefixPowerModuloExpression | minMaxExpression | (qi::lit("(") >> expression >> qi::lit(")")) | identifierExpression | literalExpression; atomicExpression.name("atomic expression"); unaryExpression = (*unaryOperator_ >> atomicExpression)[qi::_val = phoenix::bind(&ExpressionCreator::createUnaryExpression, phoenix::ref(*expressionCreator), qi::_1, qi::_2, qi::_pass)]; diff --git a/src/storm-parsers/parser/ExpressionParser.h b/src/storm-parsers/parser/ExpressionParser.h index ff0bb69ed..4fe673d8b 100644 --- a/src/storm-parsers/parser/ExpressionParser.h +++ b/src/storm-parsers/parser/ExpressionParser.h @@ -211,7 +211,19 @@ namespace storm { // A parser used for recognizing the operators at the "power" precedence level. prefixPowerModuloOperatorStruct prefixPowerModuloOperator_; - + + struct predicateOperatorStruct : qi::symbols { + predicateOperatorStruct() { + add + ("atLeastOneOf", storm::expressions::OperatorType::AtLeastOneOf) + ("atMostOneOf", storm::expressions::OperatorType::AtMostOneOf) + ("exactlyOneOf", storm::expressions::OperatorType::ExactlyOneOf); + } + }; + + // A parser used for recognizing the operators at the "min/max" precedence level. + predicateOperatorStruct predicateOperator_; + std::unique_ptr expressionCreator; @@ -237,6 +249,7 @@ namespace storm { qi::rule, Skipper> minMaxExpression; qi::rule, Skipper> floorCeilExpression; qi::rule roundExpression; + qi::rule predicateExpression; qi::rule identifier; // Parser that is used to recognize doubles only (as opposed to Spirit's double_ parser). diff --git a/src/storm-parsers/parser/PrismParser.h b/src/storm-parsers/parser/PrismParser.h index 2668dcb68..72611b0dc 100644 --- a/src/storm-parsers/parser/PrismParser.h +++ b/src/storm-parsers/parser/PrismParser.h @@ -128,10 +128,13 @@ namespace storm { ("max", 18) ("floor", 19) ("ceil", 20) - ("init", 21) - ("endinit", 22) - ("invariant", 23) - ("endinvariant", 24); + ("atLeastOneOf", 21) + ("atMostOneOf", 22) + ("exactlyOneOf", 23) + ("init", 24) + ("endinit", 25) + ("invariant", 26) + ("endinvariant", 27); } }; diff --git a/src/storm/storage/expressions/BaseExpression.cpp b/src/storm/storage/expressions/BaseExpression.cpp index b6f652e3f..500dcdbc2 100644 --- a/src/storm/storage/expressions/BaseExpression.cpp +++ b/src/storm/storage/expressions/BaseExpression.cpp @@ -191,6 +191,14 @@ namespace storm { VariableExpression const& BaseExpression::asVariableExpression() const { return static_cast(*this); } + + bool BaseExpression::isPredicateExpression() const { + return false; + } + + PredicateExpression const& BaseExpression::asPredicateExpression() const { + return static_cast(*this); + } std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) { expression.printToStream(stream); diff --git a/src/storm/storage/expressions/BaseExpression.h b/src/storm/storage/expressions/BaseExpression.h index 666c4d40b..cd91a6c19 100644 --- a/src/storm/storage/expressions/BaseExpression.h +++ b/src/storm/storage/expressions/BaseExpression.h @@ -32,6 +32,7 @@ namespace storm { class UnaryBooleanFunctionExpression; class UnaryNumericalFunctionExpression; class VariableExpression; + class PredicateExpression; /*! * The base class of all expression classes. @@ -286,6 +287,9 @@ namespace storm { virtual bool isVariableExpression() const; VariableExpression const& asVariableExpression() const; + + virtual bool isPredicateExpression() const; + PredicateExpression const& asPredicateExpression() const; protected: /*! diff --git a/src/storm/storage/expressions/Expression.cpp b/src/storm/storage/expressions/Expression.cpp index 873e21fc2..6ea792b75 100644 --- a/src/storm/storage/expressions/Expression.cpp +++ b/src/storm/storage/expressions/Expression.cpp @@ -13,6 +13,7 @@ #include "storm/exceptions/InvalidTypeException.h" #include "storm/exceptions/InvalidArgumentException.h" #include "storm/utility/macros.h" +#include "storm/storage/expressions/SimplificationVisitor.h" namespace storm { namespace expressions { @@ -53,6 +54,10 @@ namespace storm { return SubstitutionVisitor>(identifierToExpressionMap).substitute(*this); } + Expression Expression::substituteNonStandardPredicates() const { + return SimplificationVisitor().substitute(*this); + } + bool Expression::evaluateAsBool(Valuation const* valuation) const { return this->getBaseExpression().evaluateAsBool(valuation); } @@ -438,6 +443,33 @@ namespace storm { return ite(first < 0, floor(first), ceil(first)); } + Expression atLeastOneOf(std::vector const& expressions) { + STORM_LOG_THROW(expressions.size() > 0, storm::exceptions::InvalidArgumentException, "AtLeastOneOf requires arguments"); + std::vector> baseexpressions; + for(auto const& expr : expressions) { + baseexpressions.push_back(expr.getBaseExpressionPointer()); + } + return Expression(std::shared_ptr(new PredicateExpression(expressions.front().getManager(), expressions.front().getManager().getBooleanType(), baseexpressions, PredicateExpression::PredicateType::AtLeastOneOf))); + } + + Expression atMostOneOf(std::vector const& expressions) { + STORM_LOG_THROW(expressions.size() > 0, storm::exceptions::InvalidArgumentException, "AtMostOneOf requires arguments"); + std::vector> baseexpressions; + for(auto const& expr : expressions) { + baseexpressions.push_back(expr.getBaseExpressionPointer()); + } + return Expression(std::shared_ptr(new PredicateExpression(expressions.front().getManager(), expressions.front().getManager().getBooleanType(), baseexpressions, PredicateExpression::PredicateType::AtMostOneOf))); + } + + Expression exactlyOneOf(std::vector const& expressions) { + STORM_LOG_THROW(expressions.size() > 0, storm::exceptions::InvalidArgumentException, "ExactlyOneOf requires arguments"); + std::vector> baseexpressions; + for(auto const& expr : expressions) { + baseexpressions.push_back(expr.getBaseExpressionPointer()); + } + return Expression(std::shared_ptr(new PredicateExpression(expressions.front().getManager(), expressions.front().getManager().getBooleanType(), baseexpressions, PredicateExpression::PredicateType::ExactlyOneOf))); + } + Expression disjunction(std::vector const& expressions) { return applyAssociative(expressions, [] (Expression const& e1, Expression const& e2) { return e1 || e2; }); } diff --git a/src/storm/storage/expressions/Expression.h b/src/storm/storage/expressions/Expression.h index 22b07734e..eb4b2c9fb 100644 --- a/src/storm/storage/expressions/Expression.h +++ b/src/storm/storage/expressions/Expression.h @@ -60,6 +60,7 @@ namespace storm { friend Expression minimum(Expression const& first, Expression const& second); friend Expression maximum(Expression const& first, Expression const& second); + Expression() = default; ~Expression(); @@ -99,6 +100,11 @@ namespace storm { */ Expression substitute(std::map const& variableToExpressionMap) const; + /*! + * Eliminate nonstandard predicates from the expression. + * @return + */ + Expression substituteNonStandardPredicates() const; /*! * Substitutes all occurrences of the variables according to the given map. Note that this substitution is * done simultaneously, i.e., variables appearing in the expressions that were "plugged in" are not @@ -439,6 +445,9 @@ namespace storm { Expression modulo(Expression const& first, Expression const& second); Expression minimum(Expression const& first, Expression const& second); Expression maximum(Expression const& first, Expression const& second); + Expression atLeastOneOf(std::vector const& expressions); + Expression atMostOneOf(std::vector const& expressions); + Expression exactlyOneOf(std::vector const& expressions); Expression disjunction(std::vector const& expressions); Expression conjunction(std::vector const& expressions); Expression sum(std::vector const& expressions); diff --git a/src/storm/storage/expressions/ExpressionVisitor.cpp b/src/storm/storage/expressions/ExpressionVisitor.cpp new file mode 100644 index 000000000..3e95f2c75 --- /dev/null +++ b/src/storm/storage/expressions/ExpressionVisitor.cpp @@ -0,0 +1,11 @@ +#include "storm/storage/expressions/ExpressionVisitor.h" +#include "storm/utility/macros.h" +#include "storm/exceptions/NotImplementedException.h" + +namespace storm { + namespace expressions { + boost::any ExpressionVisitor::visit(PredicateExpression const&, boost::any const&) { + STORM_LOG_THROW(false,storm::exceptions::NotImplementedException, "Predicate Expressions are not supported by this visitor"); + } + } +} \ No newline at end of file diff --git a/src/storm/storage/expressions/ExpressionVisitor.h b/src/storm/storage/expressions/ExpressionVisitor.h index 8ef1ea90d..367178426 100644 --- a/src/storm/storage/expressions/ExpressionVisitor.h +++ b/src/storm/storage/expressions/ExpressionVisitor.h @@ -17,6 +17,7 @@ namespace storm { class BooleanLiteralExpression; class IntegerLiteralExpression; class RationalLiteralExpression; + class PredicateExpression; class ExpressionVisitor { public: @@ -32,6 +33,7 @@ namespace storm { virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) = 0; virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) = 0; virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) = 0; + virtual boost::any visit(PredicateExpression const& expression, boost::any const& data); }; } } diff --git a/src/storm/storage/expressions/Expressions.h b/src/storm/storage/expressions/Expressions.h index 272810509..bcbf2b76a 100644 --- a/src/storm/storage/expressions/Expressions.h +++ b/src/storm/storage/expressions/Expressions.h @@ -8,4 +8,5 @@ #include "storm/storage/expressions/UnaryBooleanFunctionExpression.h" #include "storm/storage/expressions/UnaryNumericalFunctionExpression.h" #include "storm/storage/expressions/VariableExpression.h" +#include "storm/storage/expressions/PredicateExpression.h" #include "storm/storage/expressions/Expression.h" diff --git a/src/storm/storage/expressions/OperatorType.cpp b/src/storm/storage/expressions/OperatorType.cpp index bba61a653..a36982465 100644 --- a/src/storm/storage/expressions/OperatorType.cpp +++ b/src/storm/storage/expressions/OperatorType.cpp @@ -27,6 +27,9 @@ namespace storm { case OperatorType::Floor: stream << "floor"; break; case OperatorType::Ceil: stream << "ceil"; break; case OperatorType::Ite: stream << "ite"; break; + case OperatorType::AtMostOneOf: stream << "atMostOneOf"; break; + case OperatorType::AtLeastOneOf: stream << "atLeastOneOf"; break; + case OperatorType::ExactlyOneOf: stream << "exactlyOneOf"; break; } return stream; } diff --git a/src/storm/storage/expressions/OperatorType.h b/src/storm/storage/expressions/OperatorType.h index e334b8104..196c900c7 100644 --- a/src/storm/storage/expressions/OperatorType.h +++ b/src/storm/storage/expressions/OperatorType.h @@ -29,7 +29,10 @@ namespace storm { Not, Floor, Ceil, - Ite + Ite, + AtLeastOneOf, + AtMostOneOf, + ExactlyOneOf }; std::ostream& operator<<(std::ostream& stream, OperatorType const& operatorType); diff --git a/src/storm/storage/expressions/PredicateExpression.cpp b/src/storm/storage/expressions/PredicateExpression.cpp new file mode 100644 index 000000000..357a0cb17 --- /dev/null +++ b/src/storm/storage/expressions/PredicateExpression.cpp @@ -0,0 +1,100 @@ + +#include "storm/storage/expressions/PredicateExpression.h" + +#include "storm/storage/expressions/ExpressionVisitor.h" +#include "storm/utility/macros.h" +#include "storm/storage/BitVector.h" +#include "storm/exceptions/InvalidTypeException.h" + +namespace storm { + namespace expressions { + OperatorType toOperatorType(PredicateExpression::PredicateType tp) { + switch (tp) { + case PredicateExpression::PredicateType::AtMostOneOf: return OperatorType::AtMostOneOf; + case PredicateExpression::PredicateType::AtLeastOneOf: return OperatorType::AtLeastOneOf; + case PredicateExpression::PredicateType::ExactlyOneOf: return OperatorType::ExactlyOneOf; + } + STORM_LOG_ASSERT(false, "Predicate type not supported"); + } + + PredicateExpression::PredicateExpression(ExpressionManager const &manager, Type const& type, std::vector > const &operands, PredicateType predicateType) : BaseExpression(manager, type), predicate(predicateType), operands(operands) {} + + // Override base class methods. + storm::expressions::OperatorType PredicateExpression::getOperator() const { + return toOperatorType(predicate); + } + + bool PredicateExpression::evaluateAsBool(Valuation const *valuation) const { + STORM_LOG_THROW(this->hasBooleanType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as boolean."); + storm::storage::BitVector results(operands.size()); + uint64_t i = 0; + for(auto const& operand : operands) { + results.set(i, operand->evaluateAsBool(valuation)); + ++i; + } + switch(predicate) { + case PredicateType::ExactlyOneOf: return results.getNumberOfSetBits() == 1; + case PredicateType::AtMostOneOf: return results.getNumberOfSetBits() <= 1; + case PredicateType::AtLeastOneOf: return results.getNumberOfSetBits() >= 1; + } + STORM_LOG_ASSERT(false, "Unknown predicate type"); + } + + std::shared_ptr PredicateExpression::simplify() const { + std::vector> simplifiedOperands; + for (auto const& operand : operands) { + simplifiedOperands.push_back(operand->simplify()); + } + return std::shared_ptr(new PredicateExpression(this->getManager(), this->getType(), simplifiedOperands, predicate)); + } + + boost::any PredicateExpression::accept(ExpressionVisitor &visitor, boost::any const &data) const { + return visitor.visit(*this, data); + } + + bool PredicateExpression::isPredicateExpression() const { + return true; + } + + bool PredicateExpression::isFunctionApplication() const { + return true; + } + + bool PredicateExpression::containsVariables() const { + for(auto const& operand : operands) { + if(operand->containsVariables()) { + return true; + } + } + return false; + } + + uint_fast64_t PredicateExpression::getArity() const { + return operands.size(); + } + + std::shared_ptr PredicateExpression::getOperand(uint_fast64_t operandIndex) const { + STORM_LOG_ASSERT(operandIndex < this->getArity(), "Invalid operand access"); + return operands[operandIndex]; + } + + void PredicateExpression::gatherVariables(std::set& variables) const { + for(auto const& operand : operands) { + operand->gatherVariables(variables); + } + } + + /*! + * Retrieves the relation associated with the expression. + * + * @return The relation associated with the expression. + */ + PredicateExpression::PredicateType PredicateExpression::getPredicateType() const { + return predicate; + } + + void PredicateExpression::printToStream(std::ostream& stream) const { + + } + } +} \ No newline at end of file diff --git a/src/storm/storage/expressions/PredicateExpression.h b/src/storm/storage/expressions/PredicateExpression.h new file mode 100644 index 000000000..cfa8a5b6e --- /dev/null +++ b/src/storm/storage/expressions/PredicateExpression.h @@ -0,0 +1,66 @@ +#pragma once + +#include "storm/storage/expressions/BaseExpression.h" + +namespace storm { + namespace expressions { + /*! + * The base class of all binary expressions. + */ + class PredicateExpression : public BaseExpression { + public: + enum class PredicateType { AtLeastOneOf, AtMostOneOf, ExactlyOneOf }; + + PredicateExpression(ExpressionManager const &manager,Type const& type, + std::vector > const &operands, + PredicateType predicateType); + + // Instantiate constructors and assignments with their default implementations. + PredicateExpression(PredicateExpression const &other) = default; + + PredicateExpression &operator=(PredicateExpression const &other) = delete; + + PredicateExpression(PredicateExpression &&) = default; + + PredicateExpression &operator=(PredicateExpression &&) = delete; + + virtual ~PredicateExpression() = default; + + // Override base class methods. + virtual storm::expressions::OperatorType getOperator() const override; + + virtual bool evaluateAsBool(Valuation const *valuation = nullptr) const override; + + virtual std::shared_ptr simplify() const override; + + virtual boost::any accept(ExpressionVisitor &visitor, boost::any const &data) const override; + + virtual bool isPredicateExpression() const override; + + virtual bool isFunctionApplication() const override; + + virtual bool containsVariables() const override; + + virtual uint_fast64_t getArity() const override; + + virtual std::shared_ptr getOperand(uint_fast64_t operandIndex) const override; + + virtual void gatherVariables(std::set& variables) const override; + + /*! + * Retrieves the relation associated with the expression. + * + * @return The relation associated with the expression. + */ + PredicateType getPredicateType() const; + + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + + private: + PredicateType predicate; + std::vector> operands; + }; + } +} diff --git a/src/storm/storage/expressions/SimplificationVisitor.cpp b/src/storm/storage/expressions/SimplificationVisitor.cpp new file mode 100644 index 000000000..0d4c0a1b0 --- /dev/null +++ b/src/storm/storage/expressions/SimplificationVisitor.cpp @@ -0,0 +1,171 @@ +#include +#include +#include + +#include "storm/storage/expressions/SimplificationVisitor.h" +#include "storm/storage/expressions/Expressions.h" +#include "storm/storage/expressions/PredicateExpression.h" +#include "storm/storage/expressions/ExpressionManager.h" + +namespace storm { + namespace expressions { + SimplificationVisitor::SimplificationVisitor() { + // Intentionally left empty. + } + + Expression SimplificationVisitor::substitute(Expression const &expression) { + return Expression(boost::any_cast>( + expression.getBaseExpression().accept(*this, boost::none))); + } + + boost::any SimplificationVisitor::visit(IfThenElseExpression const &expression, boost::any const &data) { + std::shared_ptr conditionExpression = boost::any_cast>( + expression.getCondition()->accept(*this, data)); + std::shared_ptr thenExpression = boost::any_cast>( + expression.getThenExpression()->accept(*this, data)); + std::shared_ptr elseExpression = boost::any_cast>( + expression.getElseExpression()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (conditionExpression.get() == expression.getCondition().get() && + thenExpression.get() == expression.getThenExpression().get() && + elseExpression.get() == expression.getElseExpression().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new IfThenElseExpression(expression.getManager(), expression.getType(), conditionExpression, + thenExpression, elseExpression))); + } + } + + boost::any + SimplificationVisitor::visit(BinaryBooleanFunctionExpression const &expression, boost::any const &data) { + std::shared_ptr firstExpression = boost::any_cast>( + expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>( + expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && + secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), + firstExpression, secondExpression, + expression.getOperatorType()))); + } + } + + boost::any + SimplificationVisitor::visit(BinaryNumericalFunctionExpression const &expression, boost::any const &data) { + std::shared_ptr firstExpression = boost::any_cast>( + expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>( + expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && + secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), + firstExpression, secondExpression, + expression.getOperatorType()))); + } + } + + boost::any SimplificationVisitor::visit(BinaryRelationExpression const &expression, boost::any const &data) { + std::shared_ptr firstExpression = boost::any_cast>( + expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>( + expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && + secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new BinaryRelationExpression(expression.getManager(), expression.getType(), firstExpression, + secondExpression, expression.getRelationType()))); + } + } + + boost::any SimplificationVisitor::visit(VariableExpression const &expression, boost::any const &) { + + return expression.getSharedPointer(); + + } + + boost::any + SimplificationVisitor::visit(UnaryBooleanFunctionExpression const &expression, boost::any const &data) { + std::shared_ptr operandExpression = boost::any_cast>( + expression.getOperand()->accept(*this, data)); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), + operandExpression, expression.getOperatorType()))); + } + } + + boost::any + SimplificationVisitor::visit(UnaryNumericalFunctionExpression const &expression, boost::any const &data) { + std::shared_ptr operandExpression = boost::any_cast>( + expression.getOperand()->accept(*this, data)); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), + operandExpression, expression.getOperatorType()))); + } + } + + boost::any SimplificationVisitor::visit(PredicateExpression const &expression, boost::any const &data) { + std::vector newExpressions; + for (uint64_t i = 0; i < expression.getArity(); ++i) { + newExpressions.emplace_back(boost::any_cast>( + expression.getOperand(i)->accept(*this, data))); + } + std::vector newSumExpressions; + for (auto const &expr : newExpressions) { + newSumExpressions.push_back( + ite(expr, expression.getManager().integer(1), expression.getManager().integer(0))); + } + + storm::expressions::Expression finalexpr; + if (expression.getPredicateType() == PredicateExpression::PredicateType::AtLeastOneOf) { + finalexpr = storm::expressions::sum(newSumExpressions) > expression.getManager().integer(0); + } else if (expression.getPredicateType() == PredicateExpression::PredicateType::AtMostOneOf) { + finalexpr = storm::expressions::sum(newSumExpressions) <= expression.getManager().integer(1); + } else if (expression.getPredicateType() == PredicateExpression::PredicateType::ExactlyOneOf) { + finalexpr = storm::expressions::sum(newSumExpressions) == expression.getManager().integer(1); + } else { + STORM_LOG_ASSERT(false, "Unknown predicate type."); + } + return std::const_pointer_cast(finalexpr.getBaseExpressionPointer()); + } + + + boost::any SimplificationVisitor::visit(BooleanLiteralExpression const &expression, boost::any const &) { + return expression.getSharedPointer(); + } + + boost::any SimplificationVisitor::visit(IntegerLiteralExpression const &expression, boost::any const &) { + return expression.getSharedPointer(); + } + + boost::any SimplificationVisitor::visit(RationalLiteralExpression const &expression, boost::any const &) { + return expression.getSharedPointer(); + } + + } +} \ No newline at end of file diff --git a/src/storm/storage/expressions/SimplificationVisitor.h b/src/storm/storage/expressions/SimplificationVisitor.h new file mode 100644 index 000000000..b1c80314e --- /dev/null +++ b/src/storm/storage/expressions/SimplificationVisitor.h @@ -0,0 +1,41 @@ +#pragma once +#include + +#include "storm/storage/expressions/Expression.h" +#include "storm/storage/expressions/ExpressionVisitor.h" + +namespace storm { + namespace expressions { + class SimplificationVisitor : public ExpressionVisitor { + public: + /*! + * Creates a new simplification visitor that replaces predicates by other (simpler?) predicates. + * + * Configuration: + * Currently, the visitor only replaces nonstandard predicates + * + */ + SimplificationVisitor(); + + /*! + * Simplifies based on the configuration. + */ + Expression substitute(Expression const& expression); + + virtual boost::any visit(IfThenElseExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryRelationExpression const& expression, boost::any const& data) override; + virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(PredicateExpression const& expression, boost::any const& data) override; + + protected: + // + }; + } +} diff --git a/src/storm/storage/expressions/SubstitutionVisitor.cpp b/src/storm/storage/expressions/SubstitutionVisitor.cpp index e835806c6..47f3f7139 100644 --- a/src/storm/storage/expressions/SubstitutionVisitor.cpp +++ b/src/storm/storage/expressions/SubstitutionVisitor.cpp @@ -104,6 +104,26 @@ namespace storm { return std::const_pointer_cast(std::shared_ptr(new UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); } } + + template + boost::any SubstitutionVisitor::visit(PredicateExpression const& expression, boost::any const& data) { + bool changed = false; + std::vector> newExpressions; + for (uint64_t i = 0; i < expression.getArity(); ++i) { + newExpressions.push_back(boost::any_cast>(expression.getOperand(i)->accept(*this, data))); + if (!changed && newExpressions.back() == expression.getOperand(i)) { + changed = true; + } + } + + // If the arguments did not change, we simply push the expression itself. + if (!changed) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new PredicateExpression(expression.getManager(), expression.getType(), newExpressions, expression.getPredicateType()))); + } + } + template boost::any SubstitutionVisitor::visit(BooleanLiteralExpression const& expression, boost::any const&) { diff --git a/src/storm/storage/expressions/SubstitutionVisitor.h b/src/storm/storage/expressions/SubstitutionVisitor.h index 99f0c2c18..fc2028858 100644 --- a/src/storm/storage/expressions/SubstitutionVisitor.h +++ b/src/storm/storage/expressions/SubstitutionVisitor.h @@ -38,7 +38,8 @@ namespace storm { virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; - + virtual boost::any visit(PredicateExpression const& expression, boost::any const& data) override; + protected: // A mapping of variables to expressions with which they shall be replaced. MapType const& variableToExpressionMapping; diff --git a/src/storm/storage/expressions/ToDiceStringVisitor.cpp b/src/storm/storage/expressions/ToDiceStringVisitor.cpp new file mode 100644 index 000000000..2a9467cc3 --- /dev/null +++ b/src/storm/storage/expressions/ToDiceStringVisitor.cpp @@ -0,0 +1,291 @@ +#include "storm/exceptions/NotSupportedException.h" +#include "storm/storage/expressions/ToDiceStringVisitor.h" + +namespace storm { + namespace expressions { + ToDiceStringVisitor::ToDiceStringVisitor(uint64 nrBits) : nrBits(nrBits) { + + } + + std::string ToDiceStringVisitor::toString(Expression const& expression) { + return toString(expression.getBaseExpressionPointer().get()); + } + + std::string ToDiceStringVisitor::toString(BaseExpression const* expression) { + stream.str(""); + stream.clear(); + expression->accept(*this, boost::none); + return stream.str(); + } + + boost::any ToDiceStringVisitor::visit(IfThenElseExpression const& expression, boost::any const& data) { + stream << "if "; + expression.getCondition()->accept(*this, data); + stream << " then "; + expression.getThenExpression()->accept(*this, data); + stream << " else "; + expression.getElseExpression()->accept(*this, data); + stream << ""; + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) { + switch (expression.getOperatorType()) { + case BinaryBooleanFunctionExpression::OperatorType::And: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << " && "; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryBooleanFunctionExpression::OperatorType::Or: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << " || "; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryBooleanFunctionExpression::OperatorType::Xor: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << " ^ "; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryBooleanFunctionExpression::OperatorType::Implies: + stream << "(!("; + expression.getFirstOperand()->accept(*this, data); + stream << ") || "; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryBooleanFunctionExpression::OperatorType::Iff: + expression.getFirstOperand()->accept(*this, data); + stream << " <=> "; + expression.getSecondOperand()->accept(*this, data); + break; + } + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) { + switch (expression.getOperatorType()) { + case BinaryNumericalFunctionExpression::OperatorType::Plus: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << "+"; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryNumericalFunctionExpression::OperatorType::Minus: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << "-"; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryNumericalFunctionExpression::OperatorType::Times: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << "*"; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryNumericalFunctionExpression::OperatorType::Divide: { + STORM_LOG_THROW(expression.getSecondOperand()->isIntegerLiteralExpression(), + storm::exceptions::NotSupportedException, + "Dice does not support modulo with nonconst rhs"); + uint64_t denominator = expression.getSecondOperand()->evaluateAsInt(); + int shifts = 0; + while (denominator % 2 == 0) { + denominator = denominator >> 1; + shifts++; + } + denominator = denominator >> 1; + if (denominator > 0) { + STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, + "Dice does not support division with non-powers of two"); + } + if (shifts > 0) { + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << " >> " << shifts; + stream << ")"; + } else { + expression.getFirstOperand()->accept(*this, data); + } + + } + break; + case BinaryNumericalFunctionExpression::OperatorType::Power: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << "^"; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryNumericalFunctionExpression::OperatorType::Modulo: + STORM_LOG_THROW(expression.getSecondOperand()->isIntegerLiteralExpression(), storm::exceptions::NotSupportedException, "Dice does not support modulo with nonconst rhs"); + STORM_LOG_THROW(expression.getSecondOperand()->evaluateAsInt() == 2, storm::exceptions::NotSupportedException, "Dice does not support modulo with rhs != 2"); + + stream << "( nth_bit(int(" << nrBits << "," << nrBits-1 << "), "; + expression.getFirstOperand()->accept(*this, data); + stream << "))"; + break; + case BinaryNumericalFunctionExpression::OperatorType::Max: + stream << "max("; + expression.getFirstOperand()->accept(*this, data); + stream << ","; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryNumericalFunctionExpression::OperatorType::Min: + stream << "min("; + expression.getFirstOperand()->accept(*this, data); + stream << ","; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + } + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(BinaryRelationExpression const& expression, boost::any const& data) { + switch (expression.getRelationType()) { + case BinaryRelationExpression::RelationType::Equal: + if (expression.getFirstOperand()->isBinaryNumericalFunctionExpression()) { + if (expression.getFirstOperand()->asBinaryNumericalFunctionExpression().getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Modulo) { + expression.getFirstOperand()->accept(*this, data); + } + } else { + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << "=="; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + } + + break; + case BinaryRelationExpression::RelationType::NotEqual: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << "!="; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryRelationExpression::RelationType::Less: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << "<"; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryRelationExpression::RelationType::LessOrEqual: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << "<="; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryRelationExpression::RelationType::Greater: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << ">"; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + case BinaryRelationExpression::RelationType::GreaterOrEqual: + stream << "("; + expression.getFirstOperand()->accept(*this, data); + stream << ">="; + expression.getSecondOperand()->accept(*this, data); + stream << ")"; + break; + } + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(VariableExpression const& expression, boost::any const&) { + stream << expression.getVariable().getName(); + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) { + switch (expression.getOperatorType()) { + case UnaryBooleanFunctionExpression::OperatorType::Not: + stream << "!("; + expression.getOperand()->accept(*this, data); + stream << ")"; + } + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) { + switch (expression.getOperatorType()) { + case UnaryNumericalFunctionExpression::OperatorType::Minus: + stream << "-("; + expression.getOperand()->accept(*this, data); + stream << ")"; + break; + case UnaryNumericalFunctionExpression::OperatorType::Floor: + stream << "floor("; + expression.getOperand()->accept(*this, data); + stream << ")"; + break; + case UnaryNumericalFunctionExpression::OperatorType::Ceil: + stream << "ceil("; + expression.getOperand()->accept(*this, data); + stream << ")"; + break; + } + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(PredicateExpression const& expression, boost::any const& data) { + auto pdt = expression.getPredicateType(); + STORM_LOG_ASSERT(pdt == PredicateExpression::PredicateType::ExactlyOneOf || pdt == PredicateExpression::PredicateType::AtLeastOneOf || pdt == PredicateExpression::PredicateType::AtMostOneOf, "Only some predicate types are supported."); + stream << "("; + if (expression.getPredicateType() == PredicateExpression::PredicateType::ExactlyOneOf || expression.getPredicateType() == PredicateExpression::PredicateType::AtMostOneOf) { + stream << "(true "; + for (uint64_t operandi = 0; operandi < expression.getArity(); ++operandi) { + for (uint64_t operandj = operandi + 1; operandj < expression.getArity(); ++operandj) { + stream << "&& !("; + expression.getOperand(operandi)->accept(*this, data); + stream << " && "; + expression.getOperand(operandj)->accept(*this, data); + stream << ")"; + } + } + stream << ")"; + } + if (expression.getPredicateType() == PredicateExpression::PredicateType::ExactlyOneOf) { + stream << " && "; + } + if (expression.getPredicateType() == PredicateExpression::PredicateType::ExactlyOneOf || expression.getPredicateType() == PredicateExpression::PredicateType::AtLeastOneOf) { + stream << "( false"; + for (uint64_t operandj = 0; operandj < expression.getArity(); ++operandj) { + stream << "|| "; + expression.getOperand(operandj)->accept(*this, data); + } + stream << ")"; + } + stream << ")"; + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(BooleanLiteralExpression const& expression, boost::any const&) { + stream << (expression.getValue() ? " true " : " false "); + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(IntegerLiteralExpression const& expression, boost::any const&) { + stream << "int(" << nrBits << "," << expression.getValue() << ")"; + return boost::any(); + } + + boost::any ToDiceStringVisitor::visit(RationalLiteralExpression const& expression, boost::any const&) { + stream << std::scientific << std::setprecision(std::numeric_limits::max_digits10) << "(" << expression.getValueAsDouble() << ")"; + return boost::any(); + } + } +} diff --git a/src/storm/storage/expressions/ToDiceStringVisitor.h b/src/storm/storage/expressions/ToDiceStringVisitor.h new file mode 100644 index 000000000..00a91d7f9 --- /dev/null +++ b/src/storm/storage/expressions/ToDiceStringVisitor.h @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include "storm/storage/expressions/Expression.h" +#include "storm/storage/expressions/Expressions.h" +#include "storm/storage/expressions/ExpressionVisitor.h" + +namespace storm { + namespace expressions { + class ToDiceStringVisitor : public ExpressionVisitor { + public: + ToDiceStringVisitor(uint64 nrBits); + + std::string toString(Expression const& expression); + std::string toString(BaseExpression const* expression); + + virtual boost::any visit(IfThenElseExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryRelationExpression const& expression, boost::any const& data) override; + virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(PredicateExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; + + + private: + std::stringstream stream; + uint64_t nrBits; + }; + } +} + diff --git a/src/storm/storage/prism/Assignment.cpp b/src/storm/storage/prism/Assignment.cpp index e0ccf7a92..a760ca0a6 100644 --- a/src/storm/storage/prism/Assignment.cpp +++ b/src/storm/storage/prism/Assignment.cpp @@ -21,6 +21,10 @@ namespace storm { Assignment Assignment::substitute(std::map const& substitution) const { return Assignment(this->getVariable(), this->getExpression().substitute(substitution).simplify(), this->getFilename(), this->getLineNumber()); } + + Assignment Assignment::substituteNonStandardPredicates() const { + return Assignment(this->getVariable(), this->getExpression().substituteNonStandardPredicates().simplify(), this->getFilename(), this->getLineNumber()); + } bool Assignment::isIdentity() const { if(this->expression.isVariable()) { diff --git a/src/storm/storage/prism/Assignment.h b/src/storm/storage/prism/Assignment.h index 43777a9d6..646dddc53 100644 --- a/src/storm/storage/prism/Assignment.h +++ b/src/storm/storage/prism/Assignment.h @@ -57,7 +57,9 @@ namespace storm { * @return The resulting assignment. */ Assignment substitute(std::map const& substitution) const; - + + Assignment substituteNonStandardPredicates() const; + /*! * Checks whether the assignment is an identity (lhs equals rhs) * diff --git a/src/storm/storage/prism/BooleanVariable.cpp b/src/storm/storage/prism/BooleanVariable.cpp index c3b3d3431..141965e43 100644 --- a/src/storm/storage/prism/BooleanVariable.cpp +++ b/src/storm/storage/prism/BooleanVariable.cpp @@ -11,6 +11,10 @@ namespace storm { BooleanVariable BooleanVariable::substitute(std::map const& substitution) const { return BooleanVariable(this->getExpressionVariable(), this->getInitialValueExpression().isInitialized() ? this->getInitialValueExpression().substitute(substitution) : this->getInitialValueExpression(), this->isObservable(), this->getFilename(), this->getLineNumber()); } + + BooleanVariable BooleanVariable::substituteNonStandardPredicates() const { + return BooleanVariable(this->getExpressionVariable(), this->getInitialValueExpression().isInitialized() ? this->getInitialValueExpression().substituteNonStandardPredicates() : this->getInitialValueExpression(), this->isObservable(), this->getFilename(), this->getLineNumber()); + } void BooleanVariable::createMissingInitialValue() { if (!this->hasInitialValue()) { diff --git a/src/storm/storage/prism/BooleanVariable.h b/src/storm/storage/prism/BooleanVariable.h index acdc9865f..252e4673c 100644 --- a/src/storm/storage/prism/BooleanVariable.h +++ b/src/storm/storage/prism/BooleanVariable.h @@ -34,7 +34,8 @@ namespace storm { * @return The resulting boolean variable. */ BooleanVariable substitute(std::map const& substitution) const; - + BooleanVariable substituteNonStandardPredicates() const; + virtual void createMissingInitialValue() override; friend std::ostream& operator<<(std::ostream& stream, BooleanVariable const& variable); diff --git a/src/storm/storage/prism/Command.cpp b/src/storm/storage/prism/Command.cpp index 4b357c73a..045eca424 100644 --- a/src/storm/storage/prism/Command.cpp +++ b/src/storm/storage/prism/Command.cpp @@ -56,7 +56,17 @@ namespace storm { return Command(this->getGlobalIndex(), this->isMarkovian(), this->getActionIndex(), this->getActionName(), this->getGuardExpression().substitute(substitution).simplify(), newUpdates, this->getFilename(), this->getLineNumber()); } - + + Command Command::substituteNonStandardPredicates() const { + std::vector newUpdates; + newUpdates.reserve(this->getNumberOfUpdates()); + for (auto const& update : this->getUpdates()) { + newUpdates.emplace_back(update.substituteNonStandardPredicates()); + } + + return Command(this->getGlobalIndex(), this->isMarkovian(), this->getActionIndex(), this->getActionName(), this->getGuardExpression().substituteNonStandardPredicates().simplify(), newUpdates, this->getFilename(), this->getLineNumber()); + } + bool Command::isLabeled() const { return labeled; } diff --git a/src/storm/storage/prism/Command.h b/src/storm/storage/prism/Command.h index 8213349ce..84b95474c 100644 --- a/src/storm/storage/prism/Command.h +++ b/src/storm/storage/prism/Command.h @@ -114,7 +114,8 @@ namespace storm { * @return The resulting command. */ Command substitute(std::map const& substitution) const; - + + Command substituteNonStandardPredicates() const; /*! * Retrieves whether the command possesses a synchronization label. * diff --git a/src/storm/storage/prism/Formula.cpp b/src/storm/storage/prism/Formula.cpp index a628c63b3..c3c36e01d 100644 --- a/src/storm/storage/prism/Formula.cpp +++ b/src/storm/storage/prism/Formula.cpp @@ -44,6 +44,15 @@ namespace storm { return Formula(this->getName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } } + + Formula Formula::substituteNonStandardPredicates() const { + assert(this->getExpression().isInitialized()); + if (hasExpressionVariable()) { + return Formula(this->getExpressionVariable(), this->getExpression().substituteNonStandardPredicates(), this->getFilename(), this->getLineNumber()); + } else { + return Formula(this->getName(), this->getExpression().substituteNonStandardPredicates(), this->getFilename(), this->getLineNumber()); + } + } std::ostream& operator<<(std::ostream& stream, Formula const& formula) { stream << "formula " << formula.getName() << " = " << formula.getExpression() << ";"; diff --git a/src/storm/storage/prism/Formula.h b/src/storm/storage/prism/Formula.h index 379e2cdd0..b7678f571 100644 --- a/src/storm/storage/prism/Formula.h +++ b/src/storm/storage/prism/Formula.h @@ -92,6 +92,7 @@ namespace storm { * @return The resulting formula. */ Formula substitute(std::map const& substitution) const; + Formula substituteNonStandardPredicates() const; friend std::ostream& operator<<(std::ostream& stream, Formula const& formula); diff --git a/src/storm/storage/prism/IntegerVariable.cpp b/src/storm/storage/prism/IntegerVariable.cpp index d53ea618b..6cd4f5759 100644 --- a/src/storm/storage/prism/IntegerVariable.cpp +++ b/src/storm/storage/prism/IntegerVariable.cpp @@ -21,6 +21,10 @@ namespace storm { IntegerVariable IntegerVariable::substitute(std::map const& substitution) const { return IntegerVariable(this->getExpressionVariable(), this->getLowerBoundExpression().substitute(substitution), this->getUpperBoundExpression().substitute(substitution), this->getInitialValueExpression().isInitialized() ? this->getInitialValueExpression().substitute(substitution) : this->getInitialValueExpression(), this->isObservable(), this->getFilename(), this->getLineNumber()); } + + IntegerVariable IntegerVariable::substituteNonStandardPredicates() const { + return IntegerVariable(this->getExpressionVariable(), this->getLowerBoundExpression().substituteNonStandardPredicates(), this->getUpperBoundExpression().substituteNonStandardPredicates(), this->getInitialValueExpression().isInitialized() ? this->getInitialValueExpression().substituteNonStandardPredicates() : this->getInitialValueExpression(), this->isObservable(), this->getFilename(), this->getLineNumber()); + } void IntegerVariable::createMissingInitialValue() { if (!this->hasInitialValue()) { diff --git a/src/storm/storage/prism/IntegerVariable.h b/src/storm/storage/prism/IntegerVariable.h index 67618f6eb..3069ff45a 100644 --- a/src/storm/storage/prism/IntegerVariable.h +++ b/src/storm/storage/prism/IntegerVariable.h @@ -57,7 +57,9 @@ namespace storm { * @return The resulting boolean variable. */ IntegerVariable substitute(std::map const& substitution) const; - + + IntegerVariable substituteNonStandardPredicates() const; + virtual void createMissingInitialValue() override; friend std::ostream& operator<<(std::ostream& stream, IntegerVariable const& variable); diff --git a/src/storm/storage/prism/Label.cpp b/src/storm/storage/prism/Label.cpp index 4f07baeb9..4cd8f1a59 100644 --- a/src/storm/storage/prism/Label.cpp +++ b/src/storm/storage/prism/Label.cpp @@ -18,6 +18,10 @@ namespace storm { Label Label::substitute(std::map const& substitution) const { return Label(this->getName(), this->getStatePredicateExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } + + Label Label::substituteNonStandardPredicates() const { + return Label(this->getName(), this->getStatePredicateExpression().substituteNonStandardPredicates(), this->getFilename(), this->getLineNumber()); + } std::ostream& operator<<(std::ostream& stream, Label const& label) { stream << "label \"" << label.getName() << "\" = " << label.getStatePredicateExpression() << ";"; @@ -31,5 +35,10 @@ namespace storm { ObservationLabel ObservationLabel::substitute(std::map const& substitution) const { return ObservationLabel(this->getName(), this->getStatePredicateExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } + + ObservationLabel ObservationLabel::substituteNonStandardPredicates() const { + return ObservationLabel(this->getName(), this->getStatePredicateExpression().substituteNonStandardPredicates(), this->getFilename(), this->getLineNumber()); + } + } // namespace prism } // namespace storm diff --git a/src/storm/storage/prism/Label.h b/src/storm/storage/prism/Label.h index 298e5b721..aa22653a1 100644 --- a/src/storm/storage/prism/Label.h +++ b/src/storm/storage/prism/Label.h @@ -58,6 +58,7 @@ namespace storm { * @return The resulting label. */ Label substitute(std::map const& substitution) const; + Label substituteNonStandardPredicates() const; friend std::ostream& operator<<(std::ostream& stream, Label const& label); @@ -96,7 +97,7 @@ namespace storm { * @return The resulting label. */ ObservationLabel substitute(std::map const& substitution) const; - + ObservationLabel substituteNonStandardPredicates() const; }; diff --git a/src/storm/storage/prism/Module.cpp b/src/storm/storage/prism/Module.cpp index 512985ded..3c9c8f698 100644 --- a/src/storm/storage/prism/Module.cpp +++ b/src/storm/storage/prism/Module.cpp @@ -235,6 +235,28 @@ namespace storm { return Module(this->getName(), newBooleanVariables, newIntegerVariables, this->getClockVariables(), this->getInvariant(), newCommands, this->getFilename(), this->getLineNumber()); } + + Module Module::substituteNonStandardPredicates() const { + std::vector newBooleanVariables; + newBooleanVariables.reserve(this->getNumberOfBooleanVariables()); + for (auto const& booleanVariable : this->getBooleanVariables()) { + newBooleanVariables.emplace_back(booleanVariable.substituteNonStandardPredicates()); + } + + std::vector newIntegerVariables; + newBooleanVariables.reserve(this->getNumberOfIntegerVariables()); + for (auto const& integerVariable : this->getIntegerVariables()) { + newIntegerVariables.emplace_back(integerVariable.substituteNonStandardPredicates()); + } + + std::vector newCommands; + newCommands.reserve(this->getNumberOfCommands()); + for (auto const& command : this->getCommands()) { + newCommands.emplace_back(command.substituteNonStandardPredicates()); + } + + return Module(this->getName(), newBooleanVariables, newIntegerVariables, this->getClockVariables(), this->getInvariant(), newCommands, this->getFilename(), this->getLineNumber()); + } bool Module::containsVariablesOnlyInUpdateProbabilities(std::set const& undefinedConstantVariables) const { for (auto const& booleanVariable : this->getBooleanVariables()) { diff --git a/src/storm/storage/prism/Module.h b/src/storm/storage/prism/Module.h index 138dfbe09..adeb94e01 100644 --- a/src/storm/storage/prism/Module.h +++ b/src/storm/storage/prism/Module.h @@ -250,7 +250,9 @@ namespace storm { * @return The resulting module. */ Module substitute(std::map const& substitution) const; - + + Module substituteNonStandardPredicates() const; + /*! * Checks whether the given variables only appear in the update probabilities of the module and nowhere else. * diff --git a/src/storm/storage/prism/Program.cpp b/src/storm/storage/prism/Program.cpp index cdd41f7c3..f17674472 100644 --- a/src/storm/storage/prism/Program.cpp +++ b/src/storm/storage/prism/Program.cpp @@ -885,7 +885,51 @@ namespace storm { Program Program::substituteFormulas() const { return substituteConstantsFormulas(false, true); } - + + Program Program::substituteNonStandardPredicates() const { + // TODO support in constants, initial construct, and rewards + + std::vector newFormulas; + newFormulas.reserve(this->getNumberOfFormulas()); + for (auto const& oldFormula : this->getFormulas()) { + newFormulas.emplace_back(oldFormula.substituteNonStandardPredicates()); + } + + std::vector newBooleanVariables; + newBooleanVariables.reserve(this->getNumberOfGlobalBooleanVariables()); + for (auto const& booleanVariable : this->getGlobalBooleanVariables()) { + newBooleanVariables.emplace_back(booleanVariable.substituteNonStandardPredicates()); + } + + std::vector newIntegerVariables; + newBooleanVariables.reserve(this->getNumberOfGlobalIntegerVariables()); + for (auto const& integerVariable : this->getGlobalIntegerVariables()) { + newIntegerVariables.emplace_back(integerVariable.substituteNonStandardPredicates()); + } + + std::vector newModules; + newModules.reserve(this->getNumberOfModules()); + for (auto const& module : this->getModules()) { + newModules.emplace_back(module.substituteNonStandardPredicates()); + } + + + std::vector