diff --git a/src/adapters/MathsatExpressionAdapter.h b/src/adapters/MathsatExpressionAdapter.h index f15276b80..16c449adb 100644 --- a/src/adapters/MathsatExpressionAdapter.h +++ b/src/adapters/MathsatExpressionAdapter.h @@ -24,19 +24,12 @@ namespace storm { class MathsatExpressionAdapter : public storm::expressions::ExpressionVisitor { public: /*! - * Creates an expression adapter that can translate expressions to the format of Z3. + * Creates an expression adapter that can translate expressions to the format of MathSAT. * - * @warning The adapter internally creates helper variables prefixed with `__z3adapter_`. As a consequence, - * having variables with this prefix in the variableToExpressionMap might lead to unexpected results and is - * strictly to be avoided. - * - * @param context A reference to the Z3 context over which to build the expressions. The lifetime of the - * context needs to be guaranteed as long as the instance of this adapter is used. - * @param createVariables If set to true, additional variables will be created for variables that appear in - * expressions and are not yet known to the adapter. - * @param variableToDeclarationMap A mapping from variable names to their corresponding MathSAT declarations (if already existing). + * @param manager The manager that can be used to build expressions. + * @param env The MathSAT environment in which to build the expressions. */ - MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map const& variableToDeclarationMap = std::map()) : env(env), variableToDeclarationMap(variableToDeclarationMap), createVariables(createVariables) { + MathsatExpressionAdapter(storm::expressions::ExpressionManager& manager, msat_env& env) : manager(manager), env(env), variableToDeclarationMap() { // Intentionally left empty. } @@ -51,6 +44,16 @@ namespace storm { STORM_LOG_THROW(!MSAT_ERROR_TERM(result), storm::exceptions::ExpressionEvaluationException, "Could not translate expression to MathSAT's format."); return result; } + + /*! + * Translates the given variable to an equivalent expression for Z3. + * + * @param variable The variable to translate. + * @return An equivalent term for MathSAT. + */ + msat_term translateExpression(storm::expressions::Variable const& variable) { + return msat_make_constant(env, variableToDeclarationMap[variable]); + } virtual boost::any visit(expressions::BinaryBooleanFunctionExpression const& expression) override { msat_term leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); @@ -173,34 +176,7 @@ namespace storm { } virtual boost::any visit(expressions::VariableExpression const& expression) override { - std::map::iterator stringVariablePair = variableToDeclarationMap.find(expression.getVariableName()); - msat_decl result; - - if (stringVariablePair == variableToDeclarationMap.end() && createVariables) { - std::pair::iterator, bool> iteratorAndFlag; - switch (expression.getReturnType()) { - case storm::expressions::ExpressionReturnType::Bool: - iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_bool_type(env)))); - result = iteratorAndFlag.first->second; - break; - case storm::expressions::ExpressionReturnType::Int: - iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_integer_type(env)))); - result = iteratorAndFlag.first->second; - break; - case storm::expressions::ExpressionReturnType::Double: - iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_rational_type(env)))); - result = iteratorAndFlag.first->second; - break; - default: - STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression.getVariableName() << "' with unknown type while trying to create solver variables."); - } - } else { - STORM_LOG_THROW(stringVariablePair != variableToDeclarationMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression.getVariableName() << "'."); - result = stringVariablePair->second; - } - - STORM_LOG_THROW(!MSAT_ERROR_DECL(result), storm::exceptions::ExpressionEvaluationException, "Unable to translate expression to MathSAT format, because a variable could not be translated."); - return msat_make_constant(env, result); + return msat_make_constant(env, variableToDeclarationMap[expression.getVariable()]); } storm::expressions::Expression translateExpression(msat_term const& term) { @@ -221,32 +197,20 @@ namespace storm { } else if (msat_term_is_leq(env, term)) { return translateExpression(msat_term_get_arg(term, 0)) <= translateExpression(msat_term_get_arg(term, 1)); } else if (msat_term_is_true(env, term)) { - return storm::expressions::Expression::createTrue(); + return manager.boolean(true); } else if (msat_term_is_false(env, term)) { - return storm::expressions::Expression::createFalse(); - } else if (msat_term_is_boolean_constant(env, term)) { - char* name = msat_decl_get_name(msat_term_get_decl(term)); - std::string name_str(name); - storm::expressions::Expression result = expressions::Expression::createBooleanVariable(name_str.substr(0, name_str.find('/'))); - msat_free(name); - return result; + return manager.boolean(false); } else if (msat_term_is_constant(env, term)) { - char* name = msat_decl_get_name(msat_term_get_decl(term)); - std::string name_str(name); - storm::expressions::Expression result; - if (msat_is_integer_type(env, msat_term_get_type(term))) { - result = expressions::Expression::createIntegerVariable(name_str.substr(0, name_str.find('/'))); - } else if (msat_is_rational_type(env, msat_term_get_type(term))) { - result = expressions::Expression::createDoubleVariable(name_str.substr(0, name_str.find('/'))); - } + std::string nameString(name); + storm::expressions::Expression result = manager.getVariableExpression(nameString.substr(0, nameString.find('/'))); msat_free(name); return result; } else if (msat_term_is_number(env, term)) { if (msat_is_integer_type(env, msat_term_get_type(term))) { - return expressions::Expression::createIntegerLiteral(std::stoll(msat_term_repr(term))); + return manager.integer(std::stoll(msat_term_repr(term))); } else if (msat_is_rational_type(env, msat_term_get_type(term))) { - return expressions::Expression::createDoubleLiteral(std::stod(msat_term_repr(term))); + return manager.rational(std::stod(msat_term_repr(term))); } } @@ -258,14 +222,14 @@ namespace storm { } private: + // The expression manager to use. + storm::expressions::ExpressionManager& manager; + // The MathSAT environment used. msat_env& env; // A mapping of variable names to their declaration in the MathSAT environment. - std::map variableToDeclarationMap; - - // A flag indicating whether variables are supposed to be created if they are not already known to the adapter. - bool createVariables; + std::unordered_map variableToDeclarationMap; }; #endif } // namespace adapters diff --git a/src/adapters/Z3ExpressionAdapter.h b/src/adapters/Z3ExpressionAdapter.h index b28744ae7..e97231c42 100644 --- a/src/adapters/Z3ExpressionAdapter.h +++ b/src/adapters/Z3ExpressionAdapter.h @@ -1,7 +1,7 @@ #ifndef STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_ #define STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_ -#include +#include // Include the headers of Z3 only if it is available. #ifdef STORM_HAVE_Z3 @@ -11,6 +11,7 @@ #include "storm-config.h" #include "src/storage/expressions/Expressions.h" +#include "src/storage/expressions/ExpressionManager.h" #include "src/utility/macros.h" #include "src/exceptions/ExpressionEvaluationException.h" #include "src/exceptions/InvalidTypeException.h" @@ -25,27 +26,25 @@ namespace storm { /*! * Creates an expression adapter that can translate expressions to the format of Z3. * - * @warning The adapter internally creates helper variables prefixed with `__z3adapter_`. As a consequence, - * having variables with this prefix in the variableToExpressionMap might lead to unexpected results and is - * strictly to be avoided. - * + * @param manager The manager that can be used to build expressions. * @param context A reference to the Z3 context over which to build the expressions. The lifetime of the * context needs to be guaranteed as long as the instance of this adapter is used. - * @param createVariables If set to true, additional variables will be created for variables that appear in - * expressions and are not yet known to the adapter. - * @param variableToExpressionMap A mapping from variable names to their corresponding Z3 expressions (if already existing). */ - Z3ExpressionAdapter(z3::context& context, bool createVariables = true, std::map const& variableToExpressionMap = std::map()) : context(context), additionalAssertions(), additionalVariableCounter(0), variableToExpressionMap(variableToExpressionMap), createVariables(createVariables) { - // Intentionally left empty. + Z3ExpressionAdapter(storm::expressions::ExpressionManager& manager, z3::context& context) : manager(manager), context(context), additionalAssertions(), variableToExpressionMap() { + // Here, we need to create the mapping from all variables known to the manager to their z3 counterparts. + for (auto const& variableTypePair : manager) { + switch (variableTypePair.second) { + case storm::expressions::ExpressionReturnType::Bool: variableToExpressionMap.insert(std::make_pair(variableTypePair.first, context.bool_const(variableTypePair.first.getName().c_str()))); break; + case storm::expressions::ExpressionReturnType::Int: variableToExpressionMap.insert(std::make_pair(variableTypePair.first, context.int_const(variableTypePair.first.getName().c_str()))); break; + case storm::expressions::ExpressionReturnType::Double: variableToExpressionMap.insert(std::make_pair(variableTypePair.first, context.real_const(variableTypePair.first.getName().c_str()))); break; + default: STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << variableTypePair.first.getName() << "' with unknown type while trying to create solver variables."); + } + } } /*! * Translates the given expression to an equivalent expression for Z3. * - * @warning The adapter internally creates helper variables prefixed with `__z3adapter_`. As a consequence, - * having variables with this prefix in the variableToExpressionMap might lead to unexpected results and is - * strictly to be aboost::anyed. - * * @param expression The expression to translate. * @return An equivalent expression for Z3. */ @@ -60,13 +59,25 @@ namespace storm { return result; } + /*! + * Translates the given variable to an equivalent expression for Z3. + * + * @param variable The variable to translate. + * @return An equivalent expression for Z3. + */ + z3::expr translateExpression(storm::expressions::Variable const& variable) { + auto const& variableExpressionPair = variableToExpressionMap.find(variable); + STORM_LOG_ASSERT(variableExpressionPair != variableToExpressionMap.end(), "Unable to find variable."); + return variableExpressionPair->second; + } + storm::expressions::Expression translateExpression(z3::expr const& expr) { if (expr.is_app()) { switch (expr.decl().decl_kind()) { case Z3_OP_TRUE: - return storm::expressions::Expression::createTrue(); + return manager.boolean(true); case Z3_OP_FALSE: - return storm::expressions::Expression::createFalse(); + return manager.boolean(false); case Z3_OP_EQ: return this->translateExpression(expr.arg(0)) == this->translateExpression(expr.arg(1)); case Z3_OP_ITE: @@ -130,7 +141,7 @@ namespace storm { if (expr.is_int() && expr.is_const()) { long long value; if (Z3_get_numeral_int64(expr.ctx(), expr, &value)) { - return storm::expressions::Expression::createIntegerLiteral(value); + return manager.integer(value); } else { STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Expression is constant integer and value does not fit into 64-bit integer."); } @@ -138,7 +149,7 @@ namespace storm { long long num; long long den; if (Z3_get_numeral_rational_int64(expr.ctx(), expr, &num, &den)) { - return storm::expressions::Expression::createDoubleLiteral(static_cast(num) / static_cast(den)); + return manager.rational(static_cast(num) / static_cast(den)); } else { STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Expression is constant real and value does not fit into a fraction with 64-bit integer numerator and denominator."); } @@ -146,16 +157,7 @@ namespace storm { case Z3_OP_UNINTERPRETED: // Currently, we only support uninterpreted constant functions. STORM_LOG_THROW(expr.is_const(), storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Encountered non-constant uninterpreted function."); - if (expr.is_bool()) { - return storm::expressions::Expression::createBooleanVariable(expr.decl().name().str()); - } else if (expr.is_int()) { - return storm::expressions::Expression::createIntegerVariable(expr.decl().name().str()); - } else if (expr.is_real()) { - return storm::expressions::Expression::createDoubleVariable(expr.decl().name().str()); - } else { - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Encountered constant uninterpreted function of unknown sort."); - } - + return manager.getVariable(expr.decl().name().str()).getExpression(); default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Encountered unhandled Z3_decl_kind " << expr.decl().kind() <<"."); break; @@ -262,12 +264,14 @@ namespace storm { case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: return 0 - childResult; case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: { - z3::expr floorVariable = context.int_const(("__z3adapter_floor_" + std::to_string(additionalVariableCounter++)).c_str()); + storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshAuxiliaryVariable(storm::expressions::ExpressionReturnType::Int); + z3::expr floorVariable = context.int_const(freshAuxiliaryVariable.getName().c_str()); additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1)); return floorVariable; } case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:{ - z3::expr ceilVariable = context.int_const(("__z3adapter_ceil_" + std::to_string(additionalVariableCounter++)).c_str()); + storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshAuxiliaryVariable(storm::expressions::ExpressionReturnType::Int); + z3::expr ceilVariable = context.int_const(freshAuxiliaryVariable.getName().c_str()); additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable))); return ceilVariable; } @@ -283,50 +287,22 @@ namespace storm { } virtual boost::any visit(storm::expressions::VariableExpression const& expression) override { - std::map::iterator stringVariablePair = variableToExpressionMap.find(expression.getVariableName()); - z3::expr result(context); - - if (stringVariablePair == variableToExpressionMap.end() && createVariables) { - std::pair::iterator, bool> iteratorAndFlag; - switch (expression.getReturnType()) { - case storm::expressions::ExpressionReturnType::Bool: - iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.bool_const(expression.getVariableName().c_str()))); - result = iteratorAndFlag.first->second; - break; - case storm::expressions::ExpressionReturnType::Int: - iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.int_const(expression.getVariableName().c_str()))); - result = iteratorAndFlag.first->second; - break; - case storm::expressions::ExpressionReturnType::Double: - iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.real_const(expression.getVariableName().c_str()))); - result = iteratorAndFlag.first->second; - break; - default: - STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression.getVariableName() << "' with unknown type while trying to create solver variables."); - } - } else { - STORM_LOG_THROW(stringVariablePair != variableToExpressionMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression.getVariableName() << "'."); - result = stringVariablePair->second; - } - - return result; + return variableToExpressionMap.at(expression.getVariable()); } private: + // The manager that can be used to build expressions. + storm::expressions::ExpressionManager& manager; + // The context that is used to translate the expressions. z3::context& context; - // A stack of assertions that need to be kept separate, because they were only impliclty part of an assertion that was added. + // A vector of assertions that need to be kept separate, because they were only implicitly part of an + // assertion that was added. std::vector additionalAssertions; - // A counter for the variables that were created to identify the additional assertions. - uint_fast64_t additionalVariableCounter; - // A mapping from variable names to their Z3 equivalent. - std::map variableToExpressionMap; - - // A flag that indicates whether new variables are to be created when an unkown variable is encountered. - bool createVariables; + std::unordered_map variableToExpressionMap; }; #endif } // namespace adapters diff --git a/src/solver/MathsatSmtSolver.cpp b/src/solver/MathsatSmtSolver.cpp index 4c095690b..b2793fb33 100644 --- a/src/solver/MathsatSmtSolver.cpp +++ b/src/solver/MathsatSmtSolver.cpp @@ -8,14 +8,14 @@ namespace storm { namespace solver { #ifdef STORM_HAVE_MSAT - MathsatSmtSolver::MathsatAllsatModelReference::MathsatAllsatModelReference(msat_env const& env, msat_term* model, std::unordered_map const& atomNameToSlotMapping) : env(env), model(model), atomNameToSlotMapping(atomNameToSlotMapping) { + MathsatSmtSolver::MathsatAllsatModelReference::MathsatAllsatModelReference(storm::expressions::ExpressionManager const& manager, msat_env const& env, msat_term* model, std::unordered_map const& variableNameToSlotMapping) : ModelReference(manager), env(env), model(model), variableNameToSlotMapping(variableNameToSlotMapping) { // Intentionally left empty. } - bool MathsatSmtSolver::MathsatAllsatModelReference::getBooleanValue(std::string const& name) const { - std::unordered_map::const_iterator nameSlotPair = atomNameToSlotMapping.find(name); - STORM_LOG_THROW(nameSlotPair != atomNameToSlotMapping.end(), storm::exceptions::InvalidArgumentException, "Cannot retrieve value of unknown variable '" << name << "' from model."); - msat_term selectedTerm = model[nameSlotPair->second]; + bool MathsatSmtSolver::MathsatAllsatModelReference::getBooleanValue(storm::expressions::Variable const& variable) const { + std::unordered_map::const_iterator variableSlotPair = variableNameToSlotMapping.find(variable); + STORM_LOG_THROW(variableSlotPair != variableNameToSlotMapping.end(), storm::exceptions::InvalidArgumentException, "Cannot retrieve value of unknown variable '" << variable.getName() << "' from model."); + msat_term selectedTerm = model[variableSlotPair->second]; if (msat_term_is_not(env, selectedTerm)) { return false; @@ -24,49 +24,49 @@ namespace storm { } } - int_fast64_t MathsatSmtSolver::MathsatAllsatModelReference::getIntegerValue(std::string const& name) const { + int_fast64_t MathsatSmtSolver::MathsatAllsatModelReference::getIntegerValue(storm::expressions::Variable const& variable) const { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Unable to retrieve integer value from model that only contains boolean values."); } - double MathsatSmtSolver::MathsatAllsatModelReference::getDoubleValue(std::string const& name) const { + double MathsatSmtSolver::MathsatAllsatModelReference::getRationalValue(storm::expressions::Variable const& variable) const { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Unable to retrieve double value from model that only contains boolean values."); } - MathsatSmtSolver::MathsatModelReference::MathsatModelReference(msat_env const& env, storm::adapters::MathsatExpressionAdapter& expressionAdapter) : env(env), expressionAdapter(expressionAdapter) { + MathsatSmtSolver::MathsatModelReference::MathsatModelReference(storm::expressions::ExpressionManager const& manager, msat_env const& env, storm::adapters::MathsatExpressionAdapter& expressionAdapter) : ModelReference(manager), env(env), expressionAdapter(expressionAdapter) { // Intentionally left empty. } - bool MathsatSmtSolver::MathsatModelReference::getBooleanValue(std::string const& name) const { - msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); + bool MathsatSmtSolver::MathsatModelReference::getBooleanValue(storm::expressions::Variable const& variable) const { + STORM_LOG_ASSERT(variable.hasBooleanType(), "Variable is non-boolean type."); + msat_term msatVariable = expressionAdapter.translateExpression(variable); msat_term msatValue = msat_get_model_value(env, msatVariable); - STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); + STORM_LOG_ASSERT(!MSAT_ERROR_TERM(msatValue), "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue); - STORM_LOG_THROW(value.hasBooleanReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve boolean value of non-boolean variable '" << name << "'."); return value.evaluateAsBool(); } - int_fast64_t MathsatSmtSolver::MathsatModelReference::getIntegerValue(std::string const& name) const { - msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); + int_fast64_t MathsatSmtSolver::MathsatModelReference::getIntegerValue(storm::expressions::Variable const& variable) const { + STORM_LOG_ASSERT(variable.hasBooleanType(), "Variable is non-boolean type."); + msat_term msatVariable = expressionAdapter.translateExpression(variable); msat_term msatValue = msat_get_model_value(env, msatVariable); - STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); + STORM_LOG_ASSERT(!MSAT_ERROR_TERM(msatValue), "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue); - STORM_LOG_THROW(value.hasIntegralReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve integer value of non-integer variable '" << name << "'."); return value.evaluateAsInt(); } - double MathsatSmtSolver::MathsatModelReference::getDoubleValue(std::string const& name) const { + double MathsatSmtSolver::MathsatModelReference::getRationalValue(storm::expressions::Variable const& variable) const { + STORM_LOG_ASSERT(variable.hasBooleanType(), "Variable is non-boolean type."); msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); msat_term msatValue = msat_get_model_value(env, msatVariable); - STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); + STORM_LOG_ASSERT(!MSAT_ERROR_TERM(msatValue), "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue); - STORM_LOG_THROW(value.hasIntegralReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve double value of non-double variable '" << name << "'."); return value.evaluateAsDouble(); } #endif - MathsatSmtSolver::MathsatSmtSolver(Options const& options) + MathsatSmtSolver::MathsatSmtSolver(storm::expressions::ExpressionManager const& manager, Options const& options) #ifdef STORM_HAVE_MSAT - : expressionAdapter(nullptr), lastCheckAssumptions(false), lastResult(CheckResult::Unknown) + : SmtSolver(manager), expressionAdapter(nullptr), lastCheckAssumptions(false), lastResult(CheckResult::Unknown) #endif { #ifdef STORM_HAVE_MSAT @@ -219,7 +219,7 @@ namespace storm { #endif } - storm::expressions::SimpleValuation MathsatSmtSolver::getModelAsValuation() + storm::expressions::Valuation MathsatSmtSolver::getModelAsValuation() { #ifdef STORM_HAVE_MSAT STORM_LOG_THROW(this->lastResult == SmtSolver::CheckResult::Sat, storm::exceptions::InvalidStateException, "Unable to create model for formula that was not determined to be satisfiable."); @@ -239,8 +239,8 @@ namespace storm { } #ifdef STORM_HAVE_MSAT - storm::expressions::SimpleValuation MathsatSmtSolver::convertMathsatModelToValuation() { - storm::expressions::SimpleValuation stormModel; + storm::expressions::Valuation MathsatSmtSolver::convertMathsatModelToValuation() { + storm::expressions::Valuation stormModel; msat_model_iterator modelIterator = msat_create_model_iterator(env); STORM_LOG_THROW(!MSAT_ERROR_MODEL_ITERATOR(modelIterator), storm::exceptions::UnexpectedException, "MathSat returned an illegal model iterator."); @@ -275,11 +275,11 @@ namespace storm { } #endif - std::vector MathsatSmtSolver::allSat(std::vector const& important) + std::vector MathsatSmtSolver::allSat(std::vector const& important) { #ifdef STORM_HAVE_MSAT - std::vector valuations; - this->allSat(important, [&valuations](storm::expressions::SimpleValuation const& valuation) -> bool { valuations.push_back(valuation); return true; }); + std::vector valuations; + this->allSat(important, [&valuations](storm::expressions::Valuation const& valuation) -> bool { valuations.push_back(valuation); return true; }); return valuations; #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without MathSAT support."); @@ -290,14 +290,14 @@ namespace storm { #ifdef STORM_HAVE_MSAT class AllsatValuationCallbackUserData { public: - AllsatValuationCallbackUserData(msat_env& env, std::function const& callback) : env(env), callback(callback) { + AllsatValuationCallbackUserData(msat_env& env, std::function const& callback) : env(env), callback(callback) { // Intentionally left empty. } static int allsatValuationsCallback(msat_term* model, int size, void* user_data) { AllsatValuationCallbackUserData* user = reinterpret_cast(user_data); - storm::expressions::SimpleValuation valuation; + storm::expressions::Valuation valuation; for (int i = 0; i < size; ++i) { bool currentTermValue = true; msat_term currentTerm = model[i]; @@ -323,7 +323,7 @@ namespace storm { msat_env& env; // The function that is to be called when the MathSAT model has been translated to a valuation. - std::function const& callback; + std::function const& callback; }; class AllsatModelReferenceCallbackUserData { @@ -355,7 +355,7 @@ namespace storm { #endif - uint_fast64_t MathsatSmtSolver::allSat(std::vector const& important, std::function const& callback) { + uint_fast64_t MathsatSmtSolver::allSat(std::vector const& important, std::function const& callback) { #ifdef STORM_HAVE_MSAT // Create a backtracking point, because MathSAT will modify the assertions stack during its AllSat procedure. this->push(); diff --git a/src/solver/MathsatSmtSolver.h b/src/solver/MathsatSmtSolver.h index db5015bb9..38d958dbb 100644 --- a/src/solver/MathsatSmtSolver.h +++ b/src/solver/MathsatSmtSolver.h @@ -33,27 +33,27 @@ namespace storm { #ifdef STORM_HAVE_MSAT class MathsatAllsatModelReference : public SmtSolver::ModelReference { public: - MathsatAllsatModelReference(msat_env const& env, msat_term* model, std::unordered_map const& atomNameToSlotMapping); + MathsatAllsatModelReference(storm::expressions::ExpressionManager const& manager, msat_env const& env, msat_term* model, std::unordered_map const& variableNameToSlotMapping); - virtual bool getBooleanValue(std::string const& name) const override; - virtual int_fast64_t getIntegerValue(std::string const& name) const override; - virtual double getDoubleValue(std::string const& name) const override; + virtual bool getBooleanValue(storm::expressions::Variable const& variable) const override; + virtual int_fast64_t getIntegerValue(storm::expressions::Variable const& variable) const override; + virtual double getRationalValue(storm::expressions::Variable const& variable) const override; private: msat_env const& env; msat_term* model; - std::unordered_map const& atomNameToSlotMapping; + std::unordered_map const& variableNameToSlotMapping; }; #endif #ifdef STORM_HAVE_MSAT class MathsatModelReference : public SmtSolver::ModelReference { public: - MathsatModelReference(msat_env const& env, storm::adapters::MathsatExpressionAdapter& expressionAdapter); + MathsatModelReference(storm::expressions::ExpressionManager const& manager, msat_env const& env, storm::adapters::MathsatExpressionAdapter& expressionAdapter); - virtual bool getBooleanValue(std::string const& name) const override; - virtual int_fast64_t getIntegerValue(std::string const& name) const override; - virtual double getDoubleValue(std::string const& name) const override; + virtual bool getBooleanValue(storm::expressions::Variable const& variable) const override; + virtual int_fast64_t getIntegerValue(storm::expressions::Variable const& variable) const override; + virtual double getRationalValue(storm::expressions::Variable const& variable) const override; private: msat_env const& env; @@ -61,7 +61,7 @@ namespace storm { }; #endif - MathsatSmtSolver(Options const& options = Options()); + MathsatSmtSolver(storm::expressions::ExpressionManager const& manager, Options const& options = Options()); virtual ~MathsatSmtSolver(); @@ -81,15 +81,15 @@ namespace storm { virtual CheckResult checkWithAssumptions(std::initializer_list const& assumptions) override; - virtual storm::expressions::SimpleValuation getModelAsValuation() override; + virtual storm::expressions::Valuation getModelAsValuation() override; virtual std::shared_ptr getModel() override; - virtual std::vector allSat(std::vector const& important) override; + virtual std::vector allSat(std::vector const& important) override; - virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback) override; + virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback) override; - virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback) override; + virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback) override; virtual std::vector getUnsatAssumptions() override; @@ -98,7 +98,7 @@ namespace storm { virtual storm::expressions::Expression getInterpolant(std::vector const& groupsA) override; private: - storm::expressions::SimpleValuation convertMathsatModelToValuation(); + storm::expressions::Valuation convertMathsatModelToValuation(); #ifdef STORM_HAVE_MSAT // The MathSAT environment. diff --git a/src/solver/SmtSolver.cpp b/src/solver/SmtSolver.cpp index 46ffac841..0de0d48ea 100644 --- a/src/solver/SmtSolver.cpp +++ b/src/solver/SmtSolver.cpp @@ -6,7 +6,15 @@ namespace storm { namespace solver { - SmtSolver::SmtSolver() { + SmtSolver::ModelReference::ModelReference(storm::expressions::ExpressionManager const& manager) : manager(manager) { + // Intentionally left empty. + } + + storm::expressions::ExpressionManager const& SmtSolver::ModelReference::getManager() const { + return manager; + } + + SmtSolver::SmtSolver(storm::expressions::ExpressionManager& manager) : manager(manager) { // Intentionally left empty. } @@ -32,7 +40,7 @@ namespace storm { } } - storm::expressions::SimpleValuation SmtSolver::getModelAsValuation() { + storm::expressions::Valuation SmtSolver::getModelAsValuation() { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "This solver does not support model generation."); } @@ -40,15 +48,15 @@ namespace storm { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "This solver does not support model generation."); } - std::vector SmtSolver::allSat(std::vector const& important) { + std::vector SmtSolver::allSat(std::vector const& important) { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "This solver does not support model generation."); } - uint_fast64_t SmtSolver::allSat(std::vector const& important, std::function const& callback) { + uint_fast64_t SmtSolver::allSat(std::vector const& important, std::function const& callback) { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "This solver does not support model generation."); } - uint_fast64_t SmtSolver::allSat(std::vector const& important, std::function const& callback) { + uint_fast64_t SmtSolver::allSat(std::vector const& important, std::function const& callback) { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "This solver does not support model generation."); } @@ -68,5 +76,13 @@ namespace storm { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "This solver does not support generation of interpolants."); } + storm::expressions::ExpressionManager const& SmtSolver::getManager() const { + return manager; + } + + storm::expressions::ExpressionManager& SmtSolver::getManager() { + return manager; + } + } // namespace solver } // namespace storm \ No newline at end of file diff --git a/src/solver/SmtSolver.h b/src/solver/SmtSolver.h index 8eec37b11..389f03d8f 100644 --- a/src/solver/SmtSolver.h +++ b/src/solver/SmtSolver.h @@ -3,8 +3,9 @@ #include -#include "storage/expressions/Expressions.h" -#include "storage/expressions/SimpleValuation.h" +#include "src/storage/expressions/Expressions.h" +#include "src/storage/expressions/Valuation.h" +#include "src/storage/expressions/ExpressionManager.h" #include #include @@ -29,18 +30,39 @@ namespace storm { */ class ModelReference { public: - virtual bool getBooleanValue(std::string const& name) const = 0; - virtual int_fast64_t getIntegerValue(std::string const& name) const = 0; - virtual double getDoubleValue(std::string const& name) const = 0; + /*! + * Creates a model reference that uses the given expression manager. + * + * @param manager The manager responsible for the variables whose value can be requested. + */ + ModelReference(storm::expressions::ExpressionManager const& manager); + + virtual bool getBooleanValue(storm::expressions::Variable const& variable) const = 0; + virtual int_fast64_t getIntegerValue(storm::expressions::Variable const& variable) const = 0; + virtual double getRationalValue(storm::expressions::Variable const& variable) const = 0; + + /*! + * Retrieves the expression manager associated with this model reference. + * + * @return The expression manager associated with this model reference. + */ + storm::expressions::ExpressionManager const& getManager() const; + + private: + // The expression manager responsible for the variableswhose value can be requested via this model + // reference. + storm::expressions::ExpressionManager const& manager; }; public: /*! * Constructs a new Smt solver with the given options. * + * @param manager The expression manager responsible for all expressions that in some way or another interact + * with this solver. * @throws storm::exceptions::IllegalArgumentValueException if an option is unsupported for the solver. */ - SmtSolver(); + SmtSolver(storm::expressions::ExpressionManager& manager); /*! * Destructs the solver instance @@ -136,7 +158,7 @@ namespace storm { * * @return A valuation that holds the values of the variables in the current model. */ - virtual storm::expressions::SimpleValuation getModelAsValuation(); + virtual storm::expressions::Valuation getModelAsValuation(); /*! * If the last call to check() or checkWithAssumptions() returned Sat, this method retrieves a model that @@ -159,7 +181,7 @@ namespace storm { * * @returns the set of all valuations of the important atoms, such that the currently asserted formulas are satisfiable */ - virtual std::vector allSat(std::vector const& important); + virtual std::vector allSat(std::vector const& important); /*! * Performs AllSat over the (provided) important atoms. That is, this function determines all models of the @@ -172,7 +194,7 @@ namespace storm { * * @return The number of models of the important atoms that where found. */ - virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback); + virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback); /*! * Performs AllSat over the (provided) important atoms. That is, this function determines all models of the @@ -185,7 +207,7 @@ namespace storm { * * @return The number of models of the important atoms that where found. */ - virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback); + virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback); /*! * If the last call to check() returned Unsat, this function can be used to retrieve the unsatisfiable core @@ -231,6 +253,24 @@ namespace storm { * conjunction of I and B is inconsistent. */ virtual storm::expressions::Expression getInterpolant(std::vector const& groupsA); + + /*! + * Retrieves the expression manager associated with the solver. + * + * @return The expression manager associated with the solver. + */ + storm::expressions::ExpressionManager const& getManager() const; + + /*! + * Retrieves the expression manager associated with the solver. + * + * @return The expression manager associated with the solver. + */ + storm::expressions::ExpressionManager& getManager(); + + private: + // The manager responsible for the expressions that interact with this solver. + storm::expressions::ExpressionManager& manager; }; } } diff --git a/src/solver/Z3SmtSolver.cpp b/src/solver/Z3SmtSolver.cpp index 09daf9ae0..6d8e7e1f4 100644 --- a/src/solver/Z3SmtSolver.cpp +++ b/src/solver/Z3SmtSolver.cpp @@ -6,51 +6,54 @@ namespace storm { namespace solver { #ifdef STORM_HAVE_Z3 - Z3SmtSolver::Z3ModelReference::Z3ModelReference(z3::model model, storm::adapters::Z3ExpressionAdapter& expressionAdapter) : model(model), expressionAdapter(expressionAdapter) { + Z3SmtSolver::Z3ModelReference::Z3ModelReference(storm::expressions::ExpressionManager const& manager, z3::model model, storm::adapters::Z3ExpressionAdapter& expressionAdapter) : ModelReference(manager), model(model), expressionAdapter(expressionAdapter) { // Intentionally left empty. } #endif - bool Z3SmtSolver::Z3ModelReference::getBooleanValue(std::string const& name) const { + bool Z3SmtSolver::Z3ModelReference::getBooleanValue(storm::expressions::Variable const& variable) const { #ifdef STORM_HAVE_Z3 - z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); - z3::expr z3ExprValuation = model.eval(z3Expr); - return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsBool(); + STORM_LOG_ASSERT(variable.getManager() == this->getManager(), "Requested variable is managed by a different manager."); + z3::expr z3Expr = this->expressionAdapter.translateExpression(variable); + z3::expr z3ExprValuation = model.eval(z3Expr, true); + return this->expressionAdapter.translateExpression(z3ExprValuation).isTrue(); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support."); #endif } - int_fast64_t Z3SmtSolver::Z3ModelReference::getIntegerValue(std::string const& name) const { + int_fast64_t Z3SmtSolver::Z3ModelReference::getIntegerValue(storm::expressions::Variable const& variable) const { #ifdef STORM_HAVE_Z3 - z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createIntegerVariable(name)); - z3::expr z3ExprValuation = model.eval(z3Expr); + STORM_LOG_ASSERT(variable.getManager() == this->getManager(), "Requested variable is managed by a different manager."); + z3::expr z3Expr = this->expressionAdapter.translateExpression(variable); + z3::expr z3ExprValuation = model.eval(z3Expr, true); return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsInt(); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support."); #endif } - double Z3SmtSolver::Z3ModelReference::getDoubleValue(std::string const& name) const { + double Z3SmtSolver::Z3ModelReference::getRationalValue(storm::expressions::Variable const& variable) const { #ifdef STORM_HAVE_Z3 - z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createDoubleVariable(name)); - z3::expr z3ExprValuation = model.eval(z3Expr); + STORM_LOG_ASSERT(variable.getManager() == this->getManager(), "Requested variable is managed by a different manager."); + z3::expr z3Expr = this->expressionAdapter.translateExpression(variable); + z3::expr z3ExprValuation = model.eval(z3Expr, true); return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsDouble(); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support."); #endif } - Z3SmtSolver::Z3SmtSolver() + Z3SmtSolver::Z3SmtSolver(storm::expressions::ExpressionManager& manager) #ifdef STORM_HAVE_Z3 - : context(nullptr), solver(nullptr), expressionAdapter(nullptr), lastCheckAssumptions(false), lastResult(CheckResult::Unknown) + : SmtSolver(manager), context(nullptr), solver(nullptr), expressionAdapter(nullptr), lastCheckAssumptions(false), lastResult(CheckResult::Unknown) #endif { z3::config config; config.set("model", true); context = std::unique_ptr(new z3::context(config)); solver = std::unique_ptr(new z3::solver(*context)); - expressionAdapter = std::unique_ptr(new storm::adapters::Z3ExpressionAdapter(*context, true)); + expressionAdapter = std::unique_ptr(new storm::adapters::Z3ExpressionAdapter(this->getManager(), *context)); } Z3SmtSolver::~Z3SmtSolver() { @@ -177,7 +180,7 @@ namespace storm { #endif } - storm::expressions::SimpleValuation Z3SmtSolver::getModelAsValuation() + storm::expressions::Valuation Z3SmtSolver::getModelAsValuation() { #ifdef STORM_HAVE_Z3 STORM_LOG_THROW(this->lastResult == SmtSolver::CheckResult::Sat, storm::exceptions::InvalidStateException, "Unable to create model for formula that was not determined to be satisfiable."); @@ -190,15 +193,15 @@ namespace storm { std::shared_ptr Z3SmtSolver::getModel() { #ifdef STORM_HAVE_Z3 STORM_LOG_THROW(this->lastResult == SmtSolver::CheckResult::Sat, storm::exceptions::InvalidStateException, "Unable to create model for formula that was not determined to be satisfiable."); - return std::shared_ptr(new Z3ModelReference(this->solver->get_model(), *this->expressionAdapter)); + return std::shared_ptr(new Z3ModelReference(this->getManager(), this->solver->get_model(), *this->expressionAdapter)); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support."); #endif } #ifdef STORM_HAVE_Z3 - storm::expressions::SimpleValuation Z3SmtSolver::convertZ3ModelToValuation(z3::model const& model) { - storm::expressions::SimpleValuation stormModel; + storm::expressions::Valuation Z3SmtSolver::convertZ3ModelToValuation(z3::model const& model) { + storm::expressions::Valuation stormModel(this->getManager()); for (unsigned i = 0; i < model.num_consts(); ++i) { z3::func_decl variableI = model.get_const_decl(i); @@ -206,13 +209,13 @@ namespace storm { switch (variableInterpretation.getReturnType()) { case storm::expressions::ExpressionReturnType::Bool: - stormModel.addBooleanIdentifier(variableI.name().str(), variableInterpretation.evaluateAsBool()); + stormModel.setBooleanValue(this->getManager().getVariable(variableI.name().str()), variableInterpretation.isTrue()); break; case storm::expressions::ExpressionReturnType::Int: - stormModel.addIntegerIdentifier(variableI.name().str(), variableInterpretation.evaluateAsInt()); + stormModel.setIntegerValue(this->getManager().getVariable(variableI.name().str()), variableInterpretation.evaluateAsInt()); break; case storm::expressions::ExpressionReturnType::Double: - stormModel.addDoubleIdentifier(variableI.name().str(), variableInterpretation.evaluateAsDouble()); + stormModel.setRationalValue(this->getManager().getVariable(variableI.name().str()), variableInterpretation.evaluateAsDouble()); break; default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Variable interpretation in model is not of type bool, int or double.") @@ -225,22 +228,21 @@ namespace storm { } #endif - std::vector Z3SmtSolver::allSat(std::vector const& important) + std::vector Z3SmtSolver::allSat(std::vector const& important) { #ifdef STORM_HAVE_Z3 - std::vector valuations; - this->allSat(important, [&valuations](storm::expressions::SimpleValuation const& valuation) -> bool { valuations.push_back(valuation); return true; }); + std::vector valuations; + this->allSat(important, [&valuations](storm::expressions::Valuation const& valuation) -> bool { valuations.push_back(valuation); return true; }); return valuations; #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support."); #endif } - uint_fast64_t Z3SmtSolver::allSat(std::vector const& important, std::function const& callback) - { + uint_fast64_t Z3SmtSolver::allSat(std::vector const& important, std::function const& callback) { #ifdef STORM_HAVE_Z3 - for (storm::expressions::Expression const& atom : important) { - STORM_LOG_THROW(atom.isVariable() && atom.hasBooleanReturnType(), storm::exceptions::InvalidArgumentException, "The important atoms for AllSat must be boolean variables."); + for (storm::expressions::Variable const& variable : important) { + STORM_LOG_THROW(variable.hasBooleanType(), storm::exceptions::InvalidArgumentException, "The important atoms for AllSat must be boolean variables."); } uint_fast64_t numberOfModels = 0; @@ -255,21 +257,13 @@ namespace storm { z3::model model = this->solver->get_model(); z3::expr modelExpr = this->context->bool_val(true); - storm::expressions::SimpleValuation valuation; + storm::expressions::Valuation valuation(this->getManager()); - for (storm::expressions::Expression const& importantAtom : important) { - z3::expr z3ImportantAtom = this->expressionAdapter->translateExpression(importantAtom); + for (storm::expressions::Variable const& importantAtom : important) { + z3::expr z3ImportantAtom = this->expressionAdapter->translateExpression(importantAtom.getExpression()); z3::expr z3ImportantAtomValuation = model.eval(z3ImportantAtom, true); modelExpr = modelExpr && (z3ImportantAtom == z3ImportantAtomValuation); - if (importantAtom.getReturnType() == storm::expressions::ExpressionReturnType::Bool) { - valuation.addBooleanIdentifier(importantAtom.getIdentifier(), this->expressionAdapter->translateExpression(z3ImportantAtomValuation).evaluateAsBool()); - } else if (importantAtom.getReturnType() == storm::expressions::ExpressionReturnType::Int) { - valuation.addIntegerIdentifier(importantAtom.getIdentifier(), this->expressionAdapter->translateExpression(z3ImportantAtomValuation).evaluateAsInt()); - } else if (importantAtom.getReturnType() == storm::expressions::ExpressionReturnType::Double) { - valuation.addDoubleIdentifier(importantAtom.getIdentifier(), this->expressionAdapter->translateExpression(z3ImportantAtomValuation).evaluateAsDouble()); - } else { - STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Important atom has invalid type."); - } + valuation.setBooleanValue(importantAtom, this->expressionAdapter->translateExpression(z3ImportantAtomValuation).isTrue()); } // Check if we are required to proceed, and if so rule out the current model. @@ -287,12 +281,11 @@ namespace storm { #endif } - uint_fast64_t Z3SmtSolver::allSat(std::vector const& important, std::function const& callback) - { + uint_fast64_t Z3SmtSolver::allSat(std::vector const& important, std::function const& callback) { #ifdef STORM_HAVE_Z3 - for (storm::expressions::Expression const& atom : important) { - STORM_LOG_THROW(atom.isVariable() && atom.hasBooleanReturnType(), storm::exceptions::InvalidArgumentException, "The important atoms for AllSat must be boolean variables."); - } + for (storm::expressions::Variable const& variable : important) { + STORM_LOG_THROW(variable.hasBooleanType(), storm::exceptions::InvalidArgumentException, "The important atoms for AllSat must be boolean variables."); + } uint_fast64_t numberOfModels = 0; bool proceed = true; @@ -306,14 +299,14 @@ namespace storm { z3::model model = this->solver->get_model(); z3::expr modelExpr = this->context->bool_val(true); - storm::expressions::SimpleValuation valuation; + storm::expressions::Valuation valuation(this->getManager()); - for (storm::expressions::Expression const& importantAtom : important) { - z3::expr z3ImportantAtom = this->expressionAdapter->translateExpression(importantAtom); + for (storm::expressions::Variable const& importantAtom : important) { + z3::expr z3ImportantAtom = this->expressionAdapter->translateExpression(importantAtom.getExpression()); z3::expr z3ImportantAtomValuation = model.eval(z3ImportantAtom, true); modelExpr = modelExpr && (z3ImportantAtom == z3ImportantAtomValuation); } - Z3ModelReference modelRef(model, *expressionAdapter); + Z3ModelReference modelRef(this->getManager(), model, *expressionAdapter); // Check if we are required to proceed, and if so rule out the current model. proceed = callback(modelRef); diff --git a/src/solver/Z3SmtSolver.h b/src/solver/Z3SmtSolver.h index b5f222574..725b52e16 100644 --- a/src/solver/Z3SmtSolver.h +++ b/src/solver/Z3SmtSolver.h @@ -17,11 +17,11 @@ namespace storm { class Z3ModelReference : public SmtSolver::ModelReference { public: #ifdef STORM_HAVE_Z3 - Z3ModelReference(z3::model m, storm::adapters::Z3ExpressionAdapter& expressionAdapter); + Z3ModelReference(storm::expressions::ExpressionManager const& manager, z3::model m, storm::adapters::Z3ExpressionAdapter& expressionAdapter); #endif - virtual bool getBooleanValue(std::string const& name) const override; - virtual int_fast64_t getIntegerValue(std::string const& name) const override; - virtual double getDoubleValue(std::string const& name) const override; + virtual bool getBooleanValue(storm::expressions::Variable const& variable) const override; + virtual int_fast64_t getIntegerValue(storm::expressions::Variable const& variable) const override; + virtual double getRationalValue(storm::expressions::Variable const& variable) const override; private: #ifdef STORM_HAVE_Z3 @@ -34,7 +34,7 @@ namespace storm { }; public: - Z3SmtSolver(); + Z3SmtSolver(storm::expressions::ExpressionManager& manager); virtual ~Z3SmtSolver(); virtual void push() override; @@ -53,15 +53,15 @@ namespace storm { virtual CheckResult checkWithAssumptions(std::initializer_list const& assumptions) override; - virtual storm::expressions::SimpleValuation getModelAsValuation() override; + virtual storm::expressions::Valuation getModelAsValuation() override; virtual std::shared_ptr getModel() override; - virtual std::vector allSat(std::vector const& important) override; + virtual std::vector allSat(std::vector const& important) override; - virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback) override; + virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback) override; - virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback) override; + virtual uint_fast64_t allSat(std::vector const& important, std::function const& callback) override; virtual std::vector getUnsatAssumptions() override; @@ -73,7 +73,7 @@ namespace storm { * @param model The Z3 model to convert. * @return The valuation of variables corresponding to the given model. */ - storm::expressions::SimpleValuation convertZ3ModelToValuation(z3::model const& model); + storm::expressions::Valuation convertZ3ModelToValuation(z3::model const& model); // The context used by the solver. std::unique_ptr context; diff --git a/src/storage/expressions/BaseExpression.cpp b/src/storage/expressions/BaseExpression.cpp index 753b4a129..59950fc9a 100644 --- a/src/storage/expressions/BaseExpression.cpp +++ b/src/storage/expressions/BaseExpression.cpp @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - BaseExpression::BaseExpression(ExpressionReturnType returnType) : returnType(returnType) { + BaseExpression::BaseExpression(ExpressionManager const& manager, ExpressionReturnType returnType) : manager(manager), returnType(returnType) { // Intentionally left empty. } @@ -77,6 +77,10 @@ namespace storm { return false; } + ExpressionManager const& BaseExpression::getManager() const { + return manager; + } + std::shared_ptr BaseExpression::getSharedPointer() const { return this->shared_from_this(); } diff --git a/src/storage/expressions/BaseExpression.h b/src/storage/expressions/BaseExpression.h index 86f6f28a4..dd589c345 100644 --- a/src/storage/expressions/BaseExpression.h +++ b/src/storage/expressions/BaseExpression.h @@ -15,7 +15,10 @@ #include "src/utility/OsDetection.h" namespace storm { - namespace expressions { + namespace expressions { + // Forward-declare expression manager. + class ExpressionManager; + /*! * The base class of all expression classes. */ @@ -26,7 +29,7 @@ namespace storm { * * @param returnType The return type of the expression. */ - BaseExpression(ExpressionReturnType returnType); + BaseExpression(ExpressionManager const& manager, ExpressionReturnType returnType); // Create default versions of constructors and assignments. BaseExpression(BaseExpression const&) = default; @@ -149,13 +152,6 @@ namespace storm { */ virtual std::set getVariables() const = 0; - /*! - * Retrieves the mapping of all variables that appear in the expression to their return type. - * - * @return The mapping of all variables that appear in the expression to their return type. - */ - virtual std::map getVariablesAndTypes() const = 0; - /*! * Simplifies the expression according to some simple rules. * @@ -198,6 +194,13 @@ namespace storm { */ std::shared_ptr getSharedPointer() const; + /*! + * Retrieves the manager responsible for this expression. + * + * @return The manager responsible for this expression. + */ + ExpressionManager const& getManager() const; + /*! * Retrieves the return type of the expression. * @@ -206,6 +209,7 @@ namespace storm { ExpressionReturnType getReturnType() const; friend std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression); + protected: /*! * Prints the expression to the given stream. @@ -215,6 +219,9 @@ namespace storm { virtual void printToStream(std::ostream& stream) const = 0; private: + // The manager responsible for this expression. + ExpressionManager const& manager; + // The return type of this expression. ExpressionReturnType returnType; }; diff --git a/src/storage/expressions/BinaryBooleanFunctionExpression.cpp b/src/storage/expressions/BinaryBooleanFunctionExpression.cpp index 953ff9f92..7be4bf766 100644 --- a/src/storage/expressions/BinaryBooleanFunctionExpression.cpp +++ b/src/storage/expressions/BinaryBooleanFunctionExpression.cpp @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - BinaryBooleanFunctionExpression::BinaryBooleanFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType) : BinaryExpression(returnType, firstOperand, secondOperand), operatorType(operatorType) { + BinaryBooleanFunctionExpression::BinaryBooleanFunctionExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType) : BinaryExpression(manager, returnType, firstOperand, secondOperand), operatorType(operatorType) { // Intentionally left empty. } diff --git a/src/storage/expressions/BinaryBooleanFunctionExpression.h b/src/storage/expressions/BinaryBooleanFunctionExpression.h index eb32b7914..9c20d6159 100644 --- a/src/storage/expressions/BinaryBooleanFunctionExpression.h +++ b/src/storage/expressions/BinaryBooleanFunctionExpression.h @@ -16,12 +16,13 @@ namespace storm { /*! * Creates a binary boolean function expression with the given return type, operands and operator. * + * @param manager The manager responsible for this expression. * @param returnType The return type of the expression. * @param firstOperand The first operand of the expression. * @param secondOperand The second operand of the expression. * @param functionType The operator of the expression. */ - BinaryBooleanFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType); + BinaryBooleanFunctionExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType); // Instantiate constructors and assignments with their default implementations. BinaryBooleanFunctionExpression(BinaryBooleanFunctionExpression const& other) = default; diff --git a/src/storage/expressions/BinaryExpression.cpp b/src/storage/expressions/BinaryExpression.cpp index 78fbdaf73..7df83d390 100644 --- a/src/storage/expressions/BinaryExpression.cpp +++ b/src/storage/expressions/BinaryExpression.cpp @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - BinaryExpression::BinaryExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand) : BaseExpression(returnType), firstOperand(firstOperand), secondOperand(secondOperand) { + BinaryExpression::BinaryExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand) : BaseExpression(manager, returnType), firstOperand(firstOperand), secondOperand(secondOperand) { // Intentionally left empty. } @@ -23,13 +23,6 @@ namespace storm { firstVariableSet.insert(secondVariableSet.begin(), secondVariableSet.end()); return firstVariableSet; } - - std::map BinaryExpression::getVariablesAndTypes() const { - std::map firstVariableSet = this->getFirstOperand()->getVariablesAndTypes(); - std::map secondVariableSet = this->getSecondOperand()->getVariablesAndTypes(); - firstVariableSet.insert(secondVariableSet.begin(), secondVariableSet.end()); - return firstVariableSet; - } std::shared_ptr const& BinaryExpression::getFirstOperand() const { return this->firstOperand; diff --git a/src/storage/expressions/BinaryExpression.h b/src/storage/expressions/BinaryExpression.h index 3f9be4d06..0c29d7614 100644 --- a/src/storage/expressions/BinaryExpression.h +++ b/src/storage/expressions/BinaryExpression.h @@ -14,11 +14,12 @@ namespace storm { /*! * Constructs a binary expression with the given return type and operands. * + * @param manager The manager responsible for this expression. * @param returnType The return type of the expression. * @param firstOperand The first operand of the expression. * @param secondOperand The second operand of the expression. */ - BinaryExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand); + BinaryExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand); // Instantiate constructors and assignments with their default implementations. BinaryExpression(BinaryExpression const& other) = default; @@ -35,7 +36,6 @@ namespace storm { virtual uint_fast64_t getArity() const override; virtual std::shared_ptr getOperand(uint_fast64_t operandIndex) const override; virtual std::set getVariables() const override; - virtual std::map getVariablesAndTypes() const override; /*! * Retrieves the first operand of the expression. diff --git a/src/storage/expressions/BinaryNumericalFunctionExpression.cpp b/src/storage/expressions/BinaryNumericalFunctionExpression.cpp index b6c85ebcb..4df59c0a5 100644 --- a/src/storage/expressions/BinaryNumericalFunctionExpression.cpp +++ b/src/storage/expressions/BinaryNumericalFunctionExpression.cpp @@ -7,7 +7,7 @@ namespace storm { namespace expressions { - BinaryNumericalFunctionExpression::BinaryNumericalFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType) : BinaryExpression(returnType, firstOperand, secondOperand), operatorType(operatorType) { + BinaryNumericalFunctionExpression::BinaryNumericalFunctionExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType) : BinaryExpression(manager, returnType, firstOperand, secondOperand), operatorType(operatorType) { // Intentionally left empty. } diff --git a/src/storage/expressions/BinaryNumericalFunctionExpression.h b/src/storage/expressions/BinaryNumericalFunctionExpression.h index 8e129a21d..28c03787c 100644 --- a/src/storage/expressions/BinaryNumericalFunctionExpression.h +++ b/src/storage/expressions/BinaryNumericalFunctionExpression.h @@ -16,12 +16,13 @@ namespace storm { /*! * Constructs a binary numerical function expression with the given return type, operands and operator. * + * @param manager The manager responsible for this expression. * @param returnType The return type of the expression. * @param firstOperand The first operand of the expression. * @param secondOperand The second operand of the expression. * @param functionType The operator of the expression. */ - BinaryNumericalFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType); + BinaryNumericalFunctionExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType); // Instantiate constructors and assignments with their default implementations. BinaryNumericalFunctionExpression(BinaryNumericalFunctionExpression const& other) = default; diff --git a/src/storage/expressions/BinaryRelationExpression.cpp b/src/storage/expressions/BinaryRelationExpression.cpp index bb0545e8e..3524ab08b 100644 --- a/src/storage/expressions/BinaryRelationExpression.cpp +++ b/src/storage/expressions/BinaryRelationExpression.cpp @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - BinaryRelationExpression::BinaryRelationExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, RelationType relationType) : BinaryExpression(returnType, firstOperand, secondOperand), relationType(relationType) { + BinaryRelationExpression::BinaryRelationExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, RelationType relationType) : BinaryExpression(manager, returnType, firstOperand, secondOperand), relationType(relationType) { // Intentionally left empty. } diff --git a/src/storage/expressions/BinaryRelationExpression.h b/src/storage/expressions/BinaryRelationExpression.h index 898cd650f..16c11a9e6 100644 --- a/src/storage/expressions/BinaryRelationExpression.h +++ b/src/storage/expressions/BinaryRelationExpression.h @@ -16,12 +16,13 @@ namespace storm { /*! * Creates a binary relation expression with the given return type, operands and relation type. * + * @param manager The manager responsible for this expression. * @param returnType The return type of the expression. * @param firstOperand The first operand of the expression. * @param secondOperand The second operand of the expression. * @param relationType The operator of the expression. */ - BinaryRelationExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, RelationType relationType); + BinaryRelationExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, RelationType relationType); // Instantiate constructors and assignments with their default implementations. BinaryRelationExpression(BinaryRelationExpression const& other) = default; diff --git a/src/storage/expressions/BooleanLiteralExpression.cpp b/src/storage/expressions/BooleanLiteralExpression.cpp index b1970b87c..f9c72bffe 100644 --- a/src/storage/expressions/BooleanLiteralExpression.cpp +++ b/src/storage/expressions/BooleanLiteralExpression.cpp @@ -2,7 +2,7 @@ namespace storm { namespace expressions { - BooleanLiteralExpression::BooleanLiteralExpression(bool value) : BaseExpression(ExpressionReturnType::Bool), value(value) { + BooleanLiteralExpression::BooleanLiteralExpression(ExpressionManager const& manager, bool value) : BaseExpression(manager, ExpressionReturnType::Bool), value(value) { // Intentionally left empty. } @@ -25,10 +25,6 @@ namespace storm { std::set BooleanLiteralExpression::getVariables() const { return std::set(); } - - std::map BooleanLiteralExpression::getVariablesAndTypes() const { - return std::map(); - } std::shared_ptr BooleanLiteralExpression::simplify() const { return this->shared_from_this(); diff --git a/src/storage/expressions/BooleanLiteralExpression.h b/src/storage/expressions/BooleanLiteralExpression.h index 61a5056c7..f382fed62 100644 --- a/src/storage/expressions/BooleanLiteralExpression.h +++ b/src/storage/expressions/BooleanLiteralExpression.h @@ -11,9 +11,10 @@ namespace storm { /*! * Creates a boolean literal expression with the given value. * + * @param manager The manager responsible for this expression. * @param value The value of the boolean literal. */ - BooleanLiteralExpression(bool value); + BooleanLiteralExpression(ExpressionManager const& manager, bool value); // Instantiate constructors and assignments with their default implementations. BooleanLiteralExpression(BooleanLiteralExpression const& other) = default; @@ -30,7 +31,6 @@ namespace storm { virtual bool isTrue() const override; virtual bool isFalse() const override; virtual std::set getVariables() const override; - virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual boost::any accept(ExpressionVisitor& visitor) const override; diff --git a/src/storage/expressions/DoubleLiteralExpression.cpp b/src/storage/expressions/DoubleLiteralExpression.cpp index 772053b67..5d2af19cf 100644 --- a/src/storage/expressions/DoubleLiteralExpression.cpp +++ b/src/storage/expressions/DoubleLiteralExpression.cpp @@ -2,7 +2,7 @@ namespace storm { namespace expressions { - DoubleLiteralExpression::DoubleLiteralExpression(double value) : BaseExpression(ExpressionReturnType::Double), value(value) { + DoubleLiteralExpression::DoubleLiteralExpression(ExpressionManager const& manager, double value) : BaseExpression(manager, ExpressionReturnType::Double), value(value) { // Intentionally left empty. } @@ -17,10 +17,6 @@ namespace storm { std::set DoubleLiteralExpression::getVariables() const { return std::set(); } - - std::map DoubleLiteralExpression::getVariablesAndTypes() const { - return std::map(); - } std::shared_ptr DoubleLiteralExpression::simplify() const { return this->shared_from_this(); diff --git a/src/storage/expressions/DoubleLiteralExpression.h b/src/storage/expressions/DoubleLiteralExpression.h index 531662b7a..9630dab9f 100644 --- a/src/storage/expressions/DoubleLiteralExpression.h +++ b/src/storage/expressions/DoubleLiteralExpression.h @@ -11,9 +11,10 @@ namespace storm { /*! * Creates an double literal expression with the given value. * + * @param manager The manager responsible for this expression. * @param value The value of the double literal. */ - DoubleLiteralExpression(double value); + DoubleLiteralExpression(ExpressionManager const& manager, double value); // Instantiate constructors and assignments with their default implementations. DoubleLiteralExpression(DoubleLiteralExpression const& other) = default; @@ -28,7 +29,6 @@ namespace storm { virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; virtual bool isLiteral() const override; virtual std::set getVariables() const override; - virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual boost::any accept(ExpressionVisitor& visitor) const override; diff --git a/src/storage/expressions/Expression.cpp b/src/storage/expressions/Expression.cpp index 8e7d41902..f0768c0cd 100644 --- a/src/storage/expressions/Expression.cpp +++ b/src/storage/expressions/Expression.cpp @@ -15,6 +15,10 @@ namespace storm { // Intentionally left empty. } + Expression::Expression(Variable const& variable) : expressionPtr(new VariableExpression(variable)) { + // Intentionally left empty. + } + Expression Expression::substitute(std::map const& identifierToExpressionMap) const { return SubstitutionVisitor>(identifierToExpressionMap).substitute(*this); } @@ -166,124 +170,147 @@ namespace storm { } Expression Expression::operator+(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '+' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Plus))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Plus))); } Expression Expression::operator-(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '-' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Minus))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Minus))); } Expression Expression::operator-() const { STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '-' requires numerical operand."); - return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getReturnType(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Minus))); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Minus))); } Expression Expression::operator*(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '*' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Times))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Times))); } Expression Expression::operator/(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '/' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Divide))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Divide))); } Expression Expression::operator^(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '^' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Power))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Power))); } Expression Expression::operator&&(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::And))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::And))); } Expression Expression::operator||(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '||' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Or))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Or))); } Expression Expression::operator!() const { STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '!' requires boolean operand."); - return Expression(std::shared_ptr(new UnaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), UnaryBooleanFunctionExpression::OperatorType::Not))); + return Expression(std::shared_ptr(new UnaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), UnaryBooleanFunctionExpression::OperatorType::Not))); } Expression Expression::operator==(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '==' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Equal))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Equal))); } Expression Expression::operator!=(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW((this->hasNumericalReturnType() && other.hasNumericalReturnType()) || (this->hasBooleanReturnType() && other.hasBooleanReturnType()), storm::exceptions::InvalidTypeException, "Operator '!=' requires operands of equal type."); if (this->hasNumericalReturnType() && other.hasNumericalReturnType()) { - return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual))); } else { - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor))); } } Expression Expression::operator>(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '>' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Greater))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Greater))); } Expression Expression::operator>=(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '>=' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::GreaterOrEqual))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::GreaterOrEqual))); } Expression Expression::operator<(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '<' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Less))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Less))); } Expression Expression::operator<=(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '<=' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::LessOrEqual))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::LessOrEqual))); } Expression Expression::minimum(Expression const& lhs, Expression const& rhs) { + assertSameManager(lhs.getBaseExpression(), rhs.getBaseExpression()); STORM_LOG_THROW(lhs.hasNumericalReturnType() && rhs.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'min' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Min))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Min))); } Expression Expression::maximum(Expression const& lhs, Expression const& rhs) { + assertSameManager(lhs.getBaseExpression(), rhs.getBaseExpression()); STORM_LOG_THROW(lhs.hasNumericalReturnType() && rhs.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'max' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Max))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Max))); } Expression Expression::ite(Expression const& thenExpression, Expression const& elseExpression) { + assertSameManager(this->getBaseExpression(), thenExpression.getBaseExpression()); + assertSameManager(thenExpression.getBaseExpression(), elseExpression.getBaseExpression()); STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Condition of if-then-else operator must be of boolean type."); STORM_LOG_THROW(thenExpression.hasBooleanReturnType() && elseExpression.hasBooleanReturnType() || thenExpression.hasNumericalReturnType() && elseExpression.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "'then' and 'else' expression of if-then-else operator must have equal return type."); - return Expression(std::shared_ptr(new IfThenElseExpression(thenExpression.hasBooleanReturnType() && elseExpression.hasBooleanReturnType() ? ExpressionReturnType::Bool : (thenExpression.getReturnType() == ExpressionReturnType::Int && elseExpression.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double), this->getBaseExpressionPointer(), thenExpression.getBaseExpressionPointer(), elseExpression.getBaseExpressionPointer()))); + return Expression(std::shared_ptr(new IfThenElseExpression(this->getBaseExpression().getManager(), thenExpression.hasBooleanReturnType() && elseExpression.hasBooleanReturnType() ? ExpressionReturnType::Bool : (thenExpression.getReturnType() == ExpressionReturnType::Int && elseExpression.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double), this->getBaseExpressionPointer(), thenExpression.getBaseExpressionPointer(), elseExpression.getBaseExpressionPointer()))); } Expression Expression::implies(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Implies))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Implies))); } Expression Expression::iff(Expression const& other) const { + assertSameManager(this->getBaseExpression(), other.getBaseExpression()); STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Iff))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Iff))); } Expression Expression::floor() const { STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'floor' requires numerical operand."); - return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Floor))); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Floor))); } Expression Expression::ceil() const { STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'ceil' requires numerical operand."); - return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil))); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil))); } boost::any Expression::accept(ExpressionVisitor& visitor) const { return this->getBaseExpression().accept(visitor); } + void Expression::assertSameManager(BaseExpression const& a, BaseExpression const& b) { + STORM_LOG_THROW(a.getManager() == b.getManager(), storm::exceptions::InvalidArgumentException, "Expressions are managed by different manager."); + } + std::ostream& operator<<(std::ostream& stream, Expression const& expression) { stream << expression.getBaseExpression(); return stream; diff --git a/src/storage/expressions/Expression.h b/src/storage/expressions/Expression.h index e456f295b..395a6982e 100644 --- a/src/storage/expressions/Expression.h +++ b/src/storage/expressions/Expression.h @@ -11,16 +11,16 @@ namespace storm { namespace expressions { + // Foward-declare expression manager class. + class ExpressionManager; + class Variable; + class Expression { public: - Expression() = default; + friend class ExpressionManager; + friend class Variable; - /*! - * Creates an expression with the given underlying base expression. - * - * @param expressionPtr A pointer to the underlying base expression. - */ - Expression(std::shared_ptr const& expressionPtr); + Expression() = default; // Instantiate constructors and assignments with their default implementations. Expression(Expression const& other) = default; @@ -30,17 +30,6 @@ namespace storm { Expression& operator=(Expression&&) = default; #endif - // Static factory methods to create atomic expression parts. - static Expression createBooleanLiteral(bool value); - static Expression createTrue(); - static Expression createFalse(); - static Expression createIntegerLiteral(int_fast64_t value); - static Expression createDoubleLiteral(double value); - static Expression createBooleanVariable(std::string const& variableName); - static Expression createIntegerVariable(std::string const& variableName); - static Expression createDoubleVariable(std::string const& variableName); - static Expression createUndefinedVariable(std::string const& variableName); - // Provide operator overloads to conveniently construct new expressions from other expressions. Expression operator+(Expression const& other) const; Expression operator-(Expression const& other) const; @@ -236,19 +225,6 @@ namespace storm { * @return The set of all variables that appear in the expression. */ std::set getVariables() const; - - /*! - * Retrieves the mapping of all variables that appear in the expression to their return type. - * - * @param validate If this parameter is true, check() is called with the returnvalue before - * it is returned. - * - * @throws storm::exceptions::InvalidTypeException If a variables with the same name but different - * types occur somewhere withing the expression. - * - * @return The mapping of all variables that appear in the expression to their return type. - */ - std::map getVariablesAndTypes(bool validate = true) const; /*! * Retrieves the base expression underlying this expression object. Note that prior to calling this, the @@ -303,8 +279,34 @@ namespace storm { friend std::ostream& operator<<(std::ostream& stream, Expression const& expression); private: + /*! + * Creates an expression with the given underlying base expression. + * + * @param expressionPtr A pointer to the underlying base expression. + */ + Expression(std::shared_ptr const& expressionPtr); + + /*! + * Creates an expression representing the given variable. + * + * @param variable The variable to represent. + */ + Expression(Variable const& variable); + + /*! + * Checks whether the two expressions share the same expression manager. + * + * @param a The first expression. + * @param b The second expression. + * @return True iff the expressions share the same manager. + */ + static void assertSameManager(BaseExpression const& a, BaseExpression const& b); + // A pointer to the underlying base expression. std::shared_ptr expressionPtr; + + // A pointer to the responsible manager. + std::shared_ptr manager; }; } } diff --git a/src/storage/expressions/ExpressionManager.cpp b/src/storage/expressions/ExpressionManager.cpp new file mode 100644 index 000000000..9e518871e --- /dev/null +++ b/src/storage/expressions/ExpressionManager.cpp @@ -0,0 +1,247 @@ +#include "src/storage/expressions/ExpressionManager.h" + +#include "src/storage/expressions/Expressions.h" +#include "src/storage/expressions/Variable.h" +#include "src/utility/macros.h" +#include "src/exceptions/InvalidStateException.h" + +namespace storm { + namespace expressions { + + VariableIterator::VariableIterator(ExpressionManager const& manager, std::unordered_map::const_iterator nameIndexIterator, std::unordered_map::const_iterator nameIndexIteratorEnd, VariableSelection const& selection) : manager(manager), nameIndexIterator(nameIndexIterator), nameIndexIteratorEnd(nameIndexIteratorEnd), selection(selection) { + moveUntilNextSelectedElement(false); + } + + bool VariableIterator::operator==(VariableIterator const& other) { + return this->nameIndexIterator == other.nameIndexIterator; + } + + bool VariableIterator::operator!=(VariableIterator const& other) { + return !(*this == other); + } + + VariableIterator::value_type& VariableIterator::operator*() { + return currentElement; + } + + VariableIterator& VariableIterator::operator++(int) { + moveUntilNextSelectedElement(); + return *this; + } + + VariableIterator& VariableIterator::operator++() { + moveUntilNextSelectedElement(); + return *this; + } + + void VariableIterator::moveUntilNextSelectedElement(bool atLeastOneStep) { + if (atLeastOneStep && nameIndexIterator != nameIndexIteratorEnd) { + ++nameIndexIterator; + } + + // Move the underlying iterator forward until a variable matches the selection. + while (nameIndexIterator != nameIndexIteratorEnd + && (selection == VariableSelection::OnlyRegularVariables && (nameIndexIterator->second & ExpressionManager::auxiliaryMask) != 0) + && (selection == VariableSelection::OnlyAuxiliaryVariables && (nameIndexIterator->second & ExpressionManager::auxiliaryMask) == 0)) { + ++nameIndexIterator; + } + + ExpressionReturnType type = ExpressionReturnType::Undefined; + if ((nameIndexIterator->second & ExpressionManager::booleanMask) != 0) { + type = ExpressionReturnType::Bool; + } else if ((nameIndexIterator->second & ExpressionManager::integerMask) != 0) { + type = ExpressionReturnType::Int; + } else if ((nameIndexIterator->second & ExpressionManager::rationalMask) != 0) { + type = ExpressionReturnType::Double; + } + + if (nameIndexIterator != nameIndexIteratorEnd) { + currentElement = std::make_pair(Variable(manager, nameIndexIterator->second), type); + } + } + + ExpressionManager::ExpressionManager() : nameToIndexMapping(), variableTypeToCountMapping(), auxiliaryVariableTypeToCountMapping() { + variableTypeToCountMapping[static_cast(storm::expressions::ExpressionReturnType::Bool)] = 0; + variableTypeToCountMapping[static_cast(storm::expressions::ExpressionReturnType::Int)] = 0; + variableTypeToCountMapping[static_cast(storm::expressions::ExpressionReturnType::Double)] = 0; + auxiliaryVariableTypeToCountMapping[static_cast(storm::expressions::ExpressionReturnType::Bool)] = 0; + auxiliaryVariableTypeToCountMapping[static_cast(storm::expressions::ExpressionReturnType::Int)] = 0; + auxiliaryVariableTypeToCountMapping[static_cast(storm::expressions::ExpressionReturnType::Double)] = 0; + } + + Expression ExpressionManager::boolean(bool value) const { + return Expression(std::shared_ptr(new BooleanLiteralExpression(*this, value))); + } + + Expression ExpressionManager::integer(int_fast64_t value) const { + return Expression(std::shared_ptr(new IntegerLiteralExpression(*this, value))); + } + + Expression ExpressionManager::rational(double value) const { + return Expression(std::shared_ptr(new DoubleLiteralExpression(*this, value))); + } + + bool ExpressionManager::operator==(ExpressionManager const& other) const { + return this == &other; + } + + bool ExpressionManager::isValidVariableName(std::string const& name) { + return name.size() < 2 || name.at(0) != '_' || name.at(1) != '_'; + } + + bool ExpressionManager::variableExists(std::string const& name) const { + auto nameIndexPair = nameToIndexMapping.find(name); + return nameIndexPair != nameToIndexMapping.end(); + } + + Variable ExpressionManager::declareVariable(std::string const& name, storm::expressions::ExpressionReturnType const& variableType) { + STORM_LOG_THROW(!variableExists(name), storm::exceptions::InvalidArgumentException, "Variable with name '" << name << "' already exists."); + return declareOrGetVariable(name, variableType); + } + + Variable ExpressionManager::declareAuxiliaryVariable(std::string const& name, storm::expressions::ExpressionReturnType const& variableType) { + STORM_LOG_THROW(!variableExists(name), storm::exceptions::InvalidArgumentException, "Variable with name '" << name << "' already exists."); + return declareOrGetAuxiliaryVariable(name, variableType); + } + + Variable ExpressionManager::declareOrGetVariable(std::string const& name, storm::expressions::ExpressionReturnType const& variableType) { + STORM_LOG_THROW(isValidVariableName(name), storm::exceptions::InvalidArgumentException, "Invalid variable name '" << name << "'."); + uint_fast64_t newIndex = 0; + switch (variableType) { + case ExpressionReturnType::Bool: + newIndex = variableTypeToCountMapping[static_cast(ExpressionReturnType::Bool)]++ | booleanMask; + break; + case ExpressionReturnType::Int: + newIndex = variableTypeToCountMapping[static_cast(ExpressionReturnType::Int)]++ | integerMask; + break; + case ExpressionReturnType::Double: + newIndex = variableTypeToCountMapping[static_cast(ExpressionReturnType::Double)]++ | rationalMask; + break; + case ExpressionReturnType::Undefined: + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal variable type."); + } + + nameToIndexMapping[name] = newIndex; + indexToNameMapping[newIndex] = name; + return Variable(*this, newIndex); + } + + Variable ExpressionManager::declareOrGetAuxiliaryVariable(std::string const& name, storm::expressions::ExpressionReturnType const& variableType) { + auto nameIndexPair = nameToIndexMapping.find(name); + if (nameIndexPair != nameToIndexMapping.end()) { + return Variable(*this, nameIndexPair->second); + } else { + STORM_LOG_THROW(isValidVariableName(name), storm::exceptions::InvalidArgumentException, "Invalid variable name '" << name << "'."); + uint_fast64_t newIndex = auxiliaryMask; + switch (variableType) { + case ExpressionReturnType::Bool: + newIndex |= auxiliaryVariableTypeToCountMapping[static_cast(ExpressionReturnType::Bool)]++ | booleanMask; + break; + case ExpressionReturnType::Int: + newIndex |= auxiliaryVariableTypeToCountMapping[static_cast(ExpressionReturnType::Int)]++ | integerMask; + break; + case ExpressionReturnType::Double: + newIndex |= auxiliaryVariableTypeToCountMapping[static_cast(ExpressionReturnType::Double)]++ | rationalMask; + break; + case ExpressionReturnType::Undefined: + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal variable type."); + } + + nameToIndexMapping[name] = newIndex; + indexToNameMapping[newIndex] = name; + return Variable(*this, newIndex); + } + } + + Variable ExpressionManager::getVariable(std::string const& name) const { + auto nameIndexPair = nameToIndexMapping.find(name); + STORM_LOG_THROW(nameIndexPair != nameToIndexMapping.end(), storm::exceptions::InvalidArgumentException, "Unknown variable '" << name << "'."); + return Variable(*this, nameIndexPair->second); + } + + Expression ExpressionManager::getVariableExpression(std::string const& name) const { + return Expression(getVariable(name)); + } + + Variable ExpressionManager::declareFreshVariable(storm::expressions::ExpressionReturnType const& variableType) { + std::string newName = "__x" + std::to_string(freshVariableCounter++); + return declareVariable(newName, variableType); + } + + Variable ExpressionManager::declareFreshAuxiliaryVariable(storm::expressions::ExpressionReturnType const& variableType) { + std::string newName = "__x" + std::to_string(freshVariableCounter++); + return declareAuxiliaryVariable(newName, variableType); + } + + uint_fast64_t ExpressionManager::getNumberOfVariables(storm::expressions::ExpressionReturnType const& variableType) const { + return variableTypeToCountMapping[static_cast(variableType)]; + } + + uint_fast64_t ExpressionManager::getNumberOfVariables() const { + return numberOfVariables; + } + + uint_fast64_t ExpressionManager::getNumberOfBooleanVariables() const { + return getNumberOfVariables(storm::expressions::ExpressionReturnType::Bool); + } + + uint_fast64_t ExpressionManager::getNumberOfIntegerVariables() const { + return getNumberOfVariables(storm::expressions::ExpressionReturnType::Int); + } + + uint_fast64_t ExpressionManager::getNumberOfRationalVariables() const { + return getNumberOfVariables(storm::expressions::ExpressionReturnType::Double); + } + + uint_fast64_t ExpressionManager::getNumberOfAuxiliaryVariables(storm::expressions::ExpressionReturnType const& variableType) const { + return auxiliaryVariableTypeToCountMapping[static_cast(variableType)]; + } + + uint_fast64_t ExpressionManager::getNumberOfAuxiliaryVariables() const { + return numberOfAuxiliaryVariables; + } + + uint_fast64_t ExpressionManager::getNumberOfAuxiliaryBooleanVariables() const { + return getNumberOfAuxiliaryVariables(storm::expressions::ExpressionReturnType::Bool); + } + + uint_fast64_t ExpressionManager::getNumberOfAuxiliaryIntegerVariables() const { + return getNumberOfAuxiliaryVariables(storm::expressions::ExpressionReturnType::Int); + } + + uint_fast64_t ExpressionManager::getNumberOfAuxiliaryRationalVariables() const { + return getNumberOfAuxiliaryVariables(storm::expressions::ExpressionReturnType::Double); + } + + std::string const& ExpressionManager::getVariableName(uint_fast64_t index) const { + auto indexTypeNamePair = indexToNameMapping.find(index); + STORM_LOG_THROW(indexTypeNamePair != indexToNameMapping.end(), storm::exceptions::InvalidArgumentException, "Unknown variable index '" << index << "'."); + return indexTypeNamePair->second; + } + + ExpressionReturnType ExpressionManager::getVariableType(uint_fast64_t index) const { + if ((index & booleanMask) != 0) { + return ExpressionReturnType::Bool; + } else if ((index & integerMask) != 0) { + return ExpressionReturnType::Int; + } else if ((index & rationalMask) != 0) { + return ExpressionReturnType::Double; + } else { + return ExpressionReturnType::Undefined; + } + } + + uint_fast64_t ExpressionManager::getOffset(uint_fast64_t index) const { + return index & offsetMask; + } + + ExpressionManager::const_iterator ExpressionManager::begin() const { + return ExpressionManager::const_iterator(*this, this->nameToIndexMapping.end(), this->nameToIndexMapping.begin(), const_iterator::VariableSelection::OnlyRegularVariables); + } + + ExpressionManager::const_iterator ExpressionManager::end() const { + return ExpressionManager::const_iterator(*this, this->nameToIndexMapping.end(), this->nameToIndexMapping.end(), const_iterator::VariableSelection::OnlyRegularVariables); + } + + } // namespace expressions +} // namespace storm \ No newline at end of file diff --git a/src/storage/expressions/ExpressionManager.h b/src/storage/expressions/ExpressionManager.h new file mode 100644 index 000000000..9722e706b --- /dev/null +++ b/src/storage/expressions/ExpressionManager.h @@ -0,0 +1,333 @@ +#ifndef STORM_STORAGE_EXPRESSIONS_EXPRESSIONMANAGER_H_ +#define STORM_STORAGE_EXPRESSIONS_EXPRESSIONMANAGER_H_ + +#include +#include +#include + +#include "src/storage/expressions/Expression.h" +#include "src/utility/OsDetection.h" + +namespace storm { + namespace expressions { + // Forward-declare manager class for iterator class. + class ExpressionManager; + + class VariableIterator : public std::iterator const> { + public: + enum class VariableSelection { OnlyRegularVariables, OnlyAuxiliaryVariables, AllVariables }; + + VariableIterator(ExpressionManager const& manager, std::unordered_map::const_iterator nameIndexIterator, std::unordered_map::const_iterator nameIndexIteratorEnd, VariableSelection const& selection); + VariableIterator(VariableIterator const& other) = default; + VariableIterator& operator=(VariableIterator const& other) = default; + + // Define the basic input iterator operations. + bool operator==(VariableIterator const& other); + bool operator!=(VariableIterator const& other); + value_type& operator*(); + VariableIterator& operator++(int); + VariableIterator& operator++(); + + private: + /*! + * Moves the current element to the next selected element or the end iterator if there is no such element. + * + * @param atLeastOneStep A flag indicating whether at least one step should be made. + */ + void moveUntilNextSelectedElement(bool atLeastOneStep = true); + + // The manager responsible for the variable to iterate over. + ExpressionManager const& manager; + + // The underlying iterator that ranges over all names and the corresponding indices. + std::unordered_map::const_iterator nameIndexIterator; + + // The iterator indicating the end of the underlying iterator range. + std::unordered_map::const_iterator nameIndexIteratorEnd; + + // A field indicating which variables we are supposed to iterate over. + VariableSelection selection; + + // The current element that is shown to the outside upon dereferencing. + std::pair currentElement; + }; + + /*! + * This class is responsible for managing a set of typed variables and all expressions using these variables. + */ + class ExpressionManager { + public: + friend class VariableIterator; + + typedef VariableIterator const_iterator; + + /*! + * Creates a new manager that is unaware of any variables. + */ + ExpressionManager(); + + // Explicitly delete copy construction/assignment, since the manager is supposed to be stored as a pointer + // of some sort. This is because the expression classes store a reference to the manager and it must + // therefore be guaranteed that they do not become invalid, because the manager has been copied. + ExpressionManager(ExpressionManager const& other) = delete; + ExpressionManager& operator=(ExpressionManager const& other) = delete; +#ifndef WINDOWS + // Create default instantiations for the move construction/assignment. + ExpressionManager(ExpressionManager&& other) = default; + ExpressionManager& operator=(ExpressionManager&& other) = default; +#endif + + /*! + * Creates an expression that characterizes the given boolean literal. + * + * @param value The value of the boolean literal. + * @return The resulting expression. + */ + Expression boolean(bool value) const; + + /*! + * Creates an expression that characterizes the given integer literal. + * + * @param value The value of the integer literal. + * @return The resulting expression. + */ + Expression integer(int_fast64_t value) const; + + /*! + * Creates an expression that characterizes the given rational literal. + * + * @param value The value of the rational literal. + * @return The resulting expression. + */ + Expression rational(double value) const; + + /*! + * Compares the two expression managers for equality, which holds iff they are the very same object. + */ + bool operator==(ExpressionManager const& other) const; + + /*! + * Declares a variable with a name that must not yet exist and its corresponding type. Note that the name + * must not start with two underscores since these variables are reserved for internal use only. + * + * @param name The name of the variable. + * @param variableType The type of the variable. + * @return The newly declared variable. + */ + Variable declareVariable(std::string const& name, storm::expressions::ExpressionReturnType const& variableType); + + /*! + * Declares an auxiliary variable with a name that must not yet exist and its corresponding type. + * + * @param name The name of the variable. + * @param variableType The type of the variable. + * @return The newly declared variable. + */ + Variable declareAuxiliaryVariable(std::string const& name, storm::expressions::ExpressionReturnType const& variableType); + + /*! + * Declares a variable with the given name if it does not yet exist. + * + * @param name The name of the variable to declare. + * @param variableType The type of the variable to declare. + * @return The variable. + */ + Variable declareOrGetVariable(std::string const& name, storm::expressions::ExpressionReturnType const& variableType); + + /*! + * Declares a variable with the given name if it does not yet exist. + * + * @param name The name of the variable to declare. + * @param variableType The type of the variable to declare. + * @return The variable. + */ + Variable declareOrGetAuxiliaryVariable(std::string const& name, storm::expressions::ExpressionReturnType const& variableType); + + /*! + * Retrieves the expression that represents the variable with the given name. + * + * @param name The name of the variable to retrieve. + */ + Variable getVariable(std::string const& name) const; + + /*! + * Retrieves an expression that represents the variable with the given name. + * + * @param name The name of the variable + * @return An expression that represents the variable with the given name. + */ + Expression getVariableExpression(std::string const& name) const; + + /*! + * Declares a variable with the given type whose name is guaranteed to be unique and not yet in use. + * + * @param variableType The type of the variable to declare. + * @return The variable. + */ + Variable declareFreshVariable(storm::expressions::ExpressionReturnType const& variableType); + + /*! + * Declares an auxiliary variable with the given type whose name is guaranteed to be unique and not yet in use. + * + * @param variableType The type of the variable to declare. + * @return The variable. + */ + Variable declareFreshAuxiliaryVariable(storm::expressions::ExpressionReturnType const& variableType); + + /*! + * Retrieves the number of variables with the given type. + * + * @return The number of variables with the given type. + */ + uint_fast64_t getNumberOfVariables(storm::expressions::ExpressionReturnType const& variableType) const; + + /*! + * Retrieves the number of variables. + * + * @return The number of variables. + */ + uint_fast64_t getNumberOfVariables() const; + + /*! + * Retrieves the number of boolean variables. + * + * @return The number of boolean variables. + */ + uint_fast64_t getNumberOfBooleanVariables() const; + + /*! + * Retrieves the number of integer variables. + * + * @return The number of integer variables. + */ + uint_fast64_t getNumberOfIntegerVariables() const; + + /*! + * Retrieves the number of rational variables. + * + * @return The number of rational variables. + */ + uint_fast64_t getNumberOfRationalVariables() const; + + /*! + * Retrieves the number of auxiliary variables with the given type. + * + * @return The number of auxiliary variables with the given type. + */ + uint_fast64_t getNumberOfAuxiliaryVariables(storm::expressions::ExpressionReturnType const& variableType) const; + + /*! + * Retrieves the number of auxiliary variables. + * + * @return The number of auxiliary variables. + */ + uint_fast64_t getNumberOfAuxiliaryVariables() const; + + /*! + * Retrieves the number of boolean variables. + * + * @return The number of boolean variables. + */ + uint_fast64_t getNumberOfAuxiliaryBooleanVariables() const; + + /*! + * Retrieves the number of integer variables. + * + * @return The number of integer variables. + */ + uint_fast64_t getNumberOfAuxiliaryIntegerVariables() const; + + /*! + * Retrieves the number of rational variables. + * + * @return The number of rational variables. + */ + uint_fast64_t getNumberOfAuxiliaryRationalVariables() const; + + /*! + * Retrieves the name of the variable with the given index. + * + * @param index The index of the variable whose name to retrieve. + * @return The name of the variable. + */ + std::string const& getVariableName(uint_fast64_t index) const; + + /*! + * Retrieves the type of the variable with the given index. + * + * @param index The index of the variable whose name to retrieve. + * @return The type of the variable. + */ + ExpressionReturnType getVariableType(uint_fast64_t index) const; + + /*! + * Retrieves the offset of the variable with the given index within the group of equally typed variables. + * + * @param index The index of the variable. + * @return The offset of the variable. + */ + uint_fast64_t getOffset(uint_fast64_t index) const; + + /*! + * Retrieves an iterator to all variables managed by this manager. + * + * @return An iterator to all variables managed by this manager. + */ + const_iterator begin() const; + + /*! + * Retrieves an iterator that points beyond the last variable managed by this manager. + * + * @return An iterator that points beyond the last variable managed by this manager. + */ + const_iterator end() const; + + private: + /*! + * Checks whether the given variable name is valid. + * + * @param name The name to check. + * @return True iff the variable name is valid. + */ + static bool isValidVariableName(std::string const& name); + + /*! + * Retrieves whether a variable with the given name exists. + * + * @param name The name of the variable to check for. + * @return True iff a variable with this name exists. + */ + bool variableExists(std::string const& name) const; + + // A mapping from all variable names (auxiliary + normal) to their indices. + std::unordered_map nameToIndexMapping; + + // A mapping from all variable indices to their names. + std::unordered_map indexToNameMapping; + + // Store counts for variables. + std::vector variableTypeToCountMapping; + + // The number of declared variables. + uint_fast64_t numberOfVariables; + + // Store counts for auxiliary variables. + std::vector auxiliaryVariableTypeToCountMapping; + + // The number of declared auxiliary variables. + uint_fast64_t numberOfAuxiliaryVariables; + + // A counter used to create fresh variables. + uint_fast64_t freshVariableCounter; + + // A mask that can be used to query whether a variable is an auxiliary variable. + static const uint_fast64_t auxiliaryMask = (1 << 63); + static const uint_fast64_t booleanMask = (1 << 62); + static const uint_fast64_t integerMask = (1 << 61); + static const uint_fast64_t rationalMask = (1 << 60); + static const uint_fast64_t offsetMask = (1 << 60) - 1; + }; + } +} + +#endif /* STORM_STORAGE_EXPRESSIONS_EXPRESSIONMANAGER_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/ExpressionReturnType.h b/src/storage/expressions/ExpressionReturnType.h index 0cf928ed8..04dd2b3ee 100644 --- a/src/storage/expressions/ExpressionReturnType.h +++ b/src/storage/expressions/ExpressionReturnType.h @@ -8,10 +8,28 @@ namespace storm { /*! * Each node in an expression tree has a uniquely defined type from this enum. */ - enum class ExpressionReturnType {Undefined, Bool, Int, Double}; + enum class ExpressionReturnType { Undefined = 0, Bool = 1, Int = 2, Double = 3}; std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue); } } +namespace std { + // Provide a hashing operator, so we can put variables in unordered collections. + template <> + struct hash { + std::size_t operator()(storm::expressions::ExpressionReturnType const& type) const { + return static_cast(type); + } + }; + + // Provide a less operator, so we can put variables in ordered collections. + template <> + struct less { + std::size_t operator()(storm::expressions::ExpressionReturnType const& type1, storm::expressions::ExpressionReturnType const& type2) const { + return static_cast(type1) < static_cast(type2); + } + }; +} + #endif /* STORM_STORAGE_EXPRESSIONS_EXPRESSIONRETURNTYPE_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/IfThenElseExpression.cpp b/src/storage/expressions/IfThenElseExpression.cpp index 229b63b31..2e20c15b8 100644 --- a/src/storage/expressions/IfThenElseExpression.cpp +++ b/src/storage/expressions/IfThenElseExpression.cpp @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - IfThenElseExpression::IfThenElseExpression(ExpressionReturnType returnType, std::shared_ptr const& condition, std::shared_ptr const& thenExpression, std::shared_ptr const& elseExpression) : BaseExpression(returnType), condition(condition), thenExpression(thenExpression), elseExpression(elseExpression) { + IfThenElseExpression::IfThenElseExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& condition, std::shared_ptr const& thenExpression, std::shared_ptr const& elseExpression) : BaseExpression(manager, returnType), condition(condition), thenExpression(thenExpression), elseExpression(elseExpression) { // Intentionally left empty. } @@ -71,15 +71,6 @@ namespace storm { result.insert(tmp.begin(), tmp.end()); return result; } - - std::map IfThenElseExpression::getVariablesAndTypes() const { - std::map result = this->condition->getVariablesAndTypes(); - std::map tmp = this->thenExpression->getVariablesAndTypes(); - result.insert(tmp.begin(), tmp.end()); - tmp = this->elseExpression->getVariablesAndTypes(); - result.insert(tmp.begin(), tmp.end()); - return result; - } std::shared_ptr IfThenElseExpression::simplify() const { std::shared_ptr conditionSimplified; @@ -94,7 +85,7 @@ namespace storm { if (conditionSimplified.get() == this->condition.get() && thenExpressionSimplified.get() == this->thenExpression.get() && elseExpressionSimplified.get() == this->elseExpression.get()) { return this->shared_from_this(); } else { - return std::shared_ptr(new IfThenElseExpression(this->getReturnType(), conditionSimplified, thenExpressionSimplified, elseExpressionSimplified)); + return std::shared_ptr(new IfThenElseExpression(this->getManager(), this->getReturnType(), conditionSimplified, thenExpressionSimplified, elseExpressionSimplified)); } } } diff --git a/src/storage/expressions/IfThenElseExpression.h b/src/storage/expressions/IfThenElseExpression.h index 47d34e3f3..85a90c2b9 100644 --- a/src/storage/expressions/IfThenElseExpression.h +++ b/src/storage/expressions/IfThenElseExpression.h @@ -11,11 +11,12 @@ namespace storm { /*! * Creates an if-then-else expression with the given return type, condition and operands. * + * @param manager The manager responsible for this expression. * @param returnType The return type of the expression. * @param thenExpression The expression evaluated if the condition evaluates true. * @param elseExpression The expression evaluated if the condition evaluates false. */ - IfThenElseExpression(ExpressionReturnType returnType, std::shared_ptr const& condition, std::shared_ptr const& thenExpression, std::shared_ptr const& elseExpression); + IfThenElseExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& condition, std::shared_ptr const& thenExpression, std::shared_ptr const& elseExpression); // Instantiate constructors and assignments with their default implementations. IfThenElseExpression(IfThenElseExpression const& other) = default; @@ -36,7 +37,6 @@ namespace storm { virtual int_fast64_t evaluateAsInt(Valuation const* valuation = nullptr) const override; virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; virtual std::set getVariables() const override; - virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual boost::any accept(ExpressionVisitor& visitor) const override; diff --git a/src/storage/expressions/IntegerLiteralExpression.cpp b/src/storage/expressions/IntegerLiteralExpression.cpp index 9f4a22d6c..36b8fe2bf 100644 --- a/src/storage/expressions/IntegerLiteralExpression.cpp +++ b/src/storage/expressions/IntegerLiteralExpression.cpp @@ -2,7 +2,7 @@ namespace storm { namespace expressions { - IntegerLiteralExpression::IntegerLiteralExpression(int_fast64_t value) : BaseExpression(ExpressionReturnType::Int), value(value) { + IntegerLiteralExpression::IntegerLiteralExpression(ExpressionManager const& manager, int_fast64_t value) : BaseExpression(manager, ExpressionReturnType::Int), value(value) { // Intentionally left empty. } @@ -22,10 +22,6 @@ namespace storm { return std::set(); } - std::map IntegerLiteralExpression::getVariablesAndTypes() const { - return std::map(); - } - std::shared_ptr IntegerLiteralExpression::simplify() const { return this->shared_from_this(); } diff --git a/src/storage/expressions/IntegerLiteralExpression.h b/src/storage/expressions/IntegerLiteralExpression.h index e764f5bbc..348ea06ed 100644 --- a/src/storage/expressions/IntegerLiteralExpression.h +++ b/src/storage/expressions/IntegerLiteralExpression.h @@ -11,9 +11,10 @@ namespace storm { /*! * Creates an integer literal expression with the given value. * + * @param manager The manager responsible for this expression. * @param value The value of the integer literal. */ - IntegerLiteralExpression(int_fast64_t value); + IntegerLiteralExpression(ExpressionManager const& manager, int_fast64_t value); // Instantiate constructors and assignments with their default implementations. IntegerLiteralExpression(IntegerLiteralExpression const& other) = default; @@ -29,7 +30,6 @@ namespace storm { virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; virtual bool isLiteral() const override; virtual std::set getVariables() const override; - virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual boost::any accept(ExpressionVisitor& visitor) const override; diff --git a/src/storage/expressions/SimpleValuation.cpp b/src/storage/expressions/SimpleValuation.cpp deleted file mode 100644 index ae111c653..000000000 --- a/src/storage/expressions/SimpleValuation.cpp +++ /dev/null @@ -1,177 +0,0 @@ -#include "src/storage/expressions/SimpleValuation.h" - -#include - -#include -#include "src/utility/macros.h" -#include "src/exceptions/InvalidArgumentException.h" -#include "src/exceptions/InvalidAccessException.h" - -namespace storm { - namespace expressions { - bool SimpleValuation::operator==(SimpleValuation const& other) const { - return this->identifierToValueMap == other.identifierToValueMap; - } - - void SimpleValuation::addBooleanIdentifier(std::string const& name, bool initialValue) { - STORM_LOG_THROW(this->identifierToValueMap.find(name) == this->identifierToValueMap.end(), storm::exceptions::InvalidArgumentException, "Identifier '" << name << "' already registered."); - this->identifierToValueMap.emplace(name, initialValue); - } - - void SimpleValuation::addIntegerIdentifier(std::string const& name, int_fast64_t initialValue) { - STORM_LOG_THROW(this->identifierToValueMap.find(name) == this->identifierToValueMap.end(), storm::exceptions::InvalidArgumentException, "Identifier '" << name << "' already registered."); - this->identifierToValueMap.emplace(name, initialValue); - } - - void SimpleValuation::addDoubleIdentifier(std::string const& name, double initialValue) { - STORM_LOG_THROW(this->identifierToValueMap.find(name) == this->identifierToValueMap.end(), storm::exceptions::InvalidArgumentException, "Identifier '" << name << "' already registered."); - this->identifierToValueMap.emplace(name, initialValue); - } - - void SimpleValuation::setBooleanValue(std::string const& name, bool value) { - this->identifierToValueMap[name] = value; - } - - void SimpleValuation::setIntegerValue(std::string const& name, int_fast64_t value) { - this->identifierToValueMap[name] = value; - } - - void SimpleValuation::setDoubleValue(std::string const& name, double value) { - this->identifierToValueMap[name] = value; - } - - void SimpleValuation::removeIdentifier(std::string const& name) { - auto nameValuePair = this->identifierToValueMap.find(name); - STORM_LOG_THROW(nameValuePair != this->identifierToValueMap.end(), storm::exceptions::InvalidArgumentException, "Deleting unknown identifier '" << name << "'."); - this->identifierToValueMap.erase(nameValuePair); - } - - ExpressionReturnType SimpleValuation::getIdentifierType(std::string const& name) const { - auto nameValuePair = this->identifierToValueMap.find(name); - STORM_LOG_THROW(nameValuePair != this->identifierToValueMap.end(), storm::exceptions::InvalidAccessException, "Access to unkown identifier '" << name << "'."); - if (nameValuePair->second.type() == typeid(bool)) { - return ExpressionReturnType::Bool; - } else if (nameValuePair->second.type() == typeid(int_fast64_t)) { - return ExpressionReturnType::Int; - } else { - return ExpressionReturnType::Double; - } - } - - bool SimpleValuation::containsBooleanIdentifier(std::string const& name) const { - auto nameValuePair = this->identifierToValueMap.find(name); - if (nameValuePair == this->identifierToValueMap.end()) { - return false; - } - return nameValuePair->second.type() == typeid(bool); - } - - bool SimpleValuation::containsIntegerIdentifier(std::string const& name) const { - auto nameValuePair = this->identifierToValueMap.find(name); - if (nameValuePair == this->identifierToValueMap.end()) { - return false; - } - return nameValuePair->second.type() == typeid(int_fast64_t); - } - - bool SimpleValuation::containsDoubleIdentifier(std::string const& name) const { - auto nameValuePair = this->identifierToValueMap.find(name); - if (nameValuePair == this->identifierToValueMap.end()) { - return false; - } - return nameValuePair->second.type() == typeid(double); - } - - bool SimpleValuation::getBooleanValue(std::string const& name) const { - auto nameValuePair = this->identifierToValueMap.find(name); - STORM_LOG_THROW(nameValuePair != this->identifierToValueMap.end(), storm::exceptions::InvalidAccessException, "Access to unkown identifier '" << name << "'."); - return boost::get(nameValuePair->second); - } - - int_fast64_t SimpleValuation::getIntegerValue(std::string const& name) const { - auto nameValuePair = this->identifierToValueMap.find(name); - STORM_LOG_THROW(nameValuePair != this->identifierToValueMap.end(), storm::exceptions::InvalidAccessException, "Access to unkown identifier '" << name << "'."); - return boost::get(nameValuePair->second); - } - - double SimpleValuation::getDoubleValue(std::string const& name) const { - auto nameValuePair = this->identifierToValueMap.find(name); - STORM_LOG_THROW(nameValuePair != this->identifierToValueMap.end(), storm::exceptions::InvalidAccessException, "Access to unkown identifier '" << name << "'."); - return boost::get(nameValuePair->second); - } - - std::size_t SimpleValuation::getNumberOfIdentifiers() const { - return this->identifierToValueMap.size(); - } - - std::set SimpleValuation::getIdentifiers() const { - std::set result; - for (auto const& nameValuePair : this->identifierToValueMap) { - result.insert(nameValuePair.first); - } - return result; - } - - std::set SimpleValuation::getBooleanIdentifiers() const { - std::set result; - for (auto const& nameValuePair : this->identifierToValueMap) { - if (nameValuePair.second.type() == typeid(bool)) { - result.insert(nameValuePair.first); - } - } - return result; - } - - std::set SimpleValuation::getIntegerIdentifiers() const { - std::set result; - for (auto const& nameValuePair : this->identifierToValueMap) { - if (nameValuePair.second.type() == typeid(int_fast64_t)) { - result.insert(nameValuePair.first); - } - } - return result; - } - - std::set SimpleValuation::getDoubleIdentifiers() const { - std::set result; - for (auto const& nameValuePair : this->identifierToValueMap) { - if (nameValuePair.second.type() == typeid(double)) { - result.insert(nameValuePair.first); - } - } - return result; - } - - std::ostream& operator<<(std::ostream& stream, SimpleValuation const& valuation) { - stream << "{ "; - uint_fast64_t elementIndex = 0; - for (auto const& nameValuePair : valuation.identifierToValueMap) { - stream << nameValuePair.first << " -> " << nameValuePair.second << " "; - ++elementIndex; - if (elementIndex < valuation.identifierToValueMap.size()) { - stream << ", "; - } - } - stream << "}"; - - return stream; - } - - std::size_t SimpleValuationPointerHash::operator()(SimpleValuation* valuation) const { - size_t seed = 0; - for (auto const& nameValuePair : valuation->identifierToValueMap) { - boost::hash_combine(seed, nameValuePair.first); - boost::hash_combine(seed, nameValuePair.second); - } - return seed; - } - - bool SimpleValuationPointerCompare::operator()(SimpleValuation* valuation1, SimpleValuation* valuation2) const { - return *valuation1 == *valuation2; - } - - bool SimpleValuationPointerLess::operator()(SimpleValuation* valuation1, SimpleValuation* valuation2) const { - return valuation1->identifierToValueMap < valuation2->identifierToValueMap; - } - } -} \ No newline at end of file diff --git a/src/storage/expressions/SimpleValuation.h b/src/storage/expressions/SimpleValuation.h deleted file mode 100644 index fccfe2faa..000000000 --- a/src/storage/expressions/SimpleValuation.h +++ /dev/null @@ -1,148 +0,0 @@ -#ifndef STORM_STORAGE_EXPRESSIONS_SIMPLEVALUATION_H_ -#define STORM_STORAGE_EXPRESSIONS_SIMPLEVALUATION_H_ - -#include -#include -#include - -#include "src/storage/expressions/Valuation.h" -#include "src/storage/expressions/ExpressionReturnType.h" -#include "src/utility/OsDetection.h" - -namespace storm { - namespace expressions { - class SimpleValuation : public Valuation { - public: - friend class SimpleValuationPointerHash; - friend class SimpleValuationPointerLess; - - // Instantiate some constructors and assignments with their default implementations. - SimpleValuation() = default; - SimpleValuation(SimpleValuation const&) = default; - SimpleValuation& operator=(SimpleValuation const&) = default; -#ifndef WINDOWS - SimpleValuation(SimpleValuation&&) = default; - SimpleValuation& operator=(SimpleValuation&&) = default; -#endif - virtual ~SimpleValuation() = default; - - /*! - * Compares two simple valuations wrt. equality. - */ - bool operator==(SimpleValuation const& other) const; - - /*! - * Adds a boolean identifier with the given name. - * - * @param name The name of the boolean identifier to add. - * @param initialValue The initial value of the identifier. - */ - void addBooleanIdentifier(std::string const& name, bool initialValue = false); - - /*! - * Adds a integer identifier with the given name. - * - * @param name The name of the integer identifier to add. - * @param initialValue The initial value of the identifier. - */ - void addIntegerIdentifier(std::string const& name, int_fast64_t initialValue = 0); - - /*! - * Adds a double identifier with the given name. - * - * @param name The name of the double identifier to add. - * @param initialValue The initial value of the identifier. - */ - void addDoubleIdentifier(std::string const& name, double initialValue = 0); - - /*! - * Sets the value of the boolean identifier with the given name to the given value. - * - * @param name The name of the boolean identifier whose value to set. - * @param value The new value of the boolean identifier. - */ - void setBooleanValue(std::string const& name, bool value); - - /*! - * Sets the value of the integer identifier with the given name to the given value. - * - * @param name The name of the integer identifier whose value to set. - * @param value The new value of the integer identifier. - */ - void setIntegerValue(std::string const& name, int_fast64_t value); - - /*! - * Sets the value of the double identifier with the given name to the given value. - * - * @param name The name of the double identifier whose value to set. - * @param value The new value of the double identifier. - */ - void setDoubleValue(std::string const& name, double value); - - /*! - * Removes the given identifier from this valuation. - * - * @param name The name of the identifier that is to be removed. - */ - void removeIdentifier(std::string const& name); - - /*! - * Retrieves the type of the identifier with the given name. - * - * @param name The name of the identifier whose type to retrieve. - * @return The type of the identifier with the given name. - */ - ExpressionReturnType getIdentifierType(std::string const& name) const; - - // Override base class methods. - virtual bool containsBooleanIdentifier(std::string const& name) const override; - virtual bool containsIntegerIdentifier(std::string const& name) const override; - virtual bool containsDoubleIdentifier(std::string const& name) const override; - virtual std::size_t getNumberOfIdentifiers() const override; - virtual std::set getIdentifiers() const override; - virtual std::set getBooleanIdentifiers() const override; - virtual std::set getIntegerIdentifiers() const override; - virtual std::set getDoubleIdentifiers() const override; - virtual bool getBooleanValue(std::string const& name) const override; - virtual int_fast64_t getIntegerValue(std::string const& name) const override; - virtual double getDoubleValue(std::string const& name) const override; - - friend std::ostream& operator<<(std::ostream& stream, SimpleValuation const& valuation); - - private: - // A mapping of boolean identifiers to their local indices in the value container. - boost::container::flat_map> identifierToValueMap; - }; - - /*! - * A helper class that can pe used as the hash functor for data structures that need to hash a simple valuations - * given via pointers. - */ - class SimpleValuationPointerHash { - public: - std::size_t operator()(SimpleValuation* valuation) const; - }; - - /*! - * A helper class that can be used as the comparison functor wrt. equality for data structures that need to - * store pointers to a simple valuations and need to compare the elements wrt. their content (rather than - * pointer equality). - */ - class SimpleValuationPointerCompare { - public: - bool operator()(SimpleValuation* valuation1, SimpleValuation* valuation2) const; - }; - - /*! - * A helper class that can be used as the comparison functor wrt. "<" for data structures that need to - * store pointers to a simple valuations and need to compare the elements wrt. their content (rather than - * pointer equality). - */ - class SimpleValuationPointerLess { - public: - bool operator()(SimpleValuation* valuation1, SimpleValuation* valuation2) const; - }; - } -} - -#endif /* STORM_STORAGE_EXPRESSIONS_SIMPLEVALUATION_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/UnaryBooleanFunctionExpression.cpp b/src/storage/expressions/UnaryBooleanFunctionExpression.cpp index d9f7839ea..74b96a36c 100644 --- a/src/storage/expressions/UnaryBooleanFunctionExpression.cpp +++ b/src/storage/expressions/UnaryBooleanFunctionExpression.cpp @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - UnaryBooleanFunctionExpression::UnaryBooleanFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType) : UnaryExpression(returnType, operand), operatorType(operatorType) { + UnaryBooleanFunctionExpression::UnaryBooleanFunctionExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType) : UnaryExpression(manager, returnType, operand), operatorType(operatorType) { // Intentionally left empty. } diff --git a/src/storage/expressions/UnaryBooleanFunctionExpression.h b/src/storage/expressions/UnaryBooleanFunctionExpression.h index 02b2fab2f..3cf7c0b55 100644 --- a/src/storage/expressions/UnaryBooleanFunctionExpression.h +++ b/src/storage/expressions/UnaryBooleanFunctionExpression.h @@ -16,11 +16,12 @@ namespace storm { /*! * Creates a unary boolean function expression with the given return type, operand and operator. * + * @param manager The manager responsible for this expression. * @param returnType The return type of the expression. * @param operand The operand of the expression. * @param operatorType The operator of the expression. */ - UnaryBooleanFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType); + UnaryBooleanFunctionExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType); // Instantiate constructors and assignments with their default implementations. UnaryBooleanFunctionExpression(UnaryBooleanFunctionExpression const& other) = default; diff --git a/src/storage/expressions/UnaryExpression.cpp b/src/storage/expressions/UnaryExpression.cpp index 3569e5453..57abeeedc 100644 --- a/src/storage/expressions/UnaryExpression.cpp +++ b/src/storage/expressions/UnaryExpression.cpp @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - UnaryExpression::UnaryExpression(ExpressionReturnType returnType, std::shared_ptr const& operand) : BaseExpression(returnType), operand(operand) { + UnaryExpression::UnaryExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& operand) : BaseExpression(manager, returnType), operand(operand) { // Intentionally left empty. } @@ -20,10 +20,6 @@ namespace storm { std::set UnaryExpression::getVariables() const { return this->getOperand()->getVariables(); } - - std::map UnaryExpression::getVariablesAndTypes() const { - return this->getOperand()->getVariablesAndTypes(); - } std::shared_ptr const& UnaryExpression::getOperand() const { return this->operand; diff --git a/src/storage/expressions/UnaryExpression.h b/src/storage/expressions/UnaryExpression.h index 5473d3122..60b770622 100644 --- a/src/storage/expressions/UnaryExpression.h +++ b/src/storage/expressions/UnaryExpression.h @@ -11,10 +11,11 @@ namespace storm { /*! * Creates a unary expression with the given return type and operand. * + * @param manager The manager responsible for this expression. * @param returnType The return type of the expression. * @param operand The operand of the unary expression. */ - UnaryExpression(ExpressionReturnType returnType, std::shared_ptr const& operand); + UnaryExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& operand); // Instantiate constructors and assignments with their default implementations. UnaryExpression(UnaryExpression const& other); @@ -31,7 +32,6 @@ namespace storm { virtual uint_fast64_t getArity() const override; virtual std::shared_ptr getOperand(uint_fast64_t operandIndex) const override; virtual std::set getVariables() const override; - virtual std::map getVariablesAndTypes() const override; /*! * Retrieves the operand of the unary expression. diff --git a/src/storage/expressions/UnaryNumericalFunctionExpression.cpp b/src/storage/expressions/UnaryNumericalFunctionExpression.cpp index 75a2c0186..6237db107 100644 --- a/src/storage/expressions/UnaryNumericalFunctionExpression.cpp +++ b/src/storage/expressions/UnaryNumericalFunctionExpression.cpp @@ -6,7 +6,7 @@ namespace storm { namespace expressions { - UnaryNumericalFunctionExpression::UnaryNumericalFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType) : UnaryExpression(returnType, operand), operatorType(operatorType) { + UnaryNumericalFunctionExpression::UnaryNumericalFunctionExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType) : UnaryExpression(manager, returnType, operand), operatorType(operatorType) { // Intentionally left empty. } diff --git a/src/storage/expressions/UnaryNumericalFunctionExpression.h b/src/storage/expressions/UnaryNumericalFunctionExpression.h index 7852c73c7..b4aac426e 100644 --- a/src/storage/expressions/UnaryNumericalFunctionExpression.h +++ b/src/storage/expressions/UnaryNumericalFunctionExpression.h @@ -16,11 +16,12 @@ namespace storm { /*! * Creates a unary numerical function expression with the given return type, operand and operator. * + * @param manager The manager responsible for this expression. * @param returnType The return type of the expression. * @param operand The operand of the expression. * @param operatorType The operator of the expression. */ - UnaryNumericalFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType); + UnaryNumericalFunctionExpression(ExpressionManager const& manager, ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType); // Instantiate constructors and assignments with their default implementations. UnaryNumericalFunctionExpression(UnaryNumericalFunctionExpression const& other) = default; diff --git a/src/storage/expressions/Valuation.cpp b/src/storage/expressions/Valuation.cpp new file mode 100644 index 000000000..a875b358f --- /dev/null +++ b/src/storage/expressions/Valuation.cpp @@ -0,0 +1,82 @@ +#include "src/storage/expressions/Valuation.h" + +#include + +#include "src/storage/expressions/ExpressionManager.h" +#include "src/storage/expressions/Variable.h" + +namespace storm { + namespace expressions { + Valuation::Valuation(ExpressionManager const& manager) : manager(manager), booleanValues(nullptr), integerValues(nullptr), rationalValues(nullptr) { + if (manager.getNumberOfBooleanVariables() > 0) { + booleanValues = std::unique_ptr>(new std::vector(manager.getNumberOfBooleanVariables())); + } + if (manager.getNumberOfIntegerVariables() > 0) { + integerValues = std::unique_ptr>(new std::vector(manager.getNumberOfIntegerVariables())); + } + if (manager.getNumberOfRationalVariables() > 0) { + rationalValues = std::unique_ptr>(new std::vector(manager.getNumberOfRationalVariables())); + } + } + + Valuation::Valuation(Valuation const& other) : manager(other.manager) { + if (other.booleanValues != nullptr) { + booleanValues = std::unique_ptr>(new std::vector(*other.booleanValues)); + } + if (other.integerValues != nullptr) { + integerValues = std::unique_ptr>(new std::vector(*other.integerValues)); + } + if (other.booleanValues != nullptr) { + rationalValues = std::unique_ptr>(new std::vector(*other.rationalValues)); + } + } + + bool Valuation::operator==(Valuation const& other) const { + return manager == other.manager && booleanValues == other.booleanValues && integerValues == other.integerValues && rationalValues == other.rationalValues; + } + + bool Valuation::getBooleanValue(Variable const& booleanVariable) const { + return (*booleanValues)[booleanVariable.getOffset()]; + } + + int_fast64_t Valuation::getIntegerValue(Variable const& integerVariable) const { + return (*integerValues)[integerVariable.getOffset()]; + } + + double Valuation::getRationalValue(Variable const& rationalVariable) const { + return (*rationalValues)[rationalVariable.getOffset()]; + } + + void Valuation::setBooleanValue(Variable const& booleanVariable, bool value) { + (*booleanValues)[booleanVariable.getOffset()] = value; + } + + void Valuation::setIntegerValue(Variable const& integerVariable, int_fast64_t value) { + (*integerValues)[integerVariable.getOffset()] = value; + } + + void Valuation::setRationalValue(Variable const& rationalVariable, double value) { + (*rationalValues)[rationalVariable.getOffset()] = value; + } + + ExpressionManager const& Valuation::getManager() const { + return manager; + } + + std::size_t ValuationPointerHash::operator()(Valuation* valuation) const { + size_t seed = 0; + boost::hash_combine(seed, valuation->booleanValues); + boost::hash_combine(seed, valuation->integerValues); + boost::hash_combine(seed, valuation->rationalValues); + return seed; + } + + bool ValuationPointerCompare::operator()(Valuation* valuation1, Valuation* valuation2) const { + return *valuation1 == *valuation2; + } + + bool ValuationPointerLess::operator()(Valuation* valuation1, Valuation* valuation2) const { + return valuation1->booleanValues < valuation2->booleanValues || valuation1->integerValues < valuation2->integerValues || valuation1->rationalValues < valuation2->rationalValues; + } + } +} \ No newline at end of file diff --git a/src/storage/expressions/Valuation.h b/src/storage/expressions/Valuation.h index a09c4f07c..2d56ae106 100644 --- a/src/storage/expressions/Valuation.h +++ b/src/storage/expressions/Valuation.h @@ -1,100 +1,138 @@ #ifndef STORM_STORAGE_EXPRESSIONS_VALUATION_H_ #define STORM_STORAGE_EXPRESSIONS_VALUATION_H_ -#include -#include +#include +#include namespace storm { namespace expressions { + class ExpressionManager; + class Variable; + /*! - * The base class of all valuations where a valuation assigns a concrete value to all identifiers. This is, for - * example, used for evaluating expressions. + * A class to store a valuation of variables. This is, for example, used for evaluating expressions. */ class Valuation { public: - /*! - * Retrieves the boolean value of the identifier with the given name. - * - * @param name The name of the boolean identifier whose value to retrieve. - * @return The value of the boolean identifier. - */ - virtual bool getBooleanValue(std::string const& name) const = 0; + friend class ValuationPointerHash; + friend class ValuationPointerLess; /*! - * Retrieves the integer value of the identifier with the given name. + * Creates a valuation of all non-auxiliary variables managed by the given manager. If the manager is + * modified in the sense that additional variables are added, all valuations over its variables are + * invalidated. * - * @param name The name of the integer identifier whose value to retrieve. - * @return The value of the integer identifier. + * @param manager The manager of the variables. */ - virtual int_fast64_t getIntegerValue(std::string const& name) const = 0; + Valuation(ExpressionManager const& manager); /*! - * Retrieves the double value of the identifier with the given name. + * Deep-copies the valuation. * - * @param name The name of the double identifier whose value to retrieve. - * @return The value of the double identifier. + * @param other The valuation to copy */ - virtual double getDoubleValue(std::string const& name) const = 0; + Valuation(Valuation const& other); /*! - * Retrieves whether there exists a boolean identifier with the given name in the valuation. + * Checks whether the two valuations are semantically equivalent. * - * @param name The name of the boolean identifier to query. - * @return True iff the identifier exists and is of boolean type. + * @param other The valuation with which to compare. + * @return True iff the two valuations are semantically equivalent. */ - virtual bool containsBooleanIdentifier(std::string const& name) const = 0; + bool operator==(Valuation const& other) const; /*! - * Retrieves whether there exists a integer identifier with the given name in the valuation. + * Retrieves the value of the given boolean variable. * - * @param name The name of the integer identifier to query. - * @return True iff the identifier exists and is of boolean type. + * @param booleanVariable The boolean variable whose value to retrieve. + * @return The value of the boolean variable. */ - virtual bool containsIntegerIdentifier(std::string const& name) const = 0; + bool getBooleanValue(Variable const& booleanVariable) const; /*! - * Retrieves whether there exists a double identifier with the given name in the valuation. + * Sets the value of the given boolean variable to the provided value. * - * @param name The name of the double identifier to query. - * @return True iff the identifier exists and is of boolean type. + * @param booleanVariable The variable whose value to set. + * @param value The new value of the variable. */ - virtual bool containsDoubleIdentifier(std::string const& name) const = 0; + void setBooleanValue(Variable const& booleanVariable, bool value); /*! - * Retrieves the number of identifiers in this valuation. + * Retrieves the value of the given integer variable. * - * @return The number of identifiers in this valuation. + * @param integerVariable The integer variable whose value to retrieve. + * @return The value of the integer variable. */ - virtual std::size_t getNumberOfIdentifiers() const = 0; - + int_fast64_t getIntegerValue(Variable const& integerVariable) const; + /*! - * Retrieves the set of all identifiers contained in this valuation. + * Sets the value of the given boolean variable to the provided value. * - * @return The set of all identifiers contained in this valuation. + * @param integerVariable The variable whose value to set. + * @param value The new value of the variable. */ - virtual std::set getIdentifiers() const = 0; + void setIntegerValue(Variable const& integerVariable, int_fast64_t value); /*! - * Retrieves the set of boolean identifiers contained in this valuation. + * Retrieves the value of the given rational variable. * - * @return The set of boolean identifiers contained in this valuation. + * @param rationalVariable The rational variable whose value to retrieve. + * @return The value of the rational variable. */ - virtual std::set getBooleanIdentifiers() const = 0; - + double getRationalValue(Variable const& rationalVariable) const; + /*! - * Retrieves the set of integer identifiers contained in this valuation. + * Sets the value of the given boolean variable to the provided value. * - * @return The set of integer identifiers contained in this valuation. + * @param integerVariable The variable whose value to set. + * @param value The new value of the variable. */ - virtual std::set getIntegerIdentifiers() const = 0; - + void setRationalValue(Variable const& rationalVariable, double value); + /*! - * Retrieves the set of double identifiers contained in this valuation. + * Retrieves the manager responsible for the variables of this valuation. * - * @return The set of double identifiers contained in this valuation. + * @return The manager. */ - virtual std::set getDoubleIdentifiers() const = 0; + ExpressionManager const& getManager() const; + private: + // The manager responsible for the variables of this valuation. + ExpressionManager const& manager; + + // Containers that store the values of the variables of the appropriate type. + std::unique_ptr> booleanValues; + std::unique_ptr> integerValues; + std::unique_ptr> rationalValues; + }; + + /*! + * A helper class that can pe used as the hash functor for data structures that need to hash valuations given + * via pointers. + */ + class ValuationPointerHash { + public: + std::size_t operator()(Valuation* valuation) const; + }; + + /*! + * A helper class that can be used as the comparison functor wrt. equality for data structures that need to + * store pointers to valuations and need to compare the elements wrt. their content (rather than pointer + * equality). + */ + class ValuationPointerCompare { + public: + bool operator()(Valuation* valuation1, Valuation* valuation2) const; + }; + + /*! + * A helper class that can be used as the comparison functor wrt. "<" for data structures that need to + * store pointers to valuations and need to compare the elements wrt. their content (rather than pointer + * equality). + */ + class ValuationPointerLess { + public: + bool operator()(Valuation* valuation1, Valuation* valuation2) const; }; } } diff --git a/src/storage/expressions/Variable.cpp b/src/storage/expressions/Variable.cpp new file mode 100644 index 000000000..67683dfaa --- /dev/null +++ b/src/storage/expressions/Variable.cpp @@ -0,0 +1,54 @@ +#include "src/storage/expressions/Variable.h" +#include "src/storage/expressions/ExpressionManager.h" + +namespace storm { + namespace expressions { + Variable::Variable(ExpressionManager const& manager, uint_fast64_t index) : manager(manager), index(index) { + // Intentionally left empty. + } + + bool Variable::operator==(Variable const& other) const { + return manager == other.manager && index == other.index; + } + + storm::expressions::Expression Variable::getExpression() const { + return storm::expressions::Expression(*this); + } + + uint_fast64_t Variable::getIndex() const { + return index; + } + + uint_fast64_t Variable::getOffset() const { + return manager.getOffset(index); + } + + std::string const& Variable::getName() const { + return manager.getVariableName(index); + } + + ExpressionReturnType Variable::getType() const { + return manager.getVariableType(index); + } + + ExpressionManager const& Variable::getManager() const { + return manager; + } + + bool Variable::hasBooleanType() const { + return this->getType() == ExpressionReturnType::Bool; + } + + bool Variable::hasIntegralType() const { + return this->getType() == ExpressionReturnType::Int; + } + + bool Variable::hasRationalType() const { + return this->getType() == ExpressionReturnType::Double; + } + + bool Variable::hasNumericType() const { + return this->getType() == ExpressionReturnType::Int || this->getType() == ExpressionReturnType::Double; + } + } +} \ No newline at end of file diff --git a/src/storage/expressions/Variable.h b/src/storage/expressions/Variable.h new file mode 100644 index 000000000..10325877d --- /dev/null +++ b/src/storage/expressions/Variable.h @@ -0,0 +1,138 @@ +#ifndef STORM_STORAGE_EXPRESSIONS_VARIABLE_H_ +#define STORM_STORAGE_EXPRESSIONS_VARIABLE_H_ + +#include +#include + +#include "src/utility/OsDetection.h" +#include "src/storage/expressions/ExpressionReturnType.h" +#include "src/storage/expressions/Expression.h" + +namespace storm { + namespace expressions { + class ExpressionManager; + + // This class captures a simple variable. + class Variable { + public: + /*! + * Constructs a variable with the given index and type. + * + * @param manager The manager that is responsible for this variable. + * @param index The (unique) index of the variable. + */ + Variable(ExpressionManager const& manager, uint_fast64_t index); + + // Default-instantiate some copy/move construction/assignment. + Variable(Variable const& other) = default; + Variable& operator=(Variable const& other) = default; +#ifndef WINDOWS + Variable(Variable&& other) = default; + Variable& operator=(Variable&& other) = default; +#endif + + /*! + * Checks the two variables for equality. + * + * @param other The variable to compare with. + * @return True iff the two variables are the same. + */ + bool operator==(Variable const& other) const; + + /*! + * Retrieves the name of the variable. + * + * @return name The name of the variable. + */ + std::string const& getName() const; + + /*! + * Retrieves the type of the variable. + * + * @return The type of the variable. + */ + ExpressionReturnType getType() const; + + /*! + * Retrieves an expression that represents the variable. + * + * @return An expression that represents the varible. + */ + storm::expressions::Expression getExpression() const; + + /*! + * Retrieves the manager responsible for this variable. + */ + ExpressionManager const& getManager() const; + + /*! + * Retrieves the index of the variable. + * + * @return The index of the variable. + */ + uint_fast64_t getIndex() const; + + /*! + * Retrieves the offset of the variable in the group of all equally typed variables. + * + * @return The offset of the variable. + */ + uint_fast64_t getOffset() const; + + /*! + * Checks whether the variable is of boolean type. + * + * @return True iff the variable if of boolean type. + */ + bool hasBooleanType() const; + + /*! + * Checks whether the variable is of integral type. + * + * @return True iff the variable if of integral type. + */ + bool hasIntegralType() const; + + /*! + * Checks whether the variable is of rational type. + * + * @return True iff the variable if of rational type. + */ + bool hasRationalType() const; + + /*! + * Checks whether the variable is of boolean type. + * + * @return True iff the variable if of boolean type. + */ + bool hasNumericType() const; + + private: + // The manager that is responsible for this variable. + ExpressionManager const& manager; + + // The index of the variable. + uint_fast64_t index; + }; + } +} + +namespace std { + // Provide a hashing operator, so we can put variables in unordered collections. + template <> + struct hash { + std::size_t operator()(storm::expressions::Variable const& variable) const { + return std::hash()(variable.getIndex()); + } + }; + + // Provide a less operator, so we can put variables in ordered collections. + template <> + struct less { + std::size_t operator()(storm::expressions::Variable const& variable1, storm::expressions::Variable const& variable2) const { + return variable1.getIndex() < variable2.getIndex(); + } + }; +} + +#endif /* STORM_STORAGE_EXPRESSIONS_VARIABLE_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/VariableExpression.cpp b/src/storage/expressions/VariableExpression.cpp index 685ce1d4d..0737b94d9 100644 --- a/src/storage/expressions/VariableExpression.cpp +++ b/src/storage/expressions/VariableExpression.cpp @@ -4,26 +4,30 @@ namespace storm { namespace expressions { - VariableExpression::VariableExpression(ExpressionReturnType returnType, std::string const& variableName) : BaseExpression(returnType), variableName(variableName) { + VariableExpression::VariableExpression(Variable const& variable) : BaseExpression(variable.getManager(), variable.getType()), variable(variable) { // Intentionally left empty. } std::string const& VariableExpression::getVariableName() const { - return this->variableName; + return variable.getName(); + } + + Variable const& VariableExpression::getVariable() const { + return variable; } bool VariableExpression::evaluateAsBool(Valuation const* valuation) const { STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation."); STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as boolean: return type is not a boolean."); - return valuation->getBooleanValue(this->getVariableName()); + return valuation->getBooleanValue(this->getVariable()); } int_fast64_t VariableExpression::evaluateAsInt(Valuation const* valuation) const { STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation."); STORM_LOG_THROW(this->hasIntegralReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as integer: return type is not an integer."); - return valuation->getIntegerValue(this->getVariableName()); + return valuation->getIntegerValue(this->getVariable()); } double VariableExpression::evaluateAsDouble(Valuation const* valuation) const { @@ -31,8 +35,8 @@ namespace storm { STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as double: return type is not a double."); switch (this->getReturnType()) { - case ExpressionReturnType::Int: return static_cast(valuation->getIntegerValue(this->getVariableName())); break; - case ExpressionReturnType::Double: valuation->getDoubleValue(this->getVariableName()); break; + case ExpressionReturnType::Int: return static_cast(valuation->getIntegerValue(this->getVariable())); break; + case ExpressionReturnType::Double: valuation->getRationalValue(this->getVariable()); break; default: break; } STORM_LOG_ASSERT(false, "Type of variable is required to be numeric."); @@ -56,10 +60,6 @@ namespace storm { std::set VariableExpression::getVariables() const { return {this->getVariableName()}; } - - std::map VariableExpression::getVariablesAndTypes() const { - return{ std::make_pair(this->getVariableName(), this->getReturnType()) }; - } std::shared_ptr VariableExpression::simplify() const { return this->shared_from_this(); diff --git a/src/storage/expressions/VariableExpression.h b/src/storage/expressions/VariableExpression.h index d84a12c55..505900ebc 100644 --- a/src/storage/expressions/VariableExpression.h +++ b/src/storage/expressions/VariableExpression.h @@ -2,6 +2,7 @@ #define STORM_STORAGE_EXPRESSIONS_VARIABLEEXPRESSION_H_ #include "src/storage/expressions/BaseExpression.h" +#include "src/storage/expressions/Variable.h" #include "src/utility/OsDetection.h" namespace storm { @@ -14,7 +15,7 @@ namespace storm { * @param returnType The return type of the variable expression. * @param variableName The name of the variable associated with this expression. */ - VariableExpression(ExpressionReturnType returnType, std::string const& variableName); + VariableExpression(Variable const& variable); // Instantiate constructors and assignments with their default implementations. VariableExpression(VariableExpression const&) = default; @@ -33,7 +34,6 @@ namespace storm { virtual bool containsVariables() const override; virtual bool isVariable() const override; virtual std::set getVariables() const override; - virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual boost::any accept(ExpressionVisitor& visitor) const override; @@ -44,13 +44,20 @@ namespace storm { */ std::string const& getVariableName() const; + /*! + * Retrieves the variable associated with this expression. + * + * @return The variable associated with this expression. + */ + Variable const& getVariable() const; + protected: // Override base class method. virtual void printToStream(std::ostream& stream) const override; private: - // The variable name associated with this expression. - std::string variableName; + // The variable that is represented by this expression. + Variable variable; }; } }