Browse Source

refactored helper code

tempestpy_adaptions
Stefan Pranger 3 years ago
parent
commit
6a5f626259
  1. 86
      src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp
  2. 3
      src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h

86
src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp

@ -17,6 +17,30 @@ namespace storm {
namespace modelchecker {
namespace helper {
template<typename ValueType>
void SparseNondeterministicStepBoundedHorizonHelper<ValueType>::getMaybeStatesRowGroupSizes(storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector<uint64_t>& maybeStatesRowGroupSizes, uint& choiceValuesCounter) {
std::vector<uint64_t> rowGroupIndices = transitionMatrix.getRowGroupIndices();
for(uint counter = 0; counter < maybeStates.size(); counter++) {
if(maybeStates.get(counter)) {
maybeStatesRowGroupSizes.push_back(rowGroupIndices.at(counter));
choiceValuesCounter += transitionMatrix.getRowGroupSize(counter);
}
}
}
template<typename ValueType>
void SparseNondeterministicStepBoundedHorizonHelper<ValueType>::getChoiceValues(storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector<ValueType> const& choiceValues, std::vector<ValueType>& allChoices) {
auto choice_it = choiceValues.begin();
for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) {
uint rowGroupSize = transitionMatrix.getRowGroupSize(state);
if (maybeStates.get(state)) {
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it;
}
}
}
}
template<typename ValueType>
SparseNondeterministicStepBoundedHorizonHelper<ValueType>::SparseNondeterministicStepBoundedHorizonHelper(/*storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::SparseMatrix<ValueType> const& backwardTransitions*/)
//transitionMatrix(transitionMatrix), backwardTransitions(backwardTransitions)
@ -60,34 +84,21 @@ namespace storm {
std::vector<ValueType> subresult(maybeStates.getNumberOfSetBits());
auto multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, submatrix);
std::vector<uint64_t> rowGroupIndices = transitionMatrix.getRowGroupIndices();
std::vector<uint64_t> maybeStatesRowGroupSizes;
uint choiceValuesCounter;
getMaybeStatesRowGroupSizes(transitionMatrix, maybeStates, maybeStatesRowGroupSizes, choiceValuesCounter);
choiceValues = std::vector<ValueType>(choiceValuesCounter, storm::utility::zero<ValueType>());
if (lowerBound == 0) {
if(goal.isShieldingTask())
{
std::vector<storm::storage::SparseMatrix<double>::index_type> rowGroupIndices = transitionMatrix.getRowGroupIndices();
std::vector<storm::storage::SparseMatrix<double>::index_type> reducedRowGroupIndices;
uint sizeChoiceValues = 0;
for(uint counter = 0; counter < maybeStates.size(); counter++) {
if(maybeStates.get(counter)) {
sizeChoiceValues += transitionMatrix.getRowGroupSize(counter);
reducedRowGroupIndices.push_back(rowGroupIndices.at(counter));
}
}
choiceValues = std::vector<ValueType>(sizeChoiceValues, storm::utility::zero<ValueType>());
multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound, nullptr, choiceValues, reducedRowGroupIndices);
multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound, nullptr, choiceValues, maybeStatesRowGroupSizes);
// fill up choicesValues for shields
std::vector<ValueType> allChoices = std::vector<ValueType>(transitionMatrix.getRowGroupIndices().at(transitionMatrix.getRowGroupIndices().size() - 1), storm::utility::zero<ValueType>());
auto choice_it = choiceValues.begin();
for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) {
uint rowGroupSize = transitionMatrix.getRowGroupIndices().at(state + 1) - transitionMatrix.getRowGroupIndices().at(state);
if (maybeStates.get(state)) {
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it;
}
}
}
std::vector<ValueType> allChoices = std::vector<ValueType>(transitionMatrix.getRowCount(), storm::utility::zero<ValueType>());
getChoiceValues(transitionMatrix, maybeStates, choiceValues, allChoices);
choiceValues = allChoices;
} else {
multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound);
@ -95,34 +106,15 @@ namespace storm {
} else {
if(goal.isShieldingTask())
{
std::vector<storm::storage::SparseMatrix<double>::index_type> rowGroupIndices = transitionMatrix.getRowGroupIndices();
std::vector<storm::storage::SparseMatrix<double>::index_type> reducedRowGroupIndices;
uint sizeChoiceValues = 0;
for(uint counter = 0; counter < maybeStates.size(); counter++) {
if(maybeStates.get(counter)) {
sizeChoiceValues += transitionMatrix.getRowGroupSize(counter);
reducedRowGroupIndices.push_back(rowGroupIndices.at(counter));
}
}
choiceValues = std::vector<ValueType>(sizeChoiceValues, storm::utility::zero<ValueType>());
multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1, nullptr, choiceValues, reducedRowGroupIndices);
multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1, nullptr, choiceValues, maybeStatesRowGroupSizes);
storm::storage::SparseMatrix<ValueType> submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false);
auto multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, submatrix);
b = std::vector<ValueType>(b.size(), storm::utility::zero<ValueType>());
multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, lowerBound - 1, nullptr, choiceValues, reducedRowGroupIndices);
multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, lowerBound - 1, nullptr, choiceValues, maybeStatesRowGroupSizes);
// fill up choicesValues for shields
std::vector<ValueType> allChoices = std::vector<ValueType>(transitionMatrix.getRowGroupIndices().at(transitionMatrix.getRowGroupIndices().size() - 1), storm::utility::zero<ValueType>());
auto choice_it = choiceValues.begin();
for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) {
uint rowGroupSize = transitionMatrix.getRowGroupIndices().at(state + 1) - transitionMatrix.getRowGroupIndices().at(state);
if (maybeStates.get(state)) {
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it;
}
}
}
std::vector<ValueType> allChoices = std::vector<ValueType>(transitionMatrix.getRowCount(), storm::utility::zero<ValueType>());
getChoiceValues(transitionMatrix, maybeStates, choiceValues, allChoices);
choiceValues = allChoices;
} else {
multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1);
@ -147,4 +139,4 @@ namespace storm {
template class SparseNondeterministicStepBoundedHorizonHelper<storm::RationalNumber>;
}
}
}
}

3
src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h

@ -16,6 +16,9 @@ namespace storm {
class SparseNondeterministicStepBoundedHorizonHelper {
public:
SparseNondeterministicStepBoundedHorizonHelper();
void getMaybeStatesRowGroupSizes(storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector<uint64_t>& maybeStatesRowGroupSizes, uint& choiceValuesCounter);
void getChoiceValues(storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector<ValueType> const& choiceValues, std::vector<ValueType>& allChoices);
std::vector<ValueType> compute(Environment const& env, storm::solver::SolveGoal<ValueType>&& goal, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::SparseMatrix<ValueType> const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint = ModelCheckerHint(), storm::storage::BitVector& resultMaybeStates = nullptr, std::vector<ValueType>& choiceValues = nullptr);
private:
};

Loading…
Cancel
Save