diff --git a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp index f6161fac2..a6c230625 100644 --- a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp +++ b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp @@ -110,6 +110,34 @@ namespace storm { //std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(values)); //return result; STORM_LOG_THROW(false, storm::exceptions::NotImplementedException, "NYI"); + template + void SparseSmgRpatlModelChecker::coalitionIndicator(Environment& env, CheckTask const& checkTask) { + storm::storage::BitVector coalitionIndicators(this->getModel().getTransitionMatrix().getRowGroupCount()); + + std::vector> formulaPlayerIds = checkTask.getFormula().getCoalition().getPlayerIds(); + std::vector playerIds; + std::vector> playerActionIndices = this->getModel().getPlayerActionIndices(); + + for(auto const& player : formulaPlayerIds) { + // If the player is given via the player name we have to look up its index + if(player.type() == typeid(std::string)) { + auto it = std::find_if(playerActionIndices.begin(), playerActionIndices.end(), + [&player](const std::pair& element){ return element.first == boost::get(player); }); + playerIds.push_back(it->second); + // If the player is given by its index we have to shift it to match internal mappings + } else if(player.type() == typeid(uint_fast64_t)) { + playerIds.push_back(boost::get(player) - 1); + } + } + + for(uint i = 0; i < playerActionIndices.size(); i++) { + if(std::find(playerIds.begin(), playerIds.end(), playerActionIndices.at(i).second) != playerIds.end()) { + coalitionIndicators.set(i); + } + } + coalitionIndicators.complement(); + + env.solver().multiplier().setOptimizationDirectionOverride(coalitionIndicators); } template class SparseSmgRpatlModelChecker>; diff --git a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.h b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.h index 91bb09206..0927fc618 100644 --- a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.h +++ b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.h @@ -31,6 +31,8 @@ namespace storm { virtual std::unique_ptr computeLongRunAverageProbabilities(Environment const& env, CheckTask const& checkTask) override; virtual std::unique_ptr computeLongRunAverageRewards(Environment const& env, storm::logic::RewardMeasureType rewardMeasureType, CheckTask const& checkTask) override; + + void coalitionIndicator(Environment& env, CheckTask const& checkTask); }; } // namespace modelchecker } // namespace storm