diff --git a/src/storm/storage/memorystructure/SparseModelMemoryProduct.cpp b/src/storm/storage/memorystructure/SparseModelMemoryProduct.cpp index 1d1d7f77c..2b0bd8be1 100644 --- a/src/storm/storage/memorystructure/SparseModelMemoryProduct.cpp +++ b/src/storm/storage/memorystructure/SparseModelMemoryProduct.cpp @@ -18,26 +18,24 @@ namespace storm { namespace storage { template - SparseModelMemoryProduct::SparseModelMemoryProduct(storm::models::sparse::Model const& sparseModel, storm::storage::MemoryStructure const& memoryStructure) : model(sparseModel), memory(memoryStructure) { - reachableStates = storm::storage::BitVector(model.getNumberOfStates() * memory.getNumberOfStates(), false); + SparseModelMemoryProduct::SparseModelMemoryProduct(storm::models::sparse::Model const& sparseModel, storm::storage::MemoryStructure const& memoryStructure) : memoryStateCount(memoryStructure.getNumberOfStates()), model(sparseModel), memory(memoryStructure) { + reachableStates = storm::storage::BitVector(model.getNumberOfStates() * memoryStateCount, false); } template void SparseModelMemoryProduct::addReachableState(uint64_t const& modelState, uint64_t const& memoryState) { - reachableStates.set(modelState * memory.getNumberOfStates() + memoryState, true); + reachableStates.set(modelState * memoryStateCount + memoryState, true); } template void SparseModelMemoryProduct::setBuildFullProduct() { - reachableStates.clear(); - reachableStates.complement(); + reachableStates.fill(); } template std::shared_ptr> SparseModelMemoryProduct::build(boost::optional> const& scheduler) { uint64_t modelStateCount = model.getNumberOfStates(); - uint64_t memoryStateCount = memory.getNumberOfStates(); std::vector memorySuccessors = computeMemorySuccessors(); @@ -54,7 +52,7 @@ namespace storm { // Compute the mapping to the states of the result uint64_t reachableStateCount = 0; - toResultStateMapping = std::vector (model.getNumberOfStates() * memory.getNumberOfStates(), std::numeric_limits::max()); + toResultStateMapping = std::vector (model.getNumberOfStates() * memoryStateCount, std::numeric_limits::max()); for (auto const& reachableState : reachableStates) { toResultStateMapping[reachableState] = reachableStateCount; ++reachableStateCount; @@ -82,13 +80,12 @@ namespace storm { template uint64_t const& SparseModelMemoryProduct::getResultState(uint64_t const& modelState, uint64_t const& memoryState) const { - return toResultStateMapping[modelState * memory.getNumberOfStates() + memoryState]; + return toResultStateMapping[modelState * memoryStateCount + memoryState]; } template std::vector SparseModelMemoryProduct::computeMemorySuccessors() const { uint64_t modelTransitionCount = model.getTransitionMatrix().getEntryCount(); - uint64_t memoryStateCount = memory.getNumberOfStates(); std::vector result(modelTransitionCount * memoryStateCount, std::numeric_limits::max()); for (uint64_t memoryState = 0; memoryState < memoryStateCount; ++memoryState) { @@ -106,7 +103,6 @@ namespace storm { template void SparseModelMemoryProduct::computeReachableStates(std::vector const& memorySuccessors, storm::storage::BitVector const& initialStates, boost::optional> const& scheduler) { - uint64_t memoryStateCount = memory.getNumberOfStates(); // Explore the reachable states via DFS. // A state s on the stack corresponds to the model state (s / memoryStateCount) and memory state (s % memoryStateCount) reachableStates |= initialStates; @@ -158,7 +154,6 @@ namespace storm { template storm::storage::SparseMatrix SparseModelMemoryProduct::buildDeterministicTransitionMatrix(std::vector const& memorySuccessors) const { - uint64_t memoryStateCount = memory.getNumberOfStates(); uint64_t numResStates = reachableStates.getNumberOfSetBits(); uint64_t numResTransitions = 0; for (auto const& stateIndex : reachableStates) { @@ -184,7 +179,6 @@ namespace storm { template storm::storage::SparseMatrix SparseModelMemoryProduct::buildNondeterministicTransitionMatrix(std::vector const& memorySuccessors) const { - uint64_t memoryStateCount = memory.getNumberOfStates(); uint64_t numResStates = reachableStates.getNumberOfSetBits(); uint64_t numResChoices = 0; uint64_t numResTransitions = 0; @@ -218,7 +212,6 @@ namespace storm { template storm::storage::SparseMatrix SparseModelMemoryProduct::buildTransitionMatrixForScheduler(std::vector const& memorySuccessors, storm::storage::Scheduler const& scheduler) const { - uint64_t memoryStateCount = memory.getNumberOfStates(); uint64_t numResStates = reachableStates.getNumberOfSetBits(); uint64_t numResChoices = 0; uint64_t numResTransitions = 0; @@ -314,7 +307,6 @@ namespace storm { template storm::models::sparse::StateLabeling SparseModelMemoryProduct::buildStateLabeling(storm::storage::SparseMatrix const& resultTransitionMatrix) const { uint64_t modelStateCount = model.getNumberOfStates(); - uint64_t memoryStateCount = memory.getNumberOfStates(); uint64_t numResStates = resultTransitionMatrix.getRowGroupCount(); storm::models::sparse::StateLabeling resultLabeling(numResStates); @@ -357,7 +349,6 @@ namespace storm { typedef typename RewardModelType::ValueType RewardValueType; std::unordered_map result; - uint64_t memoryStateCount = memory.getNumberOfStates(); uint64_t numResStates = resultTransitionMatrix.getRowGroupCount(); for (auto const& rewardModel : model.getRewardModels()) { @@ -471,7 +462,6 @@ namespace storm { } else if (model.isOfType(storm::models::ModelType::MarkovAutomaton)) { // We also need to translate the exit rates and the Markovian states uint64_t numResStates = components.transitionMatrix.getRowGroupCount(); - uint64_t memoryStateCount = memory.getNumberOfStates(); std::vector resultExitRates; resultExitRates.reserve(components.transitionMatrix.getRowGroupCount()); storm::storage::BitVector resultMarkovianStates(numResStates, false); diff --git a/src/storm/storage/memorystructure/SparseModelMemoryProduct.h b/src/storm/storage/memorystructure/SparseModelMemoryProduct.h index fb959daa5..43fdb2a4c 100644 --- a/src/storm/storage/memorystructure/SparseModelMemoryProduct.h +++ b/src/storm/storage/memorystructure/SparseModelMemoryProduct.h @@ -72,6 +72,8 @@ namespace storm { // Indicates which states are considered reachable. (s, m) is reachable if this BitVector is true at (s * memoryStateCount) + m storm::storage::BitVector reachableStates; + uint64_t const memoryStateCount; + storm::models::sparse::Model const& model; storm::storage::MemoryStructure const& memory;