diff --git a/src/storm-pars/analysis/Lattice.cpp b/src/storm-pars/analysis/Lattice.cpp index e015358cb..a0dc25d47 100644 --- a/src/storm-pars/analysis/Lattice.cpp +++ b/src/storm-pars/analysis/Lattice.cpp @@ -12,10 +12,13 @@ namespace storm { storm::storage::BitVector bottomStates, uint_fast64_t numberOfStates) { top = new Node(); top->states = topStates; + setStatesAbove(top, storm::storage::BitVector(numberOfStates), false); + setStatesBelow(top, bottomStates, false); + bottom = new Node(); bottom->states = bottomStates; - top->below.insert(bottom); - bottom->above.insert(top); + setStatesBelow(bottom, storm::storage::BitVector(numberOfStates), false); + setStatesAbove(bottom, topStates, false); nodes = std::vector(numberOfStates); for (auto i = topStates.getNextSetIndex(0); i < numberOfStates; i = topStates.getNextSetIndex(i+1)) { @@ -36,13 +39,19 @@ namespace storm { Lattice::Lattice(Lattice* lattice) { top = new Node(); top->states = storm::storage::BitVector(lattice->getTop()->states); + setStatesAbove(top, lattice->getTop()->statesAbove, false); + setStatesBelow(top, lattice->getTop()->statesBelow, false); + bottom = new Node(); bottom->states = storm::storage::BitVector(lattice->getBottom()->states); + setStatesAbove(bottom, lattice->getBottom()->statesAbove, false); + setStatesBelow(bottom, lattice->getBottom()->statesBelow, false); + numberOfStates = top->states.size(); nodes = std::vector(numberOfStates); addedStates = storm::storage::BitVector(numberOfStates); - addedStates.operator|=(top->states); - addedStates.operator|=(bottom->states); + addedStates |= (top->states); + addedStates |= (bottom->states); for (auto i = top->states.getNextSetIndex(0); i < numberOfStates; i = top->states.getNextSetIndex(i+1)) { nodes.at(i) = top; @@ -59,6 +68,8 @@ namespace storm { if (oldNode != nullptr) { Node *newNode = new Node(); newNode->states = storm::storage::BitVector(oldNode->states); + setStatesAbove(newNode, oldNode->statesAbove, false); + setStatesBelow(newNode, oldNode->statesBelow, false); for (auto i = newNode->states.getNextSetIndex(0); i < numberOfStates; i = newNode->states.getNextSetIndex(i + 1)) { addedStates.set(i); @@ -67,17 +78,7 @@ namespace storm { } } - // Create transitions - for (auto itr = oldNodes.begin(); itr != oldNodes.end(); ++itr) { - Node* oldNode = (*itr); - if (oldNode != nullptr) { - auto state = (*itr)->states.getNextSetIndex(0); - for (auto itr2 = (*itr)->below.begin(); itr2 != (*itr)->below.end(); ++itr2) { - auto stateBelow = (*itr2)->states.getNextSetIndex(0); - addRelationNodes(getNode((state)), getNode((stateBelow))); - } - } - } + assert(addedStates == lattice->getAddedStates()); } void Lattice::addBetween(uint_fast64_t state, Node *above, Node *below) { @@ -87,12 +88,28 @@ namespace storm { Node *newNode = new Node(); newNode->states = storm::storage::BitVector(numberOfStates); newNode->states.set(state); - newNode->above = std::set({above}); - newNode->below = std::set({below}); - below->above.erase(above); - above->below.erase(below); - (below->above).insert(newNode); - above->below.insert(newNode); + + setStatesAbove(newNode, above->statesAbove | above->states, false); + setStatesBelow(newNode, below->statesBelow | below->states, false); + + newNode->statesAbove = above->statesAbove | above->states; + newNode->statesBelow = below->statesBelow | below->states; + newNode->recentlyAddedAbove = storm::storage::BitVector(above->statesAbove | above->states); + newNode->recentlyAddedBelow = storm::storage::BitVector(below->statesBelow | below->states); + setStatesBelow(above, state); + setStatesAbove(below, state); + + auto nodesBelow = getNodesBelow(below); + for (auto itr = nodesBelow.begin(); itr != nodesBelow.end(); ++itr) { + assert((*itr)->statesAbove.size() == numberOfStates); + setStatesAbove((*itr), state); + } + + auto nodesAbove = getNodesAbove(above); + for (auto itr = nodesAbove.begin(); itr != nodesAbove.end(); ++itr) { + assert((*itr)->statesBelow.size() == numberOfStates); + setStatesBelow((*itr), state); + } nodes.at(state) = newNode; addedStates.set(state); } @@ -102,6 +119,14 @@ namespace storm { node->states.set(state); nodes.at(state) = node; addedStates.set(state); + auto nodesBelow = getNodesBelow(node); + for (auto itr = nodesBelow.begin(); itr != nodesBelow.end(); ++itr) { + setStatesAbove((*itr), state); + } + auto nodesAbove = getNodesAbove(node); + for (auto itr = nodesAbove.begin(); nodesAbove.size() != 0 &&itr != nodesAbove.end(); ++itr) { + setStatesBelow((*itr), state); + } } void Lattice::add(uint_fast64_t state) { @@ -109,13 +134,16 @@ namespace storm { } void Lattice::addRelationNodes(storm::analysis::Lattice::Node *above, storm::analysis::Lattice::Node * below) { - assert(compare(above, below) == UNKNOWN || compare(above, below) == ABOVE); - above->below.insert(below); - below->above.insert(above); - } - - void Lattice::mergeNodes(storm::analysis::Lattice::Node *n1, storm::analysis::Lattice::Node * n2) { - // TODO + above->statesBelow |= below->states; + below->statesAbove |= above->states; + auto nodesBelow = getNodesBelow(below); + for (auto itr = nodesBelow.begin(); itr != nodesBelow.end(); ++itr) { + setStatesAbove((*itr), above->states, true); + } + auto nodesAbove = getNodesAbove(above); + for (auto itr = nodesAbove.begin() ; itr != nodesAbove.end(); ++itr) { + setStatesBelow((*itr), below->states, true); + } } int Lattice::compare(uint_fast64_t state1, uint_fast64_t state2) { @@ -128,12 +156,12 @@ namespace storm { return SAME; } - if (above(node1, node2, new std::set({}))) { - assert(!above(node2, node1, new std::set({}))); + if (above(node1, node2, std::make_shared>(std::set({})))) { + assert(!above(node2, node1, std::make_shared>(std::set({})))); return ABOVE; } - if (above(node2, node1, new std::set({}))) { + if (above(node2, node1, std::make_shared>(std::set({})))) { return BELOW; } } @@ -144,6 +172,26 @@ namespace storm { return nodes.at(stateNumber); } + std::set Lattice::getNodesAbove(Lattice::Node * node) { + if (node->recentlyAddedAbove.getNumberOfSetBits() != 0) { + for (auto i = node->recentlyAddedAbove.getNextSetIndex(0); i < node->recentlyAddedAbove.size(); i = node->recentlyAddedAbove.getNextSetIndex(i+1)) { + node->above.insert(getNode(i)); + } + node->recentlyAddedAbove.clear(); + } + return node->above; + } + + std::set Lattice::getNodesBelow(Lattice::Node * node) { + if (node->recentlyAddedBelow.getNumberOfSetBits() != 0) { + for (auto i = node->recentlyAddedBelow.getNextSetIndex(0); i < node->recentlyAddedBelow.size(); i = node->recentlyAddedBelow.getNextSetIndex(i+1)) { + node->below.insert(getNode(i)); + } + node->recentlyAddedBelow.clear(); + } + return node->above; + } + Lattice::Node *Lattice::getTop() { return top; } @@ -160,6 +208,8 @@ namespace storm { return addedStates; } + + void Lattice::toString(std::ostream &out) { std::vector printedNodes = std::vector({}); for (auto itr = nodes.begin(); itr != nodes.end(); ++itr) { @@ -180,6 +230,7 @@ namespace storm { out << " Address: " << node << "\n"; out << " Above: {"; + getNodesAbove(node); // such that it is updated for (auto itr2 = node->above.begin(); itr2 != node->above.end(); ++itr2) { Node *above = *itr2; index = above->states.getNextSetIndex(0); @@ -198,6 +249,8 @@ namespace storm { out << " Below: {"; + getNodesBelow(node); // To make sure it is updated + // TODO verbeteren for (auto itr2 = node->below.begin(); itr2 != node->below.end(); ++itr2) { Node *below = *itr2; out << "{"; @@ -256,15 +309,42 @@ namespace storm { out << "}" << std::endl; } - bool Lattice::above(Node *node1, Node *node2, std::set* seenNodes) { - bool result = !node1->below.empty() && std::find(node1->below.begin(), node1->below.end(), node2) != node1->below.end(); - for (auto itr = node1->below.begin(); !result && node1->below.end() != itr; ++itr) { - if (std::find(seenNodes->begin(), seenNodes->end(), (*itr)) == seenNodes->end()) { - seenNodes->insert(*itr); - result |= above(*itr, node2, seenNodes); - } + bool Lattice::above(Node *node1, Node *node2, std::shared_ptr> seenNodes) { + return node1->statesBelow.get(node2->states.getNextSetIndex(0)); + } + + void Lattice::setStatesAbove(storm::analysis::Lattice::Node *node, uint_fast64_t state) { + node->statesAbove.set(state); + node->recentlyAddedAbove.set(state); + } + + void Lattice::setStatesBelow(storm::analysis::Lattice::Node *node, uint_fast64_t state) { + node->statesBelow.set(state); + node->recentlyAddedBelow.set(state); + } + + void Lattice::setStatesAbove(storm::analysis::Lattice::Node *node, storm::storage::BitVector states, bool alreadyInitialized) { + if (alreadyInitialized) { + node->statesAbove |= states; + node->recentlyAddedAbove |= states; + } else { + node->statesAbove = storm::storage::BitVector(states); + node->recentlyAddedAbove = storm::storage::BitVector(states); } - return result; } + + void Lattice::setStatesBelow(storm::analysis::Lattice::Node *node, storm::storage::BitVector states, bool alreadyInitialized) { + if (alreadyInitialized) { + node->statesBelow |= states; + node->recentlyAddedBelow |= states; + } else { + node->statesBelow = storm::storage::BitVector(states); + node->recentlyAddedBelow = storm::storage::BitVector(states); + } + } + + + + } } diff --git a/src/storm-pars/analysis/Lattice.h b/src/storm-pars/analysis/Lattice.h index b92570e65..ce9cb7ce7 100644 --- a/src/storm-pars/analysis/Lattice.h +++ b/src/storm-pars/analysis/Lattice.h @@ -18,8 +18,12 @@ namespace storm { public: struct Node { storm::storage::BitVector states; - std::set above; - std::set below; + storm::storage::BitVector statesAbove; + storm::storage::BitVector statesBelow; + storm::storage::BitVector recentlyAddedAbove; + storm::storage::BitVector recentlyAddedBelow; + std::set above; + std::set below; }; /*! @@ -55,14 +59,12 @@ namespace storm { void add(uint_fast64_t state); /*! - * Adds a new relation between two nodes to the lattice - * @param above The node closest to the top Node of the Lattice. - * @param below The node closest to the bottom Node of the Lattice. - */ + * Adds a new relation between two nodes to the lattice + * @param above The node closest to the top Node of the Lattice. + * @param below The node closest to the bottom Node of the Lattice. + */ void addRelationNodes(storm::analysis::Lattice::Node *above, storm::analysis::Lattice::Node * below); - void mergeNodes(storm::analysis::Lattice::Node *n1, storm::analysis::Lattice::Node * n2); - /*! * Compares the level of the nodes of the states. * Behaviour unknown when one or more of the states doesnot occur at any Node in the Lattice. @@ -107,6 +109,9 @@ namespace storm { */ std::vector getNodes(); + std::set getNodesBelow(Node* node); + std::set getNodesAbove(Node* node); + /*! * Returns a BitVector in which all added states are set. * @@ -149,9 +154,17 @@ namespace storm { uint_fast64_t numberOfStates; - bool above(Node * node1, Node * node2, std::set* seenNodes); + bool above(Node * node1, Node * node2, std::shared_ptr> seenNodes); int compare(Node* node1, Node* node2); + + void setStatesAbove(Node* node, uint_fast64_t state); + + void setStatesBelow(Node* node, uint_fast64_t state); + + void setStatesAbove(storm::analysis::Lattice::Node *node, storm::storage::BitVector states, bool alreadyInitialized); + + void setStatesBelow(storm::analysis::Lattice::Node *node, storm::storage::BitVector states, bool alreadyInitialized); }; } } diff --git a/src/storm-pars/analysis/LatticeExtender.cpp b/src/storm-pars/analysis/LatticeExtender.cpp index 004418d57..89d4d6ec9 100644 --- a/src/storm-pars/analysis/LatticeExtender.cpp +++ b/src/storm-pars/analysis/LatticeExtender.cpp @@ -31,19 +31,6 @@ namespace storm { template LatticeExtender::LatticeExtender(std::shared_ptr> model) { this->model = model; - - initialMiddleStates = storm::storage::BitVector(model->getNumberOfStates()); - // Check if MC is acyclic - auto decomposition = storm::storage::StronglyConnectedComponentDecomposition(model->getTransitionMatrix(), false, false); - for (auto i = 0; i < decomposition.size(); ++i) { - auto scc = decomposition.getBlock(i); - if (scc.size() > 1) { - auto states = scc.getStates(); - // TODO: Smarter state picking - // Add one of the states of the scc - initialMiddleStates.set(*(states.begin())); - } - } } template @@ -91,6 +78,34 @@ namespace storm { } } + auto initialMiddleStates = storm::storage::BitVector(model->getNumberOfStates()); + // Check if MC is acyclic + auto decomposition = storm::storage::StronglyConnectedComponentDecomposition(model->getTransitionMatrix(), false, false); + for (auto i = 0; i < decomposition.size(); ++i) { + auto scc = decomposition.getBlock(i); + if (scc.size() > 1) { + auto states = scc.getStates(); + // check if the state has already one successor in bottom of top, in that case pick it + bool found = false; + for (auto stateItr = states.begin(); !found && stateItr < states.end(); ++stateItr) { + auto successors = stateMap[*stateItr]; + if (successors.getNumberOfSetBits() == 2) { + auto succ1 = successors.getNextSetIndex(0); + auto succ2 = successors.getNextSetIndex(succ1 + 1); + auto intersection = bottomStates | topStates; + if ((intersection[succ1] && ! intersection[succ2]) + || (intersection[succ2] && !intersection[succ1])) { + initialMiddleStates.set(*stateItr); + found = true; + } else if (intersection[succ1] && intersection[succ2]) { + found = true; + } + } + + } + } + } + // Create the Lattice storm::analysis::Lattice *lattice = new storm::analysis::Lattice(topStates, bottomStates, numberOfStates); for (auto state = initialMiddleStates.getNextSetIndex(0); state != numberOfStates; state = initialMiddleStates.getNextSetIndex(state + 1)) { @@ -127,7 +142,8 @@ namespace storm { storm::analysis::Lattice::Node *n2 = lattice->getNode(val2); if (n1 != nullptr && n2 != nullptr) { - lattice->mergeNodes(n1, n2); + assert(false); +// lattice->mergeNodes(n1, n2); } else if (n1 != nullptr) { lattice->addToNode(val2, n1); } else if (n2 != nullptr) { diff --git a/src/storm-pars/analysis/LatticeExtender.h b/src/storm-pars/analysis/LatticeExtender.h index 8737c8f9e..01b0554a8 100644 --- a/src/storm-pars/analysis/LatticeExtender.h +++ b/src/storm-pars/analysis/LatticeExtender.h @@ -54,7 +54,7 @@ namespace storm { std::map stateMap; - storm::storage::BitVector initialMiddleStates; +// storm::storage::BitVector initialMiddleStates; }; } }