diff --git a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp index 68d7631d0..0d88c456e 100644 --- a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp +++ b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp @@ -17,6 +17,30 @@ namespace storm { namespace modelchecker { namespace helper { + template + void SparseNondeterministicStepBoundedHorizonHelper::getMaybeStatesRowGroupSizes(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector& maybeStatesRowGroupSizes, uint& choiceValuesCounter) { + std::vector rowGroupIndices = transitionMatrix.getRowGroupIndices(); + for(uint counter = 0; counter < maybeStates.size(); counter++) { + if(maybeStates.get(counter)) { + maybeStatesRowGroupSizes.push_back(rowGroupIndices.at(counter)); + choiceValuesCounter += transitionMatrix.getRowGroupSize(counter); + } + } + } + + template + void SparseNondeterministicStepBoundedHorizonHelper::getChoiceValues(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector const& choiceValues, std::vector& allChoices) { + auto choice_it = choiceValues.begin(); + for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) { + uint rowGroupSize = transitionMatrix.getRowGroupSize(state); + if (maybeStates.get(state)) { + for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { + allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it; + } + } + } + } + template SparseNondeterministicStepBoundedHorizonHelper::SparseNondeterministicStepBoundedHorizonHelper(/*storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions*/) //transitionMatrix(transitionMatrix), backwardTransitions(backwardTransitions) @@ -25,7 +49,7 @@ namespace storm { } template - std::vector SparseNondeterministicStepBoundedHorizonHelper::compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint) + std::vector SparseNondeterministicStepBoundedHorizonHelper::compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint, storm::storage::BitVector& resultMaybeStates, std::vector& choiceValues) { std::vector result(transitionMatrix.getRowGroupCount(), storm::utility::zero()); storm::storage::BitVector makeZeroColumns; @@ -60,14 +84,45 @@ namespace storm { std::vector subresult(maybeStates.getNumberOfSetBits()); auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); + + std::vector rowGroupIndices = transitionMatrix.getRowGroupIndices(); + std::vector maybeStatesRowGroupSizes; + uint choiceValuesCounter; + getMaybeStatesRowGroupSizes(transitionMatrix, maybeStates, maybeStatesRowGroupSizes, choiceValuesCounter); + choiceValues = std::vector(choiceValuesCounter, storm::utility::zero()); + if (lowerBound == 0) { - multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound); + if(goal.isShieldingTask()) + { + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound, nullptr, choiceValues, maybeStatesRowGroupSizes); + + // fill up choicesValues for shields + std::vector allChoices = std::vector(transitionMatrix.getRowCount(), storm::utility::zero()); + getChoiceValues(transitionMatrix, maybeStates, choiceValues, allChoices); + choiceValues = allChoices; + } else { + multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound); + } } else { - multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1); - storm::storage::SparseMatrix submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); - auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); - b = std::vector(b.size(), storm::utility::zero()); - multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, lowerBound - 1); + if(goal.isShieldingTask()) + { + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1, nullptr, choiceValues, maybeStatesRowGroupSizes); + storm::storage::SparseMatrix submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); + auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); + b = std::vector(b.size(), storm::utility::zero()); + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, lowerBound - 1, nullptr, choiceValues, maybeStatesRowGroupSizes); + + // fill up choicesValues for shields + std::vector allChoices = std::vector(transitionMatrix.getRowCount(), storm::utility::zero()); + getChoiceValues(transitionMatrix, maybeStates, choiceValues, allChoices); + choiceValues = allChoices; + } else { + multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1); + storm::storage::SparseMatrix submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); + auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); + b = std::vector(b.size(), storm::utility::zero()); + multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, lowerBound - 1); + } } // Set the values of the resulting vector accordingly. storm::utility::vector::setVectorValues(result, maybeStates, subresult); @@ -75,12 +130,13 @@ namespace storm { if (lowerBound == 0) { storm::utility::vector::setVectorValues(result, psiStates, storm::utility::one()); } + + resultMaybeStates = maybeStates; return result; } - template class SparseNondeterministicStepBoundedHorizonHelper; template class SparseNondeterministicStepBoundedHorizonHelper; } } -} \ No newline at end of file +} diff --git a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h index 98fc49c2c..0ef64fe2c 100644 --- a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h +++ b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h @@ -4,6 +4,7 @@ #include "storm/modelchecker/hints/ModelCheckerHint.h" #include "storm/modelchecker/prctl/helper/SolutionType.h" #include "storm/storage/SparseMatrix.h" +#include "storm/storage/BitVector.h" #include "storm/utility/solver.h" #include "storm/solver/SolveGoal.h" @@ -15,7 +16,10 @@ namespace storm { class SparseNondeterministicStepBoundedHorizonHelper { public: SparseNondeterministicStepBoundedHorizonHelper(); - std::vector compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint = ModelCheckerHint()); + + void getMaybeStatesRowGroupSizes(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector& maybeStatesRowGroupSizes, uint& choiceValuesCounter); + void getChoiceValues(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector const& choiceValues, std::vector& allChoices); + std::vector compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint = ModelCheckerHint(), storm::storage::BitVector& resultMaybeStates = nullptr, std::vector& choiceValues = nullptr); private: }; diff --git a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp index 8cc954f91..ffefa0e78 100644 --- a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp +++ b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp @@ -24,6 +24,10 @@ #include "storm/solver/SolveGoal.h" +#include "storm/storage/BitVector.h" + +#include "storm/shields/ShieldHandling.h" + #include "storm/settings/modules/GeneralSettings.h" #include "storm/exceptions/InvalidStateException.h" @@ -40,7 +44,7 @@ namespace storm { SparseMdpPrctlModelChecker::SparseMdpPrctlModelChecker(SparseMdpModelType const& model) : SparsePropositionalModelChecker(model) { // Intentionally left empty. } - + template bool SparseMdpPrctlModelChecker::canHandleStatic(CheckTask const& checkTask, bool* requiresSingleInitialState) { storm::logic::Formula const& formula = checkTask.getFormula(); @@ -57,7 +61,7 @@ namespace storm { } return false; } - + template bool SparseMdpPrctlModelChecker::canHandle(CheckTask const& checkTask) const { bool requiresSingleInitialState = false; @@ -67,7 +71,7 @@ namespace storm { return false; } } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeBoundedUntilProbabilities(Environment const& env, CheckTask const& checkTask) { storm::logic::BoundedUntilFormula const& pathFormula = checkTask.getFormula(); @@ -92,21 +96,36 @@ namespace storm { ExplicitQualitativeCheckResult const& leftResult = leftResultPointer->asExplicitQualitativeCheckResult(); ExplicitQualitativeCheckResult const& rightResult = rightResultPointer->asExplicitQualitativeCheckResult(); storm::modelchecker::helper::SparseNondeterministicStepBoundedHorizonHelper helper; - std::vector numericResult = helper.compute(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), pathFormula.getNonStrictLowerBound(), pathFormula.getNonStrictUpperBound(), checkTask.getHint()); + std::vector numericResult; + + //This works only with empty vectors, no nullptr + storm::storage::BitVector resultMaybeStates; + std::vector choiceValues; + + numericResult = helper.compute(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), pathFormula.getNonStrictLowerBound(), pathFormula.getNonStrictUpperBound(), checkTask.getHint(), resultMaybeStates, choiceValues); + if(checkTask.isShieldingTask()) { + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(resultMaybeStates), storm::storage::BitVector(resultMaybeStates.size(), true)); + } return std::unique_ptr(new ExplicitQuantitativeCheckResult(std::move(numericResult))); } } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeNextProbabilities(Environment const& env, CheckTask const& checkTask) { storm::logic::NextFormula const& pathFormula = checkTask.getFormula(); STORM_LOG_THROW(checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "Formula needs to specify whether minimal or maximal values are to be computed on nondeterministic model."); std::unique_ptr subResultPointer = this->check(env, pathFormula.getSubformula()); ExplicitQualitativeCheckResult const& subResult = subResultPointer->asExplicitQualitativeCheckResult(); - std::vector numericResult = storm::modelchecker::helper::SparseMdpPrctlHelper::computeNextProbabilities(env, checkTask.getOptimizationDirection(), this->getModel().getTransitionMatrix(), subResult.getTruthValuesVector()); - return std::unique_ptr(new ExplicitQuantitativeCheckResult(std::move(numericResult))); + auto ret = storm::modelchecker::helper::SparseMdpPrctlHelper::computeNextProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), checkTask.getOptimizationDirection(), this->getModel().getTransitionMatrix(), subResult.getTruthValuesVector()); + std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); + if(checkTask.isShieldingTask()) { + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), true)); + } else if (checkTask.isProduceSchedulersSet() && ret.scheduler) { + result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); + } + return result; } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeUntilProbabilities(Environment const& env, CheckTask const& checkTask) { storm::logic::UntilFormula const& pathFormula = checkTask.getFormula(); @@ -117,12 +136,14 @@ namespace storm { ExplicitQualitativeCheckResult const& rightResult = rightResultPointer->asExplicitQualitativeCheckResult(); auto ret = storm::modelchecker::helper::SparseMdpPrctlHelper::computeUntilProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), checkTask.isQualitativeSet(), checkTask.isProduceSchedulersSet(), checkTask.getHint()); std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); - if (checkTask.isProduceSchedulersSet() && ret.scheduler) { + if(checkTask.isShieldingTask()) { + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), true)); + } else if (checkTask.isProduceSchedulersSet() && ret.scheduler) { result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); } return result; } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeGloballyProbabilities(Environment const& env, CheckTask const& checkTask) { storm::logic::GloballyFormula const& pathFormula = checkTask.getFormula(); @@ -131,12 +152,14 @@ namespace storm { ExplicitQualitativeCheckResult const& subResult = subResultPointer->asExplicitQualitativeCheckResult(); auto ret = storm::modelchecker::helper::SparseMdpPrctlHelper::computeGloballyProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), subResult.getTruthValuesVector(), checkTask.isQualitativeSet(), checkTask.isProduceSchedulersSet()); std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); - if (checkTask.isProduceSchedulersSet() && ret.scheduler) { + if(checkTask.isShieldingTask()) { + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(),subResult.getTruthValuesVector(), storm::storage::BitVector(ret.maybeStates.size(), true)); + } else if (checkTask.isProduceSchedulersSet() && ret.scheduler) { result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); } return result; } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeConditionalProbabilities(Environment const& env, CheckTask const& checkTask) { storm::logic::ConditionalFormula const& conditionalFormula = checkTask.getFormula(); @@ -152,7 +175,7 @@ namespace storm { return storm::modelchecker::helper::SparseMdpPrctlHelper::computeConditionalProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector()); } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeCumulativeRewards(Environment const& env, storm::logic::RewardMeasureType, CheckTask const& checkTask) { storm::logic::CumulativeRewardFormula const& rewardPathFormula = checkTask.getFormula(); @@ -176,7 +199,7 @@ namespace storm { return std::unique_ptr(new ExplicitQuantitativeCheckResult(std::move(numericResult))); } } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeInstantaneousRewards(Environment const& env, storm::logic::RewardMeasureType, CheckTask const& checkTask) { storm::logic::InstantaneousRewardFormula const& rewardPathFormula = checkTask.getFormula(); @@ -185,7 +208,7 @@ namespace storm { std::vector numericResult = storm::modelchecker::helper::SparseMdpPrctlHelper::computeInstantaneousRewards(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), checkTask.isRewardModelSet() ? this->getModel().getRewardModel(checkTask.getRewardModel()) : this->getModel().getRewardModel(""), rewardPathFormula.getBound()); return std::unique_ptr(new ExplicitQuantitativeCheckResult(std::move(numericResult))); } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeReachabilityRewards(Environment const& env, storm::logic::RewardMeasureType, CheckTask const& checkTask) { storm::logic::EventuallyFormula const& eventuallyFormula = checkTask.getFormula(); @@ -200,7 +223,7 @@ namespace storm { } return result; } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeReachabilityTimes(Environment const& env, storm::logic::RewardMeasureType, CheckTask const& checkTask) { storm::logic::EventuallyFormula const& eventuallyFormula = checkTask.getFormula(); @@ -214,7 +237,7 @@ namespace storm { } return result; } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeTotalRewards(Environment const& env, storm::logic::RewardMeasureType, CheckTask const& checkTask) { STORM_LOG_THROW(checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "Formula needs to specify whether minimal or maximal values are to be computed on nondeterministic model."); @@ -233,18 +256,18 @@ namespace storm { STORM_LOG_THROW(checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "Formula needs to specify whether minimal or maximal values are to be computed on nondeterministic model."); std::unique_ptr subResultPointer = this->check(env, stateFormula); ExplicitQualitativeCheckResult const& subResult = subResultPointer->asExplicitQualitativeCheckResult(); - + storm::modelchecker::helper::SparseNondeterministicInfiniteHorizonHelper helper(this->getModel().getTransitionMatrix()); storm::modelchecker::helper::setInformationFromCheckTaskNondeterministic(helper, checkTask, this->getModel()); auto values = helper.computeLongRunAverageProbabilities(env, subResult.getTruthValuesVector()); - + std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(values))); if (checkTask.isProduceSchedulersSet()) { result->asExplicitQuantitativeCheckResult().setScheduler(std::make_unique>(helper.extractScheduler())); } return result; } - + template std::unique_ptr SparseMdpPrctlModelChecker::computeLongRunAverageRewards(Environment const& env, storm::logic::RewardMeasureType rewardMeasureType, CheckTask const& checkTask) { STORM_LOG_THROW(checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "Formula needs to specify whether minimal or maximal values are to be computed on nondeterministic model."); @@ -258,12 +281,12 @@ namespace storm { } return result; } - + template std::unique_ptr SparseMdpPrctlModelChecker::checkMultiObjectiveFormula(Environment const& env, CheckTask const& checkTask) { return multiobjective::performMultiObjectiveModelChecking(env, this->getModel(), checkTask.getFormula()); } - + template std::unique_ptr SparseMdpPrctlModelChecker::checkQuantileFormula(Environment const& env, CheckTask const& checkTask) { STORM_LOG_THROW(checkTask.isOnlyInitialStatesRelevantSet(), storm::exceptions::InvalidOperationException, "Computing quantiles is only supported for the initial states of a model."); @@ -272,14 +295,14 @@ namespace storm { helper::rewardbounded::QuantileHelper qHelper(this->getModel(), checkTask.getFormula()); auto res = qHelper.computeQuantile(env); - + if (res.size() == 1 && res.front().size() == 1) { return std::unique_ptr(new ExplicitQuantitativeCheckResult(initialState, std::move(res.front().front()))); } else { return std::unique_ptr(new ExplicitParetoCurveCheckResult(initialState, std::move(res))); } } - + template class SparseMdpPrctlModelChecker>; #ifdef STORM_HAVE_CARL diff --git a/src/storm/modelchecker/prctl/helper/MDPModelCheckingHelperReturnType.h b/src/storm/modelchecker/prctl/helper/MDPModelCheckingHelperReturnType.h index 97efcc3bc..51f501fc7 100644 --- a/src/storm/modelchecker/prctl/helper/MDPModelCheckingHelperReturnType.h +++ b/src/storm/modelchecker/prctl/helper/MDPModelCheckingHelperReturnType.h @@ -18,7 +18,7 @@ namespace storm { MDPSparseModelCheckingHelperReturnType(MDPSparseModelCheckingHelperReturnType const&) = delete; MDPSparseModelCheckingHelperReturnType(MDPSparseModelCheckingHelperReturnType&&) = default; - MDPSparseModelCheckingHelperReturnType(std::vector&& values, std::unique_ptr>&& scheduler = nullptr) : values(std::move(values)), scheduler(std::move(scheduler)) { + MDPSparseModelCheckingHelperReturnType(std::vector&& values, storm::storage::BitVector&& maybeStates = nullptr, std::unique_ptr>&& scheduler = nullptr, std::vector&& choiceValues = nullptr) : values(std::move(values)), maybeStates(maybeStates), scheduler(std::move(scheduler)), choiceValues(std::move(choiceValues)) { // Intentionally left empty. } @@ -29,8 +29,14 @@ namespace storm { // The values computed for the states. std::vector values; + // The maybe states of the model + storm::storage::BitVector maybeStates; + // A scheduler, if it was computed. std::unique_ptr> scheduler; + + // The values computed for the available choices. + std::vector choiceValues; }; } diff --git a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp index eb63176c6..11f3bcaca 100644 --- a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp +++ b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp @@ -134,16 +134,25 @@ namespace storm { } template - std::vector SparseMdpPrctlHelper::computeNextProbabilities(Environment const& env, OptimizationDirection dir, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& nextStates) { + MDPSparseModelCheckingHelperReturnType SparseMdpPrctlHelper::computeNextProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, OptimizationDirection dir, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& nextStates) { // Create the vector with which to multiply and initialize it correctly. std::vector result(transitionMatrix.getRowGroupCount()); storm::utility::vector::setVectorValues(result, nextStates, storm::utility::one()); - + std::vector choiceValues = std::vector(transitionMatrix.getRowCount(), storm::utility::zero()); + storm::storage::BitVector allStates = storm::storage::BitVector(transitionMatrix.getRowGroupCount(), true); + auto multiplier = storm::solver::MultiplierFactory().create(env, transitionMatrix); - multiplier->multiplyAndReduce(env, dir, result, nullptr, result); - - return result; + + if(goal.isShieldingTask()) { + multiplier->multiply(env, result, nullptr, choiceValues); + multiplier->reduce(env, goal.direction(), choiceValues, transitionMatrix.getRowGroupIndices(), result, nullptr); + } + else { + multiplier->multiplyAndReduce(env, dir, result, nullptr, result); + } + + return MDPSparseModelCheckingHelperReturnType(std::move(result), std::move(allStates), nullptr, std::move(choiceValues)); } template @@ -594,7 +603,7 @@ namespace storm { // We need to identify the maybe states (states which have a probability for satisfying the until formula // that is strictly between 0 and 1) and the states that satisfy the formula with probablity 1 and 0, respectively. QualitativeStateSetsUntilProbabilities qualitativeStateSets = getQualitativeStateSetsUntilProbabilities(goal, transitionMatrix, backwardTransitions, phiStates, psiStates, hint); - + STORM_LOG_INFO("Preprocessing: " << qualitativeStateSets.statesWithProbability1.getNumberOfSetBits() << " states with probability 1, " << qualitativeStateSets.statesWithProbability0.getNumberOfSetBits() << " with probability 0 (" << qualitativeStateSets.maybeStates.getNumberOfSetBits() << " states remaining)."); // Set values of resulting vector that are known exactly. @@ -608,22 +617,34 @@ namespace storm { // Check if the values of the maybe states are relevant for the SolveGoal bool maybeStatesNotRelevant = goal.hasRelevantValues() && goal.relevantValues().isDisjointFrom(qualitativeStateSets.maybeStates); + + // create multiplier and execute the calculation for 1 additional step + auto multiplier = storm::solver::MultiplierFactory().create(env, transitionMatrix); + + uint sizeMaybeStateChoiceValues = 0; + for(uint counter = 0; counter < qualitativeStateSets.maybeStates.size(); counter++) { + if(qualitativeStateSets.maybeStates.get(counter)) { + sizeMaybeStateChoiceValues += transitionMatrix.getRowGroupSize(counter); + } + } + + std::vector maybeStateChoiceValues = std::vector(sizeMaybeStateChoiceValues, storm::utility::zero()); // Check whether we need to compute exact probabilities for some states. - if (qualitative || maybeStatesNotRelevant) { + if (qualitative || maybeStatesNotRelevant || !goal.isShieldingTask()) { // Set the values for all maybe-states to 0.5 to indicate that their probability values are neither 0 nor 1. storm::utility::vector::setVectorValues(result, qualitativeStateSets.maybeStates, storm::utility::convertNumber(0.5)); } else { if (!qualitativeStateSets.maybeStates.empty()) { // In this case we have have to compute the remaining probabilities. - + // Obtain proper hint information either from the provided hint or from requirements of the solver. SparseMdpHintType hintInformation = computeHints(env, SolutionType::UntilProbabilities, hint, goal.direction(), transitionMatrix, backwardTransitions, qualitativeStateSets.maybeStates, phiStates, qualitativeStateSets.statesWithProbability1, produceScheduler); // Declare the components of the equation system we will solve. storm::storage::SparseMatrix submatrix; std::vector b; - + // If the hint information tells us that we have to eliminate MECs, we do so now. boost::optional> ecInformation; if (hintInformation.getEliminateEndComponents()) { @@ -632,10 +653,10 @@ namespace storm { // Otherwise, we compute the standard equations. computeFixedPointSystemUntilProbabilities(goal, transitionMatrix, qualitativeStateSets, submatrix, b); } - + // Now compute the results for the maybe states. MaybeStateResult resultForMaybeStates = computeValuesForMaybeStates(env, std::move(goal), std::move(submatrix), b, produceScheduler, hintInformation); - + // If we eliminated end components, we need to extract the result differently. if (ecInformation && ecInformation.get().getEliminatedEndComponents()) { ecInformation.get().setValues(result, qualitativeStateSets.maybeStates, resultForMaybeStates.getValues()); @@ -649,6 +670,39 @@ namespace storm { extractSchedulerChoices(*scheduler, resultForMaybeStates.getScheduler(), qualitativeStateSets.maybeStates); } } + if (goal.isShieldingTask()) { + std::vector subResult; + uint sizeChoiceValues = 0; + for(uint counter = 0; counter < qualitativeStateSets.maybeStates.size(); counter++) { + if(qualitativeStateSets.maybeStates.get(counter)) { + subResult.push_back(result.at(counter)); + } + } + + submatrix = transitionMatrix.getSubmatrix(true, qualitativeStateSets.maybeStates, qualitativeStateSets.maybeStates, false); + auto sub_multiplier = storm::solver::MultiplierFactory().create(env, submatrix); + sub_multiplier->multiply(env, subResult, &b, maybeStateChoiceValues); + + } + } + } + + std::vector choiceValues = std::vector(transitionMatrix.getRowGroupIndices().at(transitionMatrix.getRowGroupIndices().size() - 1), storm::utility::zero()); + auto choice_it = maybeStateChoiceValues.begin(); + for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) { + uint rowGroupSize = transitionMatrix.getRowGroupIndices().at(state + 1) - transitionMatrix.getRowGroupIndices().at(state); + if (qualitativeStateSets.maybeStates.get(state)) { + for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { + choiceValues.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it; + } + } else if (qualitativeStateSets.statesWithProbability0.get(state)) { + for(uint choice = 0; choice < rowGroupSize; choice++) { + choiceValues.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = 0; + } + } else if (qualitativeStateSets.statesWithProbability1.get(state)) { + for(uint choice = 0; choice < rowGroupSize; choice++) { + choiceValues.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = 1; + } } } @@ -664,7 +718,7 @@ namespace storm { STORM_LOG_ASSERT((!produceScheduler && !scheduler) || scheduler->isMemorylessScheduler(), "Expected a memoryless scheduler"); // Return result. - return MDPSparseModelCheckingHelperReturnType(std::move(result), std::move(scheduler)); + return MDPSparseModelCheckingHelperReturnType(std::move(result), std::move(qualitativeStateSets.maybeStates), std::move(scheduler), std::move(choiceValues)); } template @@ -678,7 +732,6 @@ namespace storm { statesInPsiMecs.set(stateActionsPair.first, true); } } - return computeUntilProbabilities(env, std::move(goal), transitionMatrix, backwardTransitions, psiStates, statesInPsiMecs, qualitative, produceScheduler); } else { goal.oneMinus(); @@ -686,6 +739,9 @@ namespace storm { for (auto& element : result.values) { element = storm::utility::one() - element; } + for (auto& choice : result.choiceValues) { + choice = storm::utility::one() - choice; + } return result; } } @@ -1080,7 +1136,7 @@ namespace storm { // Prepare resulting vector. std::vector result(transitionMatrix.getRowGroupCount(), storm::utility::zero()); - + // Determine which states have a reward that is infinity or less than infinity. QualitativeStateSetsReachabilityRewards qualitativeStateSets = getQualitativeStateSetsReachabilityRewards(goal, transitionMatrix, backwardTransitions, targetStates, hint, zeroRewardStatesGetter, zeroRewardChoicesGetter); @@ -1172,8 +1228,9 @@ namespace storm { STORM_LOG_ASSERT((!produceScheduler && !scheduler) || scheduler->isDeterministicScheduler(), "Expected a deterministic scheduler"); STORM_LOG_ASSERT((!produceScheduler && !scheduler) || scheduler->isMemorylessScheduler(), "Expected a memoryless scheduler"); + std::vector choiceValues; - return MDPSparseModelCheckingHelperReturnType(std::move(result), std::move(scheduler)); + return MDPSparseModelCheckingHelperReturnType(std::move(result), std::move(qualitativeStateSets.maybeStates), std::move(scheduler), std::move(choiceValues)); } template diff --git a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h index b3137841d..099533dab 100644 --- a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h +++ b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h @@ -41,7 +41,7 @@ namespace storm { static std::map computeRewardBoundedValues(Environment const& env, OptimizationDirection dir, rewardbounded::MultiDimensionalRewardUnfolding& rewardUnfolding, storm::storage::BitVector const& initialStates); - static std::vector computeNextProbabilities(Environment const& env, OptimizationDirection dir, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& nextStates); + static MDPSparseModelCheckingHelperReturnType computeNextProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, OptimizationDirection dir, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& nextStates); static MDPSparseModelCheckingHelperReturnType computeUntilProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, bool qualitative, bool produceScheduler, ModelCheckerHint const& hint = ModelCheckerHint()); diff --git a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp index b193b9f09..1668bc440 100644 --- a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp +++ b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp @@ -24,7 +24,7 @@ #include "storm/models/sparse/StandardRewardModel.h" -#include "storm/shields/shield-handling.h" +#include "storm/shields/ShieldHandling.h" #include "storm/settings/modules/GeneralSettings.h" @@ -212,7 +212,7 @@ namespace storm { std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(values))); if(checkTask.isShieldingTask()) { - tempest::shields::createOptimalShield(std::make_shared>(this->getModel()), helper.getProducedOptimalChoices(), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), statesOfCoalition, statesOfCoalition); + tempest::shields::createQuantitativeShield(std::make_shared>(this->getModel()), helper.getProducedOptimalChoices(), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), statesOfCoalition, statesOfCoalition); } else if (checkTask.isProduceSchedulersSet()) { result->asExplicitQuantitativeCheckResult().setScheduler(std::make_unique>(helper.extractScheduler())); } diff --git a/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp b/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp index a57a0ccef..78d554346 100644 --- a/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp +++ b/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp @@ -41,8 +41,7 @@ namespace storm { } viHelper.performValueIteration(env, x, b, goal.direction()); - //if(goal.isShieldingTask()) { - if(true) { + if(goal.isShieldingTask()) { viHelper.getChoiceValues(env, x, constrainedChoiceValues); } viHelper.fillResultVector(x, relevantStates, psiStates); @@ -107,13 +106,12 @@ namespace storm { // create multiplier and execute the calculation for 1 step auto multiplier = storm::solver::MultiplierFactory().create(env, transitionMatrix); std::vector choiceValues = std::vector(transitionMatrix.getRowCount(), storm::utility::zero()); - - //if(goal.isShieldingTask()) { - if (true) { + if (goal.isShieldingTask()) { multiplier->multiply(env, x, &b, choiceValues); + multiplier->reduce(env, goal.direction(), choiceValues, transitionMatrix.getRowGroupIndices(), x, &statesOfCoalition); + } else { + multiplier->multiplyAndReduce(env, goal.direction(), x, &b, x, nullptr, &statesOfCoalition); } - multiplier->multiplyAndReduce(env, goal.direction(), x, &b, x, nullptr, &statesOfCoalition); - return SMGSparseModelCheckingHelperReturnType(std::move(x), std::move(allStates), nullptr, std::move(choiceValues)); } diff --git a/src/storm/shields/ShieldHandling.cpp b/src/storm/shields/ShieldHandling.cpp new file mode 100644 index 000000000..69240f6f1 --- /dev/null +++ b/src/storm/shields/ShieldHandling.cpp @@ -0,0 +1,47 @@ +#include "ShieldHandling.h" + +namespace tempest { + namespace shields { + std::string shieldFilename(std::shared_ptr const& shieldingExpression) { + return shieldingExpression->getFilename() + ".shield"; + } + + template + void createShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { + std::ofstream stream; + storm::utility::openFile(shieldFilename(shieldingExpression), stream); + if(shieldingExpression->isPreSafetyShield()) { + PreSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + shield.construct().printToStream(stream, shieldingExpression, model); + } else if(shieldingExpression->isPostSafetyShield()) { + PostSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + shield.construct().printToStream(stream, shieldingExpression, model); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); + storm::utility::closeFile(stream); + } + storm::utility::closeFile(stream); + } + + template + void createQuantitativeShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { + std::ofstream stream; + storm::utility::openFile(shieldFilename(shieldingExpression), stream); + if(shieldingExpression->isOptimalShield()) { + OptimalShield shield(model->getTransitionMatrix().getRowGroupIndices(), precomputedChoices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + shield.construct().printToStream(stream, shieldingExpression, model); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); + storm::utility::closeFile(stream); + } + storm::utility::closeFile(stream); + } + // Explicitly instantiate appropriate + template void createShield::index_type>(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + template void createQuantitativeShield::index_type>(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); +#ifdef STORM_HAVE_CARL + template void createShield::index_type>(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + template void createQuantitativeShield::index_type>(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); +#endif + } +} diff --git a/src/storm/shields/ShieldHandling.h b/src/storm/shields/ShieldHandling.h new file mode 100644 index 000000000..2b21a8522 --- /dev/null +++ b/src/storm/shields/ShieldHandling.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + +#include "storm/storage/Scheduler.h" +#include "storm/storage/BitVector.h" + +#include "storm/logic/ShieldExpression.h" + +#include "storm/shields/AbstractShield.h" +#include "storm/shields/PreSafetyShield.h" +#include "storm/shields/PostSafetyShield.h" +#include "storm/shields/OptimalShield.h" + +#include "storm/io/file.h" +#include "storm/utility/macros.h" + +#include "storm/exceptions/InvalidArgumentException.h" + +namespace tempest { + namespace shields { + std::string shieldFilename(std::shared_ptr const& shieldingExpression); + + template + void createShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + + template + void createQuantitativeShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + } +} diff --git a/src/storm/shields/shield-handling.h b/src/storm/shields/shield-handling.h deleted file mode 100644 index ef4686334..000000000 --- a/src/storm/shields/shield-handling.h +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "storm/storage/Scheduler.h" -#include "storm/storage/BitVector.h" - -#include "storm/logic/ShieldExpression.h" - -#include "storm/shields/AbstractShield.h" -#include "storm/shields/PreSafetyShield.h" -#include "storm/shields/PostSafetyShield.h" -#include "storm/shields/OptimalShield.h" - -#include "storm/io/file.h" -#include "storm/utility/macros.h" - -#include "storm/exceptions/InvalidArgumentException.h" - - -namespace tempest { - namespace shields { - std::string shieldFilename(std::shared_ptr const& shieldingExpression) { - return shieldingExpression->getFilename() + ".shield"; - } - - template - void createShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { - std::ofstream stream; - storm::utility::openFile(shieldFilename(shieldingExpression), stream); - if(shieldingExpression->isPreSafetyShield()) { - PreSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); - shield.construct().printToStream(stream, shieldingExpression, model); - } else if(shieldingExpression->isPostSafetyShield()) { - PostSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); - shield.construct().printToStream(stream, shieldingExpression, model); - } else { - STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); - storm::utility::closeFile(stream); - } - storm::utility::closeFile(stream); - } - - template - void createOptimalShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { - std::ofstream stream; - storm::utility::openFile(shieldFilename(shieldingExpression), stream); - if(shieldingExpression->isOptimalShield()) { - OptimalShield shield(model->getTransitionMatrix().getRowGroupIndices(), precomputedChoices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); - shield.construct().printToStream(stream, shieldingExpression, model); - } else { - STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); - storm::utility::closeFile(stream); - } - storm::utility::closeFile(stream); - } - } -} diff --git a/src/storm/solver/Multiplier.cpp b/src/storm/solver/Multiplier.cpp index 97eb22bb8..2324b7833 100644 --- a/src/storm/solver/Multiplier.cpp +++ b/src/storm/solver/Multiplier.cpp @@ -68,12 +68,53 @@ namespace storm { } } + template + void Multiplier::repeatedMultiplyAndReduceWithChoices(Environment const& env, OptimizationDirection const& dir, std::vector& x, std::vector const* b, uint64_t n, storm::storage::BitVector const* dirOverride, std::vector& choiceValues, std::vector::index_type> rowGroupIndices) const { + storm::utility::ProgressMeasurement progress("multiplications"); + progress.setMaxCount(n); + progress.startNewMeasurement(0); + for (uint64_t i = 0; i < n; ++i) { + multiply(env, x, b, choiceValues); + reduce(env, dir, choiceValues, rowGroupIndices, x); + if (storm::utility::resources::isTerminate()) { + STORM_LOG_WARN("Aborting after " << i << " of " << n << " multiplications"); + break; + } + } + } + template void Multiplier::multiplyRow2(uint64_t const& rowIndex, std::vector const& x1, ValueType& val1, std::vector const& x2, ValueType& val2) const { multiplyRow(rowIndex, x1, val1); multiplyRow(rowIndex, x2, val2); } + template + void Multiplier::reduce(Environment const& env, OptimizationDirection const& dir, std::vector const& choiceValues, std::vector::index_type> rowGroupIndices, std::vector& result, storm::storage::BitVector const* dirOverride) const { + auto choice_it = choiceValues.begin(); + for(uint state = 0; state < rowGroupIndices.size() - 1; state++) { + uint rowGroupSize = rowGroupIndices[state + 1] - rowGroupIndices[state]; + if(dirOverride != nullptr) { + if((dir == storm::OptimizationDirection::Minimize && !dirOverride->get(state)) || (dir == storm::OptimizationDirection::Maximize && dirOverride->get(state))) { + result.at(state) = *std::min_element(choice_it, choice_it + rowGroupSize); + choice_it += rowGroupSize; + } + else { + result.at(state) = *std::max_element(choice_it, choice_it + rowGroupSize); + choice_it += rowGroupSize; + } + } else { + if(dir == storm::OptimizationDirection::Minimize) { + result.at(state) = *std::min_element(choice_it, choice_it + rowGroupSize); + choice_it += rowGroupSize; + } else { + result.at(state) = *std::max_element(choice_it, choice_it + rowGroupSize); + choice_it += rowGroupSize; + } + } + } + } + template std::unique_ptr> MultiplierFactory::create(Environment const& env, storm::storage::SparseMatrix const& matrix) { auto type = env.solver().multiplier().getType(); diff --git a/src/storm/solver/Multiplier.h b/src/storm/solver/Multiplier.h index 4127a6625..5e132a628 100644 --- a/src/storm/solver/Multiplier.h +++ b/src/storm/solver/Multiplier.h @@ -8,6 +8,8 @@ #include "storm/solver/OptimizationDirection.h" #include "storm/solver/MultiplicationStyle.h" +#include "storm/storage/SparseMatrix.h" + namespace storm { @@ -119,6 +121,8 @@ namespace storm { */ void repeatedMultiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector& x, std::vector const* b, uint64_t n, storm::storage::BitVector const* dirOverride = nullptr) const; + void repeatedMultiplyAndReduceWithChoices(const Environment &env, const OptimizationDirection &dir, std::vector &x, const std::vector *b, uint64_t n, const storage::BitVector *dirOverride, std::vector &choiceValues, std::vector rowGroupIndices) const; + /*! * Multiplies the row with the given index with x and adds the result to the provided value * @param rowIndex The index of the considered row @@ -137,9 +141,12 @@ namespace storm { */ virtual void multiplyRow2(uint64_t const& rowIndex, std::vector const& x1, ValueType& val1, std::vector const& x2, ValueType& val2) const; + void reduce(Environment const& env, OptimizationDirection const& dir, std::vector const& choiceValues, std::vector::index_type> rowGroupIndices, std::vector& result, storm::storage::BitVector const* dirOverride = nullptr) const; + protected: mutable std::unique_ptr> cachedVector; storm::storage::SparseMatrix const& matrix; + }; template diff --git a/src/storm/solver/SolveGoal.cpp b/src/storm/solver/SolveGoal.cpp index 8913e1807..b410ff121 100644 --- a/src/storm/solver/SolveGoal.cpp +++ b/src/storm/solver/SolveGoal.cpp @@ -122,6 +122,11 @@ namespace storm { relevantValueVector = std::move(values); } + template + bool SolveGoal::isShieldingTask() const { + return shieldingTask; + } + template class SolveGoal; #ifdef STORM_HAVE_CARL diff --git a/src/storm/solver/SolveGoal.h b/src/storm/solver/SolveGoal.h index 9bd6f7e73..741d33b09 100644 --- a/src/storm/solver/SolveGoal.h +++ b/src/storm/solver/SolveGoal.h @@ -51,6 +51,7 @@ namespace storm { comparisonType = checkTask.getBoundComparisonType(); threshold = checkTask.getBoundThreshold(); } + shieldingTask = checkTask.isShieldingTask(); } SolveGoal(bool minimize); @@ -77,19 +78,23 @@ namespace storm { ValueType const& thresholdValue() const; bool hasRelevantValues() const; - + storm::storage::BitVector& relevantValues(); storm::storage::BitVector const& relevantValues() const; void restrictRelevantValues(storm::storage::BitVector const& filter); void setRelevantValues(storm::storage::BitVector&& values); - + + bool isShieldingTask() const; + private: boost::optional optimizationDirection; boost::optional comparisonType; boost::optional threshold; boost::optional relevantValueVector; + // We only want to know if it **is** a shielding task + bool shieldingTask; }; template