diff --git a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp index 30cfcf19f..39b2ea6e9 100644 --- a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp +++ b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp @@ -102,6 +102,8 @@ namespace storm { storm::storage::BitVector resultMaybeStates; std::vector choiceValues; + //TODO: check the result for shields + if(checkTask.isShieldingTask()) { numericResult = helper.compute(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), pathFormula.getNonStrictLowerBound(), pathFormula.getNonStrictUpperBound(), checkTask.getHint(), resultMaybeStates, choiceValues); tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(resultMaybeStates), storm::storage::BitVector(resultMaybeStates.size(), false)); @@ -121,7 +123,8 @@ namespace storm { auto ret = storm::modelchecker::helper::SparseMdpPrctlHelper::computeNextProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), checkTask.getOptimizationDirection(), this->getModel().getTransitionMatrix(), subResult.getTruthValuesVector()); std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); if(checkTask.isShieldingTask()) { - tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.values), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), false)); + //TODO: creating a shield for NEXT does not work + tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), false)); } else if (checkTask.isProduceSchedulersSet() && ret.scheduler) { result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); } diff --git a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp index 04eab8cfe..9ad693722 100644 --- a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp +++ b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp @@ -620,10 +620,18 @@ namespace storm { // create multiplier and execute the calculation for 1 additional step auto multiplier = storm::solver::MultiplierFactory().create(env, transitionMatrix); - std::vector choiceValues = std::vector(transitionMatrix.getRowCount(), storm::utility::zero()); + + uint sizeChoiceValues = 0; + for(uint counter = 0; counter < qualitativeStateSets.maybeStates.size(); counter++) { + if(qualitativeStateSets.maybeStates.get(counter)) { + sizeChoiceValues += transitionMatrix.getRowGroupSize(counter); + } + } + + std::vector choiceValues = std::vector(sizeChoiceValues, storm::utility::zero()); // Check whether we need to compute exact probabilities for some states. - if (qualitative || maybeStatesNotRelevant) { + if (qualitative || maybeStatesNotRelevant || !goal.isShieldingTask()) { // Set the values for all maybe-states to 0.5 to indicate that their probability values are neither 0 nor 1. storm::utility::vector::setVectorValues(result, qualitativeStateSets.maybeStates, storm::utility::convertNumber(0.5)); } else { @@ -663,8 +671,50 @@ namespace storm { } } if (goal.isShieldingTask()) { - multiplier->multiply(env, result, &b, choiceValues); - multiplier->reduce(env, goal.direction(), choiceValues, transitionMatrix.getRowGroupIndices(), result, nullptr); + STORM_LOG_DEBUG("SparseMdpPrctlHelper::computeUntilProbabilities: Before multiply()"); + STORM_LOG_DEBUG(result); + + std::vector subResult; + uint sizeChoiceValues = 0; + for(uint counter = 0; counter < qualitativeStateSets.maybeStates.size(); counter++) { + if(qualitativeStateSets.maybeStates.get(counter)) { + subResult.push_back(result.at(counter)); + } + } + + STORM_LOG_DEBUG(subResult); + + + STORM_LOG_DEBUG(choiceValues); + //std::vector b_complete = transitionMatrix.getConstrainedRowGroupSumVector(storm::storage::BitVector(qualitativeStateSets.maybeStates.size(), true), qualitativeStateSets.statesWithProbability1); + + submatrix = transitionMatrix.getSubmatrix(true, qualitativeStateSets.maybeStates, qualitativeStateSets.maybeStates, false); + + //STORM_LOG_DEBUG(b_complete); + + auto sub_multiplier = storm::solver::MultiplierFactory().create(env, submatrix); + + + sub_multiplier->multiply(env, subResult, &b, choiceValues); + STORM_LOG_DEBUG(choiceValues); + + std::vector allChoices = std::vector(transitionMatrix.getRowGroupIndices().at(transitionMatrix.getRowGroupIndices().size() - 1), storm::utility::zero()); + auto choice_it = choiceValues.begin(); + for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) { + uint rowGroupSize = transitionMatrix.getRowGroupIndices().at(state + 1) - transitionMatrix.getRowGroupIndices().at(state); + if (qualitativeStateSets.maybeStates.get(state)) { + for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { + allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it; + } + } + } + choiceValues = allChoices; + + STORM_LOG_DEBUG(choiceValues); + + + //TODO: fill up choiceValues with Zeros (dimension!) + } } } @@ -703,6 +753,9 @@ namespace storm { for (auto& element : result.values) { element = storm::utility::one() - element; } + for (auto& choice : result.choiceValues) { + choice = storm::utility::one() - choice; + } return result; } }