Browse Source

store the choiceValues in the iterations and the maybeStates then return it to the SparseMdpPrctlModelChecker.cpp

tempestpy_adaptions
Lukas Posch 3 years ago
committed by Stefan Pranger
parent
commit
def9e65525
  1. 39
      src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp
  2. 2
      src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h

39
src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp

@ -134,16 +134,25 @@ namespace storm {
}
template<typename ValueType>
std::vector<ValueType> SparseMdpPrctlHelper<ValueType>::computeNextProbabilities(Environment const& env, OptimizationDirection dir, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& nextStates) {
MDPSparseModelCheckingHelperReturnType<ValueType> SparseMdpPrctlHelper<ValueType>::computeNextProbabilities(Environment const& env, storm::solver::SolveGoal<ValueType>&& goal, OptimizationDirection dir, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& nextStates) {
// Create the vector with which to multiply and initialize it correctly.
std::vector<ValueType> result(transitionMatrix.getRowGroupCount());
storm::utility::vector::setVectorValues(result, nextStates, storm::utility::one<ValueType>());
std::vector<ValueType> choiceValues = std::vector<ValueType>(transitionMatrix.getRowCount(), storm::utility::zero<ValueType>());
storm::storage::BitVector allStates = storm::storage::BitVector(transitionMatrix.getRowGroupCount(), true);
auto multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, transitionMatrix);
multiplier->multiplyAndReduce(env, dir, result, nullptr, result);
return result;
if(goal.isShieldingTask()) {
multiplier->multiply(env, result, nullptr, choiceValues);
multiplier->reduce(env, goal.direction(), choiceValues, transitionMatrix.getRowGroupIndices(), result, nullptr);
}
else {
multiplier->multiplyAndReduce(env, dir, result, nullptr, result);
}
return MDPSparseModelCheckingHelperReturnType<ValueType>(std::move(result), std::move(allStates), nullptr, std::move(choiceValues));
}
template<typename ValueType>
@ -608,6 +617,10 @@ namespace storm {
// Check if the values of the maybe states are relevant for the SolveGoal
bool maybeStatesNotRelevant = goal.hasRelevantValues() && goal.relevantValues().isDisjointFrom(qualitativeStateSets.maybeStates);
// create multiplier and execute the calculation for 1 additional step
auto multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, transitionMatrix);
std::vector<ValueType> choiceValues = std::vector<ValueType>(transitionMatrix.getRowCount(), storm::utility::zero<ValueType>());
// Check whether we need to compute exact probabilities for some states.
if (qualitative || maybeStatesNotRelevant) {
@ -632,10 +645,10 @@ namespace storm {
// Otherwise, we compute the standard equations.
computeFixedPointSystemUntilProbabilities(goal, transitionMatrix, qualitativeStateSets, submatrix, b);
}
// Now compute the results for the maybe states.
MaybeStateResult<ValueType> resultForMaybeStates = computeValuesForMaybeStates(env, std::move(goal), std::move(submatrix), b, produceScheduler, hintInformation);
// If we eliminated end components, we need to extract the result differently.
if (ecInformation && ecInformation.get().getEliminatedEndComponents()) {
ecInformation.get().setValues(result, qualitativeStateSets.maybeStates, resultForMaybeStates.getValues());
@ -649,6 +662,10 @@ namespace storm {
extractSchedulerChoices(*scheduler, resultForMaybeStates.getScheduler(), qualitativeStateSets.maybeStates);
}
}
if (goal.isShieldingTask()) {
multiplier->multiply(env, result, &b, choiceValues);
multiplier->reduce(env, goal.direction(), choiceValues, transitionMatrix.getRowGroupIndices(), result, nullptr);
}
}
}
@ -664,7 +681,7 @@ namespace storm {
STORM_LOG_ASSERT((!produceScheduler && !scheduler) || scheduler->isMemorylessScheduler(), "Expected a memoryless scheduler");
// Return result.
return MDPSparseModelCheckingHelperReturnType<ValueType>(std::move(result), std::move(scheduler));
return MDPSparseModelCheckingHelperReturnType<ValueType>(std::move(result), std::move(qualitativeStateSets.maybeStates), std::move(scheduler), std::move(choiceValues));
}
template<typename ValueType>
@ -1173,7 +1190,11 @@ namespace storm {
STORM_LOG_ASSERT((!produceScheduler && !scheduler) || scheduler->isMemorylessScheduler(), "Expected a memoryless scheduler");
return MDPSparseModelCheckingHelperReturnType<ValueType>(std::move(result), std::move(scheduler));
//TODO: keep nullvector for choiceValues in computeReachabilityRewardsHelper? - nullptr causes ERROR
//std::vector<ValueType> choiceValues = std::vector<ValueType>(transitionMatrix.getRowCount(), storm::utility::zero<ValueType>());
std::vector<ValueType> choiceValues = std::vector<ValueType>();
return MDPSparseModelCheckingHelperReturnType<ValueType>(std::move(result), std::move(qualitativeStateSets.maybeStates), std::move(scheduler), std::move(choiceValues));
}
template<typename ValueType>

2
src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h

@ -41,7 +41,7 @@ namespace storm {
static std::map<storm::storage::sparse::state_type, ValueType> computeRewardBoundedValues(Environment const& env, OptimizationDirection dir, rewardbounded::MultiDimensionalRewardUnfolding<ValueType, true>& rewardUnfolding, storm::storage::BitVector const& initialStates);
static std::vector<ValueType> computeNextProbabilities(Environment const& env, OptimizationDirection dir, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& nextStates);
static MDPSparseModelCheckingHelperReturnType<ValueType> computeNextProbabilities(Environment const& env, storm::solver::SolveGoal<ValueType>&& goal, OptimizationDirection dir, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& nextStates);
static MDPSparseModelCheckingHelperReturnType<ValueType> computeUntilProbabilities(Environment const& env, storm::solver::SolveGoal<ValueType>&& goal, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::SparseMatrix<ValueType> const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, bool qualitative, bool produceScheduler, ModelCheckerHint const& hint = ModelCheckerHint());

Loading…
Cancel
Save