Browse Source

Added some more methods to valuations. Changed visitor invocation slightly. Moves ExpressionReturnType in separate file. Finished linearity checking visitor. Started on visitor that extracts coefficients of linear expressions.

Former-commit-id: 6e3d0ec910
tempestpy_adaptions
dehnert 11 years ago
parent
commit
389fddc996
  1. 10
      src/storage/expressions/BaseExpression.cpp
  2. 8
      src/storage/expressions/BaseExpression.h
  3. 17
      src/storage/expressions/Expression.cpp
  4. 15
      src/storage/expressions/Expression.h
  5. 15
      src/storage/expressions/ExpressionReturnType.cpp
  6. 17
      src/storage/expressions/ExpressionReturnType.h
  7. 4
      src/storage/expressions/IdentifierSubstitutionVisitor.cpp
  8. 2
      src/storage/expressions/IdentifierSubstitutionVisitor.h
  9. 70
      src/storage/expressions/LinearCoefficientVisitor.cpp
  10. 46
      src/storage/expressions/LinearCoefficientVisitor.h
  11. 91
      src/storage/expressions/LinearityCheckVisitor.cpp
  12. 10
      src/storage/expressions/LinearityCheckVisitor.h
  13. 59
      src/storage/expressions/SimpleValuation.cpp
  14. 14
      src/storage/expressions/SimpleValuation.h
  15. 4
      src/storage/expressions/SubstitutionVisitor.cpp
  16. 2
      src/storage/expressions/SubstitutionVisitor.h
  17. 4
      src/storage/expressions/TypeCheckVisitor.cpp
  18. 2
      src/storage/expressions/TypeCheckVisitor.h
  19. 36
      src/storage/expressions/Valuation.h
  20. 17
      test/functional/storage/ExpressionTest.cpp

10
src/storage/expressions/BaseExpression.cpp

@ -81,16 +81,6 @@ namespace storm {
return this->shared_from_this(); return this->shared_from_this();
} }
std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue) {
switch (enumValue) {
case ExpressionReturnType::Undefined: stream << "undefined"; break;
case ExpressionReturnType::Bool: stream << "bool"; break;
case ExpressionReturnType::Int: stream << "int"; break;
case ExpressionReturnType::Double: stream << "double"; break;
}
return stream;
}
std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) { std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) {
expression.printToStream(stream); expression.printToStream(stream);
return stream; return stream;

8
src/storage/expressions/BaseExpression.h

@ -6,6 +6,7 @@
#include <set> #include <set>
#include <iostream> #include <iostream>
#include "src/storage/expressions/ExpressionReturnType.h"
#include "src/storage/expressions/Valuation.h" #include "src/storage/expressions/Valuation.h"
#include "src/storage/expressions/ExpressionVisitor.h" #include "src/storage/expressions/ExpressionVisitor.h"
#include "src/storage/expressions/OperatorType.h" #include "src/storage/expressions/OperatorType.h"
@ -14,13 +15,6 @@
namespace storm { namespace storm {
namespace expressions { namespace expressions {
/*!
* Each node in an expression tree has a uniquely defined type from this enum.
*/
enum class ExpressionReturnType {Undefined, Bool, Int, Double};
std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue);
/*! /*!
* The base class of all expression classes. * The base class of all expression classes.
*/ */

17
src/storage/expressions/Expression.cpp

@ -5,6 +5,7 @@
#include "src/storage/expressions/SubstitutionVisitor.h" #include "src/storage/expressions/SubstitutionVisitor.h"
#include "src/storage/expressions/IdentifierSubstitutionVisitor.h" #include "src/storage/expressions/IdentifierSubstitutionVisitor.h"
#include "src/storage/expressions/TypeCheckVisitor.h" #include "src/storage/expressions/TypeCheckVisitor.h"
#include "src/storage/expressions/LinearityCheckVisitor.h"
#include "src/storage/expressions/Expressions.h" #include "src/storage/expressions/Expressions.h"
#include "src/exceptions/InvalidTypeException.h" #include "src/exceptions/InvalidTypeException.h"
#include "src/exceptions/ExceptionMacros.h" #include "src/exceptions/ExceptionMacros.h"
@ -16,27 +17,27 @@ namespace storm {
} }
Expression Expression::substitute(std::map<std::string, Expression> const& identifierToExpressionMap) const { Expression Expression::substitute(std::map<std::string, Expression> const& identifierToExpressionMap) const {
return SubstitutionVisitor<std::map<std::string, Expression>>(identifierToExpressionMap).substitute(this->getBaseExpressionPointer().get());
return SubstitutionVisitor<std::map<std::string, Expression>>(identifierToExpressionMap).substitute(this);
} }
Expression Expression::substitute(std::unordered_map<std::string, Expression> const& identifierToExpressionMap) const { Expression Expression::substitute(std::unordered_map<std::string, Expression> const& identifierToExpressionMap) const {
return SubstitutionVisitor<std::unordered_map<std::string, Expression>>(identifierToExpressionMap).substitute(this->getBaseExpressionPointer().get());
return SubstitutionVisitor<std::unordered_map<std::string, Expression>>(identifierToExpressionMap).substitute(this);
} }
Expression Expression::substitute(std::map<std::string, std::string> const& identifierToIdentifierMap) const { Expression Expression::substitute(std::map<std::string, std::string> const& identifierToIdentifierMap) const {
return IdentifierSubstitutionVisitor<std::map<std::string, std::string>>(identifierToIdentifierMap).substitute(this->getBaseExpressionPointer().get());
return IdentifierSubstitutionVisitor<std::map<std::string, std::string>>(identifierToIdentifierMap).substitute(this);
} }
Expression Expression::substitute(std::unordered_map<std::string, std::string> const& identifierToIdentifierMap) const { Expression Expression::substitute(std::unordered_map<std::string, std::string> const& identifierToIdentifierMap) const {
return IdentifierSubstitutionVisitor<std::unordered_map<std::string, std::string>>(identifierToIdentifierMap).substitute(this->getBaseExpressionPointer().get());
return IdentifierSubstitutionVisitor<std::unordered_map<std::string, std::string>>(identifierToIdentifierMap).substitute(this);
} }
void Expression::check(std::map<std::string, storm::expressions::ExpressionReturnType> const& identifierToTypeMap) const { void Expression::check(std::map<std::string, storm::expressions::ExpressionReturnType> const& identifierToTypeMap) const {
return TypeCheckVisitor<std::map<std::string, storm::expressions::ExpressionReturnType>>(identifierToTypeMap).check(this->getBaseExpressionPointer().get());
return TypeCheckVisitor<std::map<std::string, storm::expressions::ExpressionReturnType>>(identifierToTypeMap).check(this);
} }
void Expression::check(std::unordered_map<std::string, storm::expressions::ExpressionReturnType> const& identifierToTypeMap) const { void Expression::check(std::unordered_map<std::string, storm::expressions::ExpressionReturnType> const& identifierToTypeMap) const {
return TypeCheckVisitor<std::unordered_map<std::string, storm::expressions::ExpressionReturnType>>(identifierToTypeMap).check(this->getBaseExpressionPointer().get());
return TypeCheckVisitor<std::unordered_map<std::string, storm::expressions::ExpressionReturnType>>(identifierToTypeMap).check(this);
} }
bool Expression::evaluateAsBool(Valuation const* valuation) const { bool Expression::evaluateAsBool(Valuation const* valuation) const {
@ -105,6 +106,10 @@ namespace storm {
|| this->getOperator() == OperatorType::Greater || this->getOperator() == OperatorType::GreaterOrEqual; || this->getOperator() == OperatorType::Greater || this->getOperator() == OperatorType::GreaterOrEqual;
} }
bool Expression::isLinear() const {
return LinearityCheckVisitor().check(*this);
}
std::set<std::string> Expression::getVariables() const { std::set<std::string> Expression::getVariables() const {
return this->getBaseExpression().getVariables(); return this->getBaseExpression().getVariables();
} }

15
src/storage/expressions/Expression.h

@ -6,6 +6,7 @@
#include <unordered_map> #include <unordered_map>
#include "src/storage/expressions/BaseExpression.h" #include "src/storage/expressions/BaseExpression.h"
#include "src/storage/expressions/ExpressionVisitor.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -238,6 +239,13 @@ namespace storm {
*/ */
bool isRelationalExpression() const; bool isRelationalExpression() const;
/*!
* Retrieves whether this expression is a linear expression.
*
* @return True iff the expression is linear.
*/
bool isLinear() const;
/*! /*!
* Retrieves the set of all variables that appear in the expression. * Retrieves the set of all variables that appear in the expression.
* *
@ -281,6 +289,13 @@ namespace storm {
*/ */
bool hasBooleanReturnType() const; bool hasBooleanReturnType() const;
/*!
* Accepts the given visitor.
*
* @param visitor The visitor to accept.
*/
void accept(ExpressionVisitor* visitor) const;
friend std::ostream& operator<<(std::ostream& stream, Expression const& expression); friend std::ostream& operator<<(std::ostream& stream, Expression const& expression);
private: private:

15
src/storage/expressions/ExpressionReturnType.cpp

@ -0,0 +1,15 @@
#include "src/storage/expressions/ExpressionReturnType.h"
namespace storm {
namespace expressions {
std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue) {
switch (enumValue) {
case ExpressionReturnType::Undefined: stream << "undefined"; break;
case ExpressionReturnType::Bool: stream << "bool"; break;
case ExpressionReturnType::Int: stream << "int"; break;
case ExpressionReturnType::Double: stream << "double"; break;
}
return stream;
}
}
}

17
src/storage/expressions/ExpressionReturnType.h

@ -0,0 +1,17 @@
#ifndef STORM_STORAGE_EXPRESSIONS_EXPRESSIONRETURNTYPE_H_
#define STORM_STORAGE_EXPRESSIONS_EXPRESSIONRETURNTYPE_H_
#include <iostream>
namespace storm {
namespace expressions {
/*!
* Each node in an expression tree has a uniquely defined type from this enum.
*/
enum class ExpressionReturnType {Undefined, Bool, Int, Double};
std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue);
}
}
#endif /* STORM_STORAGE_EXPRESSIONS_EXPRESSIONRETURNTYPE_H_ */

4
src/storage/expressions/IdentifierSubstitutionVisitor.cpp

@ -13,8 +13,8 @@ namespace storm {
} }
template<typename MapType> template<typename MapType>
Expression IdentifierSubstitutionVisitor<MapType>::substitute(BaseExpression const* expression) {
expression->accept(this);
Expression IdentifierSubstitutionVisitor<MapType>::substitute(Expression const& expression) {
expression.getBaseExpression().accept(this);
return Expression(this->expressionStack.top()); return Expression(this->expressionStack.top());
} }

2
src/storage/expressions/IdentifierSubstitutionVisitor.h

@ -26,7 +26,7 @@ namespace storm {
* @return The expression in which all identifiers in the key set of the previously given mapping are * @return The expression in which all identifiers in the key set of the previously given mapping are
* substituted with the mapped-to expressions. * substituted with the mapped-to expressions.
*/ */
Expression substitute(BaseExpression const* expression);
Expression substitute(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override; virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override; virtual void visit(BinaryBooleanFunctionExpression const* expression) override;

70
src/storage/expressions/LinearCoefficientVisitor.cpp

@ -0,0 +1,70 @@
#include "src/storage/expressions/LinearCoefficientVisitor.h"
#include "src/storage/expressions/Expressions.h"
#include "src/exceptions/ExceptionMacros.h"
#include "src/exceptions/InvalidArgumentException.h"
namespace storm {
namespace expressions {
std::pair<SimpleValuation, double> LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) {
expression.getBaseExpression().accept(this);
return resultStack.top();
}
void LinearCoefficientVisitor::visit(IfThenElseExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) {
}
void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) {
}
void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) {
}
void LinearCoefficientVisitor::visit(VariableExpression const* expression) {
SimpleValuation valuation;
switch (expression->getReturnType()) {
case ExpressionReturnType::Bool: LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); break;
case ExpressionReturnType::Int:
case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression->getVariableName(), 1); break;
case ExpressionReturnType::Undefined: LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal expression return type."); break;
}
resultStack.push(std::make_pair(valuation, 0));
}
void LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const* expression) {
if (expression->getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) {
// Here, we need to negate all double identifiers.
std::pair<SimpleValuation, double>& valuationConstantPair = resultStack.top();
for (auto const& identifier : valuationConstantPair.first.getDoubleIdentifiers()) {
valuationConstantPair.first.setDoubleValue(identifier, -valuationConstantPair.first.getDoubleValue(identifier));
}
} else {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
}
void LinearCoefficientVisitor::visit(BooleanLiteralExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(IntegerLiteralExpression const* expression) {
resultStack.push(std::make_pair(SimpleValuation(), static_cast<double>(expression->getValue())));
}
void LinearCoefficientVisitor::visit(DoubleLiteralExpression const* expression) {
resultStack.push(std::make_pair(SimpleValuation(), expression->getValue()));
}
}
}

46
src/storage/expressions/LinearCoefficientVisitor.h

@ -0,0 +1,46 @@
#ifndef STORM_STORAGE_EXPRESSIONS_LINEARCOEFFICIENTVISITOR_H_
#define STORM_STORAGE_EXPRESSIONS_LINEARCOEFFICIENTVISITOR_H_
#include <stack>
#include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/ExpressionVisitor.h"
#include "src/storage/expressions/SimpleValuation.h"
namespace storm {
namespace expressions {
class LinearCoefficientVisitor : public ExpressionVisitor {
public:
/*!
* Creates a linear coefficient visitor.
*/
LinearCoefficientVisitor() = default;
/*!
* Computes the (double) coefficients of all identifiers appearing in the expression if the expression
* was rewritten as a sum of atoms.. If the expression is not linear, an exception is thrown.
*
* @param expression The expression for which to compute the coefficients.
* @return A pair consisting of a mapping from identifiers to their coefficients and the coefficient of
* the constant atom.
*/
std::pair<SimpleValuation, double> getLinearCoefficients(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override;
virtual void visit(BinaryNumericalFunctionExpression const* expression) override;
virtual void visit(BinaryRelationExpression const* expression) override;
virtual void visit(VariableExpression const* expression) override;
virtual void visit(UnaryBooleanFunctionExpression const* expression) override;
virtual void visit(UnaryNumericalFunctionExpression const* expression) override;
virtual void visit(BooleanLiteralExpression const* expression) override;
virtual void visit(IntegerLiteralExpression const* expression) override;
virtual void visit(DoubleLiteralExpression const* expression) override;
private:
std::stack<std::pair<SimpleValuation, double>> resultStack;
};
}
}
#endif /* STORM_STORAGE_EXPRESSIONS_LINEARCOEFFICIENTVISITOR_H_ */

91
src/storage/expressions/LinearityCheckVisitor.cpp

@ -6,73 +6,124 @@
namespace storm { namespace storm {
namespace expressions { namespace expressions {
bool LinearityCheckVisitor::check(BaseExpression const* expression) {
expression->accept(this);
return resultStack.top();
LinearityCheckVisitor::LinearityCheckVisitor() : resultStack() {
// Intentionally left empty.
}
bool LinearityCheckVisitor::check(Expression const& expression) {
expression.getBaseExpression().accept(this);
return resultStack.top() == LinearityStatus::LinearWithoutVariables || resultStack.top() == LinearityStatus::LinearContainsVariables;
} }
void LinearityCheckVisitor::visit(IfThenElseExpression const* expression) { void LinearityCheckVisitor::visit(IfThenElseExpression const* expression) {
// An if-then-else expression is never linear. // An if-then-else expression is never linear.
resultStack.push(false);
resultStack.push(LinearityStatus::NonLinear);
} }
void LinearityCheckVisitor::visit(BinaryBooleanFunctionExpression const* expression) { void LinearityCheckVisitor::visit(BinaryBooleanFunctionExpression const* expression) {
// Boolean function applications are not allowed in linear expressions. // Boolean function applications are not allowed in linear expressions.
resultStack.push(false);
resultStack.push(LinearityStatus::NonLinear);
} }
void LinearityCheckVisitor::visit(BinaryNumericalFunctionExpression const* expression) { void LinearityCheckVisitor::visit(BinaryNumericalFunctionExpression const* expression) {
bool leftResult = true;
bool rightResult = true;
LinearityStatus leftResult;
LinearityStatus rightResult;
switch (expression->getOperatorType()) { switch (expression->getOperatorType()) {
case BinaryNumericalFunctionExpression::OperatorType::Plus: case BinaryNumericalFunctionExpression::OperatorType::Plus:
case BinaryNumericalFunctionExpression::OperatorType::Minus: case BinaryNumericalFunctionExpression::OperatorType::Minus:
expression->getFirstOperand()->accept(this); expression->getFirstOperand()->accept(this);
leftResult = resultStack.top(); leftResult = resultStack.top();
if (!leftResult) {
if (leftResult == LinearityStatus::NonLinear) {
return;
} else { } else {
resultStack.pop(); resultStack.pop();
expression->getSecondOperand()->accept(this); expression->getSecondOperand()->accept(this);
rightResult = resultStack.top();
if (rightResult == LinearityStatus::NonLinear) {
return;
}
resultStack.pop();
} }
resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables);
break;
case BinaryNumericalFunctionExpression::OperatorType::Times: case BinaryNumericalFunctionExpression::OperatorType::Times:
case BinaryNumericalFunctionExpression::OperatorType::Divide: case BinaryNumericalFunctionExpression::OperatorType::Divide:
case BinaryNumericalFunctionExpression::OperatorType::Min: resultStack.push(false); break;
case BinaryNumericalFunctionExpression::OperatorType::Max: resultStack.push(false); break;
expression->getFirstOperand()->accept(this);
leftResult = resultStack.top();
if (leftResult == LinearityStatus::NonLinear) {
return;
} else {
resultStack.pop();
expression->getSecondOperand()->accept(this);
rightResult = resultStack.top();
if (rightResult == LinearityStatus::NonLinear) {
return;
}
resultStack.pop();
}
if (leftResult == LinearityStatus::LinearContainsVariables && rightResult == LinearityStatus::LinearContainsVariables) {
resultStack.push(LinearityStatus::NonLinear);
}
resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables);
break;
case BinaryNumericalFunctionExpression::OperatorType::Min: resultStack.push(LinearityStatus::NonLinear); break;
case BinaryNumericalFunctionExpression::OperatorType::Max: resultStack.push(LinearityStatus::NonLinear); break;
} }
} }
void LinearityCheckVisitor::visit(BinaryRelationExpression const* expression) { void LinearityCheckVisitor::visit(BinaryRelationExpression const* expression) {
resultStack.push(false);
LinearityStatus leftResult;
LinearityStatus rightResult;
expression->getFirstOperand()->accept(this);
leftResult = resultStack.top();
if (leftResult == LinearityStatus::NonLinear) {
return;
} else {
resultStack.pop();
expression->getSecondOperand()->accept(this);
rightResult = resultStack.top();
if (rightResult == LinearityStatus::NonLinear) {
return;
}
resultStack.pop();
}
resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables);
} }
void LinearityCheckVisitor::visit(VariableExpression const* expression) { void LinearityCheckVisitor::visit(VariableExpression const* expression) {
resultStack.push(true);
resultStack.push(LinearityStatus::LinearContainsVariables);
} }
void LinearityCheckVisitor::visit(UnaryBooleanFunctionExpression const* expression) { void LinearityCheckVisitor::visit(UnaryBooleanFunctionExpression const* expression) {
// Boolean function applications are not allowed in linear expressions. // Boolean function applications are not allowed in linear expressions.
resultStack.push(false);
resultStack.push(LinearityStatus::NonLinear);
} }
void LinearityCheckVisitor::visit(UnaryNumericalFunctionExpression const* expression) { void LinearityCheckVisitor::visit(UnaryNumericalFunctionExpression const* expression) {
// Intentionally left empty (just pass subresult one level further).
switch (expression->getOperatorType()) {
case UnaryNumericalFunctionExpression::OperatorType::Minus: break;
case UnaryNumericalFunctionExpression::OperatorType::Floor:
case UnaryNumericalFunctionExpression::OperatorType::Ceil: resultStack.pop(); resultStack.push(LinearityStatus::NonLinear); break;
}
} }
void LinearityCheckVisitor::visit(BooleanLiteralExpression const* expression) { void LinearityCheckVisitor::visit(BooleanLiteralExpression const* expression) {
// Boolean function applications are not allowed in linear expressions.
resultStack.push(false);
resultStack.push(LinearityStatus::NonLinear);
} }
void LinearityCheckVisitor::visit(IntegerLiteralExpression const* expression) { void LinearityCheckVisitor::visit(IntegerLiteralExpression const* expression) {
resultStack.push(true);
resultStack.push(LinearityStatus::LinearWithoutVariables);
} }
void LinearityCheckVisitor::visit(DoubleLiteralExpression const* expression) { void LinearityCheckVisitor::visit(DoubleLiteralExpression const* expression) {
resultStack.push(true);
resultStack.push(LinearityStatus::LinearWithoutVariables);
} }
} }
} }

10
src/storage/expressions/LinearityCheckVisitor.h

@ -13,12 +13,14 @@ namespace storm {
/*! /*!
* Creates a linearity check visitor. * Creates a linearity check visitor.
*/ */
LinearityCheckVisitor() = default;
LinearityCheckVisitor();
/*! /*!
* Checks that the given expression is linear. * Checks that the given expression is linear.
*
* @param expression The expression to check for linearity.
*/ */
bool check(BaseExpression const* expression);
bool check(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override; virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override; virtual void visit(BinaryBooleanFunctionExpression const* expression) override;
@ -32,8 +34,10 @@ namespace storm {
virtual void visit(DoubleLiteralExpression const* expression) override; virtual void visit(DoubleLiteralExpression const* expression) override;
private: private:
enum class LinearityStatus { NonLinear, LinearContainsVariables, LinearWithoutVariables };
// A stack for communicating the results of the subexpressions. // A stack for communicating the results of the subexpressions.
std::stack<bool> resultStack;
std::stack<LinearityStatus> resultStack;
}; };
} }
} }

59
src/storage/expressions/SimpleValuation.cpp

@ -1,5 +1,8 @@
#include <boost/functional/hash.hpp>
#include "src/storage/expressions/SimpleValuation.h" #include "src/storage/expressions/SimpleValuation.h"
#include <set>
#include <boost/functional/hash.hpp>
#include "src/exceptions/ExceptionMacros.h" #include "src/exceptions/ExceptionMacros.h"
#include "src/exceptions/InvalidArgumentException.h" #include "src/exceptions/InvalidArgumentException.h"
#include "src/exceptions/InvalidAccessException.h" #include "src/exceptions/InvalidAccessException.h"
@ -43,6 +46,18 @@ namespace storm {
this->identifierToValueMap.erase(nameValuePair); this->identifierToValueMap.erase(nameValuePair);
} }
ExpressionReturnType SimpleValuation::getIdentifierType(std::string const& name) const {
auto nameValuePair = this->identifierToValueMap.find(name);
LOG_THROW(nameValuePair != this->identifierToValueMap.end(), storm::exceptions::InvalidAccessException, "Access to unkown identifier '" << name << "'.");
if (nameValuePair->second.type() == typeid(bool)) {
return ExpressionReturnType::Bool;
} else if (nameValuePair->second.type() == typeid(int_fast64_t)) {
return ExpressionReturnType::Int;
} else {
return ExpressionReturnType::Double;
}
}
bool SimpleValuation::containsBooleanIdentifier(std::string const& name) const { bool SimpleValuation::containsBooleanIdentifier(std::string const& name) const {
auto nameValuePair = this->identifierToValueMap.find(name); auto nameValuePair = this->identifierToValueMap.find(name);
if (nameValuePair == this->identifierToValueMap.end()) { if (nameValuePair == this->identifierToValueMap.end()) {
@ -85,6 +100,48 @@ namespace storm {
return boost::get<double>(nameValuePair->second); return boost::get<double>(nameValuePair->second);
} }
std::size_t SimpleValuation::getNumberOfIdentifiers() const {
return this->identifierToValueMap.size();
}
std::set<std::string> SimpleValuation::getIdentifiers() const {
std::set<std::string> result;
for (auto const& nameValuePair : this->identifierToValueMap) {
result.insert(nameValuePair.first);
}
return result;
}
std::set<std::string> SimpleValuation::getBooleanIdentifiers() const {
std::set<std::string> result;
for (auto const& nameValuePair : this->identifierToValueMap) {
if (nameValuePair.second.type() == typeid(bool)) {
result.insert(nameValuePair.first);
}
}
return result;
}
std::set<std::string> SimpleValuation::getIntegerIdentifiers() const {
std::set<std::string> result;
for (auto const& nameValuePair : this->identifierToValueMap) {
if (nameValuePair.second.type() == typeid(int_fast64_t)) {
result.insert(nameValuePair.first);
}
}
return result;
}
std::set<std::string> SimpleValuation::getDoubleIdentifiers() const {
std::set<std::string> result;
for (auto const& nameValuePair : this->identifierToValueMap) {
if (nameValuePair.second.type() == typeid(double)) {
result.insert(nameValuePair.first);
}
}
return result;
}
std::ostream& operator<<(std::ostream& stream, SimpleValuation const& valuation) { std::ostream& operator<<(std::ostream& stream, SimpleValuation const& valuation) {
stream << "{ "; stream << "{ ";
uint_fast64_t elementIndex = 0; uint_fast64_t elementIndex = 0;

14
src/storage/expressions/SimpleValuation.h

@ -6,6 +6,7 @@
#include <iostream> #include <iostream>
#include "src/storage/expressions/Valuation.h" #include "src/storage/expressions/Valuation.h"
#include "src/storage/expressions/ExpressionReturnType.h"
#include "src/utility/OsDetection.h" #include "src/utility/OsDetection.h"
namespace storm { namespace storm {
@ -85,10 +86,23 @@ namespace storm {
*/ */
void removeIdentifier(std::string const& name); void removeIdentifier(std::string const& name);
/*!
* Retrieves the type of the identifier with the given name.
*
* @param name The name of the identifier whose type to retrieve.
* @return The type of the identifier with the given name.
*/
ExpressionReturnType getIdentifierType(std::string const& name) const;
// Override base class methods. // Override base class methods.
virtual bool containsBooleanIdentifier(std::string const& name) const override; virtual bool containsBooleanIdentifier(std::string const& name) const override;
virtual bool containsIntegerIdentifier(std::string const& name) const override; virtual bool containsIntegerIdentifier(std::string const& name) const override;
virtual bool containsDoubleIdentifier(std::string const& name) const override; virtual bool containsDoubleIdentifier(std::string const& name) const override;
virtual std::size_t getNumberOfIdentifiers() const override;
virtual std::set<std::string> getIdentifiers() const override;
virtual std::set<std::string> getBooleanIdentifiers() const override;
virtual std::set<std::string> getIntegerIdentifiers() const override;
virtual std::set<std::string> getDoubleIdentifiers() const override;
virtual bool getBooleanValue(std::string const& name) const override; virtual bool getBooleanValue(std::string const& name) const override;
virtual int_fast64_t getIntegerValue(std::string const& name) const override; virtual int_fast64_t getIntegerValue(std::string const& name) const override;
virtual double getDoubleValue(std::string const& name) const override; virtual double getDoubleValue(std::string const& name) const override;

4
src/storage/expressions/SubstitutionVisitor.cpp

@ -13,8 +13,8 @@ namespace storm {
} }
template<typename MapType> template<typename MapType>
Expression SubstitutionVisitor<MapType>::substitute(BaseExpression const* expression) {
expression->accept(this);
Expression SubstitutionVisitor<MapType>::substitute(Expression const& expression) {
expression.getBaseExpression().accept(this);
return Expression(this->expressionStack.top()); return Expression(this->expressionStack.top());
} }

2
src/storage/expressions/SubstitutionVisitor.h

@ -26,7 +26,7 @@ namespace storm {
* @return The expression in which all identifiers in the key set of the previously given mapping are * @return The expression in which all identifiers in the key set of the previously given mapping are
* substituted with the mapped-to expressions. * substituted with the mapped-to expressions.
*/ */
Expression substitute(BaseExpression const* expression);
Expression substitute(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override; virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override; virtual void visit(BinaryBooleanFunctionExpression const* expression) override;

4
src/storage/expressions/TypeCheckVisitor.cpp

@ -12,8 +12,8 @@ namespace storm {
} }
template<typename MapType> template<typename MapType>
void TypeCheckVisitor<MapType>::check(BaseExpression const* expression) {
expression->accept(this);
void TypeCheckVisitor<MapType>::check(Expression const& expression) {
expression.getBaseExpression().accept(this);
} }
template<typename MapType> template<typename MapType>

2
src/storage/expressions/TypeCheckVisitor.h

@ -24,7 +24,7 @@ namespace storm {
* *
* @param expression The expression in which to check the types. * @param expression The expression in which to check the types.
*/ */
void check(BaseExpression const* expression);
void check(Expression const& expression);
virtual void visit(IfThenElseExpression const* expression) override; virtual void visit(IfThenElseExpression const* expression) override;
virtual void visit(BinaryBooleanFunctionExpression const* expression) override; virtual void visit(BinaryBooleanFunctionExpression const* expression) override;

36
src/storage/expressions/Valuation.h

@ -59,6 +59,42 @@ namespace storm {
*/ */
virtual bool containsDoubleIdentifier(std::string const& name) const = 0; virtual bool containsDoubleIdentifier(std::string const& name) const = 0;
/*!
* Retrieves the number of identifiers in this valuation.
*
* @return The number of identifiers in this valuation.
*/
virtual std::size_t getNumberOfIdentifiers() const = 0;
/*!
* Retrieves the set of all identifiers contained in this valuation.
*
* @return The set of all identifiers contained in this valuation.
*/
virtual std::set<std::string> getIdentifiers() const = 0;
/*!
* Retrieves the set of boolean identifiers contained in this valuation.
*
* @return The set of boolean identifiers contained in this valuation.
*/
virtual std::set<std::string> getBooleanIdentifiers() const = 0;
/*!
* Retrieves the set of integer identifiers contained in this valuation.
*
* @return The set of integer identifiers contained in this valuation.
*/
virtual std::set<std::string> getIntegerIdentifiers() const = 0;
/*!
* Retrieves the set of double identifiers contained in this valuation.
*
* @return The set of double identifiers contained in this valuation.
*/
virtual std::set<std::string> getDoubleIdentifiers() const = 0;
}; };
} }
} }

17
test/functional/storage/ExpressionTest.cpp

@ -3,6 +3,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/LinearityCheckVisitor.h"
#include "src/storage/expressions/SimpleValuation.h" #include "src/storage/expressions/SimpleValuation.h"
#include "src/exceptions/InvalidTypeException.h" #include "src/exceptions/InvalidTypeException.h"
@ -333,3 +334,19 @@ TEST(Expression, SimpleEvaluationTest) {
ASSERT_THROW(tempExpression.evaluateAsInt(&valuation), storm::exceptions::InvalidTypeException); ASSERT_THROW(tempExpression.evaluateAsInt(&valuation), storm::exceptions::InvalidTypeException);
EXPECT_FALSE(tempExpression.evaluateAsBool(&valuation)); EXPECT_FALSE(tempExpression.evaluateAsBool(&valuation));
} }
TEST(Expression, VisitorTest) {
storm::expressions::Expression threeExpression;
storm::expressions::Expression piExpression;
storm::expressions::Expression intVarExpression;
storm::expressions::Expression doubleVarExpression;
ASSERT_NO_THROW(threeExpression = storm::expressions::Expression::createIntegerLiteral(3));
ASSERT_NO_THROW(piExpression = storm::expressions::Expression::createDoubleLiteral(3.14));
ASSERT_NO_THROW(intVarExpression = storm::expressions::Expression::createIntegerVariable("y"));
ASSERT_NO_THROW(doubleVarExpression = storm::expressions::Expression::createDoubleVariable("z"));
storm::expressions::Expression tempExpression = intVarExpression + doubleVarExpression * threeExpression;
storm::expressions::LinearityCheckVisitor visitor;
EXPECT_TRUE(visitor.check(tempExpression));
}
Loading…
Cancel
Save