From 1b7f150e76efb5b90b3997807775ea47dd939d8e Mon Sep 17 00:00:00 2001 From: TimQu Date: Tue, 18 Sep 2018 22:17:19 +0200 Subject: [PATCH] implemented functionality to rename reward model names --- src/storm/logic/Formula.cpp | 6 ++ src/storm/logic/Formula.h | 1 + .../RewardModelNameSubstitutionVisitor.cpp | 96 +++++++++++++++++++ .../RewardModelNameSubstitutionVisitor.h | 30 ++++++ src/storm/logic/TimeBoundType.h | 5 + src/storm/storage/jani/Property.cpp | 10 ++ src/storm/storage/jani/Property.h | 12 +++ 7 files changed, 160 insertions(+) create mode 100644 src/storm/logic/RewardModelNameSubstitutionVisitor.cpp create mode 100644 src/storm/logic/RewardModelNameSubstitutionVisitor.h diff --git a/src/storm/logic/Formula.cpp b/src/storm/logic/Formula.cpp index aa27d8dbf..e98806177 100644 --- a/src/storm/logic/Formula.cpp +++ b/src/storm/logic/Formula.cpp @@ -6,6 +6,7 @@ #include "storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h" #include "storm/logic/ExpressionSubstitutionVisitor.h" #include "storm/logic/LabelSubstitutionVisitor.h" +#include "storm/logic/RewardModelNameSubstitutionVisitor.h" #include "storm/logic/ToExpressionVisitor.h" namespace storm { @@ -458,6 +459,11 @@ namespace storm { return visitor.substitute(*this); } + std::shared_ptr Formula::substituteRewardModelNames(std::map const& rewardModelNameSubstitution) const { + RewardModelNameSubstitutionVisitor visitor(rewardModelNameSubstitution); + return visitor.substitute(*this); + } + storm::expressions::Expression Formula::toExpression(storm::expressions::ExpressionManager const& manager, std::map const& labelToExpressionMapping) const { ToExpressionVisitor visitor; if (labelToExpressionMapping.empty()) { diff --git a/src/storm/logic/Formula.h b/src/storm/logic/Formula.h index 44826ff9e..fc2570666 100644 --- a/src/storm/logic/Formula.h +++ b/src/storm/logic/Formula.h @@ -201,6 +201,7 @@ namespace storm { std::shared_ptr substitute(std::function const& expressionSubstitution) const; std::shared_ptr substitute(std::map const& labelSubstitution) const; std::shared_ptr substitute(std::map const& labelSubstitution) const; + std::shared_ptr substituteRewardModelNames(std::map const& rewardModelNameSubstitution) const; /*! * Takes the formula and converts it to an equivalent expression. The formula may contain atomic labels, but diff --git a/src/storm/logic/RewardModelNameSubstitutionVisitor.cpp b/src/storm/logic/RewardModelNameSubstitutionVisitor.cpp new file mode 100644 index 000000000..314e6c74d --- /dev/null +++ b/src/storm/logic/RewardModelNameSubstitutionVisitor.cpp @@ -0,0 +1,96 @@ +#include "storm/logic/RewardModelNameSubstitutionVisitor.h" +#include "storm/logic/Formulas.h" + +#include "storm/storage/jani/Model.h" +#include "storm/storage/jani/traverser/AssignmentsFinder.h" +#include "storm/utility/macros.h" + +#include "storm/exceptions/UnexpectedException.h" +#include "storm/exceptions/InvalidPropertyException.h" + +namespace storm { + namespace logic { + + RewardModelNameSubstitutionVisitor::RewardModelNameSubstitutionVisitor(std::map const& rewardModelNameMapping) : rewardModelNameMapping(rewardModelNameMapping) { + // Intentionally left empty + } + + std::shared_ptr RewardModelNameSubstitutionVisitor::substitute(Formula const& f) const { + boost::any result = f.accept(*this, boost::any()); + return boost::any_cast>(result); + } + + boost::any RewardModelNameSubstitutionVisitor::visit(BoundedUntilFormula const& f, boost::any const& data) const { + std::vector> lowerBounds, upperBounds; + std::vector timeBoundReferences; + for (uint64_t i = 0; i < f.getDimension(); ++i) { + if (f.hasLowerBound(i)) { + lowerBounds.emplace_back(TimeBound(f.isLowerBoundStrict(i), f.getLowerBound(i))); + } else { + lowerBounds.emplace_back(); + } + if (f.hasUpperBound(i)) { + upperBounds.emplace_back(TimeBound(f.isUpperBoundStrict(i), f.getUpperBound(i))); + } else { + upperBounds.emplace_back(); + } + auto const& tbr = f.getTimeBoundReference(i); + if (tbr.isRewardBound()) { + timeBoundReferences.emplace_back(getNewName(tbr.getRewardName()), tbr.getOptionalRewardAccumulation()); + } else { + timeBoundReferences.push_back(tbr); + } + } + if (f.hasMultiDimensionalSubformulas()) { + std::vector> leftSubformulas, rightSubformulas; + for (uint64_t i = 0; i < f.getDimension(); ++i) { + leftSubformulas.push_back(boost::any_cast>(f.getLeftSubformula(i).accept(*this, data))); + rightSubformulas.push_back(boost::any_cast>(f.getRightSubformula(i).accept(*this, data))); + } + return std::static_pointer_cast(std::make_shared(leftSubformulas, rightSubformulas, lowerBounds, upperBounds, timeBoundReferences)); + } else { + std::shared_ptr left = boost::any_cast>(f.getLeftSubformula().accept(*this, data)); + std::shared_ptr right = boost::any_cast>(f.getRightSubformula().accept(*this, data)); + return std::static_pointer_cast(std::make_shared(left, right, lowerBounds, upperBounds, timeBoundReferences)); + } + } + + boost::any RewardModelNameSubstitutionVisitor::visit(CumulativeRewardFormula const& f, boost::any const& data) const { + std::vector bounds; + std::vector timeBoundReferences; + for (uint64_t i = 0; i < f.getDimension(); ++i) { + bounds.emplace_back(TimeBound(f.isBoundStrict(i), f.getBound(i))); + storm::logic::TimeBoundReference tbr = f.getTimeBoundReference(i); + if (tbr.isRewardBound()) { + tbr = storm::logic::TimeBoundReference(getNewName(tbr.getRewardName()), tbr.getOptionalRewardAccumulation()); + } + timeBoundReferences.push_back(std::move(tbr)); + } + if (f.hasRewardAccumulation()) { + return std::static_pointer_cast(std::make_shared(bounds, timeBoundReferences, f.getRewardAccumulation())); + } else { + return std::static_pointer_cast(std::make_shared(bounds, timeBoundReferences)); + } + } + + boost::any RewardModelNameSubstitutionVisitor::visit(RewardOperatorFormula const& f, boost::any const& data) const { + std::shared_ptr subformula = boost::any_cast>(f.getSubformula().accept(*this, data)); + if (f.hasRewardModelName()) { + return std::static_pointer_cast(std::make_shared(subformula, getNewName(f.getRewardModelName()), f.getOperatorInformation())); + } else { + return std::static_pointer_cast(std::make_shared(subformula, boost::none, f.getOperatorInformation())); + } + } + + std::string const& RewardModelNameSubstitutionVisitor::getNewName(std::string const& oldName) const { + auto nameIt = rewardModelNameMapping.find(oldName); + if (nameIt == rewardModelNameMapping.end()) { + return oldName; + } else { + return nameIt->second; + } + } + + + } +} diff --git a/src/storm/logic/RewardModelNameSubstitutionVisitor.h b/src/storm/logic/RewardModelNameSubstitutionVisitor.h new file mode 100644 index 000000000..759238738 --- /dev/null +++ b/src/storm/logic/RewardModelNameSubstitutionVisitor.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "storm/logic/CloneVisitor.h" + +#include "storm/storage/expressions/Expression.h" + +namespace storm { + namespace logic { + + class RewardModelNameSubstitutionVisitor : public CloneVisitor { + public: + RewardModelNameSubstitutionVisitor(std::map const& rewardModelNameMapping); + + std::shared_ptr substitute(Formula const& f) const; + + virtual boost::any visit(BoundedUntilFormula const& f, boost::any const& data) const override; + virtual boost::any visit(CumulativeRewardFormula const& f, boost::any const& data) const override; + virtual boost::any visit(RewardOperatorFormula const& f, boost::any const& data) const override; + + private: + + std::string const& getNewName(std::string const& oldName) const; + + std::map const& rewardModelNameMapping; + }; + + } +} diff --git a/src/storm/logic/TimeBoundType.h b/src/storm/logic/TimeBoundType.h index ec9a753b0..149f76215 100644 --- a/src/storm/logic/TimeBoundType.h +++ b/src/storm/logic/TimeBoundType.h @@ -59,6 +59,11 @@ namespace storm { return rewardAccumulation.get(); } + boost::optional const& getOptionalRewardAccumulation() const { + assert(isRewardBound()); + return rewardAccumulation; + } + }; diff --git a/src/storm/storage/jani/Property.cpp b/src/storm/storage/jani/Property.cpp index 2da62b33e..df6a71850 100644 --- a/src/storm/storage/jani/Property.cpp +++ b/src/storm/storage/jani/Property.cpp @@ -38,6 +38,16 @@ namespace storm { return Property(name, filterExpression.substituteLabels(substitution), comment); } + Property Property::substituteRewardModelNames(std::map const& rewardModelNameSubstitution) const { + return Property(name, filterExpression.substituteRewardModelNames(rewardModelNameSubstitution), comment); + } + + Property Property::clone() const { + return Property(name, filterExpression.clone(), comment); + } + + + FilterExpression const& Property::getFilter() const { return this->filterExpression; } diff --git a/src/storm/storage/jani/Property.h b/src/storm/storage/jani/Property.h index 140afa3e7..e552f9940 100644 --- a/src/storm/storage/jani/Property.h +++ b/src/storm/storage/jani/Property.h @@ -5,6 +5,7 @@ #include "storm/modelchecker/results/FilterType.h" #include "storm/logic/Formulas.h" #include "storm/logic/FragmentSpecification.h" +#include "storm/logic/CloneVisitor.h" #include "storm/utility/macros.h" #include "storm/exceptions/InvalidArgumentException.h" @@ -64,6 +65,15 @@ namespace storm { return FilterExpression(formula->substitute(labelSubstitution), ft, statesFormula->substitute(labelSubstitution)); } + FilterExpression substituteRewardModelNames(std::map const& rewardModelNameSubstitution) const { + return FilterExpression(formula->substituteRewardModelNames(rewardModelNameSubstitution), ft, statesFormula->substituteRewardModelNames(rewardModelNameSubstitution)); + } + + FilterExpression clone() const { + storm::logic::CloneVisitor cv; + return FilterExpression(cv.clone(*formula), ft, cv.clone(*statesFormula)); + } + private: // For now, we assume that the states are always the initial states. std::shared_ptr formula; @@ -111,6 +121,8 @@ namespace storm { Property substitute(std::map const& substitution) const; Property substitute(std::function const& substitutionFunction) const; Property substituteLabels(std::map const& labelSubstitution) const; + Property substituteRewardModelNames(std::map const& rewardModelNameSubstitution) const; + Property clone() const; FilterExpression const& getFilter() const;