diff --git a/CMakeLists.txt b/CMakeLists.txt index 4229059b3..858e42bf0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -423,6 +423,14 @@ add_subdirectory("${PROJECT_SOURCE_DIR}/resources/3rdparty/glpk-4.53") include_directories("${PROJECT_SOURCE_DIR}/resources/3rdparty/glpk-4.53/src") target_link_libraries(storm "glpk") +############################################################# +## +## ExprTk +## +############################################################# +message (STATUS "StoRM - Including ExprTk") +include_directories("${PROJECT_SOURCE_DIR}/resources/3rdparty/exprtk") + ############################################################# ## ## Z3 (optional) diff --git a/src/adapters/ExplicitModelAdapter.h b/src/adapters/ExplicitModelAdapter.h index a6e517038..6b597260e 100644 --- a/src/adapters/ExplicitModelAdapter.h +++ b/src/adapters/ExplicitModelAdapter.h @@ -233,8 +233,8 @@ namespace storm { * @param actionIndex The index of the action label to select. * @return A list of lists of active commands or nothing. */ - static boost::optional>>> getActiveCommandsByActionIndex(storm::prism::Program const& program, StateType const* state, uint_fast64_t const& actionIndex) { - boost::optional>>> result((std::vector>>())); + static boost::optional>>> getActiveCommandsByActionIndex(storm::prism::Program const& program, StateType const* state, uint_fast64_t const& actionIndex) { + boost::optional>>> result((std::vector>>())); // Iterate over all modules. for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) { @@ -250,10 +250,10 @@ namespace storm { // If the module contains the action, but there is no command in the module that is labeled with // this action, we don't have any feasible command combinations. if (commandIndices.empty()) { - return boost::optional>>>(); + return boost::optional>>>(); } - std::list> commands; + std::vector> commands; // Look up commands by their indices and add them if the guard evaluates to true in the given state. for (uint_fast64_t commandIndex : commandIndices) { @@ -266,7 +266,7 @@ namespace storm { // If there was no enabled command although the module has some command with the required action label, // we must not return anything. if (commands.size() == 0) { - return boost::optional>>>(); + return boost::optional>>>(); } result.get().push_back(std::move(commands)); @@ -274,8 +274,8 @@ namespace storm { return result; } - static std::list> getUnlabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue& stateQueue) { - std::list> result; + static std::vector> getUnlabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue& stateQueue) { + std::vector> result; StateType const* currentState = stateInformation.reachableStates[stateIndex]; @@ -328,17 +328,17 @@ namespace storm { return result; } - static std::list> getLabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue& stateQueue) { - std::list> result; + static std::vector> getLabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue& stateQueue) { + std::vector> result; for (uint_fast64_t actionIndex : program.getActionIndices()) { StateType const* currentState = stateInformation.reachableStates[stateIndex]; - boost::optional>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(program, currentState, actionIndex); + boost::optional>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(program, currentState, actionIndex); // Only process this action label, if there is at least one feasible solution. if (optionalActiveCommandLists) { - std::vector>> const& activeCommandList = optionalActiveCommandLists.get(); - std::vector>::const_iterator> iteratorList(activeCommandList.size()); + std::vector>> const& activeCommandList = optionalActiveCommandLists.get(); + std::vector>::const_iterator> iteratorList(activeCommandList.size()); // Initialize the list of iterators. for (size_t i = 0; i < activeCommandList.size(); ++i) { @@ -505,8 +505,8 @@ namespace storm { uint_fast64_t currentState = stateQueue.front(); // Retrieve all choices for the current state. - std::list> allUnlabeledChoices = getUnlabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue); - std::list> allLabeledChoices = getLabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue); + std::vector> allUnlabeledChoices = getUnlabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue); + std::vector> allLabeledChoices = getLabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue); uint_fast64_t totalNumberOfChoices = allUnlabeledChoices.size() + allLabeledChoices.size(); diff --git a/src/parser/PrismParser.cpp b/src/parser/PrismParser.cpp index 66e239c3a..5312e05cc 100644 --- a/src/parser/PrismParser.cpp +++ b/src/parser/PrismParser.cpp @@ -225,7 +225,7 @@ namespace storm { storm::prism::Constant PrismParser::createUndefinedBooleanConstant(std::string const& newConstant) const { if (!this->secondRun) { try { - storm::expressions::Variable newVariable = manager->declareBooleanVariable(newConstant); + storm::expressions::Variable newVariable = manager->declareBooleanVariable(newConstant, true); this->identifiers_.add(newConstant, newVariable.getExpression()); } catch (storm::exceptions::InvalidArgumentException const& e) { if (manager->hasVariable(newConstant)) { @@ -241,7 +241,7 @@ namespace storm { storm::prism::Constant PrismParser::createUndefinedIntegerConstant(std::string const& newConstant) const { if (!this->secondRun) { try { - storm::expressions::Variable newVariable = manager->declareIntegerVariable(newConstant); + storm::expressions::Variable newVariable = manager->declareIntegerVariable(newConstant, true); this->identifiers_.add(newConstant, newVariable.getExpression()); } catch (storm::exceptions::InvalidArgumentException const& e) { if (manager->hasVariable(newConstant)) { @@ -257,7 +257,7 @@ namespace storm { storm::prism::Constant PrismParser::createUndefinedDoubleConstant(std::string const& newConstant) const { if (!this->secondRun) { try { - storm::expressions::Variable newVariable = manager->declareRationalVariable(newConstant); + storm::expressions::Variable newVariable = manager->declareRationalVariable(newConstant, true); this->identifiers_.add(newConstant, newVariable.getExpression()); } catch (storm::exceptions::InvalidArgumentException const& e) { if (manager->hasVariable(newConstant)) { @@ -273,7 +273,7 @@ namespace storm { storm::prism::Constant PrismParser::createDefinedBooleanConstant(std::string const& newConstant, storm::expressions::Expression expression) const { if (!this->secondRun) { try { - storm::expressions::Variable newVariable = manager->declareBooleanVariable(newConstant); + storm::expressions::Variable newVariable = manager->declareBooleanVariable(newConstant, true); this->identifiers_.add(newConstant, newVariable.getExpression()); } catch (storm::exceptions::InvalidArgumentException const& e) { if (manager->hasVariable(newConstant)) { @@ -289,7 +289,7 @@ namespace storm { storm::prism::Constant PrismParser::createDefinedIntegerConstant(std::string const& newConstant, storm::expressions::Expression expression) const { if (!this->secondRun) { try { - storm::expressions::Variable newVariable = manager->declareIntegerVariable(newConstant); + storm::expressions::Variable newVariable = manager->declareIntegerVariable(newConstant, true); this->identifiers_.add(newConstant, newVariable.getExpression()); } catch (storm::exceptions::InvalidArgumentException const& e) { if (manager->hasVariable(newConstant)) { @@ -305,7 +305,7 @@ namespace storm { storm::prism::Constant PrismParser::createDefinedDoubleConstant(std::string const& newConstant, storm::expressions::Expression expression) const { if (!this->secondRun) { try { - storm::expressions::Variable newVariable = manager->declareRationalVariable(newConstant); + storm::expressions::Variable newVariable = manager->declareRationalVariable(newConstant, true); this->identifiers_.add(newConstant, newVariable.getExpression()); } catch (storm::exceptions::InvalidArgumentException const& e) { if (manager->hasVariable(newConstant)) { diff --git a/src/storage/BitVector.cpp b/src/storage/BitVector.cpp index dc9140ede..238237087 100644 --- a/src/storage/BitVector.cpp +++ b/src/storage/BitVector.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "src/storage/BitVector.h" #include "src/exceptions/InvalidArgumentException.h" @@ -7,10 +8,7 @@ #include "src/utility/OsDetection.h" #include "src/utility/Hash.h" -#include "log4cplus/logger.h" -#include "log4cplus/loggingmacros.h" - -extern log4cplus::Logger logger; +#include "src/utility/macros.h" namespace storm { namespace storage { @@ -80,6 +78,10 @@ namespace storm { set(begin, end); } + BitVector::BitVector(uint_fast64_t bucketCount, uint_fast64_t bitCount) : bitCount(bitCount), bucketVector(bucketCount) { + STORM_LOG_ASSERT((bucketCount << 6) == bitCount, "Bit count does not match number of buckets."); + } + BitVector::BitVector(BitVector const& other) : bitCount(other.bitCount), bucketVector(other.bucketVector) { // Intentionally left empty. } @@ -360,6 +362,52 @@ namespace storm { return true; } + bool BitVector::matches(uint_fast64_t bitIndex, BitVector const& other) const { + STORM_LOG_ASSERT((bitIndex & mod64mask) == 0, "Bit index must be a multiple of 64."); + STORM_LOG_ASSERT(other.size() <= this->size() - bitIndex, "Bit vector argument is too long."); + + // Compute the first bucket that needs to be checked and the number of buckets. + uint64_t index = bitIndex >> 6; + + std::vector::const_iterator first1 = bucketVector.begin() + index; + std::vector::const_iterator first2 = other.bucketVector.begin(); + std::vector::const_iterator last2 = other.bucketVector.end(); + + for (; first2 != last2; ++first1, ++first2) { + if (*first1 != *first2) { + return false; + } + } + return true; + } + + void BitVector::set(uint_fast64_t bitIndex, BitVector const& other) { + STORM_LOG_ASSERT((bitIndex & mod64mask) == 0, "Bit index must be a multiple of 64."); + STORM_LOG_ASSERT(other.size() <= this->size() - bitIndex, "Bit vector argument is too long."); + + // Compute the first bucket that needs to be checked and the number of buckets. + uint64_t index = bitIndex >> 6; + + std::vector::iterator first1 = bucketVector.begin() + index; + std::vector::const_iterator first2 = other.bucketVector.begin(); + std::vector::const_iterator last2 = other.bucketVector.end(); + + for (; first2 != last2; ++first1, ++first2) { + *first1 = *first2; + } + } + + storm::storage::BitVector BitVector::get(uint_fast64_t bitIndex, uint_fast64_t numberOfBits) { + uint64_t numberOfBuckets = numberOfBits >> 6; + uint64_t index = bitIndex >> 6; + STORM_LOG_ASSERT(index + numberOfBuckets <= this->bucketCount(), "Argument is out-of-range."); + + storm::storage::BitVector result(numberOfBuckets, numberOfBits); + std::copy(this->bucketVector.begin() + index, this->bucketVector.begin() + index + numberOfBuckets, result.bucketVector.begin()); + + return result; + } + bool BitVector::empty() const { for (auto& element : bucketVector) { if (element != 0) { @@ -508,7 +556,11 @@ namespace storm { bucketVector.back() &= (1ll << (bitCount & mod64mask)) - 1ll; } } - + + size_t BitVector::bucketCount() const { + return bucketVector.size(); + } + std::ostream& operator<<(std::ostream& out, BitVector const& bitVector) { out << "bit vector(" << bitVector.getNumberOfSetBits() << "/" << bitVector.bitCount << ") ["; for (auto index : bitVector) { @@ -529,4 +581,10 @@ namespace storm { template void BitVector::set(boost::container::flat_set::iterator begin, boost::container::flat_set::iterator end); template void BitVector::set(boost::container::flat_set::const_iterator begin, boost::container::flat_set::const_iterator end); } +} + +namespace std { + std::size_t hash::operator()(storm::storage::BitVector const& bv) { + return boost::hash_range(bv.bucketVector.begin(), bv.bucketVector.end()); + } } \ No newline at end of file diff --git a/src/storage/BitVector.h b/src/storage/BitVector.h index 806febcc9..27779ea81 100644 --- a/src/storage/BitVector.h +++ b/src/storage/BitVector.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace storm { namespace storage { @@ -318,6 +319,36 @@ namespace storm { */ bool isDisjointFrom(BitVector const& other) const; + /*! + * Checks whether the given bit vector matches the bits starting from the given index in the current bit + * vector. Note: the given bit vector must be shorter than the current one minus the given index. + * + * @param bitIndex The index of the first bit that it supposed to match. This value must be a multiple of 64. + * @param other The bit vector with which to compare. + * @return bool True iff the bits match exactly. + */ + bool matches(uint_fast64_t bitIndex, BitVector const& other) const; + + /*! + * Sets the exact bit pattern of the given bit vector starting at the given bit index. Note: the given bit + * vector must be shorter than the current one minus the given index. + * + * @param bitIndex The index of the first bit that it supposed to be set. This value must be a multiple of + * 64. + * @param other The bit vector whose pattern to set. + */ + void set(uint_fast64_t bitIndex, BitVector const& other); + + /*! + * Retrieves the content of the current bit vector at the given index for the given number of bits as a new + * bit vector. + * + * @param bitIndex The index of the first bit to get. This value must be a multiple of 64. + * @param numberOfBits The number of bits to get. This value must be a multiple of 64. + * @return A new bit vector holding the selected bits. + */ + storm::storage::BitVector get(uint_fast64_t bitIndex, uint_fast64_t numberOfBits); + /*! * Retrieves whether no bits are set to true in this bit vector. * @@ -395,8 +426,17 @@ namespace storm { uint_fast64_t getNextSetIndex(uint_fast64_t startingIndex) const; friend std::ostream& operator<<(std::ostream& out, BitVector const& bitVector); + friend struct std::hash; private: + /*! + * Creates an empty bit vector with the given number of buckets. + * + * @param bucketCount The number of buckets to create. + * @param bitCount This must be the number of buckets times 64. + */ + BitVector(uint_fast64_t bucketCount, uint_fast64_t bitCount); + /*! * Retrieves the index of the next bit that is set to true after (and including) the given starting index. * @@ -412,7 +452,14 @@ namespace storm { * Truncate the last bucket so that no bits are set starting from bitCount. */ void truncateLastBucket(); - + + /*! + * Retrieves the number of buckets of the underlying storage. + * + * @return The number of buckets of the underlying storage. + */ + size_t bucketCount() const; + // The number of bits that this bit vector can hold. uint_fast64_t bitCount; @@ -426,4 +473,11 @@ namespace storm { } // namespace storage } // namespace storm +namespace std { + template <> + struct hash { + std::size_t operator()(storm::storage::BitVector const& bv); + }; +} + #endif // STORM_STORAGE_BITVECTOR_H_ diff --git a/src/storage/BitVectorHashMap.cpp b/src/storage/BitVectorHashMap.cpp new file mode 100644 index 000000000..4dba68d8e --- /dev/null +++ b/src/storage/BitVectorHashMap.cpp @@ -0,0 +1,92 @@ +#include "src/storage/BitVectorHashMap.h" + +#include + +#include "src/utility/macros.h" + +namespace storm { + namespace storage { + template + const std::vector BitVectorHashMap::sizes = {5, 13, 31, 79, 163, 277, 499, 1021, 2029, 3989, 8059, 16001, 32099, 64301, 127921, 256499, 511111, 1024901, 2048003, 4096891, 8192411, 15485863}; + + template + BitVectorHashMap::BitVectorHashMap(uint64_t bucketSize, uint64_t initialSize, double loadFactor) : loadFactor(loadFactor), bucketSize(bucketSize), numberOfElements(0) { + STORM_LOG_ASSERT(bucketSize % 64 == 0, "Bucket size must be a multiple of 64."); + currentSizeIterator = std::find_if(sizes.begin(), sizes.end(), [=] (uint64_t value) { return value > initialSize; } ); + + // Create the underlying containers. + buckets = storm::storage::BitVector(bucketSize * *currentSizeIterator); + occupied = storm::storage::BitVector(*currentSizeIterator); + values = std::vector(*currentSizeIterator); + } + + template + bool BitVectorHashMap::isBucketOccupied(uint_fast64_t bucket) { + return occupied.get(bucket); + } + + template + std::size_t BitVectorHashMap::size() const { + return numberOfElements; + } + + template + std::size_t BitVectorHashMap::capacity() const { + return *currentSizeIterator; + } + + template + void BitVectorHashMap::increaseSize() { + ++currentSizeIterator; + STORM_LOG_ASSERT(currentSizeIterator != sizes.end(), "Hash map became to big."); + + storm::storage::BitVector oldBuckets(bucketSize * *currentSizeIterator); + std::swap(oldBuckets, buckets); + storm::storage::BitVector oldOccupied = storm::storage::BitVector(*currentSizeIterator); + std::swap(oldOccupied, occupied); + std::vector oldValues = std::vector(*currentSizeIterator); + std::swap(oldValues, values); + + // Now iterate through the elements and reinsert them in the new storage. + numberOfElements = 0; + for (auto const& bucketIndex : oldOccupied) { + this->findOrAdd(oldBuckets.get(bucketIndex * bucketSize, bucketSize), oldValues[bucketIndex]); + } + } + + template + ValueType BitVectorHashMap::findOrAdd(storm::storage::BitVector const& key, ValueType value) { + // If the load of the map is too high, we increase the size. + if (numberOfElements >= loadFactor * *currentSizeIterator) { + this->increaseSize(); + } + + uint_fast64_t initialHash = hasher1(key) % *currentSizeIterator; + uint_fast64_t bucket = initialHash; + + while (isBucketOccupied(bucket)) { + if (buckets.matches(bucket * bucketSize, key)) { + return values[bucket]; + } + bucket += hasher2(key); + bucket %= *currentSizeIterator; + + // If we arrived at the original position, this means that we have visited all possible locations, but + // could not find a suitable position. This implies that we have to enlarge the map in order to resolve + // the issue. + if (bucket == initialHash) { + this->increaseSize(); + } + } + + // Insert the new bits into the bucket. + buckets.set(bucket * bucketSize, key); + occupied.set(bucket); + values[bucket] = value; + ++numberOfElements; + return value; + } + + template class BitVectorHashMap; + } +} diff --git a/src/storage/BitVectorHashMap.h b/src/storage/BitVectorHashMap.h new file mode 100644 index 000000000..a6894219a --- /dev/null +++ b/src/storage/BitVectorHashMap.h @@ -0,0 +1,102 @@ +#ifndef STORM_STORAGE_BITVECTORHASHMAP_H_ +#define STORM_STORAGE_BITVECTORHASHMAP_H_ + +#include +#include + +#include "src/storage/BitVector.h" + +namespace storm { + namespace storage { + + /*! + * This class represents a hash-map whose keys are bit vectors. The value type is arbitrary. Currently, only + * queries and insertions are supported. Also, the keys must be bit vectors with a length that is a multiple of + * 64. + */ + template, class Hash2 = std::hash> + class BitVectorHashMap { + public: + /*! + * Creates a new hash map with the given bucket size and initial size. + * + * @param bucketSize The size of the buckets that this map can hold. This value must be a multiple of 64. + * @param initialSize The number of buckets that is initially available. + * @param loadFactor The load factor that determines at which point the size of the underlying storage is + * increased. + */ + BitVectorHashMap(uint64_t bucketSize, uint64_t initialSize, double loadFactor = 0.75); + + /*! + * Searches for the given key in the map. If it is found, the mapped-to value is returned. Otherwise, the + * key is inserted with the given value. + * + * @param key The key to search or insert. + * @param value The value that is inserted if the key is not already found in the map. + * @return The found value if the key is already contained in the map and the provided new value otherwise. + */ + ValueType findOrAdd(storm::storage::BitVector const& key, ValueType value); + + /*! + * Retrieves the size of the map in terms of the number of key-value pairs it stores. + * + * @return The size of the map. + */ + std::size_t size() const; + + /*! + * Retrieves the capacity of the underlying container. + * + * @return The capacity of the underlying container. + */ + std::size_t capacity() const; + + private: + /*! + * Retrieves whether the given bucket holds a value. + * + * @param bucket The bucket to check. + * @return True iff the bucket is occupied. + */ + bool isBucketOccupied(uint_fast64_t bucket); + + /*! + * Increases the size of the hash map and performs the necessary rehashing of all entries. + */ + void increaseSize(); + + // The load factor determining when the size of the map is increased. + double loadFactor; + + // The size of one bucket. + uint64_t bucketSize; + + // The number of buckets. + std::size_t numberOfBuckets; + + // The buckets that hold the elements of the map. + storm::storage::BitVector buckets; + + // A bit vector that stores which buckets actually hold a value. + storm::storage::BitVector occupied; + + // A vector of the mapped-to values. The entry at position i is the "target" of the key in bucket i. + std::vector values; + + // The number of elements in this map. + std::size_t numberOfElements; + + // An iterator to a value in the static sizes table. + std::vector::const_iterator currentSizeIterator; + + // Functor object that are used to perform the actual hashing. + Hash1 hasher1; + Hash2 hasher2; + + // A static table that produces the next possible size of the hash table. + static const std::vector sizes; + }; + } +} + +#endif /* STORM_STORAGE_BITVECTORHASHMAP_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/ExpressionEvaluator.cpp b/src/storage/expressions/ExpressionEvaluator.cpp index e69de29bb..e54beaaca 100644 --- a/src/storage/expressions/ExpressionEvaluator.cpp +++ b/src/storage/expressions/ExpressionEvaluator.cpp @@ -0,0 +1,70 @@ +#include "src/storage/expressions/ExpressionEvaluator.h" +#include "src/storage/expressions/ExpressionManager.h" + +namespace storm { + namespace expressions { + ExpressionEvaluator::ExpressionEvaluator(storm::expressions::ExpressionManager const& manager) : manager(manager.getSharedPointer()), booleanValues(manager.getNumberOfBooleanVariables()), integerValues(manager.getNumberOfIntegerVariables()), rationalValues(manager.getNumberOfRationalVariables()) { + + for (auto const& variableTypePair : manager) { + if (variableTypePair.second.isBooleanType()) { + symbolTable.add_variable(variableTypePair.first.getName(), this->booleanValues[variableTypePair.first.getOffset()]); + } else if (variableTypePair.second.isIntegerType()) { + symbolTable.add_variable(variableTypePair.first.getName(), this->integerValues[variableTypePair.first.getOffset()]); + } else if (variableTypePair.second.isRationalType()) { + symbolTable.add_variable(variableTypePair.first.getName(), this->rationalValues[variableTypePair.first.getOffset()]); + } + } + symbolTable.add_constants(); + } + + bool ExpressionEvaluator::asBool(Expression const& expression) { + BaseExpression const* expressionPtr = expression.getBaseExpressionPointer().get(); + auto const& expressionPair = this->compiledExpressions.find(expression.getBaseExpressionPointer().get()); + if (expressionPair == this->compiledExpressions.end()) { + CompiledExpressionType const& compiledExpression = this->getCompiledExpression(expressionPtr); + return compiledExpression.value() == ValueType(1); + } + return expressionPair->second.value() == ValueType(1); + } + + int_fast64_t ExpressionEvaluator::asInt(Expression const& expression) { + BaseExpression const* expressionPtr = expression.getBaseExpressionPointer().get(); + auto const& expressionPair = this->compiledExpressions.find(expression.getBaseExpressionPointer().get()); + if (expressionPair == this->compiledExpressions.end()) { + CompiledExpressionType const& compiledExpression = this->getCompiledExpression(expressionPtr); + return static_cast(compiledExpression.value()); + } + return static_cast(expressionPair->second.value()); + } + + double ExpressionEvaluator::asDouble(Expression const& expression) { + BaseExpression const* expressionPtr = expression.getBaseExpressionPointer().get(); + auto const& expressionPair = this->compiledExpressions.find(expression.getBaseExpressionPointer().get()); + if (expressionPair == this->compiledExpressions.end()) { + CompiledExpressionType const& compiledExpression = this->getCompiledExpression(expressionPtr); + return static_cast(compiledExpression.value()); + } + return static_cast(expressionPair->second.value()); + } + + ExpressionEvaluator::CompiledExpressionType& ExpressionEvaluator::getCompiledExpression(BaseExpression const* expression) { + std::pair result = this->compiledExpressions.emplace(expression, CompiledExpressionType()); + CompiledExpressionType& compiledExpression = result.first->second; + compiledExpression.register_symbol_table(symbolTable); + bool parsingOk = parser.compile(stringTranslator.toString(expression), compiledExpression); + return compiledExpression; + } + + void ExpressionEvaluator::setBooleanValue(storm::expressions::Variable const& variable, bool value) { + this->booleanValues[variable.getOffset()] = static_cast(value); + } + + void ExpressionEvaluator::setIntegerValue(storm::expressions::Variable const& variable, int_fast64_t value) { + this->integerValues[variable.getOffset()] = static_cast(value); + } + + void ExpressionEvaluator::setRationalValue(storm::expressions::Variable const& variable, double value) { + this->rationalValues[variable.getOffset()] = static_cast(value); + } + } +} \ No newline at end of file diff --git a/src/storage/expressions/ExpressionEvaluator.h b/src/storage/expressions/ExpressionEvaluator.h index e69de29bb..1d5322774 100644 --- a/src/storage/expressions/ExpressionEvaluator.h +++ b/src/storage/expressions/ExpressionEvaluator.h @@ -0,0 +1,67 @@ +#ifndef STORM_STORAGE_EXPRESSIONS_EXPRESSIONEVALUATOR_H_ +#define STORM_STORAGE_EXPRESSIONS_EXPRESSIONEVALUATOR_H_ + +#include +#include + +#include "exprtk.hpp" + +#include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/ExpressionVisitor.h" +#include "src/storage/expressions/ToExprtkStringVisitor.h" + +namespace storm { + namespace expressions { + class ExpressionEvaluator { + public: + /*! + * Creates an expression evaluator that is capable of evaluating expressions managed by the given manager. + * + * @param manager The manager responsible for the expressions. + */ + ExpressionEvaluator(storm::expressions::ExpressionManager const& manager); + + bool asBool(Expression const& expression); + int_fast64_t asInt(Expression const& expression); + double asDouble(Expression const& expression); + + void setBooleanValue(storm::expressions::Variable const& variable, bool value); + void setIntegerValue(storm::expressions::Variable const& variable, int_fast64_t value); + void setRationalValue(storm::expressions::Variable const& variable, double value); + + private: + typedef double ValueType; + typedef exprtk::expression CompiledExpressionType; + typedef std::unordered_map CacheType; + + /*! + * Adds a compiled version of the given expression to the internal storage. + * + * @param expression The expression that is to be compiled. + */ + CompiledExpressionType& getCompiledExpression(BaseExpression const* expression); + + // The expression manager that is used by this evaluator. + std::shared_ptr manager; + + // The parser used. + exprtk::parser parser; + + // The symbol table used. + exprtk::symbol_table symbolTable; + + // The actual data that is fed into the expression. + std::vector booleanValues; + std::vector integerValues; + std::vector rationalValues; + + // A mapping of expressions to their compiled counterpart. + CacheType compiledExpressions; + + // A translator that can be used for transforming an expression into the correct string format. + ToExprtkStringVisitor stringTranslator; + }; + } +} + +#endif /* STORM_STORAGE_EXPRESSIONS_EXPRESSIONEVALUATOR_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/LinearityCheckVisitor.h b/src/storage/expressions/LinearityCheckVisitor.h index b9e1bf15c..2df8f8084 100644 --- a/src/storage/expressions/LinearityCheckVisitor.h +++ b/src/storage/expressions/LinearityCheckVisitor.h @@ -1,8 +1,6 @@ #ifndef STORM_STORAGE_EXPRESSIONS_LINEARITYCHECKVISITOR_H_ #define STORM_STORAGE_EXPRESSIONS_LINEARITYCHECKVISITOR_H_ -#include - #include "src/storage/expressions/Expression.h" #include "src/storage/expressions/ExpressionVisitor.h" diff --git a/src/storage/expressions/ToExprtkStringVisitor.cpp b/src/storage/expressions/ToExprtkStringVisitor.cpp index e69de29bb..e9d2db2c0 100644 --- a/src/storage/expressions/ToExprtkStringVisitor.cpp +++ b/src/storage/expressions/ToExprtkStringVisitor.cpp @@ -0,0 +1,197 @@ +#include "src/storage/expressions/ToExprtkStringVisitor.h" + +namespace storm { + namespace expressions { + std::string ToExprtkStringVisitor::toString(Expression const& expression) { + return toString(expression.getBaseExpressionPointer().get()); + } + + std::string ToExprtkStringVisitor::toString(BaseExpression const* expression) { + stream = std::stringstream(); + expression->accept(*this); + return std::move(stream.str()); + } + + boost::any ToExprtkStringVisitor::visit(IfThenElseExpression const& expression) { + stream << "if("; + expression.getCondition()->accept(*this); + stream << ","; + expression.getThenExpression()->accept(*this); + stream << ","; + expression.getElseExpression()->accept(*this); + stream << ")"; + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(BinaryBooleanFunctionExpression const& expression) { + switch (expression.getOperatorType()) { + case BinaryBooleanFunctionExpression::OperatorType::And: + stream << "and("; + expression.getFirstOperand()->accept(*this); + stream << ","; + expression.getFirstOperand()->accept(*this); + stream << ")"; + break; + case BinaryBooleanFunctionExpression::OperatorType::Or: + stream << "or("; + expression.getFirstOperand()->accept(*this); + stream << ","; + expression.getFirstOperand()->accept(*this); + stream << ")"; + break; + case BinaryBooleanFunctionExpression::OperatorType::Xor: + stream << "xor("; + expression.getFirstOperand()->accept(*this); + stream << ","; + expression.getFirstOperand()->accept(*this); + stream << ")"; + break; + case BinaryBooleanFunctionExpression::OperatorType::Implies: + stream << "or(not("; + expression.getFirstOperand()->accept(*this); + stream << "),"; + expression.getFirstOperand()->accept(*this); + stream << ")"; + break; + case BinaryBooleanFunctionExpression::OperatorType::Iff: + expression.getFirstOperand()->accept(*this); + stream << "=="; + expression.getFirstOperand()->accept(*this); + break; + } + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(BinaryNumericalFunctionExpression const& expression) { + switch (expression.getOperatorType()) { + case BinaryNumericalFunctionExpression::OperatorType::Plus: + expression.getFirstOperand()->accept(*this); + stream << "+"; + expression.getSecondOperand()->accept(*this); + break; + case BinaryNumericalFunctionExpression::OperatorType::Minus: + expression.getFirstOperand()->accept(*this); + stream << "-"; + expression.getSecondOperand()->accept(*this); + break; + case BinaryNumericalFunctionExpression::OperatorType::Times: + expression.getFirstOperand()->accept(*this); + stream << "*"; + expression.getSecondOperand()->accept(*this); + break; + case BinaryNumericalFunctionExpression::OperatorType::Divide: + expression.getFirstOperand()->accept(*this); + stream << "/"; + expression.getSecondOperand()->accept(*this); + break; + case BinaryNumericalFunctionExpression::OperatorType::Power: + expression.getFirstOperand()->accept(*this); + stream << "^"; + expression.getSecondOperand()->accept(*this); + break; + case BinaryNumericalFunctionExpression::OperatorType::Max: + stream << "max("; + expression.getFirstOperand()->accept(*this); + stream << ","; + expression.getSecondOperand()->accept(*this); + stream << ")"; + break; + case BinaryNumericalFunctionExpression::OperatorType::Min: + stream << "min("; + expression.getFirstOperand()->accept(*this); + stream << ","; + expression.getSecondOperand()->accept(*this); + stream << ")"; + break; + } + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(BinaryRelationExpression const& expression) { + switch (expression.getRelationType()) { + case BinaryRelationExpression::RelationType::Equal: + expression.getFirstOperand()->accept(*this); + stream << "=="; + expression.getSecondOperand()->accept(*this); + break; + case BinaryRelationExpression::RelationType::NotEqual: + expression.getFirstOperand()->accept(*this); + stream << "!="; + expression.getSecondOperand()->accept(*this); + break; + case BinaryRelationExpression::RelationType::Less: + expression.getFirstOperand()->accept(*this); + stream << "<"; + expression.getSecondOperand()->accept(*this); + break; + case BinaryRelationExpression::RelationType::LessOrEqual: + expression.getFirstOperand()->accept(*this); + stream << "<="; + expression.getSecondOperand()->accept(*this); + break; + case BinaryRelationExpression::RelationType::Greater: + expression.getFirstOperand()->accept(*this); + stream << ">"; + expression.getSecondOperand()->accept(*this); + break; + case BinaryRelationExpression::RelationType::GreaterOrEqual: + expression.getFirstOperand()->accept(*this); + stream << ">="; + expression.getSecondOperand()->accept(*this); + break; + } + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(VariableExpression const& expression) { + stream << expression.getVariableName(); + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(UnaryBooleanFunctionExpression const& expression) { + switch (expression.getOperatorType()) { + case UnaryBooleanFunctionExpression::OperatorType::Not: + stream << "not("; + expression.getOperand()->accept(*this); + stream << ")"; + } + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(UnaryNumericalFunctionExpression const& expression) { + switch (expression.getOperatorType()) { + case UnaryNumericalFunctionExpression::OperatorType::Minus: + stream << "-("; + expression.getOperand()->accept(*this); + stream << ")"; + break; + case UnaryNumericalFunctionExpression::OperatorType::Floor: + stream << "floor("; + expression.getOperand()->accept(*this); + stream << ")"; + break; + case UnaryNumericalFunctionExpression::OperatorType::Ceil: + stream << "ceil("; + expression.getOperand()->accept(*this); + stream << ")"; + break; + } + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(BooleanLiteralExpression const& expression) { + stream << expression.getValue(); + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(IntegerLiteralExpression const& expression) { + stream << expression.getValue(); + return boost::any(); + } + + boost::any ToExprtkStringVisitor::visit(DoubleLiteralExpression const& expression) { + stream << expression.getValue(); + return boost::any(); + } + } +} \ No newline at end of file diff --git a/src/storage/expressions/ToExprtkStringVisitor.h b/src/storage/expressions/ToExprtkStringVisitor.h index e69de29bb..6c285ff28 100644 --- a/src/storage/expressions/ToExprtkStringVisitor.h +++ b/src/storage/expressions/ToExprtkStringVisitor.h @@ -0,0 +1,37 @@ +#ifndef STORM_STORAGE_EXPRESSIONS_TOEXPRTKSTRINGVISITOR_H_ +#define STORM_STORAGE_EXPRESSIONS_TOEXPRTKSTRINGVISITOR_H_ + +#include + +#include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/Expressions.h" +#include "src/storage/expressions/ExpressionVisitor.h" + +namespace storm { + namespace expressions { + class ToExprtkStringVisitor : public ExpressionVisitor { + public: + ToExprtkStringVisitor() = default; + + std::string toString(Expression const& expression); + std::string toString(BaseExpression const* expression); + + virtual boost::any visit(IfThenElseExpression const& expression) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BinaryRelationExpression const& expression) override; + virtual boost::any visit(VariableExpression const& expression) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BooleanLiteralExpression const& expression) override; + virtual boost::any visit(IntegerLiteralExpression const& expression) override; + virtual boost::any visit(DoubleLiteralExpression const& expression) override; + + private: + std::stringstream stream; + }; + } +} + + +#endif /* STORM_STORAGE_EXPRESSIONS_TOEXPRTKSTRINGVISITOR_H_ */ \ No newline at end of file diff --git a/test/functional/storage/BitVectorHashMapTest.cpp b/test/functional/storage/BitVectorHashMapTest.cpp new file mode 100644 index 000000000..625762b40 --- /dev/null +++ b/test/functional/storage/BitVectorHashMapTest.cpp @@ -0,0 +1,53 @@ +#include "gtest/gtest.h" + +#include + +#include "src/storage/BitVector.h" +#include "src/storage/BitVectorHashMap.h" + +TEST(BitVectorHashMapTest, FindOrAdd) { + storm::storage::BitVectorHashMap map(64, 3); + + storm::storage::BitVector first(64); + first.set(4); + first.set(47); + ASSERT_NO_THROW(map.findOrAdd(first, 1)); + + storm::storage::BitVector second(64); + second.set(8); + second.set(18); + ASSERT_NO_THROW(map.findOrAdd(second, 2)); + + EXPECT_EQ(1, map.findOrAdd(first, 3)); + + storm::storage::BitVector third(64); + third.set(10); + third.set(63); + + ASSERT_NO_THROW(map.findOrAdd(third, 3)); + + storm::storage::BitVector fourth(64); + fourth.set(12); + fourth.set(14); + + ASSERT_NO_THROW(map.findOrAdd(fourth, 4)); + + storm::storage::BitVector fifth(64); + fifth.set(44); + fifth.set(55); + + ASSERT_NO_THROW(map.findOrAdd(fifth, 5)); + + storm::storage::BitVector sixth(64); + sixth.set(45); + sixth.set(55); + + ASSERT_NO_THROW(map.findOrAdd(sixth, 6)); + + EXPECT_EQ(1, map.findOrAdd(first, 0)); + EXPECT_EQ(2, map.findOrAdd(second, 0)); + EXPECT_EQ(3, map.findOrAdd(third, 0)); + EXPECT_EQ(4, map.findOrAdd(fourth, 0)); + EXPECT_EQ(5, map.findOrAdd(fifth, 0)); + EXPECT_EQ(6, map.findOrAdd(sixth, 0)); +} diff --git a/test/functional/storage/ExpressionEvalutionTest.cpp b/test/functional/storage/ExpressionEvalutionTest.cpp index e69de29bb..c622f6e73 100644 --- a/test/functional/storage/ExpressionEvalutionTest.cpp +++ b/test/functional/storage/ExpressionEvalutionTest.cpp @@ -0,0 +1,62 @@ +#include "gtest/gtest.h" +#include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/ExpressionManager.h" +#include "src/storage/expressions/SimpleValuation.h" +#include "src/storage/expressions/ExpressionEvaluator.h" + +TEST(ExpressionEvaluation, NaiveEvaluation) { + std::shared_ptr manager(new storm::expressions::ExpressionManager()); + + storm::expressions::Variable x; + storm::expressions::Variable y; + storm::expressions::Variable z; + ASSERT_NO_THROW(x = manager->declareBooleanVariable("x")); + ASSERT_NO_THROW(y = manager->declareIntegerVariable("y")); + ASSERT_NO_THROW(z = manager->declareRationalVariable("z")); + + storm::expressions::SimpleValuation eval(manager); + + storm::expressions::Expression iteExpression = storm::expressions::ite(x, y + z, manager->integer(3) * z); + + eval.setRationalValue(z, 5.5); + eval.setBooleanValue(x, true); + for (int_fast64_t i = 0; i < 1000; ++i) { + eval.setIntegerValue(y, 3 + i); + EXPECT_NEAR(8.5 + i, iteExpression.evaluateAsDouble(&eval), 1e-6); + } + + eval.setBooleanValue(x, false); + for (int_fast64_t i = 0; i < 1000; ++i) { + double zValue = i / static_cast(10); + eval.setRationalValue(z, zValue); + EXPECT_NEAR(3 * zValue, iteExpression.evaluateAsDouble(&eval), 1e-6); + } +} + +TEST(ExpressionEvaluation, ExprTkEvaluation) { + std::shared_ptr manager(new storm::expressions::ExpressionManager()); + + storm::expressions::Variable x; + storm::expressions::Variable y; + storm::expressions::Variable z; + ASSERT_NO_THROW(x = manager->declareBooleanVariable("x")); + ASSERT_NO_THROW(y = manager->declareIntegerVariable("y")); + ASSERT_NO_THROW(z = manager->declareRationalVariable("z")); + + storm::expressions::Expression iteExpression = storm::expressions::ite(x, y + z, manager->integer(3) * z); + storm::expressions::ExpressionEvaluator eval(*manager); + + eval.setRationalValue(z, 5.5); + eval.setBooleanValue(x, true); + for (int_fast64_t i = 0; i < 1000; ++i) { + eval.setIntegerValue(y, 3 + i); + EXPECT_NEAR(8.5 + i, eval.asDouble(iteExpression), 1e-6); + } + + eval.setBooleanValue(x, false); + for (int_fast64_t i = 0; i < 1000; ++i) { + double zValue = i / static_cast(10); + eval.setRationalValue(z, zValue); + EXPECT_NEAR(3 * zValue, eval.asDouble(iteExpression), 1e-6); + } +}