diff --git a/src/logic/AtomicExpressionFormula.cpp b/src/logic/AtomicExpressionFormula.cpp index 5f8f0f509..b782fd368 100644 --- a/src/logic/AtomicExpressionFormula.cpp +++ b/src/logic/AtomicExpressionFormula.cpp @@ -24,10 +24,6 @@ namespace storm { atomicExpressionFormulas.push_back(std::dynamic_pointer_cast(this->shared_from_this())); } - std::shared_ptr AtomicExpressionFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->expression.substitute(substitution)); - } - std::ostream& AtomicExpressionFormula::writeToStream(std::ostream& out) const { out << expression; return out; diff --git a/src/logic/AtomicExpressionFormula.h b/src/logic/AtomicExpressionFormula.h index 12636475d..2299b5eac 100644 --- a/src/logic/AtomicExpressionFormula.h +++ b/src/logic/AtomicExpressionFormula.h @@ -23,8 +23,6 @@ namespace storm { virtual void gatherAtomicExpressionFormulas(std::vector>& atomicExpressionFormulas) const override; - virtual std::shared_ptr substitute(std::map const& substitution) const override; - private: // The atomic expression represented by this node in the formula tree. storm::expressions::Expression expression; diff --git a/src/logic/AtomicLabelFormula.cpp b/src/logic/AtomicLabelFormula.cpp index 292351c5d..b0406d62c 100644 --- a/src/logic/AtomicLabelFormula.cpp +++ b/src/logic/AtomicLabelFormula.cpp @@ -25,10 +25,6 @@ namespace storm { atomicExpressionFormulas.push_back(std::dynamic_pointer_cast(this->shared_from_this())); } - std::shared_ptr AtomicLabelFormula::substitute(std::map const& substitution) const { - return std::make_shared(*this); - } - std::ostream& AtomicLabelFormula::writeToStream(std::ostream& out) const { out << "\"" << label << "\""; return out; diff --git a/src/logic/AtomicLabelFormula.h b/src/logic/AtomicLabelFormula.h index 795704305..a7a627f74 100644 --- a/src/logic/AtomicLabelFormula.h +++ b/src/logic/AtomicLabelFormula.h @@ -22,9 +22,7 @@ namespace storm { std::string const& getLabel() const; virtual void gatherAtomicLabelFormulas(std::vector>& atomicLabelFormulas) const override; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; - + virtual std::ostream& writeToStream(std::ostream& out) const override; private: diff --git a/src/logic/BinaryBooleanStateFormula.cpp b/src/logic/BinaryBooleanStateFormula.cpp index d6757e508..a697b151c 100644 --- a/src/logic/BinaryBooleanStateFormula.cpp +++ b/src/logic/BinaryBooleanStateFormula.cpp @@ -30,11 +30,7 @@ namespace storm { bool BinaryBooleanStateFormula::isOr() const { return this->getOperator() == OperatorType::Or; } - - std::shared_ptr BinaryBooleanStateFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->operatorType, this->getLeftSubformula().substitute(substitution), this->getRightSubformula().substitute(substitution)); - } - + std::ostream& BinaryBooleanStateFormula::writeToStream(std::ostream& out) const { out << "("; this->getLeftSubformula().writeToStream(out); diff --git a/src/logic/BinaryBooleanStateFormula.h b/src/logic/BinaryBooleanStateFormula.h index c0880d058..d94260caa 100644 --- a/src/logic/BinaryBooleanStateFormula.h +++ b/src/logic/BinaryBooleanStateFormula.h @@ -28,8 +28,6 @@ namespace storm { virtual std::ostream& writeToStream(std::ostream& out) const override; - virtual std::shared_ptr substitute(std::map const& substitution) const override; - private: OperatorType operatorType; }; diff --git a/src/logic/BooleanLiteralFormula.cpp b/src/logic/BooleanLiteralFormula.cpp index 897fbf377..2cc139cf3 100644 --- a/src/logic/BooleanLiteralFormula.cpp +++ b/src/logic/BooleanLiteralFormula.cpp @@ -24,10 +24,6 @@ namespace storm { return visitor.visit(*this, data); } - std::shared_ptr BooleanLiteralFormula::substitute(std::map const& substitution) const { - return std::make_shared(*this); - } - std::ostream& BooleanLiteralFormula::writeToStream(std::ostream& out) const { if (value) { out << "true"; diff --git a/src/logic/BooleanLiteralFormula.h b/src/logic/BooleanLiteralFormula.h index 955d1dbbe..0946852fc 100644 --- a/src/logic/BooleanLiteralFormula.h +++ b/src/logic/BooleanLiteralFormula.h @@ -19,8 +19,6 @@ namespace storm { virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; - virtual std::shared_ptr substitute(std::map const& substitution) const override; - virtual std::ostream& writeToStream(std::ostream& out) const override; private: diff --git a/src/logic/BoundedUntilFormula.cpp b/src/logic/BoundedUntilFormula.cpp index 8f3e73286..dd9f965f7 100644 --- a/src/logic/BoundedUntilFormula.cpp +++ b/src/logic/BoundedUntilFormula.cpp @@ -44,10 +44,6 @@ namespace storm { return boost::get(bounds); } - std::shared_ptr BoundedUntilFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getLeftSubformula().substitute(substitution), this->getRightSubformula().substitute(substitution), bounds); - } - std::ostream& BoundedUntilFormula::writeToStream(std::ostream& out) const { this->getLeftSubformula().writeToStream(out); diff --git a/src/logic/BoundedUntilFormula.h b/src/logic/BoundedUntilFormula.h index 0e18633ce..13ee7cefc 100644 --- a/src/logic/BoundedUntilFormula.h +++ b/src/logic/BoundedUntilFormula.h @@ -26,8 +26,6 @@ namespace storm { virtual std::ostream& writeToStream(std::ostream& out) const override; - virtual std::shared_ptr substitute(std::map const& substitution) const override; - private: boost::variant> bounds; }; diff --git a/src/logic/ConditionalFormula.cpp b/src/logic/ConditionalFormula.cpp index 55bddabd9..fd1b1b442 100644 --- a/src/logic/ConditionalFormula.cpp +++ b/src/logic/ConditionalFormula.cpp @@ -33,11 +33,7 @@ namespace storm { boost::any ConditionalFormula::accept(FormulaVisitor const& visitor, boost::any const& data) const { return visitor.visit(*this, data); } - - std::shared_ptr ConditionalFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getSubformula().substitute(substitution), this->getConditionFormula().substitute(substitution), context); - } - + void ConditionalFormula::gatherAtomicExpressionFormulas(std::vector>& atomicExpressionFormulas) const { this->getSubformula().gatherAtomicExpressionFormulas(atomicExpressionFormulas); this->getConditionFormula().gatherAtomicExpressionFormulas(atomicExpressionFormulas); diff --git a/src/logic/ConditionalFormula.h b/src/logic/ConditionalFormula.h index 87303198c..6bb4317b6 100644 --- a/src/logic/ConditionalFormula.h +++ b/src/logic/ConditionalFormula.h @@ -25,8 +25,6 @@ namespace storm { virtual std::ostream& writeToStream(std::ostream& out) const override; - virtual std::shared_ptr substitute(std::map const& substitution) const override; - virtual void gatherAtomicExpressionFormulas(std::vector>& atomicExpressionFormulas) const override; virtual void gatherAtomicLabelFormulas(std::vector>& atomicLabelFormulas) const override; virtual void gatherReferencedRewardModels(std::set& referencedRewardModels) const override; diff --git a/src/logic/CumulativeRewardFormula.cpp b/src/logic/CumulativeRewardFormula.cpp index d8318164d..f32d476e5 100644 --- a/src/logic/CumulativeRewardFormula.cpp +++ b/src/logic/CumulativeRewardFormula.cpp @@ -44,10 +44,6 @@ namespace storm { } } - std::shared_ptr CumulativeRewardFormula::substitute(std::map const& substitution) const { - return std::make_shared(*this); - } - std::ostream& CumulativeRewardFormula::writeToStream(std::ostream& out) const { if (this->hasDiscreteTimeBound()) { out << "C<=" << this->getDiscreteTimeBound(); diff --git a/src/logic/CumulativeRewardFormula.h b/src/logic/CumulativeRewardFormula.h index 305c89895..422bc09ff 100644 --- a/src/logic/CumulativeRewardFormula.h +++ b/src/logic/CumulativeRewardFormula.h @@ -32,8 +32,6 @@ namespace storm { double getContinuousTimeBound() const; - virtual std::shared_ptr substitute(std::map const& substitution) const override; - private: boost::variant timeBound; }; diff --git a/src/logic/EventuallyFormula.cpp b/src/logic/EventuallyFormula.cpp index 926c68af2..e83609e1b 100644 --- a/src/logic/EventuallyFormula.cpp +++ b/src/logic/EventuallyFormula.cpp @@ -45,11 +45,7 @@ namespace storm { boost::any EventuallyFormula::accept(FormulaVisitor const& visitor, boost::any const& data) const { return visitor.visit(*this, data); } - - std::shared_ptr EventuallyFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getSubformula().substitute(substitution), context); - } - + std::ostream& EventuallyFormula::writeToStream(std::ostream& out) const { out << "F "; this->getSubformula().writeToStream(out); diff --git a/src/logic/EventuallyFormula.h b/src/logic/EventuallyFormula.h index b21c92170..779b36f81 100644 --- a/src/logic/EventuallyFormula.h +++ b/src/logic/EventuallyFormula.h @@ -28,8 +28,6 @@ namespace storm { virtual std::ostream& writeToStream(std::ostream& out) const override; - virtual std::shared_ptr substitute(std::map const& substitution) const override; - private: FormulaContext context; }; diff --git a/src/logic/Formula.cpp b/src/logic/Formula.cpp index cc8a570be..492230a33 100644 --- a/src/logic/Formula.cpp +++ b/src/logic/Formula.cpp @@ -3,6 +3,7 @@ #include "src/logic/FragmentChecker.h" #include "src/logic/FormulaInformationVisitor.h" +#include "src/logic/VariableSubstitutionVisitor.h" #include "src/logic/LabelSubstitutionVisitor.h" #include "src/logic/ToExpressionVisitor.h" @@ -408,6 +409,11 @@ namespace storm { return referencedRewardModels; } + std::shared_ptr Formula::substitute(std::map const& substitution) const { + VariableSubstitutionVisitor visitor(substitution); + return visitor.substitute(*this); + } + std::shared_ptr Formula::substitute(std::map const& labelSubstitution) const { LabelSubstitutionVisitor visitor(labelSubstitution); return visitor.substitute(*this); diff --git a/src/logic/Formula.h b/src/logic/Formula.h index 3ac205abe..11865b950 100644 --- a/src/logic/Formula.h +++ b/src/logic/Formula.h @@ -186,8 +186,8 @@ namespace storm { std::shared_ptr asSharedPointer(); std::shared_ptr asSharedPointer() const; - virtual std::shared_ptr substitute(std::map const& substitution) const = 0; - virtual std::shared_ptr substitute(std::map const& labelSubstitution) const; + std::shared_ptr substitute(std::map const& substitution) const; + std::shared_ptr substitute(std::map const& labelSubstitution) const; /*! * Takes the formula and converts it to an equivalent expression assuming that only atomic expression formulas diff --git a/src/logic/GloballyFormula.cpp b/src/logic/GloballyFormula.cpp index df33f0b52..ae6f5b508 100644 --- a/src/logic/GloballyFormula.cpp +++ b/src/logic/GloballyFormula.cpp @@ -19,10 +19,6 @@ namespace storm { boost::any GloballyFormula::accept(FormulaVisitor const& visitor, boost::any const& data) const { return visitor.visit(*this, data); } - - std::shared_ptr GloballyFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getSubformula().substitute(substitution)); - } std::ostream& GloballyFormula::writeToStream(std::ostream& out) const { out << "G "; diff --git a/src/logic/GloballyFormula.h b/src/logic/GloballyFormula.h index 011c61de9..e17347e11 100644 --- a/src/logic/GloballyFormula.h +++ b/src/logic/GloballyFormula.h @@ -17,8 +17,6 @@ namespace storm { virtual bool isProbabilityPathFormula() const override; virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; virtual std::ostream& writeToStream(std::ostream& out) const override; }; diff --git a/src/logic/InstantaneousRewardFormula.cpp b/src/logic/InstantaneousRewardFormula.cpp index 4f1f65429..bf39cd763 100644 --- a/src/logic/InstantaneousRewardFormula.cpp +++ b/src/logic/InstantaneousRewardFormula.cpp @@ -43,11 +43,7 @@ namespace storm { return boost::get(timeBound); } } - - std::shared_ptr InstantaneousRewardFormula::substitute(std::map const& substitution) const { - return std::make_shared(*this); - } - + std::ostream& InstantaneousRewardFormula::writeToStream(std::ostream& out) const { if (this->hasDiscreteTimeBound()) { out << "I=" << this->getDiscreteTimeBound(); diff --git a/src/logic/InstantaneousRewardFormula.h b/src/logic/InstantaneousRewardFormula.h index 069bf21bd..85d27c450 100644 --- a/src/logic/InstantaneousRewardFormula.h +++ b/src/logic/InstantaneousRewardFormula.h @@ -32,9 +32,7 @@ namespace storm { bool hasContinuousTimeBound() const; double getContinuousTimeBound() const; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; - + private: boost::variant timeBound; }; diff --git a/src/logic/LabelSubstitutionVisitor.cpp b/src/logic/LabelSubstitutionVisitor.cpp index ac05c09eb..ffb28dbed 100644 --- a/src/logic/LabelSubstitutionVisitor.cpp +++ b/src/logic/LabelSubstitutionVisitor.cpp @@ -21,7 +21,6 @@ namespace storm { } else { return std::static_pointer_cast(std::make_shared(f)); } - } - + } } } diff --git a/src/logic/LongRunAverageOperatorFormula.cpp b/src/logic/LongRunAverageOperatorFormula.cpp index f2e253ead..bf6b250ea 100644 --- a/src/logic/LongRunAverageOperatorFormula.cpp +++ b/src/logic/LongRunAverageOperatorFormula.cpp @@ -18,11 +18,7 @@ namespace storm { boost::any LongRunAverageOperatorFormula::accept(FormulaVisitor const& visitor, boost::any const& data) const { return visitor.visit(*this, data); } - - std::shared_ptr LongRunAverageOperatorFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getSubformula().substitute(substitution), this->operatorInformation); - } - + std::ostream& LongRunAverageOperatorFormula::writeToStream(std::ostream& out) const { out << "LRA"; OperatorFormula::writeToStream(out); diff --git a/src/logic/LongRunAverageOperatorFormula.h b/src/logic/LongRunAverageOperatorFormula.h index 394188b67..3590ac5b2 100644 --- a/src/logic/LongRunAverageOperatorFormula.h +++ b/src/logic/LongRunAverageOperatorFormula.h @@ -17,8 +17,6 @@ namespace storm { virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; - virtual std::shared_ptr substitute(std::map const& substitution) const override; - virtual std::ostream& writeToStream(std::ostream& out) const override; }; } diff --git a/src/logic/LongRunAverageRewardFormula.cpp b/src/logic/LongRunAverageRewardFormula.cpp index 30bfd7ba9..3207b9bf2 100644 --- a/src/logic/LongRunAverageRewardFormula.cpp +++ b/src/logic/LongRunAverageRewardFormula.cpp @@ -20,10 +20,6 @@ namespace storm { return visitor.visit(*this, data); } - std::shared_ptr LongRunAverageRewardFormula::substitute(std::map const& substitution) const { - return std::shared_ptr(new LongRunAverageRewardFormula()); - } - std::ostream& LongRunAverageRewardFormula::writeToStream(std::ostream& out) const { return out << "LRA"; } diff --git a/src/logic/LongRunAverageRewardFormula.h b/src/logic/LongRunAverageRewardFormula.h index 3dfea465e..d65e55314 100644 --- a/src/logic/LongRunAverageRewardFormula.h +++ b/src/logic/LongRunAverageRewardFormula.h @@ -17,8 +17,6 @@ namespace storm { virtual bool isRewardPathFormula() const override; virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; virtual std::ostream& writeToStream(std::ostream& out) const override; diff --git a/src/logic/NextFormula.cpp b/src/logic/NextFormula.cpp index e73171a95..a0916e14f 100644 --- a/src/logic/NextFormula.cpp +++ b/src/logic/NextFormula.cpp @@ -19,11 +19,7 @@ namespace storm { boost::any NextFormula::accept(FormulaVisitor const& visitor, boost::any const& data) const { return visitor.visit(*this, data); } - - std::shared_ptr NextFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getSubformula().substitute(substitution)); - } - + std::ostream& NextFormula::writeToStream(std::ostream& out) const { out << "X "; this->getSubformula().writeToStream(out); diff --git a/src/logic/NextFormula.h b/src/logic/NextFormula.h index bade60456..4d895a48e 100644 --- a/src/logic/NextFormula.h +++ b/src/logic/NextFormula.h @@ -17,9 +17,7 @@ namespace storm { virtual bool isProbabilityPathFormula() const override; virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; - + virtual std::ostream& writeToStream(std::ostream& out) const override; }; } diff --git a/src/logic/ProbabilityOperatorFormula.cpp b/src/logic/ProbabilityOperatorFormula.cpp index 84d6c55d8..e630e45da 100644 --- a/src/logic/ProbabilityOperatorFormula.cpp +++ b/src/logic/ProbabilityOperatorFormula.cpp @@ -18,11 +18,7 @@ namespace storm { boost::any ProbabilityOperatorFormula::accept(FormulaVisitor const& visitor, boost::any const& data) const { return visitor.visit(*this, data); } - - std::shared_ptr ProbabilityOperatorFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getSubformula().substitute(substitution), this->operatorInformation); - } - + std::ostream& ProbabilityOperatorFormula::writeToStream(std::ostream& out) const { out << "P"; OperatorFormula::writeToStream(out); diff --git a/src/logic/ProbabilityOperatorFormula.h b/src/logic/ProbabilityOperatorFormula.h index 786d58b44..b1259262c 100644 --- a/src/logic/ProbabilityOperatorFormula.h +++ b/src/logic/ProbabilityOperatorFormula.h @@ -16,9 +16,7 @@ namespace storm { virtual bool isProbabilityOperatorFormula() const override; virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; - + virtual std::ostream& writeToStream(std::ostream& out) const override; }; } diff --git a/src/logic/RewardOperatorFormula.cpp b/src/logic/RewardOperatorFormula.cpp index 5daea4b59..aba2ec92a 100644 --- a/src/logic/RewardOperatorFormula.cpp +++ b/src/logic/RewardOperatorFormula.cpp @@ -40,10 +40,6 @@ namespace storm { this->getSubformula().gatherReferencedRewardModels(referencedRewardModels); } - std::shared_ptr RewardOperatorFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getSubformula().substitute(substitution), this->rewardModelName, this->operatorInformation, this->rewardMeasureType); - } - RewardMeasureType RewardOperatorFormula::getMeasureType() const { return rewardMeasureType; } diff --git a/src/logic/RewardOperatorFormula.h b/src/logic/RewardOperatorFormula.h index 300a82e37..99ca17233 100644 --- a/src/logic/RewardOperatorFormula.h +++ b/src/logic/RewardOperatorFormula.h @@ -50,9 +50,7 @@ namespace storm { * @return The measure type. */ RewardMeasureType getMeasureType() const; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; - + private: // The (optional) name of the reward model this property refers to. boost::optional rewardModelName; diff --git a/src/logic/TimeOperatorFormula.cpp b/src/logic/TimeOperatorFormula.cpp index 6a421a1a5..e7f711729 100644 --- a/src/logic/TimeOperatorFormula.cpp +++ b/src/logic/TimeOperatorFormula.cpp @@ -18,11 +18,7 @@ namespace storm { boost::any TimeOperatorFormula::accept(FormulaVisitor const& visitor, boost::any const& data) const { return visitor.visit(*this, data); } - - std::shared_ptr TimeOperatorFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getSubformula().substitute(substitution), this->operatorInformation, this->rewardMeasureType); - } - + RewardMeasureType TimeOperatorFormula::getMeasureType() const { return rewardMeasureType; } diff --git a/src/logic/TimeOperatorFormula.h b/src/logic/TimeOperatorFormula.h index b9f243a4c..24906bb6e 100644 --- a/src/logic/TimeOperatorFormula.h +++ b/src/logic/TimeOperatorFormula.h @@ -18,9 +18,7 @@ namespace storm { virtual bool isTimeOperatorFormula() const override; virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; - + virtual std::ostream& writeToStream(std::ostream& out) const override; /*! diff --git a/src/logic/UnaryBooleanStateFormula.cpp b/src/logic/UnaryBooleanStateFormula.cpp index 3983e5b27..83feb7fd4 100644 --- a/src/logic/UnaryBooleanStateFormula.cpp +++ b/src/logic/UnaryBooleanStateFormula.cpp @@ -26,11 +26,7 @@ namespace storm { bool UnaryBooleanStateFormula::isNot() const { return this->getOperator() == OperatorType::Not; } - - std::shared_ptr UnaryBooleanStateFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->operatorType, this->getSubformula().substitute(substitution)); - } - + std::ostream& UnaryBooleanStateFormula::writeToStream(std::ostream& out) const { switch (operatorType) { case OperatorType::Not: out << "!("; break; diff --git a/src/logic/UnaryBooleanStateFormula.h b/src/logic/UnaryBooleanStateFormula.h index 93d45f862..a30886a60 100644 --- a/src/logic/UnaryBooleanStateFormula.h +++ b/src/logic/UnaryBooleanStateFormula.h @@ -22,8 +22,6 @@ namespace storm { OperatorType getOperator() const; virtual bool isNot() const; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; virtual std::ostream& writeToStream(std::ostream& out) const override; diff --git a/src/logic/UntilFormula.cpp b/src/logic/UntilFormula.cpp index a00eae77d..3aaf93da1 100644 --- a/src/logic/UntilFormula.cpp +++ b/src/logic/UntilFormula.cpp @@ -20,10 +20,6 @@ namespace storm { return visitor.visit(*this, data); } - std::shared_ptr UntilFormula::substitute(std::map const& substitution) const { - return std::make_shared(this->getLeftSubformula().substitute(substitution), this->getRightSubformula().substitute(substitution)); - } - std::ostream& UntilFormula::writeToStream(std::ostream& out) const { this->getLeftSubformula().writeToStream(out); out << " U "; diff --git a/src/logic/UntilFormula.h b/src/logic/UntilFormula.h index 887d39b45..e03d32047 100644 --- a/src/logic/UntilFormula.h +++ b/src/logic/UntilFormula.h @@ -17,9 +17,7 @@ namespace storm { virtual bool isProbabilityPathFormula() const override; virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; - - virtual std::shared_ptr substitute(std::map const& substitution) const override; - + virtual std::ostream& writeToStream(std::ostream& out) const override; }; } diff --git a/src/logic/VariableSubstitutionVisitor.cpp b/src/logic/VariableSubstitutionVisitor.cpp new file mode 100644 index 000000000..9a5ba5b8a --- /dev/null +++ b/src/logic/VariableSubstitutionVisitor.cpp @@ -0,0 +1,21 @@ +#include "src/logic/VariableSubstitutionVisitor.h" + +#include "src/logic/Formulas.h" + +namespace storm { + namespace logic { + + VariableSubstitutionVisitor::VariableSubstitutionVisitor(std::map const& substitution) : substitution(substitution) { + // Intentionally left empty. + } + + std::shared_ptr VariableSubstitutionVisitor::substitute(Formula const& f) const { + boost::any result = f.accept(*this, boost::any()); + return boost::any_cast>(result); + } + + boost::any VariableSubstitutionVisitor::visit(AtomicExpressionFormula const& f, boost::any const& data) const { + return std::static_pointer_cast(std::make_shared(f.getExpression().substitute(substitution))); + } + } +} diff --git a/src/logic/VariableSubstitutionVisitor.h b/src/logic/VariableSubstitutionVisitor.h new file mode 100644 index 000000000..11876ba59 --- /dev/null +++ b/src/logic/VariableSubstitutionVisitor.h @@ -0,0 +1,29 @@ +#ifndef STORM_LOGIC_VARIABLESUBSTITUTIONVISITOR_H_ +#define STORM_LOGIC_VARIABLESUBSTITUTIONVISITOR_H_ + +#include + +#include "src/logic/CloneVisitor.h" + +#include "src/storage/expressions/Expression.h" + +namespace storm { + namespace logic { + + class VariableSubstitutionVisitor : public CloneVisitor { + public: + VariableSubstitutionVisitor(std::map const& substitution); + + std::shared_ptr substitute(Formula const& f) const; + + virtual boost::any visit(AtomicExpressionFormula const& f, boost::any const& data) const override; + + private: + std::map const& substitution; + }; + + } +} + + +#endif /* STORM_LOGIC_VARIABLESUBSTITUTIONVISITOR_H_ */ \ No newline at end of file diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp index 76765d367..f8adee32e 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp @@ -15,6 +15,7 @@ #include "src/settings/modules/GeneralSettings.h" #include "src/utility/macros.h" +#include "src/exceptions/InvalidOperationException.h" #include "src/exceptions/InvalidPropertyException.h" #include "src/exceptions/NotSupportedException.h" @@ -58,7 +59,13 @@ namespace storm { template storm::expressions::Expression SparseMdpLearningModelChecker::getTargetStateExpression(storm::logic::Formula const& subformula) const { std::shared_ptr preparedSubformula = subformula.substitute(program.getLabelToExpressionMapping()); - return preparedSubformula->toExpression(); + storm::expressions::Expression result; + try { + result = preparedSubformula->toExpression(); + } catch(storm::exceptions::InvalidOperationException const& e) { + STORM_LOG_THROW(false, storm::exceptions::InvalidPropertyException, "The property refers to unknown labels."); + } + return result; } template @@ -284,28 +291,41 @@ namespace storm { } template - uint32_t SparseMdpLearningModelChecker::sampleActionOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { + typename SparseMdpLearningModelChecker::ActionType SparseMdpLearningModelChecker::sampleActionOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { // Determine the values of all available actions. std::vector> actionValues; StateType rowGroup = explorationInformation.getRowGroup(currentStateId); auto choicesInEcIt = explorationInformation.stateToLeavingActionsOfEndComponent.find(currentStateId); + + // Check for cases in which we do not need to perform more work. + if (choicesInEcIt == explorationInformation.stateToLeavingActionsOfEndComponent.end()) { + if (explorationInformation.onlyOneActionAvailable(rowGroup)) { + return explorationInformation.getStartRowOfGroup(rowGroup); + } + } else { + if (choicesInEcIt->second->size() == 1) { + return *choicesInEcIt->second->begin(); + } + } + + // If there are more choices to consider, start by gathering the values of relevant actions. if (choicesInEcIt != explorationInformation.stateToLeavingActionsOfEndComponent.end()) { STORM_LOG_TRACE("Sampling from actions leaving the previously detected EC."); for (auto const& row : *choicesInEcIt->second) { - actionValues.push_back(std::make_pair(row, computeUpperBoundOfAction(row, explorationInformation, bounds))); + actionValues.push_back(std::make_pair(row, bounds.getBoundForAction(explorationInformation.optimizationDirection, row))); } } else { STORM_LOG_TRACE("Sampling from actions leaving the state."); for (uint32_t row = explorationInformation.getStartRowOfGroup(rowGroup); row < explorationInformation.getStartRowOfGroup(rowGroup + 1); ++row) { - actionValues.push_back(std::make_pair(row, computeUpperBoundOfAction(row, explorationInformation, bounds))); + actionValues.push_back(std::make_pair(row, bounds.getBoundForAction(explorationInformation.optimizationDirection, row))); } } STORM_LOG_ASSERT(!actionValues.empty(), "Values for actions must not be empty."); // Sort the actions wrt. to the optimization direction. - if (explorationInformation.optimizationDirection == storm::OptimizationDirection::Maximize) { + if (explorationInformation.maximize()) { std::sort(actionValues.begin(), actionValues.end(), [] (std::pair const& a, std::pair const& b) { return a.second > b.second; } ); } else { std::sort(actionValues.begin(), actionValues.end(), [] (std::pair const& a, std::pair const& b) { return a.second < b.second; } ); @@ -349,14 +369,13 @@ namespace storm { // Determine the set of states that was expanded. std::vector relevantStates; for (StateType state = 0; state < explorationInformation.stateStorage.numberOfStates; ++state) { - if (!explorationInformation.isUnexplored(state)) { + // Add the state to the relevant states if it's unexplored. Additionally, if we are computing minimal + // probabilities, we only consider it relevant if it's not a target state. + if (!explorationInformation.isUnexplored(state) && (explorationInformation.maximize() || !comparator.isOne(bounds.getLowerBoundForState(state, explorationInformation)))) { relevantStates.push_back(state); } } - - // Sort according to the actual row groups so we can insert the elements in order later. - std::sort(relevantStates.begin(), relevantStates.end(), [&explorationInformation] (StateType const& a, StateType const& b) { return explorationInformation.getRowGroup(a) < explorationInformation.getRowGroup(b); }); - StateType unexploredState = relevantStates.size(); + StateType sink = relevantStates.size(); // Create a mapping for faster look-up during the translation of flexible matrix to the real sparse matrix. std::unordered_map relevantStateToNewRowGroupMapping; @@ -382,14 +401,14 @@ namespace storm { } } if (unexpandedProbability != storm::utility::zero()) { - builder.addNextValue(currentRow, unexploredState, unexpandedProbability); + builder.addNextValue(currentRow, sink, unexpandedProbability); } ++currentRow; } } // Then, make the unexpanded state absorbing. builder.newRowGroup(currentRow); - builder.addNextValue(currentRow, unexploredState, storm::utility::one()); + builder.addNextValue(currentRow, sink, storm::utility::one()); STORM_LOG_TRACE("Successfully built matrix for MEC decomposition."); // Go on to step 2. @@ -399,74 +418,102 @@ namespace storm { // 3. Analyze the MEC decomposition. for (auto const& mec : mecDecomposition) { - // Ignore the (expected) MEC of the unexplored state. - if (mec.containsState(unexploredState)) { + // Ignore the (expected) MEC of the sink state. + if (mec.containsState(sink)) { continue; } - bool containsTargetState = false; + if (explorationInformation.maximize()) { + analyzeMecForMaximalProbabilities(mec, relevantStates, relevantStatesMatrix, explorationInformation, bounds); + } else { + analyzeMecForMinimalProbabilities(mec, relevantStates, relevantStatesMatrix, explorationInformation, bounds); + } + } + } + + template + void SparseMdpLearningModelChecker::analyzeMecForMaximalProbabilities(storm::storage::MaximalEndComponent const& mec, std::vector const& relevantStates, storm::storage::SparseMatrix const& relevantStatesMatrix, ExplorationInformation& explorationInformation, BoundValues& bounds) const { + // For maximal probabilities, we check (a) which MECs contain a target state, because the associated states + // have a probability of 1 (and we can therefore set their lower bounds to 1) and (b) which of the remaining + // MECs have no outgoing action, because the associated states have a probability of 0 (and we can therefore + // set their upper bounds to 0). + + bool containsTargetState = false; + + // Now we record all choices leaving the EC. + ActionSetPointer leavingChoices = std::make_shared(); + for (auto const& stateAndChoices : mec) { + // Compute the state of the original model that corresponds to the current state. + StateType originalState = relevantStates[stateAndChoices.first]; + StateType originalRowGroup = explorationInformation.getRowGroup(originalState); - // Now we record all choices leaving the EC. - ActionSetPointer leavingChoices = std::make_shared(); - for (auto const& stateAndChoices : mec) { - // Compute the state of the original model that corresponds to the current state. - StateType originalState = relevantStates[stateAndChoices.first]; - uint32_t originalRowGroup = explorationInformation.getRowGroup(originalState); - - // Check whether a target state is contained in the MEC. - if (!containsTargetState && comparator.isOne(bounds.getLowerBoundForRowGroup(originalRowGroup))) { - containsTargetState = true; - } - - // For each state, compute the actions that leave the MEC. - auto includedChoicesIt = stateAndChoices.second.begin(); - auto includedChoicesIte = stateAndChoices.second.end(); - for (auto action = explorationInformation.getStartRowOfGroup(originalRowGroup); action < explorationInformation.getStartRowOfGroup(originalRowGroup + 1); ++action) { - if (includedChoicesIt != includedChoicesIte) { - STORM_LOG_TRACE("Next (local) choice contained in MEC is " << (*includedChoicesIt - relevantStatesMatrix.getRowGroupIndices()[stateAndChoices.first])); - STORM_LOG_TRACE("Current (local) choice iterated is " << (action - explorationInformation.getStartRowOfGroup(originalRowGroup))); - if (action - explorationInformation.getStartRowOfGroup(originalRowGroup) != *includedChoicesIt - relevantStatesMatrix.getRowGroupIndices()[stateAndChoices.first]) { - STORM_LOG_TRACE("Choice leaves the EC."); - leavingChoices->insert(action); - } else { - STORM_LOG_TRACE("Choice stays in the EC."); - ++includedChoicesIt; - } - } else { - STORM_LOG_TRACE("Choice leaves the EC, because there is no more choice staying in the EC."); + // Check whether a target state is contained in the MEC. + if (!containsTargetState && comparator.isOne(bounds.getLowerBoundForRowGroup(originalRowGroup))) { + containsTargetState = true; + } + + // For each state, compute the actions that leave the MEC. + auto includedChoicesIt = stateAndChoices.second.begin(); + auto includedChoicesIte = stateAndChoices.second.end(); + for (auto action = explorationInformation.getStartRowOfGroup(originalRowGroup); action < explorationInformation.getStartRowOfGroup(originalRowGroup + 1); ++action) { + if (includedChoicesIt != includedChoicesIte) { + STORM_LOG_TRACE("Next (local) choice contained in MEC is " << (*includedChoicesIt - relevantStatesMatrix.getRowGroupIndices()[stateAndChoices.first])); + STORM_LOG_TRACE("Current (local) choice iterated is " << (action - explorationInformation.getStartRowOfGroup(originalRowGroup))); + if (action - explorationInformation.getStartRowOfGroup(originalRowGroup) != *includedChoicesIt - relevantStatesMatrix.getRowGroupIndices()[stateAndChoices.first]) { + STORM_LOG_TRACE("Choice leaves the EC."); leavingChoices->insert(action); + } else { + STORM_LOG_TRACE("Choice stays in the EC."); + ++includedChoicesIt; } + } else { + STORM_LOG_TRACE("Choice leaves the EC, because there is no more choice staying in the EC."); + leavingChoices->insert(action); } - - explorationInformation.stateToLeavingActionsOfEndComponent[originalState] = leavingChoices; } - // If one of the states of the EC is a target state, all states in the EC have probability 1. - if (containsTargetState) { - STORM_LOG_TRACE("MEC contains a target state."); - for (auto const& stateAndChoices : mec) { - // Compute the state of the original model that corresponds to the current state. - StateType const& originalState = relevantStates[stateAndChoices.first]; - - STORM_LOG_TRACE("Setting lower bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 1."); - bounds.setLowerBoundForState(originalState, explorationInformation, storm::utility::one()); - explorationInformation.addTerminalState(originalState); - } - } else if (leavingChoices->empty()) { - STORM_LOG_TRACE("MEC's leaving choices are empty."); - // If there is no choice leaving the EC, but it contains no target state, all states have probability 0. - for (auto const& stateAndChoices : mec) { - // Compute the state of the original model that corresponds to the current state. - StateType const& originalState = relevantStates[stateAndChoices.first]; - - STORM_LOG_TRACE("Setting upper bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 0."); - bounds.setUpperBoundForState(originalState, explorationInformation, storm::utility::zero()); - explorationInformation.addTerminalState(originalState); - } + explorationInformation.stateToLeavingActionsOfEndComponent[originalState] = leavingChoices; + } + + // If one of the states of the EC is a target state, all states in the EC have probability 1. + if (containsTargetState) { + STORM_LOG_TRACE("MEC contains a target state."); + for (auto const& stateAndChoices : mec) { + // Compute the state of the original model that corresponds to the current state. + StateType const& originalState = relevantStates[stateAndChoices.first]; + + STORM_LOG_TRACE("Setting lower bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 1."); + bounds.setLowerBoundForState(originalState, explorationInformation, storm::utility::one()); + explorationInformation.addTerminalState(originalState); + } + } else if (leavingChoices->empty()) { + STORM_LOG_TRACE("MEC's leaving choices are empty."); + // If there is no choice leaving the EC, but it contains no target state, all states have probability 0. + for (auto const& stateAndChoices : mec) { + // Compute the state of the original model that corresponds to the current state. + StateType const& originalState = relevantStates[stateAndChoices.first]; + + STORM_LOG_TRACE("Setting upper bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 0."); + bounds.setUpperBoundForState(originalState, explorationInformation, storm::utility::zero()); + explorationInformation.addTerminalState(originalState); } } } + template + void SparseMdpLearningModelChecker::analyzeMecForMinimalProbabilities(storm::storage::MaximalEndComponent const& mec, std::vector const& relevantStates, storm::storage::SparseMatrix const& relevantStatesMatrix, ExplorationInformation& explorationInformation, BoundValues& bounds) const { + // For minimal probabilities, all found MECs are guaranteed to not contain a target state. Hence, in all + // associated states, the probability is 0 and we can set the upper bounds of the states to 0). + + for (auto const& stateAndChoices : mec) { + // Compute the state of the original model that corresponds to the current state. + StateType originalState = relevantStates[stateAndChoices.first]; + + bounds.setUpperBoundForState(originalState, explorationInformation, storm::utility::zero()); + explorationInformation.addTerminalState(originalState); + } + } + template ValueType SparseMdpLearningModelChecker::computeLowerBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { ValueType result = storm::utility::zero(); @@ -525,7 +572,7 @@ namespace storm { bounds.setBoundsForAction(action, newBoundsForAction); // Check if we need to update the values for the states. - if (explorationInformation.optimizationDirection == storm::OptimizationDirection::Maximize) { + if (explorationInformation.maximize()) { bounds.setLowerBoundOfStateIfGreaterThanOld(state, explorationInformation, newBoundsForAction.first); StateType rowGroup = explorationInformation.getRowGroup(state); @@ -542,7 +589,8 @@ namespace storm { StateType rowGroup = explorationInformation.getRowGroup(state); if (bounds.getLowerBoundForRowGroup(rowGroup) < newBoundsForAction.first) { if (explorationInformation.getRowGroupSize(rowGroup) > 1) { - newBoundsForAction.first = std::min(newBoundsForAction.first, computeBoundOverAllOtherActions(storm::OptimizationDirection::Maximize, state, action, explorationInformation, bounds)); + ValueType min = computeBoundOverAllOtherActions(storm::OptimizationDirection::Minimize, state, action, explorationInformation, bounds); + newBoundsForAction.first = std::min(newBoundsForAction.first, min); } bounds.setLowerBoundForRowGroup(rowGroup, newBoundsForAction.first); diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h index c3a931913..b0a09e343 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h @@ -21,6 +21,8 @@ namespace storm { template class StateStorage; } + + class MaximalEndComponent; } namespace generator { @@ -49,7 +51,7 @@ namespace storm { private: // A struct that keeps track of certain statistics during the computation. struct Statistics { - Statistics() : iterations(0), maxPathLength(0), numberOfTargetStates(0), numberOfExploredStates(0), pathLengthUntilEndComponentDetection(27) { + Statistics() : iterations(0), maxPathLength(0), numberOfTargetStates(0), numberOfExploredStates(0), pathLengthUntilEndComponentDetection(10000) { // Intentionally left empty. } @@ -170,6 +172,10 @@ namespace storm { return rowGroupIndices[group + 1] - rowGroupIndices[group]; } + bool onlyOneActionAvailable(StateType const& group) const { + return getRowGroupSize(group) == 1; + } + void addTerminalState(StateType const& state) { terminalStates.insert(state); } @@ -185,6 +191,14 @@ namespace storm { void addRowsToMatrix(std::size_t const& count) { matrix.resize(matrix.size() + count); } + + bool maximize() const { + return optimizationDirection == storm::OptimizationDirection::Maximize; + } + + bool minimize() const { + return !maximize(); + } }; // A struct containg the lower and upper bounds per state and action. @@ -241,6 +255,14 @@ namespace storm { return upperBoundsPerAction[action]; } + ValueType const& getBoundForAction(storm::OptimizationDirection const& direction, ActionType const& action) const { + if (direction == storm::OptimizationDirection::Maximize) { + return getUpperBoundForAction(action); + } else { + return getLowerBoundForAction(action); + } + } + ValueType getDifferenceOfStateBounds(StateType const& state, ExplorationInformation const& explorationInformation) { std::pair bounds = getBoundsForState(state, explorationInformation); return bounds.second - bounds.first; @@ -312,12 +334,16 @@ namespace storm { bool exploreState(StateGeneration& stateGeneration, StateType const& currentStateId, storm::generator::CompressedState const& currentState, ExplorationInformation& explorationInformation, BoundValues& bounds, Statistics& stats) const; - uint32_t sampleActionOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; + ActionType sampleActionOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; StateType sampleSuccessorFromAction(ActionType const& chosenAction, ExplorationInformation const& explorationInformation) const; void detectEndComponents(StateActionStack const& stack, ExplorationInformation& explorationInformation, BoundValues& bounds) const; + void analyzeMecForMaximalProbabilities(storm::storage::MaximalEndComponent const& mec, std::vector const& relevantStates, storm::storage::SparseMatrix const& relevantStatesMatrix, ExplorationInformation& explorationInformation, BoundValues& bounds) const; + + void analyzeMecForMinimalProbabilities(storm::storage::MaximalEndComponent const& mec, std::vector const& relevantStates, storm::storage::SparseMatrix const& relevantStatesMatrix, ExplorationInformation& explorationInformation, BoundValues& bounds) const; + void updateProbabilityBoundsAlongSampledPath(StateActionStack& stack, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; void updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const;