Browse Source

graph-based analysis improved, and cleaning outputs

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
eca148cee0
  1. 35
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 23
      src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp
  3. 16
      src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp

35
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -352,9 +352,9 @@ namespace storm {
template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
std::cout << "Surely reach sink states: " << surelyReachSinkStates << std::endl;
std::cout << "Target states " << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
STORM_LOG_DEBUG("Surely reach sink states: " << surelyReachSinkStates);
STORM_LOG_DEBUG("Target states " << targetStates);
STORM_LOG_DEBUG("Questionmark states " << (~surelyReachSinkStates & ~targetStates));
stats.initializeSolverTimer.start();
// TODO: When do we need to reinitialize? When the solver has been reset.
initialize(k);
@ -363,7 +363,6 @@ namespace storm {
stats.winningRegionUpdatesTimer.start();
storm::storage::BitVector updated(pomdp.getNrObservations());
// TODO CODE DUPLICATION WITH UPDATE, PUT IN PROCEDURE
storm::storage::BitVector potentialWinner(pomdp.getNrObservations());
storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations());
for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -386,7 +385,7 @@ namespace storm {
updated.set(observation);
}
}
STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
STORM_LOG_DEBUG("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
observationsWithPartialWinners &= potentialWinner;
for (auto const& observation : observationsWithPartialWinners) {
@ -411,10 +410,11 @@ namespace storm {
stats.winningRegionUpdatesTimer.stop();
uint64_t maximalNrActions = 8;
STORM_LOG_WARN("We have hardcoded (an upper bound on) the number of actions");
uint64_t maximalNrActions = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
maximalNrActions = std::max(pomdp.getTransitionMatrix().getRowGroupSize(state),maximalNrActions);
}
std::vector<storm::expressions::Expression> atLeastOneOfStates;
for (uint64_t state : oneOfTheseStates) {
STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" );
atLeastOneOfStates.push_back(reachVarExpressions[state]);
@ -426,7 +426,6 @@ namespace storm {
std::set<storm::expressions::Expression> allOfTheseAssumption;
std::vector<storm::expressions::Expression> updateForObservationExpressions;
for (uint64_t state : allOfTheseStates) {
@ -484,7 +483,6 @@ namespace storm {
assert(pomdp.getNrObservations() == schedulerForObs.size());
InternalObservationScheduler scheduler;
scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations());
storm::storage::BitVector newObservations(pomdp.getNrObservations());
@ -677,14 +675,13 @@ namespace storm {
stats.graphSearchTime.start();
storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp);
uint64_t targetStatesBefore = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits());
STORM_LOG_DEBUG("Target states before graph based analysis " << targetStates.getNumberOfSetBits());
targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates);
uint64_t targetStatesAfter = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits());
STORM_LOG_DEBUG("Target states after graph based analysis " << targetStates.getNumberOfSetBits());
stats.graphSearchTime.stop();
if (targetStatesAfter - targetStatesBefore > 0) {
stats.winningRegionUpdatesTimer.start();
// TODO CODE DUPLICATION WITH INIT, PUT IN PROCEDURE
storm::storage::BitVector potentialWinner(pomdp.getNrObservations());
storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations());
for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -706,7 +703,7 @@ namespace storm {
updated.set(observation);
}
}
STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
STORM_LOG_DEBUG("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
observationsWithPartialWinners &= potentialWinner;
for (auto const& observation : observationsWithPartialWinners) {
uint64_t nrStatesForObs = statesPerObservation[observation].size();
@ -726,7 +723,6 @@ namespace storm {
stats.winningRegionUpdatesTimer.stop();
if (observationsWithPartialWinners.getNumberOfSetBits() > 0) {
STORM_LOG_WARN("This case has been barely tested and likely contains bug");
reset();
return analyze(k, ~targetStates & ~surelyReachSinkStates);
}
@ -783,10 +779,9 @@ namespace storm {
if(options.validateResult) {
STORM_LOG_WARN("Validating result is a winning region, only for debugging purposes.");
validator->validate(surelyReachSinkStates);
STORM_LOG_WARN("Validating result is a maximal winning region, only for debugging purposes.");
STORM_LOG_WARN("Validating result is a fixed point, only for debugging purposes.");
validator->validateIsMaximal(surelyReachSinkStates);
}
winningRegion.print();
if (!allOfTheseStates.empty()) {
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -805,10 +800,10 @@ namespace storm {
}
return true;
}
template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printCoveredStates(storm::storage::BitVector const &remaining) const {
STORM_LOG_DEBUG("states that are okay");
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (!remaining.get(state)) {
@ -816,7 +811,6 @@ namespace storm {
}
}
std::cout << std::endl;
}
template<typename ValueType>
@ -857,7 +851,8 @@ namespace storm {
stats.incrementSmtChecks();
if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
STORM_LOG_DEBUG("Unknown");
return false;
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
STORM_LOG_DEBUG("Unsatisfiable!");
return false;

23
src/storm-pomdp/analysis/QualitativeAnalysisOnGraphs.cpp

@ -89,6 +89,8 @@ namespace storm {
STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates);
// Now find a set of observations such that there is a memoryless scheduler inducing prob. 1 for each state whose observation is in the set.
storm::storage::BitVector potentialGoalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates);
STORM_LOG_TRACE("Prob1E states according to MDP: " << potentialGoalStates);
storm::storage::BitVector notGoalStates = ~potentialGoalStates;
storm::storage::BitVector potentialGoalObservations(pomdp.getNrObservations(), true);
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
@ -96,23 +98,39 @@ namespace storm {
potentialGoalObservations.set(pomdp.getObservation(state), false);
}
}
STORM_LOG_TRACE("Prob1E observations according to MDP: " << potentialGoalObservations);
std::vector<std::vector<uint64_t>> statesPerObservation(pomdp.getNrObservations(), std::vector<uint64_t>());
for (uint64_t state : potentialGoalStates) {
statesPerObservation[pomdp.getObservation(state)].push_back(state);
}
storm::storage::BitVector singleObservationStates(pomdp.getNumberOfStates());
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if(statesPerObservation[pomdp.getObservation(state)].size() == 1) {
singleObservationStates.set(state);
}
}
storm::storage::BitVector goalStates(pomdp.getNumberOfStates());
while (goalStates != newGoalStates) {
goalStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay, newGoalStates);
goalStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), okay & singleObservationStates, goalStates);
newGoalStates = goalStates;
STORM_LOG_INFO("Prob1A states according to MDP: " << newGoalStates);
STORM_LOG_TRACE("Prob1A states according to MDP: " << newGoalStates);
for (uint64_t observation : potentialGoalObservations) {
STORM_LOG_TRACE("Consider observation " << observation);
uint64_t actsForObservation = pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[observation][0]);
// Search whether we find an action that works for this observation.
for (uint64_t act = 0; act < actsForObservation; act++) {
STORM_LOG_TRACE("Consider action " << act);
bool isGoalAction = true; // Assume that this works, then check whether we find a violation.
for (uint64_t state : statesPerObservation[observation]) {
STORM_LOG_TRACE("Consider state " << state);
if (newGoalStates.get(state)) {
STORM_LOG_TRACE("Already a goal state " << state);
// A state is only a goal state if all actions work,
// or if all states with the same observation are goal states (and then, it does not matter which action is a goal action).
// Notice that this can mean that we wrongly conclude that some action is okay even if this is not the correct action (but then some other action exists which is okay for all states).
@ -123,10 +141,13 @@ namespace storm {
bool hasGoalEntry = false;
for (auto const& entry : pomdp.getTransitionMatrix().getRow(row)) {
assert(!storm::utility::isZero(entry.getValue()));
if(newGoalStates.get(entry.getColumn())) {
STORM_LOG_TRACE("Reaches state " << entry.getColumn() << " which is a PROB1e state");
hasGoalEntry = true;
}
else if(pomdp.getObservation(entry.getColumn()) != observation) {
STORM_LOG_TRACE("Reaches state " << entry.getColumn() << " which is not a PROB1e state");
isGoalAction = false;
break;
}

16
src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp

@ -50,7 +50,6 @@ namespace storm {
smtSolver->add(storm::expressions::disjunction(actionVars));
}
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
@ -78,7 +77,6 @@ namespace storm {
smtSolver->add(!reachVarExpressions[state]);
rowindex += pomdp.getNumberOfChoices(state);
} else if(!targetStates.get(state)) {
std::cout << state << " is not a target state" << std::endl;
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
@ -160,7 +158,6 @@ namespace storm {
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
for (auto rv : reachVars) {
if (model->getBooleanValue(rv)) {
std::cout << i << " " << std::endl;
observations.set(pomdp.getObservation(i));
} else {
remainingstates.set(i);
@ -179,18 +176,7 @@ namespace storm {
}
}
// TODO move this into a print scheduler function.
//STORM_LOG_TRACE("the scheduler: ");
for (uint64_t obs = 0; obs < scheduler.size(); ++obs) {
if (observations.get(obs)) {
//STORM_LOG_TRACE("observation: " << obs);
//std::cout << "actions:";
//for (auto act : scheduler[obs]) {
// std::cout << " " << act;
//}
//std::cout << std::endl;
}
}
return true;

Loading…
Cancel
Save