From 8142a8e004f8c0fa037555bf43ead3d25baf8e3e Mon Sep 17 00:00:00 2001 From: sjunges Date: Thu, 17 Apr 2014 14:36:57 +0200 Subject: [PATCH] some fixes for using something different from doubles for templated value type :) Former-commit-id: d26d06b2659c0861c4e904bed40d64a925595cc8 --- src/adapters/ExplicitModelAdapter.h | 54 +++++----- src/models/Ctmdp.h | 4 +- src/models/Dtmc.h | 27 +++-- src/models/Mdp.h | 6 +- .../expressions/ExpressionEvaluation.h | 99 +++++++++++++++++++ src/utility/constants.h | 16 +++ 6 files changed, 164 insertions(+), 42 deletions(-) create mode 100644 src/storage/expressions/ExpressionEvaluation.h diff --git a/src/adapters/ExplicitModelAdapter.h b/src/adapters/ExplicitModelAdapter.h index 777807c06..c1569be9f 100644 --- a/src/adapters/ExplicitModelAdapter.h +++ b/src/adapters/ExplicitModelAdapter.h @@ -30,10 +30,11 @@ #include "src/settings/Settings.h" #include "src/exceptions/ExceptionMacros.h" #include "src/exceptions/WrongFormatException.h" +#include "src/storage/expressions/ExpressionEvaluation.h" namespace storm { namespace adapters { - + using namespace storm::utility::prism; template @@ -321,7 +322,7 @@ namespace storm { return result; } - static std::list> getUnlabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue& stateQueue) { + static std::list> getUnlabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue& stateQueue, expressions::ExpressionEvaluation& eval) { std::list> result; StateType const* currentState = stateInformation.reachableStates[stateIndex]; @@ -346,7 +347,7 @@ namespace storm { Choice& choice = result.back(); choice.addChoiceLabel(command.getGlobalIndex()); - double probabilitySum = 0; + ValueType probabilitySum = utility::constantZero(); // Iterate over all updates of the current command. for (uint_fast64_t k = 0; k < command.getNumberOfUpdates(); ++k) { storm::prism::Update const& update = command.getUpdate(k); @@ -360,7 +361,7 @@ namespace storm { } // Update the choice by adding the probability/target state to it. - double probabilityToAdd = update.getLikelihoodExpression().evaluateAsDouble(currentState); + ValueType probabilityToAdd = eval.evaluate(update.getLikelihoodExpression(),currentState); probabilitySum += probabilityToAdd; boost::container::flat_set labels; labels.insert(update.getGlobalIndex()); @@ -368,14 +369,14 @@ namespace storm { } // Check that the resulting distribution is in fact a distribution. - LOG_THROW(std::abs(1 - probabilitySum) < storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble(), storm::exceptions::WrongFormatException, "Probabilities do not sum to one for command '" << command << "'."); + LOG_THROW(!storm::utility::isOne(probabilitySum), storm::exceptions::WrongFormatException, "Probabilities do not sum to one for command '" << command << "'."); } } return result; } - static std::list> getLabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue& stateQueue) { + static std::list> getLabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue& stateQueue, expressions::ExpressionEvaluation& eval) { std::list> result; for (std::string const& action : program.getActions()) { @@ -395,9 +396,9 @@ namespace storm { // As long as there is one feasible combination of commands, keep on expanding it. bool done = false; while (!done) { - std::unordered_map, StateHash, StateCompare>* currentTargetStates = new std::unordered_map, StateHash, StateCompare>(); - std::unordered_map, StateHash, StateCompare>* newTargetStates = new std::unordered_map, StateHash, StateCompare>(); - (*currentTargetStates)[new StateType(*currentState)] = storm::storage::LabeledValues(1.0); + std::unordered_map, StateHash, StateCompare>* currentTargetStates = new std::unordered_map, StateHash, StateCompare>(); + std::unordered_map, StateHash, StateCompare>* newTargetStates = new std::unordered_map, StateHash, StateCompare>(); + (*currentTargetStates)[new StateType(*currentState)] = storm::storage::LabeledValues(ValueType(1)); // FIXME: This does not check whether a global variable is written multiple times. While the // behaviour for this is undefined anyway, a warning should be issued in that case. @@ -410,9 +411,9 @@ namespace storm { for (auto const& stateProbabilityPair : *currentTargetStates) { StateType* newTargetState = applyUpdate(variableInformation, stateProbabilityPair.first, currentState, update); - storm::storage::LabeledValues newProbability; + storm::storage::LabeledValues newProbability; - double updateProbability = update.getLikelihoodExpression().evaluateAsDouble(currentState); + ValueType updateProbability = eval.evaluate(update.getLikelihoodExpression(),currentState); for (auto const& valueLabelSetPair : stateProbabilityPair.second) { // Copy the label set, so we can modify it. boost::container::flat_set newLabelSet = valueLabelSetPair.second; @@ -441,7 +442,7 @@ namespace storm { delete currentTargetStates; currentTargetStates = newTargetStates; - newTargetStates = new std::unordered_map, StateHash, StateCompare>(); + newTargetStates = new std::unordered_map, StateHash, StateCompare>(); } } @@ -458,7 +459,7 @@ namespace storm { choice.addChoiceLabel(iteratorList[i]->getGlobalIndex()); } - double probabilitySum = 0; + ValueType probabilitySum = utility::constantZero(); for (auto const& stateProbabilityPair : *newTargetStates) { std::pair flagTargetStateIndexPair = getOrAddStateIndex(stateProbabilityPair.first, stateInformation); @@ -474,7 +475,7 @@ namespace storm { } // Check that the resulting distribution is in fact a distribution. - if (std::abs(1 - probabilitySum) > storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()) { + if (!storm::utility::isOne(probabilitySum)) { LOG4CPLUS_ERROR(logger, "Sum of update probabilities do not some to one for some command."); throw storm::exceptions::WrongFormatException() << "Sum of update probabilities do not some to one for some command."; } @@ -519,7 +520,7 @@ namespace storm { * @return A tuple containing a vector with all rows at which the nondeterministic choices of each state begin * and a vector containing the labels associated with each choice. */ - static std::vector> buildMatrices(storm::prism::Program const& program, VariableInformation const& variableInformation, std::vector const& transitionRewards, StateInformation& stateInformation, bool deterministicModel, storm::storage::SparseMatrixBuilder& transitionMatrixBuilder, storm::storage::SparseMatrixBuilder& transitionRewardMatrixBuilder) { + static std::vector> buildMatrices(storm::prism::Program const& program, VariableInformation const& variableInformation, std::vector const& transitionRewards, StateInformation& stateInformation, bool deterministicModel, storm::storage::SparseMatrixBuilder& transitionMatrixBuilder, storm::storage::SparseMatrixBuilder& transitionRewardMatrixBuilder, expressions::ExpressionEvaluation& eval) { std::vector> choiceLabels; // Initialize a queue and insert the initial state. @@ -550,8 +551,8 @@ namespace storm { uint_fast64_t currentState = stateQueue.front(); // Retrieve all choices for the current state. - std::list> allUnlabeledChoices = getUnlabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue); - std::list> allLabeledChoices = getLabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue); + std::list> allUnlabeledChoices = getUnlabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue, eval); + std::list> allLabeledChoices = getLabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue, eval); uint_fast64_t totalNumberOfChoices = allUnlabeledChoices.size() + allLabeledChoices.size(); @@ -582,7 +583,7 @@ namespace storm { // Now add all rewards that match this choice. for (auto const& transitionReward : transitionRewards) { if (transitionReward.getActionName() == "" && transitionReward.getStatePredicateExpression().evaluateAsBool(stateInformation.reachableStates.at(currentState))) { - stateToRewardMap[stateProbabilityPair.first] += ValueType(transitionReward.getRewardValueExpression().evaluateAsDouble(stateInformation.reachableStates.at(currentState))); + stateToRewardMap[stateProbabilityPair.first] += eval.evaluate(transitionReward.getRewardValueExpression(),stateInformation.reachableStates.at(currentState)); } } } @@ -595,7 +596,7 @@ namespace storm { // Now add all rewards that match this choice. for (auto const& transitionReward : transitionRewards) { if (transitionReward.getActionName() == choice.getActionLabel() && transitionReward.getStatePredicateExpression().evaluateAsBool(stateInformation.reachableStates.at(currentState))) { - stateToRewardMap[stateProbabilityPair.first] += ValueType(transitionReward.getRewardValueExpression().evaluateAsDouble(stateInformation.reachableStates.at(currentState))); + stateToRewardMap[stateProbabilityPair.first] += eval.evaluate(transitionReward.getRewardValueExpression(),stateInformation.reachableStates.at(currentState)); } } } @@ -633,7 +634,7 @@ namespace storm { // Now add all rewards that match this choice. for (auto const& transitionReward : transitionRewards) { if (transitionReward.getActionName() == "" && transitionReward.getStatePredicateExpression().evaluateAsBool(stateInformation.reachableStates.at(currentState))) { - stateToRewardMap[stateProbabilityPair.first] += ValueType(transitionReward.getRewardValueExpression().evaluateAsDouble(stateInformation.reachableStates.at(currentState))); + stateToRewardMap[stateProbabilityPair.first] += eval.evaluate(transitionReward.getRewardValueExpression(),stateInformation.reachableStates.at(currentState)); } } @@ -660,7 +661,7 @@ namespace storm { // Now add all rewards that match this choice. for (auto const& transitionReward : transitionRewards) { if (transitionReward.getActionName() == choice.getActionLabel() && transitionReward.getStatePredicateExpression().evaluateAsBool(stateInformation.reachableStates.at(currentState))) { - stateToRewardMap[stateProbabilityPair.first] += ValueType(transitionReward.getRewardValueExpression().evaluateAsDouble(stateInformation.reachableStates.at(currentState))); + stateToRewardMap[stateProbabilityPair.first] += eval.evaluate(transitionReward.getRewardValueExpression(),stateInformation.reachableStates.at(currentState)); } } @@ -694,6 +695,7 @@ namespace storm { */ static ModelComponents buildModelComponents(storm::prism::Program const& program, std::string const& rewardModelName) { ModelComponents modelComponents; + expressions::ExpressionEvaluation eval; VariableInformation variableInformation; for (auto const& integerVariable : program.getGlobalIntegerVariables()) { @@ -718,7 +720,7 @@ namespace storm { // Build the transition and reward matrices. storm::storage::SparseMatrixBuilder transitionMatrixBuilder(0, 0, 0, !deterministicModel, 0); storm::storage::SparseMatrixBuilder transitionRewardMatrixBuilder(0, 0, 0, !deterministicModel, 0); - modelComponents.choiceLabeling = buildMatrices(program, variableInformation, rewardModel.getTransitionRewards(), stateInformation, deterministicModel, transitionMatrixBuilder, transitionRewardMatrixBuilder); + modelComponents.choiceLabeling = buildMatrices(program, variableInformation, rewardModel.getTransitionRewards(), stateInformation, deterministicModel, transitionMatrixBuilder, transitionRewardMatrixBuilder, eval); // Finalize the resulting matrices. modelComponents.transitionMatrix = transitionMatrixBuilder.build(); @@ -728,7 +730,7 @@ namespace storm { modelComponents.stateLabeling = buildStateLabeling(program, variableInformation, stateInformation); // Finally, construct the state rewards. - modelComponents.stateRewards = buildStateRewards(rewardModel.getStateRewards(), stateInformation); + modelComponents.stateRewards = buildStateRewards(rewardModel.getStateRewards(), stateInformation, eval); // After everything has been created, we can delete the states. for (auto state : stateInformation.reachableStates) { @@ -780,20 +782,22 @@ namespace storm { * @param stateInformation Information about the state space. * @return A vector containing the state rewards for the state space. */ - static std::vector buildStateRewards(std::vector const& rewards, StateInformation const& stateInformation) { + static std::vector buildStateRewards(std::vector const& rewards, StateInformation const& stateInformation, expressions::ExpressionEvaluation& eval) { std::vector result(stateInformation.reachableStates.size()); for (uint_fast64_t index = 0; index < stateInformation.reachableStates.size(); index++) { result[index] = ValueType(0); for (auto const& reward : rewards) { // Add this reward to the state if the state is included in the state reward. if (reward.getStatePredicateExpression().evaluateAsBool(stateInformation.reachableStates[index])) { - result[index] += ValueType(reward.getRewardValueExpression().evaluateAsDouble(stateInformation.reachableStates[index])); + result[index] += eval.evaluate(reward.getRewardValueExpression(),stateInformation.reachableStates[index]); } } } return result; } }; + + } // namespace adapters } // namespace storm diff --git a/src/models/Ctmdp.h b/src/models/Ctmdp.h index 4e690d871..4c5d7da79 100644 --- a/src/models/Ctmdp.h +++ b/src/models/Ctmdp.h @@ -127,12 +127,10 @@ private: */ bool checkValidityOfProbabilityMatrix() { // Get the settings object to customize linear solving. - storm::settings::Settings* s = storm::settings::Settings::getInstance(); - double precision = s->getOptionByLongName("precision").getArgument(0).getValueAsDouble(); for (uint_fast64_t row = 0; row < this->getTransitionMatrix().getRowCount(); row++) { T sum = this->getTransitionMatrix().getRowSum(row); if (sum == 0) continue; - if (std::abs(sum - 1) > precision) return false; + if (storm::utility::isOne(sum)) return false; } return true; } diff --git a/src/models/Dtmc.h b/src/models/Dtmc.h index 8307aa259..f9a3e12a5 100644 --- a/src/models/Dtmc.h +++ b/src/models/Dtmc.h @@ -46,8 +46,8 @@ public: * @param optionalChoiceLabeling A vector that represents the labels associated with the choices of each state. */ Dtmc(storm::storage::SparseMatrix const& probabilityMatrix, storm::models::AtomicPropositionsLabeling const& stateLabeling, - boost::optional> const& optionalStateRewardVector, boost::optional> const& optionalTransitionRewardMatrix, - boost::optional>> const& optionalChoiceLabeling) + boost::optional> const& optionalStateRewardVector = {}, boost::optional> const& optionalTransitionRewardMatrix = {}, + boost::optional>> const& optionalChoiceLabeling = {}) : AbstractDeterministicModel(probabilityMatrix, stateLabeling, optionalStateRewardVector, optionalTransitionRewardMatrix, optionalChoiceLabeling) { if (!this->checkValidityOfProbabilityMatrix()) { LOG4CPLUS_ERROR(logger, "Probability matrix is invalid."); @@ -190,11 +190,11 @@ public: // The number of transitions of the new Dtmc is the number of transitions transfered // from the old one plus one transition for each state to s_b. - storm::storage::SparseMatrixBuilder newMatBuilder(newStateCount, subSysTransitionCount + newStateCount); + storm::storage::SparseMatrixBuilder newMatBuilder(newStateCount,newStateCount,subSysTransitionCount + newStateCount); // Now fill the matrix. newRow = 0; - T rest = 0; + T rest = utility::constantZero(); for(uint_fast64_t row = 0; row < origMat.getRowCount(); ++row) { if(subSysStates.get(row)){ // Transfer transitions @@ -299,8 +299,7 @@ private: */ bool checkValidityOfProbabilityMatrix() { // Get the settings object to customize linear solving. - storm::settings::Settings* s = storm::settings::Settings::getInstance(); - double precision = s->getOptionByLongName("precision").getArgument(0).getValueAsDouble(); + if (this->getTransitionMatrix().getRowCount() != this->getTransitionMatrix().getColumnCount()) { // not square @@ -308,20 +307,28 @@ private: return false; } for (uint_fast64_t row = 0; row < this->getTransitionMatrix().getRowCount(); ++row) { - T sum = this->getTransitionMatrix().getRowSum(row); - - if (sum == 0) { + T sum = this->getTransitionMatrix().getRowSum(row); + + if (sum == T(0)) { + + LOG4CPLUS_ERROR(logger, "Row " << row << " is a deadlock (sum == " << sum << ")."); return false; } - if (std::abs(sum - 1) > precision) { + if (storm::utility::isOne(sum)) { LOG4CPLUS_ERROR(logger, "Row " << row << " has sum " << sum << "."); return false; } } return true; } + + + + }; + + } // namespace models } // namespace storm diff --git a/src/models/Mdp.h b/src/models/Mdp.h index 61858fefa..73a73a1d9 100644 --- a/src/models/Mdp.h +++ b/src/models/Mdp.h @@ -198,16 +198,14 @@ private: */ bool checkValidityOfProbabilityMatrix() { // Get the settings object to customize linear solving. - storm::settings::Settings* s = storm::settings::Settings::getInstance(); - double precision = s->getOptionByLongName("precision").getArgument(0).getValueAsDouble(); for (uint_fast64_t row = 0; row < this->getTransitionMatrix().getRowCount(); row++) { T sum = this->getTransitionMatrix().getRowSum(row); if (sum == 0) continue; - if (std::abs(sum - 1) > precision) { + if (!storm::utility::isOne(sum)) { return false; } - } + } return true; } }; diff --git a/src/storage/expressions/ExpressionEvaluation.h b/src/storage/expressions/ExpressionEvaluation.h new file mode 100644 index 000000000..99d847555 --- /dev/null +++ b/src/storage/expressions/ExpressionEvaluation.h @@ -0,0 +1,99 @@ +/** + * @file: ExpressionEvaluation.h + * @author: Sebastian Junges + * + * @since April 4, 2014 + */ + +#ifndef STORM_STORAGE_EXPRESSIONS_EXPRESSIONEVALUATION_H_ +#define STORM_STORAGE_EXPRESSIONS_EXPRESSIONEVALUATION_H_ + +#include "ExpressionVisitor.h" +#include "BaseExpression.h" + +namespace storm { +namespace expressions { + + template + struct StateType + { + typedef int type; + }; + + template + class ExpressionEvaluationVisitor : public ExpressionVisitor + { + public: + ExpressionEvaluationVisitor(S* sharedState) + : mSharedState(sharedState) + { + + } + + virtual void visit(IfThenElseExpression const* expression) = 0; + virtual void visit(BinaryBooleanFunctionExpression const* expression) = 0; + virtual void visit(BinaryNumericalFunctionExpression const* expression) = 0; + virtual void visit(BinaryRelationExpression const* expression) = 0; + virtual void visit(BooleanConstantExpression const* expression) = 0; + virtual void visit(DoubleConstantExpression const* expression) = 0; + virtual void visit(IntegerConstantExpression const* expression) = 0; + virtual void visit(VariableExpression const* expression) = 0; + virtual void visit(UnaryBooleanFunctionExpression const* expression) = 0; + virtual void visit(UnaryNumericalFunctionExpression const* expression) = 0; + virtual void visit(BooleanLiteralExpression const* expression) = 0; + virtual void visit(IntegerLiteralExpression const* expression) = 0; + virtual void visit(DoubleLiteralExpression const* expression) = 0; + + const T& value() const + { + return mValue; + } + + private: + S* mSharedState; + T mValue; + }; + + template + class ExpressionEvaluation + { + public: + ExpressionEvaluation() : mState() + { + + } + + + T evaluate(Expression const& expr, storm::expressions::SimpleValuation const*) + { + ExpressionEvaluationVisitor::type>* visitor = new ExpressionEvaluationVisitor::type>(&mState); + expr.getBaseExpression().accept(visitor); + T result = visitor->value(); + delete visitor; + return result; + } + + protected: + typename StateType::type mState; + }; + + /** + * For doubles, we keep using the getValueAs from the expressions, as this should be more efficient. + */ + template<> + class ExpressionEvaluation + { + public: + double evaluate(Expression const& expr, storm::expressions::SimpleValuation const* val) const + { + return expr.evaluateAsDouble(val); + } + }; + + + + +} +} + +#endif \ No newline at end of file diff --git a/src/utility/constants.h b/src/utility/constants.h index e8a194619..2bc527d93 100644 --- a/src/utility/constants.h +++ b/src/utility/constants.h @@ -21,6 +21,7 @@ #include "src/exceptions/InvalidArgumentException.h" #include "src/storage/BitVector.h" #include "src/storage/LabeledValues.h" +#include "src/settings/Settings.h" namespace storm { @@ -202,6 +203,21 @@ inline storm::storage::LabeledValues constantInfinity() { /*! @endcond */ +template +inline bool isOne(T sum) +{ + return (sum-T(1)).isZero(); +} + +template<> +inline bool isOne(double sum) +{ + storm::settings::Settings* s = storm::settings::Settings::getInstance(); + double precision = s->getOptionByLongName("precision").getArgument(0).getValueAsDouble(); + return std::abs(sum - 1) < precision; +} + + } //namespace utility } //namespace storm