Browse Source

graph-based analysis improved, and cleaning outputs

main
Sebastian Junges 5 years ago
parent
commit
eca148cee0
  1. 35
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 23
      src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp
  3. 16
      src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp

35
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -352,9 +352,9 @@ namespace storm {
template <typename ValueType> template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
std::cout << "Surely reach sink states: " << surelyReachSinkStates << std::endl;
std::cout << "Target states " << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
STORM_LOG_DEBUG("Surely reach sink states: " << surelyReachSinkStates);
STORM_LOG_DEBUG("Target states " << targetStates);
STORM_LOG_DEBUG("Questionmark states " << (~surelyReachSinkStates & ~targetStates));
stats.initializeSolverTimer.start(); stats.initializeSolverTimer.start();
// TODO: When do we need to reinitialize? When the solver has been reset. // TODO: When do we need to reinitialize? When the solver has been reset.
initialize(k); initialize(k);
@ -363,7 +363,6 @@ namespace storm {
stats.winningRegionUpdatesTimer.start(); stats.winningRegionUpdatesTimer.start();
storm::storage::BitVector updated(pomdp.getNrObservations()); storm::storage::BitVector updated(pomdp.getNrObservations());
// TODO CODE DUPLICATION WITH UPDATE, PUT IN PROCEDURE
storm::storage::BitVector potentialWinner(pomdp.getNrObservations()); storm::storage::BitVector potentialWinner(pomdp.getNrObservations());
storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations()); storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations());
for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -386,7 +385,7 @@ namespace storm {
updated.set(observation); updated.set(observation);
} }
} }
STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
STORM_LOG_DEBUG("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
observationsWithPartialWinners &= potentialWinner; observationsWithPartialWinners &= potentialWinner;
for (auto const& observation : observationsWithPartialWinners) { for (auto const& observation : observationsWithPartialWinners) {
@ -411,10 +410,11 @@ namespace storm {
stats.winningRegionUpdatesTimer.stop(); stats.winningRegionUpdatesTimer.stop();
uint64_t maximalNrActions = 8;
STORM_LOG_WARN("We have hardcoded (an upper bound on) the number of actions");
uint64_t maximalNrActions = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
maximalNrActions = std::max(pomdp.getTransitionMatrix().getRowGroupSize(state),maximalNrActions);
}
std::vector<storm::expressions::Expression> atLeastOneOfStates; std::vector<storm::expressions::Expression> atLeastOneOfStates;
for (uint64_t state : oneOfTheseStates) { for (uint64_t state : oneOfTheseStates) {
STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" ); STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" );
atLeastOneOfStates.push_back(reachVarExpressions[state]); atLeastOneOfStates.push_back(reachVarExpressions[state]);
@ -426,7 +426,6 @@ namespace storm {
std::set<storm::expressions::Expression> allOfTheseAssumption; std::set<storm::expressions::Expression> allOfTheseAssumption;
std::vector<storm::expressions::Expression> updateForObservationExpressions; std::vector<storm::expressions::Expression> updateForObservationExpressions;
for (uint64_t state : allOfTheseStates) { for (uint64_t state : allOfTheseStates) {
@ -484,7 +483,6 @@ namespace storm {
assert(pomdp.getNrObservations() == schedulerForObs.size()); assert(pomdp.getNrObservations() == schedulerForObs.size());
InternalObservationScheduler scheduler; InternalObservationScheduler scheduler;
scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations()); scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations());
storm::storage::BitVector newObservations(pomdp.getNrObservations()); storm::storage::BitVector newObservations(pomdp.getNrObservations());
@ -677,14 +675,13 @@ namespace storm {
stats.graphSearchTime.start(); stats.graphSearchTime.start();
storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp); storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp);
uint64_t targetStatesBefore = targetStates.getNumberOfSetBits(); uint64_t targetStatesBefore = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits());
STORM_LOG_DEBUG("Target states before graph based analysis " << targetStates.getNumberOfSetBits());
targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates);
uint64_t targetStatesAfter = targetStates.getNumberOfSetBits(); uint64_t targetStatesAfter = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits());
STORM_LOG_DEBUG("Target states after graph based analysis " << targetStates.getNumberOfSetBits());
stats.graphSearchTime.stop(); stats.graphSearchTime.stop();
if (targetStatesAfter - targetStatesBefore > 0) { if (targetStatesAfter - targetStatesBefore > 0) {
stats.winningRegionUpdatesTimer.start(); stats.winningRegionUpdatesTimer.start();
// TODO CODE DUPLICATION WITH INIT, PUT IN PROCEDURE
storm::storage::BitVector potentialWinner(pomdp.getNrObservations()); storm::storage::BitVector potentialWinner(pomdp.getNrObservations());
storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations()); storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations());
for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -706,7 +703,7 @@ namespace storm {
updated.set(observation); updated.set(observation);
} }
} }
STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
STORM_LOG_DEBUG("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
observationsWithPartialWinners &= potentialWinner; observationsWithPartialWinners &= potentialWinner;
for (auto const& observation : observationsWithPartialWinners) { for (auto const& observation : observationsWithPartialWinners) {
uint64_t nrStatesForObs = statesPerObservation[observation].size(); uint64_t nrStatesForObs = statesPerObservation[observation].size();
@ -726,7 +723,6 @@ namespace storm {
stats.winningRegionUpdatesTimer.stop(); stats.winningRegionUpdatesTimer.stop();
if (observationsWithPartialWinners.getNumberOfSetBits() > 0) { if (observationsWithPartialWinners.getNumberOfSetBits() > 0) {
STORM_LOG_WARN("This case has been barely tested and likely contains bug");
reset(); reset();
return analyze(k, ~targetStates & ~surelyReachSinkStates); return analyze(k, ~targetStates & ~surelyReachSinkStates);
} }
@ -783,10 +779,9 @@ namespace storm {
if(options.validateResult) { if(options.validateResult) {
STORM_LOG_WARN("Validating result is a winning region, only for debugging purposes."); STORM_LOG_WARN("Validating result is a winning region, only for debugging purposes.");
validator->validate(surelyReachSinkStates); validator->validate(surelyReachSinkStates);
STORM_LOG_WARN("Validating result is a maximal winning region, only for debugging purposes.");
STORM_LOG_WARN("Validating result is a fixed point, only for debugging purposes.");
validator->validateIsMaximal(surelyReachSinkStates); validator->validateIsMaximal(surelyReachSinkStates);
} }
winningRegion.print();
if (!allOfTheseStates.empty()) { if (!allOfTheseStates.empty()) {
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -805,10 +800,10 @@ namespace storm {
} }
return true; return true;
} }
template<typename ValueType> template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printCoveredStates(storm::storage::BitVector const &remaining) const { void MemlessStrategySearchQualitative<ValueType>::printCoveredStates(storm::storage::BitVector const &remaining) const {
STORM_LOG_DEBUG("states that are okay"); STORM_LOG_DEBUG("states that are okay");
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (!remaining.get(state)) { if (!remaining.get(state)) {
@ -816,7 +811,6 @@ namespace storm {
} }
} }
std::cout << std::endl; std::cout << std::endl;
} }
template<typename ValueType> template<typename ValueType>
@ -857,7 +851,8 @@ namespace storm {
stats.incrementSmtChecks(); stats.incrementSmtChecks();
if (result == storm::solver::SmtSolver::CheckResult::Unknown) { if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
STORM_LOG_DEBUG("Unknown");
return false;
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
STORM_LOG_DEBUG("Unsatisfiable!"); STORM_LOG_DEBUG("Unsatisfiable!");
return false; return false;

23
src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp

@ -89,6 +89,8 @@ namespace storm {
STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates); STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates);
// Now find a set of observations such that there is a memoryless scheduler inducing prob. 1 for each state whose observation is in the set. // Now find a set of observations such that there is a memoryless scheduler inducing prob. 1 for each state whose observation is in the set.
storm::storage::BitVector potentialGoalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates); storm::storage::BitVector potentialGoalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates);
STORM_LOG_TRACE("Prob1E states according to MDP: " << potentialGoalStates);
storm::storage::BitVector notGoalStates = ~potentialGoalStates; storm::storage::BitVector notGoalStates = ~potentialGoalStates;
storm::storage::BitVector potentialGoalObservations(pomdp.getNrObservations(), true); storm::storage::BitVector potentialGoalObservations(pomdp.getNrObservations(), true);
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -96,23 +98,39 @@ namespace storm {
potentialGoalObservations.set(pomdp.getObservation(state), false); potentialGoalObservations.set(pomdp.getObservation(state), false);
} }
} }
STORM_LOG_TRACE("Prob1E observations according to MDP: " << potentialGoalObservations);
std::vector<std::vector<uint64_t>> statesPerObservation(pomdp.getNrObservations(), std::vector<uint64_t>()); std::vector<std::vector<uint64_t>> statesPerObservation(pomdp.getNrObservations(), std::vector<uint64_t>());
for (uint64_t state : potentialGoalStates) { for (uint64_t state : potentialGoalStates) {
statesPerObservation[pomdp.getObservation(state)].push_back(state); statesPerObservation[pomdp.getObservation(state)].push_back(state);
} }
storm::storage::BitVector singleObservationStates(pomdp.getNumberOfStates());
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if(statesPerObservation[pomdp.getObservation(state)].size() == 1) {
singleObservationStates.set(state);
}
}
storm::storage::BitVector goalStates(pomdp.getNumberOfStates()); storm::storage::BitVector goalStates(pomdp.getNumberOfStates());
while (goalStates != newGoalStates) { while (goalStates != newGoalStates) {
goalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates); goalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates);
goalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay & singleObservationStates, goalStates);
newGoalStates = goalStates; newGoalStates = goalStates;
STORM_LOG_INFO("Prob1A states according to MDP: " << newGoalStates);
STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates);
for (uint64_t observation : potentialGoalObservations) { for (uint64_t observation : potentialGoalObservations) {
STORM_LOG_TRACE("Consider observation " << observation);
uint64_t actsForObservation = pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[observation][0]); uint64_t actsForObservation = pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[observation][0]);
// Search whether we find an action that works for this observation. // Search whether we find an action that works for this observation.
for (uint64_t act = 0; act < actsForObservation; act++) { for (uint64_t act = 0; act < actsForObservation; act++) {
STORM_LOG_TRACE("Consider action " << act);
bool isGoalAction = true; // Assume that this works, then check whether we find a violation. bool isGoalAction = true; // Assume that this works, then check whether we find a violation.
for (uint64_t state : statesPerObservation[observation]) { for (uint64_t state : statesPerObservation[observation]) {
STORM_LOG_TRACE("Consider state " << state);
if (newGoalStates.get(state)) { if (newGoalStates.get(state)) {
STORM_LOG_TRACE("Already a goal state " << state);
// A state is only a goal state if all actions work, // A state is only a goal state if all actions work,
// or if all states with the same observation are goal states (and then, it does not matter which action is a goal action). // or if all states with the same observation are goal states (and then, it does not matter which action is a goal action).
// Notice that this can mean that we wrongly conclude that some action is okay even if this is not the correct action (but then some other action exists which is okay for all states). // Notice that this can mean that we wrongly conclude that some action is okay even if this is not the correct action (but then some other action exists which is okay for all states).
@ -123,10 +141,13 @@ namespace storm {
bool hasGoalEntry = false; bool hasGoalEntry = false;
for (auto const& entry : pomdp.getTransitionMatrix().getRow(row)) { for (auto const& entry : pomdp.getTransitionMatrix().getRow(row)) {
assert(!storm::utility::isZero(entry.getValue())); assert(!storm::utility::isZero(entry.getValue()));
if(newGoalStates.get(entry.getColumn())) { if(newGoalStates.get(entry.getColumn())) {
STORM_LOG_TRACE("Reaches state " << entry.getColumn() << " which is a PROB1e state");
hasGoalEntry = true; hasGoalEntry = true;
} }
else if(pomdp.getObservation(entry.getColumn()) != observation) { else if(pomdp.getObservation(entry.getColumn()) != observation) {
STORM_LOG_TRACE("Reaches state " << entry.getColumn() << " which is not a PROB1e state");
isGoalAction = false; isGoalAction = false;
break; break;
} }

16
src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp

@ -50,7 +50,6 @@ namespace storm {
smtSolver->add(storm::expressions::disjunction(actionVars)); smtSolver->add(storm::expressions::disjunction(actionVars));
} }
uint64_t rowindex = 0; uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
@ -78,7 +77,6 @@ namespace storm {
smtSolver->add(!reachVarExpressions[state]); smtSolver->add(!reachVarExpressions[state]);
rowindex += pomdp.getNumberOfChoices(state); rowindex += pomdp.getNumberOfChoices(state);
} else if(!targetStates.get(state)) { } else if(!targetStates.get(state)) {
std::cout << state << " is not a target state" << std::endl;
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) { for (uint64_t j = 1; j < k; ++j) {
@ -160,7 +158,6 @@ namespace storm {
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
for (auto rv : reachVars) { for (auto rv : reachVars) {
if (model->getBooleanValue(rv)) { if (model->getBooleanValue(rv)) {
std::cout << i << " " << std::endl;
observations.set(pomdp.getObservation(i)); observations.set(pomdp.getObservation(i));
} else { } else {
remainingstates.set(i); remainingstates.set(i);
@ -179,18 +176,7 @@ namespace storm {
} }
} }
// TODO move this into a print scheduler function.
//STORM_LOG_TRACE("the scheduler: ");
for (uint64_t obs = 0; obs < scheduler.size(); ++obs) {
if (observations.get(obs)) {
//STORM_LOG_TRACE("observation: " << obs);
//std::cout << "actions:";
//for (auto act : scheduler[obs]) {
// std::cout << " " << act;
//}
//std::cout << std::endl;
}
}
return true; return true;

Loading…
Cancel
Save