diff --git a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp index 73fa080c0..b4d818ad2 100644 --- a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp +++ b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp @@ -378,25 +378,20 @@ namespace storm { if (!maybeStates.empty()) { // In this case we have to compute the reward values for the remaining states. - // We can eliminate the rows and columns from the original transition probability matrix for states - // whose reward values are already known. - storm::storage::SparseMatrix submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); - - // Prepare the right-hand side of the equation system. - std::vector b = totalStateRewardVectorGetter(submatrix.getRowCount(), transitionMatrix, maybeStates); - - // Since we are cutting away target and infinity states, we need to account for this by giving - // choices the value infinity that have some successor contained in the infinity states. - uint_fast64_t currentRow = 0; - for (auto state : maybeStates) { - for (uint_fast64_t row = nondeterministicChoiceIndices[state]; row < nondeterministicChoiceIndices[state + 1]; ++row, ++currentRow) { - for (auto const& element : transitionMatrix.getRow(row)) { - if (infinityStates.get(element.getColumn())) { - b[currentRow] = storm::utility::infinity(); - break; - } - } - } + // Prepare matrix and vector for the equation system. + storm::storage::SparseMatrix submatrix; + std::vector b; + // Remove rows and columns from the original transition probability matrix for states whose reward values are already known. + // If there are infinity states, we additionaly have to remove choices of maybeState that lead to infinity + boost::optional selectedChoices; // if not given, all maybeState choices are selected + if (infinityStates.empty()) { + submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); + b = totalStateRewardVectorGetter(submatrix.getRowCount(), transitionMatrix, maybeStates); + } else { + selectedChoices = transitionMatrix.getRowFilter(maybeStates, ~infinityStates); + submatrix = transitionMatrix.getSubmatrix(false, *selectedChoices, maybeStates, false); + b = totalStateRewardVectorGetter(transitionMatrix.getRowCount(), transitionMatrix, storm::storage::BitVector(transitionMatrix.getRowGroupCount(), true)); + storm::utility::vector::filterVectorInPlace(b, *selectedChoices); } bool skipEcWithinMaybeStatesCheck = !goal.minimize() || (hint.isExplicitModelCheckerHint() && hint.asExplicitModelCheckerHint().getNoEndComponentsInMaybeStates()); @@ -408,10 +403,25 @@ namespace storm { if (produceScheduler) { storm::storage::Scheduler const& subscheduler = *resultForMaybeStates.scheduler; - uint_fast64_t currentSubState = 0; - for (auto maybeState : maybeStates) { - scheduler->setChoice(maybeState, subscheduler.getChoice(currentSubState)); - ++currentSubState; + if (selectedChoices) { + uint_fast64_t currentSubState = 0; + for (auto maybeState : maybeStates) { + uint_fast64_t subChoice = subscheduler.getChoice(currentSubState); + // find the rowindex that corresponds to the selected row of the submodel + uint_fast64_t firstRowIndex = transitionMatrix.getRowGroupIndices()[maybeState]; + uint_fast64_t selectedRowIndex = selectedChoices->getNextSetIndex(firstRowIndex); + for (uint_fast64_t choice = 0; choice < subChoice; ++choice) { + selectedRowIndex = selectedChoices->getNextSetIndex(selectedRowIndex + 1); + } + scheduler->setChoice(maybeState, selectedRowIndex - firstRowIndex); + ++currentSubState; + } + } else { + uint_fast64_t currentSubState = 0; + for (auto maybeState : maybeStates) { + scheduler->setChoice(maybeState, subscheduler.getChoice(currentSubState)); + ++currentSubState; + } } } } diff --git a/src/storm/storage/SparseMatrix.cpp b/src/storm/storage/SparseMatrix.cpp index d71fa7ba4..36319b394 100644 --- a/src/storm/storage/SparseMatrix.cpp +++ b/src/storm/storage/SparseMatrix.cpp @@ -565,6 +565,27 @@ namespace storm { return res; } + template + storm::storage::BitVector SparseMatrix::getRowFilter(storm::storage::BitVector const& groupConstraint, storm::storage::BitVector const& columnConstraint) const { + storm::storage::BitVector result(this->getRowCount(), false); + for (auto const& group : groupConstraint) { + uint_fast64_t const endOfGroup = this->getRowGroupIndices()[group + 1]; + for (uint_fast64_t row = this->getRowGroupIndices()[group]; row < endOfGroup; ++row) { + bool choiceSatisfiesColumnConstraint = true; + for (auto const& entry : this->getRow(row)) { + if (!columnConstraint.get(entry.getColumn())) { + choiceSatisfiesColumnConstraint = false; + break; + } + } + if (choiceSatisfiesColumnConstraint) { + result.set(row, true); + } + } + } + return result; + } + template void SparseMatrix::makeRowsAbsorbing(storm::storage::BitVector const& rows) { for (auto row : rows) { diff --git a/src/storm/storage/SparseMatrix.h b/src/storm/storage/SparseMatrix.h index 686574518..8aab0d397 100644 --- a/src/storm/storage/SparseMatrix.h +++ b/src/storm/storage/SparseMatrix.h @@ -571,6 +571,17 @@ namespace storm { */ storm::storage::BitVector getRowIndicesOfRowGroups(storm::storage::BitVector const& groups) const; + /*! + * Returns the indices of all rows that + * * are in a selected group and + * * only have entries within the selected columns. + * + * @param groupConstraint the selected groups + * @param columnConstraints the selected columns + * @return a bit vector that is true at position i iff row i satisfies the constraints. + */ + storm::storage::BitVector getRowFilter(storm::storage::BitVector const& groupConstraint, storm::storage::BitVector const& columnConstraints) const; + /*! * This function makes the given rows absorbing. * diff --git a/src/storm/utility/vector.h b/src/storm/utility/vector.h index 9482d35b2..c26c17774 100644 --- a/src/storm/utility/vector.h +++ b/src/storm/utility/vector.h @@ -859,6 +859,18 @@ namespace storm { return result; } + template + void filterVectorInPlace(std::vector& v, storm::storage::BitVector const& filter) { + STORM_LOG_ASSERT(v.size() == filter.size(), "The filter size does not match the size of the input vector"); + auto vIt = v.begin(); + for(auto index : filter) { + *vIt = std::move(v[index]); + ++vIt; + } + v.resize(vIt - v.begin()); + STORM_LOG_ASSERT(v.size() == filter.getNumberOfSetBits(), "Result does not match."); + } + template bool hasNegativeEntry(std::vector const& v){ return std::any_of(v.begin(), v.end(), [](T value){return value < storm::utility::zero();});