diff --git a/src/storm/abstraction/AbstractionInformation.cpp b/src/storm/abstraction/AbstractionInformation.cpp index f1dca42cc..4cfd9ee3f 100644 --- a/src/storm/abstraction/AbstractionInformation.cpp +++ b/src/storm/abstraction/AbstractionInformation.cpp @@ -492,9 +492,11 @@ namespace storm { } template - std::pair, uint64_t> AbstractionInformation::addLocationVariables(uint64_t highestLocationIndex) { + std::pair, uint64_t> AbstractionInformation::addLocationVariables(storm::expressions::Variable const& locationExpressionVariable, uint64_t highestLocationIndex) { auto newMetaVariable = ddManager->addMetaVariable("loc_" + std::to_string(locationVariablePairs.size()), 0, highestLocationIndex); + locationExpressionVariables.insert(locationExpressionVariable); + locationExpressionToDdVariableMap.emplace(locationExpressionVariable, newMetaVariable); locationVariablePairs.emplace_back(newMetaVariable); allSourceLocationVariables.insert(newMetaVariable.first); sourceVariables.insert(newMetaVariable.first); @@ -524,6 +526,21 @@ namespace storm { return allSuccessorLocationVariables; } + template + storm::expressions::Variable const& AbstractionInformation::getDdLocationVariable(storm::expressions::Variable const& locationExpressionVariable, bool source) { + auto const& metaVariablePair = locationExpressionToDdVariableMap.at(locationExpressionVariable); + if (source) { + return metaVariablePair.first; + } else { + return metaVariablePair.second; + } + } + + template + std::set const& AbstractionInformation::getLocationExpressionVariables() const { + return locationExpressionVariables; + } + template storm::dd::Bdd AbstractionInformation::encodeLocation(storm::expressions::Variable const& locationVariable, uint64_t locationIndex) const { return this->getDdManager().getEncoding(locationVariable, locationIndex); diff --git a/src/storm/abstraction/AbstractionInformation.h b/src/storm/abstraction/AbstractionInformation.h index 55a9e0105..c9bdb3e7a 100644 --- a/src/storm/abstraction/AbstractionInformation.h +++ b/src/storm/abstraction/AbstractionInformation.h @@ -462,7 +462,7 @@ namespace storm { /*! * Adds a location variable of appropriate range and returns the pair of meta variables. */ - std::pair, uint64_t> addLocationVariables(uint64_t highestLocationIndex); + std::pair, uint64_t> addLocationVariables(storm::expressions::Variable const& locationExpressionVariable, uint64_t highestLocationIndex); /*! * Retrieves the location variable with the given index as either source or successor. @@ -479,6 +479,16 @@ namespace storm { */ std::set const& getSuccessorLocationVariables() const; + /*! + * Retrieves the DD variable for the given location expression variable. + */ + storm::expressions::Variable const& getDdLocationVariable(storm::expressions::Variable const& locationExpressionVariable, bool source); + + /*! + * Retrieves the source location variables. + */ + std::set const& getLocationExpressionVariables() const; + /*! * Encodes the given location index as either source or successor. */ @@ -612,6 +622,12 @@ namespace storm { /// The location variable pairs (source/successor). std::vector> locationVariablePairs; + /// A mapping from location expression variables to their source/successor counterparts. + std::map> locationExpressionToDdVariableMap; + + /// The set of all location expression variables. + std::set locationExpressionVariables; + // All source location variables. std::set allSourceLocationVariables; diff --git a/src/storm/abstraction/ExpressionTranslator.cpp b/src/storm/abstraction/ExpressionTranslator.cpp index a60ab3c12..fb9d1f5fc 100644 --- a/src/storm/abstraction/ExpressionTranslator.cpp +++ b/src/storm/abstraction/ExpressionTranslator.cpp @@ -16,7 +16,7 @@ namespace storm { using namespace storm::expressions; template - ExpressionTranslator::ExpressionTranslator(AbstractionInformation& abstractionInformation, std::unique_ptr&& smtSolver) : abstractionInformation(abstractionInformation), equivalenceChecker(std::move(smtSolver)), locationVariables(abstractionInformation.getSourceLocationVariables()), abstractedVariables(abstractionInformation.getAbstractedVariables()) { + ExpressionTranslator::ExpressionTranslator(AbstractionInformation& abstractionInformation, std::unique_ptr&& smtSolver) : abstractionInformation(abstractionInformation), equivalenceChecker(std::move(smtSolver)), locationVariables(abstractionInformation.getLocationExpressionVariables()), abstractedVariables(abstractionInformation.getAbstractedVariables()) { // Intentionally left empty. } @@ -127,7 +127,7 @@ namespace storm { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Expressions of this kind are currently not supported by the abstraction expression translator."); } else { - return abstractionInformation.get().getDdManager().template getIdentity(expression.getVariable()); + return abstractionInformation.get().getDdManager().template getIdentity(abstractionInformation.get().getDdLocationVariable(expression.getVariable(), true)); } } diff --git a/src/storm/abstraction/jani/AutomatonAbstractor.cpp b/src/storm/abstraction/jani/AutomatonAbstractor.cpp index 08a401608..673bb3c79 100644 --- a/src/storm/abstraction/jani/AutomatonAbstractor.cpp +++ b/src/storm/abstraction/jani/AutomatonAbstractor.cpp @@ -33,7 +33,7 @@ namespace storm { } if (automaton.getNumberOfLocations() > 1) { - locationVariables = abstractionInformation.addLocationVariables(automaton.getNumberOfLocations() - 1).first; + locationVariables = abstractionInformation.addLocationVariables(automaton.getLocationExpressionVariable(), automaton.getNumberOfLocations() - 1).first; } } diff --git a/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp b/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp index 21dc3e73e..a5ef4c2f5 100644 --- a/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp +++ b/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp @@ -8,6 +8,7 @@ #include "storm/models/symbolic/Mdp.h" #include "storm/storage/expressions/ExpressionManager.h" +#include "storm/storage/expressions/VariableSetAbstractor.h" #include "storm/storage/dd/DdManager.h" @@ -445,9 +446,25 @@ namespace storm { template std::vector GameBasedMdpModelChecker::getInitialPredicates(storm::expressions::Expression const& constraintExpression, storm::expressions::Expression const& targetStateExpression) { std::vector initialPredicates; - initialPredicates.push_back(targetStateExpression); - if (!constraintExpression.isTrue() && !constraintExpression.isFalse()) { - initialPredicates.push_back(constraintExpression); + if (preprocessedModel.isJaniModel()) { + storm::expressions::VariableSetAbstractor abstractor(preprocessedModel.asJaniModel().getAllLocationExpressionVariables()); + + storm::expressions::Expression abstractedExpression = abstractor.abstract(targetStateExpression); + if (abstractedExpression.isInitialized() && !abstractedExpression.isTrue() && !abstractedExpression.isFalse()) { + initialPredicates.push_back(abstractedExpression); + } + + abstractedExpression = abstractor.abstract(constraintExpression); + if (abstractedExpression.isInitialized() && !abstractedExpression.isTrue() && !abstractedExpression.isFalse()) { + initialPredicates.push_back(abstractedExpression); + } + } else { + if (!targetStateExpression.isTrue() && !targetStateExpression.isFalse()) { + initialPredicates.push_back(targetStateExpression); + } + if (!constraintExpression.isTrue() && !constraintExpression.isFalse()) { + initialPredicates.push_back(constraintExpression); + } } return initialPredicates; } diff --git a/src/storm/storage/expressions/VariableSetAbstractor.cpp b/src/storm/storage/expressions/VariableSetAbstractor.cpp new file mode 100644 index 000000000..5820db995 --- /dev/null +++ b/src/storm/storage/expressions/VariableSetAbstractor.cpp @@ -0,0 +1,225 @@ +#include "storm/storage/expressions/VariableSetAbstractor.h" + +#include "storm/storage/expressions/Expressions.h" + +#include "storm/utility/macros.h" +#include "storm/exceptions/InvalidArgumentException.h" + +namespace storm { + namespace expressions { + + VariableSetAbstractor::VariableSetAbstractor(std::set const& variablesToAbstract) : variablesToAbstract(variablesToAbstract) { + // Intentionally left empty. + } + + storm::expressions::Expression VariableSetAbstractor::abstract(storm::expressions::Expression const& expression) { + std::set containedVariables = expression.getVariables(); + bool onlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), containedVariables.begin(), containedVariables.end()); + + if (onlyAbstractedVariables) { + return storm::expressions::Expression(); + } + + std::set tmp; + std::set_intersection(containedVariables.begin(), containedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool hasAbstractedVariables = !tmp.empty(); + + if (hasAbstractedVariables) { + return boost::any_cast(expression.accept(*this, boost::none)); + } else { + return expression; + } + } + + boost::any VariableSetAbstractor::visit(IfThenElseExpression const& expression, boost::any const& data) { + std::set conditionVariables; + expression.getCondition()->gatherVariables(conditionVariables); + bool conditionOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), conditionVariables.begin(), conditionVariables.end()); + + std::set tmp; + std::set_intersection(conditionVariables.begin(), conditionVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool conditionHasAbstractedVariables = !tmp.empty(); + + std::set thenVariables; + expression.getThenExpression()->gatherVariables(thenVariables); + bool thenOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), thenVariables.begin(), thenVariables.end()); + + tmp.clear(); + std::set_intersection(thenVariables.begin(), thenVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool thenHasAbstractedVariables = !tmp.empty(); + + std::set elseVariables; + expression.getElseExpression()->gatherVariables(elseVariables); + bool elseOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), elseVariables.begin(), elseVariables.end()); + + tmp.clear(); + std::set_intersection(elseVariables.begin(), elseVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool elseHasAbstractedVariables = !tmp.empty(); + + if (conditionHasAbstractedVariables || thenHasAbstractedVariables || elseHasAbstractedVariables) { + if (conditionOnlyAbstractedVariables && thenOnlyAbstractedVariables && elseOnlyAbstractedVariables) { + return boost::any(); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot abstract from variable set in expression as it mixes variables of different types."); + } + } else { + return expression.toExpression(); + } + } + + boost::any VariableSetAbstractor::visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) { + std::set leftContainedVariables; + expression.getFirstOperand()->gatherVariables(leftContainedVariables); + bool leftOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), leftContainedVariables.begin(), leftContainedVariables.end()); + + std::set tmp; + std::set_intersection(leftContainedVariables.begin(), leftContainedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool leftHasAbstractedVariables = !tmp.empty(); + + std::set rightContainedVariables; + expression.getSecondOperand()->gatherVariables(rightContainedVariables); + bool rightOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), rightContainedVariables.begin(), rightContainedVariables.end()); + + tmp.clear(); + std::set_intersection(rightContainedVariables.begin(), rightContainedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool rightHasAbstractedVariables = !tmp.empty(); + + if (leftOnlyAbstractedVariables && rightOnlyAbstractedVariables) { + return boost::any(); + } else if (!leftHasAbstractedVariables && !rightHasAbstractedVariables) { + return expression; + } else { + if (leftHasAbstractedVariables && !rightHasAbstractedVariables) { + return expression.getFirstOperand()->toExpression(); + } else if (rightHasAbstractedVariables && !leftHasAbstractedVariables) { + return expression.getSecondOperand()->toExpression(); + } else { + storm::expressions::Expression leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); + storm::expressions::Expression rightResult = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); + + switch (expression.getOperatorType()) { + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And: return leftResult && rightResult; + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or: return leftResult || rightResult; + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Xor: return leftResult ^ rightResult; + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies: return storm::expressions::implies(leftResult, rightResult); + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff: return storm::expressions::iff(leftResult, rightResult); + } + } + } + } + + boost::any VariableSetAbstractor::visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) { + std::set leftContainedVariables; + expression.getFirstOperand()->gatherVariables(leftContainedVariables); + bool leftOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), leftContainedVariables.begin(), leftContainedVariables.end()); + + std::set tmp; + std::set_intersection(leftContainedVariables.begin(), leftContainedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool leftHasAbstractedVariables = !tmp.empty(); + + std::set rightContainedVariables; + expression.getSecondOperand()->gatherVariables(rightContainedVariables); + bool rightOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), rightContainedVariables.begin(), rightContainedVariables.end()); + + tmp.clear(); + std::set_intersection(rightContainedVariables.begin(), rightContainedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool rightHasAbstractedVariables = !tmp.empty(); + + if (leftOnlyAbstractedVariables && rightOnlyAbstractedVariables) { + return boost::any(); + } else if (!leftHasAbstractedVariables && !rightHasAbstractedVariables) { + return expression; + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot abstract from variable set in expression as it mixes variables of different types."); + } + } + + boost::any VariableSetAbstractor::visit(BinaryRelationExpression const& expression, boost::any const& data) { + std::set leftContainedVariables; + expression.getFirstOperand()->gatherVariables(leftContainedVariables); + bool leftOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), leftContainedVariables.begin(), leftContainedVariables.end()); + + std::set tmp; + std::set_intersection(leftContainedVariables.begin(), leftContainedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool leftHasAbstractedVariables = !tmp.empty(); + + std::set rightContainedVariables; + expression.getSecondOperand()->gatherVariables(rightContainedVariables); + bool rightOnlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), rightContainedVariables.begin(), rightContainedVariables.end()); + + tmp.clear(); + std::set_intersection(rightContainedVariables.begin(), rightContainedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool rightHasAbstractedVariables = !tmp.empty(); + + if (leftOnlyAbstractedVariables && rightOnlyAbstractedVariables) { + return boost::any(); + } else if (!leftHasAbstractedVariables && !rightHasAbstractedVariables) { + return expression; + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot abstract from variable set in expression as it mixes variables of different types."); + } + } + + boost::any VariableSetAbstractor::visit(VariableExpression const& expression, boost::any const& data) { + if (variablesToAbstract.find(expression.getVariable()) != variablesToAbstract.end()) { + return boost::any(); + } else { + return expression.toExpression(); + } + } + + boost::any VariableSetAbstractor::visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) { + std::set containedVariables; + expression.gatherVariables(containedVariables); + bool onlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), containedVariables.begin(), containedVariables.end()); + + if (onlyAbstractedVariables) { + return boost::any(); + } + + std::set tmp; + std::set_intersection(containedVariables.begin(), containedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool hasAbstractedVariables = !tmp.empty(); + if (hasAbstractedVariables) { + storm::expressions::Expression subexpression = boost::any_cast(expression.getOperand()->accept(*this, data)); + switch (expression.getOperatorType()) { + case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not: return !subexpression; + } + } else { + return expression.toExpression(); + } + } + + boost::any VariableSetAbstractor::visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) { + std::set containedVariables; + expression.gatherVariables(containedVariables); + bool onlyAbstractedVariables = std::includes(variablesToAbstract.begin(), variablesToAbstract.end(), containedVariables.begin(), containedVariables.end()); + + if (onlyAbstractedVariables) { + return boost::any(); + } + + std::set tmp; + std::set_intersection(containedVariables.begin(), containedVariables.end(), variablesToAbstract.begin(), variablesToAbstract.end(), std::inserter(tmp, tmp.begin())); + bool hasAbstractedVariables = !tmp.empty(); + if (hasAbstractedVariables) { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot abstract from variable set in expression as it mixes variables of different types."); + } else { + return expression.toExpression(); + } + } + + boost::any VariableSetAbstractor::visit(BooleanLiteralExpression const& expression, boost::any const& data) { + return expression.toExpression(); + } + + boost::any VariableSetAbstractor::visit(IntegerLiteralExpression const& expression, boost::any const& data) { + return expression.toExpression(); + } + + boost::any VariableSetAbstractor::visit(RationalLiteralExpression const& expression, boost::any const& data) { + return expression.toExpression(); + } + + } +} diff --git a/src/storm/storage/expressions/VariableSetAbstractor.h b/src/storm/storage/expressions/VariableSetAbstractor.h new file mode 100644 index 000000000..e5d905598 --- /dev/null +++ b/src/storm/storage/expressions/VariableSetAbstractor.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +#include "storm/storage/expressions/ExpressionVisitor.h" + +namespace storm { + namespace expressions { + + class Variable; + class Expression; + + class VariableSetAbstractor : public ExpressionVisitor { + public: + VariableSetAbstractor(std::set const& variablesToAbstract); + + storm::expressions::Expression abstract(storm::expressions::Expression const& expression); + + virtual boost::any visit(IfThenElseExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryRelationExpression const& expression, boost::any const& data) override; + virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; + + private: + std::set variablesToAbstract; + }; + + } +} diff --git a/src/storm/storage/jani/Model.cpp b/src/storm/storage/jani/Model.cpp index 1539c3329..f62e98547 100644 --- a/src/storm/storage/jani/Model.cpp +++ b/src/storm/storage/jani/Model.cpp @@ -630,7 +630,7 @@ namespace storm { return globalVariables; } - std::set Model::getAllExpressionVariables() const { + std::set Model::getAllExpressionVariables(bool includeLocationExpressionVariables) const { std::set result; for (auto const& constant : constants) { @@ -642,11 +642,22 @@ namespace storm { for (auto const& automaton : automata) { auto const& automatonVariables = automaton.getAllExpressionVariables(); result.insert(automatonVariables.begin(), automatonVariables.end()); + if (includeLocationExpressionVariables) { + result.insert(automaton.getLocationExpressionVariable()); + } } return result; } + std::set Model::getAllLocationExpressionVariables() const { + std::set result; + for (auto const& automaton : automata) { + result.insert(automaton.getLocationExpressionVariable()); + } + return result; + } + bool Model::hasGlobalVariable(std::string const& name) const { return globalVariables.hasVariable(name); } @@ -968,6 +979,14 @@ namespace storm { STORM_LOG_ASSERT(!automata.empty(), "No automata set"); STORM_LOG_ASSERT(composition != nullptr, "Composition is not set"); } + + storm::expressions::Expression Model::getLabelExpression(BooleanVariable const& transientVariable) const { + std::vector> allAutomata; + for (auto const& automaton : automata) { + allAutomata.emplace_back(automaton); + } + return getLabelExpression(transientVariable, allAutomata); + } storm::expressions::Expression Model::getLabelExpression(BooleanVariable const& transientVariable, std::vector> const& automata) const { STORM_LOG_THROW(transientVariable.isTransient(), storm::exceptions::InvalidArgumentException, "Expected transient variable."); diff --git a/src/storm/storage/jani/Model.h b/src/storm/storage/jani/Model.h index 237ee827a..0a79ab8d2 100644 --- a/src/storm/storage/jani/Model.h +++ b/src/storm/storage/jani/Model.h @@ -183,11 +183,19 @@ namespace storm { VariableSet const& getGlobalVariables() const; /*! - * Retrieves all expression variables used by this model. + * Retrieves all expression variables used by this model. Note that this does not include the location + * expression variables by default. * * @return The set of expression variables used by this model. */ - std::set getAllExpressionVariables() const; + std::set getAllExpressionVariables(bool includeLocationExpressionVariables = false) const; + + /*! + * Retrieves all location expression variables used by this model. + * + * @return The set of expression variables used by this model. + */ + std::set getAllLocationExpressionVariables() const; /*! * Retrieves whether this model has a global variable with the given name. @@ -380,7 +388,13 @@ namespace storm { * true. The provided location variables are used to encode the location of the automata. */ storm::expressions::Expression getLabelExpression(BooleanVariable const& transientVariable, std::vector> const& automata) const; - + + /*! + * Creates the expression that characterizes all states in which the provided transient boolean variable is + * true. The provided location variables are used to encode the location of the automata. + */ + storm::expressions::Expression getLabelExpression(BooleanVariable const& transientVariable) const; + /*! * Checks that undefined constants (parameters) of the model preserve the graph of the underlying model. * That is, undefined constants may only appear in the probability expressions of edge destinations as well