From 3820b994c5eaf3ee4b32207ac1719a4822d39733 Mon Sep 17 00:00:00 2001 From: radioGiorgio Date: Wed, 24 Jul 2019 09:19:39 +0200 Subject: [PATCH] product indices getter debugged --- ...rministicTransitionsBasedMemoryProduct.cpp | 35 ++++++++++++------- ...terministicTransitionsBasedMemoryProduct.h | 1 + 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp b/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp index 8efd085c4..82144e3fc 100644 --- a/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp +++ b/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp @@ -21,7 +21,7 @@ namespace storm { components.choiceLabeling = buildChoiceLabeling(components.transitionMatrix); // Now delete unreachable states. storm::storage::BitVector allStates(components.transitionMatrix.getRowGroupCount(), true); - auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates); + reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates); storm::storage::BitVector enabledActions(components.transitionMatrix.getRowCount()); for (uint64_t state : reachableStates) { for (uint64_t row = components.transitionMatrix.getRowGroupIndices()[state]; row < components.transitionMatrix.getRowGroupIndices()[state + 1]; ++ row) { @@ -34,8 +34,7 @@ namespace storm { // build the remaining components for (auto const& rewModel : model.getRewardModels()) { - components.rewardModels.emplace(rewModel.first, - buildRewardModel(rewModel.second, reachableStates, components.transitionMatrix)); + components.rewardModels.emplace(rewModel.first, buildRewardModel(rewModel.second, reachableStates, components.transitionMatrix)); } // build the offset vector, that allows to maintain getter indices @@ -53,7 +52,7 @@ namespace storm { productStates[0] = 0; for (uint64_t modelState = 0; modelState < model.getNumberOfStates(); ++ modelState) { if (modelState < model.getNumberOfStates() - 1) { - productStates[modelState + 1] = productStates[modelState] + memory.getNumberOfStates() * ( 1 + origTransitions.getRowGroupEntryCount(modelState) ); + productStates[modelState + 1] = productStates[modelState] + memory.getNumberOfStates() * (1 + origTransitions.getRowGroupEntryCount(modelState)); } for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++ memState) { for (uint64_t row = origTransitions.getRowGroupIndices()[modelState]; @@ -307,14 +306,22 @@ namespace storm { template std::vector SparseModelNondeterministicTransitionsBasedMemoryProduct::generateOffsetVector(storm::storage::BitVector const& reachableStates) { uint64_t numberOfStates = model.getNumberOfStates() * memory.getNumberOfStates() * (1 + model.getNumberOfTransitions()); - STORM_LOG_ASSERT(reachableStates.size() == numberOfStates, "wrong size for the vector reachableStates"); + STORM_LOG_ASSERT(reachableStates.size() == numberOfStates, "vector reachableStates has wrong size"); + uint64_t state = 0; uint64_t offset = 0; std::vector offsetVector(numberOfStates); - for (uint64_t state = 0; state < numberOfStates; ++ state) { - if (not reachableStates[state]) { - ++ offset; + while (state < numberOfStates) { + if (reachableStates[state]) { + offsetVector[state] = offset; + ++ state; + } + else { + uint64_t nextState = reachableStates.getNextSetIndex(state); + offset += nextState - state; + for (; state < nextState; ++ state) { + offsetVector[state] = offset; + } } - offsetVector[state] = offset; } return std::move(offsetVector); @@ -328,9 +335,9 @@ namespace storm { template bool SparseModelNondeterministicTransitionsBasedMemoryProduct::isProductStateReachable(uint64_t modelState, uint64_t memoryState) const { - STORM_LOG_ASSERT(not fullProductStatesOffset.empty(), "Model not built"); + STORM_LOG_ASSERT(not fullProductStatesOffset.empty(), "The product is not yet built"); uint64_t index = productStates[modelState] + memoryState * (1 + model.getTransitionMatrix().getRowGroupEntryCount(modelState)); - return ( index == 0 and fullProductStatesOffset[index] == 0 ) or ( index > 0 and fullProductStatesOffset[index] == fullProductStatesOffset[index - 1] ); + return reachableStates[index]; } template @@ -343,8 +350,10 @@ namespace storm { template uint64_t SparseModelNondeterministicTransitionsBasedMemoryProduct::getMemoryState(uint64_t productState) const { - uint64_t modelState = getModelState(productState); - uint64_t offset = productState - productStates[modelState]; + uint64_t productStateWithOffset = productState + (fullProductStatesOffset.empty() ? 0 : fullProductStatesOffset[productState]); + // binary search in the vector containing the product states indices + uint64_t modelState = std::upper_bound(productStates.begin(), productStates.end(), productStateWithOffset) - productStates.begin() - 1; + uint64_t offset = productStateWithOffset - productStates[modelState]; return offset / (1 + model.getTransitionMatrix().getRowGroupEntryCount(modelState)); } diff --git a/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h b/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h index 1f4ab9270..97384d9ef 100644 --- a/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h +++ b/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h @@ -45,6 +45,7 @@ namespace storm { storm::storage::NondeterministicMemoryStructure const& memory; std::vector productStates; // has a size equals to the number of states of the original model std::vector fullProductStatesOffset; // has a size equal to the number of states of the full product + storm::storage::BitVector reachableStates; bool forceLabeling; };