You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

363 lines
22 KiB

#include <map>
#include <unordered_map>
#include "src/storage/expressions/Expression.h"
#include "src/storage/expressions/ExpressionManager.h"
#include "src/storage/expressions/SubstitutionVisitor.h"
#include "src/storage/expressions/LinearityCheckVisitor.h"
#include "src/storage/expressions/Expressions.h"
#include "src/exceptions/InvalidTypeException.h"
#include "src/exceptions/InvalidArgumentException.h"
#include "src/utility/macros.h"
namespace storm {
namespace expressions {
/*!
* Checks whether the two expressions share the same expression manager.
*
* @param a The first expression.
* @param b The second expression.
* @return True iff the expressions share the same manager.
*/
static void assertSameManager(BaseExpression const& a, BaseExpression const& b) {
STORM_LOG_THROW(a.getManager() == b.getManager(), storm::exceptions::InvalidArgumentException, "Expressions are managed by different manager.");
}
Expression::Expression(std::shared_ptr<BaseExpression const> const& expressionPtr) : expressionPtr(expressionPtr) {
// Intentionally left empty.
}
Expression::Expression(Variable const& variable) : expressionPtr(std::shared_ptr<BaseExpression>(new VariableExpression(variable))) {
// Intentionally left empty.
}
Expression Expression::substitute(std::map<Variable, Expression> const& identifierToExpressionMap) const {
return SubstitutionVisitor<std::map<Variable, Expression>>(identifierToExpressionMap).substitute(*this);
}
Expression Expression::substitute(std::unordered_map<Variable, Expression> const& identifierToExpressionMap) const {
return SubstitutionVisitor<std::unordered_map<Variable, Expression>>(identifierToExpressionMap).substitute(*this);
}
bool Expression::evaluateAsBool(Valuation const* valuation) const {
return this->getBaseExpression().evaluateAsBool(valuation);
}
int_fast64_t Expression::evaluateAsInt(Valuation const* valuation) const {
return this->getBaseExpression().evaluateAsInt(valuation);
}
double Expression::evaluateAsDouble(Valuation const* valuation) const {
return this->getBaseExpression().evaluateAsDouble(valuation);
}
Expression Expression::simplify() const {
return Expression(this->getBaseExpression().simplify());
}
OperatorType Expression::getOperator() const {
return this->getBaseExpression().getOperator();
}
bool Expression::isFunctionApplication() const {
return this->getBaseExpression().isFunctionApplication();
}
uint_fast64_t Expression::getArity() const {
return this->getBaseExpression().getArity();
}
Expression Expression::getOperand(uint_fast64_t operandIndex) const {
return Expression(this->getBaseExpression().getOperand(operandIndex));
}
std::string const& Expression::getIdentifier() const {
return this->getBaseExpression().getIdentifier();
}
bool Expression::containsVariables() const {
return this->getBaseExpression().containsVariables();
}
bool Expression::isLiteral() const {
return this->getBaseExpression().isLiteral();
}
bool Expression::isVariable() const {
return this->getBaseExpression().isVariable();
}
bool Expression::isTrue() const {
return this->getBaseExpression().isTrue();
}
bool Expression::isFalse() const {
return this->getBaseExpression().isFalse();
}
std::set<storm::expressions::Variable> Expression::getVariables() const {
std::set<storm::expressions::Variable> result;
this->getBaseExpression().gatherVariables(result);
return result;
}
bool Expression::containsVariable(std::set<storm::expressions::Variable> const& variables) const {
std::set<storm::expressions::Variable> appearingVariables = this->getVariables();
std::set<storm::expressions::Variable> intersection;
std::set_intersection(variables.begin(), variables.end(), appearingVariables.begin(), appearingVariables.end(), std::inserter(intersection, intersection.begin()));
return !intersection.empty();
}
bool Expression::isRelationalExpression() const {
if (!this->isFunctionApplication()) {
return false;
}
return this->getOperator() == OperatorType::Equal || this->getOperator() == OperatorType::NotEqual
|| this->getOperator() == OperatorType::Less || this->getOperator() == OperatorType::LessOrEqual
|| this->getOperator() == OperatorType::Greater || this->getOperator() == OperatorType::GreaterOrEqual;
}
bool Expression::isLinear() const {
return LinearityCheckVisitor().check(*this);
}
BaseExpression const& Expression::getBaseExpression() const {
return *this->expressionPtr;
}
std::shared_ptr<BaseExpression const> const& Expression::getBaseExpressionPointer() const {
return this->expressionPtr;
}
ExpressionManager const& Expression::getManager() const {
return this->getBaseExpression().getManager();
}
Type const& Expression::getType() const {
return this->getBaseExpression().getType();
}
bool Expression::hasNumericalType() const {
return this->getBaseExpression().hasNumericalType();
}
bool Expression::hasRationalType() const {
return this->getBaseExpression().hasRationalType();
}
bool Expression::hasBooleanType() const {
return this->getBaseExpression().hasBooleanType();
}
bool Expression::hasIntegerType() const {
return this->getBaseExpression().hasIntegerType();
}
bool Expression::hasBitVectorType() const {
return this->getBaseExpression().hasBitVectorType();
}
boost::any Expression::accept(ExpressionVisitor& visitor) const {
return this->getBaseExpression().accept(visitor);
}
bool Expression::isInitialized() const {
if(this->getBaseExpressionPointer()) {
return true;
}
return false;
}
std::string Expression::toString() const {
std::stringstream stream;
stream << *this;
return stream.str();
}
std::ostream& operator<<(std::ostream& stream, Expression const& expression) {
stream << expression.getBaseExpression();
return stream;
}
Expression operator+(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(first.getManager(), first.getType().plusMinusTimes(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Plus)));
}
Expression operator-(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().plusMinusTimes(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Minus)));
}
Expression operator-(Expression const& first) {
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().minus(), first.getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Minus)));
}
Expression operator*(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().plusMinusTimes(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Times)));
}
Expression operator/(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().divide(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Divide)));
}
Expression operator^(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().power(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Power)));
}
Expression operator&&(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(first.getBaseExpression().getManager(), first.getType().logicalConnective(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::And)));
}
Expression operator||(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(first.getBaseExpression().getManager(), first.getType().logicalConnective(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Or)));
}
Expression operator!(Expression const& first) {
return Expression(std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(first.getBaseExpression().getManager(), first.getType().logicalConnective(), first.getBaseExpressionPointer(), UnaryBooleanFunctionExpression::OperatorType::Not)));
}
Expression operator==(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(first.getBaseExpression().getManager(), first.getType().numericalComparison(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Equal)));
}
Expression operator!=(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
if (first.hasNumericalType() && second.hasNumericalType()) {
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(first.getBaseExpression().getManager(), first.getType().numericalComparison(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual)));
} else {
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(first.getBaseExpression().getManager(), first.getType().logicalConnective(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor)));
}
}
Expression operator>(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(first.getBaseExpression().getManager(), first.getType().numericalComparison(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Greater)));
}
Expression operator>=(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(first.getBaseExpression().getManager(), first.getType().numericalComparison(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::GreaterOrEqual)));
}
Expression operator<(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(first.getBaseExpression().getManager(), first.getType().numericalComparison(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Less)));
}
Expression operator<=(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(first.getBaseExpression().getManager(), first.getType().numericalComparison(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::LessOrEqual)));
}
Expression operator<(Expression const& first, int64_t second) {
return first < Expression(std::shared_ptr<BaseExpression>(new IntegerLiteralExpression(first.getBaseExpression().getManager(), second)));
}
Expression operator>(Expression const& first, int64_t second) {
return first > Expression(std::shared_ptr<BaseExpression>(new IntegerLiteralExpression(first.getBaseExpression().getManager(), second)));
}
Expression operator<=(Expression const& first, int64_t second) {
return first <= Expression(std::shared_ptr<BaseExpression>(new IntegerLiteralExpression(first.getBaseExpression().getManager(), second)));
}
Expression operator>=(Expression const& first, int64_t second) {
return first >= Expression(std::shared_ptr<BaseExpression>(new IntegerLiteralExpression(first.getBaseExpression().getManager(), second)));
}
Expression minimum(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().minimumMaximum(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Min)));
}
Expression maximum(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().minimumMaximum(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Max)));
}
Expression ite(Expression const& condition, Expression const& thenExpression, Expression const& elseExpression) {
assertSameManager(condition.getBaseExpression(), thenExpression.getBaseExpression());
assertSameManager(thenExpression.getBaseExpression(), elseExpression.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new IfThenElseExpression(condition.getBaseExpression().getManager(), condition.getType().ite(thenExpression.getType(), elseExpression.getType()), condition.getBaseExpressionPointer(), thenExpression.getBaseExpressionPointer(), elseExpression.getBaseExpressionPointer())));
}
Expression implies(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(first.getBaseExpression().getManager(), first.getType().logicalConnective(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Implies)));
}
Expression iff(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(first.getBaseExpression().getManager(), first.getType().logicalConnective(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Iff)));
}
Expression xclusiveor(Expression const& first, Expression const& second) {
assertSameManager(first.getBaseExpression(), second.getBaseExpression());
return Expression(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(first.getBaseExpression().getManager(), first.getType().logicalConnective(second.getType()), first.getBaseExpressionPointer(), second.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor)));
}
Expression floor(Expression const& first) {
STORM_LOG_THROW(first.hasNumericalType(), storm::exceptions::InvalidTypeException, "Operator 'floor' requires numerical operand.");
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().floorCeil(), first.getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Floor)));
}
Expression ceil(Expression const& first) {
STORM_LOG_THROW(first.hasNumericalType(), storm::exceptions::InvalidTypeException, "Operator 'ceil' requires numerical operand.");
return Expression(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(first.getBaseExpression().getManager(), first.getType().floorCeil(), first.getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil)));
}
Expression abs(Expression const& first) {
STORM_LOG_THROW(first.hasNumericalType(), storm::exceptions::InvalidTypeException, "Abs is only defined for numerical operands");
return ite(first < first.getManager().integer(0), -first, first);
}
Expression sign(Expression const& first) {
STORM_LOG_THROW(first.hasNumericalType(), storm::exceptions::InvalidTypeException, "Sign is only defined for numerical operands");
// TODO implement (via Ite?)
STORM_LOG_ERROR("Not yet implemented");
}
Expression truncate(Expression const& first) {
STORM_LOG_THROW(first.hasNumericalType(), storm::exceptions::InvalidTypeException, "Truncate is only defined for numerical operands");
// TODO implement (via Ite?)
STORM_LOG_ERROR("Not yet implemented");
}
Expression disjunction(std::vector<storm::expressions::Expression> const& expressions) {
return apply(expressions, [] (Expression const& e1, Expression const& e2) { return e1 || e2; });
}
Expression conjunction(std::vector<storm::expressions::Expression> const& expressions) {
return apply(expressions, [] (Expression const& e1, Expression const& e2) { return e1 && e2; });
}
Expression sum(std::vector<storm::expressions::Expression> const& expressions) {
return apply(expressions, [] (Expression const& e1, Expression const& e2) { return e1 + e2; });
}
Expression apply(std::vector<storm::expressions::Expression> const& expressions, std::function<Expression (Expression const&, Expression const&)> const& function) {
STORM_LOG_THROW(!expressions.empty(), storm::exceptions::InvalidArgumentException, "Cannot build disjunction of empty expression list.");
auto it = expressions.begin();
auto ite = expressions.end();
Expression result = *it;
++it;
for (; it != ite; ++it) {
result = function(result, *it);
}
return result;
}
}
}