From 444929a9a3cf5af23291531282fed7a6ac4556a1 Mon Sep 17 00:00:00 2001 From: lukpo Date: Thu, 15 Jul 2021 11:52:47 +0200 Subject: [PATCH] create MDP shields if it is a shielding task --- .../prctl/SparseMdpPrctlModelChecker.cpp | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp index 8cc954f91..30cfcf19f 100644 --- a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp +++ b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp @@ -24,6 +24,10 @@ #include "storm/solver/SolveGoal.h" +#include "storm/storage/BitVector.h" + +#include "storm/shields/ShieldHandling.h" + #include "storm/settings/modules/GeneralSettings.h" #include "storm/exceptions/InvalidStateException.h" @@ -92,7 +96,18 @@ namespace storm { ExplicitQualitativeCheckResult const& leftResult = leftResultPointer->asExplicitQualitativeCheckResult(); ExplicitQualitativeCheckResult const& rightResult = rightResultPointer->asExplicitQualitativeCheckResult(); storm::modelchecker::helper::SparseNondeterministicStepBoundedHorizonHelper helper; - std::vector numericResult = helper.compute(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), pathFormula.getNonStrictLowerBound(), pathFormula.getNonStrictUpperBound(), checkTask.getHint()); + std::vector numericResult; + + //TODO: this does not work with nullptr as defaults for resultMaybeStates and choiceValues + storm::storage::BitVector resultMaybeStates; + std::vector choiceValues; + + if(checkTask.isShieldingTask()) { + numericResult = helper.compute(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), pathFormula.getNonStrictLowerBound(), pathFormula.getNonStrictUpperBound(), checkTask.getHint(), resultMaybeStates, choiceValues); + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(resultMaybeStates), storm::storage::BitVector(resultMaybeStates.size(), false)); + } else { + numericResult = helper.compute(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), pathFormula.getNonStrictLowerBound(), pathFormula.getNonStrictUpperBound(), checkTask.getHint(), resultMaybeStates, choiceValues); + } return std::unique_ptr(new ExplicitQuantitativeCheckResult(std::move(numericResult))); } } @@ -103,8 +118,14 @@ namespace storm { STORM_LOG_THROW(checkTask.isOptimizationDirectionSet(), storm::exceptions::InvalidPropertyException, "Formula needs to specify whether minimal or maximal values are to be computed on nondeterministic model."); std::unique_ptr subResultPointer = this->check(env, pathFormula.getSubformula()); ExplicitQualitativeCheckResult const& subResult = subResultPointer->asExplicitQualitativeCheckResult(); - std::vector numericResult = storm::modelchecker::helper::SparseMdpPrctlHelper::computeNextProbabilities(env, checkTask.getOptimizationDirection(), this->getModel().getTransitionMatrix(), subResult.getTruthValuesVector()); - return std::unique_ptr(new ExplicitQuantitativeCheckResult(std::move(numericResult))); + auto ret = storm::modelchecker::helper::SparseMdpPrctlHelper::computeNextProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), checkTask.getOptimizationDirection(), this->getModel().getTransitionMatrix(), subResult.getTruthValuesVector()); + std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); + if(checkTask.isShieldingTask()) { + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.values), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), false)); + } else if (checkTask.isProduceSchedulersSet() && ret.scheduler) { + result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); + } + return result; } template @@ -117,7 +138,9 @@ namespace storm { ExplicitQualitativeCheckResult const& rightResult = rightResultPointer->asExplicitQualitativeCheckResult(); auto ret = storm::modelchecker::helper::SparseMdpPrctlHelper::computeUntilProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), checkTask.isQualitativeSet(), checkTask.isProduceSchedulersSet(), checkTask.getHint()); std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); - if (checkTask.isProduceSchedulersSet() && ret.scheduler) { + if(checkTask.isShieldingTask()) { + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.values), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), false)); + } else if (checkTask.isProduceSchedulersSet() && ret.scheduler) { result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); } return result; @@ -131,7 +154,9 @@ namespace storm { ExplicitQualitativeCheckResult const& subResult = subResultPointer->asExplicitQualitativeCheckResult(); auto ret = storm::modelchecker::helper::SparseMdpPrctlHelper::computeGloballyProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), subResult.getTruthValuesVector(), checkTask.isQualitativeSet(), checkTask.isProduceSchedulersSet()); std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); - if (checkTask.isProduceSchedulersSet() && ret.scheduler) { + if(checkTask.isShieldingTask()) { + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.values), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), false)); + } else if (checkTask.isProduceSchedulersSet() && ret.scheduler) { result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); } return result;