You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
4.0 KiB
91 lines
4.0 KiB
#include <vector>
|
|
#include "storm/storage/expressions/Expressions.h"
|
|
#include "storm/solver/SmtSolver.h"
|
|
#include "storm/models/sparse/Pomdp.h"
|
|
#include "storm/utility/solver.h"
|
|
#include "storm/utility/Stopwatch.h"
|
|
#include "storm/exceptions/UnexpectedException.h"
|
|
|
|
namespace storm {
|
|
namespace pomdp {
|
|
|
|
template<typename ValueType>
|
|
std::set<uint32_t> extractObservations(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::BitVector const& states) {
|
|
// TODO move.
|
|
std::set<uint32_t> observations;
|
|
for(auto state : states) {
|
|
observations.insert(pomdp.getObservation(state));
|
|
}
|
|
return observations;
|
|
}
|
|
|
|
template<typename ValueType>
|
|
class OneShotPolicySearch {
|
|
// Implements to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
|
|
class Statistics {
|
|
public:
|
|
Statistics() = default;
|
|
|
|
storm::utility::Stopwatch totalTimer;
|
|
storm::utility::Stopwatch smtCheckTimer;
|
|
storm::utility::Stopwatch initializeSolverTimer;
|
|
};
|
|
|
|
public:
|
|
OneShotPolicySearch(storm::models::sparse::Pomdp<ValueType> const& pomdp,
|
|
storm::storage::BitVector const& targetStates,
|
|
storm::storage::BitVector const& surelyReachSinkStates,
|
|
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) :
|
|
pomdp(pomdp),
|
|
targetObservations(extractObservations(pomdp, targetStates)),
|
|
targetStates(targetStates),
|
|
surelyReachSinkStates(surelyReachSinkStates) {
|
|
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
|
|
smtSolver = smtSolverFactory->create(*expressionManager);
|
|
|
|
}
|
|
|
|
void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) {
|
|
surelyReachSinkStates = surelyReachSink;
|
|
}
|
|
|
|
void analyzeForInitialStates(uint64_t k) {
|
|
STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates);
|
|
STORM_LOG_TRACE("Target states: " << targetStates);
|
|
STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates));
|
|
bool result = analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates());
|
|
if (result) {
|
|
STORM_PRINT_AND_LOG("From initial state, one can almost-surely reach the target.");
|
|
} else {
|
|
STORM_PRINT_AND_LOG("From initial state, one may not almost-surely reach the target .");
|
|
}
|
|
}
|
|
|
|
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
|
|
|
|
private:
|
|
void initialize(uint64_t k);
|
|
|
|
|
|
Statistics stats;
|
|
std::unique_ptr<storm::solver::SmtSolver> smtSolver;
|
|
storm::models::sparse::Pomdp<ValueType> const& pomdp;
|
|
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
|
|
uint64_t maxK = std::numeric_limits<uint64_t>::max();
|
|
|
|
std::set<uint32_t> targetObservations;
|
|
storm::storage::BitVector targetStates;
|
|
storm::storage::BitVector surelyReachSinkStates;
|
|
|
|
std::vector<std::vector<uint64_t>> statesPerObservation;
|
|
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
|
|
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
|
|
std::vector<storm::expressions::Variable> reachVars;
|
|
std::vector<storm::expressions::Expression> reachVarExpressions;
|
|
std::vector<std::vector<storm::expressions::Expression>> pathVars;
|
|
|
|
|
|
|
|
};
|
|
}
|
|
}
|