diff --git a/src/storm/storage/expressions/ExpressionManager.cpp b/src/storm/storage/expressions/ExpressionManager.cpp index d1b3d15cd..5cc3ffcd5 100644 --- a/src/storm/storage/expressions/ExpressionManager.cpp +++ b/src/storm/storage/expressions/ExpressionManager.cpp @@ -52,7 +52,7 @@ namespace storm { } } - ExpressionManager::ExpressionManager() : nameToIndexMapping(), indexToNameMapping(), indexToTypeMapping(), numberOfBooleanVariables(0), numberOfIntegerVariables(0), numberOfBitVectorVariables(0), numberOfRationalVariables(0), numberOfAuxiliaryVariables(0), numberOfAuxiliaryBooleanVariables(0), numberOfAuxiliaryIntegerVariables(0), numberOfAuxiliaryBitVectorVariables(0), numberOfAuxiliaryRationalVariables(0), freshVariableCounter(0) { + ExpressionManager::ExpressionManager() : nameToIndexMapping(), indexToNameMapping(), indexToTypeMapping(), numberOfBooleanVariables(0), numberOfIntegerVariables(0), numberOfBitVectorVariables(0), numberOfRationalVariables(0), numberOfArrayVariables(0), numberOfAuxiliaryVariables(0), numberOfAuxiliaryBooleanVariables(0), numberOfAuxiliaryIntegerVariables(0), numberOfAuxiliaryBitVectorVariables(0), numberOfAuxiliaryRationalVariables(0), numberOfAuxiliaryArrayVariables(0), freshVariableCounter(0) { // Intentionally left empty. } @@ -115,6 +115,11 @@ namespace storm { return rationalType.get(); } + Type const& ExpressionManager::getArrayType(Type elementType) const { + Type type(this->getSharedPointer(), std::shared_ptr(new ArrayType(elementType))); + return *arrayTypes.insert(type).first; + } + bool ExpressionManager::isValidVariableName(std::string const& name) { return name.size() < 2 || name.at(0) != '_' || name.at(1) != '_'; } @@ -149,6 +154,10 @@ namespace storm { Variable ExpressionManager::declareRationalVariable(std::string const& name, bool auxiliary) { return this->declareVariable(name, this->getRationalType(), auxiliary); } + + Variable ExpressionManager::declareArrayVariable(std::string const& name, Type const& elementType, bool auxiliary) { + return this->declareVariable(name, this->getArrayType(elementType), auxiliary); + } Variable ExpressionManager::declareOrGetVariable(std::string const& name, storm::expressions::Type const& variableType, bool auxiliary) { return declareOrGetVariable(name, variableType, auxiliary, true); @@ -168,8 +177,12 @@ namespace storm { offset = numberOfIntegerVariables++ + numberOfBitVectorVariables; } else if (variableType.isBitVectorType()) { offset = numberOfBitVectorVariables++ + numberOfIntegerVariables; - } else { + } else if (variableType.isRationalType()) { offset = numberOfRationalVariables++; + } else if (variableType.isArrayType()) { + offset = numberOfArrayVariables++; + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Trying to declare a variable of unsupported type: '" << variableType.getStringRepresentation() << "'."); } } else { if (variableType.isBooleanType()) { @@ -178,8 +191,12 @@ namespace storm { offset = numberOfIntegerVariables++ + numberOfBitVectorVariables; } else if (variableType.isBitVectorType()) { offset = numberOfBitVectorVariables++ + numberOfIntegerVariables; - } else { + } else if (variableType.isRationalType()) { offset = numberOfRationalVariables++; + } else if (variableType.isArrayType()) { + offset = numberOfArrayVariables++; + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Trying to declare a variable of unsupported type: '" << variableType.getStringRepresentation() << "'."); } } @@ -240,12 +257,14 @@ namespace storm { return numberOfBitVectorVariables; } else if (variableType.isRationalType()) { return numberOfRationalVariables; + } else if (variableType.isArrayType()) { + return numberOfArrayVariables; } return 0; } uint_fast64_t ExpressionManager::getNumberOfVariables() const { - return numberOfBooleanVariables + numberOfIntegerVariables + numberOfBitVectorVariables + numberOfRationalVariables; + return numberOfBooleanVariables + numberOfIntegerVariables + numberOfBitVectorVariables + numberOfRationalVariables + numberOfArrayVariables; } uint_fast64_t ExpressionManager::getNumberOfBooleanVariables() const { @@ -264,6 +283,10 @@ namespace storm { return numberOfRationalVariables; } + uint_fast64_t ExpressionManager::getNumberOfArrayVariables() const { + return numberOfRationalVariables; + } + std::string const& ExpressionManager::getVariableName(uint_fast64_t index) const { auto indexTypeNamePair = indexToNameMapping.find(index); STORM_LOG_THROW(indexTypeNamePair != indexToNameMapping.end(), storm::exceptions::InvalidArgumentException, "Unknown variable index '" << index << "'."); @@ -300,7 +323,7 @@ namespace storm { out << "manager {" << std::endl; for (auto const& variableTypePair : manager) { - std::cout << "\t" << variableTypePair.second << " " << variableTypePair.first.getName() << " [offset " << variableTypePair.first.getOffset() << "]" << std::endl; + out << "\t" << variableTypePair.second << " " << variableTypePair.first.getName() << " [offset " << variableTypePair.first.getOffset() << "]" << std::endl; } out << "}" << std::endl; diff --git a/src/storm/storage/expressions/ExpressionManager.h b/src/storm/storage/expressions/ExpressionManager.h index 3a80261b3..cf56d5670 100644 --- a/src/storm/storage/expressions/ExpressionManager.h +++ b/src/storm/storage/expressions/ExpressionManager.h @@ -149,6 +149,11 @@ namespace storm { * @return The rational type. */ Type const& getRationalType() const; + + /*! + * Retrieves the array type with the given element type + */ + Type const& getArrayType(Type elementType) const; /*! * Declares a variable that is a copy of the provided variable (i.e. has the same type). @@ -211,6 +216,11 @@ namespace storm { */ Variable declareRationalVariable(std::string const& name, bool auxiliary = false); + /*! + * Declares a new array variable with the given name and the given element type. + */ + Variable declareArrayVariable(std::string const& name, Type const& elementType, bool auxiliary = false); + /*! * Declares a variable with the given name if it does not yet exist. * @@ -321,6 +331,11 @@ namespace storm { */ uint_fast64_t getNumberOfRationalVariables() const; + /*! + * Retrieves the number of array variables. + */ + uint_fast64_t getNumberOfArrayVariables() const; + /*! * Retrieves the name of the variable with the given index. * @@ -443,6 +458,7 @@ namespace storm { uint_fast64_t numberOfIntegerVariables; uint_fast64_t numberOfBitVectorVariables; uint_fast64_t numberOfRationalVariables; + uint_fast64_t numberOfArrayVariables; // The number of declared auxiliary variables. uint_fast64_t numberOfAuxiliaryVariables; @@ -452,6 +468,7 @@ namespace storm { uint_fast64_t numberOfAuxiliaryIntegerVariables; uint_fast64_t numberOfAuxiliaryBitVectorVariables; uint_fast64_t numberOfAuxiliaryRationalVariables; + uint_fast64_t numberOfAuxiliaryArrayVariables; // A counter used to create fresh variables. uint_fast64_t freshVariableCounter; @@ -461,6 +478,7 @@ namespace storm { mutable boost::optional integerType; mutable std::unordered_set bitvectorTypes; mutable boost::optional rationalType; + mutable std::unordered_set arrayTypes; // A mask that can be used to query whether a variable is an auxiliary variable. static const uint64_t auxiliaryMask = (1ull << 50); diff --git a/src/storm/storage/expressions/Type.cpp b/src/storm/storage/expressions/Type.cpp index b51ac812a..b585b8eae 100644 --- a/src/storm/storage/expressions/Type.cpp +++ b/src/storm/storage/expressions/Type.cpp @@ -57,6 +57,14 @@ namespace storm { bool RationalType::isRationalType() const { return true; } + + bool BaseType::isArrayType() const { + return false; + } + + bool ArrayType::isArrayType() const { + return true; + } uint64_t BooleanType::getMask() const { return BooleanType::mask; @@ -102,6 +110,26 @@ namespace storm { return "rational"; } + ArrayType::ArrayType(Type elementType) : elementType(elementType) { + // Intentionally left empty + } + + Type ArrayType::getElementType() const { + return elementType; + } + + bool ArrayType::operator==(BaseType const& other) const { + return BaseType::operator==(other) && this->elementType == static_cast(other).getElementType(); + } + + uint64_t ArrayType::getMask() const { + return ArrayType::mask; + } + + std::string ArrayType::getStringRepresentation() const { + return "array[" + elementType.getStringRepresentation() + "]"; + } + bool operator<(BaseType const& first, BaseType const& second) { if (first.getMask() < second.getMask()) { return true; @@ -109,8 +137,11 @@ namespace storm { if (first.isBitVectorType() && second.isBitVectorType()) { return static_cast(first).getWidth() < static_cast(second).getWidth(); } + if (first.isArrayType() && second.isArrayType()) { + return static_cast(first).getElementType() < static_cast(second).getElementType(); + } return false; - } + } Type::Type() : manager(nullptr), innerType(nullptr) { // Intentionally left empty. @@ -144,6 +175,10 @@ namespace storm { return this->isIntegerType() || this->isRationalType(); } + bool Type::isArrayType() const { + return this->innerType->isArrayType(); + } + std::string Type::getStringRepresentation() const { return this->innerType->getStringRepresentation(); } @@ -151,6 +186,10 @@ namespace storm { std::size_t Type::getWidth() const { return static_cast(*this->innerType).getWidth(); } + + Type Type::getElementType() const { + return static_cast(*this->innerType).getElementType(); + } bool Type::isRationalType() const { return this->innerType->isRationalType(); diff --git a/src/storm/storage/expressions/Type.h b/src/storm/storage/expressions/Type.h index b173d5d53..ed5e60679 100644 --- a/src/storm/storage/expressions/Type.h +++ b/src/storm/storage/expressions/Type.h @@ -9,113 +9,8 @@ namespace storm { namespace expressions { - // Forward-declare expression manager class. class ExpressionManager; - - class BaseType { - public: - BaseType(); - virtual ~BaseType() = default; - - /*! - * Retrieves the mask that is associated with this type. - * - * @return The mask associated with this type. - */ - virtual uint64_t getMask() const = 0; - - /*! - * Checks whether two types are actually the same. - * - * @param other The type to compare with. - * @return True iff the types are the same. - */ - virtual bool operator==(BaseType const& other) const; - - /*! - * Returns a string representation of the type. - * - * @return A string representation of the type. - */ - virtual std::string getStringRepresentation() const = 0; - - virtual bool isErrorType() const; - virtual bool isBooleanType() const; - virtual bool isIntegerType() const; - virtual bool isBitVectorType() const; - virtual bool isRationalType() const; - }; - - class BooleanType : public BaseType { - public: - virtual uint64_t getMask() const override; - virtual std::string getStringRepresentation() const override; - virtual bool isBooleanType() const override; - - private: - static const uint64_t mask = (1ull << 60); - }; - - class IntegerType : public BaseType { - public: - virtual uint64_t getMask() const override; - virtual std::string getStringRepresentation() const override; - virtual bool isIntegerType() const override; - - private: - static const uint64_t mask = (1ull << 62); - }; - - class BitVectorType : public BaseType { - public: - /*! - * Creates a new bounded bitvector type with the given bit width. - * - * @param width The bit width of the type. - */ - BitVectorType(std::size_t width); - - /*! - * Retrieves the bit width of the bounded type. - * - * @return The bit width of the bounded type. - */ - std::size_t getWidth() const; - - virtual bool operator==(BaseType const& other) const override; - virtual uint64_t getMask() const override; - virtual std::string getStringRepresentation() const override; - virtual bool isIntegerType() const override; - virtual bool isBitVectorType() const override; - - private: - static const uint64_t mask = (1ull << 61); - - // The bit width of the type. - std::size_t width; - }; - - class RationalType : public BaseType { - public: - virtual uint64_t getMask() const override; - virtual std::string getStringRepresentation() const override; - virtual bool isRationalType() const override; - - private: - static const uint64_t mask = (1ull << 63); - }; - - class ErrorType : public BaseType { - public: - virtual uint64_t getMask() const override; - virtual std::string getStringRepresentation() const override; - virtual bool isErrorType() const override; - - private: - static const uint64_t mask = 0; - }; - - bool operator<(BaseType const& first, BaseType const& second); + class BaseType; class Type { public: @@ -187,6 +82,13 @@ namespace storm { * @return True iff the type is a numerical one. */ bool isNumericalType() const; + + /*! + * Checks whether this type is an array type. + * + * @return True iff the type is an array. + */ + bool isArrayType() const; /*! * Retrieves the bit width of the type, provided that it is a bitvector type. @@ -195,6 +97,13 @@ namespace storm { */ std::size_t getWidth() const; + /*! + * Retrieves the element type of the type, provided that it is an Array type. + * + * @return The bit width of the bitvector type. + */ + Type getElementType() const; + /*! * Retrieves the manager of the type. * @@ -225,6 +134,130 @@ namespace storm { std::ostream& operator<<(std::ostream& stream, Type const& type); bool operator<(storm::expressions::Type const& type1, storm::expressions::Type const& type2); + + class BaseType { + public: + BaseType(); + virtual ~BaseType() = default; + + /*! + * Retrieves the mask that is associated with this type. + * + * @return The mask associated with this type. + */ + virtual uint64_t getMask() const = 0; + + /*! + * Checks whether two types are actually the same. + * + * @param other The type to compare with. + * @return True iff the types are the same. + */ + virtual bool operator==(BaseType const& other) const; + + /*! + * Returns a string representation of the type. + * + * @return A string representation of the type. + */ + virtual std::string getStringRepresentation() const = 0; + + virtual bool isErrorType() const; + virtual bool isBooleanType() const; + virtual bool isIntegerType() const; + virtual bool isBitVectorType() const; + virtual bool isRationalType() const; + virtual bool isArrayType() const; + }; + + class BooleanType : public BaseType { + public: + virtual uint64_t getMask() const override; + virtual std::string getStringRepresentation() const override; + virtual bool isBooleanType() const override; + + private: + static const uint64_t mask = (1ull << 60); + }; + + class IntegerType : public BaseType { + public: + virtual uint64_t getMask() const override; + virtual std::string getStringRepresentation() const override; + virtual bool isIntegerType() const override; + + private: + static const uint64_t mask = (1ull << 62); + }; + + class BitVectorType : public BaseType { + public: + /*! + * Creates a new bounded bitvector type with the given bit width. + * + * @param width The bit width of the type. + */ + BitVectorType(std::size_t width); + + /*! + * Retrieves the bit width of the bounded type. + * + * @return The bit width of the bounded type. + */ + std::size_t getWidth() const; + + virtual bool operator==(BaseType const& other) const override; + virtual uint64_t getMask() const override; + virtual std::string getStringRepresentation() const override; + virtual bool isIntegerType() const override; + virtual bool isBitVectorType() const override; + + private: + static const uint64_t mask = (1ull << 61); + + // The bit width of the type. + std::size_t width; + }; + + class RationalType : public BaseType { + public: + virtual uint64_t getMask() const override; + virtual std::string getStringRepresentation() const override; + virtual bool isRationalType() const override; + + private: + static const uint64_t mask = (1ull << 63); + }; + + class ArrayType : public BaseType { + public: + ArrayType(Type elementType); + + Type getElementType() const; + + virtual bool operator==(BaseType const& other) const override; + virtual uint64_t getMask() const override; + virtual std::string getStringRepresentation() const override; + virtual bool isArrayType() const override; + + private: + static const uint64_t mask = (1ull << 60); + + // The type of the array elements (can again be of type array). + Type elementType; + }; + + class ErrorType : public BaseType { + public: + virtual uint64_t getMask() const override; + virtual std::string getStringRepresentation() const override; + virtual bool isErrorType() const override; + + private: + static const uint64_t mask = 0; + }; + + bool operator<(BaseType const& first, BaseType const& second); } } diff --git a/src/storm/storage/jani/expressions/ArrayAccessExpression.cpp b/src/storm/storage/jani/expressions/ArrayAccessExpression.cpp new file mode 100644 index 000000000..d143fc10c --- /dev/null +++ b/src/storm/storage/jani/expressions/ArrayAccessExpression.cpp @@ -0,0 +1,36 @@ +#include "storm/storage/jani/expressions/ArrayAccessExpression.h" + +namespace storm { + namespace expressions { + + ArrayAccessExpression::ArrayAccessExpression(ExpressionManager const& manager, Type const& type, std::shared_ptr const& arrayExpression, std::shared_ptr const& indexExpression) : BinaryExpression(manager, type, arrayExpression, indexExpression) { + // Assert correct types + STORM_LOG_ASSERT(getFirstOperand()->getType().isArrayType(), "ArrayAccessExpression for an expression of type " << getFirstOperand()->getType() << "."); + STORM_LOG_ASSERT(type == getFirstOperand()->getType().getElementType(), "The ArrayAccessExpression should have type " << getFirstOperand()->getType().getElementType() << " but has " << type << " instead."); + STORM_LOG_ASSERT(getSecondOperand()->getType().isIntegerType(), "The index expression does not have an integer type."); + } + + std::shared_ptr ArrayAccessExpression::simplify() const { + return std::shared_ptr(new ArrayAccessExpression(manager, type, getFirstOperand()->simplify(), getSecondOperand()->simplify())); + } + + boost::any ArrayAccessExpression::accept(ExpressionVisitor& visitor, boost::any const& data) const { + auto janiVisitor = dynamic_cast(&visitor); + STORM_LOG_THROW(janiVisitor != nullptr, storm::exceptions::UnexpectedException, "Visitor of jani expression should be of type JaniVisitor."); + return janiVisitor->visit(*this, data); + } + + void ArrayAccessExpression::printToStream(std::ostream& stream) const { + if (firstOperand->isVariable()) { + getFirstOperand()->printToStream(stream); + } else { + stream << "("; + getFirstOperand()->printToStream(stream); + stream << ")"; + } + stream << "["; + getSecondOperand()->printToStream(stream); + stream << "]"; + } + } +} \ No newline at end of file diff --git a/src/storm/storage/jani/expressions/ArrayAccessExpression.h b/src/storm/storage/jani/expressions/ArrayAccessExpression.h new file mode 100644 index 000000000..0068b1b47 --- /dev/null +++ b/src/storm/storage/jani/expressions/ArrayAccessExpression.h @@ -0,0 +1,33 @@ +#pragma once + +#include "storm/storage/expressions/BaseExpression.h" +#include "storm/storage/jani/expressions/ExpressionManager.h" + +namespace storm { + namespace expressions { + /*! + * Represents an access to an array. + */ + class ArrayAccessExpression : public BinaryExpression { + public: + + ArrayAccessExpression(ExpressionManager const& manager, Type const& type, std::shared_ptr const& arrayExpression, std::shared_ptr const& indexExpression); + + // Instantiate constructors and assignments with their default implementations. + ArrayAccessExpression(ArrayAccessExpression const& other) = default; + ArrayAccessExpression& operator=(ArrayAccessExpression const& other) = delete; + ArrayAccessExpression(ArrayAccessExpression&&) = default; + ArrayAccessExpression& operator=(ArrayAccessExpression&&) = delete; + + virtual ~ArrayAccessExpression() = default; + + virtual std::shared_ptr simplify() const override; + virtual boost::any accept(ExpressionVisitor& visitor, boost::any const& data) const override; + + protected: + virtual void printToStream(std::ostream& stream) const override; + + + }; + } +} diff --git a/src/storm/storage/jani/expressions/ArrayExpression.cpp b/src/storm/storage/jani/expressions/ArrayExpression.cpp new file mode 100644 index 000000000..1f6aa0e49 --- /dev/null +++ b/src/storm/storage/jani/expressions/ArrayExpression.cpp @@ -0,0 +1,11 @@ +#include "storm/storage/jani/expressions/ArrayExpression.h" + +namespace storm { + namespace expressions { + + ArrayExpression::ArrayExpression(ExpressionManager const& manager, Type const& type) : BaseExpression(manager, type) { + // Intentionally left empty + } + + } +} \ No newline at end of file diff --git a/src/storm/storage/jani/expressions/ArrayExpression.h b/src/storm/storage/jani/expressions/ArrayExpression.h new file mode 100644 index 000000000..63b1bae8e --- /dev/null +++ b/src/storm/storage/jani/expressions/ArrayExpression.h @@ -0,0 +1,33 @@ +#pragma once + +#include "storm/storage/expressions/BaseExpression.h" +#include "storm/storage/jani/expressions/ExpressionManager.h" + +namespace storm { + namespace expressions { + /*! + * The base class of all array expressions. + */ + class ArrayExpression : public BaseExpression { + public: + + ArrayExpression(ExpressionManager const& manager, Type const& type); + + // Instantiate constructors and assignments with their default implementations. + ArrayExpression(ArrayExpression const& other) = default; + ArrayExpression& operator=(ArrayExpression const& other) = delete; + ArrayExpression(ArrayExpression&&) = default; + ArrayExpression& operator=(ArrayExpression&&) = delete; + + virtual ~ArrayExpression() = default; + + // Returns the size of the array + virtual std::shared_ptr size() const = 0; + + // Returns the element at position i + virtual std::shared_ptr at(uint64_t i) const = 0; + + + }; + } +} diff --git a/src/storm/storage/jani/expressions/ConstructorArrayExpression.cpp b/src/storm/storage/jani/expressions/ConstructorArrayExpression.cpp new file mode 100644 index 000000000..d643c527e --- /dev/null +++ b/src/storm/storage/jani/expressions/ConstructorArrayExpression.cpp @@ -0,0 +1,72 @@ +#include "storm/storage/jani/expressions/ConstructorArrayExpression.h" + +#include "storm/storage/jani/expressions/JaniExpressionVisitor.h" + +#include "storm/exceptions/InvalidArgumentException.h" +#include "storm/exceptions/UnexpectedException.h" + +namespace storm { + namespace expressions { + + ConstructorArrayExpression::ConstructorArrayExpression(ExpressionManager const& manager, Type const& type, std::shared_ptr const& size, storm::expressions::Variable indexVar, std::shared_ptr const& elementExpression) : ArrayExpression(manager, type), size(size), indexVar(indexVar), elementExpression(elementExpression) { + // Intentionally left empty + } + + void ConstructorArrayExpression::gatherVariables(std::set& variables) const { + // The indexVar should not be gathered (unless it is already contained). + bool indexVarContained = variables.find(indexVar) != variables.end(); + size->gatherVariables(variables); + elementExpression->gatherVariables(variables); + if (!indexVarContained) { + variables.erase(indexVar); + } + } + + bool ConstructorArrayExpression::containsVariables() const { + if (size->containsVariables()) { + return true; + } + // The index variable should not count + std::set variables; + elementExpression->gatherVariables(variables); + variables.erase(indexVar); + return !variables.empty(); + } + + std::shared_ptr ConstructorArrayExpression::simplify() const { + return std::shared_ptr(new ConstructorArrayExpression(manager, type, size->simplify(), indexVar, elementExpression->simplify())); + } + + boost::any ConstructorArrayExpression::accept(ExpressionVisitor& visitor, boost::any const& data) const { + auto janiVisitor = dynamic_cast(&visitor); + STORM_LOG_THROW(janiVisitor != nullptr, storm::exceptions::UnexpectedException, "Visitor of jani expression should be of type JaniVisitor."); + return janiVisitor->visit(*this, data); + } + + void ConstructorArrayExpression::printToStream(std::ostream& stream) const { + stream << "array[ "; + elementExpression->printToStream(stream); + stream << " | " << indexVar << "<"; + size->printToStream(stream); + stream << " ]"; + } + + std::shared_ptr ConstructorArrayExpression::size() const { + return size; + } + + std::shared_ptr ConstructorArrayExpression::at(uint64_t i) const { + STORM_LOG_THROW(i < elements.size(), storm::exceptions::InvalidArgumentException, "Tried to access the element with index " << i << " of an array of size " << elements.size() << "."); + return elements[i]; + } + + std::shared_ptr const& ConstructorArrayExpression::getElementExpression() const { + return elementExpression; + } + + storm::expressions::Variable const& ConstructorArrayExpression::getIndexVar() const { + return indexVar; + } + + } +} \ No newline at end of file diff --git a/src/storm/storage/jani/expressions/ConstructorArrayExpression.h b/src/storm/storage/jani/expressions/ConstructorArrayExpression.h new file mode 100644 index 000000000..07d52b570 --- /dev/null +++ b/src/storm/storage/jani/expressions/ConstructorArrayExpression.h @@ -0,0 +1,46 @@ +#pragma once + +#include "storm/storage/jani/expressions/ArrayExpression.h" + +namespace storm { + namespace expressions { + /*! + * Represents an array of the given size, where the i'th entry is determined by the elementExpression, where occurrences of indexVar will be substituted by i + */ + class ConstructorArrayExpression : public ArrayExpression { + public: + + ConstructorArrayExpression(ExpressionManager const& manager, Type const& type, std::shared_ptr const& size, storm::expressions::Variable indexVar, std::shared_ptr const& elementExpression); + + + // Instantiate constructors and assignments with their default implementations. + ConstructorArrayExpression(ConstructorArrayExpression const& other) = default; + ConstructorArrayExpression& operator=(ConstructorArrayExpression const& other) = delete; + ConstructorArrayExpression(ConstructorArrayExpression&&) = default; + ConstructorArrayExpression& operator=(ConstructorArrayExpression&&) = delete; + + virtual ~ConstructorArrayExpression() = default; + + virtual void gatherVariables(std::set& variables) const override; + virtual bool containsVariables() const; + virtual std::shared_ptr simplify() const override; + virtual boost::any accept(ExpressionVisitor& visitor, boost::any const& data) const override; + + // Returns the size of the array + virtual std::shared_ptr size() const override; + + // Returns the element at position i + virtual std::shared_ptr at(uint64_t i) const override; + + std::shared_ptr const& getElementExpression() const; + storm::expressions::Variable const& getIndexVar() const; + protected: + virtual void printToStream(std::ostream& stream) const override; + + private: + std::shared_ptr size; + storm::expressions::Variable indexVar; + std::shared_ptr const& elementExpression; + }; + } +} \ No newline at end of file diff --git a/src/storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.cpp b/src/storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.cpp new file mode 100644 index 000000000..0a449d3d5 --- /dev/null +++ b/src/storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.cpp @@ -0,0 +1,51 @@ +#include "storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h" + +#include "storm/exceptions/InvalidArgumentException.h" + +namespace storm { + namespace expressions { + template + boost::any JaniExpressionSubstitutionVisitor::visit(ValueArrayExpression const& expression, boost::any const& data) { + uint64_t size = expression.getSize()->evaluateAsInt(); + std::vector> newElements; + newElements.reserve(size); + for (uint64_t i = 0; i < size; ++i) { + newElements.push_back(boost::any_cast>(expression.at(i)->accept(*this, data))); + } + return std::const_pointer_cast(std::shared_ptr(new ValueArrayExpression(expression.getManager(), expression.getType(), newElements))); + } + + template + boost::any JaniExpressionSubstitutionVisitor::visit(ConstructorArrayExpression const& expression, boost::any const& data) { + std::shared_ptr newSize = boost::any_cast>(expression.getSize()->accept(*this, data)); + std::shared_ptr elementExpression = boost::any_cast>(expression.getElementExpression()->accept(*this, data)); + STORM_LOG_THROW(this->variableToExpressionMapping.find(expression.getIndexVar()) == this->variableToExpressionMapping.end(), storm::exceptions::InvalidArgumentException, "substitution of the index variable of a constructorArrayExpression is not possible."); + + // If the arguments did not change, we simply push the expression itself. + if (newSize.get() == expression.getSize().get() && elementExpression.get() == expression.getElementExpression().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new ConstructorArrayExpression(expression.getManager(), expression.getType(), newSize, expression.getIndexVar(), elementExpression))); + } + } + + template + boost::any JaniExpressionSubstitutionVisitor::visit(ArrayAccessExpression const& expression, boost::any const& data) { + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new ArrayAccessExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression))); + } + } + + + // Explicitly instantiate the class with map and unordered_map. + template class JaniExpressionSubstitutionVisitor>; + template class JaniExpressionSubstitutionVisitor>; + + } +} diff --git a/src/storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h b/src/storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h new file mode 100644 index 000000000..131c2f8ac --- /dev/null +++ b/src/storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h @@ -0,0 +1,26 @@ +#pragma once + +#include "storm/storage/expressions/SubstitutionVisitor.h" +#include "storm/storage/jani/expressions/JaniExpressions.h" + + +namespace storm { + namespace expressions { + template + class JaniExpressionSubstitutionVisitor : public SubstitutionVisitor, public ExpressionManager { + public: + /*! + * Creates a new substitution visitor that uses the given map to replace variables. + * + * @param variableToExpressionMapping A mapping from variables to expressions. + */ + JaniExpressionSubstitutionVisitor(MapType const& variableToExpressionMapping); + + virtual boost::any visit(ValueArrayExpression const& expression, boost::any const& data) override; + virtual boost::any visit(ConstructorArrayExpression const& expression, boost::any const& data) override; + virtual boost::any visit(ArrayAccessExpression const& expression, boost::any const& data) override; + }; + } +} + +#endif /* STORM_STORAGE_EXPRESSIONS_SUBSTITUTIONVISITOR_H_ */ diff --git a/src/storm/storage/jani/expressions/JaniExpressionVisitor.h b/src/storm/storage/jani/expressions/JaniExpressionVisitor.h new file mode 100644 index 000000000..d9c86e9af --- /dev/null +++ b/src/storm/storage/jani/expressions/JaniExpressionVisitor.h @@ -0,0 +1,19 @@ +#pragma once + +#include "storm/storage/expressions/SubstitutionVisitor.h" +#include "storm/storage/jani/expressions/JaniExpressions.h" + + +namespace storm { + namespace expressions { + template + class JaniExpressionVisitor{ + public: + virtual boost::any visit(ValueArrayExpression const& expression, boost::any const& data) = 0; + virtual boost::any visit(ConstructorArrayExpression const& expression, boost::any const& data) = 0; + virtual boost::any visit(ArrayAccessExpression const& expression, boost::any const& data) = 0; + }; + } +} + +#endif /* STORM_STORAGE_EXPRESSIONS_SUBSTITUTIONVISITOR_H_ */ diff --git a/src/storm/storage/jani/expressions/JaniExpressions.h b/src/storm/storage/jani/expressions/JaniExpressions.h new file mode 100644 index 000000000..a48ec5119 --- /dev/null +++ b/src/storm/storage/jani/expressions/JaniExpressions.h @@ -0,0 +1,4 @@ +#include "storm/storage/expressions/Expressions.h" +#include "storm/storage/jani/expressions/ArrayAccessExpression.h" +#include "storm/storage/jani/expressions/ConstructorArrayExpression.h" +#include "storm/storage/jani/expressions/ValueArrayExpression.h" diff --git a/src/storm/storage/jani/expressions/ValueArrayExpression.cpp b/src/storm/storage/jani/expressions/ValueArrayExpression.cpp new file mode 100644 index 000000000..feaf5a395 --- /dev/null +++ b/src/storm/storage/jani/expressions/ValueArrayExpression.cpp @@ -0,0 +1,68 @@ +#include "storm/storage/jani/expressions/ValueArrayExpression.h" + +#include "storm/storage/jani/expressions/JaniExpressionVisitor.h" + +#include "storm/exceptions/InvalidArgumentException.h" +#include "storm/exceptions/UnexpectedException.h" + +namespace storm { + namespace expressions { + + ValueArrayExpression::ValueArrayExpression(ExpressionManager const& manager, Type const& type, std::vector> const& elements) : ArrayExpression(manager, type), elements(elements) { + // Intentionally left empty + } + + void ValueArrayExpression::gatherVariables(std::set& variables) const { + for (auto const& e : elements) { + e->gatherVariables(variables); + } + } + + bool ValueArrayExpression::containsVariables() const { + for (auto const& e : elements) { + if (e->containsVariables()) { + return true; + } + } + return false; + } + + std::shared_ptr ValueArrayExpression::simplify() const { + std::vector> simplifiedElements; + simplifiedElements.reserve(elements.size()); + for (auto const& e : elements) { + simplifiedElements.push_back(e->simplify()); + } + return std::shared_ptr(new ValueArrayExpression(manager, type, simplifiedElements)); + } + + boost::any ValueArrayExpression::accept(ExpressionVisitor& visitor, boost::any const& data) const { + auto janiVisitor = dynamic_cast(&visitor); + STORM_LOG_THROW(janiVisitor != nullptr, storm::exceptions::UnexpectedException, "Visitor of jani expression should be of type JaniVisitor."); + return janiVisitor->visit(*this, data); + } + + void ValueArrayExpression::printToStream(std::ostream& stream) const { + stream << "array[ "; + bool first = true; + for (auto const& e : elements) { + e->printToStream(stream); + if (!first) { + stream << " , "; + } + first = false; + } + stream << " ]"; + } + + std::shared_ptr ValueArrayExpression::size() const { + return this->manager.integer(elements.size()).getBaseExpressionPointer(); + } + + std::shared_ptr ValueArrayExpression::at(uint64_t i) const { + STORM_LOG_THROW(i < elements.size(), storm::exceptions::InvalidArgumentException, "Tried to access the element with index " << i << " of an array of size " << elements.size() << "."); + return elements[i]; + } + + } +} \ No newline at end of file diff --git a/src/storm/storage/jani/expressions/ValueArrayExpression.h b/src/storm/storage/jani/expressions/ValueArrayExpression.h new file mode 100644 index 000000000..641330387 --- /dev/null +++ b/src/storm/storage/jani/expressions/ValueArrayExpression.h @@ -0,0 +1,42 @@ +#pragma once + +#include "storm/storage/jani/expressions/ArrayExpression.h" + +namespace storm { + namespace expressions { + /*! + * Represents an array with a given list of elements. + */ + class ValueArrayExpression : public ArrayExpression { + public: + + ValueArrayExpression(ExpressionManager const& manager, Type const& type, std::vector> elements); + + + // Instantiate constructors and assignments with their default implementations. + ValueArrayExpression(ValueArrayExpression const& other) = default; + ValueArrayExpression& operator=(ValueArrayExpression const& other) = delete; + ValueArrayExpression(ValueArrayExpression&&) = default; + ValueArrayExpression& operator=(ValueArrayExpression&&) = delete; + + virtual ~ValueArrayExpression() = default; + + virtual void gatherVariables(std::set& variables) const override; + virtual bool containsVariables() const; + virtual std::shared_ptr simplify() const override; + virtual boost::any accept(ExpressionVisitor& visitor, boost::any const& data) const override; + + // Returns the size of the array + virtual std::shared_ptr size() const override; + + // Returns the element at position i + virtual std::shared_ptr at(uint64_t i) const override; + + protected: + virtual void printToStream(std::ostream& stream) const override; + + private: + std::vector> elements; + }; + } +} \ No newline at end of file