From b2d7b1e0964c2fd9f34157e617e9912886f2f5de Mon Sep 17 00:00:00 2001 From: radioGiorgio Date: Fri, 21 Jun 2019 11:21:06 -0700 Subject: [PATCH] choice labeling --- ...rministicTransitionsBasedMemoryProduct.cpp | 98 ++++++++++++++++--- ...terministicTransitionsBasedMemoryProduct.h | 1 + 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp b/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp index bac35724c..8efd085c4 100644 --- a/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp +++ b/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.cpp @@ -18,11 +18,19 @@ namespace storm { storm::storage::sparse::ModelComponents components; components.transitionMatrix = buildTransitions(); components.stateLabeling = buildStateLabeling(); + components.choiceLabeling = buildChoiceLabeling(components.transitionMatrix); // Now delete unreachable states. storm::storage::BitVector allStates(components.transitionMatrix.getRowGroupCount(), true); auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates); + storm::storage::BitVector enabledActions(components.transitionMatrix.getRowCount()); + for (uint64_t state : reachableStates) { + for (uint64_t row = components.transitionMatrix.getRowGroupIndices()[state]; row < components.transitionMatrix.getRowGroupIndices()[state + 1]; ++ row) { + enabledActions.set(row); + } + } components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates); components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates); + components.choiceLabeling = components.choiceLabeling->getSubLabeling(enabledActions); // build the remaining components for (auto const& rewModel : model.getRewardModels()) { @@ -130,10 +138,18 @@ namespace storm { for (uint64_t modelState = 0; modelState < model.getNumberOfStates(); ++ modelState) { if (forceLabeling) { for (uint64_t memoryState = 0; memoryState < memory.getNumberOfStates(); ++ memoryState) { - std::ostringstream stream; - stream << "m" << memoryState; - std::string labelName = stream.str(); - addLabel(labelName, getProductState(modelState, memoryState)); + if (labeling.getLabelsOfState(getProductState(modelState, memoryState)).empty()) { + std::ostringstream stream; + stream << "s" << modelState; + std::string labelName = stream.str(); + addLabel(labelName, getProductState(modelState, memoryState)); + } + { + std::ostringstream stream; + stream << "m" << memoryState; + std::string labelName = stream.str(); + addLabel(labelName, getProductState(modelState, memoryState)); + } } } uint64_t entryCount = 0; @@ -143,7 +159,9 @@ namespace storm { for (uint64_t memoryState = 0; memoryState < memory.getNumberOfStates(); ++ memoryState) { uint64_t productState = getProductState(modelState, memoryState) + 1 + entryCount; // origin state - if (model.getStateLabeling().getLabelsOfState(modelState).empty()) { + if ( model.getStateLabeling().getLabelsOfState(modelState).empty() or + (model.getStateLabeling().getLabelsOfState(modelState).size() == 1 and model.getStateLabeling().getStateHasLabel("init", modelState) + and not labeling.getStateHasLabel("init", productState)) ){ if (forceLabeling) { std::ostringstream stream; stream << "s" << modelState; @@ -152,7 +170,9 @@ namespace storm { } } else { for (auto const& labelName : model.getStateLabeling().getLabelsOfState(modelState)) { - addLabel(labelName, productState); + if (labelName != "init") { + addLabel(labelName, productState); + } } } // memory labeling @@ -175,7 +195,9 @@ namespace storm { } } // arrival state - if (model.getStateLabeling().getLabelsOfState(successor).empty()) { + if ( model.getStateLabeling().getLabelsOfState(successor).empty() or + (model.getStateLabeling().getLabelsOfState(successor).size() == 1 and model.getStateLabeling().getStateHasLabel("init", successor) + and not labeling.getStateHasLabel("init", productState)) ){ if (forceLabeling) { std::ostringstream stream; stream << "s" << successor; @@ -184,7 +206,9 @@ namespace storm { } } else { for (auto const& labelName : model.getStateLabeling().getLabelsOfState(successor)) { - addLabel(labelName, productState); + if (labelName != "init") { + addLabel(labelName, productState); + } } } } @@ -195,6 +219,54 @@ namespace storm { return labeling; } + template + storm::models::sparse::ChoiceLabeling SparseModelNondeterministicTransitionsBasedMemoryProduct::buildChoiceLabeling(storm::storage::SparseMatrix const& transitions) const { + storm::storage::SparseMatrix const& origTransitions = model.getTransitionMatrix(); + storm::models::sparse::ChoiceLabeling labeling(transitions.getRowCount()); + + auto addLabel = [&] (std::string const& labelName, uint64_t row) -> void { + if (not labeling.containsLabel(labelName)) { + labeling.addLabel(labelName); + } + labeling.addLabelToChoice(labelName, row); + }; + + uint64_t row = 0; + for (uint64_t modelState = 0; modelState < model.getNumberOfStates(); ++ modelState) { + for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++ memState) { + for (uint64_t origRow = origTransitions.getRowGroupIndices()[modelState]; origRow < origTransitions.getRowGroupIndices()[modelState + 1]; ++origRow) { + if (forceLabeling and (not model.getOptionalChoiceLabeling() + or model.getChoiceLabeling().getLabelsOfChoice(origRow).empty())) { + std::ostringstream stream; + stream << "a" << origRow; + std::string labelName = stream.str(); + addLabel(labelName, row); + } else if (model.getOptionalChoiceLabeling()) { + for (auto const &labelName : model.getChoiceLabeling().getLabelsOfChoice(origRow)) { + addLabel(labelName, row); + } + } + ++row; + } + // transition states + for (uint64_t origRow = origTransitions.getRowGroupIndices()[modelState]; origRow < origTransitions.getRowGroupIndices()[modelState + 1]; ++origRow) { + for (auto const& entry : origTransitions.getRow(origRow)) { + for (auto const& memStatePrime : memory.getTransitions(memState)) { + if (forceLabeling) { + std::ostringstream stream; + stream << "m" << memStatePrime; + std::string labelName = stream.str(); + addLabel(labelName, row); + } + ++row; + } + } + } + } + } + return labeling; + } + template storm::models::sparse::StandardRewardModel::ValueType> SparseModelNondeterministicTransitionsBasedMemoryProduct::buildRewardModel(storm::models::sparse::StandardRewardModel const& rewardModel, storm::storage::BitVector const& reachableStates, storm::storage::SparseMatrix const& resultTransitionMatrix) const { boost::optional> stateRewards, actionRewards; @@ -235,6 +307,7 @@ namespace storm { template std::vector SparseModelNondeterministicTransitionsBasedMemoryProduct::generateOffsetVector(storm::storage::BitVector const& reachableStates) { uint64_t numberOfStates = model.getNumberOfStates() * memory.getNumberOfStates() * (1 + model.getNumberOfTransitions()); + STORM_LOG_ASSERT(reachableStates.size() == numberOfStates, "wrong size for the vector reachableStates"); uint64_t offset = 0; std::vector offsetVector(numberOfStates); for (uint64_t state = 0; state < numberOfStates; ++ state) { @@ -262,18 +335,17 @@ namespace storm { template uint64_t SparseModelNondeterministicTransitionsBasedMemoryProduct::getModelState(uint64_t productState) const { + uint64_t productStateWithOffset = productState + (fullProductStatesOffset.empty() ? 0 : fullProductStatesOffset[productState]); // binary search in the vector containing the product states indices - auto search = std::upper_bound(productStates.begin(), productStates.end(), productState); - uint64_t index = search - productStates.begin() - 1; - return index - (fullProductStatesOffset.empty() ? 0 : fullProductStatesOffset[index]); + auto search = std::upper_bound(productStates.begin(), productStates.end(), productStateWithOffset); + return search - productStates.begin() - 1; } template uint64_t SparseModelNondeterministicTransitionsBasedMemoryProduct::getMemoryState(uint64_t productState) const { uint64_t modelState = getModelState(productState); uint64_t offset = productState - productStates[modelState]; - uint64_t index = offset / (1 + model.getTransitionMatrix().getRowGroupEntryCount(modelState)); - return index - (fullProductStatesOffset.empty() ? 0 : fullProductStatesOffset[index]); + return offset / (1 + model.getTransitionMatrix().getRowGroupEntryCount(modelState)); } template class SparseModelNondeterministicTransitionsBasedMemoryProduct>; diff --git a/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h b/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h index 7f2335593..1f4ab9270 100644 --- a/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h +++ b/src/storm/storage/memorystructure/SparseModelNondeterministicTransitionsBasedMemoryProduct.h @@ -36,6 +36,7 @@ namespace storm { private: storm::storage::SparseMatrix buildTransitions(); storm::models::sparse::StateLabeling buildStateLabeling() const; + storm::models::sparse::ChoiceLabeling buildChoiceLabeling(storm::storage::SparseMatrix const& transitions) const; storm::models::sparse::StandardRewardModel buildRewardModel(storm::models::sparse::StandardRewardModel const& rewardModel, storm::storage::BitVector const& reachableStates, storm::storage::SparseMatrix const& resultTransitionMatrix) const; std::vector generateOffsetVector(storm::storage::BitVector const& reachableStates);