diff --git a/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp b/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp index 9af9fe78a..0c18046af 100644 --- a/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp +++ b/src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp @@ -288,7 +288,8 @@ namespace storm { for (auto const& reducedChoice : epochModel.stepChoices) { uint64_t productChoice = epochModelToProductChoiceMap[reducedChoice]; uint64_t productState = productModel->getProductStateFromChoice(productChoice); - auto memoryState = productModel->convertMemoryState(productModel->getMemoryState(productState)); + auto const& memoryState = productModel->getMemoryState(productState); + auto const& memoryStateBv = productModel->convertMemoryState(memoryState); Epoch successorEpoch = epochManager.getSuccessorEpoch(epoch, productModel->getSteps()[productChoice]); // Find out whether objective reward is earned for the current choice @@ -299,7 +300,7 @@ namespace storm { bool rewardEarned = !storm::utility::isZero(epochModel.objectiveRewards[objIndex][reducedChoice]); if (rewardEarned) { for (auto const& dim : objectiveDimensions[objIndex]) { - if (dimensions[dim].isUpperBounded == epochManager.isBottomDimension(successorEpoch, dim) && memoryState.get(dim)) { + if (dimensions[dim].isUpperBounded == epochManager.isBottomDimension(successorEpoch, dim) && memoryStateBv.get(dim)) { rewardEarned = false; break; } @@ -308,7 +309,7 @@ namespace storm { epochModel.objectiveRewardFilter[objIndex].set(reducedChoice, rewardEarned); } // compute the solution for the stepChoices - // For optimization purposes, we distinguish the easier case where the successor epoch lies in the same epoch class + // For optimization purposes, we distinguish the case where the memory state does not have to be transformed SolutionType choiceSolution; bool firstSuccessor = true; if (!containsLowerBoundedObjective && epochManager.compareEpochClass(epoch, successorEpoch)) { @@ -321,18 +322,8 @@ namespace storm { } } } else { - storm::storage::BitVector allowedRelevantDimensions(epochManager.getDimensionCount(), true); - storm::storage::BitVector forcedRelevantDimensions(epochManager.getDimensionCount(), false); - for (auto const& dim : memoryState) { - if (epochManager.isBottomDimension(successorEpoch, dim) && dimensions[dim].isUpperBounded) { - allowedRelevantDimensions &= ~objectiveDimensions[dimensions[dim].objectiveIndex]; - } else if (!epochManager.isBottomDimension(successorEpoch, dim) && !dimensions[dim].isUpperBounded) { - forcedRelevantDimensions.set(dim, true); - } - } for (auto const& successor : productModel->getProduct().getTransitionMatrix().getRow(productChoice)) { - storm::storage::BitVector successorMemoryState = (productModel->convertMemoryState(productModel->getMemoryState(successor.getColumn())) | forcedRelevantDimensions) & allowedRelevantDimensions; - uint64_t successorProductState = productModel->getProductState(productModel->getModelState(successor.getColumn()), productModel->convertMemoryState(successorMemoryState)); + uint64_t successorProductState = productModel->transformProductState(successor.getColumn(), epochManager.getEpochClass(successorEpoch), memoryState); SolutionType const& successorSolution = getStateSolution(successorEpoch, successorProductState); if (firstSuccessor) { choiceSolution = getScaledSolution(successorSolution, successor.getValue()); @@ -398,12 +389,9 @@ namespace storm { } if (!violatedLowerBoundedDimensions.empty()) { for (uint64_t state = 0; state < epochModel.epochMatrix.getRowGroupCount(); ++state) { - auto const& memoryState = productModel->convertMemoryState(productModel->getMemoryState(state)); - storm::storage::BitVector forcedRelevantDimensions = memoryState & violatedLowerBoundedDimensions; + auto const& memoryState = productModel->getMemoryState(state); for (auto& entry : epochModel.epochMatrix.getRowGroup(state)) { - storm::storage::BitVector successorMemoryState = productModel->convertMemoryState(productModel->getMemoryState(entry.getColumn())); - successorMemoryState |= forcedRelevantDimensions; - entry.setColumn(productModel->getProductState(productModel->getModelState(entry.getColumn()), productModel->convertMemoryState(successorMemoryState))); + entry.setColumn(productModel->transformProductState(entry.getColumn(), epochClass, memoryState)); } } } diff --git a/src/storm/modelchecker/multiobjective/rewardbounded/ProductModel.cpp b/src/storm/modelchecker/multiobjective/rewardbounded/ProductModel.cpp index 782b34f03..765febc24 100644 --- a/src/storm/modelchecker/multiobjective/rewardbounded/ProductModel.cpp +++ b/src/storm/modelchecker/multiobjective/rewardbounded/ProductModel.cpp @@ -71,7 +71,6 @@ namespace storm { } } } - computeReachableStatesInEpochClasses(); } @@ -325,6 +324,22 @@ namespace storm { void ProductModel::computeReachableStatesInEpochClasses() { std::set possibleSteps(steps.begin(), steps.end()); std::set> reachableEpochClasses(std::bind(&EpochManager::epochClassOrder, &epochManager, std::placeholders::_1, std::placeholders::_2)); + + collectReachableEpochClasses(reachableEpochClasses, possibleSteps); + + for (auto epochClassIt = reachableEpochClasses.rbegin(); epochClassIt != reachableEpochClasses.rend(); ++epochClassIt) { + std::vector predecessors; + for (auto predecessorIt = reachableEpochClasses.rbegin(); predecessorIt != epochClassIt; ++predecessorIt) { + if (epochManager.isPredecessorEpochClass(*predecessorIt, *epochClassIt)) { + predecessors.push_back(*predecessorIt); + } + } + computeReachableStates(*epochClassIt, predecessors); + } + } + + template + void ProductModel::collectReachableEpochClasses(std::set>& reachableEpochClasses, std::set const& possibleSteps) const { std::vector candidates({epochManager.getBottomEpoch()}); std::set newCandidates; @@ -340,19 +355,8 @@ namespace storm { candidates.assign(newCandidates.begin(), newCandidates.end()); newCandidates.clear(); } - - for (auto epochClassIt = reachableEpochClasses.rbegin(); epochClassIt != reachableEpochClasses.rend(); ++epochClassIt) { - std::vector predecessors; - for (auto predecessorIt = reachableEpochClasses.rbegin(); predecessorIt != epochClassIt; ++predecessorIt) { - if (epochManager.isPredecessorEpochClass(*predecessorIt, *epochClassIt)) { - predecessors.push_back(*predecessorIt); - } - } - computeReachableStates(*epochClassIt, predecessors); - } } - template void ProductModel::computeReachableStates(EpochClass const& epochClass, std::vector const& predecessors) { @@ -364,24 +368,12 @@ namespace storm { } storm::storage::BitVector nonBottomDimensions = ~bottomDimensions; - // Bottom dimensions corresponding to upper bounded subobjectives can not be relevant anymore - // Dimensions with a lower bound where the epoch class is not bottom should stay relevant - storm::storage::BitVector allowedRelevantDimensions(epochManager.getDimensionCount(), true); - storm::storage::BitVector forcedRelevantDimensions(epochManager.getDimensionCount(), false); - for (uint64_t dim = 0; dim < epochManager.getDimensionCount(); ++dim) { - if (dimensions[dim].isUpperBounded && bottomDimensions.get(dim)) { - allowedRelevantDimensions.set(dim, false); - } else if (!dimensions[dim].isUpperBounded && nonBottomDimensions.get(dim)) { - forcedRelevantDimensions.set(dim, true); - } - } - assert(forcedRelevantDimensions.isSubsetOf(allowedRelevantDimensions)); - storm::storage::BitVector ecInStates(getProduct().getNumberOfStates(), false); if (!epochManager.hasBottomDimensionEpochClass(epochClass)) { + storm::storage::BitVector initMemState(epochManager.getDimensionCount(), true); for (auto const& initState : getProduct().getInitialStates()) { - uint64_t transformedInitState = transformProductState(initState, allowedRelevantDimensions, forcedRelevantDimensions); + uint64_t transformedInitState = transformProductState(initState, epochClass, convertMemoryState(initMemState)); ecInStates.set(transformedInitState, true); } } @@ -396,15 +388,6 @@ namespace storm { STORM_LOG_ASSERT(reachableStates.find(predecessor) != reachableStates.end(), "Could not find reachable states of predecessor epoch class."); storm::storage::BitVector predecessorStates = reachableStates.find(predecessor)->second; for (auto const& predecessorState : predecessorStates) { - storm::storage::BitVector const& predecessorMemStateBv = convertMemoryState(getMemoryState(predecessorState)); - storm::storage::BitVector currentAllowedRelDim = allowedRelevantDimensions; - for (uint64_t dim = 0; dim < epochManager.getDimensionCount(); ++dim) { - if (!allowedRelevantDimensions.get(dim) && predecessorMemStateBv.get(dim)) { - currentAllowedRelDim &= ~objectiveDimensions[dimensions[dim].objectiveIndex]; - } - } - storm::storage::BitVector currentForcedRelDim = forcedRelevantDimensions & predecessorMemStateBv; - for (uint64_t choice = getProduct().getTransitionMatrix().getRowGroupIndices()[predecessorState]; choice < getProduct().getTransitionMatrix().getRowGroupIndices()[predecessorState + 1]; ++choice) { bool choiceLeadsToThisClass = false; Epoch const& choiceStep = getSteps()[choice]; @@ -416,7 +399,8 @@ namespace storm { if (choiceLeadsToThisClass) { for (auto const& transition : getProduct().getTransitionMatrix().getRow(choice)) { - uint64_t successorState = transformProductState(transition.getColumn(), currentAllowedRelDim, currentForcedRelDim); + uint64_t successorState = transformProductState(transition.getColumn(), epochClass, getMemoryState(predecessorState)); + ecInStates.set(successorState, true); } } @@ -432,16 +416,6 @@ namespace storm { uint64_t currentState = dfsStack.back(); dfsStack.pop_back(); - storm::storage::BitVector const& currentMemStateBv = convertMemoryState(getMemoryState(currentState)); - storm::storage::BitVector currentAllowedRelDim = allowedRelevantDimensions; - for (uint64_t dim = 0; dim < epochManager.getDimensionCount(); ++dim) { - if (!allowedRelevantDimensions.get(dim) && currentMemStateBv.get(dim)) { - currentAllowedRelDim &= ~objectiveDimensions[dimensions[dim].objectiveIndex]; - } - } - storm::storage::BitVector currentForcedRelDim = forcedRelevantDimensions & currentMemStateBv; - - for (uint64_t choice = getProduct().getTransitionMatrix().getRowGroupIndices()[currentState]; choice != getProduct().getTransitionMatrix().getRowGroupIndices()[currentState + 1]; ++choice) { bool choiceLeadsOutsideOfEpoch = false; @@ -453,7 +427,7 @@ namespace storm { } for (auto const& transition : getProduct().getTransitionMatrix().getRow(choice)) { - uint64_t successorState = transformProductState(transition.getColumn(), currentAllowedRelDim, currentForcedRelDim); + uint64_t successorState = transformProductState(transition.getColumn(), epochClass, getMemoryState(currentState)); if (choiceLeadsOutsideOfEpoch) { ecInStates.set(successorState, true); } @@ -471,15 +445,31 @@ namespace storm { } template - uint64_t ProductModel::transformProductState(uint64_t productState, storm::storage::BitVector const& allowedRelevantDimensions, storm::storage::BitVector const& forcedRelevantDimensions) const { - if (allowedRelevantDimensions.full() && forcedRelevantDimensions.empty()) { - return productState; - } else { - storm::storage::BitVector memoryStateBv = (convertMemoryState(getMemoryState(productState)) | forcedRelevantDimensions) & allowedRelevantDimensions; - return getProductState(getModelState(productState), convertMemoryState(memoryStateBv)); + uint64_t ProductModel::transformMemoryState(uint64_t const& memoryState, EpochClass const& epochClass, uint64_t const& predecessorMemoryState) const { + storm::storage::BitVector memoryStateBv = convertMemoryState(memoryState); + storm::storage::BitVector const& predecessorMemoryStateBv = convertMemoryState(predecessorMemoryState); + + for (uint64_t objIndex = 0; objIndex < objectiveDimensions.size(); ++objIndex) { + for (auto const& dim : objectiveDimensions[objIndex]) { + bool dimUpperBounded = dimensions[dim].isUpperBounded; + bool dimBottom = epochManager.isBottomDimensionEpochClass(epochClass, dim); + if (dimUpperBounded && dimBottom && predecessorMemoryStateBv.get(dim)) { + memoryStateBv &= ~objectiveDimensions[objIndex]; + break; + } else if (!dimUpperBounded && !dimBottom && predecessorMemoryStateBv.get(dim)) { + memoryStateBv.set(dim, true); + } + } } + + return convertMemoryState(memoryStateBv); } + template + uint64_t ProductModel::transformProductState(uint64_t const& productState, EpochClass const& epochClass, uint64_t const& predecessorMemoryState) const { + return getProductState(getModelState(productState), transformMemoryState(getMemoryState(productState), epochClass, predecessorMemoryState)); + } + template class ProductModel; template class ProductModel; diff --git a/src/storm/modelchecker/multiobjective/rewardbounded/ProductModel.h b/src/storm/modelchecker/multiobjective/rewardbounded/ProductModel.h index b0663145d..f030c1812 100644 --- a/src/storm/modelchecker/multiobjective/rewardbounded/ProductModel.h +++ b/src/storm/modelchecker/multiobjective/rewardbounded/ProductModel.h @@ -44,15 +44,20 @@ namespace storm { std::vector> computeObjectiveRewards(EpochClass const& epochClass, std::vector> const& objectives) const; storm::storage::BitVector const& getInStates(EpochClass const& epochClass) const; + + uint64_t transformMemoryState(uint64_t const& memoryState, EpochClass const& epochClass, uint64_t const& predecessorMemoryState) const; + uint64_t transformProductState(uint64_t const& productState, EpochClass const& epochClass, uint64_t const& predecessorMemoryState) const; + private: void setReachableProductStates(storm::storage::SparseModelMemoryProduct& productBuilder, std::vector const& originalModelSteps) const; + void collectReachableEpochClasses(std::set>& reachableEpochClasses, std::set const& possibleSteps) const; + void computeReachableStatesInEpochClasses(); void computeReachableStates(EpochClass const& epochClass, std::vector const& predecessors); - uint64_t transformProductState(uint64_t productState, storm::storage::BitVector const& allowedRelevantDimensions, storm::storage::BitVector const& forcedRelevantDimensions) const; - + std::vector> const& dimensions; std::vector const& objectiveDimensions; EpochManager const& epochManager;