diff --git a/src/storm-cli-utilities/model-handling.h b/src/storm-cli-utilities/model-handling.h index 41faf8732..41df3fabf 100644 --- a/src/storm-cli-utilities/model-handling.h +++ b/src/storm-cli-utilities/model-handling.h @@ -422,7 +422,10 @@ namespace storm { STORM_LOG_THROW(model->isSparseModel(), storm::exceptions::NotSupportedException, "Counterexample generation is currently only supported for sparse models."); auto sparseModel = model->as>(); - + for (auto& rewModel : sparseModel->getRewardModels()) { + rewModel.second.reduceToStateBasedRewards(sparseModel->getTransitionMatrix(), true); + } + STORM_LOG_THROW(sparseModel->isOfType(storm::models::ModelType::Dtmc) || sparseModel->isOfType(storm::models::ModelType::Mdp), storm::exceptions::NotSupportedException, "Counterexample is currently only supported for discrete-time models."); auto counterexampleSettings = storm::settings::getModule(); diff --git a/src/storm-counterexamples/counterexamples/SMTMinimalLabelSetGenerator.h b/src/storm-counterexamples/counterexamples/SMTMinimalLabelSetGenerator.h index a3dfbd066..52a24d0dd 100644 --- a/src/storm-counterexamples/counterexamples/SMTMinimalLabelSetGenerator.h +++ b/src/storm-counterexamples/counterexamples/SMTMinimalLabelSetGenerator.h @@ -1516,7 +1516,7 @@ namespace storm { * Returns the sub-model obtained from removing all choices that do not originate from the specified filterLabelSet. * Also returns the Labelsets of the sub-model. */ - static std::pair>, std::vector>> restrictModelToLabelSet(storm::models::sparse::Model const& model, boost::container::flat_set const& filterLabelSet, boost::optional absorbState = boost::none) { + static std::pair>, std::vector>> restrictModelToLabelSet(storm::models::sparse::Model const& model, boost::container::flat_set const& filterLabelSet, boost::optional const& rewardName = boost::none, boost::optional absorbState = boost::none) { bool customRowGrouping = model.isOfType(storm::models::ModelType::Mdp); @@ -1556,13 +1556,19 @@ namespace storm { resultLabelSet.emplace_back(); ++currentRow; } + } - + + if (rewardName) { + auto const &origRewModel = model.getRewardModel(rewardName.get()); + assert(origRewModel.hasOnlyStateRewards()); + } + std::shared_ptr> resultModel; if (model.isOfType(storm::models::ModelType::Dtmc)) { - resultModel = std::make_shared>(transitionMatrixBuilder.build(), storm::models::sparse::StateLabeling(model.getStateLabeling())); + resultModel = std::make_shared>(transitionMatrixBuilder.build(), storm::models::sparse::StateLabeling(model.getStateLabeling()), model.getRewardModels()); } else { - resultModel = std::make_shared>(transitionMatrixBuilder.build(), storm::models::sparse::StateLabeling(model.getStateLabeling())); + resultModel = std::make_shared>(transitionMatrixBuilder.build(), storm::models::sparse::StateLabeling(model.getStateLabeling()), model.getRewardModels()); } return std::make_pair(resultModel, std::move(resultLabelSet)); @@ -1748,7 +1754,7 @@ namespace storm { } - auto subChoiceOrigins = restrictModelToLabelSet(model, commandSet, psiStates.getNextSetIndex(0)); + auto subChoiceOrigins = restrictModelToLabelSet(model, commandSet, rewardName, psiStates.getNextSetIndex(0)); std::shared_ptr> const& subModel = subChoiceOrigins.first; std::vector> const& subLabelSets = subChoiceOrigins.second; @@ -1958,7 +1964,8 @@ namespace storm { rewardName = rewardOperator.getRewardModelName(); STORM_LOG_THROW(!storm::logic::isLowerBound(comparisonType), storm::exceptions::NotSupportedException, "Lower bounds in counterexamples are only supported for probability formulas."); - + STORM_LOG_THROW(model.hasRewardModel(rewardName.get()), storm::exceptions::InvalidPropertyException, "Property refers to reward " << rewardName.get() << " but model does not contain such a reward model."); + STORM_LOG_THROW(model.getRewardModel(rewardName.get()).hasOnlyStateRewards(), storm::exceptions::NotSupportedException, "We only support state-based rewards at the moment."); } bool strictBound = comparisonType == storm::logic::ComparisonType::Less; storm::logic::Formula const& subformula = formula->asOperatorFormula().getSubformula();