diff --git a/src/modelchecker/multiobjective/helper/SparseMultiObjectivePreprocessor.cpp b/src/modelchecker/multiobjective/helper/SparseMultiObjectivePreprocessor.cpp index 3035ad172..5972f5f0b 100644 --- a/src/modelchecker/multiobjective/helper/SparseMultiObjectivePreprocessor.cpp +++ b/src/modelchecker/multiobjective/helper/SparseMultiObjectivePreprocessor.cpp @@ -28,7 +28,7 @@ namespace storm { while(data.preprocessedModel.hasLabel(data.prob1StatesLabel)) { data.prob1StatesLabel = "_" + data.prob1StatesLabel; } - data.preprocessedModel.getStateLabeling().addLabel(data.prob1StatesLabel); + data.preprocessedModel.getStateLabeling().addLabel(data.prob1StatesLabel, storm::storage::BitVector(data.preprocessedModel.getNumberOfStates(), true)); //Invoke preprocessing on the individual objectives for(auto const& subFormula : originalFormula.getSubFormulas()){ @@ -181,20 +181,18 @@ namespace storm { storm::storage::BitVector subsystemStates; storm::storage::BitVector noIncomingTransitionFromFirstCopyStates; if(isProb0Formula) { - subsystemStates = storm::utility::graph::performProb0E(data.preprocessedModel, data.preprocessedModel.getBackwardTransitions(), duplicatorResult.firstCopy, newPsiStates); - subsystemStates |= duplicatorResult.secondCopy; + storm::storage::BitVector statesReachableInSecondCopy = storm::utility::graph::getReachableStates(data.preprocessedModel.getTransitionMatrix(), duplicatorResult.gateStates & (~newPsiStates), duplicatorResult.secondCopy, storm::storage::BitVector(data.preprocessedModel.getNumberOfStates(), false)); + subsystemStates = statesReachableInSecondCopy | storm::utility::graph::performProb0E(data.preprocessedModel, data.preprocessedModel.getBackwardTransitions(), duplicatorResult.firstCopy, newPsiStates); noIncomingTransitionFromFirstCopyStates = newPsiStates; } else { - for(auto psiState : newPsiStates) { - data.preprocessedModel.getStateLabeling().addLabelToState(data.prob1StatesLabel, psiState); - } - subsystemStates = storm::utility::graph::performProb1E(data.preprocessedModel, data.preprocessedModel.getBackwardTransitions(), duplicatorResult.firstCopy, newPsiStates); - subsystemStates |= duplicatorResult.secondCopy; - noIncomingTransitionFromFirstCopyStates = duplicatorResult.secondCopy & (~newPsiStates); + storm::storage::BitVector statesReachableInSecondCopy = storm::utility::graph::getReachableStates(data.preprocessedModel.getTransitionMatrix(), newPsiStates, duplicatorResult.secondCopy, storm::storage::BitVector(data.preprocessedModel.getNumberOfStates(), false)); + data.preprocessedModel.getStateLabeling().setStates(data.prob1StatesLabel, data.preprocessedModel.getStateLabeling().getStates(data.prob1StatesLabel) & statesReachableInSecondCopy); + subsystemStates = statesReachableInSecondCopy | storm::utility::graph::performProb1E(data.preprocessedModel, data.preprocessedModel.getBackwardTransitions(), duplicatorResult.firstCopy, newPsiStates); + noIncomingTransitionFromFirstCopyStates = duplicatorResult.gateStates & (~newPsiStates); } storm::storage::BitVector consideredActions(data.preprocessedModel.getTransitionMatrix().getRowCount(), true); for(auto state : duplicatorResult.firstCopy) { - for(uint_fast64_t action = data.preprocessedModel.getTransitionMatrix().getRowGroupIndices()[state]; action < data.preprocessedModel.getTransitionMatrix().getRowGroupIndices()[state +1] ; ++action) { + for(uint_fast64_t action = data.preprocessedModel.getTransitionMatrix().getRowGroupIndices()[state]; action < data.preprocessedModel.getTransitionMatrix().getRowGroupIndices()[state +1]; ++action) { for(auto const& entry : data.preprocessedModel.getTransitionMatrix().getRow(action)) { if(noIncomingTransitionFromFirstCopyStates.get(entry.getColumn())) { consideredActions.set(action, false); @@ -203,7 +201,6 @@ namespace storm { } } } - subsystemStates = storm::utility::graph::getReachableStates(data.preprocessedModel.getTransitionMatrix(), data.preprocessedModel.getInitialStates(), subsystemStates, storm::storage::BitVector(subsystemStates.size(), false)); auto subsystemBuilderResult = storm::transformer::SubsystemBuilder::transform(data.preprocessedModel, subsystemStates, consideredActions); updatePreprocessedModel(data, *subsystemBuilderResult.model, subsystemBuilderResult.newToOldStateIndexMapping); data.objectivesSolvedInPreprocessing.set(data.objectives.size()); @@ -276,11 +273,8 @@ namespace storm { // States of the first copy from which the second copy is not reachable with prob 1 under any scheduler can // be removed as the expected reward is not defined for these states. // We also need to enforce that the second copy will be reached eventually with prob 1. - for(auto targetState : duplicatorResult.gateStates) { - data.preprocessedModel.getStateLabeling().addLabelToState(data.prob1StatesLabel, targetState); - } - storm::storage::BitVector subsystemStates = storm::utility::graph::performProb1E(data.preprocessedModel, data.preprocessedModel.getBackwardTransitions(), duplicatorResult.firstCopy, duplicatorResult.gateStates); - subsystemStates |= duplicatorResult.secondCopy; + data.preprocessedModel.getStateLabeling().setStates(data.prob1StatesLabel, data.preprocessedModel.getStateLabeling().getStates(data.prob1StatesLabel) & duplicatorResult.secondCopy); + storm::storage::BitVector subsystemStates = duplicatorResult.secondCopy | storm::utility::graph::performProb1E(data.preprocessedModel, data.preprocessedModel.getBackwardTransitions(), duplicatorResult.firstCopy, duplicatorResult.gateStates); if(!subsystemStates.full()) { auto subsystemBuilderResult = storm::transformer::SubsystemBuilder::transform(data.preprocessedModel, subsystemStates, storm::storage::BitVector(data.preprocessedModel.getTransitionMatrix().getRowCount(), true)); updatePreprocessedModel(data, *subsystemBuilderResult.model, subsystemBuilderResult.newToOldStateIndexMapping); diff --git a/src/models/sparse/StateLabeling.cpp b/src/models/sparse/StateLabeling.cpp index 7957e6a99..86880ee57 100644 --- a/src/models/sparse/StateLabeling.cpp +++ b/src/models/sparse/StateLabeling.cpp @@ -96,6 +96,18 @@ namespace storm { STORM_LOG_THROW(this->containsLabel(label), storm::exceptions::InvalidArgumentException, "The label " << label << " is invalid for the labeling of the model."); return this->labelings[nameToLabelingIndexMap.at(label)]; } + + void StateLabeling::setStates(std::string const& label, storage::BitVector const& labeling) { + STORM_LOG_THROW(this->containsLabel(label), storm::exceptions::InvalidArgumentException, "The label " << label << " is invalid for the labeling of the model."); + STORM_LOG_THROW(labeling.size() == stateCount, storm::exceptions::InvalidArgumentException, "Labeling vector has invalid size."); + this->labelings[nameToLabelingIndexMap.at(label)] = labeling; + } + + void StateLabeling::setStates(std::string const& label, storage::BitVector&& labeling) { + STORM_LOG_THROW(this->containsLabel(label), storm::exceptions::InvalidArgumentException, "The label " << label << " is invalid for the labeling of the model."); + STORM_LOG_THROW(labeling.size() == stateCount, storm::exceptions::InvalidArgumentException, "Labeling vector has invalid size."); + this->labelings[nameToLabelingIndexMap.at(label)] = labeling; + } std::size_t StateLabeling::getSizeInBytes() const { std::size_t result = sizeof(*this); @@ -113,4 +125,4 @@ namespace storm { } } } -} \ No newline at end of file +} diff --git a/src/models/sparse/StateLabeling.h b/src/models/sparse/StateLabeling.h index c7b3411a7..678b80bf6 100644 --- a/src/models/sparse/StateLabeling.h +++ b/src/models/sparse/StateLabeling.h @@ -130,6 +130,22 @@ namespace storm { */ storm::storage::BitVector const& getStates(std::string const& label) const; + /*! + * Sets the labeling of states associated with the given label. + * + * @param label The name of the label. + * @param labeling A bit vector that represents the set of states that will get this label. + */ + void setStates(std::string const& label, storage::BitVector const& labeling); + + /*! + * Sets the labeling of states associated with the given label. + * + * @param label The name of the label. + * @param labeling A bit vector that represents the set of states that will get this label. + */ + void setStates(std::string const& label, storage::BitVector&& labeling); + /*! * Returns (an approximation of) the size of the labeling measured in bytes. *