|
@ -55,6 +55,9 @@ namespace storm { |
|
|
|
|
|
|
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::initialize() { |
|
|
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::initialize() { |
|
|
|
|
|
|
|
|
|
|
|
maxSolutionsStored = 0; |
|
|
|
|
|
|
|
|
swInit.start(); |
|
|
swInit.start(); |
|
|
STORM_LOG_ASSERT(!SingleObjectiveMode || (this->objectives.size() == 1), "Enabled single objective mode but there are multiple objectives."); |
|
|
STORM_LOG_ASSERT(!SingleObjectiveMode || (this->objectives.size() == 1), "Enabled single objective mode but there are multiple objectives."); |
|
|
std::vector<Epoch> epochSteps; |
|
|
std::vector<Epoch> epochSteps; |
|
@ -438,28 +441,57 @@ namespace storm { |
|
|
storm::storage::BitVector MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::computeProductInStatesForEpochClass(Epoch const& epoch) { |
|
|
storm::storage::BitVector MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::computeProductInStatesForEpochClass(Epoch const& epoch) { |
|
|
storm::storage::SparseMatrix<ValueType> const& productMatrix = memoryProduct.getProduct().getTransitionMatrix(); |
|
|
storm::storage::SparseMatrix<ValueType> const& productMatrix = memoryProduct.getProduct().getTransitionMatrix(); |
|
|
|
|
|
|
|
|
storm::storage::BitVector result = memoryProduct.getProduct().getInitialStates(); |
|
|
|
|
|
|
|
|
// Initialize the result. Initial states are only considered if the epoch contains no bottom dimension.
|
|
|
|
|
|
storm::storage::BitVector result; |
|
|
|
|
|
bool epochHasBottomDimension = false; |
|
|
|
|
|
for (uint64_t dim = 0; dim < dimensionCount; ++dim) { |
|
|
|
|
|
if (isBottomDimension(epoch, dim)) { |
|
|
|
|
|
epochHasBottomDimension = true; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (epochHasBottomDimension) { |
|
|
|
|
|
result = storm::storage::BitVector(memoryProduct.getProduct().getNumberOfStates()); |
|
|
|
|
|
} else { |
|
|
|
|
|
result = memoryProduct.getProduct().getInitialStates(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Compute the set of objectives that can not be satisfied anymore in the current epoch
|
|
|
|
|
|
storm::storage::BitVector irrelevantObjectives(objectives.size(), false); |
|
|
|
|
|
for (uint64_t objIndex = 0; objIndex < objectives.size(); ++objIndex) { |
|
|
|
|
|
bool objIrrelevant = true; |
|
|
|
|
|
for (auto const& dim : objectiveDimensions[objIndex]) { |
|
|
|
|
|
if (!isBottomDimension(epoch, dim)) { |
|
|
|
|
|
objIrrelevant = false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (objIrrelevant) { |
|
|
|
|
|
irrelevantObjectives.set(objIndex, true); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// Perform DFS
|
|
|
// Perform DFS
|
|
|
storm::storage::BitVector reachableStates = result; |
|
|
|
|
|
|
|
|
storm::storage::BitVector reachableStates = memoryProduct.getProduct().getInitialStates(); |
|
|
std::vector<uint_fast64_t> stack(reachableStates.begin(), reachableStates.end()); |
|
|
std::vector<uint_fast64_t> stack(reachableStates.begin(), reachableStates.end()); |
|
|
|
|
|
|
|
|
// std::cout << "Computing product In states for epoch " << epochToString(epoch) << std::endl;
|
|
|
|
|
|
|
|
|
|
|
|
while (!stack.empty()) { |
|
|
while (!stack.empty()) { |
|
|
uint64_t state = stack.back(); |
|
|
uint64_t state = stack.back(); |
|
|
stack.pop_back(); |
|
|
stack.pop_back(); |
|
|
|
|
|
|
|
|
for (uint64_t choice = productMatrix.getRowGroupIndices()[state]; choice < productMatrix.getRowGroupIndices()[state + 1]; ++choice) { |
|
|
for (uint64_t choice = productMatrix.getRowGroupIndices()[state]; choice < productMatrix.getRowGroupIndices()[state + 1]; ++choice) { |
|
|
auto const& choiceStep = memoryProduct.getSteps()[choice]; |
|
|
auto const& choiceStep = memoryProduct.getSteps()[choice]; |
|
|
if (!isZeroEpoch(choiceStep)) { |
|
|
if (!isZeroEpoch(choiceStep)) { |
|
|
storm::storage::BitVector objectiveSet(objectives.size(), false); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Compute the set of objectives that might or might not become irrelevant when the epoch is reached via the current choice
|
|
|
|
|
|
storm::storage::BitVector maybeIrrelevantObjectives(objectives.size(), false); |
|
|
for (uint64_t dim = 0; dim < dimensionCount; ++dim) { |
|
|
for (uint64_t dim = 0; dim < dimensionCount; ++dim) { |
|
|
if (isBottomDimension(epoch, dim) && getDimensionOfEpoch(choiceStep, dim) > 0) { |
|
|
if (isBottomDimension(epoch, dim) && getDimensionOfEpoch(choiceStep, dim) > 0) { |
|
|
objectiveSet.set(subObjectives[dim].second); |
|
|
|
|
|
|
|
|
maybeIrrelevantObjectives.set(subObjectives[dim].second); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
maybeIrrelevantObjectives &= ~irrelevantObjectives; |
|
|
|
|
|
|
|
|
if (objectiveSet.empty()) { |
|
|
|
|
|
|
|
|
// For optimization purposes, we treat the case that all objectives will be relevant seperately
|
|
|
|
|
|
if (maybeIrrelevantObjectives.empty() && irrelevantObjectives.empty()) { |
|
|
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) { |
|
|
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) { |
|
|
result.set(choiceSuccessor.getColumn(), true); |
|
|
result.set(choiceSuccessor.getColumn(), true); |
|
|
if (!reachableStates.get(choiceSuccessor.getColumn())) { |
|
|
if (!reachableStates.get(choiceSuccessor.getColumn())) { |
|
@ -468,29 +500,43 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
storm::storage::BitVector objectiveSubSet(objectiveSet.getNumberOfSetBits(), false); |
|
|
|
|
|
|
|
|
// Enumerate all possible combinations of maybe relevant objectives
|
|
|
|
|
|
storm::storage::BitVector maybeObjSubset(maybeIrrelevantObjectives.getNumberOfSetBits(), false); |
|
|
do { |
|
|
do { |
|
|
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) { |
|
|
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) { |
|
|
uint64_t modelState = memoryProduct.getModelState(choiceSuccessor.getColumn()); |
|
|
|
|
|
uint64_t memoryState = memoryProduct.getMemoryState(choiceSuccessor.getColumn()); |
|
|
|
|
|
storm::storage::BitVector memoryStatePrimeBv = memoryProduct.convertMemoryState(memoryState); |
|
|
|
|
|
|
|
|
// Compute the successor memory state for the current objective-subset and transition
|
|
|
|
|
|
storm::storage::BitVector successorMemoryState = memoryProduct.convertMemoryState(memoryProduct.getMemoryState(choiceSuccessor.getColumn())); |
|
|
|
|
|
// Unselect dimensions belonging to irrelevant objectives
|
|
|
|
|
|
for (auto const& irrelevantObjIndex : irrelevantObjectives) { |
|
|
|
|
|
successorMemoryState &= ~objectiveDimensions[irrelevantObjIndex]; |
|
|
|
|
|
} |
|
|
|
|
|
// Unselect objectives that are not in the current subset of maybe relevant objectives
|
|
|
|
|
|
// We can skip a subset if it selects an objective that is irrelevant anyway (according to the original successor memorystate).
|
|
|
|
|
|
bool skipThisSubSet = false; |
|
|
uint64_t i = 0; |
|
|
uint64_t i = 0; |
|
|
for (auto const& objIndex : objectiveSet) { |
|
|
|
|
|
if (objectiveSubSet.get(i)) { |
|
|
|
|
|
memoryStatePrimeBv &= ~objectiveDimensions[objIndex]; |
|
|
|
|
|
|
|
|
for (auto const& objIndex : maybeIrrelevantObjectives) { |
|
|
|
|
|
if (maybeObjSubset.get(i)) { |
|
|
|
|
|
if (successorMemoryState.isDisjointFrom(objectiveDimensions[objIndex])) { |
|
|
|
|
|
skipThisSubSet = true; |
|
|
|
|
|
break; |
|
|
|
|
|
} else { |
|
|
|
|
|
successorMemoryState &= ~objectiveDimensions[objIndex]; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
++i; |
|
|
++i; |
|
|
} |
|
|
} |
|
|
uint64_t successorState = memoryProduct.getProductState(modelState, memoryProduct.convertMemoryState(memoryStatePrimeBv)); |
|
|
|
|
|
|
|
|
if (!skipThisSubSet) { |
|
|
|
|
|
uint64_t successorState = memoryProduct.getProductState(memoryProduct.getModelState(choiceSuccessor.getColumn()), memoryProduct.convertMemoryState(successorMemoryState)); |
|
|
result.set(successorState, true); |
|
|
result.set(successorState, true); |
|
|
if (!reachableStates.get(successorState)) { |
|
|
if (!reachableStates.get(successorState)) { |
|
|
reachableStates.set(successorState); |
|
|
reachableStates.set(successorState); |
|
|
stack.push_back(successorState); |
|
|
stack.push_back(successorState); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
objectiveSubSet.increment(); |
|
|
|
|
|
} while (!objectiveSubSet.empty()); |
|
|
|
|
|
|
|
|
maybeObjSubset.increment(); |
|
|
|
|
|
} while (!maybeObjSubset.empty()); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) { |
|
|
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) { |
|
@ -502,6 +548,7 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return result; |
|
|
return result; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -573,6 +620,7 @@ namespace storm { |
|
|
|
|
|
|
|
|
++solIt; |
|
|
++solIt; |
|
|
} |
|
|
} |
|
|
|
|
|
maxSolutionsStored = std::max((uint64_t) solutions.size(), maxSolutionsStored); |
|
|
|
|
|
|
|
|
swInsertSol.stop(); |
|
|
swInsertSol.stop(); |
|
|
} |
|
|
} |
|
@ -580,12 +628,14 @@ namespace storm { |
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::setSolutionForCurrentEpoch(uint64_t const& productState, SolutionType const& solution) { |
|
|
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::setSolutionForCurrentEpoch(uint64_t const& productState, SolutionType const& solution) { |
|
|
STORM_LOG_ASSERT(currentEpoch, "Tried to set a solution for the current epoch, but no epoch was specified before."); |
|
|
STORM_LOG_ASSERT(currentEpoch, "Tried to set a solution for the current epoch, but no epoch was specified before."); |
|
|
|
|
|
// std::cout << "Setting solution for state " << productState << " in epoch " << epochToString(currentEpoch.get()) << std::endl;
|
|
|
solutions[std::make_pair(currentEpoch.get(), productState)] = solution; |
|
|
solutions[std::make_pair(currentEpoch.get(), productState)] = solution; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
template<typename ValueType, bool SingleObjectiveMode> |
|
|
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::setSolutionForCurrentEpoch(uint64_t const& productState, SolutionType&& solution) { |
|
|
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::setSolutionForCurrentEpoch(uint64_t const& productState, SolutionType&& solution) { |
|
|
STORM_LOG_ASSERT(currentEpoch, "Tried to set a solution for the current epoch, but no epoch was specified before."); |
|
|
STORM_LOG_ASSERT(currentEpoch, "Tried to set a solution for the current epoch, but no epoch was specified before."); |
|
|
|
|
|
// std::cout << "Setting solution for state " << productState << " in epoch " << epochToString(currentEpoch.get()) << std::endl;
|
|
|
solutions[std::make_pair(currentEpoch.get(), productState)] = std::move(solution); |
|
|
solutions[std::make_pair(currentEpoch.get(), productState)] = std::move(solution); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -595,6 +645,7 @@ namespace storm { |
|
|
//std::cout << "Getting solution for epoch " << epochToString(epoch) << " and state " << productState << std::endl;
|
|
|
//std::cout << "Getting solution for epoch " << epochToString(epoch) << " and state " << productState << std::endl;
|
|
|
auto solutionIt = solutions.find(std::make_pair(epoch, productState)); |
|
|
auto solutionIt = solutions.find(std::make_pair(epoch, productState)); |
|
|
STORM_LOG_ASSERT(solutionIt != solutions.end(), "Requested unexisting solution for epoch " << epochToString(epoch) << "."); |
|
|
STORM_LOG_ASSERT(solutionIt != solutions.end(), "Requested unexisting solution for epoch " << epochToString(epoch) << "."); |
|
|
|
|
|
//std::cout << "Retrieved solution for state " << productState << " in epoch " << epochToString(epoch) << std::endl;
|
|
|
swFindSol.stop(); |
|
|
swFindSol.stop(); |
|
|
return solutionIt->second; |
|
|
return solutionIt->second; |
|
|
} |
|
|
} |
|
|