You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
3.2 KiB
74 lines
3.2 KiB
#include <vector>
|
|
#include "storm/storage/expressions/Expressions.h"
|
|
#include "storm/solver/SmtSolver.h"
|
|
#include "storm/models/sparse/Pomdp.h"
|
|
#include "storm/utility/solver.h"
|
|
#include "storm/exceptions/UnexpectedException.h"
|
|
|
|
namespace storm {
|
|
namespace pomdp {
|
|
|
|
template<typename ValueType>
|
|
class QualitativeStrategySearchNaive {
|
|
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
|
|
|
|
|
|
public:
|
|
QualitativeStrategySearchNaive(storm::models::sparse::Pomdp<ValueType> const& pomdp,
|
|
std::set<uint32_t> const& targetObservationSet,
|
|
storm::storage::BitVector const& targetStates,
|
|
storm::storage::BitVector const& surelyReachSinkStates,
|
|
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) :
|
|
pomdp(pomdp),
|
|
targetObservations(targetObservationSet),
|
|
targetStates(targetStates),
|
|
surelyReachSinkStates(surelyReachSinkStates) {
|
|
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
|
|
smtSolver = smtSolverFactory->create(*expressionManager);
|
|
|
|
}
|
|
|
|
void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) {
|
|
surelyReachSinkStates = surelyReachSink;
|
|
}
|
|
|
|
void analyzeForInitialStates(uint64_t k) {
|
|
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
|
|
}
|
|
|
|
void findNewStrategyForSomeState(uint64_t k) {
|
|
std::cout << surelyReachSinkStates << std::endl;
|
|
std::cout << targetStates << std::endl;
|
|
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
|
|
analyze(k, ~surelyReachSinkStates & ~targetStates);
|
|
|
|
|
|
}
|
|
|
|
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
|
|
|
|
private:
|
|
void initialize(uint64_t k);
|
|
|
|
|
|
std::unique_ptr<storm::solver::SmtSolver> smtSolver;
|
|
storm::models::sparse::Pomdp<ValueType> const& pomdp;
|
|
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
|
|
uint64_t maxK = std::numeric_limits<uint64_t>::max();
|
|
|
|
std::set<uint32_t> targetObservations;
|
|
storm::storage::BitVector targetStates;
|
|
storm::storage::BitVector surelyReachSinkStates;
|
|
|
|
std::vector<std::vector<uint64_t>> statesPerObservation;
|
|
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
|
|
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
|
|
std::vector<storm::expressions::Variable> reachVars;
|
|
std::vector<storm::expressions::Expression> reachVarExpressions;
|
|
std::vector<std::vector<storm::expressions::Expression>> pathVars;
|
|
|
|
|
|
|
|
};
|
|
}
|
|
}
|