From 9f52d9fa97cab6f5148987ac274d263e16e16fbb Mon Sep 17 00:00:00 2001 From: dehnert Date: Wed, 30 Mar 2016 21:30:02 +0200 Subject: [PATCH] first working version (for DTMCs only) Former-commit-id: d3c789596e92d09c6752081b95bbf075f560c2f4 --- .../SparseMdpLearningModelChecker.cpp | 67 ++++++++----------- .../SparseMdpLearningModelChecker.h | 2 +- 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp index d01ed5156..46cf53011 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp @@ -61,7 +61,6 @@ namespace storm { lowerBoundsPerAction[rowGroupIndices[stateToRowGroupMapping[sourceStateId]] + action] = newLowerValue; STORM_LOG_TRACE("Updating lower value of action " << action << " of state " << sourceStateId << " to " << newLowerValue << "."); upperBoundsPerAction[rowGroupIndices[stateToRowGroupMapping[sourceStateId]] + action] = newUpperValue; - std::cout << "writing " << newUpperValue << " at index " << (rowGroupIndices[stateToRowGroupMapping[sourceStateId]] + action) << std::endl; STORM_LOG_TRACE("Updating upper value of action " << action << " of state " << sourceStateId << " to " << newUpperValue << "."); // Check if we need to update the values for the states. @@ -78,17 +77,10 @@ namespace storm { template void SparseMdpLearningModelChecker::updateProbabilitiesUsingStack(std::vector>& stateActionStack, std::vector>> const& transitionMatrix, std::vector const& rowGroupIndices, std::vector const& stateToRowGroupMapping, std::vector& lowerBoundsPerAction, std::vector& upperBoundsPerAction, std::vector& lowerBoundsPerState, std::vector& upperBoundsPerState, StateType const& unexploredMarker) const { - std::cout << "stack:" << std::endl; - for (auto const& entry : stateActionStack) { - std::cout << entry.first << " -[" << entry.second << "]-> "; - } - std::cout << std::endl; - while (stateActionStack.size() > 1) { stateActionStack.pop_back(); updateProbabilities(stateActionStack.back().first, stateActionStack.back().second, transitionMatrix, rowGroupIndices, stateToRowGroupMapping, lowerBoundsPerAction, upperBoundsPerAction, lowerBoundsPerState, upperBoundsPerState, unexploredMarker); - } } @@ -102,30 +94,31 @@ namespace storm { std::vector allMaxActions; // Determine the maximal value of any action. -// ValueType max = 0; -// for (uint32_t row = rowGroupIndices[rowGroup]; row < rowGroupIndices[rowGroup + 1]; ++row) { -// ValueType current = 0; -// for (auto const& element : transitionMatrix[row]) { -// current += element.getValue() * upperBoundsPerState[stateToRowGroupMapping[element.getColumn()]]; -// } -// -// max = std::max(max, current); -// } - - STORM_LOG_TRACE("Looking for action with value " << upperBoundsPerState[stateToRowGroupMapping[currentStateId]] << "."); + ValueType max = 0; + for (uint32_t row = rowGroupIndices[rowGroup]; row < rowGroupIndices[rowGroup + 1]; ++row) { + ValueType current = 0; + for (auto const& element : transitionMatrix[row]) { + current += element.getValue() * (stateToRowGroupMapping[element.getColumn()] == unexploredMarker ? storm::utility::one() : upperBoundsPerState[stateToRowGroupMapping[element.getColumn()]]); + } + + max = std::max(max, current); + } + +// STORM_LOG_TRACE("Looking for action with value " << upperBoundsPerState[stateToRowGroupMapping[currentStateId]] << "."); + STORM_LOG_TRACE("Looking for action with value " << max << "."); + for (uint32_t row = rowGroupIndices[rowGroup]; row < rowGroupIndices[rowGroup + 1]; ++row) { ValueType current = 0; for (auto const& element : transitionMatrix[row]) { - std::cout << "+= " << element.getValue() << " * " << (stateToRowGroupMapping[element.getColumn()] == unexploredMarker ? storm::utility::one() : upperBoundsPerState[stateToRowGroupMapping[element.getColumn()]]) << " (col: " << element.getColumn() << " // row (grp) " << stateToRowGroupMapping[element.getColumn()] << ")" << std::endl; current += element.getValue() * (stateToRowGroupMapping[element.getColumn()] == unexploredMarker ? storm::utility::one() : upperBoundsPerState[stateToRowGroupMapping[element.getColumn()]]); } STORM_LOG_TRACE("Computed (upper) bound " << current << " for row " << row << "."); // If the action is one of the maximizing ones, insert it into our list. // TODO: should this need to be an approximate check? - if (current == upperBoundsPerState[stateToRowGroupMapping[currentStateId]]) { +// if (current == upperBoundsPerState[stateToRowGroupMapping[currentStateId]]) { + if (current == max) { allMaxActions.push_back(row); - std::cout << "found maximizing action " << row << std::endl; } } @@ -146,7 +139,9 @@ namespace storm { // Now sample according to the probabilities. std::discrete_distribution distribution(probabilities.begin(), probabilities.end()); - return transitionMatrix[row][distribution(generator)].getColumn(); + StateType offset = distribution(generator); + STORM_LOG_TRACE("Sampled " << offset << " from " << probabilities.size() << " elements."); + return transitionMatrix[row][offset].getColumn(); } template @@ -295,12 +290,12 @@ namespace storm { // its behavior. if (!foundTargetState) { // Next, we insert the behavior into our matrix structure. - matrix.resize(matrix.size() + behavior.getNumberOfChoices()); + StateType startRow = matrix.size(); + matrix.resize(startRow + behavior.getNumberOfChoices()); uint32_t currentAction = 0; for (auto const& choice : behavior) { for (auto const& entry : choice) { - std::cout << "got " << currentStateId << " (row group " << stateToRowGroupMapping[currentStateId] << ") " << " -> " << entry.first << " with prob " << entry.second << std::endl; - matrix.back().emplace_back(entry.first, entry.second); + matrix[startRow + currentAction].emplace_back(entry.first, entry.second); } lowerBoundsPerAction.push_back(storm::utility::zero()); @@ -324,6 +319,9 @@ namespace storm { // we need to determine this now. if (matrix[rowGroupIndices[stateToRowGroupMapping[currentStateId]]].empty()) { foundTargetState = true; + + // Update the bounds along the path to the terminal state. + updateProbabilitiesUsingStack(stateActionStack, matrix, rowGroupIndices, stateToRowGroupMapping, lowerBoundsPerAction, upperBoundsPerAction, lowerBoundsPerState, upperBoundsPerState, unexploredMarker); } } @@ -342,20 +340,11 @@ namespace storm { } } - for (auto const& el : lowerBoundsPerState) { - std::cout << el << " - "; - } - std::cout << std::endl; - - for (auto const& el : upperBoundsPerState) { - std::cout << el << " - "; - } - std::cout << std::endl; - - STORM_LOG_TRACE("Lower bound is " << lowerBoundsPerState[0] << "."); - STORM_LOG_TRACE("Upper bound is " << upperBoundsPerState[0] << "."); + STORM_LOG_DEBUG("Discovered states: " << stateStorage.numberOfStates << " (" << unexploredStates.size() << " unexplored)."); + STORM_LOG_DEBUG("Lower bound is " << lowerBoundsPerState[0] << "."); + STORM_LOG_DEBUG("Upper bound is " << upperBoundsPerState[0] << "."); ValueType difference = upperBoundsPerState[0] - lowerBoundsPerState[0]; - STORM_LOG_TRACE("Difference after iteration " << iteration << " is " << difference << "."); + STORM_LOG_DEBUG("Difference after iteration " << iteration << " is " << difference << "."); convergenceCriterionMet = difference < 1e-6; ++iteration; diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h index 90e72d85a..1d91a5d14 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h @@ -34,7 +34,7 @@ namespace storm { uint32_t sampleFromMaxActions(StateType currentStateId, std::vector>> const& transitionMatrix, std::vector const& rowGroupIndices, std::vector const& stateToRowGroupMapping, std::vector& upperBounds, StateType const& unexploredMarker); StateType sampleSuccessorFromAction(StateType currentStateId, std::vector>> const& transitionMatrix, std::vector const& rowGroupIndices, std::vector const& stateToRowGroupMapping); - + // The program that defines the model to check. storm::prism::Program program;