diff --git a/src/utility/shortestPaths.cpp b/src/utility/shortestPaths.cpp index 15039a4c3..ffd3730f1 100644 --- a/src/utility/shortestPaths.cpp +++ b/src/utility/shortestPaths.cpp @@ -7,15 +7,12 @@ namespace storm { namespace utility { namespace ksp { template - ShortestPathsGenerator::ShortestPathsGenerator(std::shared_ptr> model, - state_list_t const& targets) : transitionMatrix(model->getTransitionMatrix()), - numStates(model->getNumberOfStates() + 1), // one more for meta-target - metaTarget(model->getNumberOfStates()), // first unused state number - initialStates(model->getInitialStates()), - targets(targets) { - for (state_t target : targets) { - targetProbMap.emplace(target, one()); - } + ShortestPathsGenerator::ShortestPathsGenerator(storage::SparseMatrix transitionMatrix, std::unordered_map targetProbMap, BitVector initialStates) : + transitionMatrix(transitionMatrix), + numStates(transitionMatrix.getColumnCount() + 1), // one more for meta-target + metaTarget(transitionMatrix.getColumnCount()), // first unused state index + initialStates(initialStates), + targetProbMap(targetProbMap) { computePredecessors(); @@ -30,6 +27,8 @@ namespace storm { candidatePaths.resize(numStates); } + // TODO: probTargetVector [!] to probTargetMap ctor + // extracts the relevant info from the model and delegates to ctor above template ShortestPathsGenerator::ShortestPathsGenerator(std::shared_ptr> model, BitVector const& targetBV) @@ -108,14 +107,16 @@ namespace storm { for (auto const& transition : transitionMatrix.getRowGroup(i)) { // to avoid non-minimal paths, the target states are // *not* predecessors of any state but the meta-target - if (std::find(targets.begin(), targets.end(), i) == targets.end()) { + if (!isTargetState(i)) { graphPredecessors[transition.getColumn()].push_back(i); } } } // meta-target has exactly the target states as predecessors - graphPredecessors[metaTarget] = targets; + for (auto targetProbPair : targetProbMap) { // FIXME + graphPredecessors[metaTarget].push_back(targetProbPair.first); + } } template @@ -141,7 +142,7 @@ namespace storm { state_t currentNode = (*dijkstraQueue.begin()).second; dijkstraQueue.erase(dijkstraQueue.begin()); - if (targetProbMap.count(currentNode) == 0) { + if (!isTargetState(currentNode)) { // non-target node, treated normally for (auto const& transition : transitionMatrix.getRowGroup(currentNode)) { state_t otherNode = transition.getColumn(); @@ -157,7 +158,7 @@ namespace storm { } else { // target node has only "virtual edge" (with prob 1) to meta-target // no multiplication necessary - T alternateDistance = shortestPathDistances[currentNode]; + T alternateDistance = shortestPathDistances[currentNode]; // FIXME if (alternateDistance > shortestPathDistances[metaTarget]) { shortestPathDistances[metaTarget] = alternateDistance; shortestPathPredecessors[metaTarget] = boost::optional(currentNode); @@ -231,7 +232,7 @@ namespace storm { } else { // edge must be "virtual edge" from target state to meta-target assert(targetProbMap.count(tailNode) == 1); - return utility::one(); + return one(); } } diff --git a/src/utility/shortestPaths.h b/src/utility/shortestPaths.h index 402fb31f6..400485997 100644 --- a/src/utility/shortestPaths.h +++ b/src/utility/shortestPaths.h @@ -69,6 +69,7 @@ namespace storm { // vector from SamplingModel); // in this case separately specifying a target makes no sense //ShortestPathsGenerator(storm::storage::SparseMatrix maybeTransitionMatrix, std::vector targetProbVector); + ShortestPathsGenerator(storm::storage::SparseMatrix maybeTransitionMatrix, std::unordered_map targetProbMap, BitVector initialStates); inline ~ShortestPathsGenerator(){} @@ -99,7 +100,6 @@ namespace storm { storage::SparseMatrix transitionMatrix; state_t numStates; // includes meta-target, i.e. states in model + 1 state_t metaTarget; - state_list_t targets; BitVector initialStates; std::unordered_map targetProbMap; @@ -167,12 +167,19 @@ namespace storm { return find(initialStates.begin(), initialStates.end(), node) != initialStates.end(); } - inline state_list_t bitvectorToList(storage::BitVector const& bv) const { - state_list_t list; - for (state_t state : bv) { - list.push_back(state); + inline bool isTargetState(state_t node) const { + return targetProbMap.count(node) == 1; + } + + /** + * Returns a map where each state of the input BitVector is mapped to 1 (`one`). + */ + inline std::unordered_map allProbOneMap(BitVector bitVector) const& { + std::unordered_map stateProbMap; + for (state_t node : bitVector) { + stateProbMap.emplace(node, one()); } - return list; + return stateProbMap; } // ----------------------- };