Browse Source

introducing pomdp memory patterns

tempestpy_adaptions
TimQu 7 years ago
parent
commit
d06c2c791a
  1. 25
      src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp
  2. 3
      src/storm-pomdp-cli/settings/modules/POMDPSettings.h
  3. 7
      src/storm-pomdp-cli/storm-pomdp.cpp
  4. 168
      src/storm-pomdp/storage/PomdpMemory.cpp
  5. 60
      src/storm-pomdp/storage/PomdpMemory.h
  6. 118
      src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp
  7. 17
      src/storm-pomdp/transformer/PomdpMemoryUnfolder.h

25
src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp

@ -6,6 +6,8 @@
#include "storm/settings/OptionBuilder.h"
#include "storm/settings/ArgumentBuilder.h"
#include "storm/exceptions/InvalidArgumentException.h"
namespace storm {
namespace settings {
namespace modules {
@ -17,6 +19,8 @@ namespace storm {
const std::string mecReductionOption = "mecreduction";
const std::string selfloopReductionOption = "selfloopreduction";
const std::string memoryBoundOption = "memorybound";
const std::string memoryPatternOption = "memorypattern";
std::vector<std::string> memoryPatterns = {"trivial", "fixedcounter", "selectivecounter", "ring", "settablebits", "full"};
const std::string fscmode = "fscmode";
std::vector<std::string> fscModes = {"standard", "simple-linear", "simple-linear-inverse"};
const std::string transformBinaryOption = "transformbinary";
@ -29,6 +33,7 @@ namespace storm {
this->addOption(storm::settings::OptionBuilder(moduleName, mecReductionOption, false, "Reduces the model size by analyzing maximal end components").build());
this->addOption(storm::settings::OptionBuilder(moduleName, selfloopReductionOption, false, "Reduces the model size by removing self loop actions").build());
this->addOption(storm::settings::OptionBuilder(moduleName, memoryBoundOption, false, "Sets the maximal number of allowed memory states (1 means memoryless schedulers).").addArgument(storm::settings::ArgumentBuilder::createUnsignedIntegerArgument("bound", "The maximal number of memory states.").setDefaultValueUnsignedInteger(1).addValidatorUnsignedInteger(storm::settings::ArgumentValidatorFactory::createUnsignedGreaterValidator(0)).build()).build());
this->addOption(storm::settings::OptionBuilder(moduleName, memoryPatternOption, false, "Sets the pattern of the considered memory structure").addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "Pattern name.").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(memoryPatterns)).setDefaultValueString("full").build()).build());
this->addOption(storm::settings::OptionBuilder(moduleName, fscmode, false, "Sets the way the pMC is obtained").addArgument(storm::settings::ArgumentBuilder::createStringArgument("type", "type name").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(fscModes)).setDefaultValueString("standard").build()).build());
this->addOption(storm::settings::OptionBuilder(moduleName, transformBinaryOption, false, "Transforms the pomdp to a binary pomdp.").build());
this->addOption(storm::settings::OptionBuilder(moduleName, transformSimpleOption, false, "Transforms the pomdp to a binary and simple pomdp.").build());
@ -61,6 +66,24 @@ namespace storm {
uint64_t POMDPSettings::getMemoryBound() const {
return this->getOption(memoryBoundOption).getArgumentByName("bound").getValueAsUnsignedInteger();
}
storm::storage::PomdpMemoryPattern POMDPSettings::getMemoryPattern() const {
auto pattern = this->getOption(memoryPatternOption).getArgumentByName("name").getValueAsString();
if (pattern == "trivial") {
return storm::storage::PomdpMemoryPattern::Trivial;
} else if (pattern == "fixedcounter") {
return storm::storage::PomdpMemoryPattern::FixedCounter;
} else if (pattern == "selectivecounter") {
return storm::storage::PomdpMemoryPattern::SelectiveCounter;
} else if (pattern == "ring") {
return storm::storage::PomdpMemoryPattern::Ring;
} else if (pattern == "settablebits") {
return storm::storage::PomdpMemoryPattern::SettableBits;
} else if (pattern == "full") {
return storm::storage::PomdpMemoryPattern::Full;
}
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "The name of the memory pattern is unknown.");
}
std::string POMDPSettings::getFscApplicationTypeString() const {
return this->getOption(fscmode).getArgumentByName("type").getValueAsString();
@ -78,7 +101,7 @@ namespace storm {
}
bool POMDPSettings::check() const {
// Ensure that at most one of min or max is set
STORM_LOG_THROW(getMemoryPattern() != storm::storage::PomdpMemoryPattern::Trivial || getMemoryBound() == 1, storm::exceptions::InvalidArgumentException, "Memory bound greater one is not possible with the trivial memory pattern.");
return true;
}

3
src/storm-pomdp-cli/settings/modules/POMDPSettings.h

@ -2,6 +2,7 @@
#include "storm-config.h"
#include "storm/settings/modules/ModuleSettings.h"
#include "storm-pomdp/storage/PomdpMemory.h"
#include "storm-dft/builder/DftExplorationHeuristic.h"
@ -33,7 +34,7 @@ namespace storm {
bool isTransformBinarySet() const;
std::string getFscApplicationTypeString() const;
uint64_t getMemoryBound() const;
storm::storage::PomdpMemoryPattern getMemoryPattern() const;
bool check() const override;
void finalize() override;

7
src/storm-pomdp-cli/storm-pomdp.cpp

@ -124,7 +124,6 @@ int main(const int argc, const char** argv) {
storm::analysis::UniqueObservationStates<storm::RationalNumber> uniqueAnalysis(*pomdp);
std::cout << uniqueAnalysis.analyse() << std::endl;
}
if (formula) {
if (formula->isProbabilityOperatorFormula()) {
@ -155,8 +154,10 @@ int main(const int argc, const char** argv) {
}
}
if (pomdpSettings.getMemoryBound() > 1) {
STORM_PRINT_AND_LOG("Computing the unfolding for memory bound " << pomdpSettings.getMemoryBound() << "...");
storm::transformer::PomdpMemoryUnfolder<storm::RationalNumber> memoryUnfolder(*pomdp, pomdpSettings.getMemoryBound());
STORM_PRINT_AND_LOG("Computing the unfolding for memory bound " << pomdpSettings.getMemoryBound() << " and memory pattern '" << storm::storage::toString(pomdpSettings.getMemoryPattern()) << "' ...");
storm::storage::PomdpMemory memory = storm::storage::PomdpMemoryBuilder().build(pomdpSettings.getMemoryPattern(), pomdpSettings.getMemoryBound());
std::cout << memory.toString() << std::endl;
storm::transformer::PomdpMemoryUnfolder<storm::RationalNumber> memoryUnfolder(*pomdp, memory);
pomdp = memoryUnfolder.transform();
STORM_PRINT_AND_LOG(" done." << std::endl);
pomdp->printModelInformationToStream(std::cout);

168
src/storm-pomdp/storage/PomdpMemory.cpp

@ -0,0 +1,168 @@
#include "storm-pomdp/storage/PomdpMemory.h"
#include "storm/utility/macros.h"
#include "storm/exceptions/InvalidArgumentException.h"
namespace storm {
namespace storage {
PomdpMemory::PomdpMemory(std::vector<storm::storage::BitVector> const& transitions, uint64_t initialState) : transitions(transitions), initialState(initialState) {
STORM_LOG_THROW(this->initialState < this->transitions.size(), storm::exceptions::InvalidArgumentException, "Initial state " << this->initialState << " of pomdp memory is invalid.");
for (auto const& t : this->transitions) {
STORM_LOG_THROW(t.size() == this->transitions.size(), storm::exceptions::InvalidArgumentException, "Invalid dimension of transition matrix of pomdp memory.");
STORM_LOG_THROW(!t.empty(), storm::exceptions::InvalidArgumentException, "Invalid transition matrix of pomdp memory: No deadlock states allowed.");
}
}
uint64_t PomdpMemory::getNumberOfStates() const {
return transitions.size();
}
uint64_t PomdpMemory::getInitialState() const {
return initialState;
}
storm::storage::BitVector const& PomdpMemory::getTransitions(uint64_t state) const {
return transitions.at(state);
}
uint64_t PomdpMemory::getNumberOfOutgoingTransitions(uint64_t state) const {
return getTransitions(state).getNumberOfSetBits();
}
std::vector<storm::storage::BitVector> const& PomdpMemory::getTransitions() const {
return transitions;
}
std::string PomdpMemory::toString() const {
std::string result = "PomdpMemory with " + std::to_string(getNumberOfStates()) + " states.\n";
result += "Initial state is " + std::to_string(getInitialState()) + ". Transitions are \n";
// header
result += " |";
for (uint64_t state = 0; state < getNumberOfStates(); ++state) {
if (state < 10) {
result += " ";
}
result += std::to_string(state);
}
result += "\n";
result += "--|";
for (uint64_t state = 0; state < getNumberOfStates(); ++state) {
result += "--";
}
result += "\n";
// transition matrix entries
for (uint64_t state = 0; state < getNumberOfStates(); ++state) {
if (state < 10) {
result += " ";
}
result += std::to_string(state) + "|";
for (uint64_t statePrime = 0; statePrime < getNumberOfStates(); ++statePrime) {
result += " ";
if (getTransitions(state).get(statePrime)) {
result += "1";
} else {
result += "0";
}
}
result += "\n";
}
return result;
}
std::string toString(PomdpMemoryPattern const& pattern) {
switch (pattern) {
case PomdpMemoryPattern::Trivial:
return "trivial";
case PomdpMemoryPattern::FixedCounter:
return "fixedcounter";
case PomdpMemoryPattern::SelectiveCounter:
return "selectivecounter";
case PomdpMemoryPattern::Ring:
return "ring";
case PomdpMemoryPattern::SettableBits:
return "settablebits";
case PomdpMemoryPattern::Full:
return "full";
}
return "unknown";
}
PomdpMemory PomdpMemoryBuilder::build(PomdpMemoryPattern pattern, uint64_t numStates) const {
switch (pattern) {
case PomdpMemoryPattern::Trivial:
STORM_LOG_ERROR_COND(numStates == 1, "Invoked building trivial POMDP memory with " << numStates << " states. However, trivial POMDP memory always has one state.");
return buildTrivialMemory();
case PomdpMemoryPattern::FixedCounter:
return buildFixedCountingMemory(numStates);
case PomdpMemoryPattern::SelectiveCounter:
return buildSelectiveCountingMemory(numStates);
case PomdpMemoryPattern::Ring:
return buildRingMemory(numStates);
case PomdpMemoryPattern::SettableBits:
return buildSettableBitsMemory(numStates);
case PomdpMemoryPattern::Full:
return buildFullyConnectedMemory(numStates);
}
}
PomdpMemory PomdpMemoryBuilder::buildTrivialMemory() const {
return buildFullyConnectedMemory(1);
}
PomdpMemory PomdpMemoryBuilder::buildFixedCountingMemory(uint64_t numStates) const {
std::vector<storm::storage::BitVector> transitions(numStates, storm::storage::BitVector(numStates, false));
for (uint64_t state = 0; state < numStates; ++state) {
transitions[state].set(std::min(state + 1, numStates - 1));
}
return PomdpMemory(transitions, 0);
}
PomdpMemory PomdpMemoryBuilder::buildSelectiveCountingMemory(uint64_t numStates) const {
std::vector<storm::storage::BitVector> transitions(numStates, storm::storage::BitVector(numStates, false));
for (uint64_t state = 0; state < numStates; ++state) {
transitions[state].set(state);
transitions[state].set(std::min(state + 1, numStates - 1));
}
return PomdpMemory(transitions, 0);
}
PomdpMemory PomdpMemoryBuilder::buildRingMemory(uint64_t numStates) const {
std::vector<storm::storage::BitVector> transitions(numStates, storm::storage::BitVector(numStates, false));
for (uint64_t state = 0; state < numStates; ++state) {
transitions[state].set(state);
transitions[state].set((state + 1) % numStates);
}
return PomdpMemory(transitions, 0);
}
PomdpMemory PomdpMemoryBuilder::buildSettableBitsMemory(uint64_t numStates) const {
// compute the number of bits, i.e., floor(log(numStates))
uint64_t numBits = 0;
uint64_t actualNumStates = 1;
while (actualNumStates * 2 <= numStates) {
actualNumStates *= 2;
++numBits;
}
STORM_LOG_WARN_COND(actualNumStates == numStates, "The number of memory states for the settable bits pattern has to be a power of 2. Shrinking the number of memory states to " << actualNumStates << ".");
std::vector<storm::storage::BitVector> transitions(actualNumStates, storm::storage::BitVector(actualNumStates, false));
for (uint64_t state = 0; state < actualNumStates; ++state) {
transitions[state].set(state);
for (uint64_t bit = 0; bit < numBits; ++bit) {
uint64_t bitMask = 1u << bit;
transitions[state].set(state | bitMask);
}
}
return PomdpMemory(transitions, 0);
}
PomdpMemory PomdpMemoryBuilder::buildFullyConnectedMemory(uint64_t numStates) const {
std::vector<storm::storage::BitVector> transitions(numStates, storm::storage::BitVector(numStates, true));
return PomdpMemory(transitions, 0);
}
}
}

60
src/storm-pomdp/storage/PomdpMemory.h

@ -0,0 +1,60 @@
#pragma once
#include <vector>
#include "storm/storage/BitVector.h"
#include "storm/exceptions/InvalidArgumentException.h"
namespace storm {
namespace storage {
class PomdpMemory {
public:
PomdpMemory(std::vector<storm::storage::BitVector> const& transitions, uint64_t initialState);
uint64_t getNumberOfStates() const;
uint64_t getInitialState() const;
storm::storage::BitVector const& getTransitions(uint64_t state) const;
uint64_t getNumberOfOutgoingTransitions(uint64_t state) const;
std::vector<storm::storage::BitVector> const& getTransitions() const;
std::string toString() const;
private:
std::vector<storm::storage::BitVector> transitions;
uint64_t initialState;
};
enum class PomdpMemoryPattern {
Trivial, FixedCounter, SelectiveCounter, Ring, SettableBits, Full
};
std::string toString(PomdpMemoryPattern const& pattern);
class PomdpMemoryBuilder {
public:
// Builds a memory structure with the given pattern and the given number of states.
PomdpMemory build(PomdpMemoryPattern pattern, uint64_t numStates) const;
// Builds a memory structure that consists of just a single memory state
PomdpMemory buildTrivialMemory() const;
// Builds a memory structure that consists of a chain of the given number of states.
// Every state has exactly one transition to the next state. The last state has just a selfloop.
PomdpMemory buildFixedCountingMemory(uint64_t numStates) const;
// Builds a memory structure that consists of a chain of the given number of states.
// Every state has a selfloop and a transition to the next state. The last state just has a selfloop.
PomdpMemory buildSelectiveCountingMemory(uint64_t numStates) const;
// Builds a memory structure that consists of a ring of the given number of states.
// Every state has a transition to the successor state and a selfloop
PomdpMemory buildRingMemory(uint64_t numStates) const;
// Builds a memory structure that represents floor(log(numStates)) bits that can only be set from zero to one or from zero to zero.
PomdpMemory buildSettableBitsMemory(uint64_t numStates) const;
// Builds a memory structure that consists of the given number of states which are fully connected.
PomdpMemory buildFullyConnectedMemory(uint64_t numStates) const;
};
}
}

118
src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp

@ -1,6 +1,8 @@
#include <storm/exceptions/NotSupportedException.h>
#include "storm-pomdp/transformer/PomdpMemoryUnfolder.h"
#include <limits>
#include "storm/storage/sparse/ModelComponents.h"
#include "storm/utility/graph.h"
#include "storm/exceptions/NotSupportedException.h"
@ -9,41 +11,56 @@ namespace storm {
template<typename ValueType>
PomdpMemoryUnfolder<ValueType>::PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, uint64_t numMemoryStates) : pomdp(pomdp), numMemoryStates(numMemoryStates) {
PomdpMemoryUnfolder<ValueType>::PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::PomdpMemory const& memory) : pomdp(pomdp), memory(memory) {
// intentionally left empty
}
template<typename ValueType>
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> PomdpMemoryUnfolder<ValueType>::transform() const {
// For simplicity we first build the 'full' product of pomdp and memory (with pomdp.numStates * memory.numStates states).
storm::storage::sparse::ModelComponents<ValueType> components;
components.transitionMatrix = transformTransitions();
components.stateLabeling = transformStateLabeling();
components.observabilityClasses = transformObservabilityClasses();
// Now delete unreachable states.
storm::storage::BitVector allStates(components.transitionMatrix.getRowGroupCount(), true);
auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates);
components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates);
components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates);
// build the remaining components
components.observabilityClasses = transformObservabilityClasses(reachableStates);
for (auto const& rewModel : pomdp.getRewardModels()) {
components.rewardModels.emplace(rewModel.first, transformRewardModel(rewModel.second));
components.rewardModels.emplace(rewModel.first, transformRewardModel(rewModel.second, reachableStates));
}
return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(std::move(components));
}
template<typename ValueType>
storm::storage::SparseMatrix<ValueType> PomdpMemoryUnfolder<ValueType>::transformTransitions() const {
storm::storage::SparseMatrix<ValueType> const& origTransitions = pomdp.getTransitionMatrix();
storm::storage::SparseMatrixBuilder<ValueType> builder(pomdp.getNumberOfChoices() * numMemoryStates * numMemoryStates,
pomdp.getNumberOfStates() * numMemoryStates,
origTransitions.getEntryCount() * numMemoryStates * numMemoryStates,
uint64_t numRows = 0;
uint64_t numEntries = 0;
for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
numRows += origTransitions.getRowGroupSize(modelState) * memory.getNumberOfOutgoingTransitions(memState);
numEntries += origTransitions.getRowGroup(modelState).getNumberOfEntries() * memory.getNumberOfOutgoingTransitions(memState);
}
}
storm::storage::SparseMatrixBuilder<ValueType> builder(numRows,
pomdp.getNumberOfStates() * memory.getNumberOfStates(),
numEntries,
true,
true,
pomdp.getNumberOfStates() * numMemoryStates);
pomdp.getNumberOfStates() * memory.getNumberOfStates());
uint64_t row = 0;
for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
for (uint32_t memState = 0; memState < numMemoryStates; ++memState) {
for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
builder.newRowGroup(row);
for (uint64_t origRow = origTransitions.getRowGroupIndices()[modelState]; origRow < origTransitions.getRowGroupIndices()[modelState + 1]; ++origRow) {
for (uint32_t memStatePrime = 0; memStatePrime < numMemoryStates; ++memStatePrime) {
for (auto const& memStatePrime : memory.getTransitions(memState)) {
for (auto const& entry : origTransitions.getRow(origRow)) {
builder.addNextValue(row, getUnfoldingState(entry.getColumn(), memStatePrime), entry.getValue());
}
@ -57,18 +74,18 @@ namespace storm {
template<typename ValueType>
storm::models::sparse::StateLabeling PomdpMemoryUnfolder<ValueType>::transformStateLabeling() const {
storm::models::sparse::StateLabeling labeling(pomdp.getNumberOfStates() * numMemoryStates);
storm::models::sparse::StateLabeling labeling(pomdp.getNumberOfStates() * memory.getNumberOfStates());
for (auto const& labelName : pomdp.getStateLabeling().getLabels()) {
storm::storage::BitVector newStates(pomdp.getNumberOfStates() * numMemoryStates, false);
storm::storage::BitVector newStates(pomdp.getNumberOfStates() * memory.getNumberOfStates(), false);
// The init label is only assigned to unfolding states with memState 0
// The init label is only assigned to unfolding states with the initial memory state
if (labelName == "init") {
for (auto const& modelState : pomdp.getStateLabeling().getStates(labelName)) {
newStates.set(getUnfoldingState(modelState, 0));
newStates.set(getUnfoldingState(modelState, memory.getInitialState()));
}
} else {
for (auto const& modelState : pomdp.getStateLabeling().getStates(labelName)) {
for (uint32_t memState = 0; memState < numMemoryStates; ++memState) {
for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
newStates.set(getUnfoldingState(modelState, memState));
}
}
@ -79,38 +96,55 @@ namespace storm {
}
template<typename ValueType>
std::vector<uint32_t> PomdpMemoryUnfolder<ValueType>::transformObservabilityClasses() const {
std::vector<uint32_t> PomdpMemoryUnfolder<ValueType>::transformObservabilityClasses(storm::storage::BitVector const& reachableStates) const {
std::vector<uint32_t> observations;
observations.reserve(pomdp.getNumberOfStates() * numMemoryStates);
observations.reserve(pomdp.getNumberOfStates() * memory.getNumberOfStates());
for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
for (uint32_t memState = 0; memState < numMemoryStates; ++memState) {
observations.push_back(getUnfoldingObersvation(pomdp.getObservation(modelState), memState));
for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
if (reachableStates.get(getUnfoldingState(modelState, memState))) {
observations.push_back(getUnfoldingObersvation(pomdp.getObservation(modelState), memState));
}
}
}
// Eliminate observations that are not in use (as they are not reachable).
std::set<uint32_t> occuringObservations(observations.begin(), observations.end());
uint32_t highestObservation = *occuringObservations.rbegin();
std::vector<uint32_t> oldToNewObservationMapping(highestObservation + 1, std::numeric_limits<uint32_t>::max());
uint32_t newObs = 0;
for (auto const& oldObs : occuringObservations) {
oldToNewObservationMapping[oldObs] = newObs;
++newObs;
}
for (auto& obs : observations) {
obs = oldToNewObservationMapping[obs];
}
return observations;
}
template<typename ValueType>
storm::models::sparse::StandardRewardModel<ValueType> PomdpMemoryUnfolder<ValueType>::transformRewardModel(storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel) const {
storm::models::sparse::StandardRewardModel<ValueType> PomdpMemoryUnfolder<ValueType>::transformRewardModel(storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel, storm::storage::BitVector const& reachableStates) const {
boost::optional<std::vector<ValueType>> stateRewards, actionRewards;
if (rewardModel.hasStateRewards()) {
stateRewards = std::vector<ValueType>();
stateRewards->reserve(pomdp.getNumberOfStates() * numMemoryStates);
for (auto const& stateReward : rewardModel.getStateRewardVector()) {
for (uint32_t memState = 0; memState < numMemoryStates; ++memState) {
stateRewards->push_back(stateReward);
stateRewards->reserve(pomdp.getNumberOfStates() * memory.getNumberOfStates());
for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
if (reachableStates.get(getUnfoldingState(modelState, memState))) {
stateRewards->push_back(rewardModel.getStateReward(modelState));
}
}
}
}
if (rewardModel.hasStateActionRewards()) {
actionRewards = std::vector<ValueType>();
actionRewards->reserve(pomdp.getNumberOfStates() * numMemoryStates * numMemoryStates);
for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
for (uint32_t memState = 0; memState < numMemoryStates; ++memState) {
for (uint64_t origRow = pomdp.getTransitionMatrix().getRowGroupIndices()[modelState]; origRow < pomdp.getTransitionMatrix().getRowGroupIndices()[modelState + 1]; ++origRow) {
ValueType const& actionReward = rewardModel.getStateActionReward(origRow);
for (uint32_t memStatePrime = 0; memStatePrime < numMemoryStates; ++memStatePrime) {
actionRewards->push_back(actionReward);
for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
if (reachableStates.get(getUnfoldingState(modelState, memState))) {
for (uint64_t origRow = pomdp.getTransitionMatrix().getRowGroupIndices()[modelState]; origRow < pomdp.getTransitionMatrix().getRowGroupIndices()[modelState + 1]; ++origRow) {
ValueType const& actionReward = rewardModel.getStateActionReward(origRow);
actionRewards->insert(actionRewards->end(), memory.getNumberOfOutgoingTransitions(memState), actionReward);
}
}
}
@ -121,33 +155,33 @@ namespace storm {
}
template<typename ValueType>
uint64_t PomdpMemoryUnfolder<ValueType>::getUnfoldingState(uint64_t modelState, uint32_t memoryState) const {
return modelState * numMemoryStates + memoryState;
uint64_t PomdpMemoryUnfolder<ValueType>::getUnfoldingState(uint64_t modelState, uint64_t memoryState) const {
return modelState * memory.getNumberOfStates() + memoryState;
}
template<typename ValueType>
uint64_t PomdpMemoryUnfolder<ValueType>::getModelState(uint64_t unfoldingState) const {
return unfoldingState / numMemoryStates;
return unfoldingState / memory.getNumberOfStates();
}
template<typename ValueType>
uint32_t PomdpMemoryUnfolder<ValueType>::getMemoryState(uint64_t unfoldingState) const {
return unfoldingState % numMemoryStates;
uint64_t PomdpMemoryUnfolder<ValueType>::getMemoryState(uint64_t unfoldingState) const {
return unfoldingState % memory.getNumberOfStates();
}
template<typename ValueType>
uint32_t PomdpMemoryUnfolder<ValueType>::getUnfoldingObersvation(uint32_t modelObservation, uint32_t memoryState) const {
return modelObservation * numMemoryStates + memoryState;
uint32_t PomdpMemoryUnfolder<ValueType>::getUnfoldingObersvation(uint32_t modelObservation, uint64_t memoryState) const {
return modelObservation * memory.getNumberOfStates() + memoryState;
}
template<typename ValueType>
uint32_t PomdpMemoryUnfolder<ValueType>::getModelObersvation(uint32_t unfoldingObservation) const {
return unfoldingObservation / numMemoryStates;
return unfoldingObservation / memory.getNumberOfStates();
}
template<typename ValueType>
uint32_t PomdpMemoryUnfolder<ValueType>::getMemoryStateFromObservation(uint32_t unfoldingObservation) const {
return unfoldingObservation % numMemoryStates;
uint64_t PomdpMemoryUnfolder<ValueType>::getMemoryStateFromObservation(uint32_t unfoldingObservation) const {
return unfoldingObservation % memory.getNumberOfStates();
}
template class PomdpMemoryUnfolder<storm::RationalNumber>;

17
src/storm-pomdp/transformer/PomdpMemoryUnfolder.h

@ -1,6 +1,7 @@
#pragma once
#include "storm/models/sparse/Pomdp.h"
#include "storm-pomdp/storage/PomdpMemory.h"
#include "storm/models/sparse/StandardRewardModel.h"
namespace storm {
@ -11,27 +12,27 @@ namespace storm {
public:
PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, uint64_t numMemoryStates);
PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::PomdpMemory const& memory);
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> transform() const;
private:
storm::storage::SparseMatrix<ValueType> transformTransitions() const;
storm::models::sparse::StateLabeling transformStateLabeling() const;
std::vector<uint32_t> transformObservabilityClasses() const;
storm::models::sparse::StandardRewardModel<ValueType> transformRewardModel(storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel) const;
std::vector<uint32_t> transformObservabilityClasses(storm::storage::BitVector const& reachableStates) const;
storm::models::sparse::StandardRewardModel<ValueType> transformRewardModel(storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel, storm::storage::BitVector const& reachableStates) const;
uint64_t getUnfoldingState(uint64_t modelState, uint32_t memoryState) const;
uint64_t getUnfoldingState(uint64_t modelState, uint64_t memoryState) const;
uint64_t getModelState(uint64_t unfoldingState) const;
uint32_t getMemoryState(uint64_t unfoldingState) const;
uint64_t getMemoryState(uint64_t unfoldingState) const;
uint32_t getUnfoldingObersvation(uint32_t modelObservation, uint32_t memoryState) const;
uint32_t getUnfoldingObersvation(uint32_t modelObservation, uint64_t memoryState) const;
uint32_t getModelObersvation(uint32_t unfoldingObservation) const;
uint32_t getMemoryStateFromObservation(uint32_t unfoldingObservation) const;
uint64_t getMemoryStateFromObservation(uint32_t unfoldingObservation) const;
storm::models::sparse::Pomdp<ValueType> const& pomdp;
uint32_t numMemoryStates;
storm::storage::PomdpMemory const& memory;
};
}
}
Loading…
Cancel
Save