diff --git a/src/storm/solver/stateelimination/EliminatorBase.cpp b/src/storm/solver/stateelimination/EliminatorBase.cpp index 88b77c764..7eeb4eb27 100644 --- a/src/storm/solver/stateelimination/EliminatorBase.cpp +++ b/src/storm/solver/stateelimination/EliminatorBase.cpp @@ -61,11 +61,11 @@ namespace storm { entryIt->setValue(storm::utility::simplify((ValueType) (entryIt->getValue() * columnValue))); } } - updateValue(column, columnValue); + updateValue(row, columnValue); } // Now substitute the row entries in all other rows that contain an element whose column is the current row. - FlexibleRowType& elementsWithEntryInColumnEqualRow = transposedMatrix.getRow(row); + FlexibleRowType& elementsWithEntryInColumnEqualRow = transposedMatrix.getRow(column); // In case we have a constrained elimination, we need to keep track of the rows that keep their value // in the column equal to the current row. @@ -169,7 +169,7 @@ namespace storm { predecessorForwardTransitions = std::move(newSuccessors); STORM_LOG_TRACE("Fixed new next-state probabilities of predecessor state " << predecessor << "."); - updatePredecessor(predecessor, multiplyFactor, column); + updatePredecessor(predecessor, multiplyFactor, row); STORM_LOG_TRACE("Updating priority of predecessor."); updatePriority(predecessor); @@ -187,7 +187,7 @@ namespace storm { // Delete the current state as a predecessor of the successor state only if we are going to remove the // current state's forward transitions. if (clearRow) { - FlexibleRowIterator elimIt = std::find_if(successorBackwardTransitions.begin(), successorBackwardTransitions.end(), [&](storm::storage::MatrixEntry::index_type, typename storm::storage::FlexibleSparseMatrix::value_type> const& a) { return a.getColumn() == column; }); + FlexibleRowIterator elimIt = std::find_if(successorBackwardTransitions.begin(), successorBackwardTransitions.end(), [&](storm::storage::MatrixEntry::index_type, typename storm::storage::FlexibleSparseMatrix::value_type> const& a) { return a.getColumn() == row; }); STORM_LOG_ASSERT(elimIt != successorBackwardTransitions.end(), "Expected a proper backward transition from " << successorEntry.getColumn() << " to " << column << ", but found none."); successorBackwardTransitions.erase(elimIt); } diff --git a/src/storm/solver/stateelimination/NondeterministicModelStateEliminator.cpp b/src/storm/solver/stateelimination/NondeterministicModelStateEliminator.cpp new file mode 100644 index 000000000..300e723ba --- /dev/null +++ b/src/storm/solver/stateelimination/NondeterministicModelStateEliminator.cpp @@ -0,0 +1,39 @@ +#include "storm/solver/stateelimination/NondeterministicModelStateEliminator.h" + +#include "storm/utility/macros.h" +#include "storm/utility/constants.h" + +#include "storm/exceptions/InvalidArgumentException.h" + +namespace storm { + namespace solver { + namespace stateelimination { + + template + NondeterministicModelStateEliminator::NondeterministicModelStateEliminator(storm::storage::FlexibleSparseMatrix& transitionMatrix, storm::storage::FlexibleSparseMatrix& backwardTransitions, std::vector& rowValues) + : StateEliminator(transitionMatrix, backwardTransitions), rowValues(rowValues) { + + STORM_LOG_THROW(transitionMatrix.getRowCount() == backwardTransitions.getColumnCount() && transitionMatrix.getColumnCount() == backwardTransitions.getRowCount(), storm::exceptions::InvalidArgumentException, "Invalid matrix dimensions of forward/backwards transition matrices."); + STORM_LOG_THROW(rowValues.size() == transitionMatrix.getRowCount(), storm::exceptions::InvalidArgumentException, "Invalid size of row value vector"); + // Intentionally left empty + } + + template + void NondeterministicModelStateEliminator::updateValue(storm::storage::sparse::state_type const& row, ValueType const& loopProbability) { + rowValues[row] = storm::utility::simplify((ValueType) (loopProbability * rowValues[row])); + } + + template + void NondeterministicModelStateEliminator::updatePredecessor(storm::storage::sparse::state_type const& predecessorRow, ValueType const& probability, storm::storage::sparse::state_type const& row) { + rowValues[predecessorRow] = storm::utility::simplify((ValueType) (rowValues[predecessorRow] + storm::utility::simplify((ValueType) (probability * rowValues[row])))); + } + + template class NondeterministicModelStateEliminator; + +#ifdef STORM_HAVE_CARL + template class NondeterministicModelStateEliminator; + template class NondeterministicModelStateEliminator; +#endif + } // namespace stateelimination + } // namespace storage +} // namespace storm diff --git a/src/storm/solver/stateelimination/NondeterministicModelStateEliminator.h b/src/storm/solver/stateelimination/NondeterministicModelStateEliminator.h new file mode 100644 index 000000000..1257d8df1 --- /dev/null +++ b/src/storm/solver/stateelimination/NondeterministicModelStateEliminator.h @@ -0,0 +1,28 @@ +#ifndef STORM_SOLVER_STATEELIMINATION_NONDETERMINISTICMODELSTATEELIMINATOR_H_ +#define STORM_SOLVER_STATEELIMINATION_NONDETERMINISTICMODELSTATEELIMINATOR_H_ + +#include "storm/solver/stateelimination/StateEliminator.h" + +namespace storm { + namespace solver { + namespace stateelimination { + + template + class NondeterministicModelStateEliminator : public StateEliminator { + public: + + NondeterministicModelStateEliminator(storm::storage::FlexibleSparseMatrix& transitionMatrix, storm::storage::FlexibleSparseMatrix& backwardTransitions, std::vector& rowValues); + + // Instantiaton of virtual methods. + virtual void updateValue(storm::storage::sparse::state_type const& row, ValueType const& loopProbability) override; + virtual void updatePredecessor(storm::storage::sparse::state_type const& predecessorRow, ValueType const& probability, storm::storage::sparse::state_type const& row) override; + + protected: + std::vector& rowValues; + }; + + } // namespace stateelimination + } // namespace storage +} // namespace storm + +#endif // STORM_SOLVER_STATEELIMINATION_NONDETERMINISTICMODELSTATEELIMINATOR_H_ diff --git a/src/storm/solver/stateelimination/StateEliminator.cpp b/src/storm/solver/stateelimination/StateEliminator.cpp index 47f30b701..1ebc02904 100644 --- a/src/storm/solver/stateelimination/StateEliminator.cpp +++ b/src/storm/solver/stateelimination/StateEliminator.cpp @@ -7,7 +7,7 @@ #include "storm/utility/stateelimination.h" #include "storm/utility/macros.h" #include "storm/utility/constants.h" -#include "storm/utility/macros.h" +#include "storm/exceptions/IllegalArgumentException.h" #include "storm/exceptions/InvalidStateException.h" namespace storm { @@ -24,7 +24,12 @@ namespace storm { template void StateEliminator::eliminateState(storm::storage::sparse::state_type state, bool removeForwardTransitions) { STORM_LOG_TRACE("Eliminating state " << state << "."); - this->eliminate(state, state, removeForwardTransitions); + if(this->matrix.hasTrivialRowGrouping()) { + this->eliminate(state, state, removeForwardTransitions); + } else { + STORM_LOG_THROW(this->matrix.getRowGroupSize(state) == 1, storm::exceptions::IllegalArgumentException, "Invoked state elimination on a state with multiple choices. This is not supported."); + this->eliminate(this->matrix.getRowGroupIndices()[state], state, removeForwardTransitions); + } } template class StateEliminator; diff --git a/src/storm/transformer/SparseParametricModelSimplifier.cpp b/src/storm/transformer/SparseParametricModelSimplifier.cpp index ff22de54a..54fba2338 100644 --- a/src/storm/transformer/SparseParametricModelSimplifier.cpp +++ b/src/storm/transformer/SparseParametricModelSimplifier.cpp @@ -5,7 +5,7 @@ #include "storm/models/sparse/Dtmc.h" #include "storm/models/sparse/Mdp.h" #include "storm/models/sparse/StandardRewardModel.h" -#include "storm/solver/stateelimination/PrioritizedStateEliminator.h" +#include "storm/solver/stateelimination/NondeterministicModelStateEliminator.h" #include "storm/storage/FlexibleSparseMatrix.h" #include "storm/utility/vector.h" @@ -116,32 +116,30 @@ namespace storm { } // Find the states that are to be eliminated - std::vector statesToEliminate; - storm::storage::BitVector keptStates(sparseMatrix.getRowGroupCount(), true); - storm::storage::BitVector keptRows(sparseMatrix.getRowCount(), true); + storm::storage::BitVector selectedStates = considerForElimination; for (auto state : considerForElimination) { if (sparseMatrix.getRowGroupSize(state) == 1 && (!rewardModelName.is_initialized() || storm::utility::isConstant(actionRewards[sparseMatrix.getRowGroupIndices()[state]]))) { - bool hasOnlyConstEntries = true; for (auto const& entry : sparseMatrix.getRowGroup(state)) { if(!storm::utility::isConstant(entry.getValue())) { - hasOnlyConstEntries = false; + selectedStates.set(state, false); break; } } - if (hasOnlyConstEntries) { - statesToEliminate.push_back(state); - keptStates.set(state, false); - keptRows.set(sparseMatrix.getRowGroupIndices()[state], false); - } + } else { + selectedStates.set(state, false); } } // invoke elimination and obtain resulting transition matrix storm::storage::FlexibleSparseMatrix flexibleMatrix(sparseMatrix); storm::storage::FlexibleSparseMatrix flexibleBackwardTransitions(sparseMatrix.transpose(), true); - storm::solver::stateelimination::PrioritizedStateEliminator stateEliminator(flexibleMatrix, flexibleBackwardTransitions, statesToEliminate, actionRewards); - stateEliminator.eliminateAll(); - storm::storage::SparseMatrix newTransitionMatrix = flexibleMatrix.createSparseMatrix(keptRows, keptStates); + storm::solver::stateelimination::NondeterministicModelStateEliminator stateEliminator(flexibleMatrix, flexibleBackwardTransitions, actionRewards); + for(auto state : selectedStates) { + stateEliminator.eliminateState(state, true); + } + selectedStates.complement(); + auto keptRows = sparseMatrix.getRowIndicesOfRowGroups(selectedStates); + storm::storage::SparseMatrix newTransitionMatrix = flexibleMatrix.createSparseMatrix(keptRows, selectedStates); // obtain the reward model for the resulting system std::unordered_map rewardModels; @@ -150,7 +148,7 @@ namespace storm { rewardModels.insert(std::make_pair(*rewardModelName, typename SparseModelType::RewardModelType(boost::none, std::move(actionRewards)))); } - return std::make_shared(std::move(newTransitionMatrix), model.getStateLabeling().getSubLabeling(keptStates), std::move(rewardModels)); + return std::make_shared(std::move(newTransitionMatrix), model.getStateLabeling().getSubLabeling(selectedStates), std::move(rewardModels)); }