Browse Source

optimized productInState Computation

tempestpy_adaptions
TimQu 7 years ago
parent
commit
d3e50b8769
  1. 95
      src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp
  2. 2
      src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.h

95
src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.cpp

@ -55,6 +55,9 @@ namespace storm {
template<typename ValueType, bool SingleObjectiveMode>
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::initialize() {
maxSolutionsStored = 0;
swInit.start();
STORM_LOG_ASSERT(!SingleObjectiveMode || (this->objectives.size() == 1), "Enabled single objective mode but there are multiple objectives.");
std::vector<Epoch> epochSteps;
@ -438,28 +441,57 @@ namespace storm {
storm::storage::BitVector MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::computeProductInStatesForEpochClass(Epoch const& epoch) {
storm::storage::SparseMatrix<ValueType> const& productMatrix = memoryProduct.getProduct().getTransitionMatrix();
storm::storage::BitVector result = memoryProduct.getProduct().getInitialStates();
// Initialize the result. Initial states are only considered if the epoch contains no bottom dimension.
storm::storage::BitVector result;
bool epochHasBottomDimension = false;
for (uint64_t dim = 0; dim < dimensionCount; ++dim) {
if (isBottomDimension(epoch, dim)) {
epochHasBottomDimension = true;
break;
}
}
if (epochHasBottomDimension) {
result = storm::storage::BitVector(memoryProduct.getProduct().getNumberOfStates());
} else {
result = memoryProduct.getProduct().getInitialStates();
}
// Compute the set of objectives that can not be satisfied anymore in the current epoch
storm::storage::BitVector irrelevantObjectives(objectives.size(), false);
for (uint64_t objIndex = 0; objIndex < objectives.size(); ++objIndex) {
bool objIrrelevant = true;
for (auto const& dim : objectiveDimensions[objIndex]) {
if (!isBottomDimension(epoch, dim)) {
objIrrelevant = false;
}
}
if (objIrrelevant) {
irrelevantObjectives.set(objIndex, true);
}
}
// Perform DFS
storm::storage::BitVector reachableStates = result;
storm::storage::BitVector reachableStates = memoryProduct.getProduct().getInitialStates();
std::vector<uint_fast64_t> stack(reachableStates.begin(), reachableStates.end());
// std::cout << "Computing product In states for epoch " << epochToString(epoch) << std::endl;
while (!stack.empty()) {
uint64_t state = stack.back();
stack.pop_back();
for (uint64_t choice = productMatrix.getRowGroupIndices()[state]; choice < productMatrix.getRowGroupIndices()[state + 1]; ++choice) {
auto const& choiceStep = memoryProduct.getSteps()[choice];
if (!isZeroEpoch(choiceStep)) {
storm::storage::BitVector objectiveSet(objectives.size(), false);
// Compute the set of objectives that might or might not become irrelevant when the epoch is reached via the current choice
storm::storage::BitVector maybeIrrelevantObjectives(objectives.size(), false);
for (uint64_t dim = 0; dim < dimensionCount; ++dim) {
if (isBottomDimension(epoch, dim) && getDimensionOfEpoch(choiceStep, dim) > 0) {
objectiveSet.set(subObjectives[dim].second);
maybeIrrelevantObjectives.set(subObjectives[dim].second);
}
}
maybeIrrelevantObjectives &= ~irrelevantObjectives;
if (objectiveSet.empty()) {
// For optimization purposes, we treat the case that all objectives will be relevant seperately
if (maybeIrrelevantObjectives.empty() && irrelevantObjectives.empty()) {
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) {
result.set(choiceSuccessor.getColumn(), true);
if (!reachableStates.get(choiceSuccessor.getColumn())) {
@ -468,29 +500,43 @@ namespace storm {
}
}
} else {
storm::storage::BitVector objectiveSubSet(objectiveSet.getNumberOfSetBits(), false);
// Enumerate all possible combinations of maybe relevant objectives
storm::storage::BitVector maybeObjSubset(maybeIrrelevantObjectives.getNumberOfSetBits(), false);
do {
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) {
uint64_t modelState = memoryProduct.getModelState(choiceSuccessor.getColumn());
uint64_t memoryState = memoryProduct.getMemoryState(choiceSuccessor.getColumn());
storm::storage::BitVector memoryStatePrimeBv = memoryProduct.convertMemoryState(memoryState);
// Compute the successor memory state for the current objective-subset and transition
storm::storage::BitVector successorMemoryState = memoryProduct.convertMemoryState(memoryProduct.getMemoryState(choiceSuccessor.getColumn()));
// Unselect dimensions belonging to irrelevant objectives
for (auto const& irrelevantObjIndex : irrelevantObjectives) {
successorMemoryState &= ~objectiveDimensions[irrelevantObjIndex];
}
// Unselect objectives that are not in the current subset of maybe relevant objectives
// We can skip a subset if it selects an objective that is irrelevant anyway (according to the original successor memorystate).
bool skipThisSubSet = false;
uint64_t i = 0;
for (auto const& objIndex : objectiveSet) {
if (objectiveSubSet.get(i)) {
memoryStatePrimeBv &= ~objectiveDimensions[objIndex];
for (auto const& objIndex : maybeIrrelevantObjectives) {
if (maybeObjSubset.get(i)) {
if (successorMemoryState.isDisjointFrom(objectiveDimensions[objIndex])) {
skipThisSubSet = true;
break;
} else {
successorMemoryState &= ~objectiveDimensions[objIndex];
}
}
++i;
}
uint64_t successorState = memoryProduct.getProductState(modelState, memoryProduct.convertMemoryState(memoryStatePrimeBv));
result.set(successorState, true);
if (!reachableStates.get(successorState)) {
reachableStates.set(successorState);
stack.push_back(successorState);
if (!skipThisSubSet) {
uint64_t successorState = memoryProduct.getProductState(memoryProduct.getModelState(choiceSuccessor.getColumn()), memoryProduct.convertMemoryState(successorMemoryState));
result.set(successorState, true);
if (!reachableStates.get(successorState)) {
reachableStates.set(successorState);
stack.push_back(successorState);
}
}
}
objectiveSubSet.increment();
} while (!objectiveSubSet.empty());
maybeObjSubset.increment();
} while (!maybeObjSubset.empty());
}
} else {
for (auto const& choiceSuccessor : productMatrix.getRow(choice)) {
@ -502,6 +548,7 @@ namespace storm {
}
}
}
return result;
}
@ -573,6 +620,7 @@ namespace storm {
++solIt;
}
maxSolutionsStored = std::max((uint64_t) solutions.size(), maxSolutionsStored);
swInsertSol.stop();
}
@ -580,12 +628,14 @@ namespace storm {
template<typename ValueType, bool SingleObjectiveMode>
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::setSolutionForCurrentEpoch(uint64_t const& productState, SolutionType const& solution) {
STORM_LOG_ASSERT(currentEpoch, "Tried to set a solution for the current epoch, but no epoch was specified before.");
// std::cout << "Setting solution for state " << productState << " in epoch " << epochToString(currentEpoch.get()) << std::endl;
solutions[std::make_pair(currentEpoch.get(), productState)] = solution;
}
template<typename ValueType, bool SingleObjectiveMode>
void MultiDimensionalRewardUnfolding<ValueType, SingleObjectiveMode>::setSolutionForCurrentEpoch(uint64_t const& productState, SolutionType&& solution) {
STORM_LOG_ASSERT(currentEpoch, "Tried to set a solution for the current epoch, but no epoch was specified before.");
// std::cout << "Setting solution for state " << productState << " in epoch " << epochToString(currentEpoch.get()) << std::endl;
solutions[std::make_pair(currentEpoch.get(), productState)] = std::move(solution);
}
@ -595,6 +645,7 @@ namespace storm {
//std::cout << "Getting solution for epoch " << epochToString(epoch) << " and state " << productState << std::endl;
auto solutionIt = solutions.find(std::make_pair(epoch, productState));
STORM_LOG_ASSERT(solutionIt != solutions.end(), "Requested unexisting solution for epoch " << epochToString(epoch) << ".");
//std::cout << "Retrieved solution for state " << productState << " in epoch " << epochToString(epoch) << std::endl;
swFindSol.stop();
return solutionIt->second;
}

2
src/storm/modelchecker/multiobjective/rewardbounded/MultiDimensionalRewardUnfolding.h

@ -58,6 +58,7 @@ namespace storm {
std::cout << " aux4StopWatch: " << swAux4 << " seconds." << std::endl;
std::cout << "---------------------------------------------" << std::endl;
std::cout << " Product size: " << memoryProduct.getProduct().getNumberOfStates() << std::endl;
std::cout << "maxSolutionsStored: " << maxSolutionsStored << std::endl;
std::cout << " Epoch model sizes: ";
for (auto const& i : epochModelSizes) {
std::cout << i << " ";
@ -189,6 +190,7 @@ namespace storm {
storm::utility::Stopwatch swInit, swFindSol, swInsertSol, swSetEpoch, swSetEpochClass, swAux1, swAux2, swAux3, swAux4;
std::vector<uint64_t> epochModelSizes;
uint64_t maxSolutionsStored;
};
}
}
Loading…
Cancel
Save