From d86c763b94d6ad2de8778253ce472b19bce9065e Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Mon, 21 Sep 2020 10:36:29 -0700 Subject: [PATCH] support for nonstandard predicate elimination or to-dice translation --- src/storm/storage/expressions/Expression.cpp | 5 + src/storm/storage/expressions/Expression.h | 5 + .../expressions/SimplificationVisitor.cpp | 171 ++++++++++++++++++ .../expressions/SimplificationVisitor.h | 41 +++++ .../expressions/ToDiceStringVisitor.cpp | 32 ++++ .../storage/expressions/ToDiceStringVisitor.h | 2 + src/storm/storage/prism/Assignment.cpp | 4 + src/storm/storage/prism/Assignment.h | 4 +- src/storm/storage/prism/BooleanVariable.cpp | 4 + src/storm/storage/prism/BooleanVariable.h | 3 +- src/storm/storage/prism/Command.cpp | 12 +- src/storm/storage/prism/Command.h | 3 +- src/storm/storage/prism/Formula.cpp | 9 + src/storm/storage/prism/Formula.h | 1 + src/storm/storage/prism/IntegerVariable.cpp | 4 + src/storm/storage/prism/IntegerVariable.h | 4 +- src/storm/storage/prism/Label.cpp | 9 + src/storm/storage/prism/Label.h | 3 +- src/storm/storage/prism/Module.cpp | 22 +++ src/storm/storage/prism/Module.h | 4 +- src/storm/storage/prism/Program.cpp | 46 ++++- src/storm/storage/prism/Program.h | 5 + src/storm/storage/prism/Update.cpp | 14 +- src/storm/storage/prism/Update.h | 2 + 24 files changed, 400 insertions(+), 9 deletions(-) create mode 100644 src/storm/storage/expressions/SimplificationVisitor.cpp create mode 100644 src/storm/storage/expressions/SimplificationVisitor.h diff --git a/src/storm/storage/expressions/Expression.cpp b/src/storm/storage/expressions/Expression.cpp index 17eb4d30c..6ea792b75 100644 --- a/src/storm/storage/expressions/Expression.cpp +++ b/src/storm/storage/expressions/Expression.cpp @@ -13,6 +13,7 @@ #include "storm/exceptions/InvalidTypeException.h" #include "storm/exceptions/InvalidArgumentException.h" #include "storm/utility/macros.h" +#include "storm/storage/expressions/SimplificationVisitor.h" namespace storm { namespace expressions { @@ -53,6 +54,10 @@ namespace storm { return SubstitutionVisitor>(identifierToExpressionMap).substitute(*this); } + Expression Expression::substituteNonStandardPredicates() const { + return SimplificationVisitor().substitute(*this); + } + bool Expression::evaluateAsBool(Valuation const* valuation) const { return this->getBaseExpression().evaluateAsBool(valuation); } diff --git a/src/storm/storage/expressions/Expression.h b/src/storm/storage/expressions/Expression.h index f82439a36..eb4b2c9fb 100644 --- a/src/storm/storage/expressions/Expression.h +++ b/src/storm/storage/expressions/Expression.h @@ -100,6 +100,11 @@ namespace storm { */ Expression substitute(std::map const& variableToExpressionMap) const; + /*! + * Eliminate nonstandard predicates from the expression. + * @return + */ + Expression substituteNonStandardPredicates() const; /*! * Substitutes all occurrences of the variables according to the given map. Note that this substitution is * done simultaneously, i.e., variables appearing in the expressions that were "plugged in" are not diff --git a/src/storm/storage/expressions/SimplificationVisitor.cpp b/src/storm/storage/expressions/SimplificationVisitor.cpp new file mode 100644 index 000000000..0d4c0a1b0 --- /dev/null +++ b/src/storm/storage/expressions/SimplificationVisitor.cpp @@ -0,0 +1,171 @@ +#include +#include +#include + +#include "storm/storage/expressions/SimplificationVisitor.h" +#include "storm/storage/expressions/Expressions.h" +#include "storm/storage/expressions/PredicateExpression.h" +#include "storm/storage/expressions/ExpressionManager.h" + +namespace storm { + namespace expressions { + SimplificationVisitor::SimplificationVisitor() { + // Intentionally left empty. + } + + Expression SimplificationVisitor::substitute(Expression const &expression) { + return Expression(boost::any_cast>( + expression.getBaseExpression().accept(*this, boost::none))); + } + + boost::any SimplificationVisitor::visit(IfThenElseExpression const &expression, boost::any const &data) { + std::shared_ptr conditionExpression = boost::any_cast>( + expression.getCondition()->accept(*this, data)); + std::shared_ptr thenExpression = boost::any_cast>( + expression.getThenExpression()->accept(*this, data)); + std::shared_ptr elseExpression = boost::any_cast>( + expression.getElseExpression()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (conditionExpression.get() == expression.getCondition().get() && + thenExpression.get() == expression.getThenExpression().get() && + elseExpression.get() == expression.getElseExpression().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new IfThenElseExpression(expression.getManager(), expression.getType(), conditionExpression, + thenExpression, elseExpression))); + } + } + + boost::any + SimplificationVisitor::visit(BinaryBooleanFunctionExpression const &expression, boost::any const &data) { + std::shared_ptr firstExpression = boost::any_cast>( + expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>( + expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && + secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), + firstExpression, secondExpression, + expression.getOperatorType()))); + } + } + + boost::any + SimplificationVisitor::visit(BinaryNumericalFunctionExpression const &expression, boost::any const &data) { + std::shared_ptr firstExpression = boost::any_cast>( + expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>( + expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && + secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), + firstExpression, secondExpression, + expression.getOperatorType()))); + } + } + + boost::any SimplificationVisitor::visit(BinaryRelationExpression const &expression, boost::any const &data) { + std::shared_ptr firstExpression = boost::any_cast>( + expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>( + expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && + secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new BinaryRelationExpression(expression.getManager(), expression.getType(), firstExpression, + secondExpression, expression.getRelationType()))); + } + } + + boost::any SimplificationVisitor::visit(VariableExpression const &expression, boost::any const &) { + + return expression.getSharedPointer(); + + } + + boost::any + SimplificationVisitor::visit(UnaryBooleanFunctionExpression const &expression, boost::any const &data) { + std::shared_ptr operandExpression = boost::any_cast>( + expression.getOperand()->accept(*this, data)); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), + operandExpression, expression.getOperatorType()))); + } + } + + boost::any + SimplificationVisitor::visit(UnaryNumericalFunctionExpression const &expression, boost::any const &data) { + std::shared_ptr operandExpression = boost::any_cast>( + expression.getOperand()->accept(*this, data)); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr( + new UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), + operandExpression, expression.getOperatorType()))); + } + } + + boost::any SimplificationVisitor::visit(PredicateExpression const &expression, boost::any const &data) { + std::vector newExpressions; + for (uint64_t i = 0; i < expression.getArity(); ++i) { + newExpressions.emplace_back(boost::any_cast>( + expression.getOperand(i)->accept(*this, data))); + } + std::vector newSumExpressions; + for (auto const &expr : newExpressions) { + newSumExpressions.push_back( + ite(expr, expression.getManager().integer(1), expression.getManager().integer(0))); + } + + storm::expressions::Expression finalexpr; + if (expression.getPredicateType() == PredicateExpression::PredicateType::AtLeastOneOf) { + finalexpr = storm::expressions::sum(newSumExpressions) > expression.getManager().integer(0); + } else if (expression.getPredicateType() == PredicateExpression::PredicateType::AtMostOneOf) { + finalexpr = storm::expressions::sum(newSumExpressions) <= expression.getManager().integer(1); + } else if (expression.getPredicateType() == PredicateExpression::PredicateType::ExactlyOneOf) { + finalexpr = storm::expressions::sum(newSumExpressions) == expression.getManager().integer(1); + } else { + STORM_LOG_ASSERT(false, "Unknown predicate type."); + } + return std::const_pointer_cast(finalexpr.getBaseExpressionPointer()); + } + + + boost::any SimplificationVisitor::visit(BooleanLiteralExpression const &expression, boost::any const &) { + return expression.getSharedPointer(); + } + + boost::any SimplificationVisitor::visit(IntegerLiteralExpression const &expression, boost::any const &) { + return expression.getSharedPointer(); + } + + boost::any SimplificationVisitor::visit(RationalLiteralExpression const &expression, boost::any const &) { + return expression.getSharedPointer(); + } + + } +} \ No newline at end of file diff --git a/src/storm/storage/expressions/SimplificationVisitor.h b/src/storm/storage/expressions/SimplificationVisitor.h new file mode 100644 index 000000000..b1c80314e --- /dev/null +++ b/src/storm/storage/expressions/SimplificationVisitor.h @@ -0,0 +1,41 @@ +#pragma once +#include + +#include "storm/storage/expressions/Expression.h" +#include "storm/storage/expressions/ExpressionVisitor.h" + +namespace storm { + namespace expressions { + class SimplificationVisitor : public ExpressionVisitor { + public: + /*! + * Creates a new simplification visitor that replaces predicates by other (simpler?) predicates. + * + * Configuration: + * Currently, the visitor only replaces nonstandard predicates + * + */ + SimplificationVisitor(); + + /*! + * Simplifies based on the configuration. + */ + Expression substitute(Expression const& expression); + + virtual boost::any visit(IfThenElseExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryRelationExpression const& expression, boost::any const& data) override; + virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(PredicateExpression const& expression, boost::any const& data) override; + + protected: + // + }; + } +} diff --git a/src/storm/storage/expressions/ToDiceStringVisitor.cpp b/src/storm/storage/expressions/ToDiceStringVisitor.cpp index a1a20ed24..2a9467cc3 100644 --- a/src/storm/storage/expressions/ToDiceStringVisitor.cpp +++ b/src/storm/storage/expressions/ToDiceStringVisitor.cpp @@ -241,6 +241,38 @@ namespace storm { return boost::any(); } + boost::any ToDiceStringVisitor::visit(PredicateExpression const& expression, boost::any const& data) { + auto pdt = expression.getPredicateType(); + STORM_LOG_ASSERT(pdt == PredicateExpression::PredicateType::ExactlyOneOf || pdt == PredicateExpression::PredicateType::AtLeastOneOf || pdt == PredicateExpression::PredicateType::AtMostOneOf, "Only some predicate types are supported."); + stream << "("; + if (expression.getPredicateType() == PredicateExpression::PredicateType::ExactlyOneOf || expression.getPredicateType() == PredicateExpression::PredicateType::AtMostOneOf) { + stream << "(true "; + for (uint64_t operandi = 0; operandi < expression.getArity(); ++operandi) { + for (uint64_t operandj = operandi + 1; operandj < expression.getArity(); ++operandj) { + stream << "&& !("; + expression.getOperand(operandi)->accept(*this, data); + stream << " && "; + expression.getOperand(operandj)->accept(*this, data); + stream << ")"; + } + } + stream << ")"; + } + if (expression.getPredicateType() == PredicateExpression::PredicateType::ExactlyOneOf) { + stream << " && "; + } + if (expression.getPredicateType() == PredicateExpression::PredicateType::ExactlyOneOf || expression.getPredicateType() == PredicateExpression::PredicateType::AtLeastOneOf) { + stream << "( false"; + for (uint64_t operandj = 0; operandj < expression.getArity(); ++operandj) { + stream << "|| "; + expression.getOperand(operandj)->accept(*this, data); + } + stream << ")"; + } + stream << ")"; + return boost::any(); + } + boost::any ToDiceStringVisitor::visit(BooleanLiteralExpression const& expression, boost::any const&) { stream << (expression.getValue() ? " true " : " false "); return boost::any(); diff --git a/src/storm/storage/expressions/ToDiceStringVisitor.h b/src/storm/storage/expressions/ToDiceStringVisitor.h index 40ca07ea4..00a91d7f9 100644 --- a/src/storm/storage/expressions/ToDiceStringVisitor.h +++ b/src/storm/storage/expressions/ToDiceStringVisitor.h @@ -22,10 +22,12 @@ namespace storm { virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override; virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override; virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(PredicateExpression const& expression, boost::any const& data) override; virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; + private: std::stringstream stream; uint64_t nrBits; diff --git a/src/storm/storage/prism/Assignment.cpp b/src/storm/storage/prism/Assignment.cpp index e0ccf7a92..a760ca0a6 100644 --- a/src/storm/storage/prism/Assignment.cpp +++ b/src/storm/storage/prism/Assignment.cpp @@ -21,6 +21,10 @@ namespace storm { Assignment Assignment::substitute(std::map const& substitution) const { return Assignment(this->getVariable(), this->getExpression().substitute(substitution).simplify(), this->getFilename(), this->getLineNumber()); } + + Assignment Assignment::substituteNonStandardPredicates() const { + return Assignment(this->getVariable(), this->getExpression().substituteNonStandardPredicates().simplify(), this->getFilename(), this->getLineNumber()); + } bool Assignment::isIdentity() const { if(this->expression.isVariable()) { diff --git a/src/storm/storage/prism/Assignment.h b/src/storm/storage/prism/Assignment.h index 43777a9d6..646dddc53 100644 --- a/src/storm/storage/prism/Assignment.h +++ b/src/storm/storage/prism/Assignment.h @@ -57,7 +57,9 @@ namespace storm { * @return The resulting assignment. */ Assignment substitute(std::map const& substitution) const; - + + Assignment substituteNonStandardPredicates() const; + /*! * Checks whether the assignment is an identity (lhs equals rhs) * diff --git a/src/storm/storage/prism/BooleanVariable.cpp b/src/storm/storage/prism/BooleanVariable.cpp index c3b3d3431..141965e43 100644 --- a/src/storm/storage/prism/BooleanVariable.cpp +++ b/src/storm/storage/prism/BooleanVariable.cpp @@ -11,6 +11,10 @@ namespace storm { BooleanVariable BooleanVariable::substitute(std::map const& substitution) const { return BooleanVariable(this->getExpressionVariable(), this->getInitialValueExpression().isInitialized() ? this->getInitialValueExpression().substitute(substitution) : this->getInitialValueExpression(), this->isObservable(), this->getFilename(), this->getLineNumber()); } + + BooleanVariable BooleanVariable::substituteNonStandardPredicates() const { + return BooleanVariable(this->getExpressionVariable(), this->getInitialValueExpression().isInitialized() ? this->getInitialValueExpression().substituteNonStandardPredicates() : this->getInitialValueExpression(), this->isObservable(), this->getFilename(), this->getLineNumber()); + } void BooleanVariable::createMissingInitialValue() { if (!this->hasInitialValue()) { diff --git a/src/storm/storage/prism/BooleanVariable.h b/src/storm/storage/prism/BooleanVariable.h index acdc9865f..252e4673c 100644 --- a/src/storm/storage/prism/BooleanVariable.h +++ b/src/storm/storage/prism/BooleanVariable.h @@ -34,7 +34,8 @@ namespace storm { * @return The resulting boolean variable. */ BooleanVariable substitute(std::map const& substitution) const; - + BooleanVariable substituteNonStandardPredicates() const; + virtual void createMissingInitialValue() override; friend std::ostream& operator<<(std::ostream& stream, BooleanVariable const& variable); diff --git a/src/storm/storage/prism/Command.cpp b/src/storm/storage/prism/Command.cpp index 4b357c73a..045eca424 100644 --- a/src/storm/storage/prism/Command.cpp +++ b/src/storm/storage/prism/Command.cpp @@ -56,7 +56,17 @@ namespace storm { return Command(this->getGlobalIndex(), this->isMarkovian(), this->getActionIndex(), this->getActionName(), this->getGuardExpression().substitute(substitution).simplify(), newUpdates, this->getFilename(), this->getLineNumber()); } - + + Command Command::substituteNonStandardPredicates() const { + std::vector newUpdates; + newUpdates.reserve(this->getNumberOfUpdates()); + for (auto const& update : this->getUpdates()) { + newUpdates.emplace_back(update.substituteNonStandardPredicates()); + } + + return Command(this->getGlobalIndex(), this->isMarkovian(), this->getActionIndex(), this->getActionName(), this->getGuardExpression().substituteNonStandardPredicates().simplify(), newUpdates, this->getFilename(), this->getLineNumber()); + } + bool Command::isLabeled() const { return labeled; } diff --git a/src/storm/storage/prism/Command.h b/src/storm/storage/prism/Command.h index 8213349ce..84b95474c 100644 --- a/src/storm/storage/prism/Command.h +++ b/src/storm/storage/prism/Command.h @@ -114,7 +114,8 @@ namespace storm { * @return The resulting command. */ Command substitute(std::map const& substitution) const; - + + Command substituteNonStandardPredicates() const; /*! * Retrieves whether the command possesses a synchronization label. * diff --git a/src/storm/storage/prism/Formula.cpp b/src/storm/storage/prism/Formula.cpp index a628c63b3..c3c36e01d 100644 --- a/src/storm/storage/prism/Formula.cpp +++ b/src/storm/storage/prism/Formula.cpp @@ -44,6 +44,15 @@ namespace storm { return Formula(this->getName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } } + + Formula Formula::substituteNonStandardPredicates() const { + assert(this->getExpression().isInitialized()); + if (hasExpressionVariable()) { + return Formula(this->getExpressionVariable(), this->getExpression().substituteNonStandardPredicates(), this->getFilename(), this->getLineNumber()); + } else { + return Formula(this->getName(), this->getExpression().substituteNonStandardPredicates(), this->getFilename(), this->getLineNumber()); + } + } std::ostream& operator<<(std::ostream& stream, Formula const& formula) { stream << "formula " << formula.getName() << " = " << formula.getExpression() << ";"; diff --git a/src/storm/storage/prism/Formula.h b/src/storm/storage/prism/Formula.h index 379e2cdd0..b7678f571 100644 --- a/src/storm/storage/prism/Formula.h +++ b/src/storm/storage/prism/Formula.h @@ -92,6 +92,7 @@ namespace storm { * @return The resulting formula. */ Formula substitute(std::map const& substitution) const; + Formula substituteNonStandardPredicates() const; friend std::ostream& operator<<(std::ostream& stream, Formula const& formula); diff --git a/src/storm/storage/prism/IntegerVariable.cpp b/src/storm/storage/prism/IntegerVariable.cpp index d53ea618b..6cd4f5759 100644 --- a/src/storm/storage/prism/IntegerVariable.cpp +++ b/src/storm/storage/prism/IntegerVariable.cpp @@ -21,6 +21,10 @@ namespace storm { IntegerVariable IntegerVariable::substitute(std::map const& substitution) const { return IntegerVariable(this->getExpressionVariable(), this->getLowerBoundExpression().substitute(substitution), this->getUpperBoundExpression().substitute(substitution), this->getInitialValueExpression().isInitialized() ? this->getInitialValueExpression().substitute(substitution) : this->getInitialValueExpression(), this->isObservable(), this->getFilename(), this->getLineNumber()); } + + IntegerVariable IntegerVariable::substituteNonStandardPredicates() const { + return IntegerVariable(this->getExpressionVariable(), this->getLowerBoundExpression().substituteNonStandardPredicates(), this->getUpperBoundExpression().substituteNonStandardPredicates(), this->getInitialValueExpression().isInitialized() ? this->getInitialValueExpression().substituteNonStandardPredicates() : this->getInitialValueExpression(), this->isObservable(), this->getFilename(), this->getLineNumber()); + } void IntegerVariable::createMissingInitialValue() { if (!this->hasInitialValue()) { diff --git a/src/storm/storage/prism/IntegerVariable.h b/src/storm/storage/prism/IntegerVariable.h index 67618f6eb..3069ff45a 100644 --- a/src/storm/storage/prism/IntegerVariable.h +++ b/src/storm/storage/prism/IntegerVariable.h @@ -57,7 +57,9 @@ namespace storm { * @return The resulting boolean variable. */ IntegerVariable substitute(std::map const& substitution) const; - + + IntegerVariable substituteNonStandardPredicates() const; + virtual void createMissingInitialValue() override; friend std::ostream& operator<<(std::ostream& stream, IntegerVariable const& variable); diff --git a/src/storm/storage/prism/Label.cpp b/src/storm/storage/prism/Label.cpp index 4f07baeb9..4cd8f1a59 100644 --- a/src/storm/storage/prism/Label.cpp +++ b/src/storm/storage/prism/Label.cpp @@ -18,6 +18,10 @@ namespace storm { Label Label::substitute(std::map const& substitution) const { return Label(this->getName(), this->getStatePredicateExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } + + Label Label::substituteNonStandardPredicates() const { + return Label(this->getName(), this->getStatePredicateExpression().substituteNonStandardPredicates(), this->getFilename(), this->getLineNumber()); + } std::ostream& operator<<(std::ostream& stream, Label const& label) { stream << "label \"" << label.getName() << "\" = " << label.getStatePredicateExpression() << ";"; @@ -31,5 +35,10 @@ namespace storm { ObservationLabel ObservationLabel::substitute(std::map const& substitution) const { return ObservationLabel(this->getName(), this->getStatePredicateExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } + + ObservationLabel ObservationLabel::substituteNonStandardPredicates() const { + return ObservationLabel(this->getName(), this->getStatePredicateExpression().substituteNonStandardPredicates(), this->getFilename(), this->getLineNumber()); + } + } // namespace prism } // namespace storm diff --git a/src/storm/storage/prism/Label.h b/src/storm/storage/prism/Label.h index 298e5b721..aa22653a1 100644 --- a/src/storm/storage/prism/Label.h +++ b/src/storm/storage/prism/Label.h @@ -58,6 +58,7 @@ namespace storm { * @return The resulting label. */ Label substitute(std::map const& substitution) const; + Label substituteNonStandardPredicates() const; friend std::ostream& operator<<(std::ostream& stream, Label const& label); @@ -96,7 +97,7 @@ namespace storm { * @return The resulting label. */ ObservationLabel substitute(std::map const& substitution) const; - + ObservationLabel substituteNonStandardPredicates() const; }; diff --git a/src/storm/storage/prism/Module.cpp b/src/storm/storage/prism/Module.cpp index 512985ded..3c9c8f698 100644 --- a/src/storm/storage/prism/Module.cpp +++ b/src/storm/storage/prism/Module.cpp @@ -235,6 +235,28 @@ namespace storm { return Module(this->getName(), newBooleanVariables, newIntegerVariables, this->getClockVariables(), this->getInvariant(), newCommands, this->getFilename(), this->getLineNumber()); } + + Module Module::substituteNonStandardPredicates() const { + std::vector newBooleanVariables; + newBooleanVariables.reserve(this->getNumberOfBooleanVariables()); + for (auto const& booleanVariable : this->getBooleanVariables()) { + newBooleanVariables.emplace_back(booleanVariable.substituteNonStandardPredicates()); + } + + std::vector newIntegerVariables; + newBooleanVariables.reserve(this->getNumberOfIntegerVariables()); + for (auto const& integerVariable : this->getIntegerVariables()) { + newIntegerVariables.emplace_back(integerVariable.substituteNonStandardPredicates()); + } + + std::vector newCommands; + newCommands.reserve(this->getNumberOfCommands()); + for (auto const& command : this->getCommands()) { + newCommands.emplace_back(command.substituteNonStandardPredicates()); + } + + return Module(this->getName(), newBooleanVariables, newIntegerVariables, this->getClockVariables(), this->getInvariant(), newCommands, this->getFilename(), this->getLineNumber()); + } bool Module::containsVariablesOnlyInUpdateProbabilities(std::set const& undefinedConstantVariables) const { for (auto const& booleanVariable : this->getBooleanVariables()) { diff --git a/src/storm/storage/prism/Module.h b/src/storm/storage/prism/Module.h index 138dfbe09..adeb94e01 100644 --- a/src/storm/storage/prism/Module.h +++ b/src/storm/storage/prism/Module.h @@ -250,7 +250,9 @@ namespace storm { * @return The resulting module. */ Module substitute(std::map const& substitution) const; - + + Module substituteNonStandardPredicates() const; + /*! * Checks whether the given variables only appear in the update probabilities of the module and nowhere else. * diff --git a/src/storm/storage/prism/Program.cpp b/src/storm/storage/prism/Program.cpp index cdd41f7c3..f17674472 100644 --- a/src/storm/storage/prism/Program.cpp +++ b/src/storm/storage/prism/Program.cpp @@ -885,7 +885,51 @@ namespace storm { Program Program::substituteFormulas() const { return substituteConstantsFormulas(false, true); } - + + Program Program::substituteNonStandardPredicates() const { + // TODO support in constants, initial construct, and rewards + + std::vector newFormulas; + newFormulas.reserve(this->getNumberOfFormulas()); + for (auto const& oldFormula : this->getFormulas()) { + newFormulas.emplace_back(oldFormula.substituteNonStandardPredicates()); + } + + std::vector newBooleanVariables; + newBooleanVariables.reserve(this->getNumberOfGlobalBooleanVariables()); + for (auto const& booleanVariable : this->getGlobalBooleanVariables()) { + newBooleanVariables.emplace_back(booleanVariable.substituteNonStandardPredicates()); + } + + std::vector newIntegerVariables; + newBooleanVariables.reserve(this->getNumberOfGlobalIntegerVariables()); + for (auto const& integerVariable : this->getGlobalIntegerVariables()) { + newIntegerVariables.emplace_back(integerVariable.substituteNonStandardPredicates()); + } + + std::vector newModules; + newModules.reserve(this->getNumberOfModules()); + for (auto const& module : this->getModules()) { + newModules.emplace_back(module.substituteNonStandardPredicates()); + } + + + std::vector