Browse Source

allow for graph-analysis and sat-based analysis interleaving, and restarting sat-based solver when advantageous

main
Sebastian Junges 5 years ago
parent
commit
c0ac9814e1
  1. 265
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 29
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
  3. 18
      src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp
  4. 1
      src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h
  5. 36
      src/storm-pomdp/analysis/WinningRegion.cpp
  6. 2
      src/storm-pomdp/analysis/WinningRegion.h

265
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -3,6 +3,7 @@
#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h"
#include "storm-pomdp/analysis/QualitativeAnalysis.h"
#include "storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h"
namespace storm {
namespace pomdp {
@ -53,9 +54,9 @@ namespace storm {
MemlessSearchOptions const& options) :
pomdp(pomdp),
surelyReachSinkStates(surelyReachSinkStates),
targetObservations(storm::pomdp::extractObservations(pomdp, targetStates)),
targetStates(targetStates),
options(options)
options(options),
smtSolverFactory(smtSolverFactory)
{
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
smtSolver = smtSolverFactory->create(*expressionManager);
@ -86,113 +87,142 @@ namespace storm {
} else {
lookaheadConstraintsRequired = qualitative::isLookaheadRequired(pomdp, targetStates, surelyReachSinkStates);
}
if (maxK == std::numeric_limits<uint64_t>::max()) {
// not initialized at all.
// Create some data structures.
for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
if (actionSelectionVars.empty()) {
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
}
// Fill the states-per-observation mapping,
// declare the reachability variables,
// declare the path variables.
for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
if(lookaheadConstraintsRequired) {
pathVars.push_back(std::vector<storm::expressions::Expression>());
for (uint64_t i = 0; i < k; ++i) {
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression());
}
}
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
reachVarExpressions.push_back(reachVars.back().getExpression());
reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back(reachVarExpressions.back());
continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId)));
reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back(
reachVarExpressions.back());
continuationVars.push_back(
expressionManager->declareBooleanVariable("D-" + std::to_string(stateId)));
continuationVarExpressions.push_back(continuationVars.back().getExpression());
}
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates());
assert(reachVars.size() == pomdp.getNumberOfStates());
assert(reachVarExpressions.size() == pomdp.getNumberOfStates());
// Create the action selection variables.
uint64_t obs = 0;
for(auto const& statesForObservation : statesPerObservation) {
for (auto const &statesForObservation : statesPerObservation) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) {
std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a);
actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName));
actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression());
actionSelectionVarExpressions.at(obs).push_back(
actionSelectionVars.at(obs).back().getExpression());
}
schedulerVariables.push_back(expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs), statesPerObservation.size()));
schedulerVariables.push_back(
expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs),
statesPerObservation.size()));
schedulerVariableExpressions.push_back(schedulerVariables.back());
switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs)));
switchVarExpressions.push_back(switchVars.back().getExpression());
observationUpdatedVariables.push_back(expressionManager->declareBooleanVariable("U-" + std::to_string(obs)));
observationUpdatedVariables.push_back(
expressionManager->declareBooleanVariable("U-" + std::to_string(obs)));
observationUpdatedExpressions.push_back(observationUpdatedVariables.back().getExpression());
if (options.onlyDeterministicStrategies) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) {
for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) {
smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]);
}
}
}
followVars.push_back(expressionManager->declareBooleanVariable("F-"+std::to_string(obs)));
followVarExpressions.push_back(followVars.back().getExpression());
++obs;
}
// PAPER COMMENT: 1
for (auto const& actionVars : actionSelectionVarExpressions) {
smtSolver->add(storm::expressions::disjunction(actionVars));
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
pathVars.push_back(std::vector<storm::expressions::Expression>());
}
}
// Update at least one observation.
// PAPER COMMENT: 2
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions));
// PAPER COMMENT: 3
if (lookaheadConstraintsRequired) {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
} else {
smtSolver->add(!pathVars[state][0]);
uint64_t initK = 0;
if (maxK != std::numeric_limits<uint64_t>::max()) {
initK = maxK;
}
if (initK < k) {
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
if (lookaheadConstraintsRequired) {
for (uint64_t i = initK; i < k; ++i) {
pathVars[stateId].push_back(expressionManager->declareBooleanVariable(
"P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression());
}
}
}
}
// PAPER COMMENT: 4
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates());
assert(reachVars.size() == pomdp.getNumberOfStates());
assert(reachVarExpressions.size() == pomdp.getNumberOfStates());
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreachSwitch;
std::vector<storm::expressions::Expression> subexprreachNoSwitch;
subexprreachSwitch.push_back(!reachVarExpressions[state]);
subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]);
subexprreachNoSwitch.push_back(!reachVarExpressions[state]);
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]);
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
smtSolver->add(storm::expressions::disjunction(subexprreachSwitch));
subexprreachSwitch.pop_back();
subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn()));
smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch));
subexprreachNoSwitch.pop_back();
}
rowindex++;
uint64_t obs = 0;
if (options.onlyDeterministicStrategies) {
for(auto const& statesForObservation : statesPerObservation) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) {
for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) {
smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]);
}
}
++obs;
}
}
smtSolver->push();
} else {
smtSolver->pop();
smtSolver->pop();
smtSolver->push();
assert(false);
// PAPER COMMENT: 1
obs = 0;
for (auto const& actionVars : actionSelectionVarExpressions) {
std::vector<storm::expressions::Expression> actExprs = actionVars;
//actExprs.push_back(followVarExpressions[obs]);
smtSolver->add(storm::expressions::disjunction(actExprs));
//for (auto const& av : actionVars) {
// smtSolver->add(!followVarExpressions[obs] || !av);
//}
++obs;
}
// Update at least one observation.
// PAPER COMMENT: 2
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions));
// PAPER COMMENT: 3
if (lookaheadConstraintsRequired) {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
} else {
smtSolver->add(!pathVars[state][0]);
}
}
}
// PAPER COMMENT: 4
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state) || surelyReachSinkStates.get(state)) {
continue;
}
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreachSwitch;
std::vector<storm::expressions::Expression> subexprreachNoSwitch;
subexprreachSwitch.push_back(!reachVarExpressions[state]);
subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]);
subexprreachNoSwitch.push_back(!reachVarExpressions[state]);
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]);
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
smtSolver->add(storm::expressions::disjunction(subexprreachSwitch));
subexprreachSwitch.pop_back();
subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn()));
smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch));
subexprreachNoSwitch.pop_back();
}
rowindex++;
}
}
rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
// PAPER COMMENT 5
if (surelyReachSinkStates.get(state)) {
@ -244,10 +274,12 @@ namespace storm {
}
// PAPER COMMENT 8
uint64_t obs = 0;
obs = 0;
for(auto const& statesForObservation : statesPerObservation) {
for(auto const& state : statesForObservation) {
smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0);
if (!targetStates.get(state)) {
smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0);
}
}
++obs;
}
@ -284,10 +316,10 @@ namespace storm {
template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
stats.initializeSolverTimer.start();
if (k < maxK) {
initialize(k);
maxK = k;
}
// TODO: When do we need to reinitialize? When the solver has been reset.
initialize(k);
maxK = k;
uint64_t maximalNrActions = 8;
@ -374,7 +406,7 @@ namespace storm {
newObservations.clear();
uint64_t obs = 0;
for (auto ov : observationUpdatedVariables) {
for (auto const& ov : observationUpdatedVariables) {
if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) {
STORM_LOG_TRACE("New observation updated: " << obs);
@ -384,17 +416,18 @@ namespace storm {
}
uint64_t i = 0;
for (auto rv : reachVars) {
for (auto const& rv : reachVars) {
if (!coveredStates.get(i) && model->getBooleanValue(rv)) {
STORM_LOG_TRACE("New state: " << i);
smtSolver->add(rv.getExpression());
assert(!surelyReachSinkStates.get(i));
newObservations.set(pomdp.getObservation(i));
coveredStates.set(i);
}
++i;
}
i = 0;
for (auto rv : continuationVars) {
for (auto const& rv : continuationVars) {
if (!coveredStatesAfterSwitch.get(i) && model->getBooleanValue(rv) ) {
smtSolver->add(rv.getExpression());
if (!observationsAfterSwitch.get(pomdp.getObservation(i))) {
@ -478,12 +511,14 @@ namespace storm {
stats.winningRegionUpdatesTimer.start();
storm::storage::BitVector updated(observations.size());
uint64_t newTargetObservations = 0;
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
STORM_LOG_TRACE("consider observation " << observation);
storm::storage::BitVector update(statesPerObservation[observation].size());
uint64_t i = 0;
for (uint64_t state : statesPerObservation[observation]) {
if (coveredStates.get(state)) {
assert(!surelyReachSinkStates.get(state));
update.set(i);
}
++i;
@ -493,19 +528,77 @@ namespace storm {
bool updateResult = winningRegion.update(observation, update);
STORM_LOG_TRACE("Region changed:" << updateResult);
if (updateResult) {
if (winningRegion.observationIsWinning(observation)) {
++newTargetObservations;
for (uint64_t state : statesPerObservation[observation]) {
targetStates.set(state);
}
}
updated.set(observation);
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
}
}
}
STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place");
stats.winningRegionUpdatesTimer.stop();
if (newTargetObservations>0) {
storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp);
uint64_t targetStatesBefore = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits());
targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates);
uint64_t targetStatesAfter = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits());
if (targetStatesAfter - targetStatesBefore > 0) {
stats.winningRegionUpdatesTimer.start();
for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
if (winningRegion.observationIsWinning(observation)) {
continue;
}
bool observationIsWinning = true;
for (uint64_t state : statesPerObservation[observation]) {
if(!targetStates.get(state)) {
observationIsWinning = false;
break;
}
}
if(observationIsWinning) {
stats.incrementGraphBasedWinningObservations();
winningRegion.setObservationIsWinning(observation);
updated.set(observation);
}
}
STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
uint64_t nonWinObTargetStates =0;
for (uint64_t state : targetStates) {
if (!winningRegion.observationIsWinning(pomdp.getObservation(state))) {
nonWinObTargetStates++;
}
}
stats.winningRegionUpdatesTimer.stop();
if (nonWinObTargetStates > 0) {
std::cout << "Non winning target states " << nonWinObTargetStates << std::endl;
STORM_LOG_WARN("This case has been barely tested and likely contains bug");
reset();
return analyze(k, ~targetStates & ~surelyReachSinkStates);
}
}
}
// TODO temporarily switched off due to intiialization issues when restartin.
STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place");
if(options.computeDebugOutput()) {
winningRegion.print();
}
stats.updateNewStrategySolverTime.start();
for(uint64_t observation : updated) {
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
}
uint64_t obs = 0;
for (auto const &statesForObservation : statesPerObservation) {
if (observations.get(obs) && updated.get(obs)) {
@ -537,7 +630,7 @@ namespace storm {
}
stats.updateNewStrategySolverTime.stop();
STORM_LOG_INFO("... after iteration " << stats.getIterations() << " so far " << stats.getChecks() << " checks." );
}
winningRegion.print();

29
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -117,9 +117,26 @@ namespace pomdp {
void incrementSmtChecks() {
satCalls++;
}
uint64_t getChecks() {
return satCalls;
}
uint64_t getIterations() {
return outerIterations;
}
uint64_t getGraphBasedwinningObservations() {
return graphBasedAnalysisWinOb;
}
void incrementGraphBasedWinningObservations() {
graphBasedAnalysisWinOb++;
}
private:
uint64_t satCalls = 0;
uint64_t outerIterations = 0;
uint64_t graphBasedAnalysisWinOb = 0;
};
MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
@ -167,6 +184,13 @@ namespace pomdp {
private:
storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const;
void reset () {
schedulerForObs.clear();
finalSchedulers.clear();
smtSolver->reset();
}
void printScheduler(std::vector<InternalObservationScheduler> const& );
void printCoveredStates(storm::storage::BitVector const& remaining) const;
@ -181,7 +205,6 @@ namespace pomdp {
uint64_t maxK = std::numeric_limits<uint64_t>::max();
storm::storage::BitVector surelyReachSinkStates;
std::set<uint32_t> targetObservations;
storm::storage::BitVector targetStates;
std::vector<std::vector<uint64_t>> statesPerObservation;
@ -199,6 +222,8 @@ namespace pomdp {
std::vector<storm::expressions::Variable> switchVars;
std::vector<storm::expressions::Expression> switchVarExpressions;
std::vector<storm::expressions::Variable> followVars;
std::vector<storm::expressions::Expression> followVarExpressions;
std::vector<storm::expressions::Variable> continuationVars;
std::vector<storm::expressions::Expression> continuationVarExpressions;
std::vector<std::vector<storm::expressions::Expression>> pathVars;
@ -210,6 +235,8 @@ namespace pomdp {
MemlessSearchOptions options;
Statistics stats;
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory;
};
}

18
src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp

@ -84,12 +84,11 @@ namespace storm {
}
template<typename ValueType>
storm::storage::BitVector QualitativeAnalysisOnGraphs<ValueType>::analyseProb1Max(storm::logic::UntilFormula const& formula) const {
// We consider the states that satisfy the formula with prob.1 under arbitrary schedulers as goal states.
storm::storage::BitVector newGoalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(formula.getLeftSubformula()), checkPropositionalFormula(formula.getRightSubformula()));
storm::storage::BitVector QualitativeAnalysisOnGraphs<ValueType>::analyseProb1Max(storm::storage::BitVector const& okay, storm::storage::BitVector const& good) const {
storm::storage::BitVector newGoalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, good);
STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates);
// Now find a set of observations such that there is a memoryless scheduler inducing prob. 1 for each state whose observation is in the set.
storm::storage::BitVector potentialGoalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(formula.getLeftSubformula()), newGoalStates);
storm::storage::BitVector potentialGoalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates);
storm::storage::BitVector notGoalStates = ~potentialGoalStates;
storm::storage::BitVector potentialGoalObservations(pomdp.getNrObservations(), true);
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -104,9 +103,9 @@ namespace storm {
storm::storage::BitVector goalStates(pomdp.getNumberOfStates());
while (goalStates != newGoalStates) {
goalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(formula.getLeftSubformula()), newGoalStates);
goalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates);
newGoalStates = goalStates;
STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates);
STORM_LOG_INFO("Prob1A states according to MDP: " << newGoalStates);
for (uint64_t observation : potentialGoalObservations) {
uint64_t actsForObservation = pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[observation][0]);
// Search whether we find an action that works for this observation.
@ -153,6 +152,13 @@ namespace storm {
}
template<typename ValueType>
storm::storage::BitVector QualitativeAnalysisOnGraphs<ValueType>::analyseProb1Max(storm::logic::UntilFormula const& formula) const {
// We consider the states that satisfy the formula with prob.1 under arbitrary schedulers as goal states.
return this->analyseProb1Max(checkPropositionalFormula(formula.getLeftSubformula()),
checkPropositionalFormula(formula.getRightSubformula()));
}
template<typename ValueType>
storm::storage::BitVector QualitativeAnalysisOnGraphs<ValueType>::analyseProb1Min(storm::logic::UntilFormula const& formula) const {
return storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(formula.getLeftSubformula()), checkPropositionalFormula(formula.getRightSubformula()));

1
src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h

@ -11,6 +11,7 @@ namespace storm {
storm::storage::BitVector analyseProb0(storm::logic::ProbabilityOperatorFormula const& formula) const;
storm::storage::BitVector analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const;
storm::storage::BitVector analyseProbSmaller1(storm::logic::ProbabilityOperatorFormula const& formula) const;
storm::storage::BitVector analyseProb1Max(storm::storage::BitVector const& okay, storm::storage::BitVector const& target) const;
private:
storm::storage::BitVector analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const;
storm::storage::BitVector analyseProb0Max(storm::logic::UntilFormula const& formula) const;

36
src/storm-pomdp/analysis/WinningRegion.cpp

@ -12,6 +12,24 @@ namespace pomdp {
}
}
void WinningRegion::setObservationIsWinning(uint64_t observation) {
winningRegion[observation] = { storm::storage::BitVector(observationSizes[observation], true) };
}
// void WinningRegion::addTargetState(uint64_t observation, uint64_t offset) {
// std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>();
// bool changed = true;
// for (auto const& support : winningRegion[observation]) {
// newWinningSupport.push_back(storm::storage::BitVector(support));
// if(!support.get(offset)) {
// changed = true;
// newWinningSupport.back().set(offset);
// }
// }
//
//
// }
bool WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) {
std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>();
bool changed = false;
@ -87,14 +105,24 @@ namespace pomdp {
void WinningRegion::print() const {
uint64_t observation = 0;
std::vector<uint64_t> winningObservations;
for (auto const& winningSupport : winningRegion) {
std::cout << "***** observation" << observation << std::endl;
for (auto const& support : winningSupport) {
std::cout << " " << support;
if (observationIsWinning(observation)) {
winningObservations.push_back(observation);
} else {
std::cout << "***** observation" << observation << std::endl;
for (auto const& support : winningSupport) {
std::cout << " " << support;
}
std::cout << std::endl;
}
std::cout << std::endl;
observation++;
}
std::cout << " and " << winningObservations.size() << " winning observations: (";
for (auto const& obs : winningObservations) {
std::cout << obs << " ";
}
std::cout << ")" << std::endl;
}
/**

2
src/storm-pomdp/analysis/WinningRegion.h

@ -12,6 +12,8 @@ namespace storm {
bool update(uint64_t observation, storm::storage::BitVector const& winning);
bool query(uint64_t observation, storm::storage::BitVector const& currently) const;
void setObservationIsWinning(uint64_t observation);
bool observationIsWinning(uint64_t observation) const;
storm::expressions::Expression extensionExpression(uint64_t observation, std::vector<storm::expressions::Expression>& varsForStates) const;

Loading…
Cancel
Save