diff --git a/src/storm/modelchecker/multiobjective/deterministicScheds/DetSchedsWeightVectorChecker.cpp b/src/storm/modelchecker/multiobjective/deterministicScheds/DetSchedsWeightVectorChecker.cpp new file mode 100644 index 000000000..28296037a --- /dev/null +++ b/src/storm/modelchecker/multiobjective/deterministicScheds/DetSchedsWeightVectorChecker.cpp @@ -0,0 +1,80 @@ +#include "storm/modelchecker/multiobjective/deterministicScheds/DetSchedsWeightVectorChecker.h" + +namespace storm { + namespace modelchecker { + namespace multiobjective { + + template + DetSchedsWeightVectorChecker::DetSchedsWeightVectorChecker(std::shared_ptr> const& schedulerEvaluator) : schedulerEvaluator(schedulerEvaluator) { + // Intentionally left empty; + } + + template + std::vector::ValueType>> DetSchedsWeightVectorChecker::check(Environment const& env, std::vector const& weightVector) { + std::vector> resultStack; + auto const& transitionMatrix = schedulerEvaluator->getModel().getTransitionMatrix(); + auto const& choiceIndices = schedulerEvaluator->getModel().getNondeterministicChoiceIndices(); + + uint64_t const numObjectives = weightVector.size(); + + // perform policy-iteration and store the intermediate results on the stack + do { + schedulerEvaluator->check(env); + resultStack.push_back(schedulerEvaluator->getInitialStateResults()); + + auto const& stateResults = schedulerEvaluator->getResults(); + + // Check if scheduler choices can be improved + auto const& scheduler = schedulerEvaluator->getScheduler(); + for (uint64_t state = 0; state < scheduler.size(); ++state) { + uint64_t choiceOffset = choiceIndices[state]; + uint64_t numChoices = choiceIndices[state + 1] - choiceOffset; + uint64_t currChoice = scheduler[state]; + + ValueType currChoiceValue = storm::utility::zero(); + for (uint64_t objIndex = 0; objIndex < numObjectives; ++objIndex) { + currChoiceValue += weightVector[objIndex] * stateResults[objIndex][state]; + } + + for (uint64_t choice = 0; choice < numChoices; ++choice) { + // Skip the currently selected choice + if (choice == currChoice) { + continue; + } + + ValueType choiceValue = storm::utility::zero(); + for (uint64_t objIndex = 0; objIndex < numObjectives; ++objIndex) { + if (schedulerEvaluator->getSchedulerIndependentStates(objIndex).get(state)) { + choiceValue += weightVector[objIndex] * stateResults[objIndex][state]; + } else { + ValueType objValue = storm::utility::zero(); + for (auto const& entry : transitionMatrix.getRow(choiceOffset + choice)) { + objValue += entry.getValue() * stateResults[objIndex][entry.getColumn()]; + } + choiceValue += weightVector[objIndex] * objValue; + } + } + + if (choiceValue > currChoiceValue) { + schedulerEvaluator->setChoiceAtState(state, choice); + } + } + } + } while (!schedulerEvaluator->hasCurrentSchedulerBeenChecked()); + + } + + template + std::vector::ValueType> const& DetSchedsWeightVectorChecker::getResultForAllStates(uint64_t objIndex) const { + return schedulerEvaluator->getResultForObjective(objIndex); + } + + template + std::vector const& DetSchedsWeightVectorChecker::getScheduler() const { + return schedulerEvaluator->getScheduler(); + } + + + } + } +} \ No newline at end of file diff --git a/src/storm/modelchecker/multiobjective/deterministicScheds/DetSchedsWeightVectorChecker.h b/src/storm/modelchecker/multiobjective/deterministicScheds/DetSchedsWeightVectorChecker.h new file mode 100644 index 000000000..b7d253586 --- /dev/null +++ b/src/storm/modelchecker/multiobjective/deterministicScheds/DetSchedsWeightVectorChecker.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include "storm/modelchecker/multiobjective/deterministicScheds/MultiObjectiveSchedulerEvaluator.h" + +namespace storm { + + class Environment; + + namespace modelchecker { + namespace multiobjective { + + template + class DetSchedsWeightVectorChecker { + public: + + typedef typename ModelType::ValueType ValueType; + + DetSchedsWeightVectorChecker(std::shared_ptr> const& schedulerEvaluator); + + /*! + * Optimizes the objectives in the given direction. + * Returns a sequence of points such that all points are achievable and the last point is the farest point in the given direction. + * After calling this, getResultForAllStates and getScheduler yield results with respect to that last point. + */ + std::vector> check(Environment const& env, std::vector const& weightVector); + + std::vector const& getResultForAllStates(uint64_t objIndex) const; + std::vector const& getScheduler() const; + + private: + std::shared_ptr> schedulerEvaluator; + + }; + + } + } +} \ No newline at end of file