diff --git a/src/adapters/ExplicitModelAdapter.h b/src/adapters/ExplicitModelAdapter.h index 3875eef8a..49a2a4ccd 100644 --- a/src/adapters/ExplicitModelAdapter.h +++ b/src/adapters/ExplicitModelAdapter.h @@ -703,7 +703,7 @@ namespace storm { std::list> result; StateType const* currentState = stateInformation.reachableStates[stateIndex]; - + // Iterate over all modules. for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) { storm::ir::Module const& module = program.getModule(i); @@ -757,7 +757,7 @@ namespace storm { for (std::string const& action : program.getActions()) { StateType const* currentState = stateInformation.reachableStates[stateIndex]; boost::optional>> optionalActiveCommandLists = getActiveCommandsByAction(program, currentState, action); - + // Only process this action label, if there is at least one feasible solution. if (optionalActiveCommandLists) { std::vector> const& activeCommandList = optionalActiveCommandLists.get(); diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index 3e662ac2f..5f0e6ae85 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -20,6 +20,8 @@ #include "src/modelchecker/prctl/SparseMdpPrctlModelChecker.h" #include "src/solver/GmmxxNondeterministicLinearEquationSolver.h" +#include "src/utility/IRUtility.h" + namespace storm { namespace counterexamples { @@ -218,7 +220,7 @@ namespace storm { * @param solver The solver to use for the satisfiability evaluation. */ static void assertCuts(storm::ir::Program const& program, z3::context& context, z3::solver& solver) { - + // TODO. } /*! diff --git a/src/ir/Program.cpp b/src/ir/Program.cpp index 08b4ec934..b4cd0fad3 100644 --- a/src/ir/Program.cpp +++ b/src/ir/Program.cpp @@ -77,7 +77,7 @@ namespace storm { Program::Program(Program const& otherProgram) : modelType(otherProgram.modelType), globalBooleanVariables(otherProgram.globalBooleanVariables), globalIntegerVariables(otherProgram.globalIntegerVariables), globalBooleanVariableToIndexMap(otherProgram.globalBooleanVariableToIndexMap), globalIntegerVariableToIndexMap(otherProgram.globalIntegerVariableToIndexMap), modules(otherProgram.modules), rewards(otherProgram.rewards), - actionsToModuleIndexMap(), variableToModuleIndexMap() { + actions(otherProgram.actions), actionsToModuleIndexMap(), variableToModuleIndexMap() { // Perform deep-copy of the maps. for (auto const& booleanUndefinedConstant : otherProgram.booleanUndefinedConstantExpressions) { this->booleanUndefinedConstantExpressions[booleanUndefinedConstant.first] = std::unique_ptr(new storm::ir::expressions::BooleanConstantExpression(*booleanUndefinedConstant.second)); @@ -102,6 +102,7 @@ namespace storm { this->globalIntegerVariableToIndexMap = otherProgram.globalIntegerVariableToIndexMap; this->modules = otherProgram.modules; this->rewards = otherProgram.rewards; + this->actions = otherProgram.actions; this->actionsToModuleIndexMap = otherProgram.actionsToModuleIndexMap; this->variableToModuleIndexMap = otherProgram.variableToModuleIndexMap; diff --git a/src/ir/expressions/BaseExpression.cpp b/src/ir/expressions/BaseExpression.cpp index a6b90fe56..2d85dfdc9 100644 --- a/src/ir/expressions/BaseExpression.cpp +++ b/src/ir/expressions/BaseExpression.cpp @@ -26,6 +26,16 @@ namespace storm { // Nothing to do here. } + std::unique_ptr BaseExpression::substitute(std::unique_ptr&& expression, std::map> const& substitution) { + BaseExpression* result = expression->performSubstitution(substitution); + + if (result != expression.get()) { + return std::unique_ptr(result); + } else { + return std::move(expression); + } + } + int_fast64_t BaseExpression::getValueAsInt(std::pair, std::vector> const* variableValues) const { if (type != int_) { throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression of type '" @@ -70,6 +80,10 @@ namespace storm { return type; } + BaseExpression* BaseExpression::performSubstitution(std::map> const& substitution) { + return this; + } + } // namespace expressions } // namespace ir } // namespace storm diff --git a/src/ir/expressions/BaseExpression.h b/src/ir/expressions/BaseExpression.h index 2125287cd..1f045f5f4 100644 --- a/src/ir/expressions/BaseExpression.h +++ b/src/ir/expressions/BaseExpression.h @@ -33,8 +33,11 @@ namespace storm { * The base class for all expressions. */ class BaseExpression { - public: + // Forward declare friend classes to allow access to substitute. + friend class BinaryExpression; + friend class UnaryExpression; + /*! * Each node in an expression tree has a uniquely defined type from this enum. */ @@ -78,7 +81,17 @@ namespace storm { * @param variableState An object knowing about the global variable state. */ virtual std::unique_ptr clone(std::map const& renaming, storm::parser::prism::VariableState const& variableState) const = 0; - + + /*! + * Performs the given substitution by replacing each variable in the given expression that is a key in + * the map by a copy of the mapped expression. + * + * @param expression The expression in which to perform the substitution. + * @param substitution The substitution to apply. + * @return The resulting expression. + */ + static std::unique_ptr substitute(std::unique_ptr&& expression, std::map> const& substitution); + /*! * Retrieves the value of the expression as an integer given the provided variable valuation. * @@ -137,6 +150,16 @@ namespace storm { */ ReturnType getType() const; + protected: + /*! + * Performs the given substitution on the expression, i.e. replaces all variables whose names are keys + * of the map by a copy of the expression they are associated with in the map. This is intended as a helper + * function for substitute. + * + * @param substitution The substitution to perform + */ + virtual BaseExpression* performSubstitution(std::map> const& substitution); + private: // The type to which this node evaluates. ReturnType type; diff --git a/src/ir/expressions/BinaryExpression.cpp b/src/ir/expressions/BinaryExpression.cpp index 5ca3e186b..417cc9301 100644 --- a/src/ir/expressions/BinaryExpression.cpp +++ b/src/ir/expressions/BinaryExpression.cpp @@ -20,6 +20,25 @@ namespace storm { // Nothing to do here. } + BaseExpression* BinaryExpression::performSubstitution(std::map> const& substitution) { + // Get the new left successor recursively. + BaseExpression* newLeftSuccessor = left->performSubstitution(substitution); + + // If the left successor changed, we need to update it. If it did not change, this must not be executed, + // because assigning to the unique_ptr will destroy the current successor immediately. + if (newLeftSuccessor != left.get()) { + left = std::unique_ptr(newLeftSuccessor); + } + + // Now do the same thing for the right successor. + BaseExpression* newRightSuccessor = right->performSubstitution(substitution); + if (newRightSuccessor != right.get()) { + right = std::unique_ptr(newRightSuccessor); + } + + return this; + } + std::unique_ptr const& BinaryExpression::getLeft() const { return left; } diff --git a/src/ir/expressions/BinaryExpression.h b/src/ir/expressions/BinaryExpression.h index e730aaebb..1d6560863 100644 --- a/src/ir/expressions/BinaryExpression.h +++ b/src/ir/expressions/BinaryExpression.h @@ -33,7 +33,7 @@ namespace storm { * @param binaryExpression The expression to copy. */ BinaryExpression(BinaryExpression const& binaryExpression); - + /*! * Retrieves the left child of the expression node. * @@ -48,6 +48,9 @@ namespace storm { */ std::unique_ptr const& getRight() const; + protected: + virtual BaseExpression* performSubstitution(std::map> const& substitution) override; + private: // The left child of the binary expression. std::unique_ptr left; diff --git a/src/ir/expressions/UnaryExpression.cpp b/src/ir/expressions/UnaryExpression.cpp index a5c8c4be8..31463aff3 100644 --- a/src/ir/expressions/UnaryExpression.cpp +++ b/src/ir/expressions/UnaryExpression.cpp @@ -23,6 +23,17 @@ namespace storm { return child; } + BaseExpression* UnaryExpression::performSubstitution(std::map> const& substitution) { + BaseExpression* newChild = child->performSubstitution(substitution); + + // Only update the child if it changed, because otherwise the child gets destroyed. + if (newChild != child.get()) { + child = std::unique_ptr(newChild); + } + + return this; + } + } // namespace expressions } // namespace ir } // namespace storm diff --git a/src/ir/expressions/UnaryExpression.h b/src/ir/expressions/UnaryExpression.h index fa330aca6..cfc685455 100644 --- a/src/ir/expressions/UnaryExpression.h +++ b/src/ir/expressions/UnaryExpression.h @@ -39,6 +39,9 @@ namespace storm { * @return The child of the expression node. */ std::unique_ptr const& getChild() const; + + protected: + virtual BaseExpression* performSubstitution(std::map> const& substitution) override; private: // The left child of the unary expression. diff --git a/src/ir/expressions/VariableExpression.cpp b/src/ir/expressions/VariableExpression.cpp index 177e1c316..9bf09a734 100644 --- a/src/ir/expressions/VariableExpression.cpp +++ b/src/ir/expressions/VariableExpression.cpp @@ -48,6 +48,20 @@ namespace storm { } } + BaseExpression* VariableExpression::performSubstitution(std::map> const& substitution) { + // If the name of the variable is a key of the map, we need to replace it. + auto substitutionIterator = substitution.find(variableName); + + if (substitutionIterator != substitution.end()) { + std::unique_ptr expressionClone = substitutionIterator->second.get().clone(); + BaseExpression* rawPointer = expressionClone.release(); + return rawPointer; + } else { + // Otherwise, we don't need to replace anything. + return this; + } + } + void VariableExpression::accept(ExpressionVisitor* visitor) { visitor->visit(this); } diff --git a/src/ir/expressions/VariableExpression.h b/src/ir/expressions/VariableExpression.h index d24229b9e..99358d377 100644 --- a/src/ir/expressions/VariableExpression.h +++ b/src/ir/expressions/VariableExpression.h @@ -55,7 +55,7 @@ namespace storm { virtual std::unique_ptr clone() const override; virtual std::unique_ptr clone(std::map const& renaming, storm::parser::prism::VariableState const& variableState) const override; - + virtual void accept(ExpressionVisitor* visitor) override; virtual std::string toString() const override; @@ -79,7 +79,10 @@ namespace storm { * @return The global index of the variable. */ uint_fast64_t getGlobalVariableIndex() const; - + + protected: + virtual BaseExpression* performSubstitution(std::map> const& substitution) override; + private: // The global index of the variable. uint_fast64_t globalIndex; diff --git a/src/parser/prismparser/PrismGrammar.cpp b/src/parser/prismparser/PrismGrammar.cpp index 9a6819194..a6d7a62a5 100644 --- a/src/parser/prismparser/PrismGrammar.cpp +++ b/src/parser/prismparser/PrismGrammar.cpp @@ -93,7 +93,7 @@ namespace storm { void PrismGrammar::createIntegerVariable(std::string const& name, std::shared_ptr const& lower, std::shared_ptr const& upper, std::shared_ptr const& init, std::vector& vars, std::map& varids, bool isGlobalVariable) { uint_fast64_t id = this->state->addIntegerVariable(name); uint_fast64_t newLocalIndex = this->state->nextLocalIntegerVariableIndex; - vars.emplace_back(newLocalIndex, id, name, lower->clone(), upper->clone(), init->clone()); + vars.emplace_back(newLocalIndex, id, name, lower != nullptr ? lower->clone() : nullptr, upper != nullptr ? upper->clone() : nullptr, init != nullptr ? init->clone() : nullptr); varids[name] = newLocalIndex; ++this->state->nextLocalIntegerVariableIndex; this->state->localIntegerVariables_.add(name, name); @@ -105,7 +105,7 @@ namespace storm { void PrismGrammar::createBooleanVariable(std::string const& name, std::shared_ptr const& init, std::vector& vars, std::map& varids, bool isGlobalVariable) { uint_fast64_t id = this->state->addBooleanVariable(name); uint_fast64_t newLocalIndex = this->state->nextLocalBooleanVariableIndex; - vars.emplace_back(newLocalIndex, id, name, init->clone()); + vars.emplace_back(newLocalIndex, id, name, init != nullptr ? init->clone() : nullptr); varids[name] = newLocalIndex; ++this->state->nextLocalBooleanVariableIndex; this->state->localBooleanVariables_.add(name, name); @@ -125,7 +125,7 @@ namespace storm { } Update PrismGrammar::createUpdate(std::shared_ptr const& likelihood, std::map const& bools, std::map const& ints) { this->state->nextGlobalUpdateIndex++; - return Update(this->state->getNextGlobalUpdateIndex() - 1, likelihood->clone(), bools, ints); + return Update(this->state->getNextGlobalUpdateIndex() - 1, likelihood != nullptr ? likelihood->clone() : nullptr, bools, ints); } Command PrismGrammar::createCommand(std::string const& label, std::shared_ptr const& guard, std::vector const& updates) { this->state->nextGlobalCommandIndex++; diff --git a/src/utility/IRUtility.h b/src/utility/IRUtility.h index d26b0da61..106032b15 100644 --- a/src/utility/IRUtility.h +++ b/src/utility/IRUtility.h @@ -8,7 +8,7 @@ #ifndef STORM_UTILITY_IRUTILITY_H_ #define STORM_UTILITY_IRUTILITY_H_ -#include +#include #include "log4cplus/logger.h" #include "log4cplus/loggingmacros.h" @@ -28,23 +28,21 @@ namespace storm { * @param expression The expression for which to build the weakest precondition. * @param update The update with respect to which to compute the weakest precondition. */ - std::shared_ptr getWeakestPrecondition(std::shared_ptr booleanExpression, std::vector const& updates) { - std::map> variableToExpressionMap; + std::unique_ptr getWeakestPrecondition(std::unique_ptr const& booleanExpression, std::vector const& updates) { + std::map> variableToExpressionMap; // Construct the full substitution we need to perform later. for (auto const& update : updates) { - for (uint_fast64_t assignmentIndex = 0; assignmentIndex < update.getNumberOfAssignments(); ++assignmentIndex) { - storm::ir::Assignment const& update.getAssignment(assignmentIndex); - - variableToExpressionMap[assignment.getVariableName()] = assignment.getExpression(); + for (auto const& variableAssignmentPair : update.getBooleanAssignments()) { + variableToExpressionMap.emplace(variableAssignmentPair.first, *variableAssignmentPair.second.getExpression()); + } + for (auto const& variableAssignmentPair : update.getIntegerAssignments()) { + variableToExpressionMap.emplace(variableAssignmentPair.first, *variableAssignmentPair.second.getExpression()); } } // Copy the given expression and apply the substitution. - std::shared_ptr copiedExpression = booleanExpression->clone(); - copiedExpression->substitute(variableToExpressionMap); - - return copiedExpression; + return storm::ir::expressions::BaseExpression::substitute(booleanExpression->clone(), variableToExpressionMap); } } // namespace ir