#include #include "storm/storage/expressions/Expressions.h" #include "storm/solver/SmtSolver.h" #include "storm/models/sparse/Pomdp.h" #include "storm/utility/solver.h" #include "storm/utility/Stopwatch.h" #include "storm/exceptions/UnexpectedException.h" namespace storm { namespace pomdp { template std::set extractObservations(storm::models::sparse::Pomdp const& pomdp, storm::storage::BitVector const& states) { // TODO move. std::set observations; for(auto state : states) { observations.insert(pomdp.getObservation(state)); } return observations; } template class OneShotPolicySearch { // Implements to the Chatterjee, Chmelik, Davies (AAAI-16) paper. class Statistics { public: Statistics() = default; storm::utility::Stopwatch totalTimer; storm::utility::Stopwatch smtCheckTimer; storm::utility::Stopwatch initializeSolverTimer; }; public: OneShotPolicySearch(storm::models::sparse::Pomdp const& pomdp, storm::storage::BitVector const& targetStates, storm::storage::BitVector const& surelyReachSinkStates, std::shared_ptr& smtSolverFactory) : pomdp(pomdp), targetObservations(extractObservations(pomdp, targetStates)), targetStates(targetStates), surelyReachSinkStates(surelyReachSinkStates) { this->expressionManager = std::make_shared(); smtSolver = smtSolverFactory->create(*expressionManager); } void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) { surelyReachSinkStates = surelyReachSink; } void analyzeForInitialStates(uint64_t k) { STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates); STORM_LOG_TRACE("Target states: " << targetStates); STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates)); bool result = analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates()); if (result) { STORM_PRINT_AND_LOG("From initial state, one can almost-surely reach the target."); } else { STORM_PRINT_AND_LOG("From initial state, one may not almost-surely reach the target ."); } } bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); private: void initialize(uint64_t k); Statistics stats; std::unique_ptr smtSolver; storm::models::sparse::Pomdp const& pomdp; std::shared_ptr expressionManager; uint64_t maxK = std::numeric_limits::max(); std::set targetObservations; storm::storage::BitVector targetStates; storm::storage::BitVector surelyReachSinkStates; std::vector> statesPerObservation; std::vector> actionSelectionVarExpressions; // A_{z,a} std::vector> actionSelectionVars; std::vector reachVars; std::vector reachVarExpressions; std::vector> pathVars; }; } }