Browse Source
			
			
			make everything compile again, add/fix method for memless strategy search (CCD16) and towards iterative search
			
			
				main
			
			
		
		make everything compile again, add/fix method for memless strategy search (CCD16) and towards iterative search
	
		
	
			
			
				main
			
			
		
				 7 changed files with 450 additions and 73 deletions
			
			
		- 
					9src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp
- 
					1src/storm-pomdp-cli/settings/modules/POMDPSettings.h
- 
					46src/storm-pomdp-cli/storm-pomdp.cpp
- 
					149src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
- 
					52src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
- 
					186src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp
- 
					74src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h
| @ -0,0 +1,186 @@ | |||||
|  | 
 | ||||
|  | 
 | ||||
|  | #include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h"
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | namespace storm { | ||||
|  |     namespace pomdp { | ||||
|  | 
 | ||||
|  |         template <typename ValueType> | ||||
|  |         void QualitativeStrategySearchNaive<ValueType>::initialize(uint64_t k) { | ||||
|  |             if (maxK == std::numeric_limits<uint64_t>::max()) { | ||||
|  |                 // not initialized at all.
 | ||||
|  |                 // Create some data structures.
 | ||||
|  |                 for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { | ||||
|  |                     actionSelectionVars.push_back(std::vector<storm::expressions::Variable>()); | ||||
|  |                     actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>()); | ||||
|  |                     statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
 | ||||
|  |                 } | ||||
|  | 
 | ||||
|  |                 // Fill the states-per-observation mapping,
 | ||||
|  |                 // declare the reachability variables,
 | ||||
|  |                 // declare the path variables.
 | ||||
|  |                 uint64_t stateId = 0; | ||||
|  |                 for(auto obs : pomdp.getObservations()) { | ||||
|  |                     pathVars.push_back(std::vector<storm::expressions::Expression>()); | ||||
|  |                     for (uint64_t i = 0; i < k; ++i) { | ||||
|  |                         pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression()); | ||||
|  |                     } | ||||
|  |                     reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); | ||||
|  |                     reachVarExpressions.push_back(reachVars.back().getExpression()); | ||||
|  |                     statesPerObservation.at(obs).push_back(stateId++); | ||||
|  |                 } | ||||
|  |                 assert(pathVars.size() == pomdp.getNumberOfStates()); | ||||
|  | 
 | ||||
|  |                 // Create the action selection variables.
 | ||||
|  |                 uint64_t obs = 0; | ||||
|  |                 for(auto const& statesForObservation : statesPerObservation) { | ||||
|  |                     for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) { | ||||
|  |                         std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a); | ||||
|  |                         actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName)); | ||||
|  |                         actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression()); | ||||
|  |                     } | ||||
|  |                     ++obs; | ||||
|  |                 } | ||||
|  |             } else { | ||||
|  |                 assert(false); | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             uint64_t rowindex = 0; | ||||
|  |             for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { | ||||
|  |                 if (targetStates.get(state)) { | ||||
|  |                     smtSolver->add(pathVars[state][0]); | ||||
|  |                 } else { | ||||
|  |                     smtSolver->add(!pathVars[state][0]); | ||||
|  |                 } | ||||
|  | 
 | ||||
|  |                 if (surelyReachSinkStates.get(state)) { | ||||
|  |                     smtSolver->add(!reachVarExpressions[state]); | ||||
|  |                 } else if(!targetStates.get(state)) { | ||||
|  |                     std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; | ||||
|  |                     for (uint64_t j = 1; j < k; ++j) { | ||||
|  |                         pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); | ||||
|  |                         for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { | ||||
|  |                             pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>()); | ||||
|  |                         } | ||||
|  |                     } | ||||
|  | 
 | ||||
|  |                     for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { | ||||
|  |                         std::vector<storm::expressions::Expression> subexprreach; | ||||
|  | 
 | ||||
|  |                         subexprreach.push_back(!reachVarExpressions.at(state)); | ||||
|  |                         subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); | ||||
|  |                         for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { | ||||
|  |                             subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); | ||||
|  |                         } | ||||
|  |                         smtSolver->add(storm::expressions::disjunction(subexprreach)); | ||||
|  |                         for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { | ||||
|  |                             for (uint64_t j = 1; j < k; ++j) { | ||||
|  |                                 pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); | ||||
|  |                             } | ||||
|  |                         } | ||||
|  |                         rowindex++; | ||||
|  |                     } | ||||
|  |                     smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); | ||||
|  | 
 | ||||
|  |                     for (uint64_t j = 1; j < k; ++j) { | ||||
|  |                         std::vector<storm::expressions::Expression> pathsubexprs; | ||||
|  | 
 | ||||
|  |                         for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { | ||||
|  |                             pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); | ||||
|  |                         } | ||||
|  |                         smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); | ||||
|  |                     } | ||||
|  |                 } | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             for (auto const& actionVars : actionSelectionVarExpressions) { | ||||
|  |                 smtSolver->add(storm::expressions::disjunction(actionVars)); | ||||
|  |             } | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template <typename ValueType> | ||||
|  |         bool QualitativeStrategySearchNaive<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { | ||||
|  |             if (k < maxK) { | ||||
|  |                 initialize(k); | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             std::vector<storm::expressions::Expression> atLeastOneOfStates; | ||||
|  | 
 | ||||
|  |             for(uint64_t state : oneOfTheseStates) { | ||||
|  |                 atLeastOneOfStates.push_back(reachVarExpressions[state]); | ||||
|  |             } | ||||
|  |             assert(atLeastOneOfStates.size() > 0); | ||||
|  |             smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); | ||||
|  | 
 | ||||
|  |             for(uint64_t state : allOfTheseStates) { | ||||
|  |                 smtSolver->add(reachVarExpressions[state]); | ||||
|  |             } | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |             std::cout << smtSolver->getSmtLibString() << std::endl; | ||||
|  | 
 | ||||
|  |             auto result = smtSolver->check(); | ||||
|  |             uint64_t  i = 0; | ||||
|  |             smtSolver->push(); | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |             if (result == storm::solver::SmtSolver::CheckResult::Unknown) { | ||||
|  |                 STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); | ||||
|  |             } else if(result == storm::solver::SmtSolver::CheckResult::Unsat) { | ||||
|  |                 std::cout << std::endl << "Unsatisfiable!" << std::endl; | ||||
|  |                 return false; | ||||
|  |             } else { | ||||
|  | 
 | ||||
|  |                 std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; | ||||
|  |                 auto model = smtSolver->getModel(); | ||||
|  |                 std::cout << "states that are okay" << std::endl; | ||||
|  |                 storm::storage::BitVector observations(pomdp.getNrObservations()); | ||||
|  |                 storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); | ||||
|  |                 for (auto rv : reachVars) { | ||||
|  |                     if (model->getBooleanValue(rv)) { | ||||
|  |                         std::cout << i << " " << std::endl; | ||||
|  |                         observations.set(pomdp.getObservation(i)); | ||||
|  |                     } else { | ||||
|  |                         remainingstates.set(i); | ||||
|  |                     } | ||||
|  |                     //std::cout << i << ": " << model->getBooleanValue(rv) << ", ";
 | ||||
|  |                     ++i; | ||||
|  |                 } | ||||
|  |                 std::vector <std::set<uint64_t>> scheduler; | ||||
|  |                 for (auto const &actionSelectionVarsForObs : actionSelectionVars) { | ||||
|  |                     uint64_t act = 0; | ||||
|  |                     scheduler.push_back(std::set<uint64_t>()); | ||||
|  |                     for (auto const &asv : actionSelectionVarsForObs) { | ||||
|  |                         if (model->getBooleanValue(asv)) { | ||||
|  |                             scheduler.back().insert(act); | ||||
|  |                         } | ||||
|  |                         act++; | ||||
|  |                     } | ||||
|  |                 } | ||||
|  |                 std::cout << "the scheduler: " << std::endl; | ||||
|  |                 for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { | ||||
|  |                     if (observations.get(obs)) { | ||||
|  |                         std::cout << "observation: " << obs << std::endl; | ||||
|  |                         std::cout << "actions:"; | ||||
|  |                         for (auto act : scheduler[obs]) { | ||||
|  |                             std::cout << " " << act; | ||||
|  |                         } | ||||
|  |                         std::cout << std::endl; | ||||
|  |                     } | ||||
|  |                 } | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |                 return true; | ||||
|  |             } | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template class QualitativeStrategySearchNaive<double>; | ||||
|  |         template class QualitativeStrategySearchNaive<storm::RationalNumber>; | ||||
|  |     } | ||||
|  | } | ||||
| @ -0,0 +1,74 @@ | |||||
|  | #include <vector> | ||||
|  | #include "storm/storage/expressions/Expressions.h" | ||||
|  | #include "storm/solver/SmtSolver.h" | ||||
|  | #include "storm/models/sparse/Pomdp.h" | ||||
|  | #include "storm/utility/solver.h" | ||||
|  | #include "storm/exceptions/UnexpectedException.h" | ||||
|  | 
 | ||||
|  | namespace storm { | ||||
|  |     namespace pomdp { | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         class QualitativeStrategySearchNaive { | ||||
|  |             // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |         public: | ||||
|  |             QualitativeStrategySearchNaive(storm::models::sparse::Pomdp<ValueType> const& pomdp, | ||||
|  |                                              std::set<uint32_t> const& targetObservationSet, | ||||
|  |                                              storm::storage::BitVector const& targetStates, | ||||
|  |                                              storm::storage::BitVector const& surelyReachSinkStates, | ||||
|  |                                              std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) : | ||||
|  |                     pomdp(pomdp), | ||||
|  |                     targetStates(targetStates), | ||||
|  |                     surelyReachSinkStates(surelyReachSinkStates), | ||||
|  |                     targetObservations(targetObservationSet) { | ||||
|  |                 this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>(); | ||||
|  |                 smtSolver = smtSolverFactory->create(*expressionManager); | ||||
|  | 
 | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) { | ||||
|  |                 surelyReachSinkStates = surelyReachSink; | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             void analyzeForInitialStates(uint64_t k) { | ||||
|  |                 analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             void findNewStrategyForSomeState(uint64_t k) { | ||||
|  |                 std::cout << surelyReachSinkStates << std::endl; | ||||
|  |                 std::cout << targetStates << std::endl; | ||||
|  |                 std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; | ||||
|  |                 analyze(k, ~surelyReachSinkStates & ~targetStates); | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); | ||||
|  | 
 | ||||
|  |         private: | ||||
|  |             void initialize(uint64_t k); | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |             std::unique_ptr<storm::solver::SmtSolver> smtSolver; | ||||
|  |             storm::models::sparse::Pomdp<ValueType> const& pomdp; | ||||
|  |             std::shared_ptr<storm::expressions::ExpressionManager> expressionManager; | ||||
|  |             uint64_t maxK = std::numeric_limits<uint64_t>::max(); | ||||
|  | 
 | ||||
|  |             std::set<uint32_t> targetObservations; | ||||
|  |             storm::storage::BitVector targetStates; | ||||
|  |             storm::storage::BitVector surelyReachSinkStates; | ||||
|  | 
 | ||||
|  |             std::vector<std::vector<uint64_t>> statesPerObservation; | ||||
|  |             std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a} | ||||
|  |             std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars; | ||||
|  |             std::vector<storm::expressions::Variable> reachVars; | ||||
|  |             std::vector<storm::expressions::Expression> reachVarExpressions; | ||||
|  |             std::vector<std::vector<storm::expressions::Expression>> pathVars; | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |         }; | ||||
|  |     } | ||||
|  | } | ||||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue