Browse Source

better performance when only looking for a winning policy

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
a90a82d271
  1. 31
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

31
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -92,8 +92,13 @@ namespace storm {
} else { } else {
lookaheadConstraintsRequired = qualitative::isLookaheadRequired(pomdp, targetStates, surelyReachSinkStates); lookaheadConstraintsRequired = qualitative::isLookaheadRequired(pomdp, targetStates, surelyReachSinkStates);
} }
if (options.pathVariableType == MemlessSearchPathVariables::RealRanking) {
k = 10; //magic constant, consider moving.
}
if (actionSelectionVars.empty()) { if (actionSelectionVars.empty()) {
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
actionSelectionVars.push_back(std::vector<storm::expressions::Variable>()); actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>()); actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
@ -537,6 +542,8 @@ namespace storm {
STORM_LOG_INFO("Start iterative solver..."); STORM_LOG_INFO("Start iterative solver...");
uint64_t iterations = 0; uint64_t iterations = 0;
bool foundWhatWeLookFor = false;
while(true) { while(true) {
stats.incrementOuterIterations(); stats.incrementOuterIterations();
// TODO consider what we really want to store about the schedulers. // TODO consider what we really want to store about the schedulers.
@ -546,20 +553,23 @@ namespace storm {
coveredStates = targetStates; coveredStates = targetStates;
coveredStatesAfterSwitch.clear(); coveredStatesAfterSwitch.clear();
observationUpdated.clear(); observationUpdated.clear();
bool newSchedulerDiscovered = false;
if (!allOfTheseAssumption.empty()) { if (!allOfTheseAssumption.empty()) {
bool foundResult = this->smtCheck(iterations, allOfTheseAssumption); bool foundResult = this->smtCheck(iterations, allOfTheseAssumption);
if (foundResult) { if (foundResult) {
// Consider storing the scheduler // Consider storing the scheduler
return true;
foundWhatWeLookFor = true;
} }
} }
bool newSchedulerDiscovered = false;
uint64_t localIterations = 0;
while (true) { while (true) {
++iterations; ++iterations;
++localIterations;
bool foundScheduler = this->smtCheck(iterations);
bool foundScheduler = foundWhatWeLookFor;
if (!foundScheduler) {
foundScheduler = this->smtCheck(iterations);
}
if (!foundScheduler) { if (!foundScheduler) {
break; break;
} }
@ -567,6 +577,7 @@ namespace storm {
stats.evaluateExtensionSolverTime.start(); stats.evaluateExtensionSolverTime.start();
auto const& model = smtSolver->getModel(); auto const& model = smtSolver->getModel();
newObservationsAfterSwitch.clear(); newObservationsAfterSwitch.clear();
newObservations.clear(); newObservations.clear();
@ -655,6 +666,11 @@ namespace storm {
scheduler.printForObservations(observations,observationsAfterSwitch); scheduler.printForObservations(observations,observationsAfterSwitch);
} }
if (foundWhatWeLookFor || (options.localIterationMaximum > 0 && (localIterations % (options.localIterationMaximum + 1) == options.localIterationMaximum))) {
stats.encodeExtensionSolverTime.stop();
break;
}
std::vector<storm::expressions::Expression> remainingExpressions; std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : ~coveredStates) { for (auto index : ~coveredStates) {
if (observationUpdated.get(pomdp.getObservation(index))) { if (observationUpdated.get(pomdp.getObservation(index))) {
@ -672,10 +688,12 @@ namespace storm {
} }
smtSolver->add(storm::expressions::disjunction(remainingExpressions)); smtSolver->add(storm::expressions::disjunction(remainingExpressions));
stats.encodeExtensionSolverTime.stop(); stats.encodeExtensionSolverTime.stop();
//smtSolver->setTimeout(options.extensionCallTimeout);
} }
if (!newSchedulerDiscovered) { if (!newSchedulerDiscovered) {
break; break;
} }
//smtSolver->unsetTimeout();
smtSolver->pop(); smtSolver->pop();
if(options.computeDebugOutput()) { if(options.computeDebugOutput()) {
@ -718,6 +736,9 @@ namespace storm {
} }
} }
stats.winningRegionUpdatesTimer.stop(); stats.winningRegionUpdatesTimer.stop();
if (foundWhatWeLookFor) {
return true;
}
if (newTargetObservations>0) { if (newTargetObservations>0) {
stats.graphSearchTime.start(); stats.graphSearchTime.start();
storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp); storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp);

Loading…
Cancel
Save