Browse Source

a new encoding for almost sure reachability

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
9a5b01b6f7
  1. 244
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 51
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

244
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -26,6 +26,8 @@ namespace storm {
}
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
reachVarExpressions.push_back(reachVars.back().getExpression());
continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId)));
continuationVarExpressions.push_back(continuationVars.back().getExpression());
statesPerObservation.at(obs).push_back(stateId++);
}
assert(pathVars.size() == pomdp.getNumberOfStates());
@ -40,22 +42,68 @@ namespace storm {
actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName));
actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression());
}
schedulerVariables.push_back(expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs), statesPerObservation.size()));
schedulerVariableExpressions.push_back(schedulerVariables.back());
switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs)));
switchVarExpressions.push_back(switchVars.back().getExpression());
++obs;
}
} else {
assert(false);
for (auto const& actionVars : actionSelectionVarExpressions) {
smtSolver->add(storm::expressions::disjunction(actionVars));
}
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
} else {
smtSolver->add(!pathVars[state][0]);
}
}
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreachSwitch;
std::vector<storm::expressions::Expression> subexprreachNoSwitch;
subexprreachSwitch.push_back(!reachVarExpressions[state]);
subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]);
subexprreachNoSwitch.push_back(!reachVarExpressions[state]);
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]);
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
smtSolver->add(storm::expressions::disjunction(subexprreachSwitch));
subexprreachSwitch.pop_back();
subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn()));
smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch));
subexprreachNoSwitch.pop_back();
}
rowindex++;
}
}
smtSolver->push();
} else {
smtSolver->pop();
smtSolver->pop();
smtSolver->push();
assert(false);
}
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (surelyReachSinkStates.get(state)) {
smtSolver->add(!reachVarExpressions[state]);
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(!pathVars[state][j]);
}
smtSolver->add(!continuationVarExpressions[state]);
} else if(!targetStates.get(state)) {
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
@ -68,12 +116,13 @@ namespace storm {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
subexprreach.push_back(!reachVarExpressions.at(state));
subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
}
smtSolver->add(storm::expressions::disjunction(subexprreach));
// subexprreach.push_back(!reachVarExpressions.at(state));
// subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
// subexprreach.push_back(!switchVarExpressions[pomdp.getObservation(state)]);
// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
// subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
// }
// smtSolver->add(storm::expressions::disjunction(subexprreach));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
@ -89,22 +138,38 @@ namespace storm {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
}
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
}
}
}
for (auto const& actionVars : actionSelectionVarExpressions) {
smtSolver->add(storm::expressions::disjunction(actionVars));
uint64_t obs = 0;
for(auto const& statesForObservation : statesPerObservation) {
for(auto const& state : statesForObservation) {
smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0);
}
++obs;
}
// These constraints ensure that the right solver is used.
// obs = 0;
// for(auto const& statesForObservation : statesPerObservation) {
// smtSolver->add(schedulerVariableExpressions[obs] >= schedulerForObs.size());
// ++obs;
// }
// TODO updateFoundSchedulers();
}
template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
if (k < maxK) {
initialize(k);
maxK = k;
}
std::vector<storm::expressions::Expression> atLeastOneOfStates;
for (uint64_t state : oneOfTheseStates) {
@ -119,12 +184,35 @@ namespace storm {
smtSolver->add(reachVarExpressions[state]);
}
std::cout << smtSolver->getSmtLibString() << std::endl;
smtSolver->push();
uint64_t obs = 0;
for(auto const& statesForObservation : statesPerObservation) {
smtSolver->add(schedulerVariableExpressions[obs] <= schedulerForObs.size());
++obs;
}
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) {
schedulerForObs.push_back(std::vector<uint64_t>());
}
std::vector<std::set<uint64_t>> scheduler;
InternalObservationScheduler scheduler;
scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations());
storm::storage::BitVector observations(pomdp.getNrObservations());
storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations());
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
uint64_t iterations = 0;
while(true) {
scheduler.clear();
observations.clear();
observationsAfterSwitch.clear();
remainingstates.clear();
while (true) {
++iterations;
std::cout << "Call to SMT Solver (" <<iterations << ")" << std::endl;
std::cout << smtSolver->getSmtLibString() << std::endl;
auto result = smtSolver->check();
uint64_t i = 0;
@ -133,70 +221,154 @@ namespace storm {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
std::cout << std::endl << "Unsatisfiable!" << std::endl;
return false;
break;
}
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl;
auto model = smtSolver->getModel();
std::cout << "states that are okay" << std::endl;
storm::storage::BitVector observations(pomdp.getNrObservations());
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
observations.clear();
observationsAfterSwitch.clear();
remainingstates.clear();
scheduler.clear();
for (auto rv : reachVars) {
if (model->getBooleanValue(rv)) {
std::cout << i << " " << std::endl;
smtSolver->add(rv.getExpression());
observations.set(pomdp.getObservation(i));
} else {
remainingstates.set(i);
}
//std::cout << i << ": " << model->getBooleanValue(rv) << ", ";
++i;
}
scheduler.clear();
i = 0;
std::cout << "states from which we continue" << std::endl;
for (auto rv : continuationVars) {
if (model->getBooleanValue(rv)) {
smtSolver->add(rv.getExpression());
observationsAfterSwitch.set(pomdp.getObservation(i));
std::cout << " " << i;
}
++i;
}
std::cout << std::endl;
std::cout << "states that are okay" << std::endl;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (!remainingstates.get(state)) {
std::cout << " " << state;
}
}
std::vector<storm::expressions::Expression> schedulerSoFar;
uint64_t obs = 0;
for (auto const &actionSelectionVarsForObs : actionSelectionVars) {
uint64_t act = 0;
scheduler.push_back(std::set<uint64_t>());
for (auto const &asv : actionSelectionVarsForObs) {
scheduler.actions.push_back(std::set<uint64_t>());
if (observations.get(obs)) {
for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) {
auto const& asv = actionSelectionVarsForObs[act];
if (model->getBooleanValue(asv)) {
scheduler.back().insert(act);
scheduler.actions.back().insert(act);
schedulerSoFar.push_back(actionSelectionVarExpressions[obs][act]);
}
act++;
}
obs++;
if (model->getBooleanValue(switchVars[obs])) {
scheduler.switchObservations.set(obs);
schedulerSoFar.push_back(switchVarExpressions[obs]);
} else {
schedulerSoFar.push_back(!switchVarExpressions[obs]);
}
std::cout << "the scheduler: " << std::endl;
for (uint64_t obs = 0; obs < scheduler.size(); ++obs) {
if (observations.get(obs)) {
std::cout << "observation: " << obs << std::endl;
std::cout << "actions:";
for (auto act : scheduler[obs]) {
std::cout << " " << act;
}
std::cout << std::endl;
if (observationsAfterSwitch.get(obs)) {
scheduler.schedulerRef.push_back(model->getIntegerValue(schedulerVariables[obs]));
schedulerSoFar.push_back(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef.back()));
} else {
scheduler.schedulerRef.push_back(0);
}
obs++;
}
std::cout << "the scheduler so far: " << std::endl;
scheduler.printForObservations(observations,observationsAfterSwitch);
std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) {
remainingExpressions.push_back(reachVarExpressions[index]);
}
smtSolver->push();
// Add scheduler
smtSolver->add(storm::expressions::conjunction(schedulerSoFar));
smtSolver->add(storm::expressions::disjunction(remainingExpressions));
}
if (scheduler.empty()) {
break;
}
smtSolver->pop();
std::cout << "states that are okay" << std::endl;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (!remainingstates.get(state)) {
std::cout << " " << state;
}
}
std::cout << std::endl;
std::cout << "the scheduler: " << std::endl;
scheduler.printForObservations(observations,observationsAfterSwitch);
std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) {
remainingExpressions.push_back(reachVarExpressions[index]);
}
smtSolver->add(storm::expressions::disjunction(remainingExpressions));
uint64_t obs = 0;
for (auto const &statesForObservation : statesPerObservation) {
if (observations.get(obs)) {
std::cout << "We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "." << std::endl;
assert(schedulerForObs.size() > obs);
schedulerForObs[obs].push_back(finalSchedulers.size());
std::cout << "We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs << std::endl;
for (auto const &state : statesForObservation) {
if (remainingstates.get(state)) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
}
}
}
++obs;
}
finalSchedulers.push_back(scheduler);
smtSolver->push();
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
smtSolver->add(schedulerVariableExpressions[obs] <= constant);
}
}
return true;
}
template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) {
}
template <typename ValueType>
storm::expressions::Expression const& MemlessStrategySearchQualitative<ValueType>::getDoneActionExpression(uint64_t obs) const {
return actionSelectionVarExpressions[obs].back();
}
template class MemlessStrategySearchQualitative<double>;

51
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -8,6 +8,42 @@
namespace storm {
namespace pomdp {
struct InternalObservationScheduler {
std::vector<std::set<uint64_t>> actions;
std::vector<uint64_t> schedulerRef;
storm::storage::BitVector switchObservations;
void clear() {
actions.clear();
schedulerRef.clear();
switchObservations.clear();
}
bool empty() const {
return actions.empty();
}
void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const {
for (uint64_t obs = 0; obs < observations.size(); ++obs) {
if (observations.get(obs)) {
std::cout << "observation: " << obs << std::endl;
std::cout << "actions:";
for (auto act : actions[obs]) {
std::cout << " " << act;
}
if (switchObservations.get(obs)) {
std::cout << " and switch.";
}
std::cout << std::endl;
}
if (observationsAfterSwitch.get(obs)) {
std::cout << "scheduler ref: " << schedulerRef[obs] << std::endl;
}
}
}
};
template<typename ValueType>
class MemlessStrategySearchQualitative {
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
@ -49,6 +85,10 @@ namespace pomdp {
private:
storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const;
void printScheduler(std::vector<InternalObservationScheduler> const& );
void initialize(uint64_t k);
@ -61,13 +101,24 @@ namespace pomdp {
storm::storage::BitVector targetStates;
storm::storage::BitVector surelyReachSinkStates;
std::vector<storm::expressions::Variable> schedulerVariables;
std::vector<storm::expressions::Expression> schedulerVariableExpressions;
std::vector<std::vector<uint64_t>> statesPerObservation;
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
std::vector<storm::expressions::Variable> reachVars;
std::vector<storm::expressions::Expression> reachVarExpressions;
std::vector<storm::expressions::Variable> switchVars;
std::vector<storm::expressions::Expression> switchVarExpressions;
std::vector<storm::expressions::Variable> continuationVars;
std::vector<storm::expressions::Expression> continuationVarExpressions;
std::vector<std::vector<storm::expressions::Expression>> pathVars;
std::vector<InternalObservationScheduler> finalSchedulers;
std::vector<std::vector<uint64_t>> schedulerForObs;
};

Loading…
Cancel
Save