diff --git a/src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp b/src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp index 6bc5b1e18..8a04f152d 100644 --- a/src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp +++ b/src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp @@ -6,6 +6,8 @@ #include "storm/settings/OptionBuilder.h" #include "storm/settings/ArgumentBuilder.h" +#include "storm/exceptions/InvalidArgumentException.h" + namespace storm { namespace settings { namespace modules { @@ -17,6 +19,8 @@ namespace storm { const std::string mecReductionOption = "mecreduction"; const std::string selfloopReductionOption = "selfloopreduction"; const std::string memoryBoundOption = "memorybound"; + const std::string memoryPatternOption = "memorypattern"; + std::vector memoryPatterns = {"trivial", "fixedcounter", "selectivecounter", "ring", "settablebits", "full"}; const std::string fscmode = "fscmode"; std::vector fscModes = {"standard", "simple-linear", "simple-linear-inverse"}; const std::string transformBinaryOption = "transformbinary"; @@ -29,6 +33,7 @@ namespace storm { this->addOption(storm::settings::OptionBuilder(moduleName, mecReductionOption, false, "Reduces the model size by analyzing maximal end components").build()); this->addOption(storm::settings::OptionBuilder(moduleName, selfloopReductionOption, false, "Reduces the model size by removing self loop actions").build()); this->addOption(storm::settings::OptionBuilder(moduleName, memoryBoundOption, false, "Sets the maximal number of allowed memory states (1 means memoryless schedulers).").addArgument(storm::settings::ArgumentBuilder::createUnsignedIntegerArgument("bound", "The maximal number of memory states.").setDefaultValueUnsignedInteger(1).addValidatorUnsignedInteger(storm::settings::ArgumentValidatorFactory::createUnsignedGreaterValidator(0)).build()).build()); + this->addOption(storm::settings::OptionBuilder(moduleName, memoryPatternOption, false, "Sets the pattern of the considered memory structure").addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "Pattern name.").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(memoryPatterns)).setDefaultValueString("full").build()).build()); this->addOption(storm::settings::OptionBuilder(moduleName, fscmode, false, "Sets the way the pMC is obtained").addArgument(storm::settings::ArgumentBuilder::createStringArgument("type", "type name").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(fscModes)).setDefaultValueString("standard").build()).build()); this->addOption(storm::settings::OptionBuilder(moduleName, transformBinaryOption, false, "Transforms the pomdp to a binary pomdp.").build()); this->addOption(storm::settings::OptionBuilder(moduleName, transformSimpleOption, false, "Transforms the pomdp to a binary and simple pomdp.").build()); @@ -61,6 +66,24 @@ namespace storm { uint64_t POMDPSettings::getMemoryBound() const { return this->getOption(memoryBoundOption).getArgumentByName("bound").getValueAsUnsignedInteger(); } + + storm::storage::PomdpMemoryPattern POMDPSettings::getMemoryPattern() const { + auto pattern = this->getOption(memoryPatternOption).getArgumentByName("name").getValueAsString(); + if (pattern == "trivial") { + return storm::storage::PomdpMemoryPattern::Trivial; + } else if (pattern == "fixedcounter") { + return storm::storage::PomdpMemoryPattern::FixedCounter; + } else if (pattern == "selectivecounter") { + return storm::storage::PomdpMemoryPattern::SelectiveCounter; + } else if (pattern == "ring") { + return storm::storage::PomdpMemoryPattern::Ring; + } else if (pattern == "settablebits") { + return storm::storage::PomdpMemoryPattern::SettableBits; + } else if (pattern == "full") { + return storm::storage::PomdpMemoryPattern::Full; + } + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "The name of the memory pattern is unknown."); + } std::string POMDPSettings::getFscApplicationTypeString() const { return this->getOption(fscmode).getArgumentByName("type").getValueAsString(); @@ -78,7 +101,7 @@ namespace storm { } bool POMDPSettings::check() const { - // Ensure that at most one of min or max is set + STORM_LOG_THROW(getMemoryPattern() != storm::storage::PomdpMemoryPattern::Trivial || getMemoryBound() == 1, storm::exceptions::InvalidArgumentException, "Memory bound greater one is not possible with the trivial memory pattern."); return true; } diff --git a/src/storm-pomdp-cli/settings/modules/POMDPSettings.h b/src/storm-pomdp-cli/settings/modules/POMDPSettings.h index 213d46be4..1f8b048d3 100644 --- a/src/storm-pomdp-cli/settings/modules/POMDPSettings.h +++ b/src/storm-pomdp-cli/settings/modules/POMDPSettings.h @@ -2,6 +2,7 @@ #include "storm-config.h" #include "storm/settings/modules/ModuleSettings.h" +#include "storm-pomdp/storage/PomdpMemory.h" #include "storm-dft/builder/DftExplorationHeuristic.h" @@ -33,7 +34,7 @@ namespace storm { bool isTransformBinarySet() const; std::string getFscApplicationTypeString() const; uint64_t getMemoryBound() const; - + storm::storage::PomdpMemoryPattern getMemoryPattern() const; bool check() const override; void finalize() override; diff --git a/src/storm-pomdp-cli/storm-pomdp.cpp b/src/storm-pomdp-cli/storm-pomdp.cpp index 22574197f..875564c4c 100644 --- a/src/storm-pomdp-cli/storm-pomdp.cpp +++ b/src/storm-pomdp-cli/storm-pomdp.cpp @@ -124,7 +124,6 @@ int main(const int argc, const char** argv) { storm::analysis::UniqueObservationStates uniqueAnalysis(*pomdp); std::cout << uniqueAnalysis.analyse() << std::endl; } - if (formula) { if (formula->isProbabilityOperatorFormula()) { @@ -155,8 +154,10 @@ int main(const int argc, const char** argv) { } } if (pomdpSettings.getMemoryBound() > 1) { - STORM_PRINT_AND_LOG("Computing the unfolding for memory bound " << pomdpSettings.getMemoryBound() << "..."); - storm::transformer::PomdpMemoryUnfolder memoryUnfolder(*pomdp, pomdpSettings.getMemoryBound()); + STORM_PRINT_AND_LOG("Computing the unfolding for memory bound " << pomdpSettings.getMemoryBound() << " and memory pattern '" << storm::storage::toString(pomdpSettings.getMemoryPattern()) << "' ..."); + storm::storage::PomdpMemory memory = storm::storage::PomdpMemoryBuilder().build(pomdpSettings.getMemoryPattern(), pomdpSettings.getMemoryBound()); + std::cout << memory.toString() << std::endl; + storm::transformer::PomdpMemoryUnfolder memoryUnfolder(*pomdp, memory); pomdp = memoryUnfolder.transform(); STORM_PRINT_AND_LOG(" done." << std::endl); pomdp->printModelInformationToStream(std::cout); diff --git a/src/storm-pomdp/storage/PomdpMemory.cpp b/src/storm-pomdp/storage/PomdpMemory.cpp new file mode 100644 index 000000000..bd9c2593b --- /dev/null +++ b/src/storm-pomdp/storage/PomdpMemory.cpp @@ -0,0 +1,168 @@ +#include "storm-pomdp/storage/PomdpMemory.h" + +#include "storm/utility/macros.h" +#include "storm/exceptions/InvalidArgumentException.h" + +namespace storm { + namespace storage { + + PomdpMemory::PomdpMemory(std::vector const& transitions, uint64_t initialState) : transitions(transitions), initialState(initialState) { + STORM_LOG_THROW(this->initialState < this->transitions.size(), storm::exceptions::InvalidArgumentException, "Initial state " << this->initialState << " of pomdp memory is invalid."); + for (auto const& t : this->transitions) { + STORM_LOG_THROW(t.size() == this->transitions.size(), storm::exceptions::InvalidArgumentException, "Invalid dimension of transition matrix of pomdp memory."); + STORM_LOG_THROW(!t.empty(), storm::exceptions::InvalidArgumentException, "Invalid transition matrix of pomdp memory: No deadlock states allowed."); + } + } + + uint64_t PomdpMemory::getNumberOfStates() const { + return transitions.size(); + } + + uint64_t PomdpMemory::getInitialState() const { + return initialState; + } + + storm::storage::BitVector const& PomdpMemory::getTransitions(uint64_t state) const { + return transitions.at(state); + } + + uint64_t PomdpMemory::getNumberOfOutgoingTransitions(uint64_t state) const { + return getTransitions(state).getNumberOfSetBits(); + } + + std::vector const& PomdpMemory::getTransitions() const { + return transitions; + } + + std::string PomdpMemory::toString() const { + std::string result = "PomdpMemory with " + std::to_string(getNumberOfStates()) + " states.\n"; + result += "Initial state is " + std::to_string(getInitialState()) + ". Transitions are \n"; + + // header + result += " |"; + for (uint64_t state = 0; state < getNumberOfStates(); ++state) { + if (state < 10) { + result += " "; + } + result += std::to_string(state); + } + result += "\n"; + result += "--|"; + for (uint64_t state = 0; state < getNumberOfStates(); ++state) { + result += "--"; + } + result += "\n"; + + // transition matrix entries + for (uint64_t state = 0; state < getNumberOfStates(); ++state) { + if (state < 10) { + result += " "; + } + result += std::to_string(state) + "|"; + for (uint64_t statePrime = 0; statePrime < getNumberOfStates(); ++statePrime) { + result += " "; + if (getTransitions(state).get(statePrime)) { + result += "1"; + } else { + result += "0"; + } + } + result += "\n"; + } + return result; + } + + std::string toString(PomdpMemoryPattern const& pattern) { + switch (pattern) { + case PomdpMemoryPattern::Trivial: + return "trivial"; + case PomdpMemoryPattern::FixedCounter: + return "fixedcounter"; + case PomdpMemoryPattern::SelectiveCounter: + return "selectivecounter"; + case PomdpMemoryPattern::Ring: + return "ring"; + case PomdpMemoryPattern::SettableBits: + return "settablebits"; + case PomdpMemoryPattern::Full: + return "full"; + } + return "unknown"; + } + + PomdpMemory PomdpMemoryBuilder::build(PomdpMemoryPattern pattern, uint64_t numStates) const { + switch (pattern) { + case PomdpMemoryPattern::Trivial: + STORM_LOG_ERROR_COND(numStates == 1, "Invoked building trivial POMDP memory with " << numStates << " states. However, trivial POMDP memory always has one state."); + return buildTrivialMemory(); + case PomdpMemoryPattern::FixedCounter: + return buildFixedCountingMemory(numStates); + case PomdpMemoryPattern::SelectiveCounter: + return buildSelectiveCountingMemory(numStates); + case PomdpMemoryPattern::Ring: + return buildRingMemory(numStates); + case PomdpMemoryPattern::SettableBits: + return buildSettableBitsMemory(numStates); + case PomdpMemoryPattern::Full: + return buildFullyConnectedMemory(numStates); + } + } + + PomdpMemory PomdpMemoryBuilder::buildTrivialMemory() const { + return buildFullyConnectedMemory(1); + } + + PomdpMemory PomdpMemoryBuilder::buildFixedCountingMemory(uint64_t numStates) const { + std::vector transitions(numStates, storm::storage::BitVector(numStates, false)); + for (uint64_t state = 0; state < numStates; ++state) { + transitions[state].set(std::min(state + 1, numStates - 1)); + } + return PomdpMemory(transitions, 0); + } + + PomdpMemory PomdpMemoryBuilder::buildSelectiveCountingMemory(uint64_t numStates) const { + std::vector transitions(numStates, storm::storage::BitVector(numStates, false)); + for (uint64_t state = 0; state < numStates; ++state) { + transitions[state].set(state); + transitions[state].set(std::min(state + 1, numStates - 1)); + } + return PomdpMemory(transitions, 0); + } + + PomdpMemory PomdpMemoryBuilder::buildRingMemory(uint64_t numStates) const { + std::vector transitions(numStates, storm::storage::BitVector(numStates, false)); + for (uint64_t state = 0; state < numStates; ++state) { + transitions[state].set(state); + transitions[state].set((state + 1) % numStates); + } + return PomdpMemory(transitions, 0); + } + + PomdpMemory PomdpMemoryBuilder::buildSettableBitsMemory(uint64_t numStates) const { + // compute the number of bits, i.e., floor(log(numStates)) + uint64_t numBits = 0; + uint64_t actualNumStates = 1; + while (actualNumStates * 2 <= numStates) { + actualNumStates *= 2; + ++numBits; + } + + STORM_LOG_WARN_COND(actualNumStates == numStates, "The number of memory states for the settable bits pattern has to be a power of 2. Shrinking the number of memory states to " << actualNumStates << "."); + + std::vector transitions(actualNumStates, storm::storage::BitVector(actualNumStates, false)); + for (uint64_t state = 0; state < actualNumStates; ++state) { + transitions[state].set(state); + for (uint64_t bit = 0; bit < numBits; ++bit) { + uint64_t bitMask = 1u << bit; + transitions[state].set(state | bitMask); + } + } + return PomdpMemory(transitions, 0); + } + + PomdpMemory PomdpMemoryBuilder::buildFullyConnectedMemory(uint64_t numStates) const { + std::vector transitions(numStates, storm::storage::BitVector(numStates, true)); + return PomdpMemory(transitions, 0); + } + } +} \ No newline at end of file diff --git a/src/storm-pomdp/storage/PomdpMemory.h b/src/storm-pomdp/storage/PomdpMemory.h new file mode 100644 index 000000000..50ff309fa --- /dev/null +++ b/src/storm-pomdp/storage/PomdpMemory.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include "storm/storage/BitVector.h" +#include "storm/exceptions/InvalidArgumentException.h" + +namespace storm { + namespace storage { + + class PomdpMemory { + + public: + PomdpMemory(std::vector const& transitions, uint64_t initialState); + uint64_t getNumberOfStates() const; + uint64_t getInitialState() const; + storm::storage::BitVector const& getTransitions(uint64_t state) const; + uint64_t getNumberOfOutgoingTransitions(uint64_t state) const; + std::vector const& getTransitions() const; + std::string toString() const; + private: + std::vector transitions; + uint64_t initialState; + }; + + enum class PomdpMemoryPattern { + Trivial, FixedCounter, SelectiveCounter, Ring, SettableBits, Full + }; + + std::string toString(PomdpMemoryPattern const& pattern); + + class PomdpMemoryBuilder { + public: + // Builds a memory structure with the given pattern and the given number of states. + PomdpMemory build(PomdpMemoryPattern pattern, uint64_t numStates) const; + + // Builds a memory structure that consists of just a single memory state + PomdpMemory buildTrivialMemory() const; + + // Builds a memory structure that consists of a chain of the given number of states. + // Every state has exactly one transition to the next state. The last state has just a selfloop. + PomdpMemory buildFixedCountingMemory(uint64_t numStates) const; + + // Builds a memory structure that consists of a chain of the given number of states. + // Every state has a selfloop and a transition to the next state. The last state just has a selfloop. + PomdpMemory buildSelectiveCountingMemory(uint64_t numStates) const; + + // Builds a memory structure that consists of a ring of the given number of states. + // Every state has a transition to the successor state and a selfloop + PomdpMemory buildRingMemory(uint64_t numStates) const; + + // Builds a memory structure that represents floor(log(numStates)) bits that can only be set from zero to one or from zero to zero. + PomdpMemory buildSettableBitsMemory(uint64_t numStates) const; + + // Builds a memory structure that consists of the given number of states which are fully connected. + PomdpMemory buildFullyConnectedMemory(uint64_t numStates) const; + + }; + + } +} \ No newline at end of file diff --git a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp index 83137f8a8..827493efc 100644 --- a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp +++ b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp @@ -1,6 +1,8 @@ -#include #include "storm-pomdp/transformer/PomdpMemoryUnfolder.h" + +#include #include "storm/storage/sparse/ModelComponents.h" +#include "storm/utility/graph.h" #include "storm/exceptions/NotSupportedException.h" @@ -9,41 +11,56 @@ namespace storm { template - PomdpMemoryUnfolder::PomdpMemoryUnfolder(storm::models::sparse::Pomdp const& pomdp, uint64_t numMemoryStates) : pomdp(pomdp), numMemoryStates(numMemoryStates) { + PomdpMemoryUnfolder::PomdpMemoryUnfolder(storm::models::sparse::Pomdp const& pomdp, storm::storage::PomdpMemory const& memory) : pomdp(pomdp), memory(memory) { // intentionally left empty } template std::shared_ptr> PomdpMemoryUnfolder::transform() const { + // For simplicity we first build the 'full' product of pomdp and memory (with pomdp.numStates * memory.numStates states). storm::storage::sparse::ModelComponents components; components.transitionMatrix = transformTransitions(); components.stateLabeling = transformStateLabeling(); - components.observabilityClasses = transformObservabilityClasses(); + + // Now delete unreachable states. + storm::storage::BitVector allStates(components.transitionMatrix.getRowGroupCount(), true); + auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates); + components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates); + components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates); + + // build the remaining components + components.observabilityClasses = transformObservabilityClasses(reachableStates); for (auto const& rewModel : pomdp.getRewardModels()) { - components.rewardModels.emplace(rewModel.first, transformRewardModel(rewModel.second)); + components.rewardModels.emplace(rewModel.first, transformRewardModel(rewModel.second, reachableStates)); } - return std::make_shared>(std::move(components)); } - template storm::storage::SparseMatrix PomdpMemoryUnfolder::transformTransitions() const { storm::storage::SparseMatrix const& origTransitions = pomdp.getTransitionMatrix(); - storm::storage::SparseMatrixBuilder builder(pomdp.getNumberOfChoices() * numMemoryStates * numMemoryStates, - pomdp.getNumberOfStates() * numMemoryStates, - origTransitions.getEntryCount() * numMemoryStates * numMemoryStates, + uint64_t numRows = 0; + uint64_t numEntries = 0; + for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) { + for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) { + numRows += origTransitions.getRowGroupSize(modelState) * memory.getNumberOfOutgoingTransitions(memState); + numEntries += origTransitions.getRowGroup(modelState).getNumberOfEntries() * memory.getNumberOfOutgoingTransitions(memState); + } + } + storm::storage::SparseMatrixBuilder builder(numRows, + pomdp.getNumberOfStates() * memory.getNumberOfStates(), + numEntries, true, true, - pomdp.getNumberOfStates() * numMemoryStates); + pomdp.getNumberOfStates() * memory.getNumberOfStates()); uint64_t row = 0; for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) { - for (uint32_t memState = 0; memState < numMemoryStates; ++memState) { + for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) { builder.newRowGroup(row); for (uint64_t origRow = origTransitions.getRowGroupIndices()[modelState]; origRow < origTransitions.getRowGroupIndices()[modelState + 1]; ++origRow) { - for (uint32_t memStatePrime = 0; memStatePrime < numMemoryStates; ++memStatePrime) { + for (auto const& memStatePrime : memory.getTransitions(memState)) { for (auto const& entry : origTransitions.getRow(origRow)) { builder.addNextValue(row, getUnfoldingState(entry.getColumn(), memStatePrime), entry.getValue()); } @@ -57,18 +74,18 @@ namespace storm { template storm::models::sparse::StateLabeling PomdpMemoryUnfolder::transformStateLabeling() const { - storm::models::sparse::StateLabeling labeling(pomdp.getNumberOfStates() * numMemoryStates); + storm::models::sparse::StateLabeling labeling(pomdp.getNumberOfStates() * memory.getNumberOfStates()); for (auto const& labelName : pomdp.getStateLabeling().getLabels()) { - storm::storage::BitVector newStates(pomdp.getNumberOfStates() * numMemoryStates, false); + storm::storage::BitVector newStates(pomdp.getNumberOfStates() * memory.getNumberOfStates(), false); - // The init label is only assigned to unfolding states with memState 0 + // The init label is only assigned to unfolding states with the initial memory state if (labelName == "init") { for (auto const& modelState : pomdp.getStateLabeling().getStates(labelName)) { - newStates.set(getUnfoldingState(modelState, 0)); + newStates.set(getUnfoldingState(modelState, memory.getInitialState())); } } else { for (auto const& modelState : pomdp.getStateLabeling().getStates(labelName)) { - for (uint32_t memState = 0; memState < numMemoryStates; ++memState) { + for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) { newStates.set(getUnfoldingState(modelState, memState)); } } @@ -79,38 +96,55 @@ namespace storm { } template - std::vector PomdpMemoryUnfolder::transformObservabilityClasses() const { + std::vector PomdpMemoryUnfolder::transformObservabilityClasses(storm::storage::BitVector const& reachableStates) const { std::vector observations; - observations.reserve(pomdp.getNumberOfStates() * numMemoryStates); + observations.reserve(pomdp.getNumberOfStates() * memory.getNumberOfStates()); for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) { - for (uint32_t memState = 0; memState < numMemoryStates; ++memState) { - observations.push_back(getUnfoldingObersvation(pomdp.getObservation(modelState), memState)); + for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) { + if (reachableStates.get(getUnfoldingState(modelState, memState))) { + observations.push_back(getUnfoldingObersvation(pomdp.getObservation(modelState), memState)); + } } } + + // Eliminate observations that are not in use (as they are not reachable). + std::set occuringObservations(observations.begin(), observations.end()); + uint32_t highestObservation = *occuringObservations.rbegin(); + std::vector oldToNewObservationMapping(highestObservation + 1, std::numeric_limits::max()); + uint32_t newObs = 0; + for (auto const& oldObs : occuringObservations) { + oldToNewObservationMapping[oldObs] = newObs; + ++newObs; + } + for (auto& obs : observations) { + obs = oldToNewObservationMapping[obs]; + } + return observations; } template - storm::models::sparse::StandardRewardModel PomdpMemoryUnfolder::transformRewardModel(storm::models::sparse::StandardRewardModel const& rewardModel) const { + storm::models::sparse::StandardRewardModel PomdpMemoryUnfolder::transformRewardModel(storm::models::sparse::StandardRewardModel const& rewardModel, storm::storage::BitVector const& reachableStates) const { boost::optional> stateRewards, actionRewards; if (rewardModel.hasStateRewards()) { stateRewards = std::vector(); - stateRewards->reserve(pomdp.getNumberOfStates() * numMemoryStates); - for (auto const& stateReward : rewardModel.getStateRewardVector()) { - for (uint32_t memState = 0; memState < numMemoryStates; ++memState) { - stateRewards->push_back(stateReward); + stateRewards->reserve(pomdp.getNumberOfStates() * memory.getNumberOfStates()); + for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) { + for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) { + if (reachableStates.get(getUnfoldingState(modelState, memState))) { + stateRewards->push_back(rewardModel.getStateReward(modelState)); + } } } } if (rewardModel.hasStateActionRewards()) { actionRewards = std::vector(); - actionRewards->reserve(pomdp.getNumberOfStates() * numMemoryStates * numMemoryStates); for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) { - for (uint32_t memState = 0; memState < numMemoryStates; ++memState) { - for (uint64_t origRow = pomdp.getTransitionMatrix().getRowGroupIndices()[modelState]; origRow < pomdp.getTransitionMatrix().getRowGroupIndices()[modelState + 1]; ++origRow) { - ValueType const& actionReward = rewardModel.getStateActionReward(origRow); - for (uint32_t memStatePrime = 0; memStatePrime < numMemoryStates; ++memStatePrime) { - actionRewards->push_back(actionReward); + for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) { + if (reachableStates.get(getUnfoldingState(modelState, memState))) { + for (uint64_t origRow = pomdp.getTransitionMatrix().getRowGroupIndices()[modelState]; origRow < pomdp.getTransitionMatrix().getRowGroupIndices()[modelState + 1]; ++origRow) { + ValueType const& actionReward = rewardModel.getStateActionReward(origRow); + actionRewards->insert(actionRewards->end(), memory.getNumberOfOutgoingTransitions(memState), actionReward); } } } @@ -121,33 +155,33 @@ namespace storm { } template - uint64_t PomdpMemoryUnfolder::getUnfoldingState(uint64_t modelState, uint32_t memoryState) const { - return modelState * numMemoryStates + memoryState; + uint64_t PomdpMemoryUnfolder::getUnfoldingState(uint64_t modelState, uint64_t memoryState) const { + return modelState * memory.getNumberOfStates() + memoryState; } template uint64_t PomdpMemoryUnfolder::getModelState(uint64_t unfoldingState) const { - return unfoldingState / numMemoryStates; + return unfoldingState / memory.getNumberOfStates(); } template - uint32_t PomdpMemoryUnfolder::getMemoryState(uint64_t unfoldingState) const { - return unfoldingState % numMemoryStates; + uint64_t PomdpMemoryUnfolder::getMemoryState(uint64_t unfoldingState) const { + return unfoldingState % memory.getNumberOfStates(); } template - uint32_t PomdpMemoryUnfolder::getUnfoldingObersvation(uint32_t modelObservation, uint32_t memoryState) const { - return modelObservation * numMemoryStates + memoryState; + uint32_t PomdpMemoryUnfolder::getUnfoldingObersvation(uint32_t modelObservation, uint64_t memoryState) const { + return modelObservation * memory.getNumberOfStates() + memoryState; } template uint32_t PomdpMemoryUnfolder::getModelObersvation(uint32_t unfoldingObservation) const { - return unfoldingObservation / numMemoryStates; + return unfoldingObservation / memory.getNumberOfStates(); } template - uint32_t PomdpMemoryUnfolder::getMemoryStateFromObservation(uint32_t unfoldingObservation) const { - return unfoldingObservation % numMemoryStates; + uint64_t PomdpMemoryUnfolder::getMemoryStateFromObservation(uint32_t unfoldingObservation) const { + return unfoldingObservation % memory.getNumberOfStates(); } template class PomdpMemoryUnfolder; diff --git a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.h b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.h index 8b72e64a5..69b18559a 100644 --- a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.h +++ b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.h @@ -1,6 +1,7 @@ #pragma once #include "storm/models/sparse/Pomdp.h" +#include "storm-pomdp/storage/PomdpMemory.h" #include "storm/models/sparse/StandardRewardModel.h" namespace storm { @@ -11,27 +12,27 @@ namespace storm { public: - PomdpMemoryUnfolder(storm::models::sparse::Pomdp const& pomdp, uint64_t numMemoryStates); + PomdpMemoryUnfolder(storm::models::sparse::Pomdp const& pomdp, storm::storage::PomdpMemory const& memory); std::shared_ptr> transform() const; private: storm::storage::SparseMatrix transformTransitions() const; storm::models::sparse::StateLabeling transformStateLabeling() const; - std::vector transformObservabilityClasses() const; - storm::models::sparse::StandardRewardModel transformRewardModel(storm::models::sparse::StandardRewardModel const& rewardModel) const; + std::vector transformObservabilityClasses(storm::storage::BitVector const& reachableStates) const; + storm::models::sparse::StandardRewardModel transformRewardModel(storm::models::sparse::StandardRewardModel const& rewardModel, storm::storage::BitVector const& reachableStates) const; - uint64_t getUnfoldingState(uint64_t modelState, uint32_t memoryState) const; + uint64_t getUnfoldingState(uint64_t modelState, uint64_t memoryState) const; uint64_t getModelState(uint64_t unfoldingState) const; - uint32_t getMemoryState(uint64_t unfoldingState) const; + uint64_t getMemoryState(uint64_t unfoldingState) const; - uint32_t getUnfoldingObersvation(uint32_t modelObservation, uint32_t memoryState) const; + uint32_t getUnfoldingObersvation(uint32_t modelObservation, uint64_t memoryState) const; uint32_t getModelObersvation(uint32_t unfoldingObservation) const; - uint32_t getMemoryStateFromObservation(uint32_t unfoldingObservation) const; + uint64_t getMemoryStateFromObservation(uint32_t unfoldingObservation) const; storm::models::sparse::Pomdp const& pomdp; - uint32_t numMemoryStates; + storm::storage::PomdpMemory const& memory; }; } } \ No newline at end of file