diff --git a/src/parser/ExpressionParser.cpp b/src/parser/ExpressionParser.cpp new file mode 100644 index 000000000..6459d7a9d --- /dev/null +++ b/src/parser/ExpressionParser.cpp @@ -0,0 +1,400 @@ +#include "src/parser/ExpressionParser.h" +#include "src/exceptions/InvalidArgumentException.h" +#include "src/exceptions/InvalidTypeException.h" +#include "src/exceptions/WrongFormatException.h" + +namespace storm { + namespace parser { + ExpressionParser::ExpressionParser(qi::symbols const& invalidIdentifiers_) : ExpressionParser::base_type(expression), createExpressions(false), acceptDoubleLiterals(true), identifiers_(nullptr), invalidIdentifiers_(invalidIdentifiers_) { + identifier %= qi::as_string[qi::raw[qi::lexeme[((qi::alpha | qi::char_('_')) >> *(qi::alnum | qi::char_('_')))]]][qi::_pass = phoenix::bind(&ExpressionParser::isValidIdentifier, phoenix::ref(*this), qi::_1)]; + identifier.name("identifier"); + + floorCeilExpression = ((qi::lit("floor")[qi::_a = true] | qi::lit("ceil")[qi::_a = false]) >> qi::lit("(") >> plusExpression >> qi::lit(")"))[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createFloorExpression, phoenix::ref(*this), qi::_1)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createCeilExpression, phoenix::ref(*this), qi::_1)]]; + floorCeilExpression.name("floor/ceil expression"); + + minMaxExpression = ((qi::lit("min")[qi::_a = true] | qi::lit("max")[qi::_a = false]) >> qi::lit("(") >> plusExpression >> qi::lit(",") >> plusExpression >> qi::lit(")"))[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createMinimumExpression, phoenix::ref(*this), qi::_1, qi::_2)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createMaximumExpression, phoenix::ref(*this), qi::_1, qi::_2)]]; + minMaxExpression.name("min/max expression"); + + identifierExpression = identifier[qi::_val = phoenix::bind(&ExpressionParser::getIdentifierExpression, phoenix::ref(*this), qi::_1)]; + identifierExpression.name("identifier expression"); + + literalExpression = qi::lit("true")[qi::_val = phoenix::bind(&ExpressionParser::createTrueExpression, phoenix::ref(*this))] | qi::lit("false")[qi::_val = phoenix::bind(&ExpressionParser::createFalseExpression, phoenix::ref(*this))] | strict_double[qi::_val = phoenix::bind(&ExpressionParser::createDoubleLiteralExpression, phoenix::ref(*this), qi::_1, qi::_pass)] | qi::int_[qi::_val = phoenix::bind(&ExpressionParser::createIntegerLiteralExpression, phoenix::ref(*this), qi::_1)]; + literalExpression.name("literal expression"); + + atomicExpression = minMaxExpression | floorCeilExpression | qi::lit("(") >> expression >> qi::lit(")") | literalExpression | identifierExpression; + atomicExpression.name("atomic expression"); + + unaryExpression = atomicExpression[qi::_val = qi::_1] | (qi::lit("!") >> atomicExpression)[qi::_val = phoenix::bind(&ExpressionParser::createNotExpression, phoenix::ref(*this), qi::_1)] | (qi::lit("-") >> atomicExpression)[qi::_val = phoenix::bind(&ExpressionParser::createMinusExpression, phoenix::ref(*this), qi::_1)]; + unaryExpression.name("unary expression"); + + powerExpression = unaryExpression[qi::_val = qi::_1] >> -(qi::lit("^") > expression)[qi::_val = phoenix::bind(&ExpressionParser::createPowerExpression, phoenix::ref(*this), qi::_val, qi::_1)]; + powerExpression.name("power expression"); + + multiplicationExpression = powerExpression[qi::_val = qi::_1] >> *((qi::lit("*")[qi::_a = true] | qi::lit("/")[qi::_a = false]) >> powerExpression[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createMultExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createDivExpression, phoenix::ref(*this), qi::_val, qi::_1)]]); + multiplicationExpression.name("multiplication expression"); + + plusExpression = multiplicationExpression[qi::_val = qi::_1] >> *((qi::lit("+")[qi::_a = true] | qi::lit("-")[qi::_a = false]) >> multiplicationExpression)[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createPlusExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createMinusExpression, phoenix::ref(*this), qi::_val, qi::_1)]]; + plusExpression.name("plus expression"); + + relativeExpression = (plusExpression >> qi::lit(">=") >> plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createGreaterOrEqualExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit(">") >> plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createGreaterExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit("<=") >> plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createLessOrEqualExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit("<") >> plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createLessExpression, phoenix::ref(*this), qi::_1, qi::_2)] | plusExpression[qi::_val = qi::_1]; + relativeExpression.name("relative expression"); + + equalityExpression = relativeExpression[qi::_val = qi::_1] >> *((qi::lit("=")[qi::_a = true] | qi::lit("!=")[qi::_a = false]) >> relativeExpression)[phoenix::if_(qi::_a) [ qi::_val = phoenix::bind(&ExpressionParser::createEqualsExpression, phoenix::ref(*this), qi::_val, qi::_1) ] .else_ [ qi::_val = phoenix::bind(&ExpressionParser::createNotEqualsExpression, phoenix::ref(*this), qi::_val, qi::_1) ] ]; + equalityExpression.name("equality expression"); + + andExpression = equalityExpression[qi::_val = qi::_1] >> *(qi::lit("&") >> equalityExpression)[qi::_val = phoenix::bind(&ExpressionParser::createAndExpression, phoenix::ref(*this), qi::_val, qi::_1)]; + andExpression.name("and expression"); + + orExpression = andExpression[qi::_val = qi::_1] >> *((qi::lit("|")[qi::_a = true] | qi::lit("=>")[qi::_a = false]) >> andExpression)[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createOrExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createImpliesExpression, phoenix::ref(*this), qi::_val, qi::_1)] ]; + orExpression.name("or expression"); + + iteExpression = orExpression[qi::_val = qi::_1] >> -(qi::lit("?") > orExpression > qi::lit(":") > orExpression)[qi::_val = phoenix::bind(&ExpressionParser::createIteExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; + iteExpression.name("if-then-else expression"); + + expression %= iteExpression; + expression.name("expression"); + + // Enable error reporting. + qi::on_error(expression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(iteExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(orExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(andExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(equalityExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(relativeExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(plusExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(multiplicationExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(unaryExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(atomicExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(literalExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(identifierExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(minMaxExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + qi::on_error(floorCeilExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); + } + + void ExpressionParser::setIdentifierMapping(qi::symbols const* identifiers_) { + if (identifiers_ != nullptr) { + this->createExpressions = true; + this->identifiers_ = identifiers_; + } else { + this->createExpressions = false; + this->identifiers_ = nullptr; + } + } + + void ExpressionParser::unsetIdentifierMapping() { + this->createExpressions = false; + this->identifiers_ = nullptr; + } + + void ExpressionParser::setAcceptDoubleLiterals(bool flag) { + this->acceptDoubleLiterals = flag; + } + + storm::expressions::Expression ExpressionParser::createIteExpression(storm::expressions::Expression e1, storm::expressions::Expression e2, storm::expressions::Expression e3) const { + if (this->createExpressions) { + try { + return e1.ite(e2, e3); + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createImpliesExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1.implies(e2); + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createOrExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 || e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createAndExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try{ + return e1 && e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createGreaterExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 > e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createGreaterOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 >= e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createLessExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 < e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createLessOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 <= e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + if (e1.hasBooleanReturnType() && e2.hasBooleanReturnType()) { + return e1.iff(e2); + } else { + return e1 == e2; + } + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createNotEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 != e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createPlusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 + e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createMinusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 - e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createMultExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 * e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createPowerExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 ^ e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createDivExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return e1 / e2; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createNotExpression(storm::expressions::Expression e1) const { + if (this->createExpressions) { + try { + return !e1; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createMinusExpression(storm::expressions::Expression e1) const { + if (this->createExpressions) { + try { + return -e1; + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createTrueExpression() const { + if (this->createExpressions) { + return storm::expressions::Expression::createTrue(); + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createFalseExpression() const { + return storm::expressions::Expression::createFalse(); + } + + storm::expressions::Expression ExpressionParser::createDoubleLiteralExpression(double value, bool& pass) const { + // If we are not supposed to accept double expressions, we reject it by setting pass to false. + if (!this->acceptDoubleLiterals) { + pass = false; + } + + if (this->createExpressions) { + return storm::expressions::Expression::createDoubleLiteral(value); + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createIntegerLiteralExpression(int value) const { + if (this->createExpressions) { + return storm::expressions::Expression::createIntegerLiteral(static_cast(value)); + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createMinimumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return storm::expressions::Expression::minimum(e1, e2); + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createMaximumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + if (this->createExpressions) { + try { + return storm::expressions::Expression::maximum(e1, e2); + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createFloorExpression(storm::expressions::Expression e1) const { + if (this->createExpressions) { + try { + return e1.floor(); + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::createCeilExpression(storm::expressions::Expression e1) const { + if (this->createExpressions) { + try { + return e1.ceil(); + } catch (storm::exceptions::InvalidTypeException const& e) { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); + } + } else { + return storm::expressions::Expression::createFalse(); + } + } + + storm::expressions::Expression ExpressionParser::getIdentifierExpression(std::string const& identifier) const { + if (this->createExpressions) { + LOG_THROW(this->identifiers_ != nullptr, storm::exceptions::WrongFormatException, "Unable to substitute identifier expressions without given mapping."); + storm::expressions::Expression const* expression = this->identifiers_->find(identifier); + LOG_THROW(expression != nullptr, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": Undeclared identifier '" << identifier << "'."); + return *expression; + } else { + return storm::expressions::Expression::createFalse(); + } + } + + bool ExpressionParser::isValidIdentifier(std::string const& identifier) { + if (this->invalidIdentifiers_.find(identifier) != nullptr) { + return false; + } + return true; + } + } +} \ No newline at end of file diff --git a/src/parser/ExpressionParser.h b/src/parser/ExpressionParser.h new file mode 100644 index 000000000..380644caf --- /dev/null +++ b/src/parser/ExpressionParser.h @@ -0,0 +1,128 @@ +#ifndef STORM_PARSER_EXPRESSIONPARSER_H_ +#define STORM_PARSER_EXPRESSIONPARSER_H_ + +#include "src/parser/SpiritParserDefinitions.h" +#include "src/storage/expressions/Expression.h" +#include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/WrongFormatException.h" + +namespace storm { + namespace parser { + class ExpressionParser : public qi::grammar { + public: + /*! + * Creates an expression parser. Initially the parser is set to a mode in which it will not generate the + * actual expressions but only perform a syntax check and return the expression "false". To make the parser + * generate the actual expressions, a mapping of valid identifiers to their expressions need to be provided + * later. + * + * @param invalidIdentifiers_ A symbol table of identifiers that are to be rejected. + */ + ExpressionParser(qi::symbols const& invalidIdentifiers_); + + /*! + * Sets an identifier mapping that is used to determine valid variables in the expression. The mapped-to + * expressions will be substituted wherever the key value appears in the parsed expression. After setting + * this, the parser will generate expressions. + * + * @param identifiers A pointer to a mapping from identifiers to expressions. + */ + void setIdentifierMapping(qi::symbols const* identifiers_); + + /*! + * Unsets a previously set identifier mapping. This will make the parser not generate expressions any more + * but merely check for syntactic correctness of an expression. + */ + void unsetIdentifierMapping(); + + /*! + * Sets whether double literals are to be accepted or not. + * + * @param flag If set to true, double literals are accepted. + */ + void setAcceptDoubleLiterals(bool flag); + + private: + // A flag that indicates whether expressions should actually be generated or just a syntax check shall be + // performed. + bool createExpressions; + + // A flag that indicates whether double literals are accepted. + bool acceptDoubleLiterals; + + // The currently used mapping of identifiers to expressions. This is used if the parser is set to create + // expressions. + qi::symbols const* identifiers_; + + // The symbol table of invalid identifiers. + qi::symbols const& invalidIdentifiers_; + + // Rules for parsing a composed expression. + qi::rule expression; + qi::rule iteExpression; + qi::rule, Skipper> orExpression; + qi::rule andExpression; + qi::rule relativeExpression; + qi::rule, Skipper> equalityExpression; + qi::rule, Skipper> plusExpression; + qi::rule, Skipper> multiplicationExpression; + qi::rule, Skipper> powerExpression; + qi::rule unaryExpression; + qi::rule atomicExpression; + qi::rule literalExpression; + qi::rule identifierExpression; + qi::rule, Skipper> minMaxExpression; + qi::rule, Skipper> floorCeilExpression; + qi::rule identifier; + + // Parser that is used to recognize doubles only (as opposed to Spirit's double_ parser). + boost::spirit::qi::real_parser> strict_double; + + // Helper functions to create expressions. + storm::expressions::Expression createIteExpression(storm::expressions::Expression e1, storm::expressions::Expression e2, storm::expressions::Expression e3) const; + storm::expressions::Expression createImpliesExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createOrExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createAndExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createGreaterExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createGreaterOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createLessExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createLessOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createNotEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createPlusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createMinusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createMultExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createPowerExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createDivExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createNotExpression(storm::expressions::Expression e1) const; + storm::expressions::Expression createMinusExpression(storm::expressions::Expression e1) const; + storm::expressions::Expression createTrueExpression() const; + storm::expressions::Expression createFalseExpression() const; + storm::expressions::Expression createDoubleLiteralExpression(double value, bool& pass) const; + storm::expressions::Expression createIntegerLiteralExpression(int value) const; + storm::expressions::Expression createMinimumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createMaximumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; + storm::expressions::Expression createFloorExpression(storm::expressions::Expression e1) const; + storm::expressions::Expression createCeilExpression(storm::expressions::Expression e1) const; + storm::expressions::Expression getIdentifierExpression(std::string const& identifier) const; + + bool isValidIdentifier(std::string const& identifier); + + // Functor used for displaying error information. + struct ErrorHandler { + typedef qi::error_handler_result result_type; + + template + qi::error_handler_result operator()(T1 b, T2 e, T3 where, T4 const& what) const { + LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(where) << ": " << " expecting " << what << "."); + return qi::fail; + } + }; + + // An error handler function. + phoenix::function handler; + }; + } // namespace parser +} // namespace storm + +#endif /* STORM_PARSER_EXPRESSIONPARSER_H_ */ \ No newline at end of file diff --git a/src/parser/PrismParser.cpp b/src/parser/PrismParser.cpp index d4686bcff..b026880d3 100644 --- a/src/parser/PrismParser.cpp +++ b/src/parser/PrismParser.cpp @@ -60,53 +60,11 @@ namespace storm { return result; } - PrismParser::PrismParser(std::string const& filename, Iterator first) : PrismParser::base_type(start), secondRun(false), allowDoubleLiteralsFlag(true), filename(filename), annotate(first) { + PrismParser::PrismParser(std::string const& filename, Iterator first) : PrismParser::base_type(start), secondRun(false), filename(filename), annotate(first), expressionParser(keywords_) { // Parse simple identifier. identifier %= qi::as_string[qi::raw[qi::lexeme[((qi::alpha | qi::char_('_')) >> *(qi::alnum | qi::char_('_')))]]][qi::_pass = phoenix::bind(&PrismParser::isValidIdentifier, phoenix::ref(*this), qi::_1)]; identifier.name("identifier"); - floorCeilExpression = ((qi::lit("floor")[qi::_a = true] | qi::lit("ceil")[qi::_a = false]) >> qi::lit("(") >> plusExpression >> qi::lit(")"))[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&PrismParser::createFloorExpression, phoenix::ref(*this), qi::_1)] .else_ [qi::_val = phoenix::bind(&PrismParser::createCeilExpression, phoenix::ref(*this), qi::_1)]]; - floorCeilExpression.name("floor/ceil expression"); - - minMaxExpression = ((qi::lit("min")[qi::_a = true] | qi::lit("max")[qi::_a = false]) >> qi::lit("(") >> plusExpression >> qi::lit(",") >> plusExpression >> qi::lit(")"))[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&PrismParser::createMinimumExpression, phoenix::ref(*this), qi::_1, qi::_2)] .else_ [qi::_val = phoenix::bind(&PrismParser::createMaximumExpression, phoenix::ref(*this), qi::_1, qi::_2)]]; - minMaxExpression.name("min/max expression"); - - identifierExpression = identifier[qi::_val = phoenix::bind(&PrismParser::getIdentifierExpression, phoenix::ref(*this), qi::_1)]; - identifierExpression.name("identifier expression"); - - literalExpression = qi::lit("true")[qi::_val = phoenix::bind(&PrismParser::createTrueExpression, phoenix::ref(*this))] | qi::lit("false")[qi::_val = phoenix::bind(&PrismParser::createFalseExpression, phoenix::ref(*this))] | strict_double[qi::_val = phoenix::bind(&PrismParser::createDoubleLiteralExpression, phoenix::ref(*this), qi::_1, qi::_pass)] | qi::int_[qi::_val = phoenix::bind(&PrismParser::createIntegerLiteralExpression, phoenix::ref(*this), qi::_1)]; - literalExpression.name("literal expression"); - - atomicExpression = minMaxExpression | floorCeilExpression | qi::lit("(") >> expression >> qi::lit(")") | literalExpression | identifierExpression; - atomicExpression.name("atomic expression"); - - unaryExpression = atomicExpression[qi::_val = qi::_1] | (qi::lit("!") >> atomicExpression)[qi::_val = phoenix::bind(&PrismParser::createNotExpression, phoenix::ref(*this), qi::_1)] | (qi::lit("-") >> atomicExpression)[qi::_val = phoenix::bind(&PrismParser::createMinusExpression, phoenix::ref(*this), qi::_1)]; - unaryExpression.name("unary expression"); - - multiplicationExpression = unaryExpression[qi::_val = qi::_1] >> *((qi::lit("*")[qi::_a = true] | qi::lit("/")[qi::_a = false]) >> unaryExpression[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&PrismParser::createMultExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&PrismParser::createDivExpression, phoenix::ref(*this), qi::_val, qi::_1)]]); - multiplicationExpression.name("multiplication expression"); - - plusExpression = multiplicationExpression[qi::_val = qi::_1] >> *((qi::lit("+")[qi::_a = true] | qi::lit("-")[qi::_a = false]) >> multiplicationExpression)[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&PrismParser::createPlusExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&PrismParser::createMinusExpression, phoenix::ref(*this), qi::_val, qi::_1)]]; - plusExpression.name("plus expression"); - - relativeExpression = (plusExpression >> qi::lit(">=") >> plusExpression)[qi::_val = phoenix::bind(&PrismParser::createGreaterOrEqualExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit(">") >> plusExpression)[qi::_val = phoenix::bind(&PrismParser::createGreaterExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit("<=") >> plusExpression)[qi::_val = phoenix::bind(&PrismParser::createLessOrEqualExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit("<") >> plusExpression)[qi::_val = phoenix::bind(&PrismParser::createLessExpression, phoenix::ref(*this), qi::_1, qi::_2)] | plusExpression[qi::_val = qi::_1]; - relativeExpression.name("relative expression"); - - equalityExpression = relativeExpression[qi::_val = qi::_1] >> *((qi::lit("=")[qi::_a = true] | qi::lit("!=")[qi::_a = false]) >> relativeExpression)[phoenix::if_(qi::_a) [ qi::_val = phoenix::bind(&PrismParser::createEqualsExpression, phoenix::ref(*this), qi::_val, qi::_1) ] .else_ [ qi::_val = phoenix::bind(&PrismParser::createNotEqualsExpression, phoenix::ref(*this), qi::_val, qi::_1) ] ]; - equalityExpression.name("equality expression"); - - andExpression = equalityExpression[qi::_val = qi::_1] >> *(qi::lit("&") >> equalityExpression)[qi::_val = phoenix::bind(&PrismParser::createAndExpression, phoenix::ref(*this), qi::_val, qi::_1)]; - andExpression.name("and expression"); - - orExpression = andExpression[qi::_val = qi::_1] >> *((qi::lit("|")[qi::_a = true] | qi::lit("=>")[qi::_a = false]) >> andExpression)[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&PrismParser::createOrExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&PrismParser::createImpliesExpression, phoenix::ref(*this), qi::_val, qi::_1)] ]; - orExpression.name("or expression"); - - iteExpression = orExpression[qi::_val = qi::_1] >> -(qi::lit("?") > orExpression > qi::lit(":") > orExpression)[qi::_val = phoenix::bind(&PrismParser::createIteExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; - iteExpression.name("if-then-else expression"); - - expression %= iteExpression; - expression.name("expression"); - modelTypeDefinition %= modelType_; modelTypeDefinition.name("model type"); @@ -122,25 +80,25 @@ namespace storm { undefinedConstantDefinition = (undefinedBooleanConstantDefinition | undefinedIntegerConstantDefinition | undefinedDoubleConstantDefinition); undefinedConstantDefinition.name("undefined constant definition"); - definedBooleanConstantDefinition = ((qi::lit("const") >> qi::lit("bool") >> identifier >> qi::lit("=")) > expression > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createDefinedBooleanConstant, phoenix::ref(*this), qi::_1, qi::_2)]; + definedBooleanConstantDefinition = ((qi::lit("const") >> qi::lit("bool") >> identifier >> qi::lit("=")) > expressionParser > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createDefinedBooleanConstant, phoenix::ref(*this), qi::_1, qi::_2)]; definedBooleanConstantDefinition.name("defined boolean constant declaration"); - definedIntegerConstantDefinition = ((qi::lit("const") >> qi::lit("int") >> identifier >> qi::lit("=")) > expression >> qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createDefinedIntegerConstant, phoenix::ref(*this), qi::_1, qi::_2)]; + definedIntegerConstantDefinition = ((qi::lit("const") >> qi::lit("int") >> identifier >> qi::lit("=")) > expressionParser >> qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createDefinedIntegerConstant, phoenix::ref(*this), qi::_1, qi::_2)]; definedIntegerConstantDefinition.name("defined integer constant declaration"); - definedDoubleConstantDefinition = ((qi::lit("const") >> qi::lit("double") >> identifier >> qi::lit("=")) > expression > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createDefinedDoubleConstant, phoenix::ref(*this), qi::_1, qi::_2)]; + definedDoubleConstantDefinition = ((qi::lit("const") >> qi::lit("double") >> identifier >> qi::lit("=")) > expressionParser > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createDefinedDoubleConstant, phoenix::ref(*this), qi::_1, qi::_2)]; definedDoubleConstantDefinition.name("defined double constant declaration"); definedConstantDefinition %= (definedBooleanConstantDefinition | definedIntegerConstantDefinition | definedDoubleConstantDefinition); definedConstantDefinition.name("defined constant definition"); - formulaDefinition = (qi::lit("formula") > identifier > qi::lit("=") > expression > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createFormula, phoenix::ref(*this), qi::_1, qi::_2)]; + formulaDefinition = (qi::lit("formula") > identifier > qi::lit("=") > expressionParser > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createFormula, phoenix::ref(*this), qi::_1, qi::_2)]; formulaDefinition.name("formula definition"); - booleanVariableDefinition = ((identifier >> qi::lit(":") >> qi::lit("bool")) > ((qi::lit("init") > expression) | qi::attr(storm::expressions::Expression::createFalse())) > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createBooleanVariable, phoenix::ref(*this), qi::_1, qi::_2)]; + booleanVariableDefinition = ((identifier >> qi::lit(":") >> qi::lit("bool")) > ((qi::lit("init") > expressionParser) | qi::attr(storm::expressions::Expression::createFalse())) > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createBooleanVariable, phoenix::ref(*this), qi::_1, qi::_2)]; booleanVariableDefinition.name("boolean variable definition"); - integerVariableDefinition = ((identifier >> qi::lit(":") >> qi::lit("[")[phoenix::bind(&PrismParser::allowDoubleLiterals, phoenix::ref(*this), false)]) > expression[qi::_a = qi::_1] > qi::lit("..") > expression > qi::lit("]")[phoenix::bind(&PrismParser::allowDoubleLiterals, phoenix::ref(*this), true)] > -(qi::lit("init") > expression[qi::_a = qi::_1]) > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createIntegerVariable, phoenix::ref(*this), qi::_1, qi::_2, qi::_3, qi::_a)]; + integerVariableDefinition = ((identifier >> qi::lit(":") >> qi::lit("[")[phoenix::bind(&PrismParser::allowDoubleLiterals, phoenix::ref(*this), false)]) > expressionParser[qi::_a = qi::_1] > qi::lit("..") > expressionParser > qi::lit("]")[phoenix::bind(&PrismParser::allowDoubleLiterals, phoenix::ref(*this), true)] > -(qi::lit("init") > expressionParser[qi::_a = qi::_1]) > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createIntegerVariable, phoenix::ref(*this), qi::_1, qi::_2, qi::_3, qi::_a)]; integerVariableDefinition.name("integer variable definition"); variableDefinition = (booleanVariableDefinition[phoenix::push_back(qi::_r1, qi::_1)] | integerVariableDefinition[phoenix::push_back(qi::_r2, qi::_1)]); @@ -149,10 +107,10 @@ namespace storm { globalVariableDefinition = (qi::lit("global") > (booleanVariableDefinition[phoenix::push_back(phoenix::bind(&GlobalProgramInformation::globalBooleanVariables, qi::_r1), qi::_1)] | integerVariableDefinition[phoenix::push_back(phoenix::bind(&GlobalProgramInformation::globalIntegerVariables, qi::_r1), qi::_1)])); globalVariableDefinition.name("global variable declaration list"); - stateRewardDefinition = (expression > qi::lit(":") > plusExpression >> qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createStateReward, phoenix::ref(*this), qi::_1, qi::_2)]; + stateRewardDefinition = (expressionParser > qi::lit(":") > expressionParser >> qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createStateReward, phoenix::ref(*this), qi::_1, qi::_2)]; stateRewardDefinition.name("state reward definition"); - transitionRewardDefinition = (qi::lit("[") > -(identifier[qi::_a = qi::_1]) > qi::lit("]") > expression > qi::lit(":") > plusExpression > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createTransitionReward, phoenix::ref(*this), qi::_a, qi::_2, qi::_3)]; + transitionRewardDefinition = (qi::lit("[") > -(identifier[qi::_a = qi::_1]) > qi::lit("]") > expressionParser > qi::lit(":") > expressionParser > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createTransitionReward, phoenix::ref(*this), qi::_a, qi::_2, qi::_3)]; transitionRewardDefinition.name("transition reward definition"); rewardModelDefinition = (qi::lit("rewards") > -(qi::lit("\"") > identifier[qi::_a = qi::_1] > qi::lit("\"")) @@ -162,25 +120,25 @@ namespace storm { >> qi::lit("endrewards"))[qi::_val = phoenix::bind(&PrismParser::createRewardModel, phoenix::ref(*this), qi::_a, qi::_b, qi::_c)]; rewardModelDefinition.name("reward model definition"); - initialStatesConstruct = (qi::lit("init") > expression > qi::lit("endinit"))[qi::_pass = phoenix::bind(&PrismParser::addInitialStatesConstruct, phoenix::ref(*this), qi::_1, qi::_r1)]; + initialStatesConstruct = (qi::lit("init") > expressionParser > qi::lit("endinit"))[qi::_pass = phoenix::bind(&PrismParser::addInitialStatesConstruct, phoenix::ref(*this), qi::_1, qi::_r1)]; initialStatesConstruct.name("initial construct"); - labelDefinition = (qi::lit("label") > -qi::lit("\"") > identifier > -qi::lit("\"") > qi::lit("=") > expression >> qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createLabel, phoenix::ref(*this), qi::_1, qi::_2)]; + labelDefinition = (qi::lit("label") > -qi::lit("\"") > identifier > -qi::lit("\"") > qi::lit("=") > expressionParser >> qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createLabel, phoenix::ref(*this), qi::_1, qi::_2)]; labelDefinition.name("label definition"); - assignmentDefinition = (qi::lit("(") > identifier > qi::lit("'") > qi::lit("=") > expression > qi::lit(")"))[qi::_val = phoenix::bind(&PrismParser::createAssignment, phoenix::ref(*this), qi::_1, qi::_2)]; + assignmentDefinition = (qi::lit("(") > identifier > qi::lit("'") > qi::lit("=") > expressionParser > qi::lit(")"))[qi::_val = phoenix::bind(&PrismParser::createAssignment, phoenix::ref(*this), qi::_1, qi::_2)]; assignmentDefinition.name("assignment"); assignmentDefinitionList %= +assignmentDefinition % "&"; assignmentDefinitionList.name("assignment list"); - updateDefinition = (((plusExpression > qi::lit(":")) | qi::attr(storm::expressions::Expression::createDoubleLiteral(1))) >> assignmentDefinitionList)[qi::_val = phoenix::bind(&PrismParser::createUpdate, phoenix::ref(*this), qi::_1, qi::_2, qi::_r1)]; + updateDefinition = (((expressionParser > qi::lit(":")) | qi::attr(storm::expressions::Expression::createDoubleLiteral(1))) >> assignmentDefinitionList)[qi::_val = phoenix::bind(&PrismParser::createUpdate, phoenix::ref(*this), qi::_1, qi::_2, qi::_r1)]; updateDefinition.name("update"); updateListDefinition %= +updateDefinition(qi::_r1) % "+"; updateListDefinition.name("update list"); - commandDefinition = (qi::lit("[") > -(identifier[qi::_a = qi::_1]) > qi::lit("]") > expression > qi::lit("->") > updateListDefinition(qi::_r1) > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createCommand, phoenix::ref(*this), qi::_a, qi::_2, qi::_3, qi::_r1)]; + commandDefinition = (qi::lit("[") > -(identifier[qi::_a = qi::_1]) > qi::lit("]") > expressionParser > qi::lit("->") > updateListDefinition(qi::_r1) > qi::lit(";"))[qi::_val = phoenix::bind(&PrismParser::createCommand, phoenix::ref(*this), qi::_a, qi::_2, qi::_3, qi::_r1)]; commandDefinition.name("command definition"); moduleDefinition = ((qi::lit("module") >> identifier >> *(variableDefinition(qi::_a, qi::_b))) > +commandDefinition(qi::_r1) > qi::lit("endmodule"))[qi::_val = phoenix::bind(&PrismParser::createModule, phoenix::ref(*this), qi::_1, qi::_a, qi::_b, qi::_2, qi::_r1)]; @@ -209,22 +167,6 @@ namespace storm { > qi::eoi)[qi::_val = phoenix::bind(&PrismParser::createProgram, phoenix::ref(*this), qi::_a)]; start.name("probabilistic program"); - // Enable error reporting. - qi::on_error(expression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(iteExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(orExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(andExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(equalityExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(relativeExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(plusExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(multiplicationExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(unaryExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(atomicExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(literalExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(identifierExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(minMaxExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - qi::on_error(floorCeilExpression, handler(qi::_1, qi::_2, qi::_3, qi::_4)); - // Enable location tracking for important entities. auto setLocationInfoFunction = this->annotate(qi::_val, qi::_1, qi::_3); qi::on_success(undefinedBooleanConstantDefinition, setLocationInfoFunction); @@ -247,10 +189,11 @@ namespace storm { void PrismParser::moveToSecondRun() { this->secondRun = true; + this->expressionParser.setIdentifierMapping(&this->identifiers_); } void PrismParser::allowDoubleLiterals(bool flag) { - this->allowDoubleLiteralsFlag = flag; + this->expressionParser.setAcceptDoubleLiterals(flag); } std::string const& PrismParser::getFilename() const { @@ -274,301 +217,6 @@ namespace storm { return true; } - storm::expressions::Expression PrismParser::createIteExpression(storm::expressions::Expression e1, storm::expressions::Expression e2, storm::expressions::Expression e3) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1.ite(e2, e3); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createImpliesExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1.implies(e2); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createOrExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 || e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createAndExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try{ - return e1 && e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createGreaterExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 > e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createGreaterOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 >= e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createLessExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 < e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createLessOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 <= e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - if (e1.hasBooleanReturnType() && e2.hasBooleanReturnType()) { - return e1.iff(e2); - } else { - return e1 == e2; - } - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createNotEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - if (e1.hasBooleanReturnType() && e2.hasBooleanReturnType()) { - return e1 ^ e2; - } else { - return e1 != e2; - } - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createPlusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 + e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createMinusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 - e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createMultExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 * e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createDivExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1 / e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createNotExpression(storm::expressions::Expression e1) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return !e1; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createMinusExpression(storm::expressions::Expression e1) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return -e1; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createTrueExpression() const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - return storm::expressions::Expression::createTrue(); - } - } - - storm::expressions::Expression PrismParser::createFalseExpression() const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression PrismParser::createDoubleLiteralExpression(double value, bool& pass) const { - // If we are not supposed to accept double expressions, we reject it by setting pass to false. - if (!this->allowDoubleLiteralsFlag) { - pass = false; - } - - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - return storm::expressions::Expression::createDoubleLiteral(value); - } - } - - storm::expressions::Expression PrismParser::createIntegerLiteralExpression(int value) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - return storm::expressions::Expression::createIntegerLiteral(static_cast(value)); - } - } - - storm::expressions::Expression PrismParser::createMinimumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return storm::expressions::Expression::minimum(e1, e2); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createMaximumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return storm::expressions::Expression::maximum(e1, e2); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createFloorExpression(storm::expressions::Expression e1) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1.floor(); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::createCeilExpression(storm::expressions::Expression e1) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - try { - return e1.ceil(); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": " << e.what() << "."); - } - } - } - - storm::expressions::Expression PrismParser::getIdentifierExpression(std::string const& identifier) const { - if (!this->secondRun) { - return storm::expressions::Expression::createFalse(); - } else { - storm::expressions::Expression const* expression = this->identifiers_.find(identifier); - LOG_THROW(expression != nullptr, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Undeclared identifier '" << identifier << "'."); - return *expression; - } - } - storm::prism::Constant PrismParser::createUndefinedBooleanConstant(std::string const& newConstant) const { if (!this->secondRun) { LOG_THROW(this->identifiers_.find(newConstant) == nullptr, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Duplicate identifier '" << newConstant << "'."); diff --git a/src/parser/PrismParser.h b/src/parser/PrismParser.h index e1a9136ba..bdf84af24 100644 --- a/src/parser/PrismParser.h +++ b/src/parser/PrismParser.h @@ -6,23 +6,8 @@ #include #include -// Include boost spirit. -#define BOOST_SPIRIT_USE_PHOENIX_V3 -#include -#include -#include -#include -#include - -namespace qi = boost::spirit::qi; -namespace phoenix = boost::phoenix; - -typedef std::string::const_iterator BaseIteratorType; -typedef boost::spirit::line_pos_iterator PositionIteratorType; -typedef PositionIteratorType Iterator; -typedef BOOST_TYPEOF(boost::spirit::ascii::space | qi::lit("//") >> *(qi::char_ - qi::eol) >> qi::eol) Skipper; -typedef BOOST_TYPEOF(qi::lit("//") >> *(qi::char_ - qi::eol) >> qi::eol | boost::spirit::ascii::space) Skipper2; - +#include "src/parser/SpiritParserDefinitions.h" +#include "src/parser/ExpressionParser.h" #include "src/storage/prism/Program.h" #include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expressions.h" @@ -113,17 +98,6 @@ namespace storm { } }; - // Functor used for displaying error information. - struct ErrorHandler { - typedef qi::error_handler_result result_type; - - template - qi::error_handler_result operator()(T1 b, T2 e, T3 where, T4 const& what) const { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(where) << ": " << " expecting " << what << "."); - return qi::fail; - } - }; - // Functor used for annotating entities with line number information. class PositionAnnotation { public: @@ -165,9 +139,6 @@ namespace storm { */ void allowDoubleLiterals(bool flag); - // A flag that stores wether to allow or forbid double literals in parsed expressions. - bool allowDoubleLiteralsFlag; - // The name of the file being parsed. std::string filename; @@ -179,7 +150,6 @@ namespace storm { std::string const& getFilename() const; // A function used for annotating the entities with their position. - phoenix::function handler; phoenix::function annotate; // The starting point of the grammar. @@ -237,60 +207,18 @@ namespace storm { // Rules for identifier parsing. qi::rule identifier; - // Rules for parsing a composed expression. - qi::rule expression; - qi::rule iteExpression; - qi::rule, Skipper> orExpression; - qi::rule andExpression; - qi::rule relativeExpression; - qi::rule, Skipper> equalityExpression; - qi::rule, Skipper> plusExpression; - qi::rule, Skipper> multiplicationExpression; - qi::rule unaryExpression; - qi::rule atomicExpression; - qi::rule literalExpression; - qi::rule identifierExpression; - qi::rule, Skipper> minMaxExpression; - qi::rule, Skipper> floorCeilExpression; - - // Parser that is used to recognize doubles only (as opposed to Spirit's double_ parser). - boost::spirit::qi::real_parser> strict_double; - // Parsers that recognize special keywords and model types. storm::parser::PrismParser::keywordsStruct keywords_; storm::parser::PrismParser::modelTypeStruct modelType_; qi::symbols identifiers_; + // Parser used for recognizing expressions. + storm::parser::ExpressionParser expressionParser; + // Helper methods used in the grammar. bool isValidIdentifier(std::string const& identifier); bool addInitialStatesConstruct(storm::expressions::Expression initialStatesExpression, GlobalProgramInformation& globalProgramInformation); - storm::expressions::Expression createIteExpression(storm::expressions::Expression e1, storm::expressions::Expression e2, storm::expressions::Expression e3) const; - storm::expressions::Expression createImpliesExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createOrExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createAndExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createGreaterExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createGreaterOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createLessExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createLessOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createNotEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createPlusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createMinusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createMultExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createDivExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createNotExpression(storm::expressions::Expression e1) const; - storm::expressions::Expression createMinusExpression(storm::expressions::Expression e1) const; - storm::expressions::Expression createTrueExpression() const; - storm::expressions::Expression createFalseExpression() const; - storm::expressions::Expression createDoubleLiteralExpression(double value, bool& pass) const; - storm::expressions::Expression createIntegerLiteralExpression(int value) const; - storm::expressions::Expression createMinimumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createMaximumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createFloorExpression(storm::expressions::Expression e1) const; - storm::expressions::Expression createCeilExpression(storm::expressions::Expression e1) const; - storm::expressions::Expression getIdentifierExpression(std::string const& identifier) const; - storm::prism::Constant createUndefinedBooleanConstant(std::string const& newConstant) const; storm::prism::Constant createUndefinedIntegerConstant(std::string const& newConstant) const; storm::prism::Constant createUndefinedDoubleConstant(std::string const& newConstant) const; diff --git a/src/parser/SpiritParserDefinitions.h b/src/parser/SpiritParserDefinitions.h new file mode 100644 index 000000000..feb78f77d --- /dev/null +++ b/src/parser/SpiritParserDefinitions.h @@ -0,0 +1,21 @@ +#ifndef STORM_PARSER_SPIRITPARSERDEFINITIONS_H_ +#define STORM_PARSER_SPIRITPARSERDEFINITIONS_H_ + +// Include boost spirit. +#define BOOST_SPIRIT_USE_PHOENIX_V3 +#include +#include +#include +#include +#include + +namespace qi = boost::spirit::qi; +namespace phoenix = boost::phoenix; + +typedef std::string::const_iterator BaseIteratorType; +typedef boost::spirit::line_pos_iterator PositionIteratorType; +typedef PositionIteratorType Iterator; + +typedef BOOST_TYPEOF(boost::spirit::ascii::space | qi::lit("//") >> *(qi::char_ - qi::eol) >> qi::eol) Skipper; + +#endif /* STORM_PARSER_SPIRITPARSERDEFINITIONS_H_ */ \ No newline at end of file diff --git a/src/storage/dd/CuddDd.cpp b/src/storage/dd/CuddDd.cpp index cdff92894..57cda7018 100644 --- a/src/storage/dd/CuddDd.cpp +++ b/src/storage/dd/CuddDd.cpp @@ -2,7 +2,9 @@ #include #include "src/storage/dd/CuddDd.h" +#include "src/storage/dd/CuddOdd.h" #include "src/storage/dd/CuddDdManager.h" +#include "src/utility/vector.h" #include "src/exceptions/InvalidArgumentException.h" @@ -118,138 +120,141 @@ namespace storm { } Dd Dd::equals(Dd const& other) const { - Dd result(*this); - result.cuddAdd = result.cuddAdd.Equals(other.getCuddAdd()); - return result; + std::set metaVariableNames(this->getContainedMetaVariableNames()); + metaVariableNames.insert(other.getContainedMetaVariableNames().begin(), other.getContainedMetaVariableNames().end()); + return Dd(this->getDdManager(), this->getCuddAdd().Equals(other.getCuddAdd()), metaVariableNames); } Dd Dd::notEquals(Dd const& other) const { - Dd result(*this); - result.cuddAdd = result.cuddAdd.NotEquals(other.getCuddAdd()); - return result; + std::set metaVariableNames(this->getContainedMetaVariableNames()); + metaVariableNames.insert(other.getContainedMetaVariableNames().begin(), other.getContainedMetaVariableNames().end()); + return Dd(this->getDdManager(), this->getCuddAdd().NotEquals(other.getCuddAdd()), metaVariableNames); } Dd Dd::less(Dd const& other) const { - Dd result(*this); - result.cuddAdd = result.cuddAdd.LessThan(other.getCuddAdd()); - return result; + std::set metaVariableNames(this->getContainedMetaVariableNames()); + metaVariableNames.insert(other.getContainedMetaVariableNames().begin(), other.getContainedMetaVariableNames().end()); + return Dd(this->getDdManager(), this->getCuddAdd().LessThan(other.getCuddAdd()), metaVariableNames); } Dd Dd::lessOrEqual(Dd const& other) const { - Dd result(*this); - result.cuddAdd = result.cuddAdd.LessThanOrEqual(other.getCuddAdd()); - return result; + std::set metaVariableNames(this->getContainedMetaVariableNames()); + metaVariableNames.insert(other.getContainedMetaVariableNames().begin(), other.getContainedMetaVariableNames().end()); + return Dd(this->getDdManager(), this->getCuddAdd().LessThanOrEqual(other.getCuddAdd()), metaVariableNames); } Dd Dd::greater(Dd const& other) const { - Dd result(*this); - result.cuddAdd = result.cuddAdd.GreaterThan(other.getCuddAdd()); - return result; + std::set metaVariableNames(this->getContainedMetaVariableNames()); + metaVariableNames.insert(other.getContainedMetaVariableNames().begin(), other.getContainedMetaVariableNames().end()); + return Dd(this->getDdManager(), this->getCuddAdd().GreaterThan(other.getCuddAdd()), metaVariableNames); } Dd Dd::greaterOrEqual(Dd const& other) const { - Dd result(*this); - result.cuddAdd = result.cuddAdd.GreaterThanOrEqual(other.getCuddAdd()); - return result; + std::set metaVariableNames(this->getContainedMetaVariableNames()); + metaVariableNames.insert(other.getContainedMetaVariableNames().begin(), other.getContainedMetaVariableNames().end()); + return Dd(this->getDdManager(), this->getCuddAdd().GreaterThanOrEqual(other.getCuddAdd()), metaVariableNames); } Dd Dd::minimum(Dd const& other) const { std::set metaVariableNames(this->getContainedMetaVariableNames()); metaVariableNames.insert(other.getContainedMetaVariableNames().begin(), other.getContainedMetaVariableNames().end()); - return Dd(this->getDdManager(), this->getCuddAdd().Minimum(other.getCuddAdd()), metaVariableNames); } Dd Dd::maximum(Dd const& other) const { std::set metaVariableNames(this->getContainedMetaVariableNames()); metaVariableNames.insert(other.getContainedMetaVariableNames().begin(), other.getContainedMetaVariableNames().end()); - return Dd(this->getDdManager(), this->getCuddAdd().Maximum(other.getCuddAdd()), metaVariableNames); } - void Dd::existsAbstract(std::set const& metaVariableNames) { + Dd Dd::existsAbstract(std::set const& metaVariableNames) const { Dd cubeDd(this->getDdManager()->getOne()); + std::set newMetaVariables = this->getContainedMetaVariableNames(); for (auto const& metaVariableName : metaVariableNames) { // First check whether the DD contains the meta variable and erase it, if this is the case. if (!this->containsMetaVariable(metaVariableName)) { throw storm::exceptions::InvalidArgumentException() << "Cannot abstract from meta variable that is not present in the DD."; } - this->getContainedMetaVariableNames().erase(metaVariableName); + newMetaVariables.erase(metaVariableName); DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(metaVariableName); cubeDd *= metaVariable.getCube(); } - this->cuddAdd = this->cuddAdd.OrAbstract(cubeDd.getCuddAdd()); + return Dd(this->getDdManager(), this->cuddAdd.OrAbstract(cubeDd.getCuddAdd()), newMetaVariables); } - void Dd::universalAbstract(std::set const& metaVariableNames) { + Dd Dd::universalAbstract(std::set const& metaVariableNames) const { Dd cubeDd(this->getDdManager()->getOne()); + std::set newMetaVariables = this->getContainedMetaVariableNames(); for (auto const& metaVariableName : metaVariableNames) { // First check whether the DD contains the meta variable and erase it, if this is the case. if (!this->containsMetaVariable(metaVariableName)) { throw storm::exceptions::InvalidArgumentException() << "Cannot abstract from meta variable that is not present in the DD."; } - this->getContainedMetaVariableNames().erase(metaVariableName); + newMetaVariables.erase(metaVariableName); DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(metaVariableName); cubeDd *= metaVariable.getCube(); } - this->cuddAdd = this->cuddAdd.UnivAbstract(cubeDd.getCuddAdd()); + return Dd(this->getDdManager(), this->cuddAdd.UnivAbstract(cubeDd.getCuddAdd()), newMetaVariables); } - void Dd::sumAbstract(std::set const& metaVariableNames) { + Dd Dd::sumAbstract(std::set const& metaVariableNames) const { Dd cubeDd(this->getDdManager()->getOne()); + std::set newMetaVariables = this->getContainedMetaVariableNames(); for (auto const& metaVariableName : metaVariableNames) { // First check whether the DD contains the meta variable and erase it, if this is the case. if (!this->containsMetaVariable(metaVariableName)) { throw storm::exceptions::InvalidArgumentException() << "Cannot abstract from meta variable that is not present in the DD."; } - this->getContainedMetaVariableNames().erase(metaVariableName); + newMetaVariables.erase(metaVariableName); DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(metaVariableName); cubeDd *= metaVariable.getCube(); } - this->cuddAdd = this->cuddAdd.ExistAbstract(cubeDd.getCuddAdd()); + return Dd(this->getDdManager(), this->cuddAdd.ExistAbstract(cubeDd.getCuddAdd()), newMetaVariables); } - void Dd::minAbstract(std::set const& metaVariableNames) { + Dd Dd::minAbstract(std::set const& metaVariableNames) const { Dd cubeDd(this->getDdManager()->getOne()); + std::set newMetaVariables = this->getContainedMetaVariableNames(); for (auto const& metaVariableName : metaVariableNames) { // First check whether the DD contains the meta variable and erase it, if this is the case. if (!this->containsMetaVariable(metaVariableName)) { throw storm::exceptions::InvalidArgumentException() << "Cannot abstract from meta variable that is not present in the DD."; } - this->getContainedMetaVariableNames().erase(metaVariableName); + newMetaVariables.erase(metaVariableName); DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(metaVariableName); cubeDd *= metaVariable.getCube(); } - this->cuddAdd = this->cuddAdd.MinAbstract(cubeDd.getCuddAdd()); + return Dd(this->getDdManager(), this->cuddAdd.MinAbstract(cubeDd.getCuddAdd()), newMetaVariables); } - void Dd::maxAbstract(std::set const& metaVariableNames) { + Dd Dd::maxAbstract(std::set const& metaVariableNames) const { Dd cubeDd(this->getDdManager()->getOne()); + std::set newMetaVariables = this->getContainedMetaVariableNames(); for (auto const& metaVariableName : metaVariableNames) { // First check whether the DD contains the meta variable and erase it, if this is the case. if (!this->containsMetaVariable(metaVariableName)) { throw storm::exceptions::InvalidArgumentException() << "Cannot abstract from meta variable that is not present in the DD."; } - this->getContainedMetaVariableNames().erase(metaVariableName); + newMetaVariables.erase(metaVariableName); DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(metaVariableName); cubeDd *= metaVariable.getCube(); } - this->cuddAdd = this->cuddAdd.MaxAbstract(cubeDd.getCuddAdd()); + return Dd(this->getDdManager(), this->cuddAdd.MaxAbstract(cubeDd.getCuddAdd()), newMetaVariables); } bool Dd::equalModuloPrecision(Dd const& other, double precision, bool relative) const { @@ -314,6 +319,36 @@ namespace storm { return Dd(this->getDdManager(), this->cuddAdd.MatrixMultiply(otherMatrix.getCuddAdd(), summationDdVariables), containedMetaVariableNames); } + Dd Dd::greater(double value) const { + return Dd(this->getDdManager(), this->getCuddAdd().BddStrictThreshold(value).Add(), this->getContainedMetaVariableNames()); + } + + Dd Dd::greaterOrEqual(double value) const { + return Dd(this->getDdManager(), this->getCuddAdd().BddThreshold(value).Add(), this->getContainedMetaVariableNames()); + } + + Dd Dd::notZero() const { + return Dd(this->getDdManager(), this->getCuddAdd().BddPattern().Add(), this->getContainedMetaVariableNames()); + } + + Dd Dd::constrain(Dd const& constraint) const { + std::set metaVariableNames(this->getContainedMetaVariableNames()); + metaVariableNames.insert(constraint.getContainedMetaVariableNames().begin(), constraint.getContainedMetaVariableNames().end()); + + return Dd(this->getDdManager(), this->getCuddAdd().Constrain(constraint.getCuddAdd()), metaVariableNames); + } + + Dd Dd::restrict(Dd const& constraint) const { + std::set metaVariableNames(this->getContainedMetaVariableNames()); + metaVariableNames.insert(constraint.getContainedMetaVariableNames().begin(), constraint.getContainedMetaVariableNames().end()); + + return Dd(this->getDdManager(), this->getCuddAdd().Restrict(constraint.getCuddAdd()), metaVariableNames); + } + + Dd Dd::getSupport() const { + return Dd(this->getDdManager(), this->getCuddAdd().Support().Add(), this->getContainedMetaVariableNames()); + } + uint_fast64_t Dd::getNonZeroCount() const { std::size_t numberOfDdVariables = 0; for (auto const& metaVariableName : this->containedMetaVariableNames) { @@ -379,7 +414,7 @@ namespace storm { } Dd value = *this * valueEncoding; - value.sumAbstract(this->getContainedMetaVariableNames()); + value = value.sumAbstract(this->getContainedMetaVariableNames()); return static_cast(Cudd_V(value.getCuddAdd().getNode())); } @@ -395,6 +430,319 @@ namespace storm { return Cudd_IsConstant(this->cuddAdd.getNode()); } + uint_fast64_t Dd::getIndex() const { + return static_cast(this->getCuddAdd().NodeReadIndex()); + } + + template + std::vector Dd::toVector() const { + return this->toVector(Odd(*this)); + } + + template + std::vector Dd::toVector(Odd const& rowOdd) const { + std::vector result(rowOdd.getTotalOffset()); + std::vector ddVariableIndices = this->getSortedVariableIndices(); + addToVectorRec(this->getCuddAdd().getNode(), 0, ddVariableIndices.size(), 0, rowOdd, ddVariableIndices, result); + return result; + } + + storm::storage::SparseMatrix Dd::toMatrix() const { + std::set rowVariables; + std::set columnVariables; + + for (auto const& variableName : this->getContainedMetaVariableNames()) { + if (variableName.size() > 0 && variableName.back() == '\'') { + columnVariables.insert(variableName); + } else { + rowVariables.insert(variableName); + } + } + + return toMatrix(rowVariables, columnVariables, Odd(this->existsAbstract(rowVariables)), Odd(this->existsAbstract(columnVariables))); + } + + storm::storage::SparseMatrix Dd::toMatrix(storm::dd::Odd const& rowOdd, storm::dd::Odd const& columnOdd) const { + std::set rowMetaVariables; + std::set columnMetaVariables; + + for (auto const& variableName : this->getContainedMetaVariableNames()) { + if (variableName.size() > 0 && variableName.back() == '\'') { + columnMetaVariables.insert(variableName); + } else { + rowMetaVariables.insert(variableName); + } + } + + return toMatrix(rowMetaVariables, columnMetaVariables, rowOdd, columnOdd); + } + + storm::storage::SparseMatrix Dd::toMatrix(std::set const& rowMetaVariables, std::set const& columnMetaVariables, storm::dd::Odd const& rowOdd, storm::dd::Odd const& columnOdd) const { + std::vector ddRowVariableIndices; + std::vector ddColumnVariableIndices; + + for (auto const& variableName : rowMetaVariables) { + DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(variableName); + for (auto const& ddVariable : metaVariable.getDdVariables()) { + ddRowVariableIndices.push_back(ddVariable.getIndex()); + } + } + for (auto const& variableName : columnMetaVariables) { + DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(variableName); + for (auto const& ddVariable : metaVariable.getDdVariables()) { + ddColumnVariableIndices.push_back(ddVariable.getIndex()); + } + } + + // Prepare the vectors that represent the matrix. + std::vector rowIndications(rowOdd.getTotalOffset() + 1); + std::vector> columnsAndValues(this->getNonZeroCount()); + + // Create a trivial row grouping. + std::vector trivialRowGroupIndices(rowIndications.size()); + uint_fast64_t i = 0; + for (auto& entry : trivialRowGroupIndices) { + entry = i; + ++i; + } + + // Use the toMatrixRec function to compute the number of elements in each row. Using the flag, we prevent + // it from actually generating the entries in the entry vector. + toMatrixRec(this->getCuddAdd().getNode(), rowIndications, columnsAndValues, trivialRowGroupIndices, rowOdd, columnOdd, 0, 0, ddRowVariableIndices.size() + ddColumnVariableIndices.size(), 0, 0, ddRowVariableIndices, ddColumnVariableIndices, false); + + // TODO: counting might be faster by just summing over the primed variables and then using the ODD to convert + // the resulting (DD) vector to an explicit vector. + + // Now that we computed the number of entries in each row, compute the corresponding offsets in the entry vector. + uint_fast64_t tmp = 0; + uint_fast64_t tmp2 = 0; + for (uint_fast64_t i = 1; i < rowIndications.size(); ++i) { + tmp2 = rowIndications[i]; + rowIndications[i] = rowIndications[i - 1] + tmp; + std::swap(tmp, tmp2); + } + rowIndications[0] = 0; + + // Now actually fill the entry vector. + toMatrixRec(this->getCuddAdd().getNode(), rowIndications, columnsAndValues, trivialRowGroupIndices, rowOdd, columnOdd, 0, 0, ddRowVariableIndices.size() + ddColumnVariableIndices.size(), 0, 0, ddRowVariableIndices, ddColumnVariableIndices, true); + + // Since the last call to toMatrixRec modified the rowIndications, we need to restore the correct values. + for (uint_fast64_t i = rowIndications.size() - 1; i > 0; --i) { + rowIndications[i] = rowIndications[i - 1]; + } + rowIndications[0] = 0; + + // Construct matrix and return result. + return storm::storage::SparseMatrix(columnOdd.getTotalOffset(), std::move(rowIndications), std::move(columnsAndValues), std::move(trivialRowGroupIndices)); + } + + storm::storage::SparseMatrix Dd::toMatrix(std::set const& rowMetaVariables, std::set const& columnMetaVariables, std::set const& groupMetaVariables, storm::dd::Odd const& rowOdd, storm::dd::Odd const& columnOdd) const { + std::vector ddRowVariableIndices; + std::vector ddColumnVariableIndices; + std::vector ddGroupVariableIndices; + std::set rowAndColumnMetaVariables; + + for (auto const& variableName : rowMetaVariables) { + DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(variableName); + for (auto const& ddVariable : metaVariable.getDdVariables()) { + ddRowVariableIndices.push_back(ddVariable.getIndex()); + } + rowAndColumnMetaVariables.insert(variableName); + } + for (auto const& variableName : columnMetaVariables) { + DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(variableName); + for (auto const& ddVariable : metaVariable.getDdVariables()) { + ddColumnVariableIndices.push_back(ddVariable.getIndex()); + } + rowAndColumnMetaVariables.insert(variableName); + } + for (auto const& variableName : groupMetaVariables) { + DdMetaVariable const& metaVariable = this->getDdManager()->getMetaVariable(variableName); + for (auto const& ddVariable : metaVariable.getDdVariables()) { + ddGroupVariableIndices.push_back(ddVariable.getIndex()); + } + } + + // TODO: assert that the group variables are at the very top of the variable ordering? + + // Start by computing the offsets (in terms of rows) for each row group. + Dd stateToNumberOfChoices = this->notZero().existsAbstract(columnMetaVariables).sumAbstract(groupMetaVariables); + std::vector rowGroupIndices = stateToNumberOfChoices.toVector(rowOdd); + rowGroupIndices.resize(rowGroupIndices.size() + 1); + uint_fast64_t tmp = 0; + uint_fast64_t tmp2 = 0; + for (uint_fast64_t i = 1; i < rowGroupIndices.size(); ++i) { + tmp2 = rowGroupIndices[i]; + rowGroupIndices[i] = rowGroupIndices[i - 1] + tmp; + std::swap(tmp, tmp2); + } + rowGroupIndices[0] = 0; + + // Next, we split the matrix into one for each group. This only works if the group variables are at the very + // top. + std::vector> groups; + splitGroupsRec(this->getCuddAdd().getNode(), groups, ddGroupVariableIndices, 0, ddGroupVariableIndices.size(), rowAndColumnMetaVariables); + + // Create the actual storage for the non-zero entries. + std::vector> columnsAndValues(this->getNonZeroCount()); + + // Now compute the indices at which the individual rows start. + std::vector rowIndications(rowGroupIndices.back() + 1); + std::vector> statesWithGroupEnabled(groups.size()); + for (uint_fast64_t i = 0; i < groups.size(); ++i) { + auto const& dd = groups[i]; + + toMatrixRec(dd.getCuddAdd().getNode(), rowIndications, columnsAndValues, rowGroupIndices, rowOdd, columnOdd, 0, 0, ddRowVariableIndices.size() + ddColumnVariableIndices.size(), 0, 0, ddRowVariableIndices, ddColumnVariableIndices, false); + + statesWithGroupEnabled[i] = dd.notZero().existsAbstract(columnMetaVariables); + addToVectorRec(statesWithGroupEnabled[i].getCuddAdd().getNode(), 0, ddRowVariableIndices.size(), 0, rowOdd, ddRowVariableIndices, rowGroupIndices); + } + + // Since we modified the rowGroupIndices, we need to restore the correct values. + for (uint_fast64_t i = rowGroupIndices.size() - 1; i > 0; --i) { + rowGroupIndices[i] = rowGroupIndices[i - 1]; + } + rowGroupIndices[0] = 0; + + // Now that we computed the number of entries in each row, compute the corresponding offsets in the entry vector. + tmp = 0; + tmp2 = 0; + for (uint_fast64_t i = 1; i < rowIndications.size(); ++i) { + tmp2 = rowIndications[i]; + rowIndications[i] = rowIndications[i - 1] + tmp; + std::swap(tmp, tmp2); + } + rowIndications[0] = 0; + + // Now actually fill the entry vector. + for (uint_fast64_t i = 0; i < groups.size(); ++i) { + auto const& dd = groups[i]; + + toMatrixRec(dd.getCuddAdd().getNode(), rowIndications, columnsAndValues, rowGroupIndices, rowOdd, columnOdd, 0, 0, ddRowVariableIndices.size() + ddColumnVariableIndices.size(), 0, 0, ddRowVariableIndices, ddColumnVariableIndices, true); + + addToVectorRec(statesWithGroupEnabled[i].getCuddAdd().getNode(), 0, ddRowVariableIndices.size(), 0, rowOdd, ddRowVariableIndices, rowGroupIndices); + } + + // Since we modified the rowGroupIndices, we need to restore the correct values. + for (uint_fast64_t i = rowGroupIndices.size() - 1; i > 0; --i) { + rowGroupIndices[i] = rowGroupIndices[i - 1]; + } + rowGroupIndices[0] = 0; + + // Since the last call to toMatrixRec modified the rowIndications, we need to restore the correct values. + for (uint_fast64_t i = rowIndications.size() - 1; i > 0; --i) { + rowIndications[i] = rowIndications[i - 1]; + } + rowIndications[0] = 0; + + return storm::storage::SparseMatrix(columnOdd.getTotalOffset(), std::move(rowIndications), std::move(columnsAndValues), std::move(rowGroupIndices)); + } + + void Dd::toMatrixRec(DdNode const* dd, std::vector& rowIndications, std::vector>& columnsAndValues, std::vector const& rowGroupOffsets, Odd const& rowOdd, Odd const& columnOdd, uint_fast64_t currentRowLevel, uint_fast64_t currentColumnLevel, uint_fast64_t maxLevel, uint_fast64_t currentRowOffset, uint_fast64_t currentColumnOffset, std::vector const& ddRowVariableIndices, std::vector const& ddColumnVariableIndices, bool generateValues) const { + // For the empty DD, we do not need to add any entries. + if (dd == this->getDdManager()->getZero().getCuddAdd().getNode()) { + return; + } + + // If we are at the maximal level, the value to be set is stored as a constant in the DD. + if (currentRowLevel + currentColumnLevel == maxLevel) { + if (generateValues) { + columnsAndValues[rowIndications[rowGroupOffsets[currentRowOffset]]] = storm::storage::MatrixEntry(currentColumnOffset, Cudd_V(dd)); + } + ++rowIndications[rowGroupOffsets[currentRowOffset]]; + } else { + DdNode const* elseElse; + DdNode const* elseThen; + DdNode const* thenElse; + DdNode const* thenThen; + + if (ddColumnVariableIndices[currentColumnLevel] < dd->index) { + elseElse = elseThen = thenElse = thenThen = dd; + } else if (ddRowVariableIndices[currentColumnLevel] < dd->index) { + elseElse = thenElse = Cudd_E(dd); + elseThen = thenThen = Cudd_T(dd); + } else { + DdNode const* elseNode = Cudd_E(dd); + if (ddColumnVariableIndices[currentColumnLevel] < elseNode->index) { + elseElse = elseThen = elseNode; + } else { + elseElse = Cudd_E(elseNode); + elseThen = Cudd_T(elseNode); + } + + DdNode const* thenNode = Cudd_T(dd); + if (ddColumnVariableIndices[currentColumnLevel] < thenNode->index) { + thenElse = thenThen = thenNode; + } else { + thenElse = Cudd_E(thenNode); + thenThen = Cudd_T(thenNode); + } + } + + // Visit else-else. + toMatrixRec(elseElse, rowIndications, columnsAndValues, rowGroupOffsets, rowOdd.getElseSuccessor(), columnOdd.getElseSuccessor(), currentRowLevel + 1, currentColumnLevel + 1, maxLevel, currentRowOffset, currentColumnOffset, ddRowVariableIndices, ddColumnVariableIndices, generateValues); + // Visit else-then. + toMatrixRec(elseThen, rowIndications, columnsAndValues, rowGroupOffsets, rowOdd.getElseSuccessor(), columnOdd.getThenSuccessor(), currentRowLevel + 1, currentColumnLevel + 1, maxLevel, currentRowOffset, currentColumnOffset + columnOdd.getElseOffset(), ddRowVariableIndices, ddColumnVariableIndices, generateValues); + // Visit then-else. + toMatrixRec(thenElse, rowIndications, columnsAndValues, rowGroupOffsets, rowOdd.getThenSuccessor(), columnOdd.getElseSuccessor(), currentRowLevel + 1, currentColumnLevel + 1, maxLevel, currentRowOffset + rowOdd.getElseOffset(), currentColumnOffset, ddRowVariableIndices, ddColumnVariableIndices, generateValues); + // Visit then-then. + toMatrixRec(thenThen, rowIndications, columnsAndValues, rowGroupOffsets, rowOdd.getThenSuccessor(), columnOdd.getThenSuccessor(), currentRowLevel + 1, currentColumnLevel + 1, maxLevel, currentRowOffset + rowOdd.getElseOffset(), currentColumnOffset + columnOdd.getElseOffset(), ddRowVariableIndices, ddColumnVariableIndices, generateValues); + } + } + + void Dd::splitGroupsRec(DdNode* dd, std::vector>& groups, std::vector const& ddGroupVariableIndices, uint_fast64_t currentLevel, uint_fast64_t maxLevel, std::set const& remainingMetaVariables) const { + // For the empty DD, we do not need to create a group. + if (dd == this->getDdManager()->getZero().getCuddAdd().getNode()) { + return; + } + + if (currentLevel == maxLevel) { + groups.push_back(Dd(this->getDdManager(), ADD(this->getDdManager()->getCuddManager(), dd), remainingMetaVariables)); + } else if (ddGroupVariableIndices[currentLevel] < dd->index) { + splitGroupsRec(dd, groups, ddGroupVariableIndices, currentLevel + 1, maxLevel, remainingMetaVariables); + splitGroupsRec(dd, groups, ddGroupVariableIndices, currentLevel + 1, maxLevel, remainingMetaVariables); + } else { + splitGroupsRec(Cudd_E(dd), groups, ddGroupVariableIndices, currentLevel + 1, maxLevel, remainingMetaVariables); + splitGroupsRec(Cudd_T(dd), groups, ddGroupVariableIndices, currentLevel + 1, maxLevel, remainingMetaVariables); + } + } + + template + void Dd::addToVectorRec(DdNode const* dd, uint_fast64_t currentLevel, uint_fast64_t maxLevel, uint_fast64_t currentOffset, Odd const& odd, std::vector const& ddVariableIndices, std::vector& targetVector) const { + // For the empty DD, we do not need to add any entries. + if (dd == this->getDdManager()->getZero().getCuddAdd().getNode()) { + return; + } + + // If we are at the maximal level, the value to be set is stored as a constant in the DD. + if (currentLevel == maxLevel) { + targetVector[currentOffset] += static_cast(Cudd_V(dd)); + } else if (ddVariableIndices[currentLevel] < dd->index) { + // If we skipped a level, we need to enumerate the explicit entries for the case in which the bit is set + // and for the one in which it is not set. + addToVectorRec(dd, currentLevel + 1, maxLevel, currentOffset, odd.getElseSuccessor(), ddVariableIndices, targetVector); + addToVectorRec(dd, currentLevel + 1, maxLevel, currentOffset + odd.getElseOffset(), odd.getThenSuccessor(), ddVariableIndices, targetVector); + } else { + // Otherwise, we simply recursively call the function for both (different) cases. + addToVectorRec(Cudd_E(dd), currentLevel + 1, maxLevel, currentOffset, odd.getElseSuccessor(), ddVariableIndices, targetVector); + addToVectorRec(Cudd_T(dd), currentLevel + 1, maxLevel, currentOffset + odd.getElseOffset(), odd.getThenSuccessor(), ddVariableIndices, targetVector); + } + } + + std::vector Dd::getSortedVariableIndices() const { + std::vector ddVariableIndices; + for (auto const& metaVariableName : this->getContainedMetaVariableNames()) { + auto const& metaVariable = this->getDdManager()->getMetaVariable(metaVariableName); + for (auto const& ddVariable : metaVariable.getDdVariables()) { + ddVariableIndices.push_back(ddVariable.getIndex()); + } + } + + // Next, we need to sort them, since they may be arbitrarily ordered otherwise. + std::sort(ddVariableIndices.begin(), ddVariableIndices.end()); + return ddVariableIndices; + } + bool Dd::containsMetaVariable(std::string const& metaVariableName) const { auto const& metaVariable = containedMetaVariableNames.find(metaVariableName); return metaVariable != containedMetaVariableNames.end(); @@ -477,17 +825,68 @@ namespace storm { int* cube; double value; DdGen* generator = this->getCuddAdd().FirstCube(&cube, &value); - return DdForwardIterator(this->getDdManager(), generator, cube, value, Cudd_IsGenEmpty(generator), &this->getContainedMetaVariableNames(), enumerateDontCareMetaVariables); + return DdForwardIterator(this->getDdManager(), generator, cube, value, (Cudd_IsGenEmpty(generator) != 0), &this->getContainedMetaVariableNames(), enumerateDontCareMetaVariables); } DdForwardIterator Dd::end(bool enumerateDontCareMetaVariables) const { return DdForwardIterator(this->getDdManager(), nullptr, nullptr, 0, true, nullptr, enumerateDontCareMetaVariables); } + storm::expressions::Expression Dd::toExpression() const { + return toExpressionRecur(this->getCuddAdd().getNode(), this->getDdManager()->getDdVariableNames()); + } + + storm::expressions::Expression Dd::getMintermExpression() const { + // Note that we first transform the ADD into a BDD to convert all non-zero terminals to ones and therefore + // make the DD more compact. + Dd tmp(this->getDdManager(), this->getCuddAdd().BddPattern().Add(), this->getContainedMetaVariableNames()); + return getMintermExpressionRecur(this->getDdManager()->getCuddManager().getManager(), this->getCuddAdd().BddPattern().getNode(), this->getDdManager()->getDdVariableNames()); + } + + storm::expressions::Expression Dd::toExpressionRecur(DdNode const* dd, std::vector const& variableNames) { + // If the DD is a terminal node, we can simply return a constant expression. + if (Cudd_IsConstant(dd)) { + return storm::expressions::Expression::createDoubleLiteral(static_cast(Cudd_V(dd))); + } else { + return storm::expressions::Expression::createBooleanVariable(variableNames[dd->index]).ite(toExpressionRecur(Cudd_T(dd), variableNames), toExpressionRecur(Cudd_E(dd), variableNames)); + } + } + + storm::expressions::Expression Dd::getMintermExpressionRecur(::DdManager* manager, DdNode const* dd, std::vector const& variableNames) { + // If the DD is a terminal node, we can simply return a constant expression. + if (Cudd_IsConstant(dd)) { + if (Cudd_IsComplement(dd)) { + return storm::expressions::Expression::createBooleanLiteral(false); + } else { + return storm::expressions::Expression::createBooleanLiteral((dd == Cudd_ReadOne(manager)) ? true : false); + } + } else { + // Get regular versions of the pointers. + DdNode* regularDd = Cudd_Regular(dd); + DdNode* thenDd = Cudd_T(regularDd); + DdNode* elseDd = Cudd_E(regularDd); + + // Compute expression recursively. + storm::expressions::Expression result = storm::expressions::Expression::createBooleanVariable(variableNames[dd->index]).ite(getMintermExpressionRecur(manager, thenDd, variableNames), getMintermExpressionRecur(manager, elseDd, variableNames)); + if (Cudd_IsComplement(dd)) { + result = !result; + } + + return result; + } + } + std::ostream & operator<<(std::ostream& out, const Dd& dd) { dd.exportToDot(); return out; } - + + // Explicitly instantiate some templated functions. + template std::vector Dd::toVector() const; + template std::vector Dd::toVector(Odd const& rowOdd) const; + template void Dd::addToVectorRec(DdNode const* dd, uint_fast64_t currentLevel, uint_fast64_t maxLevel, uint_fast64_t currentOffset, Odd const& odd, std::vector const& ddVariableIndices, std::vector& targetVector) const; + template std::vector Dd::toVector() const; + template std::vector Dd::toVector(Odd const& rowOdd) const; + template void Dd::addToVectorRec(DdNode const* dd, uint_fast64_t currentLevel, uint_fast64_t maxLevel, uint_fast64_t currentOffset, Odd const& odd, std::vector const& ddVariableIndices, std::vector& targetVector) const; } } \ No newline at end of file diff --git a/src/storage/dd/CuddDd.h b/src/storage/dd/CuddDd.h index 9f638d329..316e30c89 100644 --- a/src/storage/dd/CuddDd.h +++ b/src/storage/dd/CuddDd.h @@ -8,6 +8,8 @@ #include "src/storage/dd/Dd.h" #include "src/storage/dd/CuddDdForwardIterator.h" +#include "src/storage/SparseMatrix.h" +#include "src/storage/expressions/Expression.h" #include "src/utility/OsDetection.h" // Include the C++-interface of CUDD. @@ -15,8 +17,9 @@ namespace storm { namespace dd { - // Forward-declare the DdManager class. + // Forward-declare some classes. template class DdManager; + template class Odd; template<> class Dd { @@ -24,6 +27,7 @@ namespace storm { // Declare the DdManager and DdIterator class as friend so it can access the internals of a DD. friend class DdManager; friend class DdForwardIterator; + friend class Odd; // Instantiate all copy/move constructors/assignments with the default implementation. Dd() = default; @@ -231,35 +235,35 @@ namespace storm { * * @param metaVariableNames The names of all meta variables from which to abstract. */ - void existsAbstract(std::set const& metaVariableNames); + Dd existsAbstract(std::set const& metaVariableNames) const; /*! * Universally abstracts from the given meta variables. * * @param metaVariableNames The names of all meta variables from which to abstract. */ - void universalAbstract(std::set const& metaVariableNames); + Dd universalAbstract(std::set const& metaVariableNames) const; /*! * Sum-abstracts from the given meta variables. * * @param metaVariableNames The names of all meta variables from which to abstract. */ - void sumAbstract(std::set const& metaVariableNames); + Dd sumAbstract(std::set const& metaVariableNames) const; /*! * Min-abstracts from the given meta variables. * * @param metaVariableNames The names of all meta variables from which to abstract. */ - void minAbstract(std::set const& metaVariableNames); + Dd minAbstract(std::set const& metaVariableNames) const; /*! * Max-abstracts from the given meta variables. * * @param metaVariableNames The names of all meta variables from which to abstract. */ - void maxAbstract(std::set const& metaVariableNames); + Dd maxAbstract(std::set const& metaVariableNames) const; /*! * Checks whether the current and the given DD represent the same function modulo some given precision. @@ -291,6 +295,59 @@ namespace storm { */ Dd multiplyMatrix(Dd const& otherMatrix, std::set const& summationMetaVariableNames) const; + /*! + * Computes a DD that represents the function in which all assignments with a function value strictly larger + * than the given value are mapped to one and all others to zero. + * + * @param value The value used for the comparison. + * @return The resulting DD. + */ + Dd greater(double value) const; + + /*! + * Computes a DD that represents the function in which all assignments with a function value larger or equal + * to the given value are mapped to one and all others to zero. + * + * @param value The value used for the comparison. + * @return The resulting DD. + */ + Dd greaterOrEqual(double value) const; + + /*! + * Computes a DD that represents the function in which all assignments with a function value unequal to zero + * are mapped to one and all others to zero. + * + * @return The resulting DD. + */ + Dd notZero() const; + + /*! + * Computes the constraint of the current DD with the given constraint. That is, the function value of the + * resulting DD will be the same as the current ones for all assignments mapping to one in the constraint + * and may be different otherwise. + * + * @param constraint The constraint to use for the operation. + * @return The resulting DD. + */ + Dd constrain(Dd const& constraint) const; + + /*! + * Computes the restriction of the current DD with the given constraint. That is, the function value of the + * resulting DD will be the same as the current ones for all assignments mapping to one in the constraint + * and may be different otherwise. + * + * @param constraint The constraint to use for the operation. + * @return The resulting DD. + */ + Dd restrict(Dd const& constraint) const; + + /*! + * Retrieves the support of the current DD. + * + * @return The support represented as a DD. + */ + Dd getSupport() const; + /*! * Retrieves the number of encodings that are mapped to a non-zero value. * @@ -393,6 +450,76 @@ namespace storm { */ bool isConstant() const; + /*! + * Retrieves the index of the topmost variable in the DD. + * + * @return The index of the topmost variable in DD. + */ + uint_fast64_t getIndex() const; + + /*! + * Converts the DD to a vector. + * + * @return The double vector that is represented by this DD. + */ + template + std::vector toVector() const; + + /*! + * Converts the DD to a vector. The given offset-labeled DD is used to determine the correct row of + * each entry. + * + * @param rowOdd The ODD used for determining the correct row. + * @return The double vector that is represented by this DD. + */ + template + std::vector toVector(storm::dd::Odd const& rowOdd) const; + + /*! + * Converts the DD to a (sparse) double matrix. All contained non-primed variables are assumed to encode the + * row, whereas all primed variables are assumed to encode the column. + * + * @return The matrix that is represented by this DD. + */ + storm::storage::SparseMatrix toMatrix() const; + + /*! + * Converts the DD to a (sparse) double matrix. All contained non-primed variables are assumed to encode the + * row, whereas all primed variables are assumed to encode the column. The given offset-labeled DDs are used + * to determine the correct row and column, respectively, for each entry. + * + * @param rowOdd The ODD used for determining the correct row. + * @param columnOdd The ODD used for determining the correct column. + * @return The matrix that is represented by this DD. + */ + storm::storage::SparseMatrix toMatrix(storm::dd::Odd const& rowOdd, storm::dd::Odd const& columnOdd) const; + + /*! + * Converts the DD to a (sparse) double matrix. The given offset-labeled DDs are used to determine the + * correct row and column, respectively, for each entry. + * + * @param rowMetaVariables The meta variables that encode the rows of the matrix. + * @param columnMetaVariables The meta variables that encode the columns of the matrix. + * @param rowOdd The ODD used for determining the correct row. + * @param columnOdd The ODD used for determining the correct column. + * @return The matrix that is represented by this DD. + */ + storm::storage::SparseMatrix toMatrix(std::set const& rowMetaVariables, std::set const& columnMetaVariables, storm::dd::Odd const& rowOdd, storm::dd::Odd const& columnOdd) const; + + /*! + * Converts the DD to a row-grouped (sparse) double matrix. The given offset-labeled DDs are used to + * determine the correct row and column, respectively, for each entry. Note: this function assumes that + * the meta variables used to distinguish different row groups are at the very top of the DD. + * + * @param rowMetaVariables The meta variables that encode the rows of the matrix. + * @param columnMetaVariables The meta variables that encode the columns of the matrix. + * @param groupMetaVariables The meta variables that are used to distinguish different row groups. + * @param rowOdd The ODD used for determining the correct row. + * @param columnOdd The ODD used for determining the correct column. + * @return The matrix that is represented by this DD. + */ + storm::storage::SparseMatrix toMatrix(std::set const& rowMetaVariables, std::set const& columnMetaVariables, std::set const& groupMetaVariables, storm::dd::Odd const& rowOdd, storm::dd::Odd const& columnOdd) const; + /*! * Retrieves whether the given meta variable is contained in the DD. * @@ -455,6 +582,26 @@ namespace storm { */ DdForwardIterator end(bool enumerateDontCareMetaVariables = true) const; + /*! + * Converts the DD into a (heavily nested) if-then-else expression that represents the very same function. + * The variable names used in the expression are derived from the meta variable name and are extended with a + * suffix ".i" if the meta variable is integer-valued, expressing that the variable is the i-th bit of the + * meta variable. + * + * @return The resulting expression. + */ + storm::expressions::Expression toExpression() const; + + /*! + * Converts the DD into a (heavily nested) if-then-else (with negations) expression that evaluates to true + * if and only if the assignment is minterm of the DD. The variable names used in the expression are derived + * from the meta variable name and are extended with a suffix ".i" if the meta variable is integer-valued, + * expressing that the variable is the i-th bit of the meta variable. + * + * @return The resulting expression. + */ + storm::expressions::Expression getMintermExpression() const; + friend std::ostream & operator<<(std::ostream& out, const Dd& dd); private: /*! @@ -485,15 +632,95 @@ namespace storm { */ void removeContainedMetaVariable(std::string const& metaVariableName); + /*! + * Performs the recursive step of toExpression on the given DD. + * + * @param dd The dd to translate into an expression. + * @param variableNames The names of the variables to use in the expression. + * @return The resulting expression. + */ + static storm::expressions::Expression toExpressionRecur(DdNode const* dd, std::vector const& variableNames); + + /*! + * Performs the recursive step of getMintermExpression on the given DD. + * + * @param manager The manager of the DD. + * @param dd The dd whose minterms to translate into an expression. + * @param variableNames The names of the variables to use in the expression. + * @return The resulting expression. + */ + static storm::expressions::Expression getMintermExpressionRecur(::DdManager* manager, DdNode const* dd, std::vector const& variableNames); + /*! * Creates a DD that encapsulates the given CUDD ADD. * * @param ddManager The manager responsible for this DD. * @param cuddAdd The CUDD ADD to store. - * @param + * @param containedMetaVariableNames The names of the meta variables that appear in the DD. */ Dd(std::shared_ptr> ddManager, ADD cuddAdd, std::set const& containedMetaVariableNames = std::set()); + /*! + * Helper function to convert the DD into a (sparse) matrix. + * + * @param dd The DD to convert. + * @param rowIndications A vector indicating at which position in the columnsAndValues vector the entries + * of row i start. Note: this vector is modified in the computation. More concretely, each entry i in the + * vector will be increased by the number of entries in the row. This can be used to count the number + * of entries in each row. If the values are not to be modified, a copy needs to be provided or the entries + * need to be restored afterwards. + * @param columnsAndValues The vector that will hold the columns and values of non-zero entries upon successful + * completion. + * @param rowGroupOffsets The row offsets at which a given row group starts. + * @param rowOdd The ODD used for the row translation. + * @param columnOdd The ODD used for the column translation. + * @param currentRowLevel The currently considered row level in the DD. + * @param currentColumnLevel The currently considered row level in the DD. + * @param maxLevel The number of levels that need to be considered. + * @param currentRowOffset The current row offset. + * @param currentColumnOffset The current row offset. + * @param ddRowVariableIndices The (sorted) indices of all DD row variables that need to be considered. + * @param ddColumnVariableIndices The (sorted) indices of all DD row variables that need to be considered. + * @param generateValues If set to true, the vector columnsAndValues is filled with the actual entries, which + * only works if the offsets given in rowIndications are already correct. If they need to be computed first, + * this flag needs to be false. + */ + void toMatrixRec(DdNode const* dd, std::vector& rowIndications, std::vector>& columnsAndValues, std::vector const& rowGroupOffsets, Odd const& rowOdd, Odd const& columnOdd, uint_fast64_t currentRowLevel, uint_fast64_t currentColumnLevel, uint_fast64_t maxLevel, uint_fast64_t currentRowOffset, uint_fast64_t currentColumnOffset, std::vector const& ddRowVariableIndices, std::vector const& ddColumnVariableIndices, bool generateValues = true) const; + + /*! + * Splits the given matrix DD into the groups using the given group variables. + * + * @param dd The DD to split. + * @param groups A vector that is to be filled with the DDs for the individual groups. + * @param ddGroupVariableIndices The (sorted) indices of all DD group variables that need to be considered. + * @param currentLevel The currently considered level in the DD. + * @param maxLevel The number of levels that need to be considered. + * @param remainingMetaVariables The meta variables that remain in the DDs after the groups have been split. + */ + void splitGroupsRec(DdNode* dd, std::vector>& groups, std::vector const& ddGroupVariableIndices, uint_fast64_t currentLevel, uint_fast64_t maxLevel, std::set const& remainingMetaVariables) const; + + /*! + * Performs a recursive step to add the given DD-based vector to the given explicit vector. + * + * @param dd The DD to add to the explicit vector. + * @param currentLevel The currently considered level in the DD. + * @param maxLevel The number of levels that need to be considered. + * @param currentOffset The current offset. + * @param odd The ODD used for the translation. + * @param ddVariableIndices The (sorted) indices of all DD variables that need to be considered. + * @param targetVector The vector to which the translated DD-based vector is to be added. + */ + template + void addToVectorRec(DdNode const* dd, uint_fast64_t currentLevel, uint_fast64_t maxLevel, uint_fast64_t currentOffset, Odd const& odd, std::vector const& ddVariableIndices, std::vector& targetVector) const; + + /*! + * Retrieves the indices of all DD variables that are contained in this DD (not necessarily in the support, + * because they could be "don't cares"). Additionally, the indices are sorted to allow for easy access. + * + * @return The (sorted) indices of all DD variables that are contained in this DD. + */ + std::vector getSortedVariableIndices() const; + // A pointer to the manager responsible for this DD. std::shared_ptr> ddManager; diff --git a/src/storage/dd/CuddDdForwardIterator.cpp b/src/storage/dd/CuddDdForwardIterator.cpp index cce5c112c..4a990746e 100644 --- a/src/storage/dd/CuddDdForwardIterator.cpp +++ b/src/storage/dd/CuddDdForwardIterator.cpp @@ -70,7 +70,7 @@ namespace storm { if (this->relevantDontCareDdVariables.empty() || this->cubeCounter >= std::pow(2, this->relevantDontCareDdVariables.size()) - 1) { // Get the next cube and check for emptiness. ABDD::NextCube(generator, &cube, &value); - this->isAtEnd = Cudd_IsGenEmpty(generator); + this->isAtEnd = (Cudd_IsGenEmpty(generator) != 0); // In case we are not done yet, we get ready to treat the next cube. if (!this->isAtEnd) { diff --git a/src/storage/dd/CuddDdManager.cpp b/src/storage/dd/CuddDdManager.cpp index fd26a782c..3f0a2abed 100644 --- a/src/storage/dd/CuddDdManager.cpp +++ b/src/storage/dd/CuddDdManager.cpp @@ -33,7 +33,7 @@ bool CuddOptionsRegistered = storm::settings::Settings::registerNewModule([] (st namespace storm { namespace dd { DdManager::DdManager() : metaVariableMap(), cuddManager() { - this->cuddManager.SetMaxMemory(storm::settings::Settings::getInstance()->getOptionByLongName("cuddmaxmem").getArgument(0).getValueAsUnsignedInteger() * 1024 * 1024); + this->cuddManager.SetMaxMemory(static_cast(storm::settings::Settings::getInstance()->getOptionByLongName("cuddmaxmem").getArgument(0).getValueAsUnsignedInteger() * 1024ul * 1024ul)); this->cuddManager.SetEpsilon(storm::settings::Settings::getInstance()->getOptionByLongName("cuddprec").getArgument(0).getValueAsDouble()); } @@ -183,8 +183,14 @@ namespace storm { std::vector> variableNamePairs; for (auto const& nameMetaVariablePair : this->metaVariableMap) { DdMetaVariable const& metaVariable = nameMetaVariablePair.second; - for (uint_fast64_t variableIndex = 0; variableIndex < metaVariable.getNumberOfDdVariables(); ++variableIndex) { - variableNamePairs.emplace_back(metaVariable.getDdVariables()[variableIndex].getCuddAdd(), metaVariable.getName() + "." + std::to_string(variableIndex)); + // If the meta variable is of type bool, we don't need to suffix it with the bit number. + if (metaVariable.getType() == DdMetaVariable::MetaVariableType::Bool) { + variableNamePairs.emplace_back(metaVariable.getDdVariables().front().getCuddAdd(), metaVariable.getName()); + } else { + // For integer-valued meta variables, we, however, have to add the suffix. + for (uint_fast64_t variableIndex = 0; variableIndex < metaVariable.getNumberOfDdVariables(); ++variableIndex) { + variableNamePairs.emplace_back(metaVariable.getDdVariables()[variableIndex].getCuddAdd(), metaVariable.getName() + "." + std::to_string(variableIndex)); + } } } diff --git a/src/storage/dd/CuddDdManager.h b/src/storage/dd/CuddDdManager.h index 14ea4d8a3..75c6a23de 100644 --- a/src/storage/dd/CuddDdManager.h +++ b/src/storage/dd/CuddDdManager.h @@ -16,6 +16,7 @@ namespace storm { class DdManager : public std::enable_shared_from_this> { public: friend class Dd; + friend class Odd; friend class DdForwardIterator; /*! @@ -138,7 +139,6 @@ namespace storm { */ void triggerReordering(); - protected: /*! * Retrieves the meta variable with the given name if it exists. * diff --git a/src/storage/dd/CuddDdMetaVariable.cpp b/src/storage/dd/CuddDdMetaVariable.cpp index b7f45dbe3..13667bd99 100644 --- a/src/storage/dd/CuddDdMetaVariable.cpp +++ b/src/storage/dd/CuddDdMetaVariable.cpp @@ -10,7 +10,7 @@ namespace storm { } } - DdMetaVariable::DdMetaVariable(std::string const& name, std::vector> const& ddVariables, std::shared_ptr> manager) : name(name), type(MetaVariableType::Bool), ddVariables(ddVariables), cube(manager->getOne()), manager(manager) { + DdMetaVariable::DdMetaVariable(std::string const& name, std::vector> const& ddVariables, std::shared_ptr> manager) : name(name), type(MetaVariableType::Bool), low(0), high(1), ddVariables(ddVariables), cube(manager->getOne()), manager(manager) { // Create the cube of all variables of this meta variable. for (auto const& ddVariable : this->ddVariables) { this->cube *= ddVariable; @@ -21,7 +21,7 @@ namespace storm { return this->name; } - typename DdMetaVariable::MetaVariableType DdMetaVariable::getType() const { + DdMetaVariable::MetaVariableType DdMetaVariable::getType() const { return this->type; } diff --git a/src/storage/dd/CuddDdMetaVariable.h b/src/storage/dd/CuddDdMetaVariable.h index b54ca1939..25375dc4d 100644 --- a/src/storage/dd/CuddDdMetaVariable.h +++ b/src/storage/dd/CuddDdMetaVariable.h @@ -13,8 +13,9 @@ namespace storm { namespace dd { - // Forward-declare the DdManager class. + // Forward-declare some classes. template class DdManager; + template class Odd; template<> class DdMetaVariable { @@ -22,6 +23,7 @@ namespace storm { // Declare the DdManager class as friend so it can access the internals of a meta variable. friend class DdManager; friend class Dd; + friend class Odd; friend class DdForwardIterator; // An enumeration for all legal types of meta variables. diff --git a/src/storage/dd/CuddOdd.cpp b/src/storage/dd/CuddOdd.cpp new file mode 100644 index 000000000..dc84d9fd7 --- /dev/null +++ b/src/storage/dd/CuddOdd.cpp @@ -0,0 +1,111 @@ +#include "src/storage/dd/CuddOdd.h" + +#include + +#include "src/storage/dd/CuddDdManager.h" +#include "src/storage/dd/CuddDdMetaVariable.h" + +namespace storm { + namespace dd { + Odd::Odd(Dd const& dd) { + std::shared_ptr> manager = dd.getDdManager(); + + // First, we need to determine the involved DD variables indices. + std::vector ddVariableIndices = dd.getSortedVariableIndices(); + + // Prepare a unique table for each level that keeps the constructed ODD nodes unique. + std::vector>>> uniqueTableForLevels(ddVariableIndices.size() + 1); + + // Now construct the ODD structure. + std::shared_ptr> rootOdd = buildOddRec(dd.getCuddAdd().getNode(), manager->getCuddManager(), 0, ddVariableIndices.size(), ddVariableIndices, uniqueTableForLevels); + + // Finally, move the children of the root ODD into this ODD. + this->dd = rootOdd->dd; + this->elseNode = std::move(rootOdd->elseNode); + this->thenNode = std::move(rootOdd->thenNode); + this->elseOffset = rootOdd->elseOffset; + this->thenOffset = rootOdd->thenOffset; + } + + Odd::Odd(ADD dd, std::shared_ptr>&& elseNode, uint_fast64_t elseOffset, std::shared_ptr>&& thenNode, uint_fast64_t thenOffset) : dd(dd), elseNode(elseNode), thenNode(thenNode), elseOffset(elseOffset), thenOffset(thenOffset) { + // Intentionally left empty. + } + + Odd const& Odd::getThenSuccessor() const { + return *this->thenNode; + } + + Odd const& Odd::getElseSuccessor() const { + return *this->elseNode; + } + + uint_fast64_t Odd::getElseOffset() const { + return this->elseOffset; + } + + void Odd::setElseOffset(uint_fast64_t newOffset) { + this->elseOffset = newOffset; + } + + uint_fast64_t Odd::getThenOffset() const { + return this->thenOffset; + } + + void Odd::setThenOffset(uint_fast64_t newOffset) { + this->thenOffset = newOffset; + } + + uint_fast64_t Odd::getTotalOffset() const { + return this->elseOffset + this->thenOffset; + } + + uint_fast64_t Odd::getNodeCount() const { + // If the ODD contains a constant (and thus has no children), the size is 1. + if (this->elseNode == nullptr && this->thenNode == nullptr) { + return 1; + } + + // If the two successors are actually the same, we need to count the subnodes only once. + if (this->elseNode == this->thenNode) { + return this->elseNode->getNodeCount(); + } else { + return this->elseNode->getNodeCount() + this->thenNode->getNodeCount(); + } + } + + std::shared_ptr> Odd::buildOddRec(DdNode* dd, Cudd const& manager, uint_fast64_t currentLevel, uint_fast64_t maxLevel, std::vector const& ddVariableIndices, std::vector>>>& uniqueTableForLevels) { + // Check whether the ODD for this node has already been computed (for this level) and if so, return this instead. + auto const& iterator = uniqueTableForLevels[currentLevel].find(dd); + if (iterator != uniqueTableForLevels[currentLevel].end()) { + return iterator->second; + } else { + // Otherwise, we need to recursively compute the ODD. + + // If we are already past the maximal level that is to be considered, we can simply create a Odd without + // successors + if (currentLevel == maxLevel) { + uint_fast64_t elseOffset = 0; + uint_fast64_t thenOffset = 0; + + // If the DD is not the zero leaf, then the then-offset is 1. + if (dd != Cudd_ReadZero(manager.getManager())) { + thenOffset = 1; + } + + return std::shared_ptr>(new Odd(ADD(manager, dd), nullptr, elseOffset, nullptr, thenOffset)); + } else if (ddVariableIndices[currentLevel] < static_cast(dd->index)) { + // If we skipped the level in the DD, we compute the ODD just for the else-successor and use the same + // node for the then-successor as well. + std::shared_ptr> elseNode = buildOddRec(dd, manager, currentLevel + 1, maxLevel, ddVariableIndices, uniqueTableForLevels); + std::shared_ptr> thenNode = elseNode; + return std::shared_ptr>(new Odd(ADD(manager, dd), std::move(elseNode), elseNode->getElseOffset() + elseNode->getThenOffset(), std::move(thenNode), thenNode->getElseOffset() + thenNode->getThenOffset())); + } else { + // Otherwise, we compute the ODDs for both the then- and else successors. + std::shared_ptr> elseNode = buildOddRec(Cudd_E(dd), manager, currentLevel + 1, maxLevel, ddVariableIndices, uniqueTableForLevels); + std::shared_ptr> thenNode = buildOddRec(Cudd_T(dd), manager, currentLevel + 1, maxLevel, ddVariableIndices, uniqueTableForLevels); + return std::shared_ptr>(new Odd(ADD(manager, dd), std::move(elseNode), elseNode->getElseOffset() + elseNode->getThenOffset(), std::move(thenNode), thenNode->getElseOffset() + thenNode->getThenOffset())); + } + } + } + } +} \ No newline at end of file diff --git a/src/storage/dd/CuddOdd.h b/src/storage/dd/CuddOdd.h new file mode 100644 index 000000000..1d36613f0 --- /dev/null +++ b/src/storage/dd/CuddOdd.h @@ -0,0 +1,131 @@ +#ifndef STORM_STORAGE_DD_CUDDODD_H_ +#define STORM_STORAGE_DD_CUDDODD_H_ + +#include + +#include "src/storage/dd/Odd.h" +#include "src/storage/dd/CuddDd.h" +#include "src/utility/OsDetection.h" + +// Include the C++-interface of CUDD. +#include "cuddObj.hh" + +namespace storm { + namespace dd { + template<> + class Odd { + public: + /*! + * Constructs an offset-labeled DD from the given DD. + * + * @param dd The DD for which to build the offset-labeled DD. + */ + Odd(Dd const& dd); + + // Instantiate all copy/move constructors/assignments with the default implementation. + Odd() = default; + Odd(Odd const& other) = default; + Odd& operator=(Odd const& other) = default; +#ifndef WINDOWS + Odd(Odd&& other) = default; + Odd& operator=(Odd&& other) = default; +#endif + + /*! + * Retrieves the then-successor of this ODD node. + * + * @return The then-successor of this ODD node. + */ + Odd const& getThenSuccessor() const; + + /*! + * Retrieves the else-successor of this ODD node. + * + * @return The else-successor of this ODD node. + */ + Odd const& getElseSuccessor() const; + + /*! + * Retrieves the else-offset of this ODD node. + * + * @return The else-offset of this ODD node. + */ + uint_fast64_t getElseOffset() const; + + /*! + * Sets the else-offset of this ODD node. + * + * @param newOffset The new else-offset of this ODD node. + */ + void setElseOffset(uint_fast64_t newOffset); + + /*! + * Retrieves the then-offset of this ODD node. + * + * @return The then-offset of this ODD node. + */ + uint_fast64_t getThenOffset() const; + + /*! + * Sets the then-offset of this ODD node. + * + * @param newOffset The new then-offset of this ODD node. + */ + void setThenOffset(uint_fast64_t newOffset); + + /*! + * Retrieves the total offset, i.e., the sum of the then- and else-offset. + * + * @return The total offset of this ODD. + */ + uint_fast64_t getTotalOffset() const; + + /*! + * Retrieves the size of the ODD. Note: the size is computed by a traversal, so this may be costlier than + * expected. + * + * @return The size (in nodes) of this ODD. + */ + uint_fast64_t getNodeCount() const; + + private: + /*! + * Constructs an offset-labeled DD with the given topmost DD node, else- and then-successor. + * + * @param dd The DD associated with this ODD node. + * @param elseNode The else-successor of thie ODD node. + * @param elseOffset The offset of the else-successor. + * @param thenNode The then-successor of thie ODD node. + * @param thenOffset The offset of the then-successor. + */ + Odd(ADD dd, std::shared_ptr>&& elseNode, uint_fast64_t elseOffset, std::shared_ptr>&& thenNode, uint_fast64_t thenOffset); + + /*! + * Recursively builds the ODD. + * + * @param dd The DD for which to build the ODD. + * @param manager The manager responsible for the DD. + * @param currentLevel The currently considered level in the DD. + * @param maxLevel The number of levels that need to be considered. + * @param ddVariableIndices The (sorted) indices of all DD variables that need to be considered. + * @param uniqueTableForLevels A vector of unique tables, one for each level to be considered, that keeps + * ODD nodes for the same DD and level unique. + * @return A pointer to the constructed ODD for the given arguments. + */ + static std::shared_ptr> buildOddRec(DdNode* dd, Cudd const& manager, uint_fast64_t currentLevel, uint_fast64_t maxLevel, std::vector const& ddVariableIndices, std::vector>>>& uniqueTableForLevels); + + // The DD associated with this ODD node. + ADD dd; + + // The then- and else-nodes. + std::shared_ptr> elseNode; + std::shared_ptr> thenNode; + + // The offsets that need to be added if the then- or else-successor is taken, respectively. + uint_fast64_t elseOffset; + uint_fast64_t thenOffset; + }; + } +} + +#endif /* STORM_STORAGE_DD_CUDDODD_H_ */ \ No newline at end of file diff --git a/src/storage/dd/Odd.h b/src/storage/dd/Odd.h new file mode 100644 index 000000000..6a7231300 --- /dev/null +++ b/src/storage/dd/Odd.h @@ -0,0 +1,13 @@ +#ifndef STORM_STORAGE_DD_ODD_H_ +#define STORM_STORAGE_DD_ODD_H_ + +#include "src/storage/dd/DdType.h" + +namespace storm { + namespace dd { + // Declare Odd class so we can then specialize it for the different DD types. + template class Odd; + } +} + +#endif /* STORM_STORAGE_DD_ODD_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/BinaryNumericalFunctionExpression.cpp b/src/storage/expressions/BinaryNumericalFunctionExpression.cpp index f6172821d..7db4fa0cf 100644 --- a/src/storage/expressions/BinaryNumericalFunctionExpression.cpp +++ b/src/storage/expressions/BinaryNumericalFunctionExpression.cpp @@ -1,4 +1,5 @@ #include +#include #include "src/storage/expressions/BinaryNumericalFunctionExpression.h" #include "src/exceptions/ExceptionMacros.h" @@ -22,6 +23,7 @@ namespace storm { case OperatorType::Divide: return storm::expressions::OperatorType::Divide; break; case OperatorType::Min: return storm::expressions::OperatorType::Min; break; case OperatorType::Max: return storm::expressions::OperatorType::Max; break; + case OperatorType::Power: return storm::expressions::OperatorType::Power; break; } } @@ -37,6 +39,7 @@ namespace storm { case OperatorType::Divide: return firstOperandEvaluation / secondOperandEvaluation; break; case OperatorType::Min: return std::min(firstOperandEvaluation, secondOperandEvaluation); break; case OperatorType::Max: return std::max(firstOperandEvaluation, secondOperandEvaluation); break; + case OperatorType::Power: return static_cast(std::pow(firstOperandEvaluation, secondOperandEvaluation)); break; } } @@ -52,6 +55,7 @@ namespace storm { case OperatorType::Divide: return static_cast(firstOperandEvaluation / secondOperandEvaluation); break; case OperatorType::Min: return static_cast(std::min(firstOperandEvaluation, secondOperandEvaluation)); break; case OperatorType::Max: return static_cast(std::max(firstOperandEvaluation, secondOperandEvaluation)); break; + case OperatorType::Power: return std::pow(firstOperandEvaluation, secondOperandEvaluation); break; } } @@ -79,6 +83,7 @@ namespace storm { case OperatorType::Divide: stream << *this->getFirstOperand() << " / " << *this->getSecondOperand(); break; case OperatorType::Min: stream << "min(" << *this->getFirstOperand() << ", " << *this->getSecondOperand() << ")"; break; case OperatorType::Max: stream << "max(" << *this->getFirstOperand() << ", " << *this->getSecondOperand() << ")"; break; + case OperatorType::Power: stream << *this->getFirstOperand() << " ^ " << *this->getSecondOperand(); break; } stream << ")"; } diff --git a/src/storage/expressions/BinaryNumericalFunctionExpression.h b/src/storage/expressions/BinaryNumericalFunctionExpression.h index 13ee489df..77b8021a4 100644 --- a/src/storage/expressions/BinaryNumericalFunctionExpression.h +++ b/src/storage/expressions/BinaryNumericalFunctionExpression.h @@ -11,7 +11,7 @@ namespace storm { /*! * An enum type specifying the different operators applicable. */ - enum class OperatorType {Plus, Minus, Times, Divide, Min, Max}; + enum class OperatorType {Plus, Minus, Times, Divide, Min, Max, Power}; /*! * Constructs a binary numerical function expression with the given return type, operands and operator. diff --git a/src/storage/expressions/Expression.cpp b/src/storage/expressions/Expression.cpp index 347d9a778..c61e93b9d 100644 --- a/src/storage/expressions/Expression.cpp +++ b/src/storage/expressions/Expression.cpp @@ -207,8 +207,8 @@ namespace storm { } Expression Expression::operator^(Expression const& other) const { - LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '^' requires boolean operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor))); + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '^' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Power))); } Expression Expression::operator&&(Expression const& other) const { @@ -232,8 +232,12 @@ namespace storm { } Expression Expression::operator!=(Expression const& other) const { - LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '!=' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual))); + LOG_THROW((this->hasNumericalReturnType() && other.hasNumericalReturnType()) || (this->hasBooleanReturnType() && other.hasBooleanReturnType()), storm::exceptions::InvalidTypeException, "Operator '!=' requires operands of equal type."); + if (this->hasNumericalReturnType() && other.hasNumericalReturnType()) { + return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual))); + } else { + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor))); + } } Expression Expression::operator>(Expression const& other) const { diff --git a/src/storage/expressions/LinearityCheckVisitor.cpp b/src/storage/expressions/LinearityCheckVisitor.cpp index 9b382ed22..6f595900a 100644 --- a/src/storage/expressions/LinearityCheckVisitor.cpp +++ b/src/storage/expressions/LinearityCheckVisitor.cpp @@ -73,6 +73,7 @@ namespace storm { break; case BinaryNumericalFunctionExpression::OperatorType::Min: resultStack.push(LinearityStatus::NonLinear); break; case BinaryNumericalFunctionExpression::OperatorType::Max: resultStack.push(LinearityStatus::NonLinear); break; + case BinaryNumericalFunctionExpression::OperatorType::Power: resultStack.push(LinearityStatus::NonLinear); break; } } diff --git a/src/storage/expressions/OperatorType.h b/src/storage/expressions/OperatorType.h index 8a4f199aa..8968cf105 100644 --- a/src/storage/expressions/OperatorType.h +++ b/src/storage/expressions/OperatorType.h @@ -16,6 +16,7 @@ namespace storm { Divide, Min, Max, + Power, Equal, NotEqual, Less, diff --git a/src/storm.cpp b/src/storm.cpp index c7b40247a..8d3810f72 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -31,6 +31,7 @@ #include "src/modelchecker/prctl/SparseDtmcPrctlModelChecker.h" #include "src/modelchecker/prctl/SparseMdpPrctlModelChecker.h" #include "src/solver/GmmxxLinearEquationSolver.h" +#include "src/solver/NativeLinearEquationSolver.h" #include "src/solver/GmmxxNondeterministicLinearEquationSolver.h" #include "src/solver/GurobiLpSolver.h" #include "src/counterexamples/MILPMinimalLabelSetGenerator.h" @@ -117,8 +118,8 @@ void printUsage() { double userTime = uLargeInteger.QuadPart / 10000.0; std::cout << "CPU Time: " << std::endl; - std::cout << "\tKernel Time: " << std::setprecision(3) << kernelTime << std::endl; - std::cout << "\tUser Time: " << std::setprecision(3) << userTime << std::endl; + std::cout << "\tKernel Time: " << std::setprecision(5) << kernelTime << "ms" << std::endl; + std::cout << "\tUser Time: " << std::setprecision(5) << userTime << "ms" << std::endl; #endif } @@ -252,10 +253,12 @@ void cleanUp() { storm::modelchecker::prctl::AbstractModelChecker* createPrctlModelChecker(storm::models::Dtmc const & dtmc) { // Create the appropriate model checker. storm::settings::Settings* s = storm::settings::Settings::getInstance(); - std::string const chosenMatrixLibrary = s->getOptionByLongName("matrixLibrary").getArgument(0).getValueAsString(); - if (chosenMatrixLibrary == "gmm++") { + std::string const& linsolver = s->getOptionByLongName("linsolver").getArgument(0).getValueAsString(); + if (linsolver == "gmm++") { return new storm::modelchecker::prctl::SparseDtmcPrctlModelChecker(dtmc, new storm::solver::GmmxxLinearEquationSolver()); - } + } else if (linsolver == "native") { + return new storm::modelchecker::prctl::SparseDtmcPrctlModelChecker(dtmc, new storm::solver::NativeLinearEquationSolver()); + } // The control flow should never reach this point, as there is a default setting for matrixlib. std::string message = "No matrix library suitable for DTMC model checking has been set."; @@ -451,6 +454,9 @@ int main(const int argc, const char* argv[]) { stormSetAlarm(timeout); } + // Execution Time measurement, start + std::chrono::high_resolution_clock::time_point executionStart = std::chrono::high_resolution_clock::now(); + // Now, the settings are received and the specified model is parsed. The actual actions taken depend on whether // the model was provided in explicit or symbolic format. if (s->isSet("explicit")) { @@ -468,6 +474,10 @@ int main(const int argc, const char* argv[]) { std::shared_ptr> model = storm::parser::AutoParser::parseModel(chosenTransitionSystemFile, chosenLabelingFile, chosenStateRewardsFile, chosenTransitionRewardsFile); + // Model Parsing Time Measurement, End + std::chrono::high_resolution_clock::time_point parsingEnd = std::chrono::high_resolution_clock::now(); + std::cout << "Parsing the given model took " << std::chrono::duration_cast(parsingEnd - executionStart).count() << " milliseconds." << std::endl; + if (s->isSet("exportdot")) { std::ofstream outputFileStream; outputFileStream.open(s->getOptionByLongName("exportdot").getArgument(0).getValueAsString(), std::ofstream::out); @@ -475,19 +485,24 @@ int main(const int argc, const char* argv[]) { outputFileStream.close(); } - //Should there be a counterexample generated in case the formula is not satisfied? + // Should there be a counterexample generated in case the formula is not satisfied? if(s->isSet("counterexample")) { + // Counterexample Time Measurement, Start + std::chrono::high_resolution_clock::time_point counterexampleStart = std::chrono::high_resolution_clock::now(); generateCounterExample(model); + // Counterexample Time Measurement, End + std::chrono::high_resolution_clock::time_point counterexampleEnd = std::chrono::high_resolution_clock::now(); + std::cout << "Generating the counterexample took " << std::chrono::duration_cast(counterexampleEnd - counterexampleStart).count() << " milliseconds." << std::endl; } else { - // Determine which engine is to be used to choose the right model checker. - LOG4CPLUS_DEBUG(logger, s->getOptionByLongName("matrixLibrary").getArgument(0).getValueAsString()); - // Depending on the model type, the appropriate model checking procedure is chosen. storm::modelchecker::prctl::AbstractModelChecker* modelchecker = nullptr; model->printModelInformationToStream(std::cout); + // Modelchecking Time Measurement, Start + std::chrono::high_resolution_clock::time_point modelcheckingStart = std::chrono::high_resolution_clock::now(); + switch (model->getType()) { case storm::models::DTMC: LOG4CPLUS_INFO(logger, "Model is a DTMC."); @@ -529,8 +544,15 @@ int main(const int argc, const char* argv[]) { if (modelchecker != nullptr) { delete modelchecker; } + + // Modelchecking Time Measurement, End + std::chrono::high_resolution_clock::time_point modelcheckingEnd = std::chrono::high_resolution_clock::now(); + std::cout << "Running the ModelChecker took " << std::chrono::duration_cast(modelcheckingEnd - modelcheckingStart).count() << " milliseconds." << std::endl; } } else if (s->isSet("symbolic")) { + // Program Translation Time Measurement, Start + std::chrono::high_resolution_clock::time_point programTranslationStart = std::chrono::high_resolution_clock::now(); + // First, we build the model using the given symbolic model description and constant definitions. std::string const& programFile = s->getOptionByLongName("symbolic").getArgument(0).getValueAsString(); std::string const& constants = s->getOptionByLongName("constants").getArgument(0).getValueAsString(); @@ -538,6 +560,10 @@ int main(const int argc, const char* argv[]) { std::shared_ptr> model = storm::adapters::ExplicitModelAdapter::translateProgram(program, constants); model->printModelInformationToStream(std::cout); + // Program Translation Time Measurement, End + std::chrono::high_resolution_clock::time_point programTranslationEnd = std::chrono::high_resolution_clock::now(); + std::cout << "Parsing and translating the Symbolic Input took " << std::chrono::duration_cast(programTranslationEnd - programTranslationStart).count() << " milliseconds." << std::endl; + if (s->isSet("mincmd")) { if (model->getType() != storm::models::MDP) { LOG4CPLUS_ERROR(logger, "Minimal command counterexample generation is only supported for models of type MDP."); @@ -549,6 +575,9 @@ int main(const int argc, const char* argv[]) { // Determine whether we are required to use the MILP-version or the SAT-version. bool useMILP = s->getOptionByLongName("mincmd").getArgumentByName("method").getValueAsString() == "milp"; + // MinCMD Time Measurement, Start + std::chrono::high_resolution_clock::time_point minCmdStart = std::chrono::high_resolution_clock::now(); + // Now parse the property file and receive the list of parsed formulas. std::string const& propertyFile = s->getOptionByLongName("mincmd").getArgumentByName("propertyFile").getValueAsString(); std::list*> formulaList = storm::parser::PrctlFileParser(propertyFile); @@ -564,13 +593,17 @@ int main(const int argc, const char* argv[]) { // Once we are done with the formula, delete it. delete formulaPtr; } + + // MinCMD Time Measurement, End + std::chrono::high_resolution_clock::time_point minCmdEnd = std::chrono::high_resolution_clock::now(); + std::cout << "Minimal command Counterexample generation took " << std::chrono::duration_cast(minCmdStart - minCmdEnd).count() << " milliseconds." << std::endl; } else if (s->isSet("prctl")) { - // Determine which engine is to be used to choose the right model checker. - LOG4CPLUS_DEBUG(logger, s->getOptionByLongName("matrixLibrary").getArgument(0).getValueAsString()); - // Depending on the model type, the appropriate model checking procedure is chosen. storm::modelchecker::prctl::AbstractModelChecker* modelchecker = nullptr; + // Modelchecking Time Measurement, Start + std::chrono::high_resolution_clock::time_point modelcheckingStart = std::chrono::high_resolution_clock::now(); + switch (model->getType()) { case storm::models::DTMC: LOG4CPLUS_INFO(logger, "Model is a DTMC."); @@ -602,8 +635,16 @@ int main(const int argc, const char* argv[]) { if (modelchecker != nullptr) { delete modelchecker; } + + // Modelchecking Time Measurement, End + std::chrono::high_resolution_clock::time_point modelcheckingEnd = std::chrono::high_resolution_clock::now(); + std::cout << "Running the PRCTL ModelChecker took " << std::chrono::duration_cast(modelcheckingEnd - modelcheckingStart).count() << " milliseconds." << std::endl; } } + + // Execution Time Measurement, End + std::chrono::high_resolution_clock::time_point executionEnd = std::chrono::high_resolution_clock::now(); + std::cout << "Complete execution took " << std::chrono::duration_cast(executionEnd - executionStart).count() << " milliseconds." << std::endl; // Perform clean-up and terminate. cleanUp(); diff --git a/src/utility/StormOptions.cpp b/src/utility/StormOptions.cpp index 29f6dccfd..48b7aa3e0 100644 --- a/src/utility/StormOptions.cpp +++ b/src/utility/StormOptions.cpp @@ -28,7 +28,7 @@ bool storm::utility::StormOptions::optionsRegistered = storm::settings::Settings settings->addOption(storm::settings::OptionBuilder("StoRM Main", "counterExample", "", "Generates a counterexample for the given PRCTL formulas if not satisfied by the model").addArgument(storm::settings::ArgumentBuilder::createStringArgument("outputPath", "The path to the directory to write the generated counterexample files to.").build()).build()); - settings->addOption(storm::settings::OptionBuilder("StoRM Main", "transitionRewards", "", "If specified, the transition rewards are read from this file and added to the explicit model. Note that this requires an explicit model.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("transitionRewardsFileName", "The file from which to read the rransition rewards.").addValidationFunctionString(storm::settings::ArgumentValidators::existingReadableFileValidator()).build()).build()); + settings->addOption(storm::settings::OptionBuilder("StoRM Main", "transitionRewards", "", "If specified, the transition rewards are read from this file and added to the explicit model. Note that this requires an explicit model.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("transitionRewardsFileName", "The file from which to read the transition rewards.").addValidationFunctionString(storm::settings::ArgumentValidators::existingReadableFileValidator()).build()).build()); settings->addOption(storm::settings::OptionBuilder("StoRM Main", "stateRewards", "", "If specified, the state rewards are read from this file and added to the explicit model. Note that this requires an explicit model.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("stateRewardsFileName", "The file from which to read the state rewards.").addValidationFunctionString(storm::settings::ArgumentValidators::existingReadableFileValidator()).build()).build()); diff --git a/test/functional/parser/PrismParserTest.cpp b/test/functional/parser/PrismParserTest.cpp index 5bf7bfd49..3621149b8 100644 --- a/test/functional/parser/PrismParserTest.cpp +++ b/test/functional/parser/PrismParserTest.cpp @@ -4,6 +4,7 @@ TEST(PrismParser, StandardModelTest) { storm::prism::Program result; + result = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/coin2.nm"); EXPECT_NO_THROW(result = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/coin2.nm")); EXPECT_NO_THROW(result = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/crowds5_5.pm")); EXPECT_NO_THROW(result = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/csma2_2.nm")); diff --git a/test/functional/storage/CuddDdTest.cpp b/test/functional/storage/CuddDdTest.cpp index f39e4ad48..e95348625 100644 --- a/test/functional/storage/CuddDdTest.cpp +++ b/test/functional/storage/CuddDdTest.cpp @@ -3,6 +3,7 @@ #include "src/exceptions/InvalidArgumentException.h" #include "src/storage/dd/CuddDdManager.h" #include "src/storage/dd/CuddDd.h" +#include "src/storage/dd/CuddOdd.h" #include "src/storage/dd/DdMetaVariable.h" TEST(CuddDdManager, Constants) { @@ -162,12 +163,12 @@ TEST(CuddDd, OperatorTest) { dd4 = dd3.minimum(dd1); dd4 *= manager->getEncoding("x", 2); - dd4.sumAbstract({"x"}); + dd4 = dd4.sumAbstract({"x"}); EXPECT_EQ(2, dd4.getValue()); dd4 = dd3.maximum(dd1); dd4 *= manager->getEncoding("x", 2); - dd4.sumAbstract({"x"}); + dd4 = dd4.sumAbstract({"x"}); EXPECT_EQ(5, dd4.getValue()); dd1 = manager->getConstant(0.01); @@ -187,36 +188,36 @@ TEST(CuddDd, AbstractionTest) { dd2 = manager->getConstant(5); dd3 = dd1.equals(dd2); EXPECT_EQ(1, dd3.getNonZeroCount()); - ASSERT_THROW(dd3.existsAbstract({"x'"}), storm::exceptions::InvalidArgumentException); - ASSERT_NO_THROW(dd3.existsAbstract({"x"})); + ASSERT_THROW(dd3 = dd3.existsAbstract({"x'"}), storm::exceptions::InvalidArgumentException); + ASSERT_NO_THROW(dd3 = dd3.existsAbstract({"x"})); EXPECT_EQ(1, dd3.getNonZeroCount()); EXPECT_EQ(1, dd3.getMax()); dd3 = dd1.equals(dd2); dd3 *= manager->getConstant(3); EXPECT_EQ(1, dd3.getNonZeroCount()); - ASSERT_THROW(dd3.existsAbstract({"x'"}), storm::exceptions::InvalidArgumentException); - ASSERT_NO_THROW(dd3.existsAbstract({"x"})); + ASSERT_THROW(dd3 = dd3.existsAbstract({"x'"}), storm::exceptions::InvalidArgumentException); + ASSERT_NO_THROW(dd3 = dd3.existsAbstract({"x"})); EXPECT_TRUE(dd3 == manager->getZero()); dd3 = dd1.equals(dd2); dd3 *= manager->getConstant(3); - ASSERT_THROW(dd3.sumAbstract({"x'"}), storm::exceptions::InvalidArgumentException); - ASSERT_NO_THROW(dd3.sumAbstract({"x"})); + ASSERT_THROW(dd3 = dd3.sumAbstract({"x'"}), storm::exceptions::InvalidArgumentException); + ASSERT_NO_THROW(dd3 = dd3.sumAbstract({"x"})); EXPECT_EQ(1, dd3.getNonZeroCount()); EXPECT_EQ(3, dd3.getMax()); dd3 = dd1.equals(dd2); dd3 *= manager->getConstant(3); - ASSERT_THROW(dd3.minAbstract({"x'"}), storm::exceptions::InvalidArgumentException); - ASSERT_NO_THROW(dd3.minAbstract({"x"})); + ASSERT_THROW(dd3 = dd3.minAbstract({"x'"}), storm::exceptions::InvalidArgumentException); + ASSERT_NO_THROW(dd3 = dd3.minAbstract({"x"})); EXPECT_EQ(0, dd3.getNonZeroCount()); EXPECT_EQ(0, dd3.getMax()); dd3 = dd1.equals(dd2); dd3 *= manager->getConstant(3); - ASSERT_THROW(dd3.maxAbstract({"x'"}), storm::exceptions::InvalidArgumentException); - ASSERT_NO_THROW(dd3.maxAbstract({"x"})); + ASSERT_THROW(dd3 = dd3.maxAbstract({"x'"}), storm::exceptions::InvalidArgumentException); + ASSERT_NO_THROW(dd3 = dd3.maxAbstract({"x"})); EXPECT_EQ(1, dd3.getNonZeroCount()); EXPECT_EQ(3, dd3.getMax()); } @@ -308,3 +309,115 @@ TEST(CuddDd, ForwardIteratorTest) { } EXPECT_EQ(1, numberOfValuations); } + +TEST(CuddDd, ToExpressionTest) { + std::shared_ptr> manager(new storm::dd::DdManager()); + manager->addMetaVariable("x", 1, 9); + + storm::dd::Dd dd; + ASSERT_NO_THROW(dd = manager->getIdentity("x")); + + storm::expressions::Expression ddAsExpression; + ASSERT_NO_THROW(ddAsExpression = dd.toExpression()); + + storm::expressions::SimpleValuation valuation; + for (std::size_t bit = 0; bit < manager->getMetaVariable("x").getNumberOfDdVariables(); ++bit) { + valuation.addBooleanIdentifier("x." + std::to_string(bit)); + } + + storm::dd::DdMetaVariable const& metaVariable = manager->getMetaVariable("x"); + + for (auto valuationValuePair : dd) { + for (std::size_t i = 0; i < metaVariable.getNumberOfDdVariables(); ++i) { + // Check if the i-th bit is set or not and modify the valuation accordingly. + if (((valuationValuePair.first.getIntegerValue("x") - metaVariable.getLow()) & (1ull << (metaVariable.getNumberOfDdVariables() - i - 1))) != 0) { + valuation.setBooleanValue("x." + std::to_string(i), true); + } else { + valuation.setBooleanValue("x." + std::to_string(i), false); + } + } + + // At this point, the constructed valuation should make the expression obtained from the DD evaluate to the very + // same value as the current value obtained from the DD. + EXPECT_EQ(valuationValuePair.second, ddAsExpression.evaluateAsDouble(&valuation)); + } + + storm::expressions::Expression mintermExpression = dd.getMintermExpression(); + + // Check whether all minterms are covered. + for (auto valuationValuePair : dd) { + for (std::size_t i = 0; i < metaVariable.getNumberOfDdVariables(); ++i) { + // Check if the i-th bit is set or not and modify the valuation accordingly. + if (((valuationValuePair.first.getIntegerValue("x") - metaVariable.getLow()) & (1ull << (metaVariable.getNumberOfDdVariables() - i - 1))) != 0) { + valuation.setBooleanValue("x." + std::to_string(i), true); + } else { + valuation.setBooleanValue("x." + std::to_string(i), false); + } + } + + // At this point, the constructed valuation should make the expression obtained from the DD evaluate to the very + // same value as the current value obtained from the DD. + EXPECT_TRUE(mintermExpression.evaluateAsBool(&valuation)); + } + + // Now check no additional minterms are covered. + dd = !dd; + for (auto valuationValuePair : dd) { + for (std::size_t i = 0; i < metaVariable.getNumberOfDdVariables(); ++i) { + // Check if the i-th bit is set or not and modify the valuation accordingly. + if (((valuationValuePair.first.getIntegerValue("x") - metaVariable.getLow()) & (1ull << (metaVariable.getNumberOfDdVariables() - i - 1))) != 0) { + valuation.setBooleanValue("x." + std::to_string(i), true); + } else { + valuation.setBooleanValue("x." + std::to_string(i), false); + } + } + + // At this point, the constructed valuation should make the expression obtained from the DD evaluate to the very + // same value as the current value obtained from the DD. + EXPECT_FALSE(mintermExpression.evaluateAsBool(&valuation)); + } +} + +TEST(CuddDd, OddTest) { + std::shared_ptr> manager(new storm::dd::DdManager()); + manager->addMetaVariable("a"); + manager->addMetaVariable("x", 1, 9); + + storm::dd::Dd dd = manager->getIdentity("x"); + storm::dd::Odd odd; + ASSERT_NO_THROW(odd = storm::dd::Odd(dd)); + EXPECT_EQ(9, odd.getTotalOffset()); + EXPECT_EQ(12, odd.getNodeCount()); + + std::vector ddAsVector; + ASSERT_NO_THROW(ddAsVector = dd.toVector()); + EXPECT_EQ(9, ddAsVector.size()); + for (uint_fast64_t i = 0; i < ddAsVector.size(); ++i) { + EXPECT_TRUE(i+1 == ddAsVector[i]); + } + + // Create a non-trivial matrix. + dd = manager->getIdentity("x").equals(manager->getIdentity("x'")) * manager->getRange("x"); + dd += manager->getEncoding("x", 1) * manager->getRange("x'") + manager->getEncoding("x'", 1) * manager->getRange("x"); + + // Create the ODDs. + storm::dd::Odd rowOdd; + ASSERT_NO_THROW(rowOdd = storm::dd::Odd(manager->getRange("x"))); + storm::dd::Odd columnOdd; + ASSERT_NO_THROW(columnOdd = storm::dd::Odd(manager->getRange("x'"))); + + // Try to translate the matrix. + storm::storage::SparseMatrix matrix; + ASSERT_NO_THROW(matrix = dd.toMatrix({"x"}, {"x'"}, rowOdd, columnOdd)); + + EXPECT_EQ(9, matrix.getRowCount()); + EXPECT_EQ(9, matrix.getColumnCount()); + EXPECT_EQ(25, matrix.getNonzeroEntryCount()); + + dd = manager->getRange("x") * manager->getRange("x'") * manager->getEncoding("a", 0).ite(dd, dd + manager->getConstant(1)); + ASSERT_NO_THROW(matrix = dd.toMatrix({"x"}, {"x'"}, {"a"}, rowOdd, columnOdd)); + EXPECT_EQ(18, matrix.getRowCount()); + EXPECT_EQ(9, matrix.getRowGroupCount()); + EXPECT_EQ(9, matrix.getColumnCount()); + EXPECT_EQ(106, matrix.getNonzeroEntryCount()); +} \ No newline at end of file diff --git a/test/functional/storage/ExpressionTest.cpp b/test/functional/storage/ExpressionTest.cpp index d3a920d1d..642456995 100644 --- a/test/functional/storage/ExpressionTest.cpp +++ b/test/functional/storage/ExpressionTest.cpp @@ -223,10 +223,10 @@ TEST(Expression, OperatorTest) { ASSERT_NO_THROW(tempExpression = boolVarExpression.iff(boolVarExpression)); EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); - ASSERT_THROW(tempExpression = trueExpression ^ piExpression, storm::exceptions::InvalidTypeException); - ASSERT_NO_THROW(tempExpression = trueExpression ^ falseExpression); + ASSERT_THROW(tempExpression = trueExpression != piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = trueExpression != falseExpression); EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); - ASSERT_NO_THROW(tempExpression = boolVarExpression ^ boolVarExpression); + ASSERT_NO_THROW(tempExpression = boolVarExpression != boolVarExpression); EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); ASSERT_THROW(tempExpression = trueExpression.floor(), storm::exceptions::InvalidTypeException); @@ -240,6 +240,12 @@ TEST(Expression, OperatorTest) { EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); ASSERT_NO_THROW(tempExpression = doubleVarExpression.ceil()); EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + + ASSERT_THROW(tempExpression = trueExpression ^ piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression ^ threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = intVarExpression ^ doubleVarExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); } TEST(Expression, SubstitutionTest) {