Browse Source

add a validator to the winning region search

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
5783719c05
  1. 1
      src/storm-pomdp-cli/storm-pomdp.cpp
  2. 13
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  3. 3
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
  4. 16
      src/storm-pomdp/analysis/WinningRegion.cpp
  5. 3
      src/storm-pomdp/analysis/WinningRegion.h
  6. 27
      src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp
  7. 2
      src/storm-pomdp/analysis/WinningRegionQueryInterface.h

1
src/storm-pomdp-cli/storm-pomdp.cpp

@ -108,6 +108,7 @@ namespace storm {
}
else if(storm::utility::getLogLevel() == l3pp::LogLevel::TRACE) {
loglevel = 3;
options.validateEveryStep = true;
}
options.setDebugLevel(loglevel);

13
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -75,6 +75,10 @@ namespace storm {
nrStatesPerObservation.push_back(states.size());
}
winningRegion = WinningRegion(nrStatesPerObservation);
if(options.validateEveryStep) {
STORM_LOG_WARN("The validator should only be created when the option is set.");
validator = std::make_shared<WinningRegionQueryInterface<ValueType>>(pomdp, winningRegion);
}
}
@ -586,14 +590,13 @@ namespace storm {
}
// TODO temporarily switched off due to intiialization issues when restartin.
STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place");
if(options.computeDebugOutput()) {
winningRegion.print();
}
if(options.validateEveryStep) {
STORM_LOG_WARN("Validating every step, for debug purposes only!");
validator->validate();
}
stats.updateNewStrategySolverTime.start();
for(uint64_t observation : updated) {
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);

3
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -8,6 +8,7 @@
#include "storm/exceptions/UnexpectedException.h"
#include "storm-pomdp/analysis/WinningRegion.h"
#include "storm-pomdp/analysis/WinningRegionQueryInterface.h"
namespace storm {
namespace pomdp {
@ -45,6 +46,7 @@ namespace pomdp {
bool onlyDeterministicStrategies = false;
bool forceLookahead = true;
bool validateEveryStep = false;
private:
std::string exportSATcalls = "";
@ -236,6 +238,7 @@ namespace pomdp {
Statistics stats;
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory;
std::shared_ptr<WinningRegionQueryInterface<ValueType>> validator;
};

16
src/storm-pomdp/analysis/WinningRegion.cpp

@ -130,9 +130,25 @@ namespace pomdp {
* @return
*/
uint64_t WinningRegion::getNumberOfObservations() const {
assert(winningRegion.size() == observationSizes.size());
return observationSizes.size();
}
bool WinningRegion::empty() const {
for (auto const& ob : winningRegion) {
if (!ob.empty()) {
return false;
}
}
return true;
}
std::vector<storm::storage::BitVector> const& WinningRegion::getWinningSetsPerObservation(uint64_t observation) const {
assert(observation < getNumberOfObservations());
return winningRegion[observation];
}
uint64_t WinningRegion::getStorageSize() const {
uint64_t result = 0;
for (uint64_t i = 0; i < getNumberOfObservations(); ++i) {

3
src/storm-pomdp/analysis/WinningRegion.h

@ -12,6 +12,8 @@ namespace storm {
bool update(uint64_t observation, storm::storage::BitVector const& winning);
bool query(uint64_t observation, storm::storage::BitVector const& currently) const;
std::vector<storm::storage::BitVector> const& getWinningSetsPerObservation(uint64_t observation) const;
void setObservationIsWinning(uint64_t observation);
bool observationIsWinning(uint64_t observation) const;
@ -20,6 +22,7 @@ namespace storm {
uint64_t getStorageSize() const;
uint64_t getNumberOfObservations() const;
bool empty() const;
void print() const;
private:
std::vector<std::vector<storm::storage::BitVector>> winningRegion;

27
src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp

@ -40,7 +40,7 @@ namespace storm {
template<typename ValueType>
bool WinningRegionQueryInterface<ValueType>::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const {
std::map<uint32_t, storm::storage::BitVector> successors;
STORM_LOG_DEBUG("Stays in winning region? (" << currentBeliefSupport << ", " << actionIndex << ")");
for (uint64_t oldState : currentBeliefSupport) {
uint64_t row = pomdp.getTransitionMatrix().getRowGroupIndices()[oldState] + actionIndex;
for (auto const& successor : pomdp.getTransitionMatrix().getRow(row)) {
@ -55,14 +55,39 @@ namespace storm {
}
for (auto const& entry : successors) {
if(!isInWinningRegion(entry.second)) {
STORM_LOG_DEBUG("Belief support " << entry.second << " is not winning");
return false;
} else {
STORM_LOG_DEBUG("Belief support " << entry.second << " is winning");
}
}
return true;
}
template<typename ValueType>
void WinningRegionQueryInterface<ValueType>::validate() const {
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) {
storm::storage::BitVector states(pomdp.getNumberOfStates());
for (uint64_t offset : winningBelief) {
states.set(statesPerObservation[obs][offset]);
}
bool safeActionExists = false;
for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) {
if (staysInWinningRegion(states,actionIndex)) {
safeActionExists = true;
break;
}
}
STORM_LOG_ASSERT(safeActionExists, "Observation " << obs << " with associated states: " << statesPerObservation[obs] << " , support " << states);
}
}
}
template class WinningRegionQueryInterface<double>;
template class WinningRegionQueryInterface<storm::RationalNumber>;
}
}

2
src/storm-pomdp/analysis/WinningRegionQueryInterface.h

@ -1,3 +1,4 @@
#pragma once
#include "storm/models/sparse/Pomdp.h"
#include "storm-pomdp/analysis/WinningRegion.h"
@ -13,6 +14,7 @@ namespace storm {
bool staysInWinningRegion(storm::storage::BitVector const& beliefSupport, uint64_t actionIndex) const;
void validate() const;
private:
storm::models::sparse::Pomdp<ValueType> const& pomdp;
WinningRegion const& winningRegion;

Loading…
Cancel
Save