From c2b287a1e174bca5eb32c9ba3c55b200fa5e234f Mon Sep 17 00:00:00 2001 From: dehnert Date: Wed, 6 Apr 2016 16:05:54 +0200 Subject: [PATCH] more work on learning approach Former-commit-id: 48aa9ddd2cbe7ea50a2a9214f4fdff41163f8e05 --- src/logic/AtomicLabelFormula.cpp | 1 + src/logic/BinaryBooleanStateFormula.h | 2 +- src/logic/BooleanLiteralFormula.h | 2 +- src/logic/BoundedUntilFormula.h | 2 +- src/logic/CloneVisitor.h | 37 ++++ src/logic/ConditionalFormula.cpp | 4 + src/logic/ConditionalFormula.h | 7 +- src/logic/CopyVisitor.cpp | 106 +++++++++++ src/logic/CumulativeRewardFormula.h | 2 +- src/logic/EventuallyFormula.cpp | 4 + src/logic/EventuallyFormula.h | 2 + src/logic/Formula.cpp | 12 ++ src/logic/Formula.h | 9 + src/logic/LabelSubstitutionVisitor.cpp | 27 +++ src/logic/LabelSubstitutionVisitor.h | 29 +++ src/logic/OperatorFormula.cpp | 4 + src/logic/OperatorFormula.h | 4 +- src/logic/ToExpressionVisitor.cpp | 103 +++++++++++ src/logic/ToExpressionVisitor.h | 39 ++++ .../SparseMdpLearningModelChecker.cpp | 167 ++++++++++-------- .../SparseMdpLearningModelChecker.h | 34 ++-- src/storage/prism/Program.cpp | 8 + src/storage/prism/Program.h | 7 + 23 files changed, 515 insertions(+), 97 deletions(-) create mode 100644 src/logic/CloneVisitor.h create mode 100644 src/logic/CopyVisitor.cpp create mode 100644 src/logic/LabelSubstitutionVisitor.cpp create mode 100644 src/logic/LabelSubstitutionVisitor.h create mode 100644 src/logic/ToExpressionVisitor.cpp create mode 100644 src/logic/ToExpressionVisitor.h diff --git a/src/logic/AtomicLabelFormula.cpp b/src/logic/AtomicLabelFormula.cpp index 19a5f673f..292351c5d 100644 --- a/src/logic/AtomicLabelFormula.cpp +++ b/src/logic/AtomicLabelFormula.cpp @@ -1,5 +1,6 @@ #include "src/logic/AtomicLabelFormula.h" +#include "src/logic/AtomicExpressionFormula.h" #include "src/logic/FormulaVisitor.h" namespace storm { diff --git a/src/logic/BinaryBooleanStateFormula.h b/src/logic/BinaryBooleanStateFormula.h index 892d28bc9..c0880d058 100644 --- a/src/logic/BinaryBooleanStateFormula.h +++ b/src/logic/BinaryBooleanStateFormula.h @@ -29,7 +29,7 @@ namespace storm { virtual std::ostream& writeToStream(std::ostream& out) const override; virtual std::shared_ptr substitute(std::map const& substitution) const override; - + private: OperatorType operatorType; }; diff --git a/src/logic/BooleanLiteralFormula.h b/src/logic/BooleanLiteralFormula.h index 405044600..955d1dbbe 100644 --- a/src/logic/BooleanLiteralFormula.h +++ b/src/logic/BooleanLiteralFormula.h @@ -20,7 +20,7 @@ namespace storm { virtual boost::any accept(FormulaVisitor const& visitor, boost::any const& data) const override; virtual std::shared_ptr substitute(std::map const& substitution) const override; - + virtual std::ostream& writeToStream(std::ostream& out) const override; private: diff --git a/src/logic/BoundedUntilFormula.h b/src/logic/BoundedUntilFormula.h index 222bab142..0e18633ce 100644 --- a/src/logic/BoundedUntilFormula.h +++ b/src/logic/BoundedUntilFormula.h @@ -27,7 +27,7 @@ namespace storm { virtual std::ostream& writeToStream(std::ostream& out) const override; virtual std::shared_ptr substitute(std::map const& substitution) const override; - + private: boost::variant> bounds; }; diff --git a/src/logic/CloneVisitor.h b/src/logic/CloneVisitor.h new file mode 100644 index 000000000..a23ee4a8b --- /dev/null +++ b/src/logic/CloneVisitor.h @@ -0,0 +1,37 @@ +#ifndef STORM_LOGIC_CLONEVISITOR_H_ +#define STORM_LOGIC_CLONEVISITOR_H_ + +#include "src/logic/FormulaVisitor.h" + +namespace storm { + namespace logic { + + class CloneVisitor : public FormulaVisitor { + public: + std::shared_ptr clone(Formula const& f) const; + + virtual boost::any visit(AtomicExpressionFormula const& f, boost::any const& data) const override; + virtual boost::any visit(AtomicLabelFormula const& f, boost::any const& data) const override; + virtual boost::any visit(BinaryBooleanStateFormula const& f, boost::any const& data) const override; + virtual boost::any visit(BooleanLiteralFormula const& f, boost::any const& data) const override; + virtual boost::any visit(BoundedUntilFormula const& f, boost::any const& data) const override; + virtual boost::any visit(ConditionalFormula const& f, boost::any const& data) const override; + virtual boost::any visit(CumulativeRewardFormula const& f, boost::any const& data) const override; + virtual boost::any visit(EventuallyFormula const& f, boost::any const& data) const override; + virtual boost::any visit(TimeOperatorFormula const& f, boost::any const& data) const override; + virtual boost::any visit(GloballyFormula const& f, boost::any const& data) const override; + virtual boost::any visit(InstantaneousRewardFormula const& f, boost::any const& data) const override; + virtual boost::any visit(LongRunAverageOperatorFormula const& f, boost::any const& data) const override; + virtual boost::any visit(LongRunAverageRewardFormula const& f, boost::any const& data) const override; + virtual boost::any visit(NextFormula const& f, boost::any const& data) const override; + virtual boost::any visit(ProbabilityOperatorFormula const& f, boost::any const& data) const override; + virtual boost::any visit(RewardOperatorFormula const& f, boost::any const& data) const override; + virtual boost::any visit(UnaryBooleanStateFormula const& f, boost::any const& data) const override; + virtual boost::any visit(UntilFormula const& f, boost::any const& data) const override; + }; + + } +} + + +#endif /* STORM_LOGIC_CLONEVISITOR_H_ */ \ No newline at end of file diff --git a/src/logic/ConditionalFormula.cpp b/src/logic/ConditionalFormula.cpp index b12b831be..55bddabd9 100644 --- a/src/logic/ConditionalFormula.cpp +++ b/src/logic/ConditionalFormula.cpp @@ -18,6 +18,10 @@ namespace storm { return *conditionFormula; } + FormulaContext const& ConditionalFormula::getContext() const { + return context; + } + bool ConditionalFormula::isConditionalProbabilityFormula() const { return context == FormulaContext::Probability; } diff --git a/src/logic/ConditionalFormula.h b/src/logic/ConditionalFormula.h index 66d71ebba..87303198c 100644 --- a/src/logic/ConditionalFormula.h +++ b/src/logic/ConditionalFormula.h @@ -7,9 +7,7 @@ namespace storm { namespace logic { class ConditionalFormula : public Formula { - public: - enum class Context { Probability, Reward }; - + public: ConditionalFormula(std::shared_ptr const& subformula, std::shared_ptr const& conditionFormula, FormulaContext context = FormulaContext::Probability); virtual ~ConditionalFormula() { @@ -18,6 +16,7 @@ namespace storm { Formula const& getSubformula() const; Formula const& getConditionFormula() const; + FormulaContext const& getContext() const; virtual bool isConditionalProbabilityFormula() const override; virtual bool isConditionalRewardFormula() const override; @@ -27,7 +26,7 @@ namespace storm { virtual std::ostream& writeToStream(std::ostream& out) const override; virtual std::shared_ptr substitute(std::map const& substitution) const override; - + virtual void gatherAtomicExpressionFormulas(std::vector>& atomicExpressionFormulas) const override; virtual void gatherAtomicLabelFormulas(std::vector>& atomicLabelFormulas) const override; virtual void gatherReferencedRewardModels(std::set& referencedRewardModels) const override; diff --git a/src/logic/CopyVisitor.cpp b/src/logic/CopyVisitor.cpp new file mode 100644 index 000000000..fe7182bf7 --- /dev/null +++ b/src/logic/CopyVisitor.cpp @@ -0,0 +1,106 @@ +#include "src/logic/CloneVisitor.h" + +#include "src/logic/Formulas.h" + +namespace storm { + namespace logic { + + std::shared_ptr CloneVisitor::clone(Formula const& f) const { + boost::any result = f.accept(*this, boost::any()); + return boost::any_cast>(result); + } + + boost::any CloneVisitor::visit(AtomicExpressionFormula const& f, boost::any const& data) const { + return std::static_pointer_cast(std::make_shared(f)); + } + + boost::any CloneVisitor::visit(AtomicLabelFormula const& f, boost::any const& data) const { + return std::static_pointer_cast(std::make_shared(f)); + } + + boost::any CloneVisitor::visit(BinaryBooleanStateFormula const& f, boost::any const& data) const { + std::shared_ptr left = boost::any_cast>(f.getLeftSubformula().accept(*this, data)); + std::shared_ptr right = boost::any_cast>(f.getRightSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(f.getOperator(), left, right)); + } + + boost::any CloneVisitor::visit(BooleanLiteralFormula const& f, boost::any const& data) const { + return std::static_pointer_cast(std::make_shared(f)); + } + + boost::any CloneVisitor::visit(BoundedUntilFormula const& f, boost::any const& data) const { + std::shared_ptr left = boost::any_cast>(f.getLeftSubformula().accept(*this, data)); + std::shared_ptr right = boost::any_cast>(f.getRightSubformula().accept(*this, data)); + if (f.hasDiscreteTimeBound()) { + return std::static_pointer_cast(std::make_shared(left, right, f.getDiscreteTimeBound())); + } else { + return std::static_pointer_cast(std::make_shared(left, right, f.getIntervalBounds())); + } + } + + boost::any CloneVisitor::visit(ConditionalFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + std::shared_ptr conditionFormula = boost::any_cast>(f.getConditionFormula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(subformula, conditionFormula, f.getContext())); + } + + boost::any CloneVisitor::visit(CumulativeRewardFormula const& f, boost::any const& data) const { + return std::static_pointer_cast(std::make_shared(f)); + } + + boost::any CloneVisitor::visit(EventuallyFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(subformula, f.getContext())); + } + + boost::any CloneVisitor::visit(TimeOperatorFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(subformula, f.getOperatorInformation())); + } + + boost::any CloneVisitor::visit(GloballyFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(subformula)); + } + + boost::any CloneVisitor::visit(InstantaneousRewardFormula const& f, boost::any const& data) const { + return std::static_pointer_cast(std::make_shared(f)); + } + + boost::any CloneVisitor::visit(LongRunAverageOperatorFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(subformula, f.getOperatorInformation())); + } + + boost::any CloneVisitor::visit(LongRunAverageRewardFormula const& f, boost::any const& data) const { + return std::static_pointer_cast(std::make_shared(f)); + } + + boost::any CloneVisitor::visit(NextFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(subformula)); + } + + boost::any CloneVisitor::visit(ProbabilityOperatorFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(subformula, f.getOperatorInformation())); + } + + boost::any CloneVisitor::visit(RewardOperatorFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(subformula, f.getOptionalRewardModelName(), f.getOperatorInformation())); + } + + boost::any CloneVisitor::visit(UnaryBooleanStateFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(f.getOperator(), subformula)); + } + + boost::any CloneVisitor::visit(UntilFormula const& f, boost::any const& data) const { + std::shared_ptr left = boost::any_cast>(f.getLeftSubformula().accept(*this, data)); + std::shared_ptr right = boost::any_cast>(f.getRightSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(left, right)); + } + + } +} diff --git a/src/logic/CumulativeRewardFormula.h b/src/logic/CumulativeRewardFormula.h index 6a5e8c035..305c89895 100644 --- a/src/logic/CumulativeRewardFormula.h +++ b/src/logic/CumulativeRewardFormula.h @@ -33,7 +33,7 @@ namespace storm { double getContinuousTimeBound() const; virtual std::shared_ptr substitute(std::map const& substitution) const override; - + private: boost::variant timeBound; }; diff --git a/src/logic/EventuallyFormula.cpp b/src/logic/EventuallyFormula.cpp index 56c9decff..926c68af2 100644 --- a/src/logic/EventuallyFormula.cpp +++ b/src/logic/EventuallyFormula.cpp @@ -10,6 +10,10 @@ namespace storm { STORM_LOG_THROW(context == FormulaContext::Probability || context == FormulaContext::Reward || context == FormulaContext::Time, storm::exceptions::InvalidPropertyException, "Invalid context for formula."); } + FormulaContext const& EventuallyFormula::getContext() const { + return context; + } + bool EventuallyFormula::isEventuallyFormula() const { return true; } diff --git a/src/logic/EventuallyFormula.h b/src/logic/EventuallyFormula.h index f3716fee9..b21c92170 100644 --- a/src/logic/EventuallyFormula.h +++ b/src/logic/EventuallyFormula.h @@ -14,6 +14,8 @@ namespace storm { // Intentionally left empty. } + FormulaContext const& getContext() const; + virtual bool isEventuallyFormula() const override; virtual bool isReachabilityProbabilityFormula() const override; virtual bool isReachabilityRewardFormula() const override; diff --git a/src/logic/Formula.cpp b/src/logic/Formula.cpp index c48548f10..cc8a570be 100644 --- a/src/logic/Formula.cpp +++ b/src/logic/Formula.cpp @@ -3,6 +3,8 @@ #include "src/logic/FragmentChecker.h" #include "src/logic/FormulaInformationVisitor.h" +#include "src/logic/LabelSubstitutionVisitor.h" +#include "src/logic/ToExpressionVisitor.h" namespace storm { namespace logic { @@ -406,6 +408,16 @@ namespace storm { return referencedRewardModels; } + std::shared_ptr Formula::substitute(std::map const& labelSubstitution) const { + LabelSubstitutionVisitor visitor(labelSubstitution); + return visitor.substitute(*this); + } + + storm::expressions::Expression Formula::toExpression() const { + ToExpressionVisitor visitor; + return visitor.toExpression(*this); + } + std::shared_ptr Formula::asSharedPointer() { return this->shared_from_this(); } diff --git a/src/logic/Formula.h b/src/logic/Formula.h index b77ecfb20..3ac205abe 100644 --- a/src/logic/Formula.h +++ b/src/logic/Formula.h @@ -187,6 +187,15 @@ namespace storm { std::shared_ptr asSharedPointer() const; virtual std::shared_ptr substitute(std::map const& substitution) const = 0; + virtual std::shared_ptr substitute(std::map const& labelSubstitution) const; + + /*! + * Takes the formula and converts it to an equivalent expression assuming that only atomic expression formulas + * and boolean connectives appear in the formula. + * + * @return An equivalent expression. + */ + storm::expressions::Expression toExpression() const; std::string toString() const; virtual std::ostream& writeToStream(std::ostream& out) const = 0; diff --git a/src/logic/LabelSubstitutionVisitor.cpp b/src/logic/LabelSubstitutionVisitor.cpp new file mode 100644 index 000000000..ac05c09eb --- /dev/null +++ b/src/logic/LabelSubstitutionVisitor.cpp @@ -0,0 +1,27 @@ +#include "src/logic/LabelSubstitutionVisitor.h" + +#include "src/logic/Formulas.h" + +namespace storm { + namespace logic { + + LabelSubstitutionVisitor::LabelSubstitutionVisitor(std::map const& labelToExpressionMapping) : labelToExpressionMapping(labelToExpressionMapping) { + // Intentionally left empty. + } + + std::shared_ptr LabelSubstitutionVisitor::substitute(Formula const& f) const { + boost::any result = f.accept(*this, boost::any()); + return boost::any_cast>(result); + } + + boost::any LabelSubstitutionVisitor::visit(AtomicLabelFormula const& f, boost::any const& data) const { + auto it = labelToExpressionMapping.find(f.getLabel()); + if (it != labelToExpressionMapping.end()) { + return std::static_pointer_cast(std::make_shared(it->second)); + } else { + return std::static_pointer_cast(std::make_shared(f)); + } + } + + } +} diff --git a/src/logic/LabelSubstitutionVisitor.h b/src/logic/LabelSubstitutionVisitor.h new file mode 100644 index 000000000..dec06231e --- /dev/null +++ b/src/logic/LabelSubstitutionVisitor.h @@ -0,0 +1,29 @@ +#ifndef STORM_LOGIC_LABELSUBSTITUTIONVISITOR_H_ +#define STORM_LOGIC_LABELSUBSTITUTIONVISITOR_H_ + +#include + +#include "src/logic/CloneVisitor.h" + +#include "src/storage/expressions/Expression.h" + +namespace storm { + namespace logic { + + class LabelSubstitutionVisitor : public CloneVisitor { + public: + LabelSubstitutionVisitor(std::map const& labelToExpressionMapping); + + std::shared_ptr substitute(Formula const& f) const; + + virtual boost::any visit(AtomicLabelFormula const& f, boost::any const& data) const override; + + private: + std::map const& labelToExpressionMapping; + }; + + } +} + + +#endif /* STORM_LOGIC_FORMULAINFORMATIONVISITOR_H_ */ \ No newline at end of file diff --git a/src/logic/OperatorFormula.cpp b/src/logic/OperatorFormula.cpp index 4a8e2e215..474a9608b 100644 --- a/src/logic/OperatorFormula.cpp +++ b/src/logic/OperatorFormula.cpp @@ -50,6 +50,10 @@ namespace storm { return true; } + OperatorInformation const& OperatorFormula::getOperatorInformation() const { + return operatorInformation; + } + bool OperatorFormula::hasQualitativeResult() const { return this->hasBound(); } diff --git a/src/logic/OperatorFormula.h b/src/logic/OperatorFormula.h index 447647b99..ee7f7a87a 100644 --- a/src/logic/OperatorFormula.h +++ b/src/logic/OperatorFormula.h @@ -37,7 +37,9 @@ namespace storm { bool hasOptimalityType() const; storm::solver::OptimizationDirection const& getOptimalityType() const; virtual bool isOperatorFormula() const override; - + + OperatorInformation const& getOperatorInformation() const; + virtual bool hasQualitativeResult() const override; virtual bool hasQuantitativeResult() const override; diff --git a/src/logic/ToExpressionVisitor.cpp b/src/logic/ToExpressionVisitor.cpp new file mode 100644 index 000000000..887f9d405 --- /dev/null +++ b/src/logic/ToExpressionVisitor.cpp @@ -0,0 +1,103 @@ +#include "src/logic/ToExpressionVisitor.h" + +#include "src/logic/Formulas.h" + +#include "src/utility/macros.h" +#include "src/exceptions/InvalidOperationException.h" + +namespace storm { + namespace logic { + + storm::expressions::Expression ToExpressionVisitor::toExpression(Formula const& f) const { + boost::any result = f.accept(*this, boost::any()); + return boost::any_cast(result); + } + + boost::any ToExpressionVisitor::visit(AtomicExpressionFormula const& f, boost::any const& data) const { + return f.getExpression(); + } + + boost::any ToExpressionVisitor::visit(AtomicLabelFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(BinaryBooleanStateFormula const& f, boost::any const& data) const { + storm::expressions::Expression left = boost::any_cast(f.getLeftSubformula().accept(*this, data)); + storm::expressions::Expression right = boost::any_cast(f.getRightSubformula().accept(*this, data)); + switch (f.getOperator()) { + case BinaryBooleanStateFormula::OperatorType::And: + return left && right; + break; + case BinaryBooleanStateFormula::OperatorType::Or: + return left || right; + break; + } + } + + boost::any ToExpressionVisitor::visit(BooleanLiteralFormula const& f, boost::any const& data) const { + return std::static_pointer_cast(std::make_shared(f)); + } + + boost::any ToExpressionVisitor::visit(BoundedUntilFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(ConditionalFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(CumulativeRewardFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(EventuallyFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(TimeOperatorFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(GloballyFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(InstantaneousRewardFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(LongRunAverageOperatorFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(LongRunAverageRewardFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(NextFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(ProbabilityOperatorFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(RewardOperatorFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + boost::any ToExpressionVisitor::visit(UnaryBooleanStateFormula const& f, boost::any const& data) const { + storm::expressions::Expression subexpression = boost::any_cast(f.getSubformula().accept(*this, data)); + switch (f.getOperator()) { + case UnaryBooleanStateFormula::OperatorType::Not: + return !subexpression; + break; + } + } + + boost::any ToExpressionVisitor::visit(UntilFormula const& f, boost::any const& data) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "Cannot assemble expression from formula that contains illegal elements."); + } + + } +} diff --git a/src/logic/ToExpressionVisitor.h b/src/logic/ToExpressionVisitor.h new file mode 100644 index 000000000..a42258a46 --- /dev/null +++ b/src/logic/ToExpressionVisitor.h @@ -0,0 +1,39 @@ +#ifndef STORM_LOGIC_TOEXPRESSIONVISITOR_H_ +#define STORM_LOGIC_TOEXPRESSIONVISITOR_H_ + +#include "src/logic/FormulaVisitor.h" + +#include "src/storage/expressions/Expression.h" + +namespace storm { + namespace logic { + + class ToExpressionVisitor : public FormulaVisitor { + public: + storm::expressions::Expression toExpression(Formula const& f) const; + + virtual boost::any visit(AtomicExpressionFormula const& f, boost::any const& data) const override; + virtual boost::any visit(AtomicLabelFormula const& f, boost::any const& data) const override; + virtual boost::any visit(BinaryBooleanStateFormula const& f, boost::any const& data) const override; + virtual boost::any visit(BooleanLiteralFormula const& f, boost::any const& data) const override; + virtual boost::any visit(BoundedUntilFormula const& f, boost::any const& data) const override; + virtual boost::any visit(ConditionalFormula const& f, boost::any const& data) const override; + virtual boost::any visit(CumulativeRewardFormula const& f, boost::any const& data) const override; + virtual boost::any visit(EventuallyFormula const& f, boost::any const& data) const override; + virtual boost::any visit(TimeOperatorFormula const& f, boost::any const& data) const override; + virtual boost::any visit(GloballyFormula const& f, boost::any const& data) const override; + virtual boost::any visit(InstantaneousRewardFormula const& f, boost::any const& data) const override; + virtual boost::any visit(LongRunAverageOperatorFormula const& f, boost::any const& data) const override; + virtual boost::any visit(LongRunAverageRewardFormula const& f, boost::any const& data) const override; + virtual boost::any visit(NextFormula const& f, boost::any const& data) const override; + virtual boost::any visit(ProbabilityOperatorFormula const& f, boost::any const& data) const override; + virtual boost::any visit(RewardOperatorFormula const& f, boost::any const& data) const override; + virtual boost::any visit(UnaryBooleanStateFormula const& f, boost::any const& data) const override; + virtual boost::any visit(UntilFormula const& f, boost::any const& data) const override; + }; + + } +} + + +#endif /* STORM_LOGIC_TOEXPRESSIONVISITOR_H_ */ \ No newline at end of file diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp index a762618f8..d2f1c47bf 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp @@ -38,7 +38,6 @@ namespace storm { storm::logic::EventuallyFormula const& eventuallyFormula = checkTask.getFormula(); storm::logic::Formula const& subformula = eventuallyFormula.getSubformula(); STORM_LOG_THROW(program.isDeterministicModel() || checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "For nondeterministic systems, an optimization direction (min/max) must be given in the property."); - STORM_LOG_THROW(subformula.isAtomicExpressionFormula() || subformula.isAtomicLabelFormula(), storm::exceptions::NotSupportedException, "Learning engine can only deal with formulas of the form 'F \"label\"' or 'F expression'."); StateGeneration stateGeneration(program, variableInformation, getTargetStateExpression(subformula)); @@ -58,13 +57,8 @@ namespace storm { template storm::expressions::Expression SparseMdpLearningModelChecker::getTargetStateExpression(storm::logic::Formula const& subformula) const { - storm::expressions::Expression result; - if (subformula.isAtomicExpressionFormula()) { - result = subformula.asAtomicExpressionFormula().getExpression(); - } else { - result = program.getLabelExpression(subformula.asAtomicLabelFormula().getLabel()); - } - return result; + std::shared_ptr preparedSubformula = subformula.substitute(program.getLabelToExpressionMapping()); + return preparedSubformula->toExpression(); } template @@ -101,7 +95,7 @@ namespace storm { Statistics stats; bool convergenceCriterionMet = false; while (!convergenceCriterionMet) { - bool result = samplePathFromState(stateGeneration, explorationInformation, stack, bounds, stats); + bool result = samplePathFromInitialState(stateGeneration, explorationInformation, stack, bounds, stats); // If a terminal state was found, we update the probabilities along the path contained in the stack. if (result) { @@ -125,7 +119,7 @@ namespace storm { if (storm::settings::generalSettings().isShowStatisticsSet()) { std::cout << std::endl << "Learning summary -------------------------" << std::endl; - std::cout << "Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << " explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored, " << stats.numberOfTargetStates << " target states)" << std::endl; + std::cout << "Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << " explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored, " << stats.numberOfTargetStates << " target)" << std::endl; std::cout << "Sampling iterations: " << stats.iterations << std::endl; std::cout << "Maximal path length: " << stats.maxPathLength << std::endl; } @@ -134,11 +128,11 @@ namespace storm { } template - bool SparseMdpLearningModelChecker::samplePathFromState(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation, StateActionStack& stack, BoundValues& bounds, Statistics& stats) const { - + bool SparseMdpLearningModelChecker::samplePathFromInitialState(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation, StateActionStack& stack, BoundValues& bounds, Statistics& stats) const { // Start the search from the initial state. stack.push_back(std::make_pair(explorationInformation.getFirstInitialState(), 0)); + // As long as we didn't find a terminal (accepting or rejecting) state in the search, sample a new successor. bool foundTerminalState = false; while (!foundTerminalState) { StateType const& currentStateId = stack.back().first; @@ -165,7 +159,7 @@ namespace storm { if (!foundTerminalState) { // At this point, we can be sure that the state was expanded and that we can sample according to the // probabilities in the matrix. - uint32_t chosenAction = sampleMaxAction(currentStateId, explorationInformation, bounds); + uint32_t chosenAction = sampleActionOfState(currentStateId, explorationInformation, bounds); stack.back().second = chosenAction; STORM_LOG_TRACE("Sampled action " << chosenAction << " in state " << currentStateId << "."); @@ -242,7 +236,9 @@ namespace storm { explorationInformation.addRowsToMatrix(behavior.getNumberOfChoices()); ActionType currentAction = 0; - std::pair stateBounds(storm::utility::zero(), storm::utility::zero()); + + // Retrieve the lowest state bounds (wrt. to the current optimization direction). + std::pair stateBounds = getLowestBounds(explorationInformation.optimizationDirection); for (auto const& choice : behavior) { for (auto const& entry : choice) { @@ -251,7 +247,7 @@ namespace storm { std::pair actionBounds = computeBoundsOfAction(startRow + currentAction, explorationInformation, bounds); bounds.initializeBoundsForNextAction(actionBounds); - stateBounds = std::make_pair(std::max(stateBounds.first, actionBounds.first), std::max(stateBounds.second, actionBounds.second)); + stateBounds = combineBounds(explorationInformation.optimizationDirection, stateBounds, actionBounds); STORM_LOG_TRACE("Initializing bounds of action " << (startRow + currentAction) << " to " << bounds.getLowerBoundForAction(startRow + currentAction) << " and " << bounds.getUpperBoundForAction(startRow + currentAction) << "."); @@ -289,16 +285,12 @@ namespace storm { } template - uint32_t SparseMdpLearningModelChecker::sampleMaxAction(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { - StateType rowGroup = explorationInformation.getRowGroup(currentStateId); - - // First, determine all maximizing actions. - std::vector allMaxActions; - + uint32_t SparseMdpLearningModelChecker::sampleActionOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { // Determine the values of all available actions. std::vector> actionValues; - auto choicesInEcIt = explorationInformation.stateToLeavingChoicesOfEndComponent.find(currentStateId); - if (choicesInEcIt != explorationInformation.stateToLeavingChoicesOfEndComponent.end()) { + StateType rowGroup = explorationInformation.getRowGroup(currentStateId); + auto choicesInEcIt = explorationInformation.stateToLeavingActionsOfEndComponent.find(currentStateId); + if (choicesInEcIt != explorationInformation.stateToLeavingActionsOfEndComponent.end()) { STORM_LOG_TRACE("Sampling from actions leaving the previously detected EC."); for (auto const& row : *choicesInEcIt->second) { actionValues.push_back(std::make_pair(row, computeUpperBoundOfAction(row, explorationInformation, bounds))); @@ -313,8 +305,14 @@ namespace storm { STORM_LOG_ASSERT(!actionValues.empty(), "Values for actions must not be empty."); - std::sort(actionValues.begin(), actionValues.end(), [] (std::pair const& a, std::pair const& b) { return a.second > b.second; } ); + // Sort the actions wrt. to the optimization direction. + if (explorationInformation.optimizationDirection == storm::OptimizationDirection::Maximize) { + std::sort(actionValues.begin(), actionValues.end(), [] (std::pair const& a, std::pair const& b) { return a.second > b.second; } ); + } else { + std::sort(actionValues.begin(), actionValues.end(), [] (std::pair const& a, std::pair const& b) { return a.second < b.second; } ); + } + // Determine the first elements of the sorted range that agree on their value. auto end = ++actionValues.begin(); while (end != actionValues.end() && comparator.isEqual(actionValues.begin()->second, end->second)) { ++end; @@ -416,14 +414,14 @@ namespace storm { StateType originalState = relevantStates[stateAndChoices.first]; uint32_t originalRowGroup = explorationInformation.getRowGroup(originalState); - // TODO: This checks for a target state is a bit hackish and only works for max probabilities. - if (!containsTargetState && comparator.isOne(bounds.getLowerBoundForRowGroup(originalRowGroup, explorationInformation))) { + // Check whether a target state is contained in the MEC. + if (!containsTargetState && comparator.isOne(bounds.getLowerBoundForRowGroup(originalRowGroup))) { containsTargetState = true; } - + + // For each state, compute the actions that leave the MEC. auto includedChoicesIt = stateAndChoices.second.begin(); auto includedChoicesIte = stateAndChoices.second.end(); - for (auto action = explorationInformation.getStartRowOfGroup(originalRowGroup); action < explorationInformation.getStartRowOfGroup(originalRowGroup + 1); ++action) { if (includedChoicesIt != includedChoicesIte) { STORM_LOG_TRACE("Next (local) choice contained in MEC is " << (*includedChoicesIt - relevantStatesMatrix.getRowGroupIndices()[stateAndChoices.first])); @@ -441,7 +439,7 @@ namespace storm { } } - explorationInformation.stateToLeavingChoicesOfEndComponent[originalState] = leavingChoices; + explorationInformation.stateToLeavingActionsOfEndComponent[originalState] = leavingChoices; } // If one of the states of the EC is a target state, all states in the EC have probability 1. @@ -488,28 +486,6 @@ namespace storm { return result; } - template - ValueType SparseMdpLearningModelChecker::computeLowerBoundOfState(StateType const& state, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { - StateType group = explorationInformation.getRowGroup(state); - ValueType result = storm::utility::zero(); - for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) { - ValueType actionValue = computeLowerBoundOfAction(action, explorationInformation, bounds); - result = std::max(actionValue, result); - } - return result; - } - - template - ValueType SparseMdpLearningModelChecker::computeUpperBoundOfState(StateType const& state, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { - StateType group = explorationInformation.getRowGroup(state); - ValueType result = storm::utility::zero(); - for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) { - ValueType actionValue = computeUpperBoundOfAction(action, explorationInformation, bounds); - result = std::max(actionValue, result); - } - return result; - } - template std::pair SparseMdpLearningModelChecker::computeBoundsOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { // TODO: take into account self-loops? @@ -524,11 +500,10 @@ namespace storm { template std::pair SparseMdpLearningModelChecker::computeBoundsOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { StateType group = explorationInformation.getRowGroup(currentStateId); - std::pair result = std::make_pair(storm::utility::zero(), storm::utility::zero()); + std::pair result = getLowestBounds(explorationInformation.optimizationDirection); for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) { std::pair actionValues = computeBoundsOfAction(action, explorationInformation, bounds); - result.first = std::max(actionValues.first, result.first); - result.second = std::max(actionValues.second, result.second); + result = combineBounds(explorationInformation.optimizationDirection, result, actionValues); } return result; } @@ -541,10 +516,44 @@ namespace storm { stack.pop_back(); } } + + template + void SparseMdpLearningModelChecker::updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { + // Compute the new lower/upper values of the action. + std::pair newBoundsForAction = computeBoundsOfAction(action, explorationInformation, bounds); + + // And set them as the current value. + bounds.setBoundsForAction(action, newBoundsForAction); + + // Check if we need to update the values for the states. + if (explorationInformation.optimizationDirection == storm::OptimizationDirection::Maximize) { + bounds.setLowerBoundOfStateIfGreaterThanOld(state, explorationInformation, newBoundsForAction.first); + + StateType rowGroup = explorationInformation.getRowGroup(state); + if (newBoundsForAction.second < bounds.getUpperBoundForRowGroup(rowGroup)) { + if (explorationInformation.getRowGroupSize(rowGroup) > 1) { + newBoundsForAction.second = std::max(newBoundsForAction.second, computeBoundOverAllOtherActions(storm::OptimizationDirection::Maximize, state, action, explorationInformation, bounds)); + } + + bounds.setUpperBoundForRowGroup(rowGroup, newBoundsForAction.second); + } + } else { + bounds.setUpperBoundOfStateIfLessThanOld(state, explorationInformation, newBoundsForAction.second); + StateType rowGroup = explorationInformation.getRowGroup(state); + if (bounds.getLowerBoundForRowGroup(rowGroup) < newBoundsForAction.first) { + if (explorationInformation.getRowGroupSize(rowGroup) > 1) { + newBoundsForAction.first = std::min(newBoundsForAction.first, computeBoundOverAllOtherActions(storm::OptimizationDirection::Maximize, state, action, explorationInformation, bounds)); + } + + bounds.setLowerBoundForRowGroup(rowGroup, newBoundsForAction.first); + } + } + } + template - ValueType SparseMdpLearningModelChecker::computeUpperBoundOverAllOtherActions(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { - ValueType max = storm::utility::zero(); + ValueType SparseMdpLearningModelChecker::computeBoundOverAllOtherActions(storm::OptimizationDirection const& direction, StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { + ValueType bound = getLowestBound(explorationInformation.optimizationDirection); ActionType group = explorationInformation.getRowGroup(state); for (auto currentAction = explorationInformation.getStartRowOfGroup(group); currentAction < explorationInformation.getStartRowOfGroup(group + 1); ++currentAction) { @@ -552,30 +561,36 @@ namespace storm { continue; } - max = std::max(max, computeUpperBoundOfAction(currentAction, explorationInformation, bounds)); + if (direction == storm::OptimizationDirection::Maximize) { + bound = std::max(bound, computeUpperBoundOfAction(currentAction, explorationInformation, bounds)); + } else { + bound = std::min(bound, computeLowerBoundOfAction(currentAction, explorationInformation, bounds)); + } } - - return max; + return bound; } template - void SparseMdpLearningModelChecker::updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { - // Compute the new lower/upper values of the action. - std::pair newBoundsForAction = computeBoundsOfAction(action, explorationInformation, bounds); - - // And set them as the current value. - bounds.setBoundsForAction(action, newBoundsForAction); - - // Check if we need to update the values for the states. - bounds.setNewLowerBoundOfStateIfGreaterThanOld(state, explorationInformation, newBoundsForAction.first); - - StateType rowGroup = explorationInformation.getRowGroup(state); - if (newBoundsForAction.second < bounds.getUpperBoundForRowGroup(rowGroup)) { - if (explorationInformation.getRowGroupSize(rowGroup) > 1) { - newBoundsForAction.second = std::max(newBoundsForAction.second, computeUpperBoundOverAllOtherActions(state, action, explorationInformation, bounds)); - } - - bounds.setUpperBoundForState(state, explorationInformation, newBoundsForAction.second); + std::pair SparseMdpLearningModelChecker::getLowestBounds(storm::OptimizationDirection const& direction) const { + ValueType val = getLowestBound(direction); + return std::make_pair(val, val); + } + + template + ValueType SparseMdpLearningModelChecker::getLowestBound(storm::OptimizationDirection const& direction) const { + if (direction == storm::OptimizationDirection::Maximize) { + return storm::utility::zero(); + } else { + return storm::utility::one(); + } + } + + template + std::pair SparseMdpLearningModelChecker::combineBounds(storm::OptimizationDirection const& direction, std::pair const& bounds1, std::pair const& bounds2) const { + if (direction == storm::OptimizationDirection::Maximize) { + return std::make_pair(std::max(bounds1.first, bounds2.first), std::max(bounds1.second, bounds2.second)); + } else { + return std::make_pair(std::min(bounds1.first, bounds2.first), std::min(bounds1.second, bounds2.second)); } } diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h index 01ea2d323..c3a931913 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h @@ -101,7 +101,7 @@ namespace storm { storm::OptimizationDirection optimizationDirection; StateSet terminalStates; - std::unordered_map stateToLeavingChoicesOfEndComponent; + std::unordered_map stateToLeavingActionsOfEndComponent; void setInitialStates(std::vector const& initialStates) { stateStorage.initialStateIndices = initialStates; @@ -208,11 +208,11 @@ namespace storm { if (index == explorationInformation.getUnexploredMarker()) { return storm::utility::zero(); } else { - return getLowerBoundForRowGroup(index, explorationInformation); + return getLowerBoundForRowGroup(index); } } - ValueType getLowerBoundForRowGroup(StateType const& rowGroup, ExplorationInformation const& explorationInformation) const { + ValueType const& getLowerBoundForRowGroup(StateType const& rowGroup) const { return lowerBoundsPerState[rowGroup]; } @@ -257,11 +257,19 @@ namespace storm { } void setLowerBoundForState(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& value) { - lowerBoundsPerState[explorationInformation.getRowGroup(state)] = value; + setLowerBoundForRowGroup(explorationInformation.getRowGroup(state), value); + } + + void setLowerBoundForRowGroup(StateType const& group, ValueType const& value) { + lowerBoundsPerState[group] = value; } void setUpperBoundForState(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& value) { - upperBoundsPerState[explorationInformation.getRowGroup(state)] = value; + setUpperBoundForRowGroup(explorationInformation.getRowGroup(state), value); + } + + void setUpperBoundForRowGroup(StateType const& group, ValueType const& value) { + upperBoundsPerState[group] = value; } void setBoundsForAction(ActionType const& action, std::pair const& values) { @@ -275,7 +283,7 @@ namespace storm { upperBoundsPerState[rowGroup] = values.second; } - bool setNewLowerBoundOfStateIfGreaterThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newLowerValue) { + bool setLowerBoundOfStateIfGreaterThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newLowerValue) { StateType const& rowGroup = explorationInformation.getRowGroup(state); if (lowerBoundsPerState[rowGroup] < newLowerValue) { lowerBoundsPerState[rowGroup] = newLowerValue; @@ -284,7 +292,7 @@ namespace storm { return false; } - bool setNewUpperBoundOfStateIfLessThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newUpperValue) { + bool setUpperBoundOfStateIfLessThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newUpperValue) { StateType const& rowGroup = explorationInformation.getRowGroup(state); if (newUpperValue < upperBoundsPerState[rowGroup]) { upperBoundsPerState[rowGroup] = newUpperValue; @@ -300,11 +308,11 @@ namespace storm { std::tuple performLearningProcedure(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation) const; - bool samplePathFromState(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation, StateActionStack& stack, BoundValues& bounds, Statistics& stats) const; + bool samplePathFromInitialState(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation, StateActionStack& stack, BoundValues& bounds, Statistics& stats) const; bool exploreState(StateGeneration& stateGeneration, StateType const& currentStateId, storm::generator::CompressedState const& currentState, ExplorationInformation& explorationInformation, BoundValues& bounds, Statistics& stats) const; - uint32_t sampleMaxAction(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; + uint32_t sampleActionOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; StateType sampleSuccessorFromAction(ActionType const& chosenAction, ExplorationInformation const& explorationInformation) const; @@ -315,12 +323,14 @@ namespace storm { void updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; std::pair computeBoundsOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; - ValueType computeUpperBoundOverAllOtherActions(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; + ValueType computeBoundOverAllOtherActions(storm::OptimizationDirection const& direction, StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; std::pair computeBoundsOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeLowerBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeUpperBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; - ValueType computeLowerBoundOfState(StateType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; - ValueType computeUpperBoundOfState(StateType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; + + std::pair getLowestBounds(storm::OptimizationDirection const& direction) const; + ValueType getLowestBound(storm::OptimizationDirection const& direction) const; + std::pair combineBounds(storm::OptimizationDirection const& direction, std::pair const& bounds1, std::pair const& bounds2) const; // The program that defines the model to check. storm::prism::Program program; diff --git a/src/storage/prism/Program.cpp b/src/storage/prism/Program.cpp index 999cf2f30..b94f288c0 100644 --- a/src/storage/prism/Program.cpp +++ b/src/storage/prism/Program.cpp @@ -368,6 +368,14 @@ namespace storm { return this->labels[labelIndexPair->second].getStatePredicateExpression(); } + std::map Program::getLabelToExpressionMapping() const { + std::map result; + for (auto const& label : labels) { + result.emplace(label.getName(), label.getStatePredicateExpression()); + } + return result; + } + std::size_t Program::getNumberOfLabels() const { return this->getLabels().size(); } diff --git a/src/storage/prism/Program.h b/src/storage/prism/Program.h index 1d091c294..3e8245a72 100644 --- a/src/storage/prism/Program.h +++ b/src/storage/prism/Program.h @@ -379,6 +379,13 @@ namespace storm { */ storm::expressions::Expression const& getLabelExpression(std::string const& label) const; + /*! + * Retrieves a mapping from all labels in the program to their defining expressions. + * + * @return A mapping from label names to their expressions. + */ + std::map getLabelToExpressionMapping() const; + /*! * Retrieves the number of labels in the program. *