Browse Source

Created bit vector hash map and some necessary bit vector methods.

Former-commit-id: 4a9946a743
tempestpy_adaptions
dehnert 10 years ago
parent
commit
53196f5610
  1. 8
      CMakeLists.txt
  2. 28
      src/adapters/ExplicitModelAdapter.h
  3. 12
      src/parser/PrismParser.cpp
  4. 68
      src/storage/BitVector.cpp
  5. 56
      src/storage/BitVector.h
  6. 92
      src/storage/BitVectorHashMap.cpp
  7. 102
      src/storage/BitVectorHashMap.h
  8. 70
      src/storage/expressions/ExpressionEvaluator.cpp
  9. 67
      src/storage/expressions/ExpressionEvaluator.h
  10. 2
      src/storage/expressions/LinearityCheckVisitor.h
  11. 197
      src/storage/expressions/ToExprtkStringVisitor.cpp
  12. 37
      src/storage/expressions/ToExprtkStringVisitor.h
  13. 53
      test/functional/storage/BitVectorHashMapTest.cpp
  14. 62
      test/functional/storage/ExpressionEvalutionTest.cpp

8
CMakeLists.txt

@ -423,6 +423,14 @@ add_subdirectory("${PROJECT_SOURCE_DIR}/resources/3rdparty/glpk-4.53")
include_directories("${PROJECT_SOURCE_DIR}/resources/3rdparty/glpk-4.53/src")
target_link_libraries(storm "glpk")
#############################################################
##
## ExprTk
##
#############################################################
message (STATUS "StoRM - Including ExprTk")
include_directories("${PROJECT_SOURCE_DIR}/resources/3rdparty/exprtk")
#############################################################
##
## Z3 (optional)

28
src/adapters/ExplicitModelAdapter.h

@ -233,8 +233,8 @@ namespace storm {
* @param actionIndex The index of the action label to select.
* @return A list of lists of active commands or nothing.
*/
static boost::optional<std::vector<std::list<std::reference_wrapper<storm::prism::Command const>>>> getActiveCommandsByActionIndex(storm::prism::Program const& program, StateType const* state, uint_fast64_t const& actionIndex) {
boost::optional<std::vector<std::list<std::reference_wrapper<storm::prism::Command const>>>> result((std::vector<std::list<std::reference_wrapper<storm::prism::Command const>>>()));
static boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>> getActiveCommandsByActionIndex(storm::prism::Program const& program, StateType const* state, uint_fast64_t const& actionIndex) {
boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>> result((std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>()));
// Iterate over all modules.
for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) {
@ -250,10 +250,10 @@ namespace storm {
// If the module contains the action, but there is no command in the module that is labeled with
// this action, we don't have any feasible command combinations.
if (commandIndices.empty()) {
return boost::optional<std::vector<std::list<std::reference_wrapper<storm::prism::Command const>>>>();
return boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>>();
}
std::list<std::reference_wrapper<storm::prism::Command const>> commands;
std::vector<std::reference_wrapper<storm::prism::Command const>> commands;
// Look up commands by their indices and add them if the guard evaluates to true in the given state.
for (uint_fast64_t commandIndex : commandIndices) {
@ -266,7 +266,7 @@ namespace storm {
// If there was no enabled command although the module has some command with the required action label,
// we must not return anything.
if (commands.size() == 0) {
return boost::optional<std::vector<std::list<std::reference_wrapper<storm::prism::Command const>>>>();
return boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>>();
}
result.get().push_back(std::move(commands));
@ -274,8 +274,8 @@ namespace storm {
return result;
}
static std::list<Choice<ValueType>> getUnlabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue<uint_fast64_t>& stateQueue) {
std::list<Choice<ValueType>> result;
static std::vector<Choice<ValueType>> getUnlabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue<uint_fast64_t>& stateQueue) {
std::vector<Choice<ValueType>> result;
StateType const* currentState = stateInformation.reachableStates[stateIndex];
@ -328,17 +328,17 @@ namespace storm {
return result;
}
static std::list<Choice<ValueType>> getLabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue<uint_fast64_t>& stateQueue) {
std::list<Choice<ValueType>> result;
static std::vector<Choice<ValueType>> getLabeledTransitions(storm::prism::Program const& program, StateInformation& stateInformation, VariableInformation const& variableInformation, uint_fast64_t stateIndex, std::queue<uint_fast64_t>& stateQueue) {
std::vector<Choice<ValueType>> result;
for (uint_fast64_t actionIndex : program.getActionIndices()) {
StateType const* currentState = stateInformation.reachableStates[stateIndex];
boost::optional<std::vector<std::list<std::reference_wrapper<storm::prism::Command const>>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(program, currentState, actionIndex);
boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(program, currentState, actionIndex);
// Only process this action label, if there is at least one feasible solution.
if (optionalActiveCommandLists) {
std::vector<std::list<std::reference_wrapper<storm::prism::Command const>>> const& activeCommandList = optionalActiveCommandLists.get();
std::vector<std::list<std::reference_wrapper<storm::prism::Command const>>::const_iterator> iteratorList(activeCommandList.size());
std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>> const& activeCommandList = optionalActiveCommandLists.get();
std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>::const_iterator> iteratorList(activeCommandList.size());
// Initialize the list of iterators.
for (size_t i = 0; i < activeCommandList.size(); ++i) {
@ -505,8 +505,8 @@ namespace storm {
uint_fast64_t currentState = stateQueue.front();
// Retrieve all choices for the current state.
std::list<Choice<ValueType>> allUnlabeledChoices = getUnlabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue);
std::list<Choice<ValueType>> allLabeledChoices = getLabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue);
std::vector<Choice<ValueType>> allUnlabeledChoices = getUnlabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue);
std::vector<Choice<ValueType>> allLabeledChoices = getLabeledTransitions(program, stateInformation, variableInformation, currentState, stateQueue);
uint_fast64_t totalNumberOfChoices = allUnlabeledChoices.size() + allLabeledChoices.size();

12
src/parser/PrismParser.cpp

@ -225,7 +225,7 @@ namespace storm {
storm::prism::Constant PrismParser::createUndefinedBooleanConstant(std::string const& newConstant) const {
if (!this->secondRun) {
try {
storm::expressions::Variable newVariable = manager->declareBooleanVariable(newConstant);
storm::expressions::Variable newVariable = manager->declareBooleanVariable(newConstant, true);
this->identifiers_.add(newConstant, newVariable.getExpression());
} catch (storm::exceptions::InvalidArgumentException const& e) {
if (manager->hasVariable(newConstant)) {
@ -241,7 +241,7 @@ namespace storm {
storm::prism::Constant PrismParser::createUndefinedIntegerConstant(std::string const& newConstant) const {
if (!this->secondRun) {
try {
storm::expressions::Variable newVariable = manager->declareIntegerVariable(newConstant);
storm::expressions::Variable newVariable = manager->declareIntegerVariable(newConstant, true);
this->identifiers_.add(newConstant, newVariable.getExpression());
} catch (storm::exceptions::InvalidArgumentException const& e) {
if (manager->hasVariable(newConstant)) {
@ -257,7 +257,7 @@ namespace storm {
storm::prism::Constant PrismParser::createUndefinedDoubleConstant(std::string const& newConstant) const {
if (!this->secondRun) {
try {
storm::expressions::Variable newVariable = manager->declareRationalVariable(newConstant);
storm::expressions::Variable newVariable = manager->declareRationalVariable(newConstant, true);
this->identifiers_.add(newConstant, newVariable.getExpression());
} catch (storm::exceptions::InvalidArgumentException const& e) {
if (manager->hasVariable(newConstant)) {
@ -273,7 +273,7 @@ namespace storm {
storm::prism::Constant PrismParser::createDefinedBooleanConstant(std::string const& newConstant, storm::expressions::Expression expression) const {
if (!this->secondRun) {
try {
storm::expressions::Variable newVariable = manager->declareBooleanVariable(newConstant);
storm::expressions::Variable newVariable = manager->declareBooleanVariable(newConstant, true);
this->identifiers_.add(newConstant, newVariable.getExpression());
} catch (storm::exceptions::InvalidArgumentException const& e) {
if (manager->hasVariable(newConstant)) {
@ -289,7 +289,7 @@ namespace storm {
storm::prism::Constant PrismParser::createDefinedIntegerConstant(std::string const& newConstant, storm::expressions::Expression expression) const {
if (!this->secondRun) {
try {
storm::expressions::Variable newVariable = manager->declareIntegerVariable(newConstant);
storm::expressions::Variable newVariable = manager->declareIntegerVariable(newConstant, true);
this->identifiers_.add(newConstant, newVariable.getExpression());
} catch (storm::exceptions::InvalidArgumentException const& e) {
if (manager->hasVariable(newConstant)) {
@ -305,7 +305,7 @@ namespace storm {
storm::prism::Constant PrismParser::createDefinedDoubleConstant(std::string const& newConstant, storm::expressions::Expression expression) const {
if (!this->secondRun) {
try {
storm::expressions::Variable newVariable = manager->declareRationalVariable(newConstant);
storm::expressions::Variable newVariable = manager->declareRationalVariable(newConstant, true);
this->identifiers_.add(newConstant, newVariable.getExpression());
} catch (storm::exceptions::InvalidArgumentException const& e) {
if (manager->hasVariable(newConstant)) {

68
src/storage/BitVector.cpp

@ -1,5 +1,6 @@
#include <boost/container/flat_set.hpp>
#include <iostream>
#include <algorithm>
#include "src/storage/BitVector.h"
#include "src/exceptions/InvalidArgumentException.h"
@ -7,10 +8,7 @@
#include "src/utility/OsDetection.h"
#include "src/utility/Hash.h"
#include "log4cplus/logger.h"
#include "log4cplus/loggingmacros.h"
extern log4cplus::Logger logger;
#include "src/utility/macros.h"
namespace storm {
namespace storage {
@ -80,6 +78,10 @@ namespace storm {
set(begin, end);
}
BitVector::BitVector(uint_fast64_t bucketCount, uint_fast64_t bitCount) : bitCount(bitCount), bucketVector(bucketCount) {
STORM_LOG_ASSERT((bucketCount << 6) == bitCount, "Bit count does not match number of buckets.");
}
BitVector::BitVector(BitVector const& other) : bitCount(other.bitCount), bucketVector(other.bucketVector) {
// Intentionally left empty.
}
@ -360,6 +362,52 @@ namespace storm {
return true;
}
bool BitVector::matches(uint_fast64_t bitIndex, BitVector const& other) const {
STORM_LOG_ASSERT((bitIndex & mod64mask) == 0, "Bit index must be a multiple of 64.");
STORM_LOG_ASSERT(other.size() <= this->size() - bitIndex, "Bit vector argument is too long.");
// Compute the first bucket that needs to be checked and the number of buckets.
uint64_t index = bitIndex >> 6;
std::vector<uint64_t>::const_iterator first1 = bucketVector.begin() + index;
std::vector<uint64_t>::const_iterator first2 = other.bucketVector.begin();
std::vector<uint64_t>::const_iterator last2 = other.bucketVector.end();
for (; first2 != last2; ++first1, ++first2) {
if (*first1 != *first2) {
return false;
}
}
return true;
}
void BitVector::set(uint_fast64_t bitIndex, BitVector const& other) {
STORM_LOG_ASSERT((bitIndex & mod64mask) == 0, "Bit index must be a multiple of 64.");
STORM_LOG_ASSERT(other.size() <= this->size() - bitIndex, "Bit vector argument is too long.");
// Compute the first bucket that needs to be checked and the number of buckets.
uint64_t index = bitIndex >> 6;
std::vector<uint64_t>::iterator first1 = bucketVector.begin() + index;
std::vector<uint64_t>::const_iterator first2 = other.bucketVector.begin();
std::vector<uint64_t>::const_iterator last2 = other.bucketVector.end();
for (; first2 != last2; ++first1, ++first2) {
*first1 = *first2;
}
}
storm::storage::BitVector BitVector::get(uint_fast64_t bitIndex, uint_fast64_t numberOfBits) {
uint64_t numberOfBuckets = numberOfBits >> 6;
uint64_t index = bitIndex >> 6;
STORM_LOG_ASSERT(index + numberOfBuckets <= this->bucketCount(), "Argument is out-of-range.");
storm::storage::BitVector result(numberOfBuckets, numberOfBits);
std::copy(this->bucketVector.begin() + index, this->bucketVector.begin() + index + numberOfBuckets, result.bucketVector.begin());
return result;
}
bool BitVector::empty() const {
for (auto& element : bucketVector) {
if (element != 0) {
@ -508,7 +556,11 @@ namespace storm {
bucketVector.back() &= (1ll << (bitCount & mod64mask)) - 1ll;
}
}
size_t BitVector::bucketCount() const {
return bucketVector.size();
}
std::ostream& operator<<(std::ostream& out, BitVector const& bitVector) {
out << "bit vector(" << bitVector.getNumberOfSetBits() << "/" << bitVector.bitCount << ") [";
for (auto index : bitVector) {
@ -529,4 +581,10 @@ namespace storm {
template void BitVector::set(boost::container::flat_set<uint_fast64_t>::iterator begin, boost::container::flat_set<uint_fast64_t>::iterator end);
template void BitVector::set(boost::container::flat_set<uint_fast64_t>::const_iterator begin, boost::container::flat_set<uint_fast64_t>::const_iterator end);
}
}
namespace std {
std::size_t hash<storm::storage::BitVector>::operator()(storm::storage::BitVector const& bv) {
return boost::hash_range(bv.bucketVector.begin(), bv.bucketVector.end());
}
}

56
src/storage/BitVector.h

@ -6,6 +6,7 @@
#include <ostream>
#include <vector>
#include <iterator>
#include <functional>
namespace storm {
namespace storage {
@ -318,6 +319,36 @@ namespace storm {
*/
bool isDisjointFrom(BitVector const& other) const;
/*!
* Checks whether the given bit vector matches the bits starting from the given index in the current bit
* vector. Note: the given bit vector must be shorter than the current one minus the given index.
*
* @param bitIndex The index of the first bit that it supposed to match. This value must be a multiple of 64.
* @param other The bit vector with which to compare.
* @return bool True iff the bits match exactly.
*/
bool matches(uint_fast64_t bitIndex, BitVector const& other) const;
/*!
* Sets the exact bit pattern of the given bit vector starting at the given bit index. Note: the given bit
* vector must be shorter than the current one minus the given index.
*
* @param bitIndex The index of the first bit that it supposed to be set. This value must be a multiple of
* 64.
* @param other The bit vector whose pattern to set.
*/
void set(uint_fast64_t bitIndex, BitVector const& other);
/*!
* Retrieves the content of the current bit vector at the given index for the given number of bits as a new
* bit vector.
*
* @param bitIndex The index of the first bit to get. This value must be a multiple of 64.
* @param numberOfBits The number of bits to get. This value must be a multiple of 64.
* @return A new bit vector holding the selected bits.
*/
storm::storage::BitVector get(uint_fast64_t bitIndex, uint_fast64_t numberOfBits);
/*!
* Retrieves whether no bits are set to true in this bit vector.
*
@ -395,8 +426,17 @@ namespace storm {
uint_fast64_t getNextSetIndex(uint_fast64_t startingIndex) const;
friend std::ostream& operator<<(std::ostream& out, BitVector const& bitVector);
friend struct std::hash<storm::storage::BitVector>;
private:
/*!
* Creates an empty bit vector with the given number of buckets.
*
* @param bucketCount The number of buckets to create.
* @param bitCount This must be the number of buckets times 64.
*/
BitVector(uint_fast64_t bucketCount, uint_fast64_t bitCount);
/*!
* Retrieves the index of the next bit that is set to true after (and including) the given starting index.
*
@ -412,7 +452,14 @@ namespace storm {
* Truncate the last bucket so that no bits are set starting from bitCount.
*/
void truncateLastBucket();
/*!
* Retrieves the number of buckets of the underlying storage.
*
* @return The number of buckets of the underlying storage.
*/
size_t bucketCount() const;
// The number of bits that this bit vector can hold.
uint_fast64_t bitCount;
@ -426,4 +473,11 @@ namespace storm {
} // namespace storage
} // namespace storm
namespace std {
template <>
struct hash<storm::storage::BitVector> {
std::size_t operator()(storm::storage::BitVector const& bv);
};
}
#endif // STORM_STORAGE_BITVECTOR_H_

92
src/storage/BitVectorHashMap.cpp

@ -0,0 +1,92 @@
#include "src/storage/BitVectorHashMap.h"
#include <iostream>
#include "src/utility/macros.h"
namespace storm {
namespace storage {
template<class ValueType, class Hash1, class Hash2>
const std::vector<std::size_t> BitVectorHashMap<ValueType, Hash1, Hash2>::sizes = {5, 13, 31, 79, 163, 277, 499, 1021, 2029, 3989, 8059, 16001, 32099, 64301, 127921, 256499, 511111, 1024901, 2048003, 4096891, 8192411, 15485863};
template<class ValueType, class Hash1, class Hash2>
BitVectorHashMap<ValueType, Hash1, Hash2>::BitVectorHashMap(uint64_t bucketSize, uint64_t initialSize, double loadFactor) : loadFactor(loadFactor), bucketSize(bucketSize), numberOfElements(0) {
STORM_LOG_ASSERT(bucketSize % 64 == 0, "Bucket size must be a multiple of 64.");
currentSizeIterator = std::find_if(sizes.begin(), sizes.end(), [=] (uint64_t value) { return value > initialSize; } );
// Create the underlying containers.
buckets = storm::storage::BitVector(bucketSize * *currentSizeIterator);
occupied = storm::storage::BitVector(*currentSizeIterator);
values = std::vector<ValueType>(*currentSizeIterator);
}
template<class ValueType, class Hash1, class Hash2>
bool BitVectorHashMap<ValueType, Hash1, Hash2>::isBucketOccupied(uint_fast64_t bucket) {
return occupied.get(bucket);
}
template<class ValueType, class Hash1, class Hash2>
std::size_t BitVectorHashMap<ValueType, Hash1, Hash2>::size() const {
return numberOfElements;
}
template<class ValueType, class Hash1, class Hash2>
std::size_t BitVectorHashMap<ValueType, Hash1, Hash2>::capacity() const {
return *currentSizeIterator;
}
template<class ValueType, class Hash1, class Hash2>
void BitVectorHashMap<ValueType, Hash1, Hash2>::increaseSize() {
++currentSizeIterator;
STORM_LOG_ASSERT(currentSizeIterator != sizes.end(), "Hash map became to big.");
storm::storage::BitVector oldBuckets(bucketSize * *currentSizeIterator);
std::swap(oldBuckets, buckets);
storm::storage::BitVector oldOccupied = storm::storage::BitVector(*currentSizeIterator);
std::swap(oldOccupied, occupied);
std::vector<ValueType> oldValues = std::vector<ValueType>(*currentSizeIterator);
std::swap(oldValues, values);
// Now iterate through the elements and reinsert them in the new storage.
numberOfElements = 0;
for (auto const& bucketIndex : oldOccupied) {
this->findOrAdd(oldBuckets.get(bucketIndex * bucketSize, bucketSize), oldValues[bucketIndex]);
}
}
template<class ValueType, class Hash1, class Hash2>
ValueType BitVectorHashMap<ValueType, Hash1, Hash2>::findOrAdd(storm::storage::BitVector const& key, ValueType value) {
// If the load of the map is too high, we increase the size.
if (numberOfElements >= loadFactor * *currentSizeIterator) {
this->increaseSize();
}
uint_fast64_t initialHash = hasher1(key) % *currentSizeIterator;
uint_fast64_t bucket = initialHash;
while (isBucketOccupied(bucket)) {
if (buckets.matches(bucket * bucketSize, key)) {
return values[bucket];
}
bucket += hasher2(key);
bucket %= *currentSizeIterator;
// If we arrived at the original position, this means that we have visited all possible locations, but
// could not find a suitable position. This implies that we have to enlarge the map in order to resolve
// the issue.
if (bucket == initialHash) {
this->increaseSize();
}
}
// Insert the new bits into the bucket.
buckets.set(bucket * bucketSize, key);
occupied.set(bucket);
values[bucket] = value;
++numberOfElements;
return value;
}
template class BitVectorHashMap<uint_fast64_t>;
}
}

102
src/storage/BitVectorHashMap.h

@ -0,0 +1,102 @@
#ifndef STORM_STORAGE_BITVECTORHASHMAP_H_
#define STORM_STORAGE_BITVECTORHASHMAP_H_
#include <cstdint>
#include <functional>
#include "src/storage/BitVector.h"
namespace storm {
namespace storage {
/*!
* This class represents a hash-map whose keys are bit vectors. The value type is arbitrary. Currently, only
* queries and insertions are supported. Also, the keys must be bit vectors with a length that is a multiple of
* 64.
*/
template<class ValueType, class Hash1 = std::hash<storm::storage::BitVector>, class Hash2 = std::hash<storm::storage::BitVector>>
class BitVectorHashMap {
public:
/*!
* Creates a new hash map with the given bucket size and initial size.
*
* @param bucketSize The size of the buckets that this map can hold. This value must be a multiple of 64.
* @param initialSize The number of buckets that is initially available.
* @param loadFactor The load factor that determines at which point the size of the underlying storage is
* increased.
*/
BitVectorHashMap(uint64_t bucketSize, uint64_t initialSize, double loadFactor = 0.75);
/*!
* Searches for the given key in the map. If it is found, the mapped-to value is returned. Otherwise, the
* key is inserted with the given value.
*
* @param key The key to search or insert.
* @param value The value that is inserted if the key is not already found in the map.
* @return The found value if the key is already contained in the map and the provided new value otherwise.
*/
ValueType findOrAdd(storm::storage::BitVector const& key, ValueType value);
/*!
* Retrieves the size of the map in terms of the number of key-value pairs it stores.
*
* @return The size of the map.
*/
std::size_t size() const;
/*!
* Retrieves the capacity of the underlying container.
*
* @return The capacity of the underlying container.
*/
std::size_t capacity() const;
private:
/*!
* Retrieves whether the given bucket holds a value.
*
* @param bucket The bucket to check.
* @return True iff the bucket is occupied.
*/
bool isBucketOccupied(uint_fast64_t bucket);
/*!
* Increases the size of the hash map and performs the necessary rehashing of all entries.
*/
void increaseSize();
// The load factor determining when the size of the map is increased.
double loadFactor;
// The size of one bucket.
uint64_t bucketSize;
// The number of buckets.
std::size_t numberOfBuckets;
// The buckets that hold the elements of the map.
storm::storage::BitVector buckets;
// A bit vector that stores which buckets actually hold a value.
storm::storage::BitVector occupied;
// A vector of the mapped-to values. The entry at position i is the "target" of the key in bucket i.
std::vector<ValueType> values;
// The number of elements in this map.
std::size_t numberOfElements;
// An iterator to a value in the static sizes table.
std::vector<std::size_t>::const_iterator currentSizeIterator;
// Functor object that are used to perform the actual hashing.
Hash1 hasher1;
Hash2 hasher2;
// A static table that produces the next possible size of the hash table.
static const std::vector<std::size_t> sizes;
};
}
}
#endif /* STORM_STORAGE_BITVECTORHASHMAP_H_ */

70
src/storage/expressions/ExpressionEvaluator.cpp

@ -0,0 +1,70 @@
#include "src/storage/expressions/ExpressionEvaluator.h"
#include "src/storage/expressions/ExpressionManager.h"
namespace storm {
namespace expressions {
ExpressionEvaluator::ExpressionEvaluator(storm::expressions::ExpressionManager const& manager) : manager(manager.getSharedPointer()), booleanValues(manager.getNumberOfBooleanVariables()), integerValues(manager.getNumberOfIntegerVariables()), rationalValues(manager.getNumberOfRationalVariables()) {
for (auto const& variableTypePair : manager) {
if (variableTypePair.second.isBooleanType()) {
symbolTable.add_variable(variableTypePair.first.getName(), this->booleanValues[variableTypePair.first.getOffset()]);
} else if (variableTypePair.second.isIntegerType()) {
symbolTable.add_variable(variableTypePair.first.getName(), this->integerValues[variableTypePair.first.getOffset()]);
} else if (variableTypePair.second.isRationalType()) {
symbolTable.add_variable(variableTypePair.first.getName(), this->rationalValues[variableTypePair.first.getOffset()]);
}
}
symbolTable.add_constants();
}
bool ExpressionEvaluator::asBool(Expression const& expression) {
BaseExpression const* expressionPtr = expression.getBaseExpressionPointer().get();
auto const& expressionPair = this->compiledExpressions.find(expression.getBaseExpressionPointer().get());
if (expressionPair == this->compiledExpressions.end()) {
CompiledExpressionType const& compiledExpression = this->getCompiledExpression(expressionPtr);
return compiledExpression.value() == ValueType(1);
}
return expressionPair->second.value() == ValueType(1);
}
int_fast64_t ExpressionEvaluator::asInt(Expression const& expression) {
BaseExpression const* expressionPtr = expression.getBaseExpressionPointer().get();
auto const& expressionPair = this->compiledExpressions.find(expression.getBaseExpressionPointer().get());
if (expressionPair == this->compiledExpressions.end()) {
CompiledExpressionType const& compiledExpression = this->getCompiledExpression(expressionPtr);
return static_cast<int_fast64_t>(compiledExpression.value());
}
return static_cast<int_fast64_t>(expressionPair->second.value());
}
double ExpressionEvaluator::asDouble(Expression const& expression) {
BaseExpression const* expressionPtr = expression.getBaseExpressionPointer().get();
auto const& expressionPair = this->compiledExpressions.find(expression.getBaseExpressionPointer().get());
if (expressionPair == this->compiledExpressions.end()) {
CompiledExpressionType const& compiledExpression = this->getCompiledExpression(expressionPtr);
return static_cast<double>(compiledExpression.value());
}
return static_cast<double>(expressionPair->second.value());
}
ExpressionEvaluator::CompiledExpressionType& ExpressionEvaluator::getCompiledExpression(BaseExpression const* expression) {
std::pair<CacheType::iterator, bool> result = this->compiledExpressions.emplace(expression, CompiledExpressionType());
CompiledExpressionType& compiledExpression = result.first->second;
compiledExpression.register_symbol_table(symbolTable);
bool parsingOk = parser.compile(stringTranslator.toString(expression), compiledExpression);
return compiledExpression;
}
void ExpressionEvaluator::setBooleanValue(storm::expressions::Variable const& variable, bool value) {
this->booleanValues[variable.getOffset()] = static_cast<ValueType>(value);
}
void ExpressionEvaluator::setIntegerValue(storm::expressions::Variable const& variable, int_fast64_t value) {
this->integerValues[variable.getOffset()] = static_cast<ValueType>(value);
}
void ExpressionEvaluator::setRationalValue(storm::expressions::Variable const& variable, double value) {
this->rationalValues[variable.getOffset()] = static_cast<ValueType>(value);
}
}
}

67
src/storage/expressions/ExpressionEvaluator.h

@ -0,0 +1,67 @@
#ifndef STORM_STORAGE_EXPRESSIONS_EXPRESSIONEVALUATOR_H_
#define STORM_STORAGE_EXPRESSIONS_EXPRESSIONEVALUATOR_H_
#include <unordered_map>
#include <vector>
#include "exprtk.hpp"
#include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/ExpressionVisitor.h"
#include "src/storage/expressions/ToExprtkStringVisitor.h"
namespace storm {
namespace expressions {
class ExpressionEvaluator {
public:
/*!
* Creates an expression evaluator that is capable of evaluating expressions managed by the given manager.
*
* @param manager The manager responsible for the expressions.
*/
ExpressionEvaluator(storm::expressions::ExpressionManager const& manager);
bool asBool(Expression const& expression);
int_fast64_t asInt(Expression const& expression);
double asDouble(Expression const& expression);
void setBooleanValue(storm::expressions::Variable const& variable, bool value);
void setIntegerValue(storm::expressions::Variable const& variable, int_fast64_t value);
void setRationalValue(storm::expressions::Variable const& variable, double value);
private:
typedef double ValueType;
typedef exprtk::expression<ValueType> CompiledExpressionType;
typedef std::unordered_map<BaseExpression const*, CompiledExpressionType> CacheType;
/*!
* Adds a compiled version of the given expression to the internal storage.
*
* @param expression The expression that is to be compiled.
*/
CompiledExpressionType& getCompiledExpression(BaseExpression const* expression);
// The expression manager that is used by this evaluator.
std::shared_ptr<storm::expressions::ExpressionManager const> manager;
// The parser used.
exprtk::parser<ValueType> parser;
// The symbol table used.
exprtk::symbol_table<ValueType> symbolTable;
// The actual data that is fed into the expression.
std::vector<ValueType> booleanValues;
std::vector<ValueType> integerValues;
std::vector<ValueType> rationalValues;
// A mapping of expressions to their compiled counterpart.
CacheType compiledExpressions;
// A translator that can be used for transforming an expression into the correct string format.
ToExprtkStringVisitor stringTranslator;
};
}
}
#endif /* STORM_STORAGE_EXPRESSIONS_EXPRESSIONEVALUATOR_H_ */

2
src/storage/expressions/LinearityCheckVisitor.h

@ -1,8 +1,6 @@
#ifndef STORM_STORAGE_EXPRESSIONS_LINEARITYCHECKVISITOR_H_
#define STORM_STORAGE_EXPRESSIONS_LINEARITYCHECKVISITOR_H_
#include <stack>
#include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/ExpressionVisitor.h"

197
src/storage/expressions/ToExprtkStringVisitor.cpp

@ -0,0 +1,197 @@
#include "src/storage/expressions/ToExprtkStringVisitor.h"
namespace storm {
namespace expressions {
std::string ToExprtkStringVisitor::toString(Expression const& expression) {
return toString(expression.getBaseExpressionPointer().get());
}
std::string ToExprtkStringVisitor::toString(BaseExpression const* expression) {
stream = std::stringstream();
expression->accept(*this);
return std::move(stream.str());
}
boost::any ToExprtkStringVisitor::visit(IfThenElseExpression const& expression) {
stream << "if(";
expression.getCondition()->accept(*this);
stream << ",";
expression.getThenExpression()->accept(*this);
stream << ",";
expression.getElseExpression()->accept(*this);
stream << ")";
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(BinaryBooleanFunctionExpression const& expression) {
switch (expression.getOperatorType()) {
case BinaryBooleanFunctionExpression::OperatorType::And:
stream << "and(";
expression.getFirstOperand()->accept(*this);
stream << ",";
expression.getFirstOperand()->accept(*this);
stream << ")";
break;
case BinaryBooleanFunctionExpression::OperatorType::Or:
stream << "or(";
expression.getFirstOperand()->accept(*this);
stream << ",";
expression.getFirstOperand()->accept(*this);
stream << ")";
break;
case BinaryBooleanFunctionExpression::OperatorType::Xor:
stream << "xor(";
expression.getFirstOperand()->accept(*this);
stream << ",";
expression.getFirstOperand()->accept(*this);
stream << ")";
break;
case BinaryBooleanFunctionExpression::OperatorType::Implies:
stream << "or(not(";
expression.getFirstOperand()->accept(*this);
stream << "),";
expression.getFirstOperand()->accept(*this);
stream << ")";
break;
case BinaryBooleanFunctionExpression::OperatorType::Iff:
expression.getFirstOperand()->accept(*this);
stream << "==";
expression.getFirstOperand()->accept(*this);
break;
}
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(BinaryNumericalFunctionExpression const& expression) {
switch (expression.getOperatorType()) {
case BinaryNumericalFunctionExpression::OperatorType::Plus:
expression.getFirstOperand()->accept(*this);
stream << "+";
expression.getSecondOperand()->accept(*this);
break;
case BinaryNumericalFunctionExpression::OperatorType::Minus:
expression.getFirstOperand()->accept(*this);
stream << "-";
expression.getSecondOperand()->accept(*this);
break;
case BinaryNumericalFunctionExpression::OperatorType::Times:
expression.getFirstOperand()->accept(*this);
stream << "*";
expression.getSecondOperand()->accept(*this);
break;
case BinaryNumericalFunctionExpression::OperatorType::Divide:
expression.getFirstOperand()->accept(*this);
stream << "/";
expression.getSecondOperand()->accept(*this);
break;
case BinaryNumericalFunctionExpression::OperatorType::Power:
expression.getFirstOperand()->accept(*this);
stream << "^";
expression.getSecondOperand()->accept(*this);
break;
case BinaryNumericalFunctionExpression::OperatorType::Max:
stream << "max(";
expression.getFirstOperand()->accept(*this);
stream << ",";
expression.getSecondOperand()->accept(*this);
stream << ")";
break;
case BinaryNumericalFunctionExpression::OperatorType::Min:
stream << "min(";
expression.getFirstOperand()->accept(*this);
stream << ",";
expression.getSecondOperand()->accept(*this);
stream << ")";
break;
}
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(BinaryRelationExpression const& expression) {
switch (expression.getRelationType()) {
case BinaryRelationExpression::RelationType::Equal:
expression.getFirstOperand()->accept(*this);
stream << "==";
expression.getSecondOperand()->accept(*this);
break;
case BinaryRelationExpression::RelationType::NotEqual:
expression.getFirstOperand()->accept(*this);
stream << "!=";
expression.getSecondOperand()->accept(*this);
break;
case BinaryRelationExpression::RelationType::Less:
expression.getFirstOperand()->accept(*this);
stream << "<";
expression.getSecondOperand()->accept(*this);
break;
case BinaryRelationExpression::RelationType::LessOrEqual:
expression.getFirstOperand()->accept(*this);
stream << "<=";
expression.getSecondOperand()->accept(*this);
break;
case BinaryRelationExpression::RelationType::Greater:
expression.getFirstOperand()->accept(*this);
stream << ">";
expression.getSecondOperand()->accept(*this);
break;
case BinaryRelationExpression::RelationType::GreaterOrEqual:
expression.getFirstOperand()->accept(*this);
stream << ">=";
expression.getSecondOperand()->accept(*this);
break;
}
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(VariableExpression const& expression) {
stream << expression.getVariableName();
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(UnaryBooleanFunctionExpression const& expression) {
switch (expression.getOperatorType()) {
case UnaryBooleanFunctionExpression::OperatorType::Not:
stream << "not(";
expression.getOperand()->accept(*this);
stream << ")";
}
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(UnaryNumericalFunctionExpression const& expression) {
switch (expression.getOperatorType()) {
case UnaryNumericalFunctionExpression::OperatorType::Minus:
stream << "-(";
expression.getOperand()->accept(*this);
stream << ")";
break;
case UnaryNumericalFunctionExpression::OperatorType::Floor:
stream << "floor(";
expression.getOperand()->accept(*this);
stream << ")";
break;
case UnaryNumericalFunctionExpression::OperatorType::Ceil:
stream << "ceil(";
expression.getOperand()->accept(*this);
stream << ")";
break;
}
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(BooleanLiteralExpression const& expression) {
stream << expression.getValue();
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(IntegerLiteralExpression const& expression) {
stream << expression.getValue();
return boost::any();
}
boost::any ToExprtkStringVisitor::visit(DoubleLiteralExpression const& expression) {
stream << expression.getValue();
return boost::any();
}
}
}

37
src/storage/expressions/ToExprtkStringVisitor.h

@ -0,0 +1,37 @@
#ifndef STORM_STORAGE_EXPRESSIONS_TOEXPRTKSTRINGVISITOR_H_
#define STORM_STORAGE_EXPRESSIONS_TOEXPRTKSTRINGVISITOR_H_
#include <sstream>
#include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Expressions.h"
#include "src/storage/expressions/ExpressionVisitor.h"
namespace storm {
namespace expressions {
class ToExprtkStringVisitor : public ExpressionVisitor {
public:
ToExprtkStringVisitor() = default;
std::string toString(Expression const& expression);
std::string toString(BaseExpression const* expression);
virtual boost::any visit(IfThenElseExpression const& expression) override;
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BinaryRelationExpression const& expression) override;
virtual boost::any visit(VariableExpression const& expression) override;
virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BooleanLiteralExpression const& expression) override;
virtual boost::any visit(IntegerLiteralExpression const& expression) override;
virtual boost::any visit(DoubleLiteralExpression const& expression) override;
private:
std::stringstream stream;
};
}
}
#endif /* STORM_STORAGE_EXPRESSIONS_TOEXPRTKSTRINGVISITOR_H_ */

53
test/functional/storage/BitVectorHashMapTest.cpp

@ -0,0 +1,53 @@
#include "gtest/gtest.h"
#include <cstdint>
#include "src/storage/BitVector.h"
#include "src/storage/BitVectorHashMap.h"
TEST(BitVectorHashMapTest, FindOrAdd) {
storm::storage::BitVectorHashMap<uint64_t> map(64, 3);
storm::storage::BitVector first(64);
first.set(4);
first.set(47);
ASSERT_NO_THROW(map.findOrAdd(first, 1));
storm::storage::BitVector second(64);
second.set(8);
second.set(18);
ASSERT_NO_THROW(map.findOrAdd(second, 2));
EXPECT_EQ(1, map.findOrAdd(first, 3));
storm::storage::BitVector third(64);
third.set(10);
third.set(63);
ASSERT_NO_THROW(map.findOrAdd(third, 3));
storm::storage::BitVector fourth(64);
fourth.set(12);
fourth.set(14);
ASSERT_NO_THROW(map.findOrAdd(fourth, 4));
storm::storage::BitVector fifth(64);
fifth.set(44);
fifth.set(55);
ASSERT_NO_THROW(map.findOrAdd(fifth, 5));
storm::storage::BitVector sixth(64);
sixth.set(45);
sixth.set(55);
ASSERT_NO_THROW(map.findOrAdd(sixth, 6));
EXPECT_EQ(1, map.findOrAdd(first, 0));
EXPECT_EQ(2, map.findOrAdd(second, 0));
EXPECT_EQ(3, map.findOrAdd(third, 0));
EXPECT_EQ(4, map.findOrAdd(fourth, 0));
EXPECT_EQ(5, map.findOrAdd(fifth, 0));
EXPECT_EQ(6, map.findOrAdd(sixth, 0));
}

62
test/functional/storage/ExpressionEvalutionTest.cpp

@ -0,0 +1,62 @@
#include "gtest/gtest.h"
#include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/ExpressionManager.h"
#include "src/storage/expressions/SimpleValuation.h"
#include "src/storage/expressions/ExpressionEvaluator.h"
TEST(ExpressionEvaluation, NaiveEvaluation) {
std::shared_ptr<storm::expressions::ExpressionManager> manager(new storm::expressions::ExpressionManager());
storm::expressions::Variable x;
storm::expressions::Variable y;
storm::expressions::Variable z;
ASSERT_NO_THROW(x = manager->declareBooleanVariable("x"));
ASSERT_NO_THROW(y = manager->declareIntegerVariable("y"));
ASSERT_NO_THROW(z = manager->declareRationalVariable("z"));
storm::expressions::SimpleValuation eval(manager);
storm::expressions::Expression iteExpression = storm::expressions::ite(x, y + z, manager->integer(3) * z);
eval.setRationalValue(z, 5.5);
eval.setBooleanValue(x, true);
for (int_fast64_t i = 0; i < 1000; ++i) {
eval.setIntegerValue(y, 3 + i);
EXPECT_NEAR(8.5 + i, iteExpression.evaluateAsDouble(&eval), 1e-6);
}
eval.setBooleanValue(x, false);
for (int_fast64_t i = 0; i < 1000; ++i) {
double zValue = i / static_cast<double>(10);
eval.setRationalValue(z, zValue);
EXPECT_NEAR(3 * zValue, iteExpression.evaluateAsDouble(&eval), 1e-6);
}
}
TEST(ExpressionEvaluation, ExprTkEvaluation) {
std::shared_ptr<storm::expressions::ExpressionManager> manager(new storm::expressions::ExpressionManager());
storm::expressions::Variable x;
storm::expressions::Variable y;
storm::expressions::Variable z;
ASSERT_NO_THROW(x = manager->declareBooleanVariable("x"));
ASSERT_NO_THROW(y = manager->declareIntegerVariable("y"));
ASSERT_NO_THROW(z = manager->declareRationalVariable("z"));
storm::expressions::Expression iteExpression = storm::expressions::ite(x, y + z, manager->integer(3) * z);
storm::expressions::ExpressionEvaluator eval(*manager);
eval.setRationalValue(z, 5.5);
eval.setBooleanValue(x, true);
for (int_fast64_t i = 0; i < 1000; ++i) {
eval.setIntegerValue(y, 3 + i);
EXPECT_NEAR(8.5 + i, eval.asDouble(iteExpression), 1e-6);
}
eval.setBooleanValue(x, false);
for (int_fast64_t i = 0; i < 1000; ++i) {
double zValue = i / static_cast<double>(10);
eval.setRationalValue(z, zValue);
EXPECT_NEAR(3 * zValue, eval.asDouble(iteExpression), 1e-6);
}
}
Loading…
Cancel
Save