|
@ -64,8 +64,6 @@ namespace storm { |
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::initialize(std::set<storm::expressions::Variable> const& infinityBoundVariables) { |
|
|
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::initialize(std::set<storm::expressions::Variable> const& infinityBoundVariables) { |
|
|
|
|
|
|
|
|
maxSolutionsStored = 0; |
|
|
|
|
|
|
|
|
|
|
|
STORM_LOG_ASSERT(!SingleObjectiveMode || (this->objectives.size() == 1), "Enabled single objective mode but there are multiple objectives."); |
|
|
STORM_LOG_ASSERT(!SingleObjectiveMode || (this->objectives.size() == 1), "Enabled single objective mode but there are multiple objectives."); |
|
|
std::vector<Epoch> epochSteps; |
|
|
std::vector<Epoch> epochSteps; |
|
|
initializeObjectives(epochSteps, infinityBoundVariables); |
|
|
initializeObjectives(epochSteps, infinityBoundVariables); |
|
@ -333,41 +331,27 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
std::vector<typename MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::Epoch> MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::getEpochComputationOrder(Epoch const& startEpoch) { |
|
|
|
|
|
|
|
|
std::vector<typename MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::Epoch> MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::getEpochComputationOrder(Epoch const& startEpoch, bool stopAtComputedEpochs) { |
|
|
// Perform a DFS to find all the reachable epochs
|
|
|
// Perform a DFS to find all the reachable epochs
|
|
|
std::vector<Epoch> dfsStack; |
|
|
std::vector<Epoch> dfsStack; |
|
|
std::set<Epoch, std::function<bool(Epoch const&, Epoch const&)>> collectedEpochs(std::bind(&EpochManager::epochClassZigZagOrder, &epochManager, std::placeholders::_1, std::placeholders::_2)); |
|
|
std::set<Epoch, std::function<bool(Epoch const&, Epoch const&)>> collectedEpochs(std::bind(&EpochManager::epochClassZigZagOrder, &epochManager, std::placeholders::_1, std::placeholders::_2)); |
|
|
|
|
|
|
|
|
|
|
|
if (!stopAtComputedEpochs || epochSolutions.count(startEpoch) == 0) { |
|
|
collectedEpochs.insert(startEpoch); |
|
|
collectedEpochs.insert(startEpoch); |
|
|
dfsStack.push_back(startEpoch); |
|
|
dfsStack.push_back(startEpoch); |
|
|
|
|
|
} |
|
|
while (!dfsStack.empty()) { |
|
|
while (!dfsStack.empty()) { |
|
|
Epoch currentEpoch = dfsStack.back(); |
|
|
Epoch currentEpoch = dfsStack.back(); |
|
|
dfsStack.pop_back(); |
|
|
dfsStack.pop_back(); |
|
|
for (auto const& step : possibleEpochSteps) { |
|
|
for (auto const& step : possibleEpochSteps) { |
|
|
Epoch successorEpoch = epochManager.getSuccessorEpoch(currentEpoch, step); |
|
|
Epoch successorEpoch = epochManager.getSuccessorEpoch(currentEpoch, step); |
|
|
/*
|
|
|
|
|
|
for (auto const& e : collectedEpochs) { |
|
|
|
|
|
std::cout << "Comparing " << epochManager.toString(e) << " and " << epochManager.toString(successorEpoch) << std::endl; |
|
|
|
|
|
if (epochManager.epochClassZigZagOrder(e, successorEpoch)) { |
|
|
|
|
|
std::cout << " " << epochManager.toString(e) << " < " << epochManager.toString(successorEpoch) << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
if (epochManager.epochClassZigZagOrder(successorEpoch, e)) { |
|
|
|
|
|
std::cout << " " << epochManager.toString(e) << " > " << epochManager.toString(successorEpoch) << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
*/ |
|
|
|
|
|
|
|
|
if (!stopAtComputedEpochs || epochSolutions.count(successorEpoch) == 0) { |
|
|
if (collectedEpochs.insert(successorEpoch).second) { |
|
|
if (collectedEpochs.insert(successorEpoch).second) { |
|
|
dfsStack.push_back(std::move(successorEpoch)); |
|
|
dfsStack.push_back(std::move(successorEpoch)); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
/*
|
|
|
|
|
|
std::cout << "Resulting order: "; |
|
|
|
|
|
for (auto const& e : collectedEpochs) { |
|
|
|
|
|
std::cout << epochManager.toString(e) << ", "; |
|
|
|
|
|
} |
|
|
} |
|
|
std::cout << std::endl; |
|
|
|
|
|
*/ |
|
|
|
|
|
return std::vector<Epoch>(collectedEpochs.begin(), collectedEpochs.end()); |
|
|
return std::vector<Epoch>(collectedEpochs.begin(), collectedEpochs.end()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -629,8 +613,6 @@ namespace storm { |
|
|
epochModel.objectiveRewardFilter.push_back(storm::utility::vector::filterZero(objRewards)); |
|
|
epochModel.objectiveRewardFilter.push_back(storm::utility::vector::filterZero(objRewards)); |
|
|
epochModel.objectiveRewardFilter.back().complement(); |
|
|
epochModel.objectiveRewardFilter.back().complement(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
epochModelSizes.push_back(epochModel.epochMatrix.getRowGroupCount()); |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -846,9 +828,6 @@ namespace storm { |
|
|
solution.productStateToSolutionVectorMap = productStateToEpochModelInStateMap; |
|
|
solution.productStateToSolutionVectorMap = productStateToEpochModelInStateMap; |
|
|
solution.solutions = std::move(inStateSolutions); |
|
|
solution.solutions = std::move(inStateSolutions); |
|
|
epochSolutions[currentEpoch.get()] = std::move(solution); |
|
|
epochSolutions[currentEpoch.get()] = std::move(solution); |
|
|
|
|
|
|
|
|
maxSolutionsStored = std::max((uint64_t) epochSolutions.size(), maxSolutionsStored); |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|