Browse Source

More and more refactoring.

Former-commit-id: b2f5b25c92
main
dehnert 11 years ago
parent
commit
92d550be12
  1. 86
      src/storage/expressions/Expression.cpp
  2. 36
      src/storage/expressions/Expression.h
  3. 34
      src/storage/expressions/ExpressionManager.cpp
  4. 16
      src/storage/expressions/ExpressionManager.h
  5. 137
      src/storage/expressions/LinearCoefficientVisitor.cpp
  6. 29
      src/storage/expressions/LinearCoefficientVisitor.h
  7. 31
      src/storage/expressions/Type.cpp
  8. 20
      src/storage/expressions/Type.h
  9. 20
      src/storage/expressions/Variable.cpp
  10. 46
      src/storage/expressions/Variable.h
  11. 18
      src/storage/expressions/VariableExpression.cpp
  12. 2
      src/storage/prism/Assignment.cpp
  13. 5
      src/storage/prism/Assignment.h
  14. 10
      src/storage/prism/BooleanVariable.cpp
  15. 17
      src/storage/prism/BooleanVariable.h
  16. 2
      src/storage/prism/Command.cpp
  17. 2
      src/storage/prism/Command.h
  18. 26
      src/storage/prism/Constant.cpp
  19. 33
      src/storage/prism/Constant.h
  20. 6
      src/storage/prism/Formula.cpp
  21. 7
      src/storage/prism/Formula.h
  22. 2
      src/storage/prism/InitialConstruct.cpp
  23. 3
      src/storage/prism/InitialConstruct.h
  24. 10
      src/storage/prism/IntegerVariable.cpp
  25. 19
      src/storage/prism/IntegerVariable.h
  26. 2
      src/storage/prism/Label.cpp
  27. 3
      src/storage/prism/Label.h
  28. 50
      src/storage/prism/Program.cpp
  29. 24
      src/storage/prism/Program.h
  30. 2
      src/storage/prism/StateReward.cpp
  31. 3
      src/storage/prism/StateReward.h
  32. 2
      src/storage/prism/TransitionReward.cpp
  33. 3
      src/storage/prism/TransitionReward.h
  34. 2
      src/storage/prism/Update.cpp
  35. 2
      src/storage/prism/Update.h
  36. 16
      src/storage/prism/Variable.cpp
  37. 30
      src/storage/prism/Variable.h

86
src/storage/expressions/Expression.cpp

@ -2,8 +2,8 @@
#include <unordered_map> #include <unordered_map>
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/ExpressionManager.h"
#include "src/storage/expressions/SubstitutionVisitor.h" #include "src/storage/expressions/SubstitutionVisitor.h"
#include "src/storage/expressions/IdentifierSubstitutionVisitor.h"
#include "src/storage/expressions/LinearityCheckVisitor.h" #include "src/storage/expressions/LinearityCheckVisitor.h"
#include "src/storage/expressions/Expressions.h" #include "src/storage/expressions/Expressions.h"
#include "src/exceptions/InvalidTypeException.h" #include "src/exceptions/InvalidTypeException.h"
@ -15,24 +15,16 @@ namespace storm {
// Intentionally left empty. // Intentionally left empty.
} }
Expression::Expression(Variable const& variable) : expressionPtr(new VariableExpression(variable)) {
Expression::Expression(Variable const& variable) : expressionPtr(std::shared_ptr<BaseExpression>(new VariableExpression(variable))) {
// Intentionally left empty. // Intentionally left empty.
} }
Expression Expression::substitute(std::map<std::string, Expression> const& identifierToExpressionMap) const {
return SubstitutionVisitor<std::map<std::string, Expression>>(identifierToExpressionMap).substitute(*this);
Expression Expression::substitute(std::map<Variable, Expression> const& identifierToExpressionMap) const {
return SubstitutionVisitor<std::map<Variable, Expression>>(identifierToExpressionMap).substitute(*this);
} }
Expression Expression::substitute(std::unordered_map<std::string, Expression> const& identifierToExpressionMap) const {
return SubstitutionVisitor<std::unordered_map<std::string, Expression>>(identifierToExpressionMap).substitute(*this);
}
Expression Expression::substitute(std::map<std::string, std::string> const& identifierToIdentifierMap) const {
return IdentifierSubstitutionVisitor<std::map<std::string, std::string>>(identifierToIdentifierMap).substitute(*this);
}
Expression Expression::substitute(std::unordered_map<std::string, std::string> const& identifierToIdentifierMap) const {
return IdentifierSubstitutionVisitor<std::unordered_map<std::string, std::string>>(identifierToIdentifierMap).substitute(*this);
Expression Expression::substitute(std::unordered_map<Variable, Expression> const& identifierToExpressionMap) const {
return SubstitutionVisitor<std::unordered_map<Variable, Expression>>(identifierToExpressionMap).substitute(*this);
} }
bool Expression::evaluateAsBool(Valuation const* valuation) const { bool Expression::evaluateAsBool(Valuation const* valuation) const {
@ -117,7 +109,7 @@ namespace storm {
return this->expressionPtr; return this->expressionPtr;
} }
Type Expression::getType() const {
Type const& Expression::getType() const {
return this->getBaseExpression().getType(); return this->getBaseExpression().getType();
} }
@ -140,130 +132,110 @@ namespace storm {
Expression Expression::operator-(Expression const& other) const { Expression Expression::operator-(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '-' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Minus)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().plusMinusTimes(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Minus)));
} }
Expression Expression::operator-() const { Expression Expression::operator-() const {
STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '-' requires numerical operand.");
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Minus)));
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Minus)));
} }
Expression Expression::operator*(Expression const& other) const { Expression Expression::operator*(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '*' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Times)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().plusMinusTimes(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Times)));
} }
Expression Expression::operator/(Expression const& other) const { Expression Expression::operator/(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '/' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Divide)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().divide(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Divide)));
} }
Expression Expression::operator^(Expression const& other) const { Expression Expression::operator^(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '^' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Power)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().power(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Power)));
} }
Expression Expression::operator&&(Expression const& other) const { Expression Expression::operator&&(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::And)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::And)));
} }
Expression Expression::operator||(Expression const& other) const { Expression Expression::operator||(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '||' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Or)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Or)));
} }
Expression Expression::operator!() const { Expression Expression::operator!() const {
STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '!' requires boolean operand.");
return Expression(std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), UnaryBooleanFunctionExpression::OperatorType::Not)));
return Expression(std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(), this->getBaseExpressionPointer(), UnaryBooleanFunctionExpression::OperatorType::Not)));
} }
Expression Expression::operator==(Expression const& other) const { Expression Expression::operator==(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '==' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Equal)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Equal)));
} }
Expression Expression::operator!=(Expression const& other) const { Expression Expression::operator!=(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW((this->hasNumericalReturnType() && other.hasNumericalReturnType()) || (this->hasBooleanReturnType() && other.hasBooleanReturnType()), storm::exceptions::InvalidTypeException, "Operator '!=' requires operands of equal type.");
if (this->hasNumericalReturnType() && other.hasNumericalReturnType()) { if (this->hasNumericalReturnType() && other.hasNumericalReturnType()) {
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual)));
} else { } else {
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor)));
} }
} }
Expression Expression::operator>(Expression const& other) const { Expression Expression::operator>(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '>' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Greater)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Greater)));
} }
Expression Expression::operator>=(Expression const& other) const { Expression Expression::operator>=(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '>=' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::GreaterOrEqual)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::GreaterOrEqual)));
} }
Expression Expression::operator<(Expression const& other) const { Expression Expression::operator<(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '<' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Less)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Less)));
} }
Expression Expression::operator<=(Expression const& other) const { Expression Expression::operator<=(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '<=' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::LessOrEqual)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::LessOrEqual)));
} }
Expression Expression::minimum(Expression const& lhs, Expression const& rhs) { Expression Expression::minimum(Expression const& lhs, Expression const& rhs) {
assertSameManager(lhs.getBaseExpression(), rhs.getBaseExpression()); assertSameManager(lhs.getBaseExpression(), rhs.getBaseExpression());
STORM_LOG_THROW(lhs.hasNumericalReturnType() && rhs.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'min' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Min)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getType().minimumMaximum(rhs.getType()), lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Min)));
} }
Expression Expression::maximum(Expression const& lhs, Expression const& rhs) { Expression Expression::maximum(Expression const& lhs, Expression const& rhs) {
assertSameManager(lhs.getBaseExpression(), rhs.getBaseExpression()); assertSameManager(lhs.getBaseExpression(), rhs.getBaseExpression());
STORM_LOG_THROW(lhs.hasNumericalReturnType() && rhs.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'max' requires numerical operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Max)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getType().minimumMaximum(rhs.getType()), lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Max)));
} }
Expression Expression::ite(Expression const& thenExpression, Expression const& elseExpression) { Expression Expression::ite(Expression const& thenExpression, Expression const& elseExpression) {
assertSameManager(this->getBaseExpression(), thenExpression.getBaseExpression()); assertSameManager(this->getBaseExpression(), thenExpression.getBaseExpression());
assertSameManager(thenExpression.getBaseExpression(), elseExpression.getBaseExpression()); assertSameManager(thenExpression.getBaseExpression(), elseExpression.getBaseExpression());
STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Condition of if-then-else operator must be of boolean type.");
STORM_LOG_THROW(thenExpression.hasBooleanReturnType() && elseExpression.hasBooleanReturnType() || thenExpression.hasNumericalReturnType() && elseExpression.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "'then' and 'else' expression of if-then-else operator must have equal return type.");
return Expression(std::shared_ptr<BaseExpression>(new IfThenElseExpression(this->getBaseExpression().getManager(), thenExpression.hasBooleanReturnType() && elseExpression.hasBooleanReturnType() ? ExpressionReturnType::Bool : (thenExpression.getReturnType() == ExpressionReturnType::Int && elseExpression.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double), this->getBaseExpressionPointer(), thenExpression.getBaseExpressionPointer(), elseExpression.getBaseExpressionPointer())));
return Expression(std::shared_ptr<BaseExpression>(new IfThenElseExpression(this->getBaseExpression().getManager(), this->getType().ite(thenExpression.getType(), elseExpression.getType()), this->getBaseExpressionPointer(), thenExpression.getBaseExpressionPointer(), elseExpression.getBaseExpressionPointer())));
} }
Expression Expression::implies(Expression const& other) const { Expression Expression::implies(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Implies)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Implies)));
} }
Expression Expression::iff(Expression const& other) const { Expression Expression::iff(Expression const& other) const {
assertSameManager(this->getBaseExpression(), other.getBaseExpression()); assertSameManager(this->getBaseExpression(), other.getBaseExpression());
STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands.");
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Iff)));
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Iff)));
} }
Expression Expression::floor() const { Expression Expression::floor() const {
STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'floor' requires numerical operand."); STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'floor' requires numerical operand.");
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Floor)));
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().floorCeil(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Floor)));
} }
Expression Expression::ceil() const { Expression Expression::ceil() const {
STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'ceil' requires numerical operand."); STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'ceil' requires numerical operand.");
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil)));
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().floorCeil(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil)));
} }
boost::any Expression::accept(ExpressionVisitor& visitor) const { boost::any Expression::accept(ExpressionVisitor& visitor) const {

36
src/storage/expressions/Expression.h

@ -59,44 +59,26 @@ namespace storm {
static Expression maximum(Expression const& lhs, Expression const& rhs); static Expression maximum(Expression const& lhs, Expression const& rhs);
/*! /*!
* Substitutes all occurrences of identifiers according to the given map. Note that this substitution is
* done simultaneously, i.e., identifiers appearing in the expressions that were "plugged in" are not
* Substitutes all occurrences of the variables according to the given map. Note that this substitution is
* done simultaneously, i.e., variables appearing in the expressions that were "plugged in" are not
* substituted. * substituted.
* *
* @param identifierToExpressionMap A mapping from identifiers to the expression they are substituted with.
* @param variableToExpressionMap A mapping from variables to the expression they are substituted with.
* @return An expression in which all identifiers in the key set of the mapping are replaced by the * @return An expression in which all identifiers in the key set of the mapping are replaced by the
* expression they are mapped to. * expression they are mapped to.
*/ */
Expression substitute(std::map<std::string, Expression> const& identifierToExpressionMap) const;
Expression substitute(std::map<Variable, Expression> const& variableToExpressionMap) const;
/*! /*!
* Substitutes all occurrences of identifiers according to the given map. Note that this substitution is
* done simultaneously, i.e., identifiers appearing in the expressions that were "plugged in" are not
* Substitutes all occurrences of the variables according to the given map. Note that this substitution is
* done simultaneously, i.e., variables appearing in the expressions that were "plugged in" are not
* substituted. * substituted.
* *
* @param identifierToExpressionMap A mapping from identifiers to the expression they are substituted with.
* @param variableToExpressionMap A mapping from variables to the expression they are substituted with.
* @return An expression in which all identifiers in the key set of the mapping are replaced by the * @return An expression in which all identifiers in the key set of the mapping are replaced by the
* expression they are mapped to. * expression they are mapped to.
*/ */
Expression substitute(std::unordered_map<std::string, Expression> const& identifierToExpressionMap) const;
/*!
* Substitutes all occurrences of identifiers with different names given by a mapping.
*
* @param identifierToIdentifierMap A mapping from identifiers to identifiers they are substituted with.
* @return An expression in which all identifiers in the key set of the mapping are replaced by the
* identifiers they are mapped to.
*/
Expression substitute(std::map<std::string, std::string> const& identifierToIdentifierMap) const;
/*!
* Substitutes all occurrences of identifiers with different names given by a mapping.
*
* @param identifierToIdentifierMap A mapping from identifiers to identifiers they are substituted with.
* @return An expression in which all identifiers in the key set of the mapping are replaced by the
* identifiers they are mapped to.
*/
Expression substitute(std::unordered_map<std::string, std::string> const& identifierToIdentifierMap) const;
Expression substitute(std::unordered_map<Variable, Expression> const& variableToExpressionMap) const;
/*! /*!
* Evaluates the expression under the valuation of variables given by the valuation and returns the * Evaluates the expression under the valuation of variables given by the valuation and returns the
@ -247,7 +229,7 @@ namespace storm {
* *
* @return The type of the expression. * @return The type of the expression.
*/ */
Type getType() const;
Type const& getType() const;
/*! /*!
* Retrieves whether the expression has a numerical return type, i.e., integer or double. * Retrieves whether the expression has a numerical return type, i.e., integer or double.

34
src/storage/expressions/ExpressionManager.cpp

@ -47,7 +47,7 @@ namespace storm {
} }
if (nameIndexIterator != nameIndexIteratorEnd) { if (nameIndexIterator != nameIndexIteratorEnd) {
currentElement = std::make_pair(Variable(manager, nameIndexIterator->second), manager.getVariableType(nameIndexIterator->second));
currentElement = std::make_pair(Variable(manager.getSharedPointer(), nameIndexIterator->second), manager.getVariableType(nameIndexIterator->second));
} }
} }
@ -56,15 +56,15 @@ namespace storm {
} }
Expression ExpressionManager::boolean(bool value) const { Expression ExpressionManager::boolean(bool value) const {
return Expression(std::shared_ptr<BaseExpression const>(new BooleanLiteralExpression(*this, value)));
return Expression(std::shared_ptr<BaseExpression>(new BooleanLiteralExpression(*this, value)));
} }
Expression ExpressionManager::integer(int_fast64_t value) const { Expression ExpressionManager::integer(int_fast64_t value) const {
return Expression(std::shared_ptr<BaseExpression const>(new IntegerLiteralExpression(*this, value)));
return Expression(std::shared_ptr<BaseExpression>(new IntegerLiteralExpression(*this, value)));
} }
Expression ExpressionManager::rational(double value) const { Expression ExpressionManager::rational(double value) const {
return Expression(std::shared_ptr<BaseExpression const>(new DoubleLiteralExpression(*this, value)));
return Expression(std::shared_ptr<BaseExpression>(new DoubleLiteralExpression(*this, value)));
} }
bool ExpressionManager::operator==(ExpressionManager const& other) const { bool ExpressionManager::operator==(ExpressionManager const& other) const {
@ -72,19 +72,19 @@ namespace storm {
} }
Type ExpressionManager::getBooleanType() const { Type ExpressionManager::getBooleanType() const {
return Type(std::shared_ptr<BaseType>(new BooleanType()));
return Type(this->getSharedPointer(), std::shared_ptr<BaseType>(new BooleanType()));
} }
Type ExpressionManager::getIntegerType() const { Type ExpressionManager::getIntegerType() const {
return Type(std::shared_ptr<BaseType>(new IntegerType()));
return Type(this->getSharedPointer(), std::shared_ptr<BaseType>(new IntegerType()));
} }
Type ExpressionManager::getBoundedIntegerType(std::size_t width) const { Type ExpressionManager::getBoundedIntegerType(std::size_t width) const {
return Type(std::shared_ptr<BaseType>(new BoundedIntegerType(width)));
return Type(this->getSharedPointer(), std::shared_ptr<BaseType>(new BoundedIntegerType(width)));
} }
Type ExpressionManager::getRationalType() const { Type ExpressionManager::getRationalType() const {
return Type(std::shared_ptr<BaseType>(new RationalType()));
return Type(this->getSharedPointer(), std::shared_ptr<BaseType>(new RationalType()));
} }
bool ExpressionManager::isValidVariableName(std::string const& name) { bool ExpressionManager::isValidVariableName(std::string const& name) {
@ -110,7 +110,7 @@ namespace storm {
STORM_LOG_THROW(isValidVariableName(name), storm::exceptions::InvalidArgumentException, "Invalid variable name '" << name << "'."); STORM_LOG_THROW(isValidVariableName(name), storm::exceptions::InvalidArgumentException, "Invalid variable name '" << name << "'.");
auto nameIndexPair = nameToIndexMapping.find(name); auto nameIndexPair = nameToIndexMapping.find(name);
if (nameIndexPair != nameToIndexMapping.end()) { if (nameIndexPair != nameToIndexMapping.end()) {
return Variable(*this, nameIndexPair->second);
return Variable(this->getSharedPointer(), nameIndexPair->second);
} else { } else {
std::unordered_map<Type, uint_fast64_t>::iterator typeCountPair = variableTypeToCountMapping.find(variableType); std::unordered_map<Type, uint_fast64_t>::iterator typeCountPair = variableTypeToCountMapping.find(variableType);
uint_fast64_t& oldCount = variableTypeToCountMapping[variableType]; uint_fast64_t& oldCount = variableTypeToCountMapping[variableType];
@ -122,7 +122,7 @@ namespace storm {
nameToIndexMapping[name] = newIndex; nameToIndexMapping[name] = newIndex;
indexToNameMapping[newIndex] = name; indexToNameMapping[newIndex] = name;
indexToTypeMapping[newIndex] = variableType; indexToTypeMapping[newIndex] = variableType;
return Variable(*this, newIndex);
return Variable(this->getSharedPointer(), newIndex);
} }
} }
@ -130,7 +130,7 @@ namespace storm {
STORM_LOG_THROW(isValidVariableName(name), storm::exceptions::InvalidArgumentException, "Invalid variable name '" << name << "'."); STORM_LOG_THROW(isValidVariableName(name), storm::exceptions::InvalidArgumentException, "Invalid variable name '" << name << "'.");
auto nameIndexPair = nameToIndexMapping.find(name); auto nameIndexPair = nameToIndexMapping.find(name);
if (nameIndexPair != nameToIndexMapping.end()) { if (nameIndexPair != nameToIndexMapping.end()) {
return Variable(*this, nameIndexPair->second);
return Variable(this->getSharedPointer(), nameIndexPair->second);
} else { } else {
std::unordered_map<Type, uint_fast64_t>::iterator typeCountPair = auxiliaryVariableTypeToCountMapping.find(variableType); std::unordered_map<Type, uint_fast64_t>::iterator typeCountPair = auxiliaryVariableTypeToCountMapping.find(variableType);
uint_fast64_t& oldCount = auxiliaryVariableTypeToCountMapping[variableType]; uint_fast64_t& oldCount = auxiliaryVariableTypeToCountMapping[variableType];
@ -142,14 +142,14 @@ namespace storm {
nameToIndexMapping[name] = newIndex; nameToIndexMapping[name] = newIndex;
indexToNameMapping[newIndex] = name; indexToNameMapping[newIndex] = name;
indexToTypeMapping[newIndex] = variableType; indexToTypeMapping[newIndex] = variableType;
return Variable(*this, newIndex);
return Variable(this->getSharedPointer(), newIndex);
} }
} }
Variable ExpressionManager::getVariable(std::string const& name) const { Variable ExpressionManager::getVariable(std::string const& name) const {
auto nameIndexPair = nameToIndexMapping.find(name); auto nameIndexPair = nameToIndexMapping.find(name);
STORM_LOG_THROW(nameIndexPair != nameToIndexMapping.end(), storm::exceptions::InvalidArgumentException, "Unknown variable '" << name << "'."); STORM_LOG_THROW(nameIndexPair != nameToIndexMapping.end(), storm::exceptions::InvalidArgumentException, "Unknown variable '" << name << "'.");
return Variable(*this, nameIndexPair->second);
return Variable(this->getSharedPointer(), nameIndexPair->second);
} }
Expression ExpressionManager::getVariableExpression(std::string const& name) const { Expression ExpressionManager::getVariableExpression(std::string const& name) const {
@ -240,5 +240,13 @@ namespace storm {
return ExpressionManager::const_iterator(*this, this->nameToIndexMapping.end(), this->nameToIndexMapping.end(), const_iterator::VariableSelection::OnlyRegularVariables); return ExpressionManager::const_iterator(*this, this->nameToIndexMapping.end(), this->nameToIndexMapping.end(), const_iterator::VariableSelection::OnlyRegularVariables);
} }
std::shared_ptr<ExpressionManager> ExpressionManager::getSharedPointer() {
return this->shared_from_this();
}
std::shared_ptr<ExpressionManager const> ExpressionManager::getSharedPointer() const {
return this->shared_from_this();
}
} // namespace expressions } // namespace expressions
} // namespace storm } // namespace storm

16
src/storage/expressions/ExpressionManager.h

@ -58,7 +58,7 @@ namespace storm {
/*! /*!
* This class is responsible for managing a set of typed variables and all expressions using these variables. * This class is responsible for managing a set of typed variables and all expressions using these variables.
*/ */
class ExpressionManager {
class ExpressionManager : public std::enable_shared_from_this<ExpressionManager> {
public: public:
friend class VariableIterator; friend class VariableIterator;
@ -315,6 +315,20 @@ namespace storm {
const_iterator end() const; const_iterator end() const;
private: private:
/*!
* Retrieves a shared pointer to the expression manager.
*
* @return A shared pointer to the expression manager.
*/
std::shared_ptr<ExpressionManager> getSharedPointer();
/*!
* Retrieves a shared pointer to the expression manager.
*
* @return A shared pointer to the expression manager.
*/
std::shared_ptr<ExpressionManager const> getSharedPointer() const;
/*! /*!
* Checks whether the given variable name is valid. * Checks whether the given variable name is valid.
* *

137
src/storage/expressions/LinearCoefficientVisitor.cpp

@ -6,73 +6,87 @@
namespace storm { namespace storm {
namespace expressions { namespace expressions {
std::pair<SimpleValuation, double> LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) {
return boost::any_cast<std::pair<SimpleValuation, double>>(expression.getBaseExpression().accept(*this));
LinearCoefficientVisitor::VariableCoefficients::VariableCoefficients(double constantPart) : variableToCoefficientMapping(), constantPart(constantPart) {
// Intentionally left empty.
} }
boost::any LinearCoefficientVisitor::visit(IfThenElseExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator+=(VariableCoefficients&& other) {
for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) {
this->variableToCoefficientMapping[otherVariableCoefficientPair.first] += otherVariableCoefficientPair.second;
} }
boost::any LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
constantPart += other.constantPart;
return *this;
} }
boost::any LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const& expression) {
std::pair<SimpleValuation, double> leftResult = boost::any_cast<std::pair<SimpleValuation, double>>(expression.getFirstOperand()->accept(*this));
std::pair<SimpleValuation, double> rightResult = boost::any_cast<std::pair<SimpleValuation, double>>(expression.getSecondOperand()->accept(*this));
LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator-=(VariableCoefficients&& other) {
for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) {
this->variableToCoefficientMapping[otherVariableCoefficientPair.first] -= otherVariableCoefficientPair.second;
}
constantPart -= other.constantPart;
return *this;
}
if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) {
// Now add the left result to the right result.
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
if (rightResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) + rightResult.first.getDoubleValue(identifier));
} else {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier));
LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator*=(VariableCoefficients&& other) {
STORM_LOG_THROW(variableToCoefficientMapping.size() == 0 || other.variableToCoefficientMapping.size() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (other.variableToCoefficientMapping.size() > 0) {
variableToCoefficientMapping = std::move(other.variableToCoefficientMapping);
std::swap(constantPart, other.constantPart);
} }
for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) {
this->variableToCoefficientMapping[otherVariableCoefficientPair.first] *= other.constantPart;
} }
rightResult.second += leftResult.second;
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) {
// Now subtract the right result from the left result.
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
if (rightResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) - rightResult.first.getDoubleValue(identifier));
} else {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier));
constantPart *= other.constantPart;
return *this;
} }
LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator/=(VariableCoefficients&& other) {
STORM_LOG_THROW(other.variableToCoefficientMapping.size() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) {
this->variableToCoefficientMapping[otherVariableCoefficientPair.first] /= other.constantPart;
} }
for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
if (!leftResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, -rightResult.first.getDoubleValue(identifier));
constantPart /= other.constantPart;
return *this;
} }
void LinearCoefficientVisitor::VariableCoefficients::negate() {
for (auto& variableCoefficientPair : variableToCoefficientMapping) {
variableCoefficientPair.second = -variableCoefficientPair.second;
} }
rightResult.second = leftResult.second - rightResult.second;
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) {
// If the expression is linear, either the left or the right side must not contain variables.
STORM_LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (leftResult.first.getNumberOfIdentifiers() == 0) {
for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
rightResult.first.setDoubleValue(identifier, leftResult.second * rightResult.first.getDoubleValue(identifier));
constantPart = -constantPart;
} }
} else {
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
rightResult.first.addDoubleIdentifier(identifier, rightResult.second * leftResult.first.getDoubleValue(identifier));
void LinearCoefficientVisitor::VariableCoefficients::setCoefficient(storm::expressions::Variable const& variable, double coefficient) {
variableToCoefficientMapping[variable] = coefficient;
} }
double LinearCoefficientVisitor::VariableCoefficients::getCoefficient(storm::expressions::Variable const& variable) {
return variableToCoefficientMapping[variable];
} }
rightResult.second *= leftResult.second;
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) {
// If the expression is linear, either the left or the right side must not contain variables.
STORM_LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (leftResult.first.getNumberOfIdentifiers() == 0) {
for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
rightResult.first.setDoubleValue(identifier, leftResult.second / rightResult.first.getDoubleValue(identifier));
LinearCoefficientVisitor::VariableCoefficients LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) {
return boost::any_cast<VariableCoefficients>(expression.getBaseExpression().accept(*this));
} }
} else {
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
rightResult.first.addDoubleIdentifier(identifier, leftResult.first.getDoubleValue(identifier) / rightResult.second);
boost::any LinearCoefficientVisitor::visit(IfThenElseExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
} }
boost::any LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
} }
rightResult.second = leftResult.second / leftResult.second;
boost::any LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const& expression) {
VariableCoefficients leftResult = boost::any_cast<VariableCoefficients>(expression.getFirstOperand()->accept(*this));
VariableCoefficients rightResult = boost::any_cast<VariableCoefficients>(expression.getSecondOperand()->accept(*this));
if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) {
leftResult += std::move(rightResult);
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) {
leftResult -= std::move(rightResult);
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) {
leftResult *= std::move(rightResult);
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) {
leftResult /= std::move(rightResult);
} else { } else {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
} }
@ -84,15 +98,13 @@ namespace storm {
} }
boost::any LinearCoefficientVisitor::visit(VariableExpression const& expression) { boost::any LinearCoefficientVisitor::visit(VariableExpression const& expression) {
SimpleValuation valuation;
switch (expression.getReturnType()) {
case ExpressionReturnType::Bool: STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); break;
case ExpressionReturnType::Int:
case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression.getVariableName(), 1); break;
case ExpressionReturnType::Undefined: STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal expression return type."); break;
VariableCoefficients coefficients;
if (expression.getType().isNumericalType()) {
coefficients.setCoefficient(expression.getVariable(), 1);
} else {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
} }
return std::make_pair(valuation, static_cast<double>(0));
return coefficients;
} }
boost::any LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const& expression) { boost::any LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const& expression) {
@ -100,13 +112,10 @@ namespace storm {
} }
boost::any LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const& expression) { boost::any LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const& expression) {
std::pair<SimpleValuation, double> childResult = boost::any_cast<std::pair<SimpleValuation, double>>(expression.getOperand()->accept(*this));
VariableCoefficients childResult = boost::any_cast<VariableCoefficients>(expression.getOperand()->accept(*this));
if (expression.getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) { if (expression.getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) {
// Here, we need to negate all double identifiers.
for (auto const& identifier : childResult.first.getDoubleIdentifiers()) {
childResult.first.setDoubleValue(identifier, -childResult.first.getDoubleValue(identifier));
}
childResult.negate();
return childResult; return childResult;
} else { } else {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
@ -118,11 +127,11 @@ namespace storm {
} }
boost::any LinearCoefficientVisitor::visit(IntegerLiteralExpression const& expression) { boost::any LinearCoefficientVisitor::visit(IntegerLiteralExpression const& expression) {
return std::make_pair(SimpleValuation(), static_cast<double>(expression.getValue()));
return VariableCoefficients(static_cast<double>(expression.getValue()));
} }
boost::any LinearCoefficientVisitor::visit(DoubleLiteralExpression const& expression) { boost::any LinearCoefficientVisitor::visit(DoubleLiteralExpression const& expression) {
return std::make_pair(SimpleValuation(), expression.getValue());
return VariableCoefficients(expression.getValue());
} }
} }
} }

29
src/storage/expressions/LinearCoefficientVisitor.h

@ -4,6 +4,7 @@
#include <stack> #include <stack>
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Variable.h"
#include "src/storage/expressions/ExpressionVisitor.h" #include "src/storage/expressions/ExpressionVisitor.h"
#include "src/storage/expressions/SimpleValuation.h" #include "src/storage/expressions/SimpleValuation.h"
@ -11,6 +12,29 @@ namespace storm {
namespace expressions { namespace expressions {
class LinearCoefficientVisitor : public ExpressionVisitor { class LinearCoefficientVisitor : public ExpressionVisitor {
public: public:
struct VariableCoefficients {
public:
VariableCoefficients(double constantPart = 0);
VariableCoefficients(VariableCoefficients const& other) = default;
VariableCoefficients& operator=(VariableCoefficients const& other) = default;
VariableCoefficients(VariableCoefficients&& other) = default;
VariableCoefficients& operator=(VariableCoefficients&& other) = default;
VariableCoefficients& operator+=(VariableCoefficients&& other);
VariableCoefficients& operator-=(VariableCoefficients&& other);
VariableCoefficients& operator*=(VariableCoefficients&& other);
VariableCoefficients& operator/=(VariableCoefficients&& other);
void negate();
void setCoefficient(storm::expressions::Variable const& variable, double coefficient);
double getCoefficient(storm::expressions::Variable const& variable);
private:
std::map<storm::expressions::Variable, double> variableToCoefficientMapping;
double constantPart;
};
/*! /*!
* Creates a linear coefficient visitor. * Creates a linear coefficient visitor.
*/ */
@ -21,10 +45,9 @@ namespace storm {
* was rewritten as a sum of atoms.. If the expression is not linear, an exception is thrown. * was rewritten as a sum of atoms.. If the expression is not linear, an exception is thrown.
* *
* @param expression The expression for which to compute the coefficients. * @param expression The expression for which to compute the coefficients.
* @return A pair consisting of a mapping from identifiers to their coefficients and the coefficient of
* the constant atom.
* @return A structure representing the coefficients of the variables and the constant part.
*/ */
std::pair<SimpleValuation, double> getLinearCoefficients(Expression const& expression);
VariableCoefficients getLinearCoefficients(Expression const& expression);
virtual boost::any visit(IfThenElseExpression const& expression) override; virtual boost::any visit(IfThenElseExpression const& expression) override;
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override; virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override;

31
src/storage/expressions/Type.cpp

@ -50,7 +50,7 @@ namespace storm {
return "rational"; return "rational";
} }
Type::Type(ExpressionManager const& manager, std::shared_ptr<BaseType> innerType) : manager(manager), innerType(innerType) {
Type::Type(std::shared_ptr<ExpressionManager const> const& manager, std::shared_ptr<BaseType> innerType) : manager(manager), innerType(innerType) {
// Intentionally left empty. // Intentionally left empty.
} }
@ -94,12 +94,16 @@ namespace storm {
return typeid(*this->innerType) == typeid(RationalType); return typeid(*this->innerType) == typeid(RationalType);
} }
storm::expressions::ExpressionManager const& Type::getManager() const {
return *manager;
}
Type Type::plusMinusTimes(Type const& other) const { Type Type::plusMinusTimes(Type const& other) const {
STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands.");
if (this->isRationalType() || other.isRationalType()) { if (this->isRationalType() || other.isRationalType()) {
return manager.getRationalType();
return this->getManager().getRationalType();
} }
return manager.getIntegerType();
return getManager().getIntegerType();
} }
Type Type::minus() const { Type Type::minus() const {
@ -109,7 +113,15 @@ namespace storm {
Type Type::divide(Type const& other) const { Type Type::divide(Type const& other) const {
STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands.");
return manager.getRationalType();
return this->getManager().getRationalType();
}
Type Type::power(Type const& other) const {
STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands.");
if (this->isRationalType() || other.isRationalType()) {
return getManager().getRationalType();
}
return this->getManager().getIntegerType();
} }
Type Type::logicalConnective(Type const& other) const { Type Type::logicalConnective(Type const& other) const {
@ -125,27 +137,28 @@ namespace storm {
Type Type::numericalComparison(Type const& other) const { Type Type::numericalComparison(Type const& other) const {
STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands.");
if (this->isRationalType() || other.isRationalType()) { if (this->isRationalType() || other.isRationalType()) {
return manager.getRationalType();
return this->getManager().getRationalType();
} }
return manager.getIntegerType();
return this->getManager().getIntegerType();
} }
Type Type::ite(Type const& thenType, Type const& elseType) const { Type Type::ite(Type const& thenType, Type const& elseType) const {
STORM_LOG_ASSERT(this->isBooleanType(), "Operator requires boolean condition.");
STORM_LOG_ASSERT(thenType == elseType, "Operator requires equal types."); STORM_LOG_ASSERT(thenType == elseType, "Operator requires equal types.");
return thenType; return thenType;
} }
Type Type::floorCeil() const { Type Type::floorCeil() const {
STORM_LOG_ASSERT(this->isRationalType(), "Operator requires rational operand."); STORM_LOG_ASSERT(this->isRationalType(), "Operator requires rational operand.");
return manager.getIntegerType();
return this->getManager().getIntegerType();
} }
Type Type::minimumMaximum(Type const& other) const { Type Type::minimumMaximum(Type const& other) const {
STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands.");
if (this->isRationalType() || other.isRationalType()) { if (this->isRationalType() || other.isRationalType()) {
return manager.getRationalType();
return this->getManager().getRationalType();
} }
return manager.getIntegerType();
return this->getManager().getIntegerType();
} }
std::ostream& operator<<(std::ostream& stream, Type const& type) { std::ostream& operator<<(std::ostream& stream, Type const& type) {

20
src/storage/expressions/Type.h

@ -106,7 +106,15 @@ namespace storm {
class Type { class Type {
public: public:
Type(ExpressionManager const& manager, std::shared_ptr<BaseType> innerType);
Type() = default;
/*!
* Constructs a new type of the given manager with the given encapsulated type.
*
* @param manager The manager responsible for this type.
* @param innerType The encapsulated type.
*/
Type(std::shared_ptr<ExpressionManager const> const& manager, std::shared_ptr<BaseType> innerType);
/*! /*!
* Checks whether two types are the same. * Checks whether two types are the same.
@ -179,10 +187,18 @@ namespace storm {
*/ */
bool isRationalType() const; bool isRationalType() const;
/*!
* Retrieves the manager of the type.
*
* @return The manager of the type.
*/
storm::expressions::ExpressionManager const& getManager() const;
// Functions that, given the input types, produce the output type of the corresponding function application. // Functions that, given the input types, produce the output type of the corresponding function application.
Type plusMinusTimes(Type const& other) const; Type plusMinusTimes(Type const& other) const;
Type minus() const; Type minus() const;
Type divide(Type const& other) const; Type divide(Type const& other) const;
Type power(Type const& other) const;
Type logicalConnective(Type const& other) const; Type logicalConnective(Type const& other) const;
Type logicalConnective() const; Type logicalConnective() const;
Type numericalComparison(Type const& other) const; Type numericalComparison(Type const& other) const;
@ -192,7 +208,7 @@ namespace storm {
private: private:
// The manager responsible for the type. // The manager responsible for the type.
ExpressionManager const& manager;
std::shared_ptr<ExpressionManager const> manager;
// The encapsulated type. // The encapsulated type.
std::shared_ptr<BaseType> innerType; std::shared_ptr<BaseType> innerType;

20
src/storage/expressions/Variable.cpp

@ -3,7 +3,7 @@
namespace storm { namespace storm {
namespace expressions { namespace expressions {
Variable::Variable(ExpressionManager const& manager, uint_fast64_t index) : manager(manager), index(index) {
Variable::Variable(std::shared_ptr<ExpressionManager const> const& manager, uint_fast64_t index) : manager(manager), index(index) {
// Intentionally left empty. // Intentionally left empty.
} }
@ -20,19 +20,19 @@ namespace storm {
} }
uint_fast64_t Variable::getOffset() const { uint_fast64_t Variable::getOffset() const {
return manager.getOffset(index);
return this->getManager().getOffset(index);
} }
std::string const& Variable::getName() const { std::string const& Variable::getName() const {
return manager.getVariableName(index);
return this->getManager().getVariableName(index);
} }
Type const& Variable::getType() const { Type const& Variable::getType() const {
return manager.getVariableType(index);
return this->getManager().getVariableType(index);
} }
ExpressionManager const& Variable::getManager() const { ExpressionManager const& Variable::getManager() const {
return manager;
return *manager;
} }
bool Variable::hasBooleanType() const { bool Variable::hasBooleanType() const {
@ -52,3 +52,13 @@ namespace storm {
} }
} }
} }
namespace std {
std::size_t hash<storm::expressions::Variable>::operator()(storm::expressions::Variable const& variable) const {
return std::hash<uint_fast64_t>()(variable.getIndex());
}
std::size_t less<storm::expressions::Variable>::operator()(storm::expressions::Variable const& variable1, storm::expressions::Variable const& variable2) const {
return variable1.getIndex() < variable2.getIndex();
}
}

46
src/storage/expressions/Variable.h

@ -8,6 +8,26 @@
#include "src/storage/expressions/Type.h" #include "src/storage/expressions/Type.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
namespace storm {
namespace expressions {
class Variable;
}
}
namespace std {
// Provide a hashing operator, so we can put variables in unordered collections.
template <>
struct hash<storm::expressions::Variable> {
std::size_t operator()(storm::expressions::Variable const& variable) const;
};
// Provide a less operator, so we can put variables in ordered collections.
template <>
struct less<storm::expressions::Variable> {
std::size_t operator()(storm::expressions::Variable const& variable1, storm::expressions::Variable const& variable2) const;
};
}
namespace storm { namespace storm {
namespace expressions { namespace expressions {
class ExpressionManager; class ExpressionManager;
@ -15,13 +35,15 @@ namespace storm {
// This class captures a simple variable. // This class captures a simple variable.
class Variable { class Variable {
public: public:
Variable() = default;
/*! /*!
* Constructs a variable with the given index and type. * Constructs a variable with the given index and type.
* *
* @param manager The manager that is responsible for this variable. * @param manager The manager that is responsible for this variable.
* @param index The (unique) index of the variable. * @param index The (unique) index of the variable.
*/ */
Variable(ExpressionManager const& manager, uint_fast64_t index);
Variable(std::shared_ptr<ExpressionManager const> const& manager, uint_fast64_t index);
// Default-instantiate some copy/move construction/assignment. // Default-instantiate some copy/move construction/assignment.
Variable(Variable const& other) = default; Variable(Variable const& other) = default;
@ -62,6 +84,8 @@ namespace storm {
/*! /*!
* Retrieves the manager responsible for this variable. * Retrieves the manager responsible for this variable.
*
* @return The manager responsible for this variable.
*/ */
ExpressionManager const& getManager() const; ExpressionManager const& getManager() const;
@ -109,7 +133,7 @@ namespace storm {
private: private:
// The manager that is responsible for this variable. // The manager that is responsible for this variable.
ExpressionManager const& manager;
std::shared_ptr<ExpressionManager const> manager;
// The index of the variable. // The index of the variable.
uint_fast64_t index; uint_fast64_t index;
@ -117,22 +141,4 @@ namespace storm {
} }
} }
namespace std {
// Provide a hashing operator, so we can put variables in unordered collections.
template <>
struct hash<storm::expressions::Variable> {
std::size_t operator()(storm::expressions::Variable const& variable) const {
return std::hash<uint_fast64_t>()(variable.getIndex());
}
};
// Provide a less operator, so we can put variables in ordered collections.
template <>
struct less<storm::expressions::Variable> {
std::size_t operator()(storm::expressions::Variable const& variable1, storm::expressions::Variable const& variable2) const {
return variable1.getIndex() < variable2.getIndex();
}
};
}
#endif /* STORM_STORAGE_EXPRESSIONS_VARIABLE_H_ */ #endif /* STORM_STORAGE_EXPRESSIONS_VARIABLE_H_ */

18
src/storage/expressions/VariableExpression.cpp

@ -18,31 +18,27 @@ namespace storm {
bool VariableExpression::evaluateAsBool(Valuation const* valuation) const { bool VariableExpression::evaluateAsBool(Valuation const* valuation) const {
STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation."); STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation.");
STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as boolean: return type is not a boolean.");
STORM_LOG_THROW(this->hasBooleanType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as boolean: return type is not a boolean.");
return valuation->getBooleanValue(this->getVariable()); return valuation->getBooleanValue(this->getVariable());
} }
int_fast64_t VariableExpression::evaluateAsInt(Valuation const* valuation) const { int_fast64_t VariableExpression::evaluateAsInt(Valuation const* valuation) const {
STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation."); STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation.");
STORM_LOG_THROW(this->hasIntegralReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as integer: return type is not an integer.");
STORM_LOG_THROW(this->hasIntegralType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as integer: return type is not an integer.");
return valuation->getIntegerValue(this->getVariable()); return valuation->getIntegerValue(this->getVariable());
} }
double VariableExpression::evaluateAsDouble(Valuation const* valuation) const { double VariableExpression::evaluateAsDouble(Valuation const* valuation) const {
STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation."); STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation.");
STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as double: return type is not a double.");
STORM_LOG_THROW(this->hasNumericalType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as double: return type is not a double.");
switch (this->getReturnType()) {
case ExpressionReturnType::Int: return static_cast<double>(valuation->getIntegerValue(this->getVariable())); break;
case ExpressionReturnType::Double: valuation->getRationalValue(this->getVariable()); break;
default: break;
if (this->getType().isIntegralType()) {
return static_cast<double>(valuation->getIntegerValue(this->getVariable()));
} else {
return valuation->getRationalValue(this->getVariable());
} }
STORM_LOG_ASSERT(false, "Type of variable is required to be numeric.");
// Silence warning. This point can never be reached.
return 0;
} }
std::string const& VariableExpression::getIdentifier() const { std::string const& VariableExpression::getIdentifier() const {

2
src/storage/prism/Assignment.cpp

@ -14,7 +14,7 @@ namespace storm {
return this->expression; return this->expression;
} }
Assignment Assignment::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
Assignment Assignment::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return Assignment(this->getVariableName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); return Assignment(this->getVariableName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
} }

5
src/storage/prism/Assignment.h

@ -5,6 +5,7 @@
#include "src/storage/prism/LocatedInformation.h" #include "src/storage/prism/LocatedInformation.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Variable.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -45,12 +46,12 @@ namespace storm {
storm::expressions::Expression const& getExpression() const; storm::expressions::Expression const& getExpression() const;
/*! /*!
* Substitutes all identifiers in the assignment according to the given map.
* Substitutes all variables in the assignment according to the given map.
* *
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting assignment. * @return The resulting assignment.
*/ */
Assignment substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
Assignment substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, Assignment const& assignment); friend std::ostream& operator<<(std::ostream& stream, Assignment const& assignment);

10
src/storage/prism/BooleanVariable.cpp

@ -2,16 +2,12 @@
namespace storm { namespace storm {
namespace prism { namespace prism {
BooleanVariable::BooleanVariable(std::string const& variableName, std::string const& filename, uint_fast64_t lineNumber) : Variable(variableName, storm::expressions::Expression::createFalse(), true, filename, lineNumber) {
BooleanVariable::BooleanVariable(storm::expressions::Variable const& variable, storm::expressions::Expression const& initialValueExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(variable, initialValueExpression, false, filename, lineNumber) {
// Nothing to do here. // Nothing to do here.
} }
BooleanVariable::BooleanVariable(std::string const& variableName, storm::expressions::Expression const& initialValueExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(variableName, initialValueExpression, false, filename, lineNumber) {
// Nothing to do here.
}
BooleanVariable BooleanVariable::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
return BooleanVariable(this->getName(), this->getInitialValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
BooleanVariable BooleanVariable::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return BooleanVariable(this->getExpressionVariable(), this->getInitialValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
} }
std::ostream& operator<<(std::ostream& stream, BooleanVariable const& variable) { std::ostream& operator<<(std::ostream& stream, BooleanVariable const& variable) {

17
src/storage/prism/BooleanVariable.h

@ -20,23 +20,14 @@ namespace storm {
#endif #endif
/*! /*!
* Creates a boolean variable with the given name and the default initial value expression.
* Creates a boolean variable with the given constant initial value expression.
* *
* @param variableName The name of the variable.
* @param filename The filename in which the variable is defined.
* @param lineNumber The line number in which the variable is defined.
*/
BooleanVariable(std::string const& variableName, std::string const& filename = "", uint_fast64_t lineNumber = 0);
/*!
* Creates a boolean variable with the given name and the given constant initial value expression.
*
* @param variableName The name of the variable.
* @param variable The expression variable associated with this variable.
* @param initialValueExpression The constant expression that defines the initial value of the variable. * @param initialValueExpression The constant expression that defines the initial value of the variable.
* @param filename The filename in which the variable is defined. * @param filename The filename in which the variable is defined.
* @param lineNumber The line number in which the variable is defined. * @param lineNumber The line number in which the variable is defined.
*/ */
BooleanVariable(std::string const& variableName, storm::expressions::Expression const& initialValueExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0);
BooleanVariable(storm::expressions::Variable const& variable, storm::expressions::Expression const& initialValueExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0);
/*! /*!
* Substitutes all identifiers in the boolean variable according to the given map. * Substitutes all identifiers in the boolean variable according to the given map.
@ -44,7 +35,7 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting boolean variable. * @return The resulting boolean variable.
*/ */
BooleanVariable substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
BooleanVariable substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, BooleanVariable const& variable); friend std::ostream& operator<<(std::ostream& stream, BooleanVariable const& variable);
}; };

2
src/storage/prism/Command.cpp

@ -30,7 +30,7 @@ namespace storm {
return this->globalIndex; return this->globalIndex;
} }
Command Command::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
Command Command::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
std::vector<Update> newUpdates; std::vector<Update> newUpdates;
newUpdates.reserve(this->getNumberOfUpdates()); newUpdates.reserve(this->getNumberOfUpdates());
for (auto const& update : this->getUpdates()) { for (auto const& update : this->getUpdates()) {

2
src/storage/prism/Command.h

@ -82,7 +82,7 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting command. * @return The resulting command.
*/ */
Command substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
Command substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, Command const& command); friend std::ostream& operator<<(std::ostream& stream, Command const& command);

26
src/storage/prism/Constant.cpp

@ -4,20 +4,24 @@
namespace storm { namespace storm {
namespace prism { namespace prism {
Constant::Constant(storm::expressions::ExpressionReturnType type, std::string const& name, storm::expressions::Expression const& expression, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), type(type), name(name), defined(true), expression(expression) {
Constant::Constant(storm::expressions::Variable const& variable, storm::expressions::Expression const& expression, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variable(variable), defined(true), expression(expression) {
// Intentionally left empty. // Intentionally left empty.
} }
Constant::Constant(storm::expressions::ExpressionReturnType type, std::string const& name, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), type(type), name(name), defined(false), expression() {
Constant::Constant(storm::expressions::Variable const& variable, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variable(variable), defined(false), expression() {
// Intentionally left empty. // Intentionally left empty.
} }
std::string const& Constant::getName() const { std::string const& Constant::getName() const {
return this->name;
return this->variable.getName();
} }
storm::expressions::ExpressionReturnType Constant::getType() const {
return this->type;
storm::expressions::Type const& Constant::getType() const {
return this->getExpressionVariable().getType();
}
storm::expressions::Variable const& Constant::getExpressionVariable() const {
return this->variable;
} }
bool Constant::isDefined() const { bool Constant::isDefined() const {
@ -29,18 +33,12 @@ namespace storm {
return this->expression; return this->expression;
} }
Constant Constant::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
return Constant(this->getType(), this->getName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
Constant Constant::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return Constant(variable, this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
} }
std::ostream& operator<<(std::ostream& stream, Constant const& constant) { std::ostream& operator<<(std::ostream& stream, Constant const& constant) {
stream << "const ";
switch (constant.getType()) {
case storm::expressions::ExpressionReturnType::Undefined: stream << "undefined "; break;
case storm::expressions::ExpressionReturnType::Bool: stream << "bool "; break;
case storm::expressions::ExpressionReturnType::Int: stream << "int "; break;
case storm::expressions::ExpressionReturnType::Double: stream << "double "; break;
}
stream << "const " << constant.getExpressionVariable().getType();
stream << constant.getName(); stream << constant.getName();
if (constant.isDefined()) { if (constant.isDefined()) {
stream << " = " << constant.getExpression(); stream << " = " << constant.getExpression();

33
src/storage/prism/Constant.h

@ -5,6 +5,7 @@
#include "src/storage/prism/LocatedInformation.h" #include "src/storage/prism/LocatedInformation.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Variable.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -12,25 +13,23 @@ namespace storm {
class Constant : public LocatedInformation { class Constant : public LocatedInformation {
public: public:
/*! /*!
* Creates a constant with the given type, name and defining expression.
* Creates a defined constant.
* *
* @param type The type of the constant.
* @param name The name of the constant.
* @param variable The expression variable associated with the constant.
* @param expression The expression that defines the constant. * @param expression The expression that defines the constant.
* @param filename The filename in which the transition reward is defined. * @param filename The filename in which the transition reward is defined.
* @param lineNumber The line number in which the transition reward is defined. * @param lineNumber The line number in which the transition reward is defined.
*/ */
Constant(storm::expressions::ExpressionReturnType type, std::string const& name, storm::expressions::Expression const& expression, std::string const& filename = "", uint_fast64_t lineNumber = 0);
Constant(storm::expressions::Variable const& variable, storm::expressions::Expression const& expression, std::string const& filename = "", uint_fast64_t lineNumber = 0);
/*! /*!
* Creates an undefined constant with the given type and name.
* Creates an undefined constant.
* *
* @param constantType The type of the constant.
* @param constantName The name of the constant.
* @param variable The expression variable associated with the constant.
* @param filename The filename in which the transition reward is defined. * @param filename The filename in which the transition reward is defined.
* @param lineNumber The line number in which the transition reward is defined. * @param lineNumber The line number in which the transition reward is defined.
*/ */
Constant(storm::expressions::ExpressionReturnType constantType, std::string const& constantName, std::string const& filename = "", uint_fast64_t lineNumber = 0);
Constant(storm::expressions::Variable const& variable, std::string const& filename = "", uint_fast64_t lineNumber = 0);
// Create default implementations of constructors/assignment. // Create default implementations of constructors/assignment.
Constant() = default; Constant() = default;
@ -53,7 +52,14 @@ namespace storm {
* *
* @return The type of the constant; * @return The type of the constant;
*/ */
storm::expressions::ExpressionReturnType getType() const;
storm::expressions::Type const& getType() const;
/*!
* Retrieves the expression variable associated with this constant.
*
* @return The expression variable associated with this constant.
*/
storm::expressions::Variable const& getExpressionVariable() const;
/*! /*!
* Retrieves whether the constant is defined, i.e., whether there is an expression defining its value. * Retrieves whether the constant is defined, i.e., whether there is an expression defining its value.
@ -76,16 +82,13 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting constant. * @return The resulting constant.
*/ */
Constant substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
Constant substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, Constant const& constant); friend std::ostream& operator<<(std::ostream& stream, Constant const& constant);
private: private:
// The type of the constant.
storm::expressions::ExpressionReturnType type;
// The name of the constant.
std::string name;
// The expression variable associated with the constant.
storm::expressions::Variable variable;
// A flag that stores whether or not the constant is defined. // A flag that stores whether or not the constant is defined.
bool defined; bool defined;

6
src/storage/prism/Formula.cpp

@ -14,11 +14,11 @@ namespace storm {
return this->expression; return this->expression;
} }
storm::expressions::ExpressionReturnType Formula::getType() const {
return this->getExpression().getReturnType();
storm::expressions::Type const& Formula::getType() const {
return this->getExpression().getType();
} }
Formula Formula::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
Formula Formula::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return Formula(this->getName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); return Formula(this->getName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
} }

7
src/storage/prism/Formula.h

@ -5,6 +5,7 @@
#include "src/storage/prism/LocatedInformation.h" #include "src/storage/prism/LocatedInformation.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Variable.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -49,15 +50,15 @@ namespace storm {
* *
* @return The return type of the formula. * @return The return type of the formula.
*/ */
storm::expressions::ExpressionReturnType getType() const;
storm::expressions::Type const& getType() const;
/*! /*!
* Substitutes all identifiers in the expression of the formula according to the given map.
* Substitutes all variables in the expression of the formula according to the given map.
* *
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting formula. * @return The resulting formula.
*/ */
Formula substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
Formula substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, Formula const& formula); friend std::ostream& operator<<(std::ostream& stream, Formula const& formula);

2
src/storage/prism/InitialConstruct.cpp

@ -10,7 +10,7 @@ namespace storm {
return this->initialStatesExpression; return this->initialStatesExpression;
} }
InitialConstruct InitialConstruct::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
InitialConstruct InitialConstruct::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return InitialConstruct(this->getInitialStatesExpression().substitute(substitution)); return InitialConstruct(this->getInitialStatesExpression().substitute(substitution));
} }

3
src/storage/prism/InitialConstruct.h

@ -5,6 +5,7 @@
#include "src/storage/prism/LocatedInformation.h" #include "src/storage/prism/LocatedInformation.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Variable.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -42,7 +43,7 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting initial construct. * @return The resulting initial construct.
*/ */
InitialConstruct substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
InitialConstruct substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, InitialConstruct const& initialConstruct); friend std::ostream& operator<<(std::ostream& stream, InitialConstruct const& initialConstruct);

10
src/storage/prism/IntegerVariable.cpp

@ -2,11 +2,7 @@
namespace storm { namespace storm {
namespace prism { namespace prism {
IntegerVariable::IntegerVariable(std::string const& name, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(name, lowerBoundExpression, true, filename, lineNumber), lowerBoundExpression(lowerBoundExpression), upperBoundExpression(upperBoundExpression) {
// Intentionally left empty.
}
IntegerVariable::IntegerVariable(std::string const& name, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, storm::expressions::Expression const& initialValueExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(name, initialValueExpression, false, filename, lineNumber), lowerBoundExpression(lowerBoundExpression), upperBoundExpression(upperBoundExpression) {
IntegerVariable::IntegerVariable(storm::expressions::Variable const& variable, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, storm::expressions::Expression const& initialValueExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(variable, initialValueExpression, false, filename, lineNumber), lowerBoundExpression(lowerBoundExpression), upperBoundExpression(upperBoundExpression) {
// Intentionally left empty. // Intentionally left empty.
} }
@ -18,8 +14,8 @@ namespace storm {
return this->upperBoundExpression; return this->upperBoundExpression;
} }
IntegerVariable IntegerVariable::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
return IntegerVariable(this->getName(), this->getLowerBoundExpression().substitute(substitution), this->getUpperBoundExpression().substitute(substitution), this->getInitialValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
IntegerVariable IntegerVariable::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return IntegerVariable(this->getExpressionVariable(), this->getLowerBoundExpression().substitute(substitution), this->getUpperBoundExpression().substitute(substitution), this->getInitialValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
} }
std::ostream& operator<<(std::ostream& stream, IntegerVariable const& variable) { std::ostream& operator<<(std::ostream& stream, IntegerVariable const& variable) {

19
src/storage/prism/IntegerVariable.h

@ -20,27 +20,16 @@ namespace storm {
#endif #endif
/*! /*!
* Creates an integer variable with the given name and a default initial value.
* Creates an integer variable with the given initial value expression.
* *
* @param name The name of the variable.
* @param lowerBoundExpression A constant expression defining the lower bound of the domain of the variable.
* @param upperBoundExpression A constant expression defining the upper bound of the domain of the variable.
* @param filename The filename in which the variable is defined.
* @param lineNumber The line number in which the variable is defined.
*/
IntegerVariable(std::string const& name, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0);
/*!
* Creates an integer variable with the given name and the given initial value expression.
*
* @param name The name of the variable.
* @param variable The expression variable associated with this variable.
* @param lowerBoundExpression A constant expression defining the lower bound of the domain of the variable. * @param lowerBoundExpression A constant expression defining the lower bound of the domain of the variable.
* @param upperBoundExpression A constant expression defining the upper bound of the domain of the variable. * @param upperBoundExpression A constant expression defining the upper bound of the domain of the variable.
* @param initialValueExpression A constant expression that defines the initial value of the variable. * @param initialValueExpression A constant expression that defines the initial value of the variable.
* @param filename The filename in which the variable is defined. * @param filename The filename in which the variable is defined.
* @param lineNumber The line number in which the variable is defined. * @param lineNumber The line number in which the variable is defined.
*/ */
IntegerVariable(std::string const& name, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, storm::expressions::Expression const& initialValueExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0);
IntegerVariable(storm::expressions::Variable const& variable, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, storm::expressions::Expression const& initialValueExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0);
/*! /*!
* Retrieves an expression defining the lower bound for this integer variable. * Retrieves an expression defining the lower bound for this integer variable.
@ -62,7 +51,7 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting boolean variable. * @return The resulting boolean variable.
*/ */
IntegerVariable substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
IntegerVariable substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, IntegerVariable const& variable); friend std::ostream& operator<<(std::ostream& stream, IntegerVariable const& variable);

2
src/storage/prism/Label.cpp

@ -14,7 +14,7 @@ namespace storm {
return this->statePredicateExpression; return this->statePredicateExpression;
} }
Label Label::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
Label Label::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return Label(this->getName(), this->getStatePredicateExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); return Label(this->getName(), this->getStatePredicateExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
} }

3
src/storage/prism/Label.h

@ -5,6 +5,7 @@
#include "src/storage/prism/LocatedInformation.h" #include "src/storage/prism/LocatedInformation.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Variable.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -51,7 +52,7 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting label. * @return The resulting label.
*/ */
Label substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
Label substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, Label const& label); friend std::ostream& operator<<(std::ostream& stream, Label const& label);

50
src/storage/prism/Program.cpp

@ -2,34 +2,35 @@
#include <algorithm> #include <algorithm>
#include "src/storage/expressions/ExpressionManager.h"
#include "src/utility/macros.h" #include "src/utility/macros.h"
#include "exceptions/InvalidArgumentException.h"
#include "src/exceptions/InvalidArgumentException.h"
#include "src/exceptions/OutOfRangeException.h" #include "src/exceptions/OutOfRangeException.h"
#include "src/exceptions/WrongFormatException.h" #include "src/exceptions/WrongFormatException.h"
#include "src/exceptions/InvalidTypeException.h" #include "src/exceptions/InvalidTypeException.h"
namespace storm { namespace storm {
namespace prism { namespace prism {
Program::Program(ModelType modelType, std::vector<Constant> const& constants, std::vector<BooleanVariable> const& globalBooleanVariables, std::vector<IntegerVariable> const& globalIntegerVariables, std::vector<Formula> const& formulas, std::vector<Module> const& modules, std::vector<RewardModel> const& rewardModels, bool fixInitialConstruct, storm::prism::InitialConstruct const& initialConstruct, std::vector<Label> const& labels, std::string const& filename, uint_fast64_t lineNumber, bool checkValidity) : LocatedInformation(filename, lineNumber), modelType(modelType), constants(constants), constantToIndexMap(), globalBooleanVariables(globalBooleanVariables), globalBooleanVariableToIndexMap(), globalIntegerVariables(globalIntegerVariables), globalIntegerVariableToIndexMap(), formulas(formulas), formulaToIndexMap(), modules(modules), moduleToIndexMap(), rewardModels(rewardModels), rewardModelToIndexMap(), initialConstruct(initialConstruct), labels(labels), labelToIndexMap(), actions(), actionsToModuleIndexMap(), variableToModuleIndexMap() {
Program::Program(std::shared_ptr<storm::expressions::ExpressionManager> manager, ModelType modelType, std::vector<Constant> const& constants, std::vector<BooleanVariable> const& globalBooleanVariables, std::vector<IntegerVariable> const& globalIntegerVariables, std::vector<Formula> const& formulas, std::vector<Module> const& modules, std::vector<RewardModel> const& rewardModels, bool fixInitialConstruct, storm::prism::InitialConstruct const& initialConstruct, std::vector<Label> const& labels, std::string const& filename, uint_fast64_t lineNumber, bool checkValidity) : LocatedInformation(filename, lineNumber), manager(manager), modelType(modelType), constants(constants), constantToIndexMap(), globalBooleanVariables(globalBooleanVariables), globalBooleanVariableToIndexMap(), globalIntegerVariables(globalIntegerVariables), globalIntegerVariableToIndexMap(), formulas(formulas), formulaToIndexMap(), modules(modules), moduleToIndexMap(), rewardModels(rewardModels), rewardModelToIndexMap(), initialConstruct(initialConstruct), labels(labels), labelToIndexMap(), actions(), actionsToModuleIndexMap(), variableToModuleIndexMap() {
this->createMappings(); this->createMappings();
// Create a new initial construct if the corresponding flag was set. // Create a new initial construct if the corresponding flag was set.
if (fixInitialConstruct) { if (fixInitialConstruct) {
if (this->getInitialConstruct().getInitialStatesExpression().isFalse()) { if (this->getInitialConstruct().getInitialStatesExpression().isFalse()) {
storm::expressions::Expression newInitialExpression = storm::expressions::Expression::createTrue();
storm::expressions::Expression newInitialExpression = manager->boolean(true);
for (auto const& booleanVariable : this->getGlobalBooleanVariables()) { for (auto const& booleanVariable : this->getGlobalBooleanVariables()) {
newInitialExpression = newInitialExpression && (storm::expressions::Expression::createBooleanVariable(booleanVariable.getName()).iff(booleanVariable.getInitialValueExpression()));
newInitialExpression = newInitialExpression && booleanVariable.getExpression().iff(booleanVariable.getInitialValueExpression());
} }
for (auto const& integerVariable : this->getGlobalIntegerVariables()) { for (auto const& integerVariable : this->getGlobalIntegerVariables()) {
newInitialExpression = newInitialExpression && (storm::expressions::Expression::createIntegerVariable(integerVariable.getName()) == integerVariable.getInitialValueExpression());
newInitialExpression = newInitialExpression && integerVariable.getExpression() == integerVariable.getInitialValueExpression();
} }
for (auto const& module : this->getModules()) { for (auto const& module : this->getModules()) {
for (auto const& booleanVariable : module.getBooleanVariables()) { for (auto const& booleanVariable : module.getBooleanVariables()) {
newInitialExpression = newInitialExpression && (storm::expressions::Expression::createBooleanVariable(booleanVariable.getName()).iff(booleanVariable.getInitialValueExpression()));
newInitialExpression = newInitialExpression && booleanVariable.getExpression().iff(booleanVariable.getInitialValueExpression());
} }
for (auto const& integerVariable : module.getIntegerVariables()) { for (auto const& integerVariable : module.getIntegerVariables()) {
newInitialExpression = newInitialExpression && (storm::expressions::Expression::createIntegerVariable(integerVariable.getName()) == integerVariable.getInitialValueExpression());
newInitialExpression = newInitialExpression && integerVariable.getExpression() == integerVariable.getInitialValueExpression();
} }
} }
this->initialConstruct = storm::prism::InitialConstruct(newInitialExpression, this->getInitialConstruct().getFilename(), this->getInitialConstruct().getLineNumber()); this->initialConstruct = storm::prism::InitialConstruct(newInitialExpression, this->getInitialConstruct().getFilename(), this->getInitialConstruct().getLineNumber());
@ -189,7 +190,7 @@ namespace storm {
newModules.push_back(module.restrictCommands(indexSet)); newModules.push_back(module.restrictCommands(indexSet));
} }
return Program(this->getModelType(), this->getConstants(), this->getGlobalBooleanVariables(), this->getGlobalIntegerVariables(), this->getFormulas(), newModules, this->getRewardModels(), false, this->getInitialConstruct(), this->getLabels());
return Program(this->manager, this->getModelType(), this->getConstants(), this->getGlobalBooleanVariables(), this->getGlobalIntegerVariables(), this->getFormulas(), newModules, this->getRewardModels(), false, this->getInitialConstruct(), this->getLabels());
} }
void Program::createMappings() { void Program::createMappings() {
@ -240,10 +241,9 @@ namespace storm {
} }
Program Program::defineUndefinedConstants(std::map<std::string, storm::expressions::Expression> const& constantDefinitions) const {
// For sanity checking, we keep track of all undefined constants that we define in the course of this
// procedure.
std::set<std::string> definedUndefinedConstants;
Program Program::defineUndefinedConstants(std::map<storm::expressions::Variable, storm::expressions::Expression> const& constantDefinitions) const {
// For sanity checking, we keep track of all undefined constants that we define in the course of this procedure.
std::set<storm::expressions::Variable> definedUndefinedConstants;
std::vector<Constant> newConstants; std::vector<Constant> newConstants;
newConstants.reserve(this->getNumberOfConstants()); newConstants.reserve(this->getNumberOfConstants());
@ -252,25 +252,25 @@ namespace storm {
// defining expression // defining expression
if (constant.isDefined()) { if (constant.isDefined()) {
// Make sure we are not trying to define an already defined constant. // Make sure we are not trying to define an already defined constant.
STORM_LOG_THROW(constantDefinitions.find(constant.getName()) == constantDefinitions.end(), storm::exceptions::InvalidArgumentException, "Illegally defining already defined constant '" << constant.getName() << "'.");
STORM_LOG_THROW(constantDefinitions.find(constant.getExpressionVariable()) == constantDefinitions.end(), storm::exceptions::InvalidArgumentException, "Illegally defining already defined constant '" << constant.getName() << "'.");
// Now replace the occurrences of undefined constants in its defining expression. // Now replace the occurrences of undefined constants in its defining expression.
newConstants.emplace_back(constant.getType(), constant.getName(), constant.getExpression().substitute(constantDefinitions), constant.getFilename(), constant.getLineNumber()); newConstants.emplace_back(constant.getType(), constant.getName(), constant.getExpression().substitute(constantDefinitions), constant.getFilename(), constant.getLineNumber());
} else { } else {
auto const& variableExpressionPair = constantDefinitions.find(constant.getName());
auto const& variableExpressionPair = constantDefinitions.find(constant.getExpressionVariable());
// If the constant is not defined by the mapping, we leave it like it is. // If the constant is not defined by the mapping, we leave it like it is.
if (variableExpressionPair == constantDefinitions.end()) { if (variableExpressionPair == constantDefinitions.end()) {
newConstants.emplace_back(constant); newConstants.emplace_back(constant);
} else { } else {
// Otherwise, we add it to the defined constants and assign it the appropriate expression. // Otherwise, we add it to the defined constants and assign it the appropriate expression.
definedUndefinedConstants.insert(constant.getName());
definedUndefinedConstants.insert(constant.getExpressionVariable());
// Make sure the type of the constant is correct. // Make sure the type of the constant is correct.
STORM_LOG_THROW(variableExpressionPair->second.getReturnType() == constant.getType(), storm::exceptions::InvalidArgumentException, "Illegal type of expression defining constant '" << constant.getName() << "'.");
STORM_LOG_THROW(variableExpressionPair->second.getType() == constant.getType(), storm::exceptions::InvalidArgumentException, "Illegal type of expression defining constant '" << constant.getName() << "'.");
// Now create the defined constant. // Now create the defined constant.
newConstants.emplace_back(constant.getType(), constant.getName(), variableExpressionPair->second, constant.getFilename(), constant.getLineNumber());
newConstants.emplace_back(constant.getExpressionVariable(), variableExpressionPair->second, constant.getFilename(), constant.getLineNumber());
} }
} }
} }
@ -281,19 +281,19 @@ namespace storm {
STORM_LOG_THROW(definedUndefinedConstants.find(constantExpressionPair.first) != definedUndefinedConstants.end(), storm::exceptions::InvalidArgumentException, "Unable to define non-existant constant."); STORM_LOG_THROW(definedUndefinedConstants.find(constantExpressionPair.first) != definedUndefinedConstants.end(), storm::exceptions::InvalidArgumentException, "Unable to define non-existant constant.");
} }
return Program(this->getModelType(), newConstants, this->getGlobalBooleanVariables(), this->getGlobalIntegerVariables(), this->getFormulas(), this->getModules(), this->getRewardModels(), false, this->getInitialConstruct(), this->getLabels());
return Program(this->manager, this->getModelType(), newConstants, this->getGlobalBooleanVariables(), this->getGlobalIntegerVariables(), this->getFormulas(), this->getModules(), this->getRewardModels(), false, this->getInitialConstruct(), this->getLabels());
} }
Program Program::substituteConstants() const { Program Program::substituteConstants() const {
// We start by creating the appropriate substitution // We start by creating the appropriate substitution
std::map<std::string, storm::expressions::Expression> constantSubstitution;
std::map<storm::expressions::Variable, storm::expressions::Expression> constantSubstitution;
std::vector<Constant> newConstants(this->getConstants()); std::vector<Constant> newConstants(this->getConstants());
for (uint_fast64_t constantIndex = 0; constantIndex < newConstants.size(); ++constantIndex) { for (uint_fast64_t constantIndex = 0; constantIndex < newConstants.size(); ++constantIndex) {
auto const& constant = newConstants[constantIndex]; auto const& constant = newConstants[constantIndex];
STORM_LOG_THROW(constant.isDefined(), storm::exceptions::InvalidArgumentException, "Cannot substitute constants in program that contains undefined constants."); STORM_LOG_THROW(constant.isDefined(), storm::exceptions::InvalidArgumentException, "Cannot substitute constants in program that contains undefined constants.");
// Put the corresponding expression in the substitution. // Put the corresponding expression in the substitution.
constantSubstitution.emplace(constant.getName(), constant.getExpression());
constantSubstitution.emplace(constant.getExpressionVariable(), constant.getExpression());
// If there is at least one more constant to come, we substitute the costants we have so far. // If there is at least one more constant to come, we substitute the costants we have so far.
if (constantIndex + 1 < newConstants.size()) { if (constantIndex + 1 < newConstants.size()) {
@ -340,7 +340,7 @@ namespace storm {
newLabels.emplace_back(label.substitute(constantSubstitution)); newLabels.emplace_back(label.substitute(constantSubstitution));
} }
return Program(this->getModelType(), newConstants, newBooleanVariables, newIntegerVariables, newFormulas, newModules, newRewardModels, false, newInitialConstruct, newLabels);
return Program(this->manager, this->getModelType(), newConstants, newBooleanVariables, newIntegerVariables, newFormulas, newModules, newRewardModels, false, newInitialConstruct, newLabels);
} }
void Program::checkValidity() const { void Program::checkValidity() const {
@ -571,6 +571,14 @@ namespace storm {
} }
} }
storm::expressions::ExpressionManager const& Program::getManager() const {
return this->manager;
}
storm::expressions::ExpressionManager& Program::getManager() {
return this->manager;
}
std::ostream& operator<<(std::ostream& stream, Program const& program) { std::ostream& operator<<(std::ostream& stream, Program const& program) {
switch (program.getModelType()) { switch (program.getModelType()) {
case Program::ModelType::UNDEFINED: stream << "undefined"; break; case Program::ModelType::UNDEFINED: stream << "undefined"; break;

24
src/storage/prism/Program.h

@ -28,6 +28,7 @@ namespace storm {
* Creates a program with the given model type, undefined constants, global variables, modules, reward * Creates a program with the given model type, undefined constants, global variables, modules, reward
* models, labels and initial states. * models, labels and initial states.
* *
* @param manager The manager responsible for the variables and expressions of the program.
* @param modelType The type of the program. * @param modelType The type of the program.
* @param constants The constants of the program. * @param constants The constants of the program.
* @param globalBooleanVariables The global boolean variables of the program. * @param globalBooleanVariables The global boolean variables of the program.
@ -45,7 +46,7 @@ namespace storm {
* @param lineNumber The line number in which the program is defined. * @param lineNumber The line number in which the program is defined.
* @param checkValidity If set to true, the program is checked for validity. * @param checkValidity If set to true, the program is checked for validity.
*/ */
Program(ModelType modelType, std::vector<Constant> const& constants, std::vector<BooleanVariable> const& globalBooleanVariables, std::vector<IntegerVariable> const& globalIntegerVariables, std::vector<Formula> const& formulas, std::vector<Module> const& modules, std::vector<RewardModel> const& rewardModels, bool fixInitialConstruct, storm::prism::InitialConstruct const& initialConstruct, std::vector<Label> const& labels, std::string const& filename = "", uint_fast64_t lineNumber = 0, bool checkValidity = true);
Program(std::shared_ptr<storm::expressions::ExpressionManager> manager, ModelType modelType, std::vector<Constant> const& constants, std::vector<BooleanVariable> const& globalBooleanVariables, std::vector<IntegerVariable> const& globalIntegerVariables, std::vector<Formula> const& formulas, std::vector<Module> const& modules, std::vector<RewardModel> const& rewardModels, bool fixInitialConstruct, storm::prism::InitialConstruct const& initialConstruct, std::vector<Label> const& labels, std::string const& filename = "", uint_fast64_t lineNumber = 0, bool checkValidity = true);
// Provide default implementations for constructors and assignments. // Provide default implementations for constructors and assignments.
Program() = default; Program() = default;
@ -288,12 +289,12 @@ namespace storm {
/*! /*!
* Defines the undefined constants according to the given map and returns the resulting program. * Defines the undefined constants according to the given map and returns the resulting program.
* *
* @param constantDefinitions A mapping from undefined constant names to the expressions they are supposed
* @param constantDefinitions A mapping from undefined constant to the expressions they are supposed
* to be replaced with. * to be replaced with.
* @return The program after all undefined constants in the given map have been replaced with their * @return The program after all undefined constants in the given map have been replaced with their
* definitions. * definitions.
*/ */
Program defineUndefinedConstants(std::map<std::string, storm::expressions::Expression> const& constantDefinitions) const;
Program defineUndefinedConstants(std::map<storm::expressions::Variable, storm::expressions::Expression> const& constantDefinitions) const;
/*! /*!
* Substitutes all constants appearing in the expressions of the program by their defining expressions. For * Substitutes all constants appearing in the expressions of the program by their defining expressions. For
@ -311,7 +312,24 @@ namespace storm {
friend std::ostream& operator<<(std::ostream& stream, Program const& program); friend std::ostream& operator<<(std::ostream& stream, Program const& program);
/*!
* Retrieves the manager responsible for the expressions of this program.
*
* @return The manager responsible for the expressions of this program.
*/
storm::expressions::ExpressionManager const& getManager() const;
/*!
* Retrieves the manager responsible for the expressions of this program.
*
* @return The manager responsible for the expressions of this program.
*/
storm::expressions::ExpressionManager& getManager();
private: private:
// The manager responsible for the variables/expressions of the program.
std::shared_ptr<storm::expressions::ExpressionManager> manager;
// Creates the internal mappings. // Creates the internal mappings.
void createMappings(); void createMappings();

2
src/storage/prism/StateReward.cpp

@ -14,7 +14,7 @@ namespace storm {
return this->rewardValueExpression; return this->rewardValueExpression;
} }
StateReward StateReward::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
StateReward StateReward::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return StateReward(this->getStatePredicateExpression().substitute(substitution), this->getRewardValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); return StateReward(this->getStatePredicateExpression().substitute(substitution), this->getRewardValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
} }

3
src/storage/prism/StateReward.h

@ -5,6 +5,7 @@
#include "src/storage/prism/LocatedInformation.h" #include "src/storage/prism/LocatedInformation.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Variable.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -51,7 +52,7 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting state reward. * @return The resulting state reward.
*/ */
StateReward substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
StateReward substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, StateReward const& stateReward); friend std::ostream& operator<<(std::ostream& stream, StateReward const& stateReward);

2
src/storage/prism/TransitionReward.cpp

@ -18,7 +18,7 @@ namespace storm {
return this->rewardValueExpression; return this->rewardValueExpression;
} }
TransitionReward TransitionReward::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
TransitionReward TransitionReward::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
return TransitionReward(this->getActionName(), this->getStatePredicateExpression().substitute(substitution), this->getRewardValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); return TransitionReward(this->getActionName(), this->getStatePredicateExpression().substitute(substitution), this->getRewardValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber());
} }

3
src/storage/prism/TransitionReward.h

@ -5,6 +5,7 @@
#include "src/storage/prism/LocatedInformation.h" #include "src/storage/prism/LocatedInformation.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/Variable.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -60,7 +61,7 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting transition reward. * @return The resulting transition reward.
*/ */
TransitionReward substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
TransitionReward substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, TransitionReward const& transitionReward); friend std::ostream& operator<<(std::ostream& stream, TransitionReward const& transitionReward);

2
src/storage/prism/Update.cpp

@ -37,7 +37,7 @@ namespace storm {
} }
} }
Update Update::substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const {
Update Update::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const {
std::vector<Assignment> newAssignments; std::vector<Assignment> newAssignments;
newAssignments.reserve(this->getNumberOfAssignments()); newAssignments.reserve(this->getNumberOfAssignments());
for (auto const& assignment : this->getAssignments()) { for (auto const& assignment : this->getAssignments()) {

2
src/storage/prism/Update.h

@ -72,7 +72,7 @@ namespace storm {
* @param substitution The substitution to perform. * @param substitution The substitution to perform.
* @return The resulting update. * @return The resulting update.
*/ */
Update substitute(std::map<std::string, storm::expressions::Expression> const& substitution) const;
Update substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) const;
friend std::ostream& operator<<(std::ostream& stream, Update const& assignment); friend std::ostream& operator<<(std::ostream& stream, Update const& assignment);

16
src/storage/prism/Variable.cpp

@ -1,19 +1,20 @@
#include <map> #include <map>
#include "src/storage/prism/Variable.h" #include "src/storage/prism/Variable.h"
#include "src/storage/expressions/ExpressionManager.h"
namespace storm { namespace storm {
namespace prism { namespace prism {
Variable::Variable(std::string const& name, storm::expressions::Expression const& initialValueExpression, bool defaultInitialValue, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), name(name), initialValueExpression(initialValueExpression), defaultInitialValue(defaultInitialValue) {
Variable::Variable(storm::expressions::Variable const& variable, storm::expressions::Expression const& initialValueExpression, bool defaultInitialValue, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variable(variable), initialValueExpression(initialValueExpression), defaultInitialValue(defaultInitialValue) {
// Nothing to do here. // Nothing to do here.
} }
Variable::Variable(Variable const& oldVariable, std::string const& newName, std::map<std::string, std::string> const& renaming, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), name(newName), initialValueExpression(oldVariable.getInitialValueExpression().substitute(renaming)), defaultInitialValue(oldVariable.hasDefaultInitialValue()) {
Variable::Variable(storm::expressions::ExpressionManager& manager, Variable const& oldVariable, std::string const& newName, std::map<storm::expressions::Variable, storm::expressions::Expression> const& renaming, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variable(manager.declareVariable(newName, oldVariable.variable.getType())), initialValueExpression(oldVariable.getInitialValueExpression().substitute(renaming)), defaultInitialValue(oldVariable.hasDefaultInitialValue()) {
// Intentionally left empty. // Intentionally left empty.
} }
std::string const& Variable::getName() const { std::string const& Variable::getName() const {
return this->name;
return this->variable.getName();
} }
bool Variable::hasDefaultInitialValue() const { bool Variable::hasDefaultInitialValue() const {
@ -23,5 +24,14 @@ namespace storm {
storm::expressions::Expression const& Variable::getInitialValueExpression() const { storm::expressions::Expression const& Variable::getInitialValueExpression() const {
return this->initialValueExpression; return this->initialValueExpression;
} }
storm::expressions::Variable const& Variable::getExpressionVariable() const {
return this->variable;
}
storm::expressions::Expression Variable::getExpression() const {
return variable.getExpression();
}
} // namespace prism } // namespace prism
} // namespace storm } // namespace storm

30
src/storage/prism/Variable.h

@ -4,6 +4,7 @@
#include <map> #include <map>
#include "src/storage/prism/LocatedInformation.h" #include "src/storage/prism/LocatedInformation.h"
#include "src/storage/expressions/Variable.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
@ -40,36 +41,51 @@ namespace storm {
*/ */
bool hasDefaultInitialValue() const; bool hasDefaultInitialValue() const;
/*!
* Retrieves the expression variable associated with this variable.
*
* @return The expression variable associated with this variable.
*/
storm::expressions::Variable const& getExpressionVariable() const;
/*!
* Retrieves the expression associated with this variable.
*
* @return The expression associated with this variable.
*/
storm::expressions::Expression getExpression() const;
// Make the constructors protected to forbid instantiation of this class. // Make the constructors protected to forbid instantiation of this class.
protected: protected:
Variable() = default; Variable() = default;
/*! /*!
* Creates a variable with the given name and initial value.
* Creates a variable with the given initial value.
* *
* @param name The name of the variable.
* @param variable The associated expression variable.
* @param initialValueExpression The constant expression that defines the initial value of the variable. * @param initialValueExpression The constant expression that defines the initial value of the variable.
* @param hasDefaultInitialValue A flag indicating whether the initial value of the variable is its default * @param hasDefaultInitialValue A flag indicating whether the initial value of the variable is its default
* value. * value.
* @param filename The filename in which the variable is defined. * @param filename The filename in which the variable is defined.
* @param lineNumber The line number in which the variable is defined. * @param lineNumber The line number in which the variable is defined.
*/ */
Variable(std::string const& name, storm::expressions::Expression const& initialValueExpression, bool defaultInitialValue, std::string const& filename = "", uint_fast64_t lineNumber = 0);
Variable(storm::expressions::Variable const& variable, storm::expressions::Expression const& initialValueExpression, bool defaultInitialValue, std::string const& filename = "", uint_fast64_t lineNumber = 0);
/*! /*!
* Creates a copy of the given variable and performs the provided renaming. * Creates a copy of the given variable and performs the provided renaming.
* *
* @param manager The manager responsible for the variable.
* @param oldVariable The variable to copy. * @param oldVariable The variable to copy.
* @param newName New name of this variable. * @param newName New name of this variable.
* @param renaming A mapping from names that are to be renamed to the names they are to be replaced with.
* @param renaming A mapping from variables to the expressions with which they are to be replaced.
* @param filename The filename in which the variable is defined. * @param filename The filename in which the variable is defined.
* @param lineNumber The line number in which the variable is defined. * @param lineNumber The line number in which the variable is defined.
*/ */
Variable(Variable const& oldVariable, std::string const& newName, std::map<std::string, std::string> const& renaming, std::string const& filename = "", uint_fast64_t lineNumber = 0);
Variable(storm::expressions::ExpressionManager& manager, Variable const& oldVariable, std::string const& newName, std::map<storm::expressions::Variable, storm::expressions::Expression> const& renaming, std::string const& filename = "", uint_fast64_t lineNumber = 0);
private: private:
// The name of the variable.
std::string name;
// The expression variable associated with this variable.
storm::expressions::Variable variable;
// The constant expression defining the initial value of the variable. // The constant expression defining the initial value of the variable.
storm::expressions::Expression initialValueExpression; storm::expressions::Expression initialValueExpression;

Loading…
Cancel
Save