|
|
@ -467,12 +467,12 @@ namespace storm { |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
MultiDimensionalRewardUnfolding<ValueType>::MemoryProduct::MemoryProduct(storm::models::sparse::Mdp<ValueType> const& model, storm::storage::MemoryStructure const& memory, std::vector<storm::storage::BitVector>&& memoryStateMap, std::vector<std::vector<uint64_t>> const& originalModelSteps, std::vector<storm::storage::BitVector> const& objectiveDimensions) : memoryStateMap(std::move(memoryStateMap)) { |
|
|
|
|
|
|
|
storm::storage::SparseModelMemoryProduct<ValueType> productBuilder(memory.product(model)); |
|
|
|
|
|
|
|
setReachableStates(productBuilder, originalModelSteps, objectiveDimensions); |
|
|
|
|
|
|
|
sw1.stop(); |
|
|
|
product = productBuilder.build()->template as<storm::models::sparse::Mdp<ValueType>>(); |
|
|
|
sw2.stop(); |
|
|
|
|
|
|
|
uint64_t numModelStates = productBuilder.getOriginalModel().getNumberOfStates(); |
|
|
|
uint64_t numMemoryStates = productBuilder.getMemory().getNumberOfStates(); |
|
|
@ -530,25 +530,46 @@ namespace storm { |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
void MultiDimensionalRewardUnfolding<ValueType>::MemoryProduct::setReachableStates(storm::storage::SparseModelMemoryProduct<ValueType>& productBuilder, std::vector<std::vector<uint64_t>> const& originalModelSteps, std::vector<storm::storage::BitVector> const& objectiveDimensions) const { |
|
|
|
std::vector<storm::storage::BitVector> additionalReachableStates(memoryStateMap.size(), storm::storage::BitVector(productBuilder.getOriginalModel().getNumberOfStates(), false)); |
|
|
|
for (uint64_t memState = 0; memState < memoryStateMap.size(); ++memState) { |
|
|
|
auto const& memStateBv = memoryStateMap[memState]; |
|
|
|
storm::storage::BitVector stepChoices(productBuilder.getOriginalModel().getTransitionMatrix().getRowCount(), false); |
|
|
|
for (auto const& subObjectives : objectiveDimensions) { |
|
|
|
if (subObjectives.isDisjointFrom(memStateBv)) { |
|
|
|
for (auto const& subObj : subObjectives) { |
|
|
|
stepChoices |= storm::utility::vector::filterGreaterZero(originalModelSteps[subObj]); |
|
|
|
} |
|
|
|
storm::storage::BitVector consideredObjectives(objectiveDimensions.size(), false); |
|
|
|
do { |
|
|
|
storm::storage::BitVector memStatePrimeBv = memStateBv; |
|
|
|
for (auto const& objIndex : consideredObjectives) { |
|
|
|
memStatePrimeBv &= ~objectiveDimensions[objIndex]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
storm::storage::BitVector stepChoiceSuccessors(productBuilder.getOriginalModel().getNumberOfStates(), false); |
|
|
|
for (auto const& choice : stepChoices) { |
|
|
|
for (auto const& successor : productBuilder.getOriginalModel().getTransitionMatrix().getRow(choice)) { |
|
|
|
stepChoiceSuccessors.set(successor.getColumn(), true); |
|
|
|
if (memStatePrimeBv != memStateBv) { |
|
|
|
for (uint64_t choice = 0; choice < productBuilder.getOriginalModel().getTransitionMatrix().getRowCount(); ++choice) { |
|
|
|
bool consideredChoice = true; |
|
|
|
for (auto const& objIndex : consideredObjectives) { |
|
|
|
bool objectiveHasStep = false; |
|
|
|
for (auto const& dim : objectiveDimensions[objIndex]) { |
|
|
|
if (originalModelSteps[dim][choice] > 0) { |
|
|
|
objectiveHasStep = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!objectiveHasStep) { |
|
|
|
consideredChoice = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (consideredChoice) { |
|
|
|
for (auto const& successor : productBuilder.getOriginalModel().getTransitionMatrix().getRow(choice)) { |
|
|
|
if (productBuilder.isStateReachable(successor.getColumn(), memState)) { |
|
|
|
additionalReachableStates[convertMemoryState(memStatePrimeBv)].set(successor.getColumn()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto const& modelState : stepChoiceSuccessors) { |
|
|
|
consideredObjectives.increment(); |
|
|
|
} while (!consideredObjectives.empty()); |
|
|
|
} |
|
|
|
|
|
|
|
for (uint64_t memState = 0; memState < memoryStateMap.size(); ++memState) { |
|
|
|
for (auto const& modelState : additionalReachableStates[memState]) { |
|
|
|
productBuilder.addReachableState(modelState, memState); |
|
|
|
} |
|
|
|
} |
|
|
|