diff --git a/src/adapters/Smt2ExpressionAdapter.h b/src/adapters/Smt2ExpressionAdapter.h index 7ced9fc6b..a6e7396a1 100644 --- a/src/adapters/Smt2ExpressionAdapter.h +++ b/src/adapters/Smt2ExpressionAdapter.h @@ -24,7 +24,8 @@ namespace storm { * @param manager The manager that can be used to build expressions. * @param useReadableVarNames sets whether the expressions should use human readable names for the variables or the internal representation */ - Smt2ExpressionAdapter(storm::expressions::ExpressionManager& manager, bool useReadableVarNames) : manager(manager), useReadableVarNames(useReadableVarNames) { + Smt2ExpressionAdapter(storm::expressions::ExpressionManager& manager, bool useReadableVarNames) + : useReadableVarNames(useReadableVarNames) { declaredVariables.emplace_back(std::set()); } diff --git a/src/modelchecker/region/SparseDtmcRegionModelChecker.cpp b/src/modelchecker/region/SparseDtmcRegionModelChecker.cpp index 489ed4cbc..18b373d72 100644 --- a/src/modelchecker/region/SparseDtmcRegionModelChecker.cpp +++ b/src/modelchecker/region/SparseDtmcRegionModelChecker.cpp @@ -14,6 +14,7 @@ #include "src/settings/SettingsManager.h" #include "src/settings/modules/RegionSettings.h" #include "src/solver/OptimizationDirection.h" +#include "src/solver/stateelimination/MultiValueStateEliminator.h" #include "src/storage/sparse/StateType.h" #include "src/storage/FlexibleSparseMatrix.h" #include "src/utility/constants.h" @@ -114,6 +115,8 @@ namespace storm { //The states that we consider to eliminate storm::storage::BitVector considerToEliminate(submatrix.getRowCount(), true); considerToEliminate.set(initialState, false); + + std::vector statesToEliminate; for (auto const& state : considerToEliminate) { bool eliminateThisState=true; for(auto const& entry : flexibleTransitions.getRow(state)){ @@ -145,9 +148,17 @@ namespace storm { } } if(eliminateThisState){ - storm::storage::FlexibleSparseMatrix::eliminateState(flexibleTransitions, oneStepProbabilities, state, flexibleBackwardTransitions, stateRewards); subsystem.set(state,false); + statesToEliminate.push_back(state); } + + } + if(stateRewards) { + storm::solver::stateelimination::MultiValueStateEliminator eliminator(flexibleTransitions, flexibleBackwardTransitions, statesToEliminate, oneStepProbabilities, stateRewards.get()); + eliminator.eliminateAll(); + } else { + storm::solver::stateelimination::PrioritizedStateEliminator eliminator(flexibleTransitions, flexibleBackwardTransitions, statesToEliminate, oneStepProbabilities); + eliminator.eliminateAll(); } STORM_LOG_DEBUG("Eliminated " << subsystem.size() - subsystem.getNumberOfSetBits() << " of " << subsystem.size() << " states that had constant outgoing transitions."); diff --git a/src/modelchecker/region/SparseMdpRegionModelChecker.cpp b/src/modelchecker/region/SparseMdpRegionModelChecker.cpp index ce455cd85..cc913c52c 100644 --- a/src/modelchecker/region/SparseMdpRegionModelChecker.cpp +++ b/src/modelchecker/region/SparseMdpRegionModelChecker.cpp @@ -127,7 +127,8 @@ namespace storm { statesToEliminate.push_back(state); } } - storm::solver::stateelimination::PrioritizedStateEliminator(flexibleTransitions, flexibleBackwardTransitions, statesToEliminate, oneStepProbabilities); + storm::solver::stateelimination::PrioritizedStateEliminator eliminator(flexibleTransitions, flexibleBackwardTransitions, statesToEliminate, oneStepProbabilities); + eliminator.eliminateAll(); STORM_LOG_DEBUG("Eliminated " << subsystem.size() - subsystem.getNumberOfSetBits() << " of " << subsystem.size() << " states that had constant outgoing transitions."); //Build the simple model diff --git a/src/solver/stateelimination/MultiValueStateEliminator.cpp b/src/solver/stateelimination/MultiValueStateEliminator.cpp index c9e9acbd1..6076f6283 100644 --- a/src/solver/stateelimination/MultiValueStateEliminator.cpp +++ b/src/solver/stateelimination/MultiValueStateEliminator.cpp @@ -10,7 +10,12 @@ namespace storm { MultiValueStateEliminator::MultiValueStateEliminator(storm::storage::FlexibleSparseMatrix& transitionMatrix, storm::storage::FlexibleSparseMatrix& backwardTransitions, PriorityQueuePointer priorityQueue, std::vector& stateValues, std::vector& additionalStateValuesVector) : PrioritizedStateEliminator(transitionMatrix, backwardTransitions, priorityQueue, stateValues), additionalStateValues({std::ref(additionalStateValuesVector)}) { } - + + template + MultiValueStateEliminator::MultiValueStateEliminator(storm::storage::FlexibleSparseMatrix& transitionMatrix, storm::storage::FlexibleSparseMatrix& backwardTransitions, std::vector const& statesToEliminate, std::vector& stateValues, std::vector& additionalStateValuesVector) : PrioritizedStateEliminator(transitionMatrix, backwardTransitions, statesToEliminate, stateValues), additionalStateValues({std::ref(additionalStateValuesVector)}) { + + } + template void MultiValueStateEliminator::updateValue(storm::storage::sparse::state_type const& state, ValueType const& loopProbability) { this->stateValues[state] = storm::utility::simplify(loopProbability * this->stateValues[state]); @@ -26,6 +31,14 @@ namespace storm { additionalStateValueVectorRef.get()[predecessor] = storm::utility::simplify(additionalStateValueVectorRef.get()[predecessor] + storm::utility::simplify(probability * additionalStateValueVectorRef.get()[state])); } } + + template + void MultiValueStateEliminator::clearStateValues(storm::storage::sparse::state_type const& state) { + super::clearStateValues(state); + for(auto additionStateValueVectorRef : additionalStateValues) { + additionStateValueVectorRef.get()[state] = storm::utility::zero(); + } + } template class MultiValueStateEliminator; diff --git a/src/solver/stateelimination/MultiValueStateEliminator.h b/src/solver/stateelimination/MultiValueStateEliminator.h index fcd3daaa9..c7bb1b7b3 100644 --- a/src/solver/stateelimination/MultiValueStateEliminator.h +++ b/src/solver/stateelimination/MultiValueStateEliminator.h @@ -9,16 +9,23 @@ namespace storm { template class MultiValueStateEliminator : public PrioritizedStateEliminator { + private: + typedef PrioritizedStateEliminator super; public: typedef typename std::shared_ptr PriorityQueuePointer; typedef typename std::vector ValueTypeVector; - MultiValueStateEliminator(storm::storage::FlexibleSparseMatrix& transitionMatrix, storm::storage::FlexibleSparseMatrix& backwardTransitions, PriorityQueuePointer priorityQueue, std::vector& stateValues, std::vector& additionalStateValues); - + MultiValueStateEliminator(storm::storage::FlexibleSparseMatrix& transitionMatrix, storm::storage::FlexibleSparseMatrix& backwardTransitions, + PriorityQueuePointer priorityQueue, std::vector& stateValues, std::vector& additionalStateValues); + MultiValueStateEliminator(storm::storage::FlexibleSparseMatrix& transitionMatrix,storm::storage::FlexibleSparseMatrix& backwardTransitions, + std::vector const& statesToEliminate, std::vector& stateValues, std::vector& additionalStateValues); + // Instantiaton of virtual methods. void updateValue(storm::storage::sparse::state_type const& state, ValueType const& loopProbability) override; void updatePredecessor(storm::storage::sparse::state_type const& predecessor, ValueType const& probability, storm::storage::sparse::state_type const& state) override; - + + virtual void clearStateValues(storm::storage::sparse::state_type const& state) override; + private: std::vector>additionalStateValues; }; diff --git a/src/solver/stateelimination/PrioritizedStateEliminator.cpp b/src/solver/stateelimination/PrioritizedStateEliminator.cpp index 39586caea..a2ccbb869 100644 --- a/src/solver/stateelimination/PrioritizedStateEliminator.cpp +++ b/src/solver/stateelimination/PrioritizedStateEliminator.cpp @@ -34,6 +34,25 @@ namespace storm { void PrioritizedStateEliminator::updatePriority(storm::storage::sparse::state_type const& state) { priorityQueue->update(state); } + + template + void PrioritizedStateEliminator::eliminateAll(bool removeForwardTransitions) { + while (priorityQueue->hasNext()) { + storm::storage::sparse::state_type state = priorityQueue->pop(); + this->eliminateState(priorityQueue->pop(), removeForwardTransitions); + if (removeForwardTransitions) { + clearStateValues(state); + } +#ifdef STORM_DEV + STORM_LOG_ASSERT(checkConsistent(transitionMatrix, backwardTransitions), "The forward and backward transition matrices became inconsistent."); +#endif + } + } + + template + void PrioritizedStateEliminator::clearStateValues(storm::storage::sparse::state_type const &state) { + stateValues[state] = storm::utility::zero(); + } template class PrioritizedStateEliminator; diff --git a/src/solver/stateelimination/PrioritizedStateEliminator.h b/src/solver/stateelimination/PrioritizedStateEliminator.h index a20a769f0..28865f67b 100644 --- a/src/solver/stateelimination/PrioritizedStateEliminator.h +++ b/src/solver/stateelimination/PrioritizedStateEliminator.h @@ -18,10 +18,12 @@ namespace storm { PrioritizedStateEliminator(storm::storage::FlexibleSparseMatrix& transitionMatrix, storm::storage::FlexibleSparseMatrix& backwardTransitions, std::vector const& statesToEliminate, std::vector& stateValues); // Instantiaton of virtual methods. - void updateValue(storm::storage::sparse::state_type const& state, ValueType const& loopProbability) override; - void updatePredecessor(storm::storage::sparse::state_type const& predecessor, ValueType const& probability, storm::storage::sparse::state_type const& state) override; - void updatePriority(storm::storage::sparse::state_type const& state) override; - + virtual void updateValue(storm::storage::sparse::state_type const& state, ValueType const& loopProbability) override; + virtual void updatePredecessor(storm::storage::sparse::state_type const& predecessor, ValueType const& probability, storm::storage::sparse::state_type const& state) override; + virtual void updatePriority(storm::storage::sparse::state_type const& state) override; + + virtual void eliminateAll(bool eliminateForwardTransitions = true); + virtual void clearStateValues(storm::storage::sparse::state_type const& state); protected: PriorityQueuePointer priorityQueue; std::vector& stateValues;