Browse Source

debugged the refactoring a bit

Former-commit-id: 9df3d5d533
main
dehnert 9 years ago
parent
commit
1405cdfc46
  1. 2
      src/generator/NextStateGenerator.h
  2. 4
      src/generator/PrismNextStateGenerator.cpp
  3. 4
      src/generator/PrismNextStateGenerator.h
  4. 139
      src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp
  5. 26
      src/modelchecker/reachability/SparseMdpLearningModelChecker.h

2
src/generator/NextStateGenerator.h

@ -21,7 +21,7 @@ namespace storm {
virtual void load(CompressedState const& state) = 0; virtual void load(CompressedState const& state) = 0;
virtual StateBehavior<ValueType, StateType> expand(StateToIdCallback const& stateToIdCallback) = 0; virtual StateBehavior<ValueType, StateType> expand(StateToIdCallback const& stateToIdCallback) = 0;
virtual bool satisfies(storm::expressions::Expression const& expression) = 0; virtual bool satisfies(storm::expressions::Expression const& expression) const = 0;
}; };
} }
} }

4
src/generator/PrismNextStateGenerator.cpp

@ -13,7 +13,7 @@ namespace storm {
PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, VariableInformation const& variableInformation, bool buildChoiceLabeling) : program(program), selectedRewardModels(), buildChoiceLabeling(buildChoiceLabeling), variableInformation(variableInformation), evaluator(program.getManager()), state(nullptr), comparator() { PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, VariableInformation const& variableInformation, bool buildChoiceLabeling) : program(program), selectedRewardModels(), buildChoiceLabeling(buildChoiceLabeling), variableInformation(variableInformation), evaluator(program.getManager()), state(nullptr), comparator() {
// Intentionally left empty. // Intentionally left empty.
} }
template<typename ValueType, typename StateType> template<typename ValueType, typename StateType>
void PrismNextStateGenerator<ValueType, StateType>::addRewardModel(storm::prism::RewardModel const& rewardModel) { void PrismNextStateGenerator<ValueType, StateType>::addRewardModel(storm::prism::RewardModel const& rewardModel) {
selectedRewardModels.push_back(rewardModel); selectedRewardModels.push_back(rewardModel);
@ -58,7 +58,7 @@ namespace storm {
} }
template<typename ValueType, typename StateType> template<typename ValueType, typename StateType>
bool PrismNextStateGenerator<ValueType, StateType>::satisfies(storm::expressions::Expression const& expression) { bool PrismNextStateGenerator<ValueType, StateType>::satisfies(storm::expressions::Expression const& expression) const {
return evaluator.asBool(expression); return evaluator.asBool(expression);
} }

4
src/generator/PrismNextStateGenerator.h

@ -18,7 +18,7 @@ namespace storm {
typedef typename NextStateGenerator<ValueType, StateType>::StateToIdCallback StateToIdCallback; typedef typename NextStateGenerator<ValueType, StateType>::StateToIdCallback StateToIdCallback;
PrismNextStateGenerator(storm::prism::Program const& program, VariableInformation const& variableInformation, bool buildChoiceLabeling); PrismNextStateGenerator(storm::prism::Program const& program, VariableInformation const& variableInformation, bool buildChoiceLabeling);
/*! /*!
* Adds a reward model to the list of selected reward models () * Adds a reward model to the list of selected reward models ()
*/ */
@ -34,7 +34,7 @@ namespace storm {
virtual void load(CompressedState const& state) override; virtual void load(CompressedState const& state) override;
virtual StateBehavior<ValueType, StateType> expand(StateToIdCallback const& stateToIdCallback) override; virtual StateBehavior<ValueType, StateType> expand(StateToIdCallback const& stateToIdCallback) override;
virtual bool satisfies(storm::expressions::Expression const& expression) override; virtual bool satisfies(storm::expressions::Expression const& expression) const override;
private: private:
/*! /*!

139
src/modelchecker/reachability/SparseMdpLearningModelChecker.cpp

@ -40,7 +40,7 @@ namespace storm {
STORM_LOG_THROW(program.isDeterministicModel() || checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "For nondeterministic systems, an optimization direction (min/max) must be given in the property."); STORM_LOG_THROW(program.isDeterministicModel() || checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "For nondeterministic systems, an optimization direction (min/max) must be given in the property.");
STORM_LOG_THROW(subformula.isAtomicExpressionFormula() || subformula.isAtomicLabelFormula(), storm::exceptions::NotSupportedException, "Learning engine can only deal with formulas of the form 'F \"label\"' or 'F expression'."); STORM_LOG_THROW(subformula.isAtomicExpressionFormula() || subformula.isAtomicLabelFormula(), storm::exceptions::NotSupportedException, "Learning engine can only deal with formulas of the form 'F \"label\"' or 'F expression'.");
StateGeneration stateGeneration(storm::generator::PrismNextStateGenerator<ValueType, StateType>(program, variableInformation, false), getTargetStateExpression(subformula)); StateGeneration stateGeneration(program, variableInformation, getTargetStateExpression(subformula));
ExplorationInformation explorationInformation(variableInformation.getTotalBitOffset(true)); ExplorationInformation explorationInformation(variableInformation.getTotalBitOffset(true));
explorationInformation.optimizationDirection = checkTask.isOptimizationDirectionSet() ? checkTask.getOptimizationDirection() : storm::OptimizationDirection::Maximize; explorationInformation.optimizationDirection = checkTask.isOptimizationDirectionSet() ? checkTask.getOptimizationDirection() : storm::OptimizationDirection::Maximize;
@ -49,7 +49,7 @@ namespace storm {
explorationInformation.newRowGroup(0); explorationInformation.newRowGroup(0);
// Create a callback for the next-state generator to enable it to request the index of states. // Create a callback for the next-state generator to enable it to request the index of states.
std::function<StateType (storm::generator::CompressedState const&)> stateToIdCallback = createStateToIdCallback(explorationInformation); stateGeneration.stateToIdCallback = createStateToIdCallback(explorationInformation);
// Compute and return result. // Compute and return result.
std::tuple<StateType, ValueType, ValueType> boundsForInitialState = performLearningProcedure(stateGeneration, explorationInformation); std::tuple<StateType, ValueType, ValueType> boundsForInitialState = performLearningProcedure(stateGeneration, explorationInformation);
@ -114,7 +114,7 @@ namespace storm {
STORM_LOG_TRACE("Did not find terminal state."); STORM_LOG_TRACE("Did not find terminal state.");
} }
STORM_LOG_DEBUG("Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << "explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored)."); STORM_LOG_DEBUG("Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << " explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored).");
STORM_LOG_DEBUG("Value of initial state is in [" << bounds.getLowerBoundForState(initialStateIndex, explorationInformation) << ", " << bounds.getUpperBoundForState(initialStateIndex, explorationInformation) << "]."); STORM_LOG_DEBUG("Value of initial state is in [" << bounds.getLowerBoundForState(initialStateIndex, explorationInformation) << ", " << bounds.getUpperBoundForState(initialStateIndex, explorationInformation) << "].");
ValueType difference = bounds.getDifferenceOfStateBounds(initialStateIndex, explorationInformation); ValueType difference = bounds.getDifferenceOfStateBounds(initialStateIndex, explorationInformation);
STORM_LOG_DEBUG("Difference after iteration " << stats.iterations << " is " << difference << "."); STORM_LOG_DEBUG("Difference after iteration " << stats.iterations << " is " << difference << ".");
@ -125,7 +125,7 @@ namespace storm {
if (storm::settings::generalSettings().isShowStatisticsSet()) { if (storm::settings::generalSettings().isShowStatisticsSet()) {
std::cout << std::endl << "Learning summary -------------------------" << std::endl; std::cout << std::endl << "Learning summary -------------------------" << std::endl;
std::cout << "Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << "explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored, " << stats.numberOfTargetStates << " target states)" << std::endl; std::cout << "Discovered states: " << explorationInformation.getNumberOfDiscoveredStates() << " (" << stats.numberOfExploredStates << " explored, " << explorationInformation.getNumberOfUnexploredStates() << " unexplored, " << stats.numberOfTargetStates << " target states)" << std::endl;
std::cout << "Sampling iterations: " << stats.iterations << std::endl; std::cout << "Sampling iterations: " << stats.iterations << std::endl;
std::cout << "Maximal path length: " << stats.maxPathLength << std::endl; std::cout << "Maximal path length: " << stats.maxPathLength << std::endl;
} }
@ -165,7 +165,7 @@ namespace storm {
if (!foundTerminalState) { if (!foundTerminalState) {
// At this point, we can be sure that the state was expanded and that we can sample according to the // At this point, we can be sure that the state was expanded and that we can sample according to the
// probabilities in the matrix. // probabilities in the matrix.
uint32_t chosenAction = sampleFromMaxActions(currentStateId, explorationInformation, bounds); uint32_t chosenAction = sampleMaxAction(currentStateId, explorationInformation, bounds);
stack.back().second = chosenAction; stack.back().second = chosenAction;
STORM_LOG_TRACE("Sampled action " << chosenAction << " in state " << currentStateId << "."); STORM_LOG_TRACE("Sampled action " << chosenAction << " in state " << currentStateId << ".");
@ -194,10 +194,19 @@ namespace storm {
template<typename ValueType> template<typename ValueType>
bool SparseMdpLearningModelChecker<ValueType>::exploreState(StateGeneration& stateGeneration, StateType const& currentStateId, storm::generator::CompressedState const& currentState, ExplorationInformation& explorationInformation, BoundValues& bounds, Statistics& stats) const { bool SparseMdpLearningModelChecker<ValueType>::exploreState(StateGeneration& stateGeneration, StateType const& currentStateId, storm::generator::CompressedState const& currentState, ExplorationInformation& explorationInformation, BoundValues& bounds, Statistics& stats) const {
bool isTerminalState = false; bool isTerminalState = false;
bool isTargetState = false; bool isTargetState = false;
++stats.numberOfExploredStates;
// Finally, map the unexplored state to the row group.
explorationInformation.assignStateToNextRowGroup(currentStateId);
STORM_LOG_TRACE("Assigning row group " << explorationInformation.getRowGroup(currentStateId) << " to state " << currentStateId << ".");
// Initialize the bounds, because some of the following computations depend on the values to be available for
// all states that have been assigned to a row-group.
bounds.initializeBoundsForNextState();
// Before generating the behavior of the state, we need to determine whether it's a target state that // Before generating the behavior of the state, we need to determine whether it's a target state that
// does not need to be expanded. // does not need to be expanded.
stateGeneration.generator.load(currentState); stateGeneration.generator.load(currentState);
@ -232,38 +241,41 @@ namespace storm {
StateType startRow = explorationInformation.matrix.size(); StateType startRow = explorationInformation.matrix.size();
explorationInformation.addRowsToMatrix(behavior.getNumberOfChoices()); explorationInformation.addRowsToMatrix(behavior.getNumberOfChoices());
// Terminate the row group.
explorationInformation.rowGroupIndices.push_back(explorationInformation.matrix.size());
ActionType currentAction = 0; ActionType currentAction = 0;
std::pair<ValueType, ValueType> stateBounds(storm::utility::zero<ValueType>(), storm::utility::zero<ValueType>());
for (auto const& choice : behavior) { for (auto const& choice : behavior) {
for (auto const& entry : choice) { for (auto const& entry : choice) {
std::cout << "adding " << currentStateId << " -> " << entry.first << " with prob " << entry.second << std::endl; explorationInformation.getRowOfMatrix(startRow + currentAction).emplace_back(entry.first, entry.second);
explorationInformation.matrix[startRow + currentAction].emplace_back(entry.first, entry.second);
} }
bounds.initializeActionBoundsForNextAction(computeBoundsOfAction(startRow + currentAction, explorationInformation, bounds)); std::pair<ValueType, ValueType> actionBounds = computeBoundsOfAction(startRow + currentAction, explorationInformation, bounds);
bounds.initializeBoundsForNextAction(actionBounds);
stateBounds = std::make_pair(std::max(stateBounds.first, actionBounds.first), std::max(stateBounds.second, actionBounds.second));
STORM_LOG_TRACE("Initializing bounds of action " << (startRow + currentAction) << " to " << bounds.getLowerBoundForAction(startRow + currentAction) << " and " << bounds.getUpperBoundForAction(startRow + currentAction) << "."); STORM_LOG_TRACE("Initializing bounds of action " << (startRow + currentAction) << " to " << bounds.getLowerBoundForAction(startRow + currentAction) << " and " << bounds.getUpperBoundForAction(startRow + currentAction) << ".");
++currentAction; ++currentAction;
} }
bounds.initializeStateBoundsForNextState(computeBoundsOfState(currentStateId, explorationInformation, bounds)); // Terminate the row group.
STORM_LOG_TRACE("Initializing bounds of state " << currentStateId << " to " << bounds.getLowerBoundForState(currentStateId) << " and " << bounds.getUpperBoundForState(currentStateId) << "."); explorationInformation.rowGroupIndices.push_back(explorationInformation.matrix.size());
bounds.setBoundsForState(currentStateId, explorationInformation, stateBounds);
STORM_LOG_TRACE("Initializing bounds of state " << currentStateId << " to " << bounds.getLowerBoundForState(currentStateId, explorationInformation) << " and " << bounds.getUpperBoundForState(currentStateId, explorationInformation) << ".");
} }
} }
if (isTerminalState) { if (isTerminalState) {
STORM_LOG_TRACE("State does not need to be explored, because it is " << (isTargetState ? "a target state" : "a rejecting terminal state") << "."); STORM_LOG_TRACE("State does not need to be explored, because it is " << (isTargetState ? "a target state" : "a rejecting terminal state") << ".");
explorationInformation.addTerminalState(currentStateId); explorationInformation.addTerminalState(currentStateId);
if (isTargetState) { if (isTargetState) {
bounds.initializeStateBoundsForNextState(std::make_pair(storm::utility::one<ValueType>(), storm::utility::one<ValueType>())); bounds.setBoundsForState(currentStateId, explorationInformation, std::make_pair(storm::utility::one<ValueType>(), storm::utility::one<ValueType>()));
bounds.initializeStateBoundsForNextAction(std::make_pair(storm::utility::one<ValueType>(), storm::utility::one<ValueType>())); bounds.initializeBoundsForNextAction(std::make_pair(storm::utility::one<ValueType>(), storm::utility::one<ValueType>()));
} else { } else {
bounds.initializeStateBoundsForNextState(std::make_pair(storm::utility::zero<ValueType>(), storm::utility::zero<ValueType>())); bounds.setBoundsForState(currentStateId, explorationInformation, std::make_pair(storm::utility::zero<ValueType>(), storm::utility::zero<ValueType>()));
bounds.initializeStateBoundsForNextAction(std::make_pair(storm::utility::zero<ValueType>(), storm::utility::zero<ValueType>())); bounds.initializeBoundsForNextAction(std::make_pair(storm::utility::zero<ValueType>(), storm::utility::zero<ValueType>()));
} }
// Increase the size of the matrix, but leave the row empty. // Increase the size of the matrix, but leave the row empty.
@ -273,10 +285,6 @@ namespace storm {
explorationInformation.newRowGroup(); explorationInformation.newRowGroup();
} }
// Finally, map the unexplored state to the row group.
explorationInformation.assignStateToNextRowGroup(currentStateId);
STORM_LOG_TRACE("Assigning row group " << explorationInformation.getRowGroup(currentStateId) << " to state " << currentStateId << ".");
return isTerminalState; return isTerminalState;
} }
@ -303,11 +311,17 @@ namespace storm {
} }
} }
std::sort(actionValues.begin(), actionValues.end(), [] (std::pair<ActionType, ValueType> const& a, std::pair<ActionType, ValueType> const& b) { return b.second > a.second; } ); STORM_LOG_ASSERT(!actionValues.empty(), "Values for actions must not be empty.");
auto end = std::equal_range(actionValues.begin(), actionValues.end(), [this] (std::pair<ActionType, ValueType> const& a, std::pair<ActionType, ValueType> const& b) { return comparator.isEqual(a.second, b.second); } ); std::sort(actionValues.begin(), actionValues.end(), [] (std::pair<ActionType, ValueType> const& a, std::pair<ActionType, ValueType> const& b) { return a.second > b.second; } );
auto end = ++actionValues.begin();
while (end != actionValues.end() && comparator.isEqual(actionValues.begin()->second, end->second)) {
++end;
}
// Now sample from all maximizing actions. // Now sample from all maximizing actions.
std::uniform_int_distribution<uint32_t> distribution(0, std::distance(actionValues.begin(), end)); std::uniform_int_distribution<ActionType> distribution(0, std::distance(actionValues.begin(), end) - 1);
return actionValues[distribution(randomGenerator)].first; return actionValues[distribution(randomGenerator)].first;
} }
@ -350,7 +364,6 @@ namespace storm {
// Create a mapping for faster look-up during the translation of flexible matrix to the real sparse matrix. // Create a mapping for faster look-up during the translation of flexible matrix to the real sparse matrix.
std::unordered_map<StateType, StateType> relevantStateToNewRowGroupMapping; std::unordered_map<StateType, StateType> relevantStateToNewRowGroupMapping;
for (StateType index = 0; index < relevantStates.size(); ++index) { for (StateType index = 0; index < relevantStates.size(); ++index) {
std::cout << "relevant: " << relevantStates[index] << std::endl;
relevantStateToNewRowGroupMapping.emplace(relevantStates[index], index); relevantStateToNewRowGroupMapping.emplace(relevantStates[index], index);
} }
@ -400,11 +413,8 @@ namespace storm {
ActionSetPointer leavingChoices = std::make_shared<ActionSet>(); ActionSetPointer leavingChoices = std::make_shared<ActionSet>();
for (auto const& stateAndChoices : mec) { for (auto const& stateAndChoices : mec) {
// Compute the state of the original model that corresponds to the current state. // Compute the state of the original model that corresponds to the current state.
std::cout << "local state: " << stateAndChoices.first << std::endl;
StateType originalState = relevantStates[stateAndChoices.first]; StateType originalState = relevantStates[stateAndChoices.first];
std::cout << "original state: " << originalState << std::endl;
uint32_t originalRowGroup = explorationInformation.getRowGroup(originalState); uint32_t originalRowGroup = explorationInformation.getRowGroup(originalState);
std::cout << "original row group: " << originalRowGroup << std::endl;
// TODO: This checks for a target state is a bit hackish and only works for max probabilities. // TODO: This checks for a target state is a bit hackish and only works for max probabilities.
if (!containsTargetState && comparator.isOne(bounds.getLowerBoundForRowGroup(originalRowGroup, explorationInformation))) { if (!containsTargetState && comparator.isOne(bounds.getLowerBoundForRowGroup(originalRowGroup, explorationInformation))) {
@ -439,7 +449,7 @@ namespace storm {
STORM_LOG_TRACE("MEC contains a target state."); STORM_LOG_TRACE("MEC contains a target state.");
for (auto const& stateAndChoices : mec) { for (auto const& stateAndChoices : mec) {
// Compute the state of the original model that corresponds to the current state. // Compute the state of the original model that corresponds to the current state.
StateType originalState = relevantStates[stateAndChoices.first]; StateType const& originalState = relevantStates[stateAndChoices.first];
STORM_LOG_TRACE("Setting lower bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 1."); STORM_LOG_TRACE("Setting lower bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 1.");
bounds.setLowerBoundForState(originalState, explorationInformation, storm::utility::one<ValueType>()); bounds.setLowerBoundForState(originalState, explorationInformation, storm::utility::one<ValueType>());
@ -450,7 +460,7 @@ namespace storm {
// If there is no choice leaving the EC, but it contains no target state, all states have probability 0. // If there is no choice leaving the EC, but it contains no target state, all states have probability 0.
for (auto const& stateAndChoices : mec) { for (auto const& stateAndChoices : mec) {
// Compute the state of the original model that corresponds to the current state. // Compute the state of the original model that corresponds to the current state.
StateType originalState = relevantStates[stateAndChoices.first]; StateType const& originalState = relevantStates[stateAndChoices.first];
STORM_LOG_TRACE("Setting upper bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 0."); STORM_LOG_TRACE("Setting upper bound of state in row group " << explorationInformation.getRowGroup(originalState) << " to 0.");
bounds.setUpperBoundForState(originalState, explorationInformation, storm::utility::zero<ValueType>()); bounds.setUpperBoundForState(originalState, explorationInformation, storm::utility::zero<ValueType>());
@ -481,7 +491,7 @@ namespace storm {
template<typename ValueType> template<typename ValueType>
ValueType SparseMdpLearningModelChecker<ValueType>::computeLowerBoundOfState(StateType const& state, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { ValueType SparseMdpLearningModelChecker<ValueType>::computeLowerBoundOfState(StateType const& state, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const {
StateType group = explorationInformation.getRowGroup(state); StateType group = explorationInformation.getRowGroup(state);
ValueType result = std::make_pair(storm::utility::zero<ValueType>(), storm::utility::zero<ValueType>()); ValueType result = storm::utility::zero<ValueType>();
for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) { for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) {
ValueType actionValue = computeLowerBoundOfAction(action, explorationInformation, bounds); ValueType actionValue = computeLowerBoundOfAction(action, explorationInformation, bounds);
result = std::max(actionValue, result); result = std::max(actionValue, result);
@ -492,7 +502,7 @@ namespace storm {
template<typename ValueType> template<typename ValueType>
ValueType SparseMdpLearningModelChecker<ValueType>::computeUpperBoundOfState(StateType const& state, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const { ValueType SparseMdpLearningModelChecker<ValueType>::computeUpperBoundOfState(StateType const& state, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const {
StateType group = explorationInformation.getRowGroup(state); StateType group = explorationInformation.getRowGroup(state);
ValueType result = std::make_pair(storm::utility::zero<ValueType>(), storm::utility::zero<ValueType>()); ValueType result = storm::utility::zero<ValueType>();
for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) { for (ActionType action = explorationInformation.getStartRowOfGroup(group); action < explorationInformation.getStartRowOfGroup(group + 1); ++action) {
ValueType actionValue = computeUpperBoundOfAction(action, explorationInformation, bounds); ValueType actionValue = computeUpperBoundOfAction(action, explorationInformation, bounds);
result = std::max(actionValue, result); result = std::max(actionValue, result);
@ -525,12 +535,6 @@ namespace storm {
template<typename ValueType> template<typename ValueType>
void SparseMdpLearningModelChecker<ValueType>::updateProbabilityBoundsAlongSampledPath(StateActionStack& stack, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { void SparseMdpLearningModelChecker<ValueType>::updateProbabilityBoundsAlongSampledPath(StateActionStack& stack, ExplorationInformation const& explorationInformation, BoundValues& bounds) const {
std::cout << "stack is:" << std::endl;
for (auto const& el : stack) {
std::cout << el.first << "-[" << el.second << "]-> ";
}
std::cout << std::endl;
stack.pop_back(); stack.pop_back();
while (!stack.empty()) { while (!stack.empty()) {
updateProbabilityOfAction(stack.back().first, stack.back().second, explorationInformation, bounds); updateProbabilityOfAction(stack.back().first, stack.back().second, explorationInformation, bounds);
@ -538,51 +542,40 @@ namespace storm {
} }
} }
template<typename ValueType>
ValueType SparseMdpLearningModelChecker<ValueType>::computeUpperBoundOverAllOtherActions(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const {
ValueType max = storm::utility::zero<ValueType>();
ActionType group = explorationInformation.getRowGroup(state);
for (auto currentAction = explorationInformation.getStartRowOfGroup(group); currentAction < explorationInformation.getStartRowOfGroup(group + 1); ++currentAction) {
if (currentAction == action) {
continue;
}
max = std::max(max, computeUpperBoundOfAction(currentAction, explorationInformation, bounds));
}
return max;
}
template<typename ValueType> template<typename ValueType>
void SparseMdpLearningModelChecker<ValueType>::updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const { void SparseMdpLearningModelChecker<ValueType>::updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const {
// Compute the new lower/upper values of the action. // Compute the new lower/upper values of the action.
std::pair<ValueType, ValueType> newBoundsForAction = computeBoundsOfAction(action, explorationInformation, bounds); std::pair<ValueType, ValueType> newBoundsForAction = computeBoundsOfAction(action, explorationInformation, bounds);
// And set them as the current value. // And set them as the current value.
bounds.setBoundForAction(action, newBoundsForAction); bounds.setBoundsForAction(action, newBoundsForAction);
// Check if we need to update the values for the states. // Check if we need to update the values for the states.
bounds.setNewLowerBoundOfStateIfGreaterThanOld(state, explorationInformation, newBoundsForAction.first); bounds.setNewLowerBoundOfStateIfGreaterThanOld(state, explorationInformation, newBoundsForAction.first);
StateType rowGroup = explorationInformation.getRowGroup(state); StateType rowGroup = explorationInformation.getRowGroup(state);
if (newBoundsForAction < bounds.getUpperBoundOfRowGroup(rowGroup)) { if (newBoundsForAction.second < bounds.getUpperBoundForRowGroup(rowGroup)) {
if (explorationInformation.getRowGroupSize(rowGroup) > 1) {
} newBoundsForAction.second = std::max(newBoundsForAction.second, computeUpperBoundOverAllOtherActions(state, action, explorationInformation, bounds));
ValueType upperBound = computeUpperBoundOverAllOtherActions(state, action, explorationInformation, bounds);
uint32_t sourceRowGroup = stateToRowGroupMapping[sourceStateId];
if (newUpperValue < upperBoundsPerState[sourceRowGroup]) {
if (rowGroupIndices[sourceRowGroup + 1] - rowGroupIndices[sourceRowGroup] > 1) {
ValueType max = storm::utility::zero<ValueType>();
for (uint32_t currentAction = rowGroupIndices[sourceRowGroup]; currentAction < rowGroupIndices[sourceRowGroup + 1]; ++currentAction) {
std::cout << "cur: " << currentAction << std::endl;
if (currentAction == action) {
continue;
}
ValueType currentValue = storm::utility::zero<ValueType>();
for (auto const& element : transitionMatrix[currentAction]) {
currentValue += element.getValue() * (stateToRowGroupMapping[element.getColumn()] == unexploredMarker ? storm::utility::one<ValueType>() : upperBoundsPerState[stateToRowGroupMapping[element.getColumn()]]);
}
max = std::max(max, currentValue);
std::cout << "max is " << max << std::endl;
}
newUpperValue = std::max(newUpperValue, max);
}
if (newUpperValue < upperBoundsPerState[sourceRowGroup]) {
STORM_LOG_TRACE("Got new upper bound for state " << sourceStateId << ": " << newUpperValue << " (was " << upperBoundsPerState[sourceRowGroup] << ").");
std::cout << "writing at index " << sourceRowGroup << std::endl;
upperBoundsPerState[sourceRowGroup] = newUpperValue;
} }
bounds.setUpperBoundForState(state, explorationInformation, newBoundsForAction.second);
} }
} }

26
src/modelchecker/reachability/SparseMdpLearningModelChecker.h

@ -62,7 +62,7 @@ namespace storm {
// A struct containing the data required for state exploration. // A struct containing the data required for state exploration.
struct StateGeneration { struct StateGeneration {
StateGeneration(storm::generator::PrismNextStateGenerator<ValueType, StateType>&& generator, storm::expressions::Expression const& targetStateExpression) : generator(std::move(generator)), targetStateExpression(targetStateExpression) { StateGeneration(storm::prism::Program const& program, storm::generator::VariableInformation const& variableInformation, storm::expressions::Expression const& targetStateExpression) : generator(program, variableInformation, false), targetStateExpression(targetStateExpression) {
// Intentionally left empty. // Intentionally left empty.
} }
@ -125,8 +125,9 @@ namespace storm {
stateToRowGroupMapping[state] = rowGroup; stateToRowGroupMapping[state] = rowGroup;
} }
void assignStateToNextRowGroup(StateType const& state) { StateType assignStateToNextRowGroup(StateType const& state) {
stateToRowGroupMapping[state] = rowGroupIndices.size() - 1; stateToRowGroupMapping[state] = rowGroupIndices.size() - 1;
return stateToRowGroupMapping[state];
} }
void newRowGroup(ActionType const& action) { void newRowGroup(ActionType const& action) {
@ -154,7 +155,7 @@ namespace storm {
} }
bool isUnexplored(StateType const& state) const { bool isUnexplored(StateType const& state) const {
return unexploredStates.find(state) == unexploredStates.end(); return stateToRowGroupMapping[state] == unexploredMarker;
} }
bool isTerminal(StateType const& state) const { bool isTerminal(StateType const& state) const {
@ -165,6 +166,10 @@ namespace storm {
return rowGroupIndices[group]; return rowGroupIndices[group];
} }
std::size_t getRowGroupSize(StateType const& group) const {
return rowGroupIndices[group + 1] - rowGroupIndices[group];
}
void addTerminalState(StateType const& state) { void addTerminalState(StateType const& state) {
terminalStates.insert(state); terminalStates.insert(state);
} }
@ -216,11 +221,11 @@ namespace storm {
if (index == explorationInformation.getUnexploredMarker()) { if (index == explorationInformation.getUnexploredMarker()) {
return storm::utility::one<ValueType>(); return storm::utility::one<ValueType>();
} else { } else {
return getUpperBoundForRowGroup(index, explorationInformation); return getUpperBoundForRowGroup(index);
} }
} }
ValueType getUpperBoundForRowGroup(StateType const& rowGroup, ExplorationInformation const& explorationInformation) const { ValueType const& getUpperBoundForRowGroup(StateType const& rowGroup) const {
return upperBoundsPerState[rowGroup]; return upperBoundsPerState[rowGroup];
} }
@ -241,12 +246,12 @@ namespace storm {
return bounds.second - bounds.first; return bounds.second - bounds.first;
} }
void initializeStateBoundsForNextState(std::pair<ValueType, ValueType> const& vals = std::pair<ValueType, ValueType>(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>())) { void initializeBoundsForNextState(std::pair<ValueType, ValueType> const& vals = std::pair<ValueType, ValueType>(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>())) {
lowerBoundsPerState.push_back(vals.first); lowerBoundsPerState.push_back(vals.first);
upperBoundsPerState.push_back(vals.second); upperBoundsPerState.push_back(vals.second);
} }
void initializeActionBoundsForNextAction(std::pair<ValueType, ValueType> const& vals = std::pair<ValueType, ValueType>(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>())) { void initializeBoundsForNextAction(std::pair<ValueType, ValueType> const& vals = std::pair<ValueType, ValueType>(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>())) {
lowerBoundsPerAction.push_back(vals.first); lowerBoundsPerAction.push_back(vals.first);
upperBoundsPerAction.push_back(vals.second); upperBoundsPerAction.push_back(vals.second);
} }
@ -274,14 +279,18 @@ namespace storm {
StateType const& rowGroup = explorationInformation.getRowGroup(state); StateType const& rowGroup = explorationInformation.getRowGroup(state);
if (lowerBoundsPerState[rowGroup] < newLowerValue) { if (lowerBoundsPerState[rowGroup] < newLowerValue) {
lowerBoundsPerState[rowGroup] = newLowerValue; lowerBoundsPerState[rowGroup] = newLowerValue;
return true;
} }
return false;
} }
bool setNewUpperBoundOfStateIfLessThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newUpperValue) { bool setNewUpperBoundOfStateIfLessThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newUpperValue) {
StateType const& rowGroup = explorationInformation.getRowGroup(state); StateType const& rowGroup = explorationInformation.getRowGroup(state);
if (newUpperValue < upperBoundsPerState[rowGroup]) { if (newUpperValue < upperBoundsPerState[rowGroup]) {
upperBoundsPerState[rowGroup] = newUpperValue; upperBoundsPerState[rowGroup] = newUpperValue;
return true;
} }
return false;
} }
}; };
@ -306,6 +315,7 @@ namespace storm {
void updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; void updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const;
std::pair<ValueType, ValueType> computeBoundsOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; std::pair<ValueType, ValueType> computeBoundsOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
ValueType computeUpperBoundOverAllOtherActions(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
std::pair<ValueType, ValueType> computeBoundsOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; std::pair<ValueType, ValueType> computeBoundsOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
ValueType computeLowerBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeLowerBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
ValueType computeUpperBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeUpperBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
@ -319,7 +329,7 @@ namespace storm {
storm::generator::VariableInformation variableInformation; storm::generator::VariableInformation variableInformation;
// The random number generator. // The random number generator.
std::default_random_engine randomGenerator; mutable std::default_random_engine randomGenerator;
// A comparator used to determine whether values are equal. // A comparator used to determine whether values are equal.
storm::utility::ConstantsComparator<ValueType> comparator; storm::utility::ConstantsComparator<ValueType> comparator;

|||||||
100:0
Loading…
Cancel
Save