Browse Source

Refactored some parts of expressions. In particular, visitors now can return anything they want by using boost::any.

Former-commit-id: 0f6af138ae
main
dehnert 11 years ago
parent
commit
809217c359
  1. 199
      src/adapters/MathsatExpressionAdapter.h
  2. 210
      src/adapters/Z3ExpressionAdapter.h
  3. 2
      src/storage/expressions/BaseExpression.h
  4. 4
      src/storage/expressions/BinaryBooleanFunctionExpression.cpp
  5. 2
      src/storage/expressions/BinaryBooleanFunctionExpression.h
  6. 4
      src/storage/expressions/BinaryNumericalFunctionExpression.cpp
  7. 2
      src/storage/expressions/BinaryNumericalFunctionExpression.h
  8. 4
      src/storage/expressions/BinaryRelationExpression.cpp
  9. 2
      src/storage/expressions/BinaryRelationExpression.h
  10. 4
      src/storage/expressions/BooleanLiteralExpression.cpp
  11. 2
      src/storage/expressions/BooleanLiteralExpression.h
  12. 4
      src/storage/expressions/DoubleLiteralExpression.cpp
  13. 2
      src/storage/expressions/DoubleLiteralExpression.h
  14. 24
      src/storage/expressions/Expression.cpp
  15. 18
      src/storage/expressions/Expression.h
  16. 22
      src/storage/expressions/ExpressionVisitor.h
  17. 120
      src/storage/expressions/IdentifierSubstitutionVisitor.cpp
  18. 23
      src/storage/expressions/IdentifierSubstitutionVisitor.h
  19. 4
      src/storage/expressions/IfThenElseExpression.cpp
  20. 2
      src/storage/expressions/IfThenElseExpression.h
  21. 4
      src/storage/expressions/IntegerLiteralExpression.cpp
  22. 2
      src/storage/expressions/IntegerLiteralExpression.h
  23. 83
      src/storage/expressions/LinearCoefficientVisitor.cpp
  24. 23
      src/storage/expressions/LinearCoefficientVisitor.h
  25. 106
      src/storage/expressions/LinearityCheckVisitor.cpp
  26. 23
      src/storage/expressions/LinearityCheckVisitor.h
  27. 120
      src/storage/expressions/SubstitutionVisitor.cpp
  28. 23
      src/storage/expressions/SubstitutionVisitor.h
  29. 80
      src/storage/expressions/TypeCheckVisitor.cpp
  30. 47
      src/storage/expressions/TypeCheckVisitor.h
  31. 4
      src/storage/expressions/UnaryBooleanFunctionExpression.cpp
  32. 2
      src/storage/expressions/UnaryBooleanFunctionExpression.h
  33. 4
      src/storage/expressions/UnaryNumericalFunctionExpression.cpp
  34. 2
      src/storage/expressions/UnaryNumericalFunctionExpression.h
  35. 4
      src/storage/expressions/VariableExpression.cpp
  36. 2
      src/storage/expressions/VariableExpression.h
  37. 98
      src/storage/prism/Program.cpp
  38. 1
      test/functional/solver/GlpkLpSolverTest.cpp

199
src/adapters/MathsatExpressionAdapter.h

@ -36,219 +36,171 @@ namespace storm {
* expressions and are not yet known to the adapter.
* @param variableToDeclarationMap A mapping from variable names to their corresponding MathSAT declarations (if already existing).
*/
MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map<std::string, msat_decl> const& variableToDeclarationMap = std::map<std::string, msat_decl>()) : env(env), stack(), variableToDeclarationMap(variableToDeclarationMap), createVariables(createVariables) {
MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map<std::string, msat_decl> const& variableToDeclarationMap = std::map<std::string, msat_decl>()) : env(env), variableToDeclarationMap(variableToDeclarationMap), createVariables(createVariables) {
// Intentionally left empty.
}
/*!
* Translates the given expression to an equivalent term for MathSAT.
*
* @param expression The expression to be translated.
* @return An equivalent term for MathSAT.
*/
* Translates the given expression to an equivalent term for MathSAT.
*
* @param expression The expression to be translated.
* @return An equivalent term for MathSAT.
*/
msat_term translateExpression(storm::expressions::Expression const& expression) {
expression.getBaseExpression().accept(this);
msat_term result = stack.top();
stack.pop();
msat_term result = boost::any_cast<msat_term>(expression.getBaseExpression().accept(*this));
STORM_LOG_THROW(!MSAT_ERROR_TERM(result), storm::exceptions::ExpressionEvaluationException, "Could not translate expression to MathSAT's format.");
return result;
}
virtual void visit(expressions::BinaryBooleanFunctionExpression const* expression) override {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
virtual boost::any visit(expressions::BinaryBooleanFunctionExpression const& expression) override {
msat_term leftResult = boost::any_cast<msat_term>(expression.getFirstOperand()->accept(*this));
msat_term rightResult = boost::any_cast<msat_term>(expression.getSecondOperand()->accept(*this));
msat_term rightResult = stack.top();
stack.pop();
msat_term leftResult = stack.top();
stack.pop();
switch (expression->getOperatorType()) {
switch (expression.getOperatorType()) {
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And:
stack.push(msat_make_and(env, leftResult, rightResult));
break;
return msat_make_and(env, leftResult, rightResult);
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or:
stack.push(msat_make_or(env, leftResult, rightResult));
break;
return msat_make_or(env, leftResult, rightResult);
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff:
stack.push(msat_make_iff(env, leftResult, rightResult));
break;
return msat_make_iff(env, leftResult, rightResult);
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies:
stack.push(msat_make_or(env, msat_make_not(env, leftResult), rightResult));
break;
return msat_make_or(env, msat_make_not(env, leftResult), rightResult);
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".");
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<uint_fast64_t>(expression.getOperatorType()) << "' in expression " << expression << ".");
}
}
virtual void visit(expressions::BinaryNumericalFunctionExpression const* expression) override {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
msat_term rightResult = stack.top();
stack.pop();
msat_term leftResult = stack.top();
stack.pop();
virtual boost::any visit(expressions::BinaryNumericalFunctionExpression const& expression) override {
msat_term leftResult = boost::any_cast<msat_term>(expression.getFirstOperand()->accept(*this));
msat_term rightResult = boost::any_cast<msat_term>(expression.getSecondOperand()->accept(*this));
switch (expression->getOperatorType()) {
switch (expression.getOperatorType()) {
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus:
stack.push(msat_make_plus(env, leftResult, rightResult));
break;
return msat_make_plus(env, leftResult, rightResult);
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus:
stack.push(msat_make_plus(env, leftResult, msat_make_times(env, msat_make_number(env, "-1"), rightResult)));
break;
return msat_make_plus(env, leftResult, msat_make_times(env, msat_make_number(env, "-1"), rightResult));
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times:
stack.push(msat_make_times(env, leftResult, rightResult));
break;
return msat_make_times(env, leftResult, rightResult);
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide:
throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unsupported numerical binary operator: '/' (division) in expression " << expression << ".";
break;
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unsupported numerical binary operator: '/' (division) in expression.");
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min:
stack.push(msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), leftResult, rightResult));
break;
return msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), leftResult, rightResult);
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max:
stack.push(msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), rightResult, leftResult));
break;
return msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), rightResult, leftResult);
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".");
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast<uint_fast64_t>(expression.getOperatorType()) << "' in expression " << expression << ".");
}
}
virtual void visit(expressions::BinaryRelationExpression const* expression) override {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
msat_term rightResult = stack.top();
stack.pop();
msat_term leftResult = stack.top();
stack.pop();
virtual boost::any visit(expressions::BinaryRelationExpression const& expression) override {
msat_term leftResult = boost::any_cast<msat_term>(expression.getFirstOperand()->accept(*this));
msat_term rightResult = boost::any_cast<msat_term>(expression.getSecondOperand()->accept(*this));
switch (expression->getRelationType()) {
switch (expression.getRelationType()) {
case storm::expressions::BinaryRelationExpression::RelationType::Equal:
if (expression->getFirstOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool && expression->getSecondOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool) {
stack.push(msat_make_iff(env, leftResult, rightResult));
if (expression.getFirstOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool && expression.getSecondOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool) {
return msat_make_iff(env, leftResult, rightResult);
} else {
stack.push(msat_make_equal(env, leftResult, rightResult));
return msat_make_equal(env, leftResult, rightResult);
}
break;
case storm::expressions::BinaryRelationExpression::RelationType::NotEqual:
if (expression->getFirstOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool && expression->getSecondOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool) {
stack.push(msat_make_not(env, msat_make_iff(env, leftResult, rightResult)));
if (expression.getFirstOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool && expression.getSecondOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool) {
return msat_make_not(env, msat_make_iff(env, leftResult, rightResult));
} else {
stack.push(msat_make_not(env, msat_make_equal(env, leftResult, rightResult)));
return msat_make_not(env, msat_make_equal(env, leftResult, rightResult));
}
break;
case storm::expressions::BinaryRelationExpression::RelationType::Less:
stack.push(msat_make_and(env, msat_make_not(env, msat_make_equal(env, leftResult, rightResult)), msat_make_leq(env, leftResult, rightResult)));
break;
return msat_make_and(env, msat_make_not(env, msat_make_equal(env, leftResult, rightResult)), msat_make_leq(env, leftResult, rightResult));
case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual:
stack.push(msat_make_leq(env, leftResult, rightResult));
break;
return msat_make_leq(env, leftResult, rightResult);
case storm::expressions::BinaryRelationExpression::RelationType::Greater:
stack.push(msat_make_not(env, msat_make_leq(env, leftResult, rightResult)));
break;
return msat_make_not(env, msat_make_leq(env, leftResult, rightResult));
case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual:
stack.push(msat_make_or(env, msat_make_equal(env, leftResult, rightResult), msat_make_not(env, msat_make_leq(env, leftResult, rightResult))));
break;
return msat_make_or(env, msat_make_equal(env, leftResult, rightResult), msat_make_not(env, msat_make_leq(env, leftResult, rightResult)));
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<uint_fast64_t>(expression->getRelationType()) << "' in expression " << expression << ".");
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<uint_fast64_t>(expression.getRelationType()) << "' in expression " << expression << ".");
}
}
virtual void visit(storm::expressions::IfThenElseExpression const* expression) override {
expression->getCondition()->accept(this);
expression->getThenExpression()->accept(this);
expression->getElseExpression()->accept(this);
msat_term conditionResult = stack.top();
stack.pop();
msat_term thenResult = stack.top();
stack.pop();
msat_term elseResult = stack.top();
stack.pop();
stack.push(msat_make_term_ite(env, conditionResult, thenResult, elseResult));
virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression) override {
msat_term conditionResult = boost::any_cast<msat_term>(expression.getCondition()->accept(*this));
msat_term thenResult = boost::any_cast<msat_term>(expression.getThenExpression()->accept(*this));
msat_term elseResult = boost::any_cast<msat_term>(expression.getElseExpression()->accept(*this));
return msat_make_term_ite(env, conditionResult, thenResult, elseResult);
}
virtual void visit(expressions::BooleanLiteralExpression const* expression) override {
stack.push(expression->evaluateAsBool(nullptr) ? msat_make_true(env) : msat_make_false(env));
virtual boost::any visit(expressions::BooleanLiteralExpression const& expression) override {
return expression.getValue() ? msat_make_true(env) : msat_make_false(env);
}
virtual void visit(expressions::DoubleLiteralExpression const* expression) override {
stack.push(msat_make_number(env, std::to_string(expression->evaluateAsDouble(nullptr)).c_str()));
virtual boost::any visit(expressions::DoubleLiteralExpression const& expression) override {
return msat_make_number(env, std::to_string(expression.getValue()).c_str());
}
virtual void visit(expressions::IntegerLiteralExpression const* expression) override {
stack.push(msat_make_number(env, std::to_string(static_cast<int>(expression->evaluateAsInt(nullptr))).c_str()));
virtual boost::any visit(expressions::IntegerLiteralExpression const& expression) override {
return msat_make_number(env, std::to_string(static_cast<int>(expression.getValue())).c_str());
}
virtual void visit(expressions::UnaryBooleanFunctionExpression const* expression) override {
expression->getOperand()->accept(this);
msat_term childResult = stack.top();
stack.pop();
virtual boost::any visit(expressions::UnaryBooleanFunctionExpression const& expression) override {
msat_term childResult = boost::any_cast<msat_term>(expression.getOperand()->accept(*this));
switch (expression->getOperatorType()) {
switch (expression.getOperatorType()) {
case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not:
stack.push(msat_make_not(env, childResult));
return msat_make_not(env, childResult);
break;
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean unary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".");
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean unary operator: '" << static_cast<uint_fast64_t>(expression.getOperatorType()) << "' in expression " << expression << ".");
}
}
virtual void visit(expressions::UnaryNumericalFunctionExpression const* expression) override {
expression->getOperand()->accept(this);
virtual boost::any visit(expressions::UnaryNumericalFunctionExpression const& expression) override {
msat_term childResult = boost::any_cast<msat_term>(expression.getOperand()->accept(*this));
msat_term childResult = stack.top();
stack.pop();
switch (expression->getOperatorType()) {
switch (expression.getOperatorType()) {
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus:
stack.push(msat_make_times(env, msat_make_number(env, "-1"), childResult));
return msat_make_times(env, msat_make_number(env, "-1"), childResult);
break;
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor:
stack.push(msat_make_floor(env, childResult));
return msat_make_floor(env, childResult);
break;
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:
stack.push(msat_make_plus(env, msat_make_floor(env, childResult), msat_make_number(env, "1")));
return msat_make_plus(env, msat_make_floor(env, childResult), msat_make_number(env, "1"));
break;
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".");
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator: '" << static_cast<uint_fast64_t>(expression.getOperatorType()) << "' in expression " << expression << ".");
}
}
virtual void visit(expressions::VariableExpression const* expression) override {
std::map<std::string, msat_decl>::iterator stringVariablePair = variableToDeclarationMap.find(expression->getVariableName());
virtual boost::any visit(expressions::VariableExpression const& expression) override {
std::map<std::string, msat_decl>::iterator stringVariablePair = variableToDeclarationMap.find(expression.getVariableName());
msat_decl result;
if (stringVariablePair == variableToDeclarationMap.end() && createVariables) {
std::pair<std::map<std::string, msat_decl>::iterator, bool> iteratorAndFlag;
switch (expression->getReturnType()) {
switch (expression.getReturnType()) {
case storm::expressions::ExpressionReturnType::Bool:
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_bool_type(env))));
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_bool_type(env))));
result = iteratorAndFlag.first->second;
break;
case storm::expressions::ExpressionReturnType::Int:
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_integer_type(env))));
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_integer_type(env))));
result = iteratorAndFlag.first->second;
break;
case storm::expressions::ExpressionReturnType::Double:
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_rational_type(env))));
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_rational_type(env))));
result = iteratorAndFlag.first->second;
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression->getVariableName() << "' with unknown type while trying to create solver variables.");
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression.getVariableName() << "' with unknown type while trying to create solver variables.");
}
} else {
STORM_LOG_THROW(stringVariablePair != variableToDeclarationMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression->getVariableName() << "'.");
STORM_LOG_THROW(stringVariablePair != variableToDeclarationMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression.getVariableName() << "'.");
result = stringVariablePair->second;
}
STORM_LOG_THROW(!MSAT_ERROR_DECL(result), storm::exceptions::ExpressionEvaluationException, "Unable to translate expression to MathSAT format, because a variable could not be translated.");
stack.push(msat_make_constant(env, result));
return msat_make_constant(env, result);
}
storm::expressions::Expression translateExpression(msat_term const& term) {
@ -309,9 +261,6 @@ namespace storm {
// The MathSAT environment used.
msat_env& env;
// A stack used for communicating results between different functions.
std::stack<msat_term> stack;
// A mapping of variable names to their declaration in the MathSAT environment.
std::map<std::string, msat_decl> variableToDeclarationMap;

210
src/adapters/Z3ExpressionAdapter.h

@ -35,7 +35,7 @@ namespace storm {
* expressions and are not yet known to the adapter.
* @param variableToExpressionMap A mapping from variable names to their corresponding Z3 expressions (if already existing).
*/
Z3ExpressionAdapter(z3::context& context, bool createVariables = true, std::map<std::string, z3::expr> const& variableToExpressionMap = std::map<std::string, z3::expr>()) : context(context) , stack() , additionalAssertions() , additionalVariableCounter(0), variableToExpressionMap(variableToExpressionMap), createVariables(createVariables) {
Z3ExpressionAdapter(z3::context& context, bool createVariables = true, std::map<std::string, z3::expr> const& variableToExpressionMap = std::map<std::string, z3::expr>()) : context(context), additionalAssertions(), additionalVariableCounter(0), variableToExpressionMap(variableToExpressionMap), createVariables(createVariables) {
// Intentionally left empty.
}
@ -44,20 +44,18 @@ namespace storm {
*
* @warning The adapter internally creates helper variables prefixed with `__z3adapter_`. As a consequence,
* having variables with this prefix in the variableToExpressionMap might lead to unexpected results and is
* strictly to be avoided.
* strictly to be aboost::anyed.
*
* @param expression The expression to translate.
* @return An equivalent expression for Z3.
*/
z3::expr translateExpression(storm::expressions::Expression const& expression) {
expression.getBaseExpression().accept(this);
z3::expr result = stack.top();
stack.pop();
z3::expr result = boost::any_cast<z3::expr>(expression.getBaseExpression().accept(*this));
while (!additionalAssertions.empty()) {
result = result && additionalAssertions.top();
additionalAssertions.pop();
}
for (z3::expr const& assertion : additionalAssertions) {
result = result && assertion;
}
additionalAssertions.clear();
return result;
}
@ -167,211 +165,159 @@ namespace storm {
}
}
virtual void visit(storm::expressions::BinaryBooleanFunctionExpression const* expression) override {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
const z3::expr rightResult = stack.top();
stack.pop();
const z3::expr leftResult = stack.top();
stack.pop();
virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression) override {
z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this));
z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this));
switch(expression->getOperatorType()) {
switch(expression.getOperatorType()) {
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And:
stack.push(leftResult && rightResult);
break;
return leftResult && rightResult;
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or:
stack.push(leftResult || rightResult);
break;
return leftResult || rightResult;
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Xor:
stack.push(z3::expr(context, Z3_mk_xor(context, leftResult, rightResult)));
break;
return z3::expr(context, Z3_mk_xor(context, leftResult, rightResult));
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies:
stack.push(z3::expr(context, Z3_mk_implies(context, leftResult, rightResult)));
break;
return z3::expr(context, Z3_mk_implies(context, leftResult, rightResult));
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff:
stack.push(z3::expr(context, Z3_mk_iff(context, leftResult, rightResult)));
break;
return z3::expr(context, Z3_mk_iff(context, leftResult, rightResult));
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression->getOperatorType()) << "' in expression " << expression << ".");
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << ".");
}
}
virtual void visit(storm::expressions::BinaryNumericalFunctionExpression const* expression) override {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
z3::expr rightResult = stack.top();
stack.pop();
z3::expr leftResult = stack.top();
stack.pop();
switch(expression->getOperatorType()) {
virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression) override {
z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this));
z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this));
switch(expression.getOperatorType()) {
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus:
stack.push(leftResult + rightResult);
break;
return leftResult + rightResult;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus:
stack.push(leftResult - rightResult);
break;
return leftResult - rightResult;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times:
stack.push(leftResult * rightResult);
break;
return leftResult * rightResult;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide:
stack.push(leftResult / rightResult);
break;
return leftResult / rightResult;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min:
stack.push(ite(leftResult <= rightResult, leftResult, rightResult));
break;
return ite(leftResult <= rightResult, leftResult, rightResult);
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max:
stack.push(ite(leftResult >= rightResult, leftResult, rightResult));
break;
default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast<int>(expression->getOperatorType()) << "' in expression " << expression << ".");
return ite(leftResult >= rightResult, leftResult, rightResult);
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << ".");
}
}
virtual void visit(storm::expressions::BinaryRelationExpression const* expression) override {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
z3::expr rightResult = stack.top();
stack.pop();
z3::expr leftResult = stack.top();
stack.pop();
switch(expression->getRelationType()) {
virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression) override {
z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this));
z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this));
switch(expression.getRelationType()) {
case storm::expressions::BinaryRelationExpression::RelationType::Equal:
stack.push(leftResult == rightResult);
break;
return leftResult == rightResult;
case storm::expressions::BinaryRelationExpression::RelationType::NotEqual:
stack.push(leftResult != rightResult);
break;
return leftResult != rightResult;
case storm::expressions::BinaryRelationExpression::RelationType::Less:
stack.push(leftResult < rightResult);
break;
return leftResult < rightResult;
case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual:
stack.push(leftResult <= rightResult);
break;
return leftResult <= rightResult;
case storm::expressions::BinaryRelationExpression::RelationType::Greater:
stack.push(leftResult > rightResult);
break;
return leftResult > rightResult;
case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual:
stack.push(leftResult >= rightResult);
break;
return leftResult >= rightResult;
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression->getRelationType()) << "' in expression " << expression << ".");
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getRelationType()) << "' in expression " << expression << ".");
}
}
virtual void visit(storm::expressions::BooleanLiteralExpression const* expression) override {
stack.push(context.bool_val(expression->evaluateAsBool()));
virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression) override {
return context.bool_val(expression.getValue());
}
virtual void visit(storm::expressions::DoubleLiteralExpression const* expression) override {
virtual boost::any visit(storm::expressions::DoubleLiteralExpression const& expression) override {
std::stringstream fractionStream;
fractionStream << expression->evaluateAsDouble();
stack.push(context.real_val(fractionStream.str().c_str()));
fractionStream << expression.getValue();
return context.real_val(fractionStream.str().c_str());
}
virtual void visit(storm::expressions::IntegerLiteralExpression const* expression) override {
stack.push(context.int_val(static_cast<int>(expression->evaluateAsInt())));
virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression) override {
return context.int_val(static_cast<int>(expression.getValue()));
}
virtual void visit(storm::expressions::UnaryBooleanFunctionExpression const* expression) override {
expression->getOperand()->accept(this);
z3::expr childResult = stack.top();
stack.pop();
virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression) override {
z3::expr childResult = boost::any_cast<z3::expr>(expression.getOperand()->accept(*this));
switch (expression->getOperatorType()) {
switch (expression.getOperatorType()) {
case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not:
stack.push(!childResult);
break;
return !childResult;
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression->getOperatorType()) << "' in expression " << expression << ".");
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << ".");
}
}
virtual void visit(storm::expressions::UnaryNumericalFunctionExpression const* expression) override {
expression->getOperand()->accept(this);
virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression) override {
z3::expr childResult = boost::any_cast<z3::expr>(expression.getOperand()->accept(*this));
z3::expr childResult = stack.top();
stack.pop();
switch(expression->getOperatorType()) {
switch(expression.getOperatorType()) {
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus:
stack.push(0 - childResult);
break;
return 0 - childResult;
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: {
z3::expr floorVariable = context.int_const(("__z3adapter_floor_" + std::to_string(additionalVariableCounter++)).c_str());
additionalAssertions.push(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1));
stack.push(floorVariable);
break;
additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1));
return floorVariable;
}
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:{
z3::expr ceilVariable = context.int_const(("__z3adapter_ceil_" + std::to_string(additionalVariableCounter++)).c_str());
additionalAssertions.push(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable)));
stack.push(ceilVariable);
break;
additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable)));
return ceilVariable;
}
default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator '" << static_cast<int>(expression->getOperatorType()) << "'.");
default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator '" << static_cast<int>(expression.getOperatorType()) << "'.");
}
}
virtual void visit(storm::expressions::IfThenElseExpression const* expression) override {
expression->getCondition()->accept(this);
expression->getThenExpression()->accept(this);
expression->getElseExpression()->accept(this);
z3::expr conditionResult = stack.top();
stack.pop();
z3::expr thenResult = stack.top();
stack.pop();
z3::expr elseResult = stack.top();
stack.pop();
stack.push(z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult)));
virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression) override {
z3::expr conditionResult = boost::any_cast<z3::expr>(expression.getCondition()->accept(*this));
z3::expr thenResult = boost::any_cast<z3::expr>(expression.getThenExpression()->accept(*this));
z3::expr elseResult = boost::any_cast<z3::expr>(expression.getElseExpression()->accept(*this));
return z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult));
}
virtual void visit(storm::expressions::VariableExpression const* expression) override {
std::map<std::string, z3::expr>::iterator stringVariablePair = variableToExpressionMap.find(expression->getVariableName());
virtual boost::any visit(storm::expressions::VariableExpression const& expression) override {
std::map<std::string, z3::expr>::iterator stringVariablePair = variableToExpressionMap.find(expression.getVariableName());
z3::expr result(context);
if (stringVariablePair == variableToExpressionMap.end() && createVariables) {
std::pair<std::map<std::string, z3::expr>::iterator, bool> iteratorAndFlag;
switch (expression->getReturnType()) {
switch (expression.getReturnType()) {
case storm::expressions::ExpressionReturnType::Bool:
iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression->getVariableName(), context.bool_const(expression->getVariableName().c_str())));
iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.bool_const(expression.getVariableName().c_str())));
result = iteratorAndFlag.first->second;
break;
case storm::expressions::ExpressionReturnType::Int:
iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression->getVariableName(), context.int_const(expression->getVariableName().c_str())));
iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.int_const(expression.getVariableName().c_str())));
result = iteratorAndFlag.first->second;
break;
case storm::expressions::ExpressionReturnType::Double:
iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression->getVariableName(), context.real_const(expression->getVariableName().c_str())));
iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.real_const(expression.getVariableName().c_str())));
result = iteratorAndFlag.first->second;
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression->getVariableName() << "' with unknown type while trying to create solver variables.");
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression.getVariableName() << "' with unknown type while trying to create solver variables.");
}
} else {
STORM_LOG_THROW(stringVariablePair != variableToExpressionMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression->getVariableName() << "'.");
STORM_LOG_THROW(stringVariablePair != variableToExpressionMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression.getVariableName() << "'.");
result = stringVariablePair->second;
}
stack.push(result);
return result;
}
private:
// The context that is used to translate the expressions.
z3::context& context;
// A stack that is used to communicate the translation results between method calls.
std::stack<z3::expr> stack;
// A stack of assertions that need to be kept separate, because they were only impliclty part of an assertion that was added.
std::stack<z3::expr> additionalAssertions;
std::vector<z3::expr> additionalAssertions;
// A counter for the variables that were created to identify the additional assertions.
uint_fast64_t additionalVariableCounter;

2
src/storage/expressions/BaseExpression.h

@ -168,7 +168,7 @@ namespace storm {
*
* @param visitor The visitor that is to be accepted.
*/
virtual void accept(ExpressionVisitor* visitor) const = 0;
virtual boost::any accept(ExpressionVisitor& visitor) const = 0;
/*!
* Retrieves whether the expression has a numerical return type, i.e., integer or double.

4
src/storage/expressions/BinaryBooleanFunctionExpression.cpp

@ -91,8 +91,8 @@ namespace storm {
}
}
void BinaryBooleanFunctionExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any BinaryBooleanFunctionExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
void BinaryBooleanFunctionExpression::printToStream(std::ostream& stream) const {

2
src/storage/expressions/BinaryBooleanFunctionExpression.h

@ -36,7 +36,7 @@ namespace storm {
virtual storm::expressions::OperatorType getOperator() const override;
virtual bool evaluateAsBool(Valuation const* valuation = nullptr) const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the operator associated with the expression.

4
src/storage/expressions/BinaryNumericalFunctionExpression.cpp

@ -70,8 +70,8 @@ namespace storm {
}
}
void BinaryNumericalFunctionExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any BinaryNumericalFunctionExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
void BinaryNumericalFunctionExpression::printToStream(std::ostream& stream) const {

2
src/storage/expressions/BinaryNumericalFunctionExpression.h

@ -37,7 +37,7 @@ namespace storm {
virtual int_fast64_t evaluateAsInt(Valuation const* valuation = nullptr) const override;
virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the operator associated with the expression.

4
src/storage/expressions/BinaryRelationExpression.cpp

@ -46,8 +46,8 @@ namespace storm {
}
}
void BinaryRelationExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any BinaryRelationExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
BinaryRelationExpression::RelationType BinaryRelationExpression::getRelationType() const {

2
src/storage/expressions/BinaryRelationExpression.h

@ -36,7 +36,7 @@ namespace storm {
virtual storm::expressions::OperatorType getOperator() const override;
virtual bool evaluateAsBool(Valuation const* valuation = nullptr) const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the relation associated with the expression.

4
src/storage/expressions/BooleanLiteralExpression.cpp

@ -34,8 +34,8 @@ namespace storm {
return this->shared_from_this();
}
void BooleanLiteralExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any BooleanLiteralExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
bool BooleanLiteralExpression::getValue() const {

2
src/storage/expressions/BooleanLiteralExpression.h

@ -32,7 +32,7 @@ namespace storm {
virtual std::set<std::string> getVariables() const override;
virtual std::map<std::string, ExpressionReturnType> getVariablesAndTypes() const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the value of the boolean literal.

4
src/storage/expressions/DoubleLiteralExpression.cpp

@ -26,8 +26,8 @@ namespace storm {
return this->shared_from_this();
}
void DoubleLiteralExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any DoubleLiteralExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
double DoubleLiteralExpression::getValue() const {

2
src/storage/expressions/DoubleLiteralExpression.h

@ -30,7 +30,7 @@ namespace storm {
virtual std::set<std::string> getVariables() const override;
virtual std::map<std::string, ExpressionReturnType> getVariablesAndTypes() const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the value of the double literal.

24
src/storage/expressions/Expression.cpp

@ -4,7 +4,6 @@
#include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/SubstitutionVisitor.h"
#include "src/storage/expressions/IdentifierSubstitutionVisitor.h"
#include "src/storage/expressions/TypeCheckVisitor.h"
#include "src/storage/expressions/LinearityCheckVisitor.h"
#include "src/storage/expressions/Expressions.h"
#include "src/exceptions/InvalidTypeException.h"
@ -31,14 +30,6 @@ namespace storm {
Expression Expression::substitute(std::unordered_map<std::string, std::string> const& identifierToIdentifierMap) const {
return IdentifierSubstitutionVisitor<std::unordered_map<std::string, std::string>>(identifierToIdentifierMap).substitute(*this);
}
void Expression::check(std::map<std::string, storm::expressions::ExpressionReturnType> const& identifierToTypeMap) const {
return TypeCheckVisitor<std::map<std::string, storm::expressions::ExpressionReturnType>>(identifierToTypeMap).check(*this);
}
void Expression::check(std::unordered_map<std::string, storm::expressions::ExpressionReturnType> const& identifierToTypeMap) const {
return TypeCheckVisitor<std::unordered_map<std::string, storm::expressions::ExpressionReturnType>>(identifierToTypeMap).check(*this);
}
bool Expression::evaluateAsBool(Valuation const* valuation) const {
return this->getBaseExpression().evaluateAsBool(valuation);
@ -99,17 +90,6 @@ namespace storm {
std::set<std::string> Expression::getVariables() const {
return this->getBaseExpression().getVariables();
}
std::map<std::string, ExpressionReturnType> Expression::getVariablesAndTypes(bool validate) const {
if (validate) {
std::map<std::string, ExpressionReturnType> result = this->getBaseExpression().getVariablesAndTypes();
this->check(result);
return result;
}
else {
return this->getBaseExpression().getVariablesAndTypes();
}
}
bool Expression::isRelationalExpression() const {
if (!this->isFunctionApplication()) {
@ -300,6 +280,10 @@ namespace storm {
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil)));
}
boost::any Expression::accept(ExpressionVisitor& visitor) const {
return this->getBaseExpression().accept(visitor);
}
std::ostream& operator<<(std::ostream& stream, Expression const& expression) {
stream << expression.getBaseExpression();
return stream;

18
src/storage/expressions/Expression.h

@ -108,22 +108,6 @@ namespace storm {
*/
Expression substitute(std::unordered_map<std::string, std::string> const& identifierToIdentifierMap) const;
/*!
* Checks that all identifiers appearing in the expression have the types given by the map. An exception
* is thrown in case a violation is found.
*
* @param identifierToTypeMap A mapping from identifiers to the types that are supposed to have.
*/
void check(std::map<std::string, storm::expressions::ExpressionReturnType> const& identifierToTypeMap) const;
/*!
* Checks that all identifiers appearing in the expression have the types given by the map. An exception
* is thrown in case a violation is found.
*
* @param identifierToTypeMap A mapping from identifiers to the types that are supposed to have.
*/
void check(std::unordered_map<std::string, storm::expressions::ExpressionReturnType> const& identifierToTypeMap) const;
/*!
* Evaluates the expression under the valuation of variables given by the valuation and returns the
* resulting boolean value. If the return type of the expression is not a boolean an exception is thrown.
@ -314,7 +298,7 @@ namespace storm {
*
* @param visitor The visitor to accept.
*/
void accept(ExpressionVisitor* visitor) const;
boost::any accept(ExpressionVisitor& visitor) const;
friend std::ostream& operator<<(std::ostream& stream, Expression const& expression);

22
src/storage/expressions/ExpressionVisitor.h

@ -1,6 +1,8 @@
#ifndef STORM_STORAGE_EXPRESSIONS_EXPRESSIONVISITOR_H_
#define STORM_STORAGE_EXPRESSIONS_EXPRESSIONVISITOR_H_
#include <boost/any.hpp>
namespace storm {
namespace expressions {
// Forward-declare all expression classes.
@ -17,16 +19,16 @@ namespace storm {
class ExpressionVisitor {
public:
virtual void visit(IfThenElseExpression const* expression) = 0;
virtual void visit(BinaryBooleanFunctionExpression const* expression) = 0;
virtual void visit(BinaryNumericalFunctionExpression const* expression) = 0;
virtual void visit(BinaryRelationExpression const* expression) = 0;
virtual void visit(VariableExpression const* expression) = 0;
virtual void visit(UnaryBooleanFunctionExpression const* expression) = 0;
virtual void visit(UnaryNumericalFunctionExpression const* expression) = 0;
virtual void visit(BooleanLiteralExpression const* expression) = 0;
virtual void visit(IntegerLiteralExpression const* expression) = 0;
virtual void visit(DoubleLiteralExpression const* expression) = 0;
virtual boost::any visit(IfThenElseExpression const& expression) = 0;
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) = 0;
virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) = 0;
virtual boost::any visit(BinaryRelationExpression const& expression) = 0;
virtual boost::any visit(VariableExpression const& expression) = 0;
virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) = 0;
virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) = 0;
virtual boost::any visit(BooleanLiteralExpression const& expression) = 0;
virtual boost::any visit(IntegerLiteralExpression const& expression) = 0;
virtual boost::any visit(DoubleLiteralExpression const& expression) = 0;
};
}
}

120
src/storage/expressions/IdentifierSubstitutionVisitor.cpp

@ -14,138 +14,110 @@ namespace storm {
template<typename MapType>
Expression IdentifierSubstitutionVisitor<MapType>::substitute(Expression const& expression) {
expression.getBaseExpression().accept(this);
return Expression(this->expressionStack.top());
return Expression(boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getBaseExpression().accept(*this)));
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(IfThenElseExpression const* expression) {
expression->getCondition()->accept(this);
std::shared_ptr<BaseExpression const> conditionExpression = expressionStack.top();
expressionStack.pop();
expression->getThenExpression()->accept(this);
std::shared_ptr<BaseExpression const> thenExpression = expressionStack.top();
expressionStack.pop();
expression->getElseExpression()->accept(this);
std::shared_ptr<BaseExpression const> elseExpression = expressionStack.top();
expressionStack.pop();
boost::any IdentifierSubstitutionVisitor<MapType>::visit(IfThenElseExpression const& expression) {
std::shared_ptr<BaseExpression const> conditionExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getCondition()->accept(*this));
std::shared_ptr<BaseExpression const> thenExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getThenExpression()->accept(*this));
std::shared_ptr<BaseExpression const> elseExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getElseExpression()->accept(*this));
// If the arguments did not change, we simply push the expression itself.
if (conditionExpression.get() == expression->getCondition().get() && thenExpression.get() == expression->getThenExpression().get() && elseExpression.get() == expression->getElseExpression().get()) {
this->expressionStack.push(expression->getSharedPointer());
if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) {
return expression.getSharedPointer();
} else {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new IfThenElseExpression(expression->getReturnType(), conditionExpression, thenExpression, elseExpression)));
return std::shared_ptr<BaseExpression>(new IfThenElseExpression(expression.getReturnType(), conditionExpression, thenExpression, elseExpression));
}
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(BinaryBooleanFunctionExpression const* expression) {
expression->getFirstOperand()->accept(this);
std::shared_ptr<BaseExpression const> firstExpression = expressionStack.top();
expressionStack.pop();
expression->getSecondOperand()->accept(this);
std::shared_ptr<BaseExpression const> secondExpression = expressionStack.top();
expressionStack.pop();
boost::any IdentifierSubstitutionVisitor<MapType>::visit(BinaryBooleanFunctionExpression const& expression) {
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getFirstOperand()->accept(*this));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getSecondOperand()->accept(*this));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) {
this->expressionStack.push(expression->getSharedPointer());
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType())));
return std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getOperatorType()));
}
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(BinaryNumericalFunctionExpression const* expression) {
expression->getFirstOperand()->accept(this);
std::shared_ptr<BaseExpression const> firstExpression = expressionStack.top();
expressionStack.pop();
expression->getSecondOperand()->accept(this);
std::shared_ptr<BaseExpression const> secondExpression = expressionStack.top();
expressionStack.pop();
boost::any IdentifierSubstitutionVisitor<MapType>::visit(BinaryNumericalFunctionExpression const& expression) {
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getFirstOperand()->accept(*this));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getSecondOperand()->accept(*this));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) {
this->expressionStack.push(expression->getSharedPointer());
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType())));
return std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getOperatorType()));
}
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(BinaryRelationExpression const* expression) {
expression->getFirstOperand()->accept(this);
std::shared_ptr<BaseExpression const> firstExpression = expressionStack.top();
expressionStack.pop();
expression->getSecondOperand()->accept(this);
std::shared_ptr<BaseExpression const> secondExpression = expressionStack.top();
expressionStack.pop();
boost::any IdentifierSubstitutionVisitor<MapType>::visit(BinaryRelationExpression const& expression) {
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getFirstOperand()->accept(*this));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getSecondOperand()->accept(*this));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) {
this->expressionStack.push(expression->getSharedPointer());
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getRelationType())));
return std::shared_ptr<BaseExpression>(new BinaryRelationExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getRelationType()));
}
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(VariableExpression const* expression) {
boost::any IdentifierSubstitutionVisitor<MapType>::visit(VariableExpression const& expression) {
// If the variable is in the key set of the substitution, we need to replace it.
auto const& namePair = this->identifierToIdentifierMap.find(expression->getVariableName());
auto const& namePair = this->identifierToIdentifierMap.find(expression.getVariableName());
if (namePair != this->identifierToIdentifierMap.end()) {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new VariableExpression(expression->getReturnType(), namePair->second)));
return std::shared_ptr<BaseExpression>(new VariableExpression(expression.getReturnType(), namePair->second));
} else {
this->expressionStack.push(expression->getSharedPointer());
return expression.getSharedPointer();
}
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(UnaryBooleanFunctionExpression const* expression) {
expression->getOperand()->accept(this);
std::shared_ptr<BaseExpression const> operandExpression = expressionStack.top();
expressionStack.pop();
boost::any IdentifierSubstitutionVisitor<MapType>::visit(UnaryBooleanFunctionExpression const& expression) {
std::shared_ptr<BaseExpression const> operandExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getOperand()->accept(*this));
// If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression->getOperand().get()) {
expressionStack.push(expression->getSharedPointer());
if (operandExpression.get() == expression.getOperand().get()) {
return expression.getSharedPointer();
} else {
expressionStack.push(std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType())));
return std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(expression.getReturnType(), operandExpression, expression.getOperatorType()));
}
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(UnaryNumericalFunctionExpression const* expression) {
expression->getOperand()->accept(this);
std::shared_ptr<BaseExpression const> operandExpression = expressionStack.top();
expressionStack.pop();
boost::any IdentifierSubstitutionVisitor<MapType>::visit(UnaryNumericalFunctionExpression const& expression) {
std::shared_ptr<BaseExpression const> operandExpression = boost::any_cast<std::shared_ptr<BaseExpression>>(expression.getOperand()->accept(*this));
// If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression->getOperand().get()) {
expressionStack.push(expression->getSharedPointer());
if (operandExpression.get() == expression.getOperand().get()) {
return expression.getSharedPointer();
} else {
expressionStack.push(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType())));
return std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(expression.getReturnType(), operandExpression, expression.getOperatorType()));
}
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(BooleanLiteralExpression const* expression) {
this->expressionStack.push(expression->getSharedPointer());
boost::any IdentifierSubstitutionVisitor<MapType>::visit(BooleanLiteralExpression const& expression) {
return expression.getSharedPointer();
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(IntegerLiteralExpression const* expression) {
this->expressionStack.push(expression->getSharedPointer());
boost::any IdentifierSubstitutionVisitor<MapType>::visit(IntegerLiteralExpression const& expression) {
return expression.getSharedPointer();
}
template<typename MapType>
void IdentifierSubstitutionVisitor<MapType>::visit(DoubleLiteralExpression const* expression) {
this->expressionStack.push(expression->getSharedPointer());
boost::any IdentifierSubstitutionVisitor<MapType>::visit(DoubleLiteralExpression const& expression) {
return expression.getSharedPointer();
}
// Explicitly instantiate the class with map and unordered_map.

23
src/storage/expressions/IdentifierSubstitutionVisitor.h

@ -28,21 +28,18 @@ namespace storm {
*/
Expression substitute(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override;
virtual void visit(BinaryNumericalFunctionExpression const* expression) override;
virtual void visit(BinaryRelationExpression const* expression) override;
virtual void visit(VariableExpression const* expression) override;
virtual void visit(UnaryBooleanFunctionExpression const* expression) override;
virtual void visit(UnaryNumericalFunctionExpression const* expression) override;
virtual void visit(BooleanLiteralExpression const* expression) override;
virtual void visit(IntegerLiteralExpression const* expression) override;
virtual void visit(DoubleLiteralExpression const* expression) override;
virtual boost::any visit(IfThenElseExpression const& expression) override;
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BinaryRelationExpression const& expression) override;
virtual boost::any visit(VariableExpression const& expression) override;
virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BooleanLiteralExpression const& expression) override;
virtual boost::any visit(IntegerLiteralExpression const& expression) override;
virtual boost::any visit(DoubleLiteralExpression const& expression) override;
private:
// A stack of expression used to pass the results to the higher levels.
std::stack<std::shared_ptr<BaseExpression const>> expressionStack;
// A mapping of identifier names to expressions with which they shall be replaced.
MapType const& identifierToIdentifierMap;
};

4
src/storage/expressions/IfThenElseExpression.cpp

@ -99,8 +99,8 @@ namespace storm {
}
}
void IfThenElseExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any IfThenElseExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
std::shared_ptr<BaseExpression const> IfThenElseExpression::getCondition() const {

2
src/storage/expressions/IfThenElseExpression.h

@ -38,7 +38,7 @@ namespace storm {
virtual std::set<std::string> getVariables() const override;
virtual std::map<std::string, ExpressionReturnType> getVariablesAndTypes() const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the condition expression of the if-then-else expression.

4
src/storage/expressions/IntegerLiteralExpression.cpp

@ -30,8 +30,8 @@ namespace storm {
return this->shared_from_this();
}
void IntegerLiteralExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any IntegerLiteralExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
int_fast64_t IntegerLiteralExpression::getValue() const {

2
src/storage/expressions/IntegerLiteralExpression.h

@ -31,7 +31,7 @@ namespace storm {
virtual std::set<std::string> getVariables() const override;
virtual std::map<std::string, ExpressionReturnType> getVariablesAndTypes() const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the value of the integer literal.

83
src/storage/expressions/LinearCoefficientVisitor.cpp

@ -7,26 +7,22 @@
namespace storm {
namespace expressions {
std::pair<SimpleValuation, double> LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) {
expression.getBaseExpression().accept(this);
return resultStack.top();
return boost::any_cast<std::pair<SimpleValuation, double>>(expression.getBaseExpression().accept(*this));
}
void LinearCoefficientVisitor::visit(IfThenElseExpression const* expression) {
boost::any LinearCoefficientVisitor::visit(IfThenElseExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) {
boost::any LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) {
if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
boost::any LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const& expression) {
std::pair<SimpleValuation, double> leftResult = boost::any_cast<std::pair<SimpleValuation, double>>(expression.getFirstOperand()->accept(*this));
std::pair<SimpleValuation, double> rightResult = boost::any_cast<std::pair<SimpleValuation, double>>(expression.getSecondOperand()->accept(*this));
if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) {
// Now add the left result to the right result.
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
if (rightResult.first.containsDoubleIdentifier(identifier)) {
@ -36,14 +32,7 @@ namespace storm {
}
}
rightResult.second += leftResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) {
// Now subtract the right result from the left result.
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
if (rightResult.first.containsDoubleIdentifier(identifier)) {
@ -58,14 +47,7 @@ namespace storm {
}
}
rightResult.second = leftResult.second - rightResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) {
// If the expression is linear, either the left or the right side must not contain variables.
STORM_LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (leftResult.first.getNumberOfIdentifiers() == 0) {
@ -78,14 +60,7 @@ namespace storm {
}
}
rightResult.second *= leftResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) {
// If the expression is linear, either the left or the right side must not contain variables.
STORM_LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (leftResult.first.getNumberOfIdentifiers() == 0) {
@ -98,54 +73,56 @@ namespace storm {
}
}
rightResult.second = leftResult.second / leftResult.second;
return;
} else {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
return rightResult;
}
void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) {
boost::any LinearCoefficientVisitor::visit(BinaryRelationExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(VariableExpression const* expression) {
boost::any LinearCoefficientVisitor::visit(VariableExpression const& expression) {
SimpleValuation valuation;
switch (expression->getReturnType()) {
switch (expression.getReturnType()) {
case ExpressionReturnType::Bool: STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); break;
case ExpressionReturnType::Int:
case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression->getVariableName(), 1); break;
case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression.getVariableName(), 1); break;
case ExpressionReturnType::Undefined: STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal expression return type."); break;
}
resultStack.push(std::make_pair(valuation, 0));
return std::make_pair(valuation, static_cast<double>(0));
}
void LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const* expression) {
boost::any LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const* expression) {
if (expression->getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) {
boost::any LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const& expression) {
std::pair<SimpleValuation, double> childResult = boost::any_cast<std::pair<SimpleValuation, double>>(expression.getOperand()->accept(*this));
if (expression.getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) {
// Here, we need to negate all double identifiers.
std::pair<SimpleValuation, double>& valuationConstantPair = resultStack.top();
for (auto const& identifier : valuationConstantPair.first.getDoubleIdentifiers()) {
valuationConstantPair.first.setDoubleValue(identifier, -valuationConstantPair.first.getDoubleValue(identifier));
for (auto const& identifier : childResult.first.getDoubleIdentifiers()) {
childResult.first.setDoubleValue(identifier, -childResult.first.getDoubleValue(identifier));
}
return childResult;
} else {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
}
void LinearCoefficientVisitor::visit(BooleanLiteralExpression const* expression) {
boost::any LinearCoefficientVisitor::visit(BooleanLiteralExpression const& expression) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(IntegerLiteralExpression const* expression) {
resultStack.push(std::make_pair(SimpleValuation(), static_cast<double>(expression->getValue())));
boost::any LinearCoefficientVisitor::visit(IntegerLiteralExpression const& expression) {
return std::make_pair(SimpleValuation(), static_cast<double>(expression.getValue()));
}
void LinearCoefficientVisitor::visit(DoubleLiteralExpression const* expression) {
resultStack.push(std::make_pair(SimpleValuation(), expression->getValue()));
boost::any LinearCoefficientVisitor::visit(DoubleLiteralExpression const& expression) {
return std::make_pair(SimpleValuation(), expression.getValue());
}
}
}

23
src/storage/expressions/LinearCoefficientVisitor.h

@ -26,19 +26,16 @@ namespace storm {
*/
std::pair<SimpleValuation, double> getLinearCoefficients(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override;
virtual void visit(BinaryNumericalFunctionExpression const* expression) override;
virtual void visit(BinaryRelationExpression const* expression) override;
virtual void visit(VariableExpression const* expression) override;
virtual void visit(UnaryBooleanFunctionExpression const* expression) override;
virtual void visit(UnaryNumericalFunctionExpression const* expression) override;
virtual void visit(BooleanLiteralExpression const* expression) override;
virtual void visit(IntegerLiteralExpression const* expression) override;
virtual void visit(DoubleLiteralExpression const* expression) override;
private:
std::stack<std::pair<SimpleValuation, double>> resultStack;
virtual boost::any visit(IfThenElseExpression const& expression) override;
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BinaryRelationExpression const& expression) override;
virtual boost::any visit(VariableExpression const& expression) override;
virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BooleanLiteralExpression const& expression) override;
virtual boost::any visit(IntegerLiteralExpression const& expression) override;
virtual boost::any visit(DoubleLiteralExpression const& expression) override;
};
}
}

106
src/storage/expressions/LinearityCheckVisitor.cpp

@ -6,108 +6,84 @@
namespace storm {
namespace expressions {
LinearityCheckVisitor::LinearityCheckVisitor() : resultStack() {
LinearityCheckVisitor::LinearityCheckVisitor() {
// Intentionally left empty.
}
bool LinearityCheckVisitor::check(Expression const& expression) {
expression.getBaseExpression().accept(this);
return resultStack.top() == LinearityStatus::LinearWithoutVariables || resultStack.top() == LinearityStatus::LinearContainsVariables;
LinearityStatus result = boost::any_cast<LinearityStatus>(expression.getBaseExpression().accept(*this));
return result == LinearityStatus::LinearWithoutVariables || result == LinearityStatus::LinearContainsVariables;
}
void LinearityCheckVisitor::visit(IfThenElseExpression const* expression) {
boost::any LinearityCheckVisitor::visit(IfThenElseExpression const& expression) {
// An if-then-else expression is never linear.
resultStack.push(LinearityStatus::NonLinear);
return LinearityStatus::NonLinear;
}
void LinearityCheckVisitor::visit(BinaryBooleanFunctionExpression const* expression) {
boost::any LinearityCheckVisitor::visit(BinaryBooleanFunctionExpression const& expression) {
// Boolean function applications are not allowed in linear expressions.
resultStack.push(LinearityStatus::NonLinear);
return LinearityStatus::NonLinear;
}
void LinearityCheckVisitor::visit(BinaryNumericalFunctionExpression const* expression) {
LinearityStatus leftResult;
LinearityStatus rightResult;
switch (expression->getOperatorType()) {
boost::any LinearityCheckVisitor::visit(BinaryNumericalFunctionExpression const& expression) {
LinearityStatus leftResult = boost::any_cast<LinearityStatus>(expression.getFirstOperand()->accept(*this));
if (leftResult == LinearityStatus::NonLinear) {
return LinearityStatus::NonLinear;
}
LinearityStatus rightResult = boost::any_cast<LinearityStatus>(expression.getSecondOperand()->accept(*this));
if (rightResult == LinearityStatus::NonLinear) {
return LinearityStatus::NonLinear;
}
switch (expression.getOperatorType()) {
case BinaryNumericalFunctionExpression::OperatorType::Plus:
case BinaryNumericalFunctionExpression::OperatorType::Minus:
expression->getFirstOperand()->accept(this);
leftResult = resultStack.top();
if (leftResult == LinearityStatus::NonLinear) {
return;
} else {
resultStack.pop();
expression->getSecondOperand()->accept(this);
rightResult = resultStack.top();
if (rightResult == LinearityStatus::NonLinear) {
return;
}
resultStack.pop();
}
resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables);
break;
return (leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables);
case BinaryNumericalFunctionExpression::OperatorType::Times:
case BinaryNumericalFunctionExpression::OperatorType::Divide:
expression->getFirstOperand()->accept(this);
leftResult = resultStack.top();
if (leftResult == LinearityStatus::NonLinear) {
return;
} else {
resultStack.pop();
expression->getSecondOperand()->accept(this);
rightResult = resultStack.top();
if (rightResult == LinearityStatus::NonLinear) {
return;
}
resultStack.pop();
}
if (leftResult == LinearityStatus::LinearContainsVariables && rightResult == LinearityStatus::LinearContainsVariables) {
resultStack.push(LinearityStatus::NonLinear);
return LinearityStatus::NonLinear;
}
resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables);
break;
case BinaryNumericalFunctionExpression::OperatorType::Min: resultStack.push(LinearityStatus::NonLinear); break;
case BinaryNumericalFunctionExpression::OperatorType::Max: resultStack.push(LinearityStatus::NonLinear); break;
case BinaryNumericalFunctionExpression::OperatorType::Power: resultStack.push(LinearityStatus::NonLinear); break;
return (leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables);
case BinaryNumericalFunctionExpression::OperatorType::Min: return LinearityStatus::NonLinear; break;
case BinaryNumericalFunctionExpression::OperatorType::Max: return LinearityStatus::NonLinear; break;
case BinaryNumericalFunctionExpression::OperatorType::Power: return LinearityStatus::NonLinear; break;
}
}
void LinearityCheckVisitor::visit(BinaryRelationExpression const* expression) {
resultStack.push(LinearityStatus::NonLinear);
boost::any LinearityCheckVisitor::visit(BinaryRelationExpression const& expression) {
return LinearityStatus::NonLinear;
}
void LinearityCheckVisitor::visit(VariableExpression const* expression) {
resultStack.push(LinearityStatus::LinearContainsVariables);
boost::any LinearityCheckVisitor::visit(VariableExpression const& expression) {
return LinearityStatus::LinearContainsVariables;
}
void LinearityCheckVisitor::visit(UnaryBooleanFunctionExpression const* expression) {
boost::any LinearityCheckVisitor::visit(UnaryBooleanFunctionExpression const& expression) {
// Boolean function applications are not allowed in linear expressions.
resultStack.push(LinearityStatus::NonLinear);
return LinearityStatus::NonLinear;
}
void LinearityCheckVisitor::visit(UnaryNumericalFunctionExpression const* expression) {
switch (expression->getOperatorType()) {
case UnaryNumericalFunctionExpression::OperatorType::Minus: break;
boost::any LinearityCheckVisitor::visit(UnaryNumericalFunctionExpression const& expression) {
switch (expression.getOperatorType()) {
case UnaryNumericalFunctionExpression::OperatorType::Minus: return expression.getOperand()->accept(*this);
case UnaryNumericalFunctionExpression::OperatorType::Floor:
case UnaryNumericalFunctionExpression::OperatorType::Ceil: resultStack.pop(); resultStack.push(LinearityStatus::NonLinear); break;
case UnaryNumericalFunctionExpression::OperatorType::Ceil: return LinearityStatus::NonLinear;
}
}
void LinearityCheckVisitor::visit(BooleanLiteralExpression const* expression) {
resultStack.push(LinearityStatus::NonLinear);
boost::any LinearityCheckVisitor::visit(BooleanLiteralExpression const& expression) {
return LinearityStatus::NonLinear;
}
void LinearityCheckVisitor::visit(IntegerLiteralExpression const* expression) {
resultStack.push(LinearityStatus::LinearWithoutVariables);
boost::any LinearityCheckVisitor::visit(IntegerLiteralExpression const& expression) {
return LinearityStatus::LinearWithoutVariables;
}
void LinearityCheckVisitor::visit(DoubleLiteralExpression const* expression) {
resultStack.push(LinearityStatus::LinearWithoutVariables);
boost::any LinearityCheckVisitor::visit(DoubleLiteralExpression const& expression) {
return LinearityStatus::LinearWithoutVariables;
}
}
}

23
src/storage/expressions/LinearityCheckVisitor.h

@ -22,22 +22,19 @@ namespace storm {
*/
bool check(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override;
virtual void visit(BinaryNumericalFunctionExpression const* expression) override;
virtual void visit(BinaryRelationExpression const* expression) override;
virtual void visit(VariableExpression const* expression) override;
virtual void visit(UnaryBooleanFunctionExpression const* expression) override;
virtual void visit(UnaryNumericalFunctionExpression const* expression) override;
virtual void visit(BooleanLiteralExpression const* expression) override;
virtual void visit(IntegerLiteralExpression const* expression) override;
virtual void visit(DoubleLiteralExpression const* expression) override;
virtual boost::any visit(IfThenElseExpression const& expression) override;
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BinaryRelationExpression const& expression) override;
virtual boost::any visit(VariableExpression const& expression) override;
virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BooleanLiteralExpression const& expression) override;
virtual boost::any visit(IntegerLiteralExpression const& expression) override;
virtual boost::any visit(DoubleLiteralExpression const& expression) override;
private:
enum class LinearityStatus { NonLinear, LinearContainsVariables, LinearWithoutVariables };
// A stack for communicating the results of the subexpressions.
std::stack<LinearityStatus> resultStack;
};
}
}

120
src/storage/expressions/SubstitutionVisitor.cpp

@ -14,138 +14,110 @@ namespace storm {
template<typename MapType>
Expression SubstitutionVisitor<MapType>::substitute(Expression const& expression) {
expression.getBaseExpression().accept(this);
return Expression(this->expressionStack.top());
return Expression(boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getBaseExpression().accept(*this)));
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(IfThenElseExpression const* expression) {
expression->getCondition()->accept(this);
std::shared_ptr<BaseExpression const> conditionExpression = expressionStack.top();
expressionStack.pop();
expression->getThenExpression()->accept(this);
std::shared_ptr<BaseExpression const> thenExpression = expressionStack.top();
expressionStack.pop();
expression->getElseExpression()->accept(this);
std::shared_ptr<BaseExpression const> elseExpression = expressionStack.top();
expressionStack.pop();
boost::any SubstitutionVisitor<MapType>::visit(IfThenElseExpression const& expression) {
std::shared_ptr<BaseExpression const> conditionExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getCondition()->accept(*this));
std::shared_ptr<BaseExpression const> thenExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getThenExpression()->accept(*this));
std::shared_ptr<BaseExpression const> elseExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getElseExpression()->accept(*this));
// If the arguments did not change, we simply push the expression itself.
if (conditionExpression.get() == expression->getCondition().get() && thenExpression.get() == expression->getThenExpression().get() && elseExpression.get() == expression->getElseExpression().get()) {
this->expressionStack.push(expression->getSharedPointer());
if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) {
return expression.getSharedPointer();
} else {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new IfThenElseExpression(expression->getReturnType(), conditionExpression, thenExpression, elseExpression)));
return static_cast<std::shared_ptr<BaseExpression const>>(std::shared_ptr<BaseExpression>(new IfThenElseExpression(expression.getReturnType(), conditionExpression, thenExpression, elseExpression)));
}
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(BinaryBooleanFunctionExpression const* expression) {
expression->getFirstOperand()->accept(this);
std::shared_ptr<BaseExpression const> firstExpression = expressionStack.top();
expressionStack.pop();
expression->getSecondOperand()->accept(this);
std::shared_ptr<BaseExpression const> secondExpression = expressionStack.top();
expressionStack.pop();
boost::any SubstitutionVisitor<MapType>::visit(BinaryBooleanFunctionExpression const& expression) {
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) {
this->expressionStack.push(expression->getSharedPointer());
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType())));
return static_cast<std::shared_ptr<BaseExpression const>>(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getOperatorType())));
}
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(BinaryNumericalFunctionExpression const* expression) {
expression->getFirstOperand()->accept(this);
std::shared_ptr<BaseExpression const> firstExpression = expressionStack.top();
expressionStack.pop();
expression->getSecondOperand()->accept(this);
std::shared_ptr<BaseExpression const> secondExpression = expressionStack.top();
expressionStack.pop();
boost::any SubstitutionVisitor<MapType>::visit(BinaryNumericalFunctionExpression const& expression) {
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) {
this->expressionStack.push(expression->getSharedPointer());
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType())));
return static_cast<std::shared_ptr<BaseExpression const>>(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getOperatorType())));
}
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(BinaryRelationExpression const* expression) {
expression->getFirstOperand()->accept(this);
std::shared_ptr<BaseExpression const> firstExpression = expressionStack.top();
expressionStack.pop();
expression->getSecondOperand()->accept(this);
std::shared_ptr<BaseExpression const> secondExpression = expressionStack.top();
expressionStack.pop();
boost::any SubstitutionVisitor<MapType>::visit(BinaryRelationExpression const& expression) {
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) {
this->expressionStack.push(expression->getSharedPointer());
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
this->expressionStack.push(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getRelationType())));
return static_cast<std::shared_ptr<BaseExpression const>>(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getRelationType())));
}
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(VariableExpression const* expression) {
boost::any SubstitutionVisitor<MapType>::visit(VariableExpression const& expression) {
// If the variable is in the key set of the substitution, we need to replace it.
auto const& nameExpressionPair = this->identifierToExpressionMap.find(expression->getVariableName());
auto const& nameExpressionPair = this->identifierToExpressionMap.find(expression.getVariableName());
if (nameExpressionPair != this->identifierToExpressionMap.end()) {
this->expressionStack.push(nameExpressionPair->second.getBaseExpressionPointer());
return nameExpressionPair->second.getBaseExpressionPointer();
} else {
this->expressionStack.push(expression->getSharedPointer());
return expression.getSharedPointer();
}
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(UnaryBooleanFunctionExpression const* expression) {
expression->getOperand()->accept(this);
std::shared_ptr<BaseExpression const> operandExpression = expressionStack.top();
expressionStack.pop();
boost::any SubstitutionVisitor<MapType>::visit(UnaryBooleanFunctionExpression const& expression) {
std::shared_ptr<BaseExpression const> operandExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getOperand()->accept(*this));
// If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression->getOperand().get()) {
expressionStack.push(expression->getSharedPointer());
if (operandExpression.get() == expression.getOperand().get()) {
return expression.getSharedPointer();
} else {
expressionStack.push(std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType())));
return static_cast<std::shared_ptr<BaseExpression const>>(std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(expression.getReturnType(), operandExpression, expression.getOperatorType())));
}
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(UnaryNumericalFunctionExpression const* expression) {
expression->getOperand()->accept(this);
std::shared_ptr<BaseExpression const> operandExpression = expressionStack.top();
expressionStack.pop();
boost::any SubstitutionVisitor<MapType>::visit(UnaryNumericalFunctionExpression const& expression) {
std::shared_ptr<BaseExpression const> operandExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getOperand()->accept(*this));
// If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression->getOperand().get()) {
expressionStack.push(expression->getSharedPointer());
if (operandExpression.get() == expression.getOperand().get()) {
return expression.getSharedPointer();
} else {
expressionStack.push(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType())));
return static_cast<std::shared_ptr<BaseExpression const>>(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(expression.getReturnType(), operandExpression, expression.getOperatorType())));
}
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(BooleanLiteralExpression const* expression) {
this->expressionStack.push(expression->getSharedPointer());
boost::any SubstitutionVisitor<MapType>::visit(BooleanLiteralExpression const& expression) {
return expression.getSharedPointer();
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(IntegerLiteralExpression const* expression) {
this->expressionStack.push(expression->getSharedPointer());
boost::any SubstitutionVisitor<MapType>::visit(IntegerLiteralExpression const& expression) {
return expression.getSharedPointer();
}
template<typename MapType>
void SubstitutionVisitor<MapType>::visit(DoubleLiteralExpression const* expression) {
this->expressionStack.push(expression->getSharedPointer());
boost::any SubstitutionVisitor<MapType>::visit(DoubleLiteralExpression const& expression) {
return expression.getSharedPointer();
}
// Explicitly instantiate the class with map and unordered_map.

23
src/storage/expressions/SubstitutionVisitor.h

@ -28,21 +28,18 @@ namespace storm {
*/
Expression substitute(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override;
virtual void visit(BinaryNumericalFunctionExpression const* expression) override;
virtual void visit(BinaryRelationExpression const* expression) override;
virtual void visit(VariableExpression const* expression) override;
virtual void visit(UnaryBooleanFunctionExpression const* expression) override;
virtual void visit(UnaryNumericalFunctionExpression const* expression) override;
virtual void visit(BooleanLiteralExpression const* expression) override;
virtual void visit(IntegerLiteralExpression const* expression) override;
virtual void visit(DoubleLiteralExpression const* expression) override;
virtual boost::any visit(IfThenElseExpression const& expression) override;
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BinaryRelationExpression const& expression) override;
virtual boost::any visit(VariableExpression const& expression) override;
virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override;
virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override;
virtual boost::any visit(BooleanLiteralExpression const& expression) override;
virtual boost::any visit(IntegerLiteralExpression const& expression) override;
virtual boost::any visit(DoubleLiteralExpression const& expression) override;
private:
// A stack of expression used to pass the results to the higher levels.
std::stack<std::shared_ptr<BaseExpression const>> expressionStack;
// A mapping of identifier names to expressions with which they shall be replaced.
MapType const& identifierToExpressionMap;
};

80
src/storage/expressions/TypeCheckVisitor.cpp

@ -1,80 +0,0 @@
#include "src/storage/expressions/TypeCheckVisitor.h"
#include "src/storage/expressions/Expressions.h"
#include "src/utility/macros.h"
#include "src/exceptions/InvalidTypeException.h"
namespace storm {
namespace expressions {
template<typename MapType>
TypeCheckVisitor<MapType>::TypeCheckVisitor(MapType const& identifierToTypeMap) : identifierToTypeMap(identifierToTypeMap) {
// Intentionally left empty.
}
template<typename MapType>
void TypeCheckVisitor<MapType>::check(Expression const& expression) {
expression.getBaseExpression().accept(this);
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(IfThenElseExpression const* expression) {
expression->getCondition()->accept(this);
expression->getThenExpression()->accept(this);
expression->getElseExpression()->accept(this);
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(BinaryBooleanFunctionExpression const* expression) {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(BinaryNumericalFunctionExpression const* expression) {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(BinaryRelationExpression const* expression) {
expression->getFirstOperand()->accept(this);
expression->getSecondOperand()->accept(this);
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(VariableExpression const* expression) {
auto identifierTypePair = this->identifierToTypeMap.find(expression->getVariableName());
STORM_LOG_THROW(identifierTypePair != this->identifierToTypeMap.end(), storm::exceptions::InvalidArgumentException, "No type available for identifier '" << expression->getVariableName() << "'.");
STORM_LOG_THROW(identifierTypePair->second == expression->getReturnType(), storm::exceptions::InvalidTypeException, "Type mismatch for variable '" << expression->getVariableName() << "': expected '" << identifierTypePair->first << "', but found '" << expression->getReturnType() << "'.");
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(UnaryBooleanFunctionExpression const* expression) {
expression->getOperand()->accept(this);
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(UnaryNumericalFunctionExpression const* expression) {
expression->getOperand()->accept(this);
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(BooleanLiteralExpression const* expression) {
// Intentionally left empty.
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(IntegerLiteralExpression const* expression) {
// Intentionally left empty.
}
template<typename MapType>
void TypeCheckVisitor<MapType>::visit(DoubleLiteralExpression const* expression) {
// Intentionally left empty.
}
// Explicitly instantiate the class with map and unordered_map.
template class TypeCheckVisitor<std::map<std::string, ExpressionReturnType>>;
template class TypeCheckVisitor<std::unordered_map<std::string, ExpressionReturnType>>;
}
}

47
src/storage/expressions/TypeCheckVisitor.h

@ -1,47 +0,0 @@
#ifndef STORM_STORAGE_EXPRESSIONS_TYPECHECKVISITOR_H_
#define STORM_STORAGE_EXPRESSIONS_TYPECHECKVISITOR_H_
#include <stack>
#include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/ExpressionVisitor.h"
namespace storm {
namespace expressions {
template<typename MapType>
class TypeCheckVisitor : public ExpressionVisitor {
public:
/*!
* Creates a new type check visitor that uses the given map to check the types of variables and constants.
*
* @param identifierToTypeMap A mapping from identifiers to expressions.
*/
TypeCheckVisitor(MapType const& identifierToTypeMap);
/*!
* Checks that the types of the identifiers in the given expression match the ones in the previously given
* map.
*
* @param expression The expression in which to check the types.
*/
void check(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override;
virtual void visit(BinaryNumericalFunctionExpression const* expression) override;
virtual void visit(BinaryRelationExpression const* expression) override;
virtual void visit(VariableExpression const* expression) override;
virtual void visit(UnaryBooleanFunctionExpression const* expression) override;
virtual void visit(UnaryNumericalFunctionExpression const* expression) override;
virtual void visit(BooleanLiteralExpression const* expression) override;
virtual void visit(IntegerLiteralExpression const* expression) override;
virtual void visit(DoubleLiteralExpression const* expression) override;
private:
// A mapping of identifier names to expressions with which they shall be replaced.
MapType const& identifierToTypeMap;
};
}
}
#endif /* STORM_STORAGE_EXPRESSIONS_TYPECHECKVISITOR_H_ */

4
src/storage/expressions/UnaryBooleanFunctionExpression.cpp

@ -45,8 +45,8 @@ namespace storm {
}
}
void UnaryBooleanFunctionExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any UnaryBooleanFunctionExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
void UnaryBooleanFunctionExpression::printToStream(std::ostream& stream) const {

2
src/storage/expressions/UnaryBooleanFunctionExpression.h

@ -35,7 +35,7 @@ namespace storm {
virtual storm::expressions::OperatorType getOperator() const override;
virtual bool evaluateAsBool(Valuation const* valuation = nullptr) const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the operator associated with this expression.

4
src/storage/expressions/UnaryNumericalFunctionExpression.cpp

@ -54,8 +54,8 @@ namespace storm {
}
}
void UnaryNumericalFunctionExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any UnaryNumericalFunctionExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
void UnaryNumericalFunctionExpression::printToStream(std::ostream& stream) const {

2
src/storage/expressions/UnaryNumericalFunctionExpression.h

@ -36,7 +36,7 @@ namespace storm {
virtual int_fast64_t evaluateAsInt(Valuation const* valuation = nullptr) const override;
virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the operator associated with this expression.

4
src/storage/expressions/VariableExpression.cpp

@ -65,8 +65,8 @@ namespace storm {
return this->shared_from_this();
}
void VariableExpression::accept(ExpressionVisitor* visitor) const {
visitor->visit(this);
boost::any VariableExpression::accept(ExpressionVisitor& visitor) const {
return visitor.visit(*this);
}
void VariableExpression::printToStream(std::ostream& stream) const {

2
src/storage/expressions/VariableExpression.h

@ -35,7 +35,7 @@ namespace storm {
virtual std::set<std::string> getVariables() const override;
virtual std::map<std::string, ExpressionReturnType> getVariablesAndTypes() const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual void accept(ExpressionVisitor* visitor) const override;
virtual boost::any accept(ExpressionVisitor& visitor) const override;
/*!
* Retrieves the name of the variable associated with this expression.

98
src/storage/prism/Program.cpp

@ -360,13 +360,6 @@ namespace storm {
std::set<std::string> containedIdentifiers = constant.getExpression().getVariables();
bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << constant.getFilename() << ", line " << constant.getLineNumber() << ": defining expression refers to unknown identifiers.");
// Now check that the constants appear with the right types.
try {
constant.getExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << constant.getFilename() << ", line " << constant.getLineNumber() << ": " << e.what());
}
}
// Finally, register the type of the constant for later type checks.
@ -388,11 +381,6 @@ namespace storm {
std::set<std::string> containedIdentifiers = variable.getInitialValueExpression().getVariables();
bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": initial value expression refers to unknown constants.");
try {
variable.getInitialValueExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what());
}
// Register the type of the constant for later type checks.
identifierToTypeMap.emplace(variable.getName(), storm::expressions::ExpressionReturnType::Bool);
@ -410,30 +398,15 @@ namespace storm {
std::set<std::string> containedIdentifiers = variable.getLowerBoundExpression().getVariables();
bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": lower bound expression refers to unknown constants.");
try {
variable.getLowerBoundExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what());
}
containedIdentifiers = variable.getLowerBoundExpression().getVariables();
isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": upper bound expression refers to unknown constants.");
try {
variable.getUpperBoundExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what());
}
// Check the initial value of the variable.
containedIdentifiers = variable.getInitialValueExpression().getVariables();
isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": initial value expression refers to unknown constants.");
try {
variable.getInitialValueExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what());
}
// Register the type of the constant for later type checks.
identifierToTypeMap.emplace(variable.getName(), storm::expressions::ExpressionReturnType::Int);
@ -454,11 +427,6 @@ namespace storm {
std::set<std::string> containedIdentifiers = variable.getInitialValueExpression().getVariables();
bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": initial value expression refers to unknown constants.");
try {
variable.getInitialValueExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what());
}
// Register the type of the constant for later type checks.
identifierToTypeMap.emplace(variable.getName(), storm::expressions::ExpressionReturnType::Bool);
@ -478,30 +446,15 @@ namespace storm {
std::set<std::string> containedIdentifiers = variable.getLowerBoundExpression().getVariables();
bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": lower bound expression refers to unknown constants.");
try {
variable.getLowerBoundExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what());
}
containedIdentifiers = variable.getLowerBoundExpression().getVariables();
isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": upper bound expression refers to unknown constants.");
try {
variable.getUpperBoundExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what());
}
// Check the initial value of the variable.
containedIdentifiers = variable.getInitialValueExpression().getVariables();
isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": initial value expression refers to unknown constants.");
try {
variable.getInitialValueExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what());
}
// Record the new identifier for future checks.
variableNames.insert(variable.getName());
@ -528,11 +481,6 @@ namespace storm {
std::set<std::string> containedIdentifiers = command.getGuardExpression().getVariables();
bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": guard refers to unknown identifiers.");
try {
command.getGuardExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": " << e.what());
}
STORM_LOG_THROW(command.getGuardExpression().hasBooleanReturnType(), storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": expression for guard must evaluate to type 'bool'.");
// Check all updates.
@ -540,11 +488,6 @@ namespace storm {
containedIdentifiers = update.getLikelihoodExpression().getVariables();
isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": likelihood expression refers to unknown identifiers.");
try {
update.getLikelihoodExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": " << e.what());
}
// Check all assignments.
std::set<std::string> alreadyAssignedIdentifiers;
@ -563,11 +506,6 @@ namespace storm {
containedIdentifiers = assignment.getExpression().getVariables();
isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": likelihood expression refers to unknown identifiers.");
try {
assignment.getExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": " << e.what());
}
// Add the current variable to the set of assigned variables (of this update).
alreadyAssignedIdentifiers.insert(assignment.getVariableName());
@ -582,21 +520,11 @@ namespace storm {
std::set<std::string> containedIdentifiers = stateReward.getStatePredicateExpression().getVariables();
bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": state reward expression refers to unknown identifiers.");
try {
stateReward.getStatePredicateExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": " << e.what());
}
STORM_LOG_THROW(stateReward.getStatePredicateExpression().hasBooleanReturnType(), storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": state predicate must evaluate to type 'bool'.");
containedIdentifiers = stateReward.getRewardValueExpression().getVariables();
isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": state reward value expression refers to unknown identifiers.");
try {
stateReward.getRewardValueExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": " << e.what());
}
STORM_LOG_THROW(stateReward.getRewardValueExpression().hasNumericalReturnType(), storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": reward value expression must evaluate to numerical type.");
}
@ -604,21 +532,11 @@ namespace storm {
std::set<std::string> containedIdentifiers = transitionReward.getStatePredicateExpression().getVariables();
bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": state reward expression refers to unknown identifiers.");
try {
transitionReward.getStatePredicateExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": " << e.what());
}
STORM_LOG_THROW(transitionReward.getStatePredicateExpression().hasBooleanReturnType(), storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": state predicate must evaluate to type 'bool'.");
containedIdentifiers = transitionReward.getRewardValueExpression().getVariables();
isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": state reward value expression refers to unknown identifiers.");
try {
transitionReward.getRewardValueExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": " << e.what());
}
STORM_LOG_THROW(transitionReward.getRewardValueExpression().hasNumericalReturnType(), storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": reward value expression must evaluate to numerical type.");
}
}
@ -627,11 +545,6 @@ namespace storm {
std::set<std::string> containedIdentifiers = this->getInitialConstruct().getInitialStatesExpression().getVariables();
bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << this->getInitialConstruct().getFilename() << ", line " << this->getInitialConstruct().getLineNumber() << ": initial expression refers to unknown identifiers.");
try {
this->getInitialConstruct().getInitialStatesExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << this->getInitialConstruct().getFilename() << ", line " << this->getInitialConstruct().getLineNumber() << ": " << e.what());
}
// Check the labels.
for (auto const& label : this->getLabels()) {
@ -641,12 +554,6 @@ namespace storm {
std::set<std::string> containedIdentifiers = label.getStatePredicateExpression().getVariables();
bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << label.getFilename() << ", line " << label.getLineNumber() << ": label expression refers to unknown identifiers.");
try {
label.getStatePredicateExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << label.getFilename() << ", line " << label.getLineNumber() << ": " << e.what());
}
STORM_LOG_THROW(label.getStatePredicateExpression().hasBooleanReturnType(), storm::exceptions::WrongFormatException, "Error in " << label.getFilename() << ", line " << label.getLineNumber() << ": label predicate must evaluate to type 'bool'.");
}
@ -658,11 +565,6 @@ namespace storm {
std::set<std::string> containedIdentifiers = formula.getExpression().getVariables();
bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end());
STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << formula.getFilename() << ", line " << formula.getLineNumber() << ": formula expression refers to unknown identifiers.");
try {
formula.getExpression().check(identifierToTypeMap);
} catch (storm::exceptions::InvalidTypeException const& e) {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << formula.getFilename() << ", line " << formula.getLineNumber() << ": " << e.what());
}
// Record the new identifier for future checks.
allIdentifiers.insert(formula.getName());

1
test/functional/solver/GlpkLpSolverTest.cpp

@ -14,6 +14,7 @@ TEST(GlpkLpSolver, LPOptimizeMax) {
ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("z", 0, 1));
ASSERT_NO_THROW(solver.update());
solver.addConstraint("", storm::expressions::Expression::createDoubleVariable("x") + storm::expressions::Expression::createDoubleVariable("y") + storm::expressions::Expression::createDoubleVariable("z") <= storm::expressions::Expression::createDoubleLiteral(12));
ASSERT_NO_THROW(solver.addConstraint("", storm::expressions::Expression::createDoubleVariable("x") + storm::expressions::Expression::createDoubleVariable("y") + storm::expressions::Expression::createDoubleVariable("z") <= storm::expressions::Expression::createDoubleLiteral(12)));
ASSERT_NO_THROW(solver.addConstraint("", storm::expressions::Expression::createDoubleLiteral(0.5) * storm::expressions::Expression::createDoubleVariable("y") + storm::expressions::Expression::createDoubleVariable("z") - storm::expressions::Expression::createDoubleVariable("x") == storm::expressions::Expression::createDoubleLiteral(5)));
ASSERT_NO_THROW(solver.addConstraint("", storm::expressions::Expression::createDoubleVariable("y") - storm::expressions::Expression::createDoubleVariable("x") <= storm::expressions::Expression::createDoubleLiteral(5.5)));

Loading…
Cancel
Save