diff --git a/src/storm/models/sparse/ItemLabeling.cpp b/src/storm/models/sparse/ItemLabeling.cpp index b310b40e3..9101bfa77 100644 --- a/src/storm/models/sparse/ItemLabeling.cpp +++ b/src/storm/models/sparse/ItemLabeling.cpp @@ -64,6 +64,25 @@ namespace storm { addLabel(label, storage::BitVector(itemCount)); } + void ItemLabeling::removeLabel(std::string const& label) { + auto labelIt = nameToLabelingIndexMap.find(label); + STORM_LOG_THROW(labelIt != nameToLabelingIndexMap.end(), storm::exceptions::InvalidArgumentException, "Label '" << label << "' does not exist."); + uint64_t labelIndex = labelIt->second; + // Erase entry in map + nameToLabelingIndexMap.erase(labelIt); + // Erase label by 'swap and pop' + std::iter_swap(labelings.begin() + labelIndex, labelings.end() - 1); + labelings.pop_back(); + + // Update index of labeling we swapped from the end + for (auto& it: nameToLabelingIndexMap) { + if (it.second == labelings.size()) { + it.second = labelIndex; + break; + } + } + } + void ItemLabeling::join(ItemLabeling const& other) { STORM_LOG_THROW(this->itemCount == other.itemCount, storm::exceptions::InvalidArgumentException, "The item count of the two labelings does not match: " << this->itemCount << " vs. " << other.itemCount << "."); for (auto const& label : other.getLabels()) { @@ -112,11 +131,17 @@ namespace storm { } void ItemLabeling::addLabelToItem(std::string const& label, uint64_t item) { - STORM_LOG_THROW(this->containsLabel(label), storm::exceptions::OutOfRangeException, "Label '" << label << "' unknown."); + STORM_LOG_THROW(this->containsLabel(label), storm::exceptions::InvalidArgumentException, "Label '" << label << "' unknown."); STORM_LOG_THROW(item < itemCount, storm::exceptions::OutOfRangeException, "Item index out of range."); this->labelings[nameToLabelingIndexMap.at(label)].set(item, true); } + void ItemLabeling::removeLabelFromItem(std::string const& label, uint64_t item) { + STORM_LOG_THROW(item < itemCount, storm::exceptions::OutOfRangeException, "Item index out of range."); + STORM_LOG_THROW(this->getItemHasLabel(label, item),storm::exceptions::InvalidArgumentException, "Item " << item << " does not have label '" << label << "'."); + this->labelings[nameToLabelingIndexMap.at(label)].set(item, false); + } + bool ItemLabeling::getItemHasLabel(std::string const& label, uint64_t item) const { STORM_LOG_THROW(this->containsLabel(label), storm::exceptions::InvalidArgumentException, "The label '" << label << "' is invalid for the labeling of the model."); return this->labelings[nameToLabelingIndexMap.at(label)].get(item); diff --git a/src/storm/models/sparse/ItemLabeling.h b/src/storm/models/sparse/ItemLabeling.h index 04d9950d2..6a20881d4 100644 --- a/src/storm/models/sparse/ItemLabeling.h +++ b/src/storm/models/sparse/ItemLabeling.h @@ -51,7 +51,6 @@ namespace storm { */ bool operator==(ItemLabeling const& other) const; - /*! * Adds a new label to the labelings. Initially, no item is labeled with this label. * @@ -59,6 +58,13 @@ namespace storm { */ void addLabel(std::string const& label); + /*! + * Removes a label from the labelings. + * + * @param label The name of the label to remove. + */ + void removeLabel(std::string const& label); + /*! * Retrieves the set of labels contained in this labeling. * @@ -191,6 +197,14 @@ namespace storm { */ virtual void addLabelToItem(std::string const& label, uint64_t item); + /*! + * Removes a label from a given item. + * + * @param label The name of the label to remove. + * @param item The index of the item. + */ + virtual void removeLabelFromItem(std::string const& label, uint64_t item); + // The number of items for which this object can hold the labeling. diff --git a/src/storm/models/sparse/StateLabeling.cpp b/src/storm/models/sparse/StateLabeling.cpp index bcb00b7a9..1f0f88b8d 100644 --- a/src/storm/models/sparse/StateLabeling.cpp +++ b/src/storm/models/sparse/StateLabeling.cpp @@ -52,6 +52,10 @@ namespace storm { void StateLabeling::addLabelToState(std::string const& label, storm::storage::sparse::state_type state) { ItemLabeling::addLabelToItem(label, state); } + + void StateLabeling::removeLabelFromState(std::string const& label, storm::storage::sparse::state_type state) { + ItemLabeling::removeLabelFromItem(label, state); + } bool StateLabeling::getStateHasLabel(std::string const& label, storm::storage::sparse::state_type state) const { return ItemLabeling::getItemHasLabel(label, state); diff --git a/src/storm/models/sparse/StateLabeling.h b/src/storm/models/sparse/StateLabeling.h index d781fb4f4..ad431d0e2 100644 --- a/src/storm/models/sparse/StateLabeling.h +++ b/src/storm/models/sparse/StateLabeling.h @@ -65,6 +65,14 @@ namespace storm { * @param state The index of the state to label. */ void addLabelToState(std::string const& label, storm::storage::sparse::state_type state); + + /*! + * Removes a label from a given state. + * + * @param label The name of the label to remove. + * @param state The index of the state. + */ + void removeLabelFromState(std::string const& label, storm::storage::sparse::state_type state); /*! * Checks whether a given state is labeled with the given label. diff --git a/src/test/storm/model/StateLabelingTest.cpp b/src/test/storm/model/StateLabelingTest.cpp new file mode 100644 index 000000000..f968480f8 --- /dev/null +++ b/src/test/storm/model/StateLabelingTest.cpp @@ -0,0 +1,37 @@ +#include "test/storm_gtest.h" +#include "storm-config.h" + + +#include "storm/models/sparse/StateLabeling.h" + + + +TEST(StateLabelingTest, RemoveLabel) { + + storm::models::sparse::StateLabeling labeling(10); + EXPECT_EQ(10ul, labeling.getNumberOfItems()); + EXPECT_EQ(0ul, labeling.getNumberOfLabels()); + + storm::storage::BitVector statesTest1 = storm::storage::BitVector(10, {1, 4, 6, 7}); + labeling.addLabel("test1", statesTest1); + EXPECT_TRUE(labeling.containsLabel("test1")); + EXPECT_FALSE(labeling.containsLabel("abc")); + + storm::storage::BitVector statesTest2 = storm::storage::BitVector(10, {2, 6, 7, 8, 9}); + labeling.addLabel("test2", statesTest2); + + EXPECT_FALSE(labeling.getStateHasLabel("test2", 5)); + labeling.addLabelToState("test2", 5); + EXPECT_TRUE(labeling.getStateHasLabel("test2", 5)); + + EXPECT_TRUE(labeling.getStateHasLabel("test1", 4)); + labeling.removeLabelFromState("test1", 4); + EXPECT_FALSE(labeling.getStateHasLabel("test1", 4)); + + EXPECT_EQ(2ul, labeling.getNumberOfLabels()); + EXPECT_TRUE(labeling.getStateHasLabel("test1", 6)); + labeling.removeLabel("test1"); + EXPECT_FALSE(labeling.containsLabel("test1")); + EXPECT_EQ(1ul, labeling.getNumberOfLabels()); + EXPECT_TRUE(labeling.getStateHasLabel("test2", 5)); +}