Browse Source

use target state to initialise winning region, better timers and slight improvements in partial scheduler extension

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
556a884e74
  1. 226
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 11
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
  3. 31
      src/storm-pomdp/analysis/WinningRegion.cpp
  4. 1
      src/storm-pomdp/analysis/WinningRegion.h

226
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -41,9 +41,11 @@ namespace storm {
STORM_PRINT_AND_LOG("SAT Calls time: " << smtCheckTimer << std::endl);
STORM_PRINT_AND_LOG("Outer iterations: " << outerIterations << std::endl);
STORM_PRINT_AND_LOG("Solver initialization time: " << initializeSolverTimer << std::endl);
STORM_PRINT_AND_LOG("Extend partial scheduler time: " << updateExtensionSolverTime << std::endl);
STORM_PRINT_AND_LOG("Obtain partial scheduler time: " << evaluateExtensionSolverTime << std::endl);
STORM_PRINT_AND_LOG("Update solver to extend partial scheduler time: " << encodeExtensionSolverTime << std::endl);
STORM_PRINT_AND_LOG("Update solver with new scheduler time: " << updateNewStrategySolverTime << std::endl);
STORM_PRINT_AND_LOG("Winning regions update time: " << winningRegionUpdatesTimer << std::endl);
STORM_PRINT_AND_LOG("Graph search time: " << graphSearchTime << std::endl);
}
template <typename ValueType>
@ -155,15 +157,21 @@ namespace storm {
uint64_t obs = 0;
if (options.onlyDeterministicStrategies) {
for(auto const& statesForObservation : statesPerObservation) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) {
for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) {
smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]);
for(auto const& statesForObservation : statesPerObservation) {
if ( pomdp.getNumberOfChoices(statesForObservation.front()) == 1) {
++obs;
continue;
}
if (options.onlyDeterministicStrategies || statesForObservation.size() == 1) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()) - 1; ++a) {
for (uint64_t b = a + 1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) {
smtSolver->add(!(actionSelectionVarExpressions[obs][a]) ||
!(actionSelectionVarExpressions[obs][b]));
}
}
++obs;
}
++obs;
}
// PAPER COMMENT: 1
@ -302,38 +310,107 @@ namespace storm {
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs])));
}
// PAPER COMMENT 10
if (!lookaheadConstraintsRequired) {
uint64_t rowIndex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
uint64_t enabledActions = pomdp.getNumberOfChoices(state);
if (!surelyReachSinkStates.get(state)) {
std::vector<storm::expressions::Expression> successorVars;
for (uint64_t act = 0; act < enabledActions; ++act) {
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) {
successorVars.push_back(reachVarExpressions[entries.getColumn()]);
}
rowIndex++;
}
successorVars.push_back(!switchVars[pomdp.getObservation(state)]);
smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state]));
} else {
rowIndex += enabledActions;
}
// if (!lookaheadConstraintsRequired) {
// uint64_t rowIndex = 0;
// for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
// uint64_t enabledActions = pomdp.getNumberOfChoices(state);
// if (!surelyReachSinkStates.get(state)) {
// std::vector<storm::expressions::Expression> successorVars;
// for (uint64_t act = 0; act < enabledActions; ++act) {
// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) {
// successorVars.push_back(reachVarExpressions[entries.getColumn()]);
// }
// rowIndex++;
// }
// successorVars.push_back(!switchVars[pomdp.getObservation(state)]);
// smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state]));
// } else {
// rowIndex += enabledActions;
// }
// }
// } else {
// STORM_LOG_WARN("Some optimization not implemented yet.");
// }
// TODO: Update found schedulers if k is increased.
}
template<typename ValueType>
uint64_t MemlessStrategySearchQualitative<ValueType>::getOffsetFromObservation(uint64_t state, uint64_t observation) const {
if(!useFindOffset) {
STORM_LOG_WARN("This code is slow and should only be used for debugging.");
useFindOffset = true;
}
uint64_t offset = 0;
for(uint64_t s : statesPerObservation[observation]) {
if (s == state) {
return offset;
}
} else {
STORM_LOG_WARN("Some optimization not implemented yet.");
++offset;
}
// TODO: Update found schedulers if k is increased.
assert(false); // State should have occured.
return 0;
}
template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
std::cout << "Surely reach sink states: " << surelyReachSinkStates << std::endl;
std::cout << "Target states " << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
stats.initializeSolverTimer.start();
// TODO: When do we need to reinitialize? When the solver has been reset.
initialize(k);
maxK = k;
stats.winningRegionUpdatesTimer.start();
storm::storage::BitVector updated(pomdp.getNrObservations());
// TODO CODE DUPLICATION WITH UPDATE, PUT IN PROCEDURE
storm::storage::BitVector potentialWinner(pomdp.getNrObservations());
storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations());
for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
if (winningRegion.observationIsWinning(observation)) {
continue;
}
bool observationIsWinning = true;
for (uint64_t state : statesPerObservation[observation]) {
if(!targetStates.get(state)) {
observationIsWinning = false;
observationsWithPartialWinners.set(observation);
} else {
potentialWinner.set(observation);
}
}
if(observationIsWinning) {
STORM_LOG_TRACE("Observation " << observation << " is winning.");
stats.incrementGraphBasedWinningObservations();
winningRegion.setObservationIsWinning(observation);
updated.set(observation);
}
}
STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
observationsWithPartialWinners &= potentialWinner;
for (auto const& observation : observationsWithPartialWinners) {
uint64_t nrStatesForObs = statesPerObservation[observation].size();
storm::storage::BitVector update(nrStatesForObs);
for (uint64_t i = 0; i < nrStatesForObs; ++i ) {
uint64_t state = statesPerObservation[observation][i];
if(targetStates.get(state)) {
update.set(i);
}
}
assert(!update.empty());
STORM_LOG_TRACE("Extend winning region for observation " << observation << " with target states/offsets" << update);
winningRegion.addTargetStates(observation, update);
assert(winningRegion.query(observation,update));// "Cannot continue: No scheduler known for state " << i << " (observation " << obs << ").");
updated.set(observation);
}
for (auto const& state : targetStates) {
STORM_LOG_ASSERT(winningRegion.isWinning(pomdp.getObservation(state),getOffsetFromObservation(state,pomdp.getObservation(state))), "Target state " << state << " , observation " << pomdp.getObservation(state) << " is not reflected as winning.");
}
stats.winningRegionUpdatesTimer.stop();
uint64_t maximalNrActions = 8;
STORM_LOG_WARN("We have hardcoded (an upper bound on) the number of actions");
@ -415,6 +492,7 @@ namespace storm {
storm::storage::BitVector observations(pomdp.getNrObservations());
storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations());
storm::storage::BitVector observationUpdated(pomdp.getNrObservations());
storm::storage::BitVector uncoveredStates(pomdp.getNumberOfStates());
storm::storage::BitVector coveredStates(pomdp.getNumberOfStates());
storm::storage::BitVector coveredStatesAfterSwitch(pomdp.getNumberOfStates());
@ -449,15 +527,14 @@ namespace storm {
break;
}
newSchedulerDiscovered = true;
stats.updateExtensionSolverTime.start();
auto model = smtSolver->getModel();
stats.evaluateExtensionSolverTime.start();
auto const& model = smtSolver->getModel();
newObservationsAfterSwitch.clear();
newObservations.clear();
uint64_t obs = 0;
for (auto const& ov : observationUpdatedVariables) {
if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) {
STORM_LOG_TRACE("New observation updated: " << obs);
observationUpdated.set(obs);
@ -465,32 +542,43 @@ namespace storm {
obs++;
}
uint64_t i = 0;
for (auto const& rv : reachVars) {
if (!coveredStates.get(i) && model->getBooleanValue(rv)) {
// for(uint64_t i : targetStates) {
// assert(model->getBooleanValue(reachVars[i]));
// }
uncoveredStates = ~coveredStates;
for (uint64_t i : uncoveredStates) {
auto const& rv =reachVars[i];
auto const& rvExpr =reachVarExpressions[i];
if (model->getBooleanValue(rv)) {
STORM_LOG_TRACE("New state: " << i);
smtSolver->add(rv.getExpression());
smtSolver->add(rvExpr);
assert(!surelyReachSinkStates.get(i));
newObservations.set(pomdp.getObservation(i));
coveredStates.set(i);
}
++i;
}
i = 0;
for (auto const& rv : continuationVars) {
if (!coveredStatesAfterSwitch.get(i) && model->getBooleanValue(rv) ) {
smtSolver->add(rv.getExpression());
if (!observationsAfterSwitch.get(pomdp.getObservation(i))) {
newObservationsAfterSwitch.set(pomdp.getObservation(i));
storm::storage::BitVector uncoveredStatesAfterSwitch(~coveredStatesAfterSwitch);
for (uint64_t i : uncoveredStatesAfterSwitch) {
auto const& cv = continuationVars[i];
if (model->getBooleanValue(cv)) {
uint64_t obs = pomdp.getObservation(i);
STORM_LOG_ASSERT(winningRegion.isWinning(obs,getOffsetFromObservation(i,obs)), "Cannot continue: No scheduler known for state " << i << " (observation " << obs << ").");
auto const& cvExpr =continuationVarExpressions[i];
smtSolver->add(cvExpr);
if (!observationsAfterSwitch.get(obs)) {
newObservationsAfterSwitch.set(obs);
}
++i;
}
}
stats.evaluateExtensionSolverTime.stop();
if (options.computeTraceOutput()) {
detail::printRelevantInfoFromModel(model, reachVars, continuationVars);
}
stats.encodeExtensionSolverTime.start();
for (auto obs : newObservations) {
auto const &actionSelectionVarsForObs = actionSelectionVars[obs];
observations.set(obs);
@ -534,16 +622,11 @@ namespace storm {
if (remainingExpressions.empty()) {
stats.updateExtensionSolverTime.stop();
stats.encodeExtensionSolverTime.stop();
break;
}
// Add scheduler
//std::cout << storm::expressions::disjunction(remainingExpressions) << std::endl;
smtSolver->add(storm::expressions::disjunction(remainingExpressions));
stats.updateExtensionSolverTime.stop();
stats.encodeExtensionSolverTime.stop();
}
if (!newSchedulerDiscovered) {
break;
@ -591,45 +674,58 @@ namespace storm {
}
stats.winningRegionUpdatesTimer.stop();
if (newTargetObservations>0) {
stats.graphSearchTime.start();
storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp);
uint64_t targetStatesBefore = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits());
storm::storage::BitVector newtargetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates);
uint64_t targetStatesAfter = newtargetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states after graph based analysis " << newtargetStates.getNumberOfSetBits());
targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates);
uint64_t targetStatesAfter = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits());
stats.graphSearchTime.stop();
if (targetStatesAfter - targetStatesBefore > 0) {
stats.winningRegionUpdatesTimer.start();
// TODO CODE DUPLICATION WITH INIT, PUT IN PROCEDURE
storm::storage::BitVector potentialWinner(pomdp.getNrObservations());
storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations());
for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
if (winningRegion.observationIsWinning(observation)) {
continue;
}
bool observationIsWinning = true;
for (uint64_t state : statesPerObservation[observation]) {
if(!newtargetStates.get(state)) {
if(!targetStates.get(state)) {
observationIsWinning = false;
break;
observationsWithPartialWinners.set(observation);
} else {
potentialWinner.set(observation);
}
}
if(observationIsWinning) {
stats.incrementGraphBasedWinningObservations();
winningRegion.setObservationIsWinning(observation);
for(auto const& state : statesPerObservation[observation]) {
targetStates.set(state);
}
updated.set(observation);
}
}
STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
uint64_t nonWinObTargetStates =0;
for (uint64_t state : targetStates) {
if (!winningRegion.observationIsWinning(pomdp.getObservation(state))) {
nonWinObTargetStates++;
observationsWithPartialWinners &= potentialWinner;
for (auto const& observation : observationsWithPartialWinners) {
uint64_t nrStatesForObs = statesPerObservation[observation].size();
storm::storage::BitVector update(nrStatesForObs);
for (uint64_t i = 0; i < nrStatesForObs; ++i ) {
uint64_t state = statesPerObservation[observation][i];
if(targetStates.get(state)) {
update.set(i);
}
}
assert(!update.empty());
STORM_LOG_TRACE("Extend winning region for observation " << observation << " with target states/offsets" << update);
winningRegion.addTargetStates(observation, update);
assert(winningRegion.query(observation,update));//
updated.set(observation);
}
stats.winningRegionUpdatesTimer.stop();
if (nonWinObTargetStates > 0) {
std::cout << "Non winning target states " << nonWinObTargetStates << std::endl;
if (observationsWithPartialWinners.getNumberOfSetBits() > 0) {
STORM_LOG_WARN("This case has been barely tested and likely contains bug");
reset();
return analyze(k, ~targetStates & ~surelyReachSinkStates);

11
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -108,8 +108,10 @@ namespace pomdp {
storm::utility::Stopwatch totalTimer;
storm::utility::Stopwatch smtCheckTimer;
storm::utility::Stopwatch initializeSolverTimer;
storm::utility::Stopwatch updateExtensionSolverTime;
storm::utility::Stopwatch evaluateExtensionSolverTime;
storm::utility::Stopwatch encodeExtensionSolverTime;
storm::utility::Stopwatch updateNewStrategySolverTime;
storm::utility::Stopwatch graphSearchTime;
storm::utility::Stopwatch winningRegionUpdatesTimer;
@ -168,9 +170,6 @@ namespace pomdp {
}
void computeWinningRegion(uint64_t k) {
std::cout << surelyReachSinkStates << std::endl;
std::cout << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
stats.totalTimer.start();
analyze(k, ~surelyReachSinkStates & ~targetStates);
stats.totalTimer.stop();
@ -180,6 +179,8 @@ namespace pomdp {
return winningRegion;
}
uint64_t getOffsetFromObservation(uint64_t state, uint64_t observation) const;
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
Statistics const& getStatistics() const;
@ -241,6 +242,8 @@ namespace pomdp {
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory;
std::shared_ptr<WinningRegionQueryInterface<ValueType>> validator;
mutable bool useFindOffset = false;
};
}

31
src/storm-pomdp/analysis/WinningRegion.cpp

@ -18,19 +18,24 @@ namespace pomdp {
winningRegion[observation] = { storm::storage::BitVector(observationSizes[observation], true) };
}
// void WinningRegion::addTargetState(uint64_t observation, uint64_t offset) {
// std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>();
// bool changed = true;
// for (auto const& support : winningRegion[observation]) {
// newWinningSupport.push_back(storm::storage::BitVector(support));
// if(!support.get(offset)) {
// changed = true;
// newWinningSupport.back().set(offset);
// }
// }
//
//
// }
void WinningRegion::addTargetStates(uint64_t observation, storm::storage::BitVector const& offsets) {
assert(!offsets.empty());
if(winningRegion[observation].empty()) {
winningRegion[observation].push_back(offsets);
return;
}
std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>();
for (auto const& support : winningRegion[observation]) {
newWinningSupport.push_back(support | offsets);
}
// TODO it may be worthwhile to check whether something changed. If nothing changed, there is no need for the next routine.
// TODO the following code is bit naive.
winningRegion[observation].clear(); // This prevents some overhead.
for (auto const& newWinning : newWinningSupport) {
update(observation, newWinning);
}
}
bool WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) {
std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>();

1
src/storm-pomdp/analysis/WinningRegion.h

@ -19,6 +19,7 @@ namespace storm {
std::vector<storm::storage::BitVector> const& getWinningSetsPerObservation(uint64_t observation) const;
void addTargetStates(uint64_t observation, storm::storage::BitVector const& offsets);
void setObservationIsWinning(uint64_t observation);
bool observationIsWinning(uint64_t observation) const;

Loading…
Cancel
Save