diff --git a/src/generator/StateBehavior.cpp b/src/generator/StateBehavior.cpp index e14417779..6ae1f6de0 100644 --- a/src/generator/StateBehavior.cpp +++ b/src/generator/StateBehavior.cpp @@ -50,6 +50,11 @@ namespace storm { return stateRewards; } + template + std::size_t StateBehavior::getNumberOfChoices() const { + return choices.size(); + } + template class StateBehavior; template class StateBehavior; diff --git a/src/generator/StateBehavior.h b/src/generator/StateBehavior.h index 393c3280c..8f81c8b30 100644 --- a/src/generator/StateBehavior.h +++ b/src/generator/StateBehavior.h @@ -56,6 +56,11 @@ namespace storm { */ std::vector const& getStateRewards() const; + /*! + * Retrieves the number of choices in the behavior. + */ + std::size_t getNumberOfChoices() const; + private: // The choices available in the state. std::vector> choices; diff --git a/src/modelchecker/AbstractModelChecker.cpp b/src/modelchecker/AbstractModelChecker.cpp index 0ad70fd4b..850cae65d 100644 --- a/src/modelchecker/AbstractModelChecker.cpp +++ b/src/modelchecker/AbstractModelChecker.cpp @@ -13,7 +13,7 @@ namespace storm { namespace modelchecker { std::unique_ptr AbstractModelChecker::check(CheckTask const& checkTask) { storm::logic::Formula const& formula = checkTask.getFormula(); - STORM_LOG_THROW(this->canHandle(formula), storm::exceptions::InvalidArgumentException, "The model checker is not able to check the formula '" << formula << "'."); + STORM_LOG_THROW(this->canHandle(checkTask), storm::exceptions::InvalidArgumentException, "The model checker is not able to check the formula '" << formula << "'."); if (formula.isStateFormula()) { return this->checkStateFormula(checkTask.substituteFormula(formula.asStateFormula())); } diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp index 9771e0ae1..f96308e05 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp @@ -47,6 +47,13 @@ namespace storm { upperBounds[stateToRowGroupMapping[sourceStateId]] = newUpperValue; } + template + void SparseMdpLearningModelChecker::updateProbabilitiesUsingStack(std::vector>& stateActionStack, StateType const& currentStateId, std::vector>> const& transitionMatrix, std::vector const& rowGroupIndices, std::vector const& stateToRowGroupMapping, std::vector& lowerBounds, std::vector& upperBounds) const { + while (!stateActionStack.empty()) { + updateProbabilities(stateActionStack.back().first, stateActionStack.back().second, currentStateId, transitionMatrix, rowGroupIndices, stateToRowGroupMapping, lowerBounds, upperBounds); + } + } + template std::unique_ptr SparseMdpLearningModelChecker::computeReachabilityProbabilities(CheckTask const& checkTask) { storm::logic::EventuallyFormula const& eventuallyFormula = checkTask.getFormula(); @@ -62,9 +69,6 @@ namespace storm { // A container for the encountered states. storm::storage::sparse::StateStorage stateStorage(variableInformation.getTotalBitOffset(true)); - // A container that stores the states that were already expanded. - storm::storage::BitVector expandedStates; - // A generator used to explore the model. storm::generator::PrismNextStateGenerator generator(program, variableInformation, false); @@ -81,8 +85,11 @@ namespace storm { std::vector lowerBounds; std::vector upperBounds; + // A mapping of unexplored IDs to their actual compressed states. + std::unordered_map unexploredStates; + // Create a callback for the next-state generator to enable it to request the index of states. - std::function stateToIdCallback = [&stateStorage] (storm::generator::CompressedState const& state) -> StateType { + std::function stateToIdCallback = [&stateStorage, &stateToRowGroupMapping, &unexploredStates] (storm::generator::CompressedState const& state) -> StateType { StateType newIndex = stateStorage.numberOfStates; // Check, if the state was already registered. @@ -90,6 +97,8 @@ namespace storm { if (actualIndexBucketPair.first == newIndex) { ++stateStorage.numberOfStates; + stateToRowGroupMapping.push_back(0); + unexploredStates[newIndex] = state; } return actualIndexBucketPair.first; @@ -99,38 +108,43 @@ namespace storm { STORM_LOG_THROW(stateStorage.initialStateIndices.size() == 1, storm::exceptions::NotSupportedException, "Currently only models with one initial state are supported by the learning engine."); // Now perform the actual sampling. - std::unordered_map unexploredStates; std::vector> stateActionStack; stateActionStack.push_back(std::make_pair(stateStorage.initialStateIndices.front(), 0)); bool foundTargetState = false; while (!foundTargetState) { StateType const& currentStateId = stateActionStack.back().first; + STORM_LOG_TRACE("State on top of stack is: " << currentStateId << "."); // If the state is not yet expanded, we need to retrieve its behaviors. - if (!expandedStates.get(currentStateId)) { + auto unexploredIt = unexploredStates.find(currentStateId); + if (unexploredIt != unexploredStates.end()) { + STORM_LOG_TRACE("State was not yet expanded."); + // First, we need to get the compressed state back from the id. - auto it = unexploredStates.find(currentStateId); - STORM_LOG_ASSERT(it != unexploredStates.end(), "Unable to find unexplored state."); - storm::storage::BitVector currentState = it->second; + STORM_LOG_ASSERT(unexploredIt != unexploredStates.end(), "Unable to find unexplored state " << currentStateId << "."); + storm::storage::BitVector const& currentState = unexploredIt->second; // Before generating the behavior of the state, we need to determine whether it's a target state that // does not need to be expanded. generator.load(currentState); if (generator.satisfies(targetStateExpression)) { + STORM_LOG_TRACE("State does not need to be expanded, because it is a target state."); + // If it's in fact a goal state, we need to go backwards in the stack and update the probabilities. foundTargetState = true; stateActionStack.pop_back(); - while (!stateActionStack.empty()) { - updateProbabilities(stateActionStack.back().first, stateActionStack.back().second, currentStateId, matrix, rowGroupIndices, stateToRowGroupMapping, lowerBounds, upperBounds); - } - break; + STORM_LOG_TRACE("Updating probabilities along states in stack."); + updateProbabilitiesUsingStack(stateActionStack, currentStateId, matrix, rowGroupIndices, stateToRowGroupMapping, lowerBounds, upperBounds); } else { + STORM_LOG_TRACE("Expanding state."); + // If it needs to be expanded, we use the generator to retrieve the behavior of the new state. storm::generator::StateBehavior behavior = generator.expand(stateToIdCallback); + STORM_LOG_TRACE("State has " << behavior.getNumberOfChoices() << " choices."); - stateToRowGroupMapping.push_back(rowGroupIndices.size()); + stateToRowGroupMapping[currentStateId] = rowGroupIndices.size(); rowGroupIndices.push_back(matrix.size()); // Next, we insert the behavior into our matrix structure. @@ -142,13 +156,23 @@ namespace storm { } // Now that we have explored the state, we can dispose of it. - unexploredStates.erase(it); + unexploredStates.erase(unexploredIt); } } - - // At this point, we can be sure that the state was expanded and that we can sample according to the probabilities. - // TODO: set action of topmost stack element + if (!foundTargetState) { + // At this point, we can be sure that the state was expanded and that we can sample according to the + // probabilities in the matrix. + STORM_LOG_TRACE("Sampling action in state."); + uint32_t chosenAction = 0; + + STORM_LOG_TRACE("Sampling successor state according to action " << chosenAction << "."); + break; + + // TODO: set action of topmost stack element + // TOOD: determine if end component (state) + + } } return nullptr; diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h index b181a44f6..f63ce40a4 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h @@ -27,6 +27,8 @@ namespace storm { private: void updateProbabilities(StateType const& sourceStateId, uint32_t action, StateType const& targetStateId, std::vector>> const& transitionMatrix, std::vector const& rowGroupIndices, std::vector const& stateToRowGroupMapping, std::vector& lowerBounds, std::vector& upperBounds) const; + void updateProbabilitiesUsingStack(std::vector>& stateActionStack, StateType const& currentStateId, std::vector>> const& transitionMatrix, std::vector const& rowGroupIndices, std::vector const& stateToRowGroupMapping, std::vector& lowerBounds, std::vector& upperBounds) const; + // The program that defines the model to check. storm::prism::Program program;