diff --git a/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp b/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp index 3e6991854..fd86fecb4 100644 --- a/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp +++ b/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp @@ -331,21 +331,22 @@ namespace storm { } // compute the solution for the stepChoices // For optimization purposes, we distinguish the case where the memory state does not have to be transformed + EpochSolution const& successorEpochSolution = getEpochSolution(subSolutions, successorEpoch); SolutionType choiceSolution; bool firstSuccessor = true; if (!containsLowerBoundedObjective && epochManager.compareEpochClass(epoch, successorEpoch)) { for (auto const& successor : productModel->getProduct().getTransitionMatrix().getRow(productChoice)) { if (firstSuccessor) { - choiceSolution = getScaledSolution(getStateSolution(subSolutions, successorEpoch, successor.getColumn()), successor.getValue()); + choiceSolution = getScaledSolution(getStateSolution(successorEpochSolution, successor.getColumn()), successor.getValue()); firstSuccessor = false; } else { - addScaledSolution(choiceSolution, getStateSolution(subSolutions, successorEpoch, successor.getColumn()), successor.getValue()); + addScaledSolution(choiceSolution, getStateSolution(successorEpochSolution, successor.getColumn()), successor.getValue()); } } } else { for (auto const& successor : productModel->getProduct().getTransitionMatrix().getRow(productChoice)) { uint64_t successorProductState = productModel->transformProductState(successor.getColumn(), successorEpochClass, memoryState); - SolutionType const& successorSolution = getStateSolution(subSolutions, successorEpoch, successorProductState); + SolutionType const& successorSolution = getStateSolution(successorEpochSolution, successorProductState); if (firstSuccessor) { choiceSolution = getScaledSolution(successorSolution, successor.getValue()); firstSuccessor = false; @@ -580,14 +581,17 @@ namespace storm { STORM_LOG_ASSERT((*epochSolution.productStateToSolutionVectorMap)[productState] < epochSolution.solutions.size(), "Requested solution for epoch " << epochManager.toString(epoch) << " at a state for which no solution was stored."); return epochSolution.solutions[(*epochSolution.productStateToSolutionVectorMap)[productState]]; } - template - typename MultiDimensionalRewardUnfolding::SolutionType const& MultiDimensionalRewardUnfolding::getStateSolution(std::map const& solutions, Epoch const& epoch, uint64_t const& productState) { + typename MultiDimensionalRewardUnfolding::EpochSolution const& MultiDimensionalRewardUnfolding::getEpochSolution(std::map const& solutions, Epoch const& epoch) { auto epochSolutionIt = solutions.find(epoch); STORM_LOG_ASSERT(epochSolutionIt != solutions.end(), "Requested unexisting solution for epoch " << epochManager.toString(epoch) << "."); - auto const& epochSolution = *epochSolutionIt->second; - STORM_LOG_ASSERT(productState < epochSolution.productStateToSolutionVectorMap->size(), "Requested solution for epoch " << epochManager.toString(epoch) << " at an unexisting product state."); - STORM_LOG_ASSERT((*epochSolution.productStateToSolutionVectorMap)[productState] < epochSolution.solutions.size(), "Requested solution for epoch " << epochManager.toString(epoch) << " at a state for which no solution was stored."); + return *epochSolutionIt->second; + } + + template + typename MultiDimensionalRewardUnfolding::SolutionType const& MultiDimensionalRewardUnfolding::getStateSolution(EpochSolution const& epochSolution, uint64_t const& productState) { + STORM_LOG_ASSERT(productState < epochSolution.productStateToSolutionVectorMap->size(), "Requested solution at an unexisting product state."); + STORM_LOG_ASSERT((*epochSolution.productStateToSolutionVectorMap)[productState] < epochSolution.solutions.size(), "Requested solution for epoch at a state for which no solution was stored."); return epochSolution.solutions[(*epochSolution.productStateToSolutionVectorMap)[productState]]; } diff --git a/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.h b/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.h index 782f53b2f..28c25747e 100644 --- a/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.h +++ b/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.h @@ -105,7 +105,8 @@ namespace storm { std::vector solutions; }; std::map epochSolutions; - SolutionType const& getStateSolution(std::map const& solutions, Epoch const& epoch, uint64_t const& productState); + EpochSolution const& getEpochSolution(std::map const& solutions, Epoch const& epoch); + SolutionType const& getStateSolution(EpochSolution const& epochSolution, uint64_t const& productState); storm::models::sparse::Mdp const& model; std::vector> objectives;