From c0ac9814e10eb560e559d6d244b34bed3bbeced7 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Wed, 29 Apr 2020 17:07:56 -0700 Subject: [PATCH] allow for graph-analysis and sat-based analysis interleaving, and restarting sat-based solver when advantageous --- .../MemlessStrategySearchQualitative.cpp | 265 ++++++++++++------ .../MemlessStrategySearchQualitative.h | 29 +- .../analysis/QualitativeAnalysisOnGraphs.cpp | 18 +- .../analysis/QualitativeAnalysisOnGraphs.h | 1 + src/storm-pomdp/analysis/WinningRegion.cpp | 36 ++- src/storm-pomdp/analysis/WinningRegion.h | 2 + 6 files changed, 254 insertions(+), 97 deletions(-) diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index b8f1e6054..d09410545 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -3,6 +3,7 @@ #include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h" #include "storm-pomdp/analysis/QualitativeAnalysis.h" +#include "storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h" namespace storm { namespace pomdp { @@ -53,9 +54,9 @@ namespace storm { MemlessSearchOptions const& options) : pomdp(pomdp), surelyReachSinkStates(surelyReachSinkStates), - targetObservations(storm::pomdp::extractObservations(pomdp, targetStates)), targetStates(targetStates), - options(options) + options(options), + smtSolverFactory(smtSolverFactory) { this->expressionManager = std::make_shared(); smtSolver = smtSolverFactory->create(*expressionManager); @@ -86,113 +87,142 @@ namespace storm { } else { lookaheadConstraintsRequired = qualitative::isLookaheadRequired(pomdp, targetStates, surelyReachSinkStates); } - if (maxK == std::numeric_limits::max()) { - // not initialized at all. - // Create some data structures. - for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { + + if (actionSelectionVars.empty()) { + for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { actionSelectionVars.push_back(std::vector()); actionSelectionVarExpressions.push_back(std::vector()); } - - // Fill the states-per-observation mapping, - // declare the reachability variables, - // declare the path variables. - for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { - if(lookaheadConstraintsRequired) { - pathVars.push_back(std::vector()); - for (uint64_t i = 0; i < k; ++i) { - pathVars.back().push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression()); - } - } + for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); reachVarExpressions.push_back(reachVars.back().getExpression()); - reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back(reachVarExpressions.back()); - continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId))); + reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back( + reachVarExpressions.back()); + continuationVars.push_back( + expressionManager->declareBooleanVariable("D-" + std::to_string(stateId))); continuationVarExpressions.push_back(continuationVars.back().getExpression()); } - assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates()); - assert(reachVars.size() == pomdp.getNumberOfStates()); - assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); - // Create the action selection variables. uint64_t obs = 0; - for(auto const& statesForObservation : statesPerObservation) { + for (auto const &statesForObservation : statesPerObservation) { for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) { std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a); actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName)); - actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression()); + actionSelectionVarExpressions.at(obs).push_back( + actionSelectionVars.at(obs).back().getExpression()); } - schedulerVariables.push_back(expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs), statesPerObservation.size())); + schedulerVariables.push_back( + expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs), + statesPerObservation.size())); schedulerVariableExpressions.push_back(schedulerVariables.back()); switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs))); switchVarExpressions.push_back(switchVars.back().getExpression()); - observationUpdatedVariables.push_back(expressionManager->declareBooleanVariable("U-" + std::to_string(obs))); + observationUpdatedVariables.push_back( + expressionManager->declareBooleanVariable("U-" + std::to_string(obs))); observationUpdatedExpressions.push_back(observationUpdatedVariables.back().getExpression()); - if (options.onlyDeterministicStrategies) { - for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) { - for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) { - smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]); - } - } - } + followVars.push_back(expressionManager->declareBooleanVariable("F-"+std::to_string(obs))); + followVarExpressions.push_back(followVars.back().getExpression()); + ++obs; } - // PAPER COMMENT: 1 - for (auto const& actionVars : actionSelectionVarExpressions) { - smtSolver->add(storm::expressions::disjunction(actionVars)); + for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { + pathVars.push_back(std::vector()); } + } - // Update at least one observation. - // PAPER COMMENT: 2 - smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); - - // PAPER COMMENT: 3 - if (lookaheadConstraintsRequired) { - for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - if (targetStates.get(state)) { - smtSolver->add(pathVars[state][0]); - } else { - smtSolver->add(!pathVars[state][0]); + uint64_t initK = 0; + if (maxK != std::numeric_limits::max()) { + initK = maxK; + } + if (initK < k) { + for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { + if (lookaheadConstraintsRequired) { + for (uint64_t i = initK; i < k; ++i) { + pathVars[stateId].push_back(expressionManager->declareBooleanVariable( + "P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression()); } } } + } - // PAPER COMMENT: 4 - uint64_t rowindex = 0; - for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates()); + assert(reachVars.size() == pomdp.getNumberOfStates()); + assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); - for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { - std::vector subexprreachSwitch; - std::vector subexprreachNoSwitch; - subexprreachSwitch.push_back(!reachVarExpressions[state]); - subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); - subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]); - subexprreachNoSwitch.push_back(!reachVarExpressions[state]); - subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); - subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); - for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { - subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); - smtSolver->add(storm::expressions::disjunction(subexprreachSwitch)); - subexprreachSwitch.pop_back(); - subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn())); - smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch)); - subexprreachNoSwitch.pop_back(); - } - rowindex++; + uint64_t obs = 0; + if (options.onlyDeterministicStrategies) { + for(auto const& statesForObservation : statesPerObservation) { + for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) { + for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) { + smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]); + } } + ++obs; } + } - smtSolver->push(); - } else { - smtSolver->pop(); - smtSolver->pop(); - smtSolver->push(); - assert(false); + // PAPER COMMENT: 1 + obs = 0; + for (auto const& actionVars : actionSelectionVarExpressions) { + std::vector actExprs = actionVars; + //actExprs.push_back(followVarExpressions[obs]); + smtSolver->add(storm::expressions::disjunction(actExprs)); + //for (auto const& av : actionVars) { + // smtSolver->add(!followVarExpressions[obs] || !av); + //} + ++obs; + } + + + + // Update at least one observation. + // PAPER COMMENT: 2 + smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); + + // PAPER COMMENT: 3 + if (lookaheadConstraintsRequired) { + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (targetStates.get(state)) { + smtSolver->add(pathVars[state][0]); + } else { + smtSolver->add(!pathVars[state][0]); + } + } } + // PAPER COMMENT: 4 uint64_t rowindex = 0; + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (targetStates.get(state) || surelyReachSinkStates.get(state)) { + continue; + } + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + std::vector subexprreachSwitch; + std::vector subexprreachNoSwitch; + subexprreachSwitch.push_back(!reachVarExpressions[state]); + subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); + subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]); + subexprreachNoSwitch.push_back(!reachVarExpressions[state]); + subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); + subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { + subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); + smtSolver->add(storm::expressions::disjunction(subexprreachSwitch)); + subexprreachSwitch.pop_back(); + subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn())); + smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch)); + subexprreachNoSwitch.pop_back(); + } + + rowindex++; + } + } + + + + rowindex = 0; for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { // PAPER COMMENT 5 if (surelyReachSinkStates.get(state)) { @@ -244,10 +274,12 @@ namespace storm { } // PAPER COMMENT 8 - uint64_t obs = 0; + obs = 0; for(auto const& statesForObservation : statesPerObservation) { for(auto const& state : statesForObservation) { - smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0); + if (!targetStates.get(state)) { + smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0); + } } ++obs; } @@ -284,10 +316,10 @@ namespace storm { template bool MemlessStrategySearchQualitative::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { stats.initializeSolverTimer.start(); - if (k < maxK) { - initialize(k); - maxK = k; - } + // TODO: When do we need to reinitialize? When the solver has been reset. + initialize(k); + maxK = k; + uint64_t maximalNrActions = 8; @@ -374,7 +406,7 @@ namespace storm { newObservations.clear(); uint64_t obs = 0; - for (auto ov : observationUpdatedVariables) { + for (auto const& ov : observationUpdatedVariables) { if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) { STORM_LOG_TRACE("New observation updated: " << obs); @@ -384,17 +416,18 @@ namespace storm { } uint64_t i = 0; - for (auto rv : reachVars) { + for (auto const& rv : reachVars) { if (!coveredStates.get(i) && model->getBooleanValue(rv)) { STORM_LOG_TRACE("New state: " << i); smtSolver->add(rv.getExpression()); + assert(!surelyReachSinkStates.get(i)); newObservations.set(pomdp.getObservation(i)); coveredStates.set(i); } ++i; } i = 0; - for (auto rv : continuationVars) { + for (auto const& rv : continuationVars) { if (!coveredStatesAfterSwitch.get(i) && model->getBooleanValue(rv) ) { smtSolver->add(rv.getExpression()); if (!observationsAfterSwitch.get(pomdp.getObservation(i))) { @@ -478,12 +511,14 @@ namespace storm { stats.winningRegionUpdatesTimer.start(); storm::storage::BitVector updated(observations.size()); + uint64_t newTargetObservations = 0; for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { STORM_LOG_TRACE("consider observation " << observation); storm::storage::BitVector update(statesPerObservation[observation].size()); uint64_t i = 0; for (uint64_t state : statesPerObservation[observation]) { if (coveredStates.get(state)) { + assert(!surelyReachSinkStates.get(state)); update.set(i); } ++i; @@ -493,19 +528,77 @@ namespace storm { bool updateResult = winningRegion.update(observation, update); STORM_LOG_TRACE("Region changed:" << updateResult); if (updateResult) { + if (winningRegion.observationIsWinning(observation)) { + ++newTargetObservations; + for (uint64_t state : statesPerObservation[observation]) { + targetStates.set(state); + } + } updated.set(observation); updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); } } } - STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place"); stats.winningRegionUpdatesTimer.stop(); + if (newTargetObservations>0) { + storm::analysis::QualitativeAnalysisOnGraphs graphanalysis(pomdp); + uint64_t targetStatesBefore = targetStates.getNumberOfSetBits(); + STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits()); + targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); + uint64_t targetStatesAfter = targetStates.getNumberOfSetBits(); + STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits()); + if (targetStatesAfter - targetStatesBefore > 0) { + stats.winningRegionUpdatesTimer.start(); + + for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { + if (winningRegion.observationIsWinning(observation)) { + continue; + } + bool observationIsWinning = true; + for (uint64_t state : statesPerObservation[observation]) { + if(!targetStates.get(state)) { + observationIsWinning = false; + break; + } + } + if(observationIsWinning) { + stats.incrementGraphBasedWinningObservations(); + winningRegion.setObservationIsWinning(observation); + updated.set(observation); + } + } + STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations()); + uint64_t nonWinObTargetStates =0; + for (uint64_t state : targetStates) { + if (!winningRegion.observationIsWinning(pomdp.getObservation(state))) { + nonWinObTargetStates++; + } + } + stats.winningRegionUpdatesTimer.stop(); + if (nonWinObTargetStates > 0) { + std::cout << "Non winning target states " << nonWinObTargetStates << std::endl; + STORM_LOG_WARN("This case has been barely tested and likely contains bug"); + reset(); + return analyze(k, ~targetStates & ~surelyReachSinkStates); + } + } + + } + // TODO temporarily switched off due to intiialization issues when restartin. + STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place"); + + + if(options.computeDebugOutput()) { winningRegion.print(); } stats.updateNewStrategySolverTime.start(); + for(uint64_t observation : updated) { + updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); + } + uint64_t obs = 0; for (auto const &statesForObservation : statesPerObservation) { if (observations.get(obs) && updated.get(obs)) { @@ -537,7 +630,7 @@ namespace storm { } stats.updateNewStrategySolverTime.stop(); - + STORM_LOG_INFO("... after iteration " << stats.getIterations() << " so far " << stats.getChecks() << " checks." ); } winningRegion.print(); diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index 08b0ffc04..acdb9ac6e 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -117,9 +117,26 @@ namespace pomdp { void incrementSmtChecks() { satCalls++; } + + uint64_t getChecks() { + return satCalls; + } + + uint64_t getIterations() { + return outerIterations; + } + + uint64_t getGraphBasedwinningObservations() { + return graphBasedAnalysisWinOb; + } + + void incrementGraphBasedWinningObservations() { + graphBasedAnalysisWinOb++; + } private: uint64_t satCalls = 0; uint64_t outerIterations = 0; + uint64_t graphBasedAnalysisWinOb = 0; }; MemlessStrategySearchQualitative(storm::models::sparse::Pomdp const& pomdp, @@ -167,6 +184,13 @@ namespace pomdp { private: storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const; + void reset () { + schedulerForObs.clear(); + finalSchedulers.clear(); + smtSolver->reset(); + + + } void printScheduler(std::vector const& ); void printCoveredStates(storm::storage::BitVector const& remaining) const; @@ -181,7 +205,6 @@ namespace pomdp { uint64_t maxK = std::numeric_limits::max(); storm::storage::BitVector surelyReachSinkStates; - std::set targetObservations; storm::storage::BitVector targetStates; std::vector> statesPerObservation; @@ -199,6 +222,8 @@ namespace pomdp { std::vector switchVars; std::vector switchVarExpressions; + std::vector followVars; + std::vector followVarExpressions; std::vector continuationVars; std::vector continuationVarExpressions; std::vector> pathVars; @@ -210,6 +235,8 @@ namespace pomdp { MemlessSearchOptions options; Statistics stats; + std::shared_ptr& smtSolverFactory; + }; } diff --git a/src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp b/src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp index 5fbc31cbd..913986553 100644 --- a/src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp +++ b/src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp @@ -84,12 +84,11 @@ namespace storm { } template - storm::storage::BitVector QualitativeAnalysisOnGraphs::analyseProb1Max(storm::logic::UntilFormula const& formula) const { - // We consider the states that satisfy the formula with prob.1 under arbitrary schedulers as goal states. - storm::storage::BitVector newGoalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(formula.getLeftSubformula()), checkPropositionalFormula(formula.getRightSubformula())); + storm::storage::BitVector QualitativeAnalysisOnGraphs::analyseProb1Max(storm::storage::BitVector const& okay, storm::storage::BitVector const& good) const { + storm::storage::BitVector newGoalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, good); STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates); // Now find a set of observations such that there is a memoryless scheduler inducing prob. 1 for each state whose observation is in the set. - storm::storage::BitVector potentialGoalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(formula.getLeftSubformula()), newGoalStates); + storm::storage::BitVector potentialGoalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates); storm::storage::BitVector notGoalStates = ~potentialGoalStates; storm::storage::BitVector potentialGoalObservations(pomdp.getNrObservations(), true); for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { @@ -104,9 +103,9 @@ namespace storm { storm::storage::BitVector goalStates(pomdp.getNumberOfStates()); while (goalStates != newGoalStates) { - goalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(formula.getLeftSubformula()), newGoalStates); + goalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates); newGoalStates = goalStates; - STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates); + STORM_LOG_INFO("Prob1A states according to MDP: " << newGoalStates); for (uint64_t observation : potentialGoalObservations) { uint64_t actsForObservation = pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[observation][0]); // Search whether we find an action that works for this observation. @@ -153,6 +152,13 @@ namespace storm { } + template + storm::storage::BitVector QualitativeAnalysisOnGraphs::analyseProb1Max(storm::logic::UntilFormula const& formula) const { + // We consider the states that satisfy the formula with prob.1 under arbitrary schedulers as goal states. + return this->analyseProb1Max(checkPropositionalFormula(formula.getLeftSubformula()), + checkPropositionalFormula(formula.getRightSubformula())); + } + template storm::storage::BitVector QualitativeAnalysisOnGraphs::analyseProb1Min(storm::logic::UntilFormula const& formula) const { return storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(formula.getLeftSubformula()), checkPropositionalFormula(formula.getRightSubformula())); diff --git a/src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h b/src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h index f9f14dfd4..8a7625827 100644 --- a/src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h +++ b/src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h @@ -11,6 +11,7 @@ namespace storm { storm::storage::BitVector analyseProb0(storm::logic::ProbabilityOperatorFormula const& formula) const; storm::storage::BitVector analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const; storm::storage::BitVector analyseProbSmaller1(storm::logic::ProbabilityOperatorFormula const& formula) const; + storm::storage::BitVector analyseProb1Max(storm::storage::BitVector const& okay, storm::storage::BitVector const& target) const; private: storm::storage::BitVector analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const; storm::storage::BitVector analyseProb0Max(storm::logic::UntilFormula const& formula) const; diff --git a/src/storm-pomdp/analysis/WinningRegion.cpp b/src/storm-pomdp/analysis/WinningRegion.cpp index 596bb0a5d..1fd3ae950 100644 --- a/src/storm-pomdp/analysis/WinningRegion.cpp +++ b/src/storm-pomdp/analysis/WinningRegion.cpp @@ -12,6 +12,24 @@ namespace pomdp { } } + void WinningRegion::setObservationIsWinning(uint64_t observation) { + winningRegion[observation] = { storm::storage::BitVector(observationSizes[observation], true) }; + } + +// void WinningRegion::addTargetState(uint64_t observation, uint64_t offset) { +// std::vector newWinningSupport = std::vector(); +// bool changed = true; +// for (auto const& support : winningRegion[observation]) { +// newWinningSupport.push_back(storm::storage::BitVector(support)); +// if(!support.get(offset)) { +// changed = true; +// newWinningSupport.back().set(offset); +// } +// } +// +// +// } + bool WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) { std::vector newWinningSupport = std::vector(); bool changed = false; @@ -87,14 +105,24 @@ namespace pomdp { void WinningRegion::print() const { uint64_t observation = 0; + std::vector winningObservations; for (auto const& winningSupport : winningRegion) { - std::cout << "***** observation" << observation << std::endl; - for (auto const& support : winningSupport) { - std::cout << " " << support; + if (observationIsWinning(observation)) { + winningObservations.push_back(observation); + } else { + std::cout << "***** observation" << observation << std::endl; + for (auto const& support : winningSupport) { + std::cout << " " << support; + } + std::cout << std::endl; } - std::cout << std::endl; observation++; } + std::cout << " and " << winningObservations.size() << " winning observations: ("; + for (auto const& obs : winningObservations) { + std::cout << obs << " "; + } + std::cout << ")" << std::endl; } /** diff --git a/src/storm-pomdp/analysis/WinningRegion.h b/src/storm-pomdp/analysis/WinningRegion.h index e23dabf70..29e3f511f 100644 --- a/src/storm-pomdp/analysis/WinningRegion.h +++ b/src/storm-pomdp/analysis/WinningRegion.h @@ -12,6 +12,8 @@ namespace storm { bool update(uint64_t observation, storm::storage::BitVector const& winning); bool query(uint64_t observation, storm::storage::BitVector const& currently) const; + void setObservationIsWinning(uint64_t observation); + bool observationIsWinning(uint64_t observation) const; storm::expressions::Expression extensionExpression(uint64_t observation, std::vector& varsForStates) const;