diff --git a/src/generator/NextStateGenerator.h b/src/generator/NextStateGenerator.h index de80bc83d..7dab9f470 100644 --- a/src/generator/NextStateGenerator.h +++ b/src/generator/NextStateGenerator.h @@ -21,7 +21,7 @@ namespace storm { virtual void load(CompressedState const& state) = 0; virtual StateBehavior expand(StateToIdCallback const& stateToIdCallback) = 0; - virtual bool satisfies(storm::expressions::Expression const& expression) = 0; + virtual bool satisfies(storm::expressions::Expression const& expression) const = 0; }; } } diff --git a/src/generator/PrismNextStateGenerator.cpp b/src/generator/PrismNextStateGenerator.cpp index 83fd4b6c8..0064636b6 100644 --- a/src/generator/PrismNextStateGenerator.cpp +++ b/src/generator/PrismNextStateGenerator.cpp @@ -13,7 +13,7 @@ namespace storm { PrismNextStateGenerator::PrismNextStateGenerator(storm::prism::Program const& program, VariableInformation const& variableInformation, bool buildChoiceLabeling) : program(program), selectedRewardModels(), buildChoiceLabeling(buildChoiceLabeling), variableInformation(variableInformation), evaluator(program.getManager()), state(nullptr), comparator() { // Intentionally left empty. } - + template void PrismNextStateGenerator::addRewardModel(storm::prism::RewardModel const& rewardModel) { selectedRewardModels.push_back(rewardModel); @@ -58,7 +58,7 @@ namespace storm { } template - bool PrismNextStateGenerator::satisfies(storm::expressions::Expression const& expression) { + bool PrismNextStateGenerator::satisfies(storm::expressions::Expression const& expression) const { return evaluator.asBool(expression); } diff --git a/src/generator/PrismNextStateGenerator.h b/src/generator/PrismNextStateGenerator.h index 4eebf0dd3..fe997d73a 100644 --- a/src/generator/PrismNextStateGenerator.h +++ b/src/generator/PrismNextStateGenerator.h @@ -18,7 +18,7 @@ namespace storm { typedef typename NextStateGenerator::StateToIdCallback StateToIdCallback; PrismNextStateGenerator(storm::prism::Program const& program, VariableInformation const& variableInformation, bool buildChoiceLabeling); - + /*! * Adds a reward model to the list of selected reward models () */ @@ -34,7 +34,7 @@ namespace storm { virtual void load(CompressedState const& state) override; virtual StateBehavior expand(StateToIdCallback const& stateToIdCallback) override; - virtual bool satisfies(storm::expressions::Expression const& expression) override; + virtual bool satisfies(storm::expressions::Expression const& expression) const override; private: /*! diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp index 67d3c6476..a762618f8 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp @@ -40,7 +40,7 @@ namespace storm { STORM_LOG_THROW(program.isDeterministicModel() || checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "For nondeterministic systems, an optimization direction (min/max) must be given in the property."); STORM_LOG_THROW(subformula.isAtomicExpressionFormula() || subformula.isAtomicLabelFormula(), storm::exceptions::NotSupportedException, "Learning engine can only deal with formulas of the form 'F \"label\"' or 'F expression'."); - StateGeneration stateGeneration(storm::generator::PrismNextStateGenerator(program, variableInformation, false), getTargetStateExpression(subformula)); + StateGeneration stateGeneration(program, variableInformation, getTargetStateExpression(subformula)); ExplorationInformation explorationInformation(variableInformation.getTotalBitOffset(true)); explorationInformation.optimizationDirection = checkTask.isOptimizationDirectionSet() ? checkTask.getOptimizationDirection() : storm::OptimizationDirection::Maximize; @@ -49,7 +49,7 @@ namespace storm { explorationInformation.newRowGroup(0); // Create a callback for the next-state generator to enable it to request the index of states. - std::function stateToIdCallback = createStateToIdCallback(explorationInformation); + stateGeneration.stateToIdCallback = createStateToIdCallback(explorationInformation); // Compute and return result. std::tuple boundsForInitialState = performLearningProcedure(stateGeneration, explorationInformation); @@ -114,7 +114,7 @@ namespace storm { STORM_LOG_TRACE("Did not find terminal state."); } - STORM_LOG_DEBUG("Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << "explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored)."); + STORM_LOG_DEBUG("Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << " explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored)."); STORM_LOG_DEBUG("Value of initial state is in [" << bounds.getLowerBoundForState(initialStateIndex, explorationInformation) << ", " << bounds.getUpperBoundForState(initialStateIndex, explorationInformation) << "]."); ValueType difference = bounds.getDifferenceOfStateBounds(initialStateIndex, explorationInformation); STORM_LOG_DEBUG("Difference after iteration " << stats.iterations << " is " << difference << "."); @@ -125,7 +125,7 @@ namespace storm { if (storm::settings::generalSettings().isShowStatisticsSet()) { std::cout << std::endl << "Learning summary -------------------------" << std::endl; - std::cout << "Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << "explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored, " << stats.numberOfTargetStates << " target states)" << std::endl; + std::cout << "Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << " explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored, " << stats.numberOfTargetStates << " target states)" << std::endl; std::cout << "Sampling iterations: " << stats.iterations << std::endl; std::cout << "Maximal path length: " << stats.maxPathLength << std::endl; } @@ -165,7 +165,7 @@ namespace storm { if (!foundTerminalState) { // At this point, we can be sure that the state was expanded and that we can sample according to the // probabilities in the matrix. - uint32_t chosenAction = sampleFromMaxActions(currentStateId, explorationInformation, bounds); + uint32_t chosenAction = sampleMaxAction(currentStateId, explorationInformation, bounds); stack.back().second = chosenAction; STORM_LOG_TRACE("Sampled action " << chosenAction << " in state " << currentStateId << "."); @@ -194,10 +194,19 @@ namespace storm { template bool SparseMdpLearningModelChecker::exploreState(StateGeneration& stateGeneration, StateType const& currentStateId, storm::generator::CompressedState const& currentState, ExplorationInformation& explorationInformation, BoundValues& bounds, Statistics& stats) const { - bool isTerminalState = false; bool isTargetState = false; + ++stats.numberOfExploredStates; + + // Finally, map the unexplored state to the row group. + explorationInformation.assignStateToNextRowGroup(currentStateId); + STORM_LOG_TRACE("Assigning row group " << explorationInformation.getRowGroup(currentStateId) << " to state " << currentStateId << "."); + + // Initialize the bounds, because some of the following computations depend on the values to be available for + // all states that have been assigned to a row-group. + bounds.initializeBoundsForNextState(); + // Before generating the behavior of the state, we need to determine whether it's a target state that // does not need to be expanded. stateGeneration.generator.load(currentState); @@ -232,38 +241,41 @@ namespace storm { StateType startRow = explorationInformation.matrix.size(); explorationInformation.addRowsToMatrix(behavior.getNumberOfChoices()); - // Terminate the row group. - explorationInformation.rowGroupIndices.push_back(explorationInformation.matrix.size()); - ActionType currentAction = 0; + std::pair stateBounds(storm::utility::zero(), storm::utility::zero()); + for (auto const& choice : behavior) { for (auto const& entry : choice) { - std::cout << "adding " << currentStateId << " -> " << entry.first << " with prob " << entry.second << std::endl; - explorationInformation.matrix[startRow + currentAction].emplace_back(entry.first, entry.second); + explorationInformation.getRowOfMatrix(startRow + currentAction).emplace_back(entry.first, entry.second); } - bounds.initializeActionBoundsForNextAction(computeBoundsOfAction(startRow + currentAction, explorationInformation, bounds)); + std::pair actionBounds = computeBoundsOfAction(startRow + currentAction, explorationInformation, bounds); + bounds.initializeBoundsForNextAction(actionBounds); + stateBounds = std::make_pair(std::max(stateBounds.first, actionBounds.first), std::max(stateBounds.second, actionBounds.second)); STORM_LOG_TRACE("Initializing bounds of action " << (startRow + currentAction) << " to " << bounds.getLowerBoundForAction(startRow + currentAction) << " and " << bounds.getUpperBoundForAction(startRow + currentAction) << "."); ++currentAction; } - bounds.initializeStateBoundsForNextState(computeBoundsOfState(currentStateId, explorationInformation, bounds)); - STORM_LOG_TRACE("Initializing bounds of state " << currentStateId << " to " << bounds.getLowerBoundForState(currentStateId) << " and " << bounds.getUpperBoundForState(currentStateId) << "."); + // Terminate the row group. + explorationInformation.rowGroupIndices.push_back(explorationInformation.matrix.size()); + + bounds.setBoundsForState(currentStateId, explorationInformation, stateBounds); + STORM_LOG_TRACE("Initializing bounds of state " << currentStateId << " to " << bounds.getLowerBoundForState(currentStateId, explorationInformation) << " and " << bounds.getUpperBoundForState(currentStateId, explorationInformation) << "."); } } - + if (isTerminalState) { STORM_LOG_TRACE("State does not need to be explored, because it is " << (isTargetState ? "a target state" : "a rejecting terminal state") << "."); explorationInformation.addTerminalState(currentStateId); if (isTargetState) { - bounds.initializeStateBoundsForNextState(std::make_pair(storm::utility::one(), storm::utility::one())); - bounds.initializeStateBoundsForNextAction(std::make_pair(storm::utility::one(), storm::utility::one())); + bounds.setBoundsForState(currentStateId, explorationInformation, std::make_pair(storm::utility::one(), storm::utility::one())); + bounds.initializeBoundsForNextAction(std::make_pair(storm::utility::one(), storm::utility::one())); } else { - bounds.initializeStateBoundsForNextState(std::make_pair(storm::utility::zero(), storm::utility::zero())); - bounds.initializeStateBoundsForNextAction(std::make_pair(storm::utility::zero(), storm::utility::zero())); + bounds.setBoundsForState(currentStateId, explorationInformation, std::make_pair(storm::utility::zero(), storm::utility::zero())); + bounds.initializeBoundsForNextAction(std::make_pair(storm::utility::zero(), storm::utility::zero())); } // Increase the size of the matrix, but leave the row empty. @@ -273,10 +285,6 @@ namespace storm { explorationInformation.newRowGroup(); } - // Finally, map the unexplored state to the row group. - explorationInformation.assignStateToNextRowGroup(currentStateId); - STORM_LOG_TRACE("Assigning row group " << explorationInformation.getRowGroup(currentStateId) << " to state " << currentStateId << "."); - return isTerminalState; } @@ -303,11 +311,17 @@ namespace storm { } } - std::sort(actionValues.begin(), actionValues.end(), [] (std::pair const& a, std::pair const& b) { return b.second > a.second; } ); - auto end = std::equal_range(actionValues.begin(), actionValues.end(), [this] (std::pair const& a, std::pair const& b) { return comparator.isEqual(a.second, b.second); } ); + STORM_LOG_ASSERT(!actionValues.empty(), "Values for actions must not be empty."); + + std::sort(actionValues.begin(), actionValues.end(), [] (std::pair const& a, std::pair const& b) { return a.second > b.second; } ); + + auto end = ++actionValues.begin(); + while (end != actionValues.end() && comparator.isEqual(actionValues.begin()->second, end->second)) { + ++end; + } // Now sample from all maximizing actions. - std::uniform_int_distribution distribution(0, std::distance(actionValues.begin(), end)); + std::uniform_int_distribution distribution(0, std::distance(actionValues.begin(), end) - 1); return actionValues[distribution(randomGenerator)].first; } @@ -350,7 +364,6 @@ namespace storm { // Create a mapping for faster look-up during the translation of flexible matrix to the real sparse matrix. std::unordered_map relevantStateToNewRowGroupMapping; for (StateType index = 0; index < relevantStates.size(); ++index) { - std::cout << "relevant: " << relevantStates[index] << std::endl; relevantStateToNewRowGroupMapping.emplace(relevantStates[index], index); } @@ -400,11 +413,8 @@ namespace storm { ActionSetPointer leavingChoices = std::make_shared(); for (auto const& stateAndChoices : mec) { // Compute the state of the original model that corresponds to the current state. - std::cout << "local state: " << stateAndChoices.first << std::endl; StateType originalState = relevantStates[stateAndChoices.first]; - std::cout << "original state: " << originalState << std::endl; uint32_t originalRowGroup = explorationInformation.getRowGroup(originalState); - std::cout << "original row group: " << originalRowGroup << std::endl; // TODO: This checks for a target state is a bit hackish and only works for max probabilities. if (!containsTargetState && comparator.isOne(bounds.getLowerBoundForRowGroup(originalRowGroup, explorationInformation))) { @@ -439,7 +449,7 @@ namespace storm { STORM_LOG_TRACE("MEC contains a target state."); for (auto const& stateAndChoices : mec) { // Compute the state of the original model that corresponds to the current state. - StateType originalState = relevantStates[stateAndChoices.first]; + StateType const& originalState = relevantStates[stateAndChoices.first]; STORM_LOG_TRACE("Setting lower bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 1."); bounds.setLowerBoundForState(originalState, explorationInformation, storm::utility::one()); @@ -450,7 +460,7 @@ namespace storm { // If there is no choice leaving the EC, but it contains no target state, all states have probability 0. for (auto const& stateAndChoices : mec) { // Compute the state of the original model that corresponds to the current state. - StateType originalState = relevantStates[stateAndChoices.first]; + StateType const& originalState = relevantStates[stateAndChoices.first]; STORM_LOG_TRACE("Setting upper bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 0."); bounds.setUpperBoundForState(originalState, explorationInformation, storm::utility::zero()); @@ -481,7 +491,7 @@ namespace storm { template ValueType SparseMdpLearningModelChecker::computeLowerBoundOfState(StateType const& state, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { StateType group = explorationInformation.getRowGroup(state); - ValueType result = std::make_pair(storm::utility::zero(), storm::utility::zero()); + ValueType result = storm::utility::zero(); for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) { ValueType actionValue = computeLowerBoundOfAction(action, explorationInformation, bounds); result = std::max(actionValue, result); @@ -492,7 +502,7 @@ namespace storm { template ValueType SparseMdpLearningModelChecker::computeUpperBoundOfState(StateType const& state, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { StateType group = explorationInformation.getRowGroup(state); - ValueType result = std::make_pair(storm::utility::zero(), storm::utility::zero()); + ValueType result = storm::utility::zero(); for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) { ValueType actionValue = computeUpperBoundOfAction(action, explorationInformation, bounds); result = std::max(actionValue, result); @@ -525,12 +535,6 @@ namespace storm { template void SparseMdpLearningModelChecker::updateProbabilityBoundsAlongSampledPath(StateActionStack& stack, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { - std::cout << "stack is:" << std::endl; - for (auto const& el : stack) { - std::cout << el.first << "-[" << el.second << "]-> "; - } - std::cout << std::endl; - stack.pop_back(); while (!stack.empty()) { updateProbabilityOfAction(stack.back().first, stack.back().second, explorationInformation, bounds); @@ -538,51 +542,40 @@ namespace storm { } } + template + ValueType SparseMdpLearningModelChecker::computeUpperBoundOverAllOtherActions(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { + ValueType max = storm::utility::zero(); + + ActionType group = explorationInformation.getRowGroup(state); + for (auto currentAction = explorationInformation.getStartRowOfGroup(group); currentAction < explorationInformation.getStartRowOfGroup(group + 1); ++currentAction) { + if (currentAction == action) { + continue; + } + + max = std::max(max, computeUpperBoundOfAction(currentAction, explorationInformation, bounds)); + } + + return max; + } + template void SparseMdpLearningModelChecker::updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { // Compute the new lower/upper values of the action. std::pair newBoundsForAction = computeBoundsOfAction(action, explorationInformation, bounds); // And set them as the current value. - bounds.setBoundForAction(action, newBoundsForAction); + bounds.setBoundsForAction(action, newBoundsForAction); // Check if we need to update the values for the states. bounds.setNewLowerBoundOfStateIfGreaterThanOld(state, explorationInformation, newBoundsForAction.first); StateType rowGroup = explorationInformation.getRowGroup(state); - if (newBoundsForAction < bounds.getUpperBoundOfRowGroup(rowGroup)) { - - } - - ValueType upperBound = computeUpperBoundOverAllOtherActions(state, action, explorationInformation, bounds); - - uint32_t sourceRowGroup = stateToRowGroupMapping[sourceStateId]; - if (newUpperValue < upperBoundsPerState[sourceRowGroup]) { - if (rowGroupIndices[sourceRowGroup + 1] - rowGroupIndices[sourceRowGroup] > 1) { - ValueType max = storm::utility::zero(); - - for (uint32_t currentAction = rowGroupIndices[sourceRowGroup]; currentAction < rowGroupIndices[sourceRowGroup + 1]; ++currentAction) { - std::cout << "cur: " << currentAction << std::endl; - if (currentAction == action) { - continue; - } - - ValueType currentValue = storm::utility::zero(); - for (auto const& element : transitionMatrix[currentAction]) { - currentValue += element.getValue() * (stateToRowGroupMapping[element.getColumn()] == unexploredMarker ? storm::utility::one() : upperBoundsPerState[stateToRowGroupMapping[element.getColumn()]]); - } - max = std::max(max, currentValue); - std::cout << "max is " << max << std::endl; - } - - newUpperValue = std::max(newUpperValue, max); - } - - if (newUpperValue < upperBoundsPerState[sourceRowGroup]) { - STORM_LOG_TRACE("Got new upper bound for state " << sourceStateId << ": " << newUpperValue << " (was " << upperBoundsPerState[sourceRowGroup] << ")."); - std::cout << "writing at index " << sourceRowGroup << std::endl; - upperBoundsPerState[sourceRowGroup] = newUpperValue; + if (newBoundsForAction.second < bounds.getUpperBoundForRowGroup(rowGroup)) { + if (explorationInformation.getRowGroupSize(rowGroup) > 1) { + newBoundsForAction.second = std::max(newBoundsForAction.second, computeUpperBoundOverAllOtherActions(state, action, explorationInformation, bounds)); } + + bounds.setUpperBoundForState(state, explorationInformation, newBoundsForAction.second); } } diff --git a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h index 777caf2d9..01ea2d323 100644 --- a/src/modelchecker/reachability/SparseMdpLearningModelChecker.h +++ b/src/modelchecker/reachability/SparseMdpLearningModelChecker.h @@ -62,7 +62,7 @@ namespace storm { // A struct containing the data required for state exploration. struct StateGeneration { - StateGeneration(storm::generator::PrismNextStateGenerator&& generator, storm::expressions::Expression const& targetStateExpression) : generator(std::move(generator)), targetStateExpression(targetStateExpression) { + StateGeneration(storm::prism::Program const& program, storm::generator::VariableInformation const& variableInformation, storm::expressions::Expression const& targetStateExpression) : generator(program, variableInformation, false), targetStateExpression(targetStateExpression) { // Intentionally left empty. } @@ -125,8 +125,9 @@ namespace storm { stateToRowGroupMapping[state] = rowGroup; } - void assignStateToNextRowGroup(StateType const& state) { + StateType assignStateToNextRowGroup(StateType const& state) { stateToRowGroupMapping[state] = rowGroupIndices.size() - 1; + return stateToRowGroupMapping[state]; } void newRowGroup(ActionType const& action) { @@ -154,7 +155,7 @@ namespace storm { } bool isUnexplored(StateType const& state) const { - return unexploredStates.find(state) == unexploredStates.end(); + return stateToRowGroupMapping[state] == unexploredMarker; } bool isTerminal(StateType const& state) const { @@ -165,6 +166,10 @@ namespace storm { return rowGroupIndices[group]; } + std::size_t getRowGroupSize(StateType const& group) const { + return rowGroupIndices[group + 1] - rowGroupIndices[group]; + } + void addTerminalState(StateType const& state) { terminalStates.insert(state); } @@ -216,11 +221,11 @@ namespace storm { if (index == explorationInformation.getUnexploredMarker()) { return storm::utility::one(); } else { - return getUpperBoundForRowGroup(index, explorationInformation); + return getUpperBoundForRowGroup(index); } } - ValueType getUpperBoundForRowGroup(StateType const& rowGroup, ExplorationInformation const& explorationInformation) const { + ValueType const& getUpperBoundForRowGroup(StateType const& rowGroup) const { return upperBoundsPerState[rowGroup]; } @@ -241,12 +246,12 @@ namespace storm { return bounds.second - bounds.first; } - void initializeStateBoundsForNextState(std::pair const& vals = std::pair(storm::utility::zero(), storm::utility::one())) { + void initializeBoundsForNextState(std::pair const& vals = std::pair(storm::utility::zero(), storm::utility::one())) { lowerBoundsPerState.push_back(vals.first); upperBoundsPerState.push_back(vals.second); } - void initializeActionBoundsForNextAction(std::pair const& vals = std::pair(storm::utility::zero(), storm::utility::one())) { + void initializeBoundsForNextAction(std::pair const& vals = std::pair(storm::utility::zero(), storm::utility::one())) { lowerBoundsPerAction.push_back(vals.first); upperBoundsPerAction.push_back(vals.second); } @@ -274,14 +279,18 @@ namespace storm { StateType const& rowGroup = explorationInformation.getRowGroup(state); if (lowerBoundsPerState[rowGroup] < newLowerValue) { lowerBoundsPerState[rowGroup] = newLowerValue; + return true; } + return false; } bool setNewUpperBoundOfStateIfLessThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newUpperValue) { StateType const& rowGroup = explorationInformation.getRowGroup(state); if (newUpperValue < upperBoundsPerState[rowGroup]) { upperBoundsPerState[rowGroup] = newUpperValue; + return true; } + return false; } }; @@ -306,6 +315,7 @@ namespace storm { void updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; std::pair computeBoundsOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; + ValueType computeUpperBoundOverAllOtherActions(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; std::pair computeBoundsOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeLowerBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeUpperBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; @@ -319,7 +329,7 @@ namespace storm { storm::generator::VariableInformation variableInformation; // The random number generator. - std::default_random_engine randomGenerator; + mutable std::default_random_engine randomGenerator; // A comparator used to determine whether values are equal. storm::utility::ConstantsComparator comparator;