Browse Source

we now compute the winning region

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
3f4bb4cf8d
  1. 65
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 20
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
  3. 86
      src/storm-pomdp/analysis/WinningRegion.cpp
  4. 25
      src/storm-pomdp/analysis/WinningRegion.h

65
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -4,6 +4,35 @@
namespace storm {
namespace pomdp {
template <typename ValueType>
MemlessStrategySearchQualitative<ValueType>::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet,
storm::storage::BitVector const& targetStates,
storm::storage::BitVector const& surelyReachSinkStates,
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) :
pomdp(pomdp),
targetStates(targetStates),
surelyReachSinkStates(surelyReachSinkStates),
targetObservations(targetObservationSet)
{
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
smtSolver = smtSolverFactory->create(*expressionManager);
// Initialize states per observation.
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
}
uint64_t state = 0;
for (auto obs : pomdp.getObservations()) {
statesPerObservation.at(obs).push_back(state++);
}
// Initialize winning region
std::vector<uint64_t> nrStatesPerObservation;
for (auto const &states : statesPerObservation) {
nrStatesPerObservation.push_back(states.size());
}
winningRegion = WinningRegion(nrStatesPerObservation);
}
template <typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) {
if (maxK == std::numeric_limits<uint64_t>::max()) {
@ -12,14 +41,13 @@ namespace storm {
for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
}
// Fill the states-per-observation mapping,
// declare the reachability variables,
// declare the path variables.
uint64_t stateId = 0;
for(auto obs : pomdp.getObservations()) {
for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
pathVars.push_back(std::vector<storm::expressions::Expression>());
for (uint64_t i = 0; i < k; ++i) {
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression());
@ -28,7 +56,6 @@ namespace storm {
reachVarExpressions.push_back(reachVars.back().getExpression());
continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId)));
continuationVarExpressions.push_back(continuationVars.back().getExpression());
statesPerObservation.at(obs).push_back(stateId++);
}
assert(pathVars.size() == pomdp.getNumberOfStates());
assert(reachVars.size() == pomdp.getNumberOfStates());
@ -115,14 +142,6 @@ namespace storm {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
// subexprreach.push_back(!reachVarExpressions.at(state));
// subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
// subexprreach.push_back(!switchVarExpressions[pomdp.getObservation(state)]);
// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
// subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
// }
// smtSolver->add(storm::expressions::disjunction(subexprreach));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
@ -152,14 +171,7 @@ namespace storm {
++obs;
}
// These constraints ensure that the right solver is used.
// obs = 0;
// for(auto const& statesForObservation : statesPerObservation) {
// smtSolver->add(schedulerVariableExpressions[obs] >= schedulerForObs.size());
// ++obs;
// }
// TODO updateFoundSchedulers();
// TODO: Update found schedulers if k is increased.
}
template <typename ValueType>
@ -204,7 +216,6 @@ namespace storm {
uint64_t iterations = 0;
while(true) {
scheduler.clear();
observations.clear();
observationsAfterSwitch.clear();
remainingstates.clear();
@ -262,10 +273,10 @@ namespace storm {
}
}
// TODO do not repush everyting to the solver.
std::vector<storm::expressions::Expression> schedulerSoFar;
uint64_t obs = 0;
for (auto const &actionSelectionVarsForObs : actionSelectionVars) {
uint64_t act = 0;
scheduler.actions.push_back(std::set<uint64_t>());
if (observations.get(obs)) {
for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) {
@ -326,6 +337,18 @@ namespace storm {
remainingExpressions.push_back(reachVarExpressions[index]);
}
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
storm::storage::BitVector update = storm::storage::BitVector(statesPerObservation[observation].size());
uint64_t i = 0;
for (uint64_t state : statesPerObservation[observation]) {
if (!remainingstates.get(state)) {
update.set(i);
}
}
winningRegion.update(observation, update);
++i;
}
smtSolver->add(storm::expressions::disjunction(remainingExpressions));
uint64_t obs = 0;

20
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -5,6 +5,8 @@
#include "storm/utility/solver.h"
#include "storm/exceptions/UnexpectedException.h"
#include "storm-pomdp/analysis/WinningRegion.h"
namespace storm {
namespace pomdp {
@ -67,25 +69,12 @@ namespace pomdp {
class MemlessStrategySearchQualitative {
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
public:
MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet,
storm::storage::BitVector const& targetStates,
storm::storage::BitVector const& surelyReachSinkStates,
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) :
pomdp(pomdp),
targetStates(targetStates),
surelyReachSinkStates(surelyReachSinkStates),
targetObservations(targetObservationSet) {
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
smtSolver = smtSolverFactory->create(*expressionManager);
}
void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) {
surelyReachSinkStates = surelyReachSink;
}
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory);
void analyzeForInitialStates(uint64_t k) {
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
@ -96,8 +85,6 @@ namespace pomdp {
std::cout << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
analyze(k, ~surelyReachSinkStates & ~targetStates);
}
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
@ -137,6 +124,7 @@ namespace pomdp {
std::vector<InternalObservationScheduler> finalSchedulers;
std::vector<std::vector<uint64_t>> schedulerForObs;
WinningRegion winningRegion;

86
src/storm-pomdp/analysis/WinningRegion.cpp

@ -0,0 +1,86 @@
#include <iostream>
#include "storm-pomdp/analysis/WinningRegion.h"
namespace storm {
namespace pomdp {
WinningRegion::WinningRegion(std::vector<uint64_t> const& observationSizes) : observationSizes(observationSizes)
{
for (uint64_t i = 0; i < observationSizes.size(); ++i) {
winningRegion.push_back(std::vector<storm::storage::BitVector>());
}
}
void WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) {
std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>();
bool changed = false;
for (auto const& support : winningRegion[observation]) {
if (winning.isSubsetOf(support)) {
// This new winning support is already covered.
return;
}
if(support.isSubsetOf(winning)) {
// This new winning support extends the previous support, thus the previous support is now spurious
changed = true;
} else {
newWinningSupport.push_back(support);
}
}
// only if changed.
if (changed) {
newWinningSupport.push_back(winning);
winningRegion[observation] = newWinningSupport;
} else {
winningRegion[observation].push_back(winning);
}
}
bool WinningRegion::query(uint64_t observation, storm::storage::BitVector const& currently) const {
for(storm::storage::BitVector winning : winningRegion[observation]) {
if(currently.isSubsetOf(winning)) {
return true;
}
}
return false;
}
/**
* If we observe this observation, do we surely win?
* @param observation
* @return yes, if all supports for this observation are winning.
*/
bool WinningRegion::observationIsWinning(uint64_t observation) const {
return winningRegion[observation].size() == 1 && winningRegion[observation].front().full();
}
void WinningRegion::print() const {
uint64_t observation = 0;
for (auto const& winningSupport : winningRegion) {
std::cout << "***** observation" << observation << std::endl;
for (auto const& support : winningSupport) {
std::cout << " " << support;
}
std::cout << std::endl;
}
}
/**
* How many different observations are there?
* @return
*/
uint64_t WinningRegion::getNumberOfObservations() const {
return observationSizes.size();
}
uint64_t WinningRegion::getStorageSize() const {
uint64_t result = 0;
for (uint64_t i = 0; i < getNumberOfObservations(); ++i) {
result += winningRegion[i].size() * observationSizes[i];
}
return result;
}
}
}

25
src/storm-pomdp/analysis/WinningRegion.h

@ -0,0 +1,25 @@
#pragma once
#include <vector>
#include "storm/storage/BitVector.h"
namespace storm {
namespace pomdp {
class WinningRegion {
public:
WinningRegion(std::vector<uint64_t> const& observationSizes = {});
void update(uint64_t observation, storm::storage::BitVector const& winning);
bool query(uint64_t observation, storm::storage::BitVector const& currently) const;
bool observationIsWinning(uint64_t observation) const;
uint64_t getStorageSize() const;
uint64_t getNumberOfObservations() const;
void print() const;
private:
std::vector<std::vector<storm::storage::BitVector>> winningRegion;
std::vector<uint64_t> observationSizes;
};
}
}
Loading…
Cancel
Save