diff --git a/src/adapters/ExplicitModelAdapter.h b/src/adapters/ExplicitModelAdapter.h index ac0ae6455..6899f87c0 100644 --- a/src/adapters/ExplicitModelAdapter.h +++ b/src/adapters/ExplicitModelAdapter.h @@ -177,16 +177,15 @@ namespace storm { // This variable needs to be declared prior to the switch, because of C++ rules. int_fast64_t newValue = 0; for (auto const& assignment : update.getAssignments()) { - switch (assignment.getExpression().getReturnType()) { - case storm::expressions::ExpressionReturnType::Bool: newState->setBooleanValue(assignment.getVariableName(), assignment.getExpression().evaluateAsBool(baseState)); break; - case storm::expressions::ExpressionReturnType::Int: - { - newValue = assignment.getExpression().evaluateAsInt(baseState); - auto const& boundsPair = variableInformation.variableToBoundsMap.find(assignment.getVariableName()); - STORM_LOG_THROW(boundsPair->second.first <= newValue && newValue <= boundsPair->second.second, storm::exceptions::InvalidArgumentException, "Invalid value " << newValue << " for variable '" << assignment.getVariableName() << "'."); - newState->setIntegerValue(assignment.getVariableName(), newValue); break; - } - default: STORM_LOG_ASSERT(false, "Invalid type of assignment."); break; + if (assignment.getExpression().hasBooleanType()) { + newState->setBooleanValue(assignment.getVariable(), assignment.getExpression().evaluateAsBool(baseState)); + } else if (assignment.getExpression().hasIntegerType()) { + newValue = assignment.getExpression().evaluateAsInt(baseState); + auto const& boundsPair = variableInformation.variableToBoundsMap.find(assignment.getVariableName()); + STORM_LOG_THROW(boundsPair->second.first <= newValue && newValue <= boundsPair->second.second, storm::exceptions::InvalidArgumentException, "Invalid value " << newValue << " for variable '" << assignment.getVariableName() << "'."); + newState->setIntegerValue(assignment.getVariable(), newValue); + } else { + STORM_LOG_ASSERT(false, "Invalid type '" << assignment.getExpression().getType() << "' of assignment."); } } return newState; @@ -478,21 +477,7 @@ namespace storm { // Initialize a queue and insert the initial state. std::queue stateQueue; - StateType* initialState = new StateType; - for (auto const& booleanVariable : program.getGlobalBooleanVariables()) { - initialState->addBooleanIdentifier(booleanVariable.getName(), booleanVariable.getInitialValueExpression().evaluateAsBool()); - } - for (auto const& integerVariable : program.getGlobalIntegerVariables()) { - initialState->addIntegerIdentifier(integerVariable.getName(), integerVariable.getInitialValueExpression().evaluateAsInt()); - } - for (auto const& module : program.getModules()) { - for (auto const& booleanVariable : module.getBooleanVariables()) { - initialState->addBooleanIdentifier(booleanVariable.getName(), booleanVariable.getInitialValueExpression().evaluateAsBool()); - } - for (auto const& integerVariable : module.getIntegerVariables()) { - initialState->addIntegerIdentifier(integerVariable.getName(), integerVariable.getInitialValueExpression().evaluateAsInt()); - } - } + StateType* initialState = new StateType(program.getManager().getSharedPointer()); std::pair addIndexPair = getOrAddStateIndex(initialState, stateInformation); stateInformation.initialStateIndices.push_back(addIndexPair.second); diff --git a/src/adapters/MathsatExpressionAdapter.h b/src/adapters/MathsatExpressionAdapter.h index 3795ca965..854b20542 100644 --- a/src/adapters/MathsatExpressionAdapter.h +++ b/src/adapters/MathsatExpressionAdapter.h @@ -265,9 +265,9 @@ namespace storm { msat_decl msatDeclaration; if (variable.getType().isBooleanType()) { msatDeclaration = msat_declare_function(env, variable.getName().c_str(), msat_get_bool_type(env)); - } else if (variable.getType().isUnboundedIntegerType()) { + } else if (variable.getType().isIntegerType()) { msatDeclaration = msat_declare_function(env, variable.getName().c_str(), msat_get_integer_type(env)); - } else if (variable.getType().isBoundedIntegerType()) { + } else if (variable.getType().isBitVectorType()) { msatDeclaration = msat_declare_function(env, variable.getName().c_str(), msat_get_bv_type(env, variable.getType().getWidth())); } else if (variable.getType().isRationalType()) { msatDeclaration = msat_declare_function(env, variable.getName().c_str(), msat_get_rational_type(env)); diff --git a/src/adapters/Z3ExpressionAdapter.h b/src/adapters/Z3ExpressionAdapter.h index a0b671d38..419bab938 100644 --- a/src/adapters/Z3ExpressionAdapter.h +++ b/src/adapters/Z3ExpressionAdapter.h @@ -274,13 +274,13 @@ namespace storm { case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: return 0 - childResult; case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: { - storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshAuxiliaryVariable(manager.getIntegerType()); + storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshVariable(manager.getIntegerType(), true); z3::expr floorVariable = context.int_const(freshAuxiliaryVariable.getName().c_str()); additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1)); return floorVariable; } case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:{ - storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshAuxiliaryVariable(manager.getIntegerType()); + storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshVariable(manager.getIntegerType(), true); z3::expr ceilVariable = context.int_const(freshAuxiliaryVariable.getName().c_str()); additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable))); return ceilVariable; @@ -310,9 +310,9 @@ namespace storm { z3::expr z3Variable(context); if (variable.getType().isBooleanType()) { z3Variable = context.bool_const(variable.getName().c_str()); - } else if (variable.getType().isUnboundedIntegerType()) { + } else if (variable.getType().isIntegerType()) { z3Variable = context.int_const(variable.getName().c_str()); - } else if (variable.getType().isBoundedIntegerType()) { + } else if (variable.getType().isBitVectorType()) { z3Variable = context.bv_const(variable.getName().c_str(), variable.getType().getWidth()); } else if (variable.getType().isRationalType()) { z3Variable = context.real_const(variable.getName().c_str()); diff --git a/src/counterexamples/MILPMinimalLabelSetGenerator.h b/src/counterexamples/MILPMinimalLabelSetGenerator.h index f3ea782dd..9be7e4333 100644 --- a/src/counterexamples/MILPMinimalLabelSetGenerator.h +++ b/src/counterexamples/MILPMinimalLabelSetGenerator.h @@ -67,13 +67,13 @@ namespace storm { * A helper struct capturing information about the variables of the MILP formulation. */ struct VariableInformation { - std::unordered_map labelToVariableMap; - std::unordered_map> stateToChoiceVariablesMap; - std::unordered_map initialStateToChoiceVariableMap; - std::unordered_map stateToProbabilityVariableMap; - std::string virtualInitialStateVariable; - std::unordered_map problematicStateToVariableMap; - std::unordered_map, std::string, PairHash> problematicTransitionToVariableMap; + std::unordered_map labelToVariableMap; + std::unordered_map> stateToChoiceVariablesMap; + std::unordered_map initialStateToChoiceVariableMap; + std::unordered_map stateToProbabilityVariableMap; + storm::expressions::Variable virtualInitialStateVariable; + std::unordered_map problematicStateToVariableMap; + std::unordered_map, storm::expressions::Variable, PairHash> problematicTransitionToVariableMap; uint_fast64_t numberOfVariables; VariableInformation() : numberOfVariables(0) {} @@ -168,15 +168,14 @@ namespace storm { * @param relevantLabels The set of relevant labels of the model. * @return A mapping from labels to variable indices. */ - static std::pair, uint_fast64_t> createLabelVariables(storm::solver::LpSolver& solver, boost::container::flat_set const& relevantLabels) { + static std::pair, uint_fast64_t> createLabelVariables(storm::solver::LpSolver& solver, boost::container::flat_set const& relevantLabels) { std::stringstream variableNameBuffer; - std::unordered_map resultingMap; + std::unordered_map resultingMap; for (auto const& label : relevantLabels) { variableNameBuffer.str(""); variableNameBuffer.clear(); variableNameBuffer << "label" << label; - resultingMap[label] = variableNameBuffer.str(); - solver.addBinaryVariable(resultingMap[label], 1); + resultingMap[label] = solver.addBinaryVariable(variableNameBuffer.str(), 1); } return std::make_pair(resultingMap, relevantLabels.size()); } @@ -189,10 +188,10 @@ namespace storm { * @param choiceInformation The information about the choices of the model. * @return A mapping from states to a list of choice variable indices. */ - static std::pair>, uint_fast64_t> createSchedulerVariables(storm::solver::LpSolver& solver, StateInformation const& stateInformation, ChoiceInformation const& choiceInformation) { + static std::pair>, uint_fast64_t> createSchedulerVariables(storm::solver::LpSolver& solver, StateInformation const& stateInformation, ChoiceInformation const& choiceInformation) { std::stringstream variableNameBuffer; uint_fast64_t numberOfVariablesCreated = 0; - std::unordered_map> resultingMap; + std::unordered_map> resultingMap; for (auto state : stateInformation.relevantStates) { resultingMap.emplace(state, std::list()); @@ -201,8 +200,7 @@ namespace storm { variableNameBuffer.str(""); variableNameBuffer.clear(); variableNameBuffer << "choice" << row << "in" << state; - resultingMap[state].push_back(variableNameBuffer.str()); - solver.addBinaryVariable(resultingMap[state].back()); + resultingMap[state].push_back(solver.addBinaryVariable(variableNameBuffer.str())); ++numberOfVariablesCreated; } } @@ -218,10 +216,10 @@ namespace storm { * @param stateInformation The information about the states of the model. * @return A mapping from initial states to choice variable indices. */ - static std::pair, uint_fast64_t> createInitialChoiceVariables(storm::solver::LpSolver& solver, storm::models::Mdp const& labeledMdp, StateInformation const& stateInformation) { + static std::pair, uint_fast64_t> createInitialChoiceVariables(storm::solver::LpSolver& solver, storm::models::Mdp const& labeledMdp, StateInformation const& stateInformation) { std::stringstream variableNameBuffer; uint_fast64_t numberOfVariablesCreated = 0; - std::unordered_map resultingMap; + std::unordered_map resultingMap; for (auto initialState : labeledMdp.getLabeledStates("init")) { // Only consider this initial state if it is relevant. @@ -229,8 +227,7 @@ namespace storm { variableNameBuffer.str(""); variableNameBuffer.clear(); variableNameBuffer << "init" << initialState; - resultingMap[initialState] = variableNameBuffer.str(); - solver.addBinaryVariable(resultingMap[initialState]); + resultingMap[initialState] = solver.addBinaryVariable(variableNameBuffer.str()); ++numberOfVariablesCreated; } } @@ -244,17 +241,16 @@ namespace storm { * @param stateInformation The information about the states in the model. * @return A mapping from states to the index of the corresponding probability variables. */ - static std::pair, uint_fast64_t> createProbabilityVariables(storm::solver::LpSolver& solver, StateInformation const& stateInformation) { + static std::pair, uint_fast64_t> createProbabilityVariables(storm::solver::LpSolver& solver, StateInformation const& stateInformation) { std::stringstream variableNameBuffer; uint_fast64_t numberOfVariablesCreated = 0; - std::unordered_map resultingMap; + std::unordered_map resultingMap; for (auto state : stateInformation.relevantStates) { variableNameBuffer.str(""); variableNameBuffer.clear(); variableNameBuffer << "p" << state; - resultingMap[state] = variableNameBuffer.str(); - solver.addBoundedContinuousVariable(resultingMap[state], 0, 1); + resultingMap[state] = solver.addBoundedContinuousVariable(variableNameBuffer.str(), 0, 1); ++numberOfVariablesCreated; } return std::make_pair(resultingMap, numberOfVariablesCreated); @@ -268,12 +264,11 @@ namespace storm { * label-minimal subsystem of maximal probability is computed. * @return The index of the variable for the probability of the virtual initial state. */ - static std::pair createVirtualInitialStateVariable(storm::solver::LpSolver& solver, bool maximizeProbability = false) { + static std::pair createVirtualInitialStateVariable(storm::solver::LpSolver& solver, bool maximizeProbability = false) { std::stringstream variableNameBuffer; variableNameBuffer << "pinit"; - std::string variableName = variableNameBuffer.str(); - solver.addBoundedContinuousVariable(variableName, 0, 1); - return std::make_pair(variableName, 1); + storm::expressions::Variable variable = solver.addBoundedContinuousVariable(variableNameBuffer.str(), 0, 1); + return std::make_pair(variable, 1); } /*! @@ -284,10 +279,10 @@ namespace storm { * @param stateInformation The information about the states in the model. * @return A mapping from problematic states to the index of the corresponding variables. */ - static std::pair, uint_fast64_t> createProblematicStateVariables(storm::solver::LpSolver& solver, storm::models::Mdp const& labeledMdp, StateInformation const& stateInformation, ChoiceInformation const& choiceInformation) { + static std::pair, uint_fast64_t> createProblematicStateVariables(storm::solver::LpSolver& solver, storm::models::Mdp const& labeledMdp, StateInformation const& stateInformation, ChoiceInformation const& choiceInformation) { std::stringstream variableNameBuffer; uint_fast64_t numberOfVariablesCreated = 0; - std::unordered_map resultingMap; + std::unordered_map resultingMap; for (auto state : stateInformation.problematicStates) { // First check whether there is not already a variable for this state and advance to the next state @@ -296,8 +291,7 @@ namespace storm { variableNameBuffer.str(""); variableNameBuffer.clear(); variableNameBuffer << "r" << state; - resultingMap[state] = variableNameBuffer.str(); - solver.addBoundedContinuousVariable(resultingMap[state], 0, 1); + resultingMap[state] = solver.addBoundedContinuousVariable(variableNameBuffer.str(), 0, 1); ++numberOfVariablesCreated; } @@ -309,8 +303,7 @@ namespace storm { variableNameBuffer.str(""); variableNameBuffer.clear(); variableNameBuffer << "r" << successorEntry.getColumn(); - resultingMap[successorEntry.getColumn()] = variableNameBuffer.str(); - solver.addBoundedContinuousVariable(resultingMap[successorEntry.getColumn()], 0, 1); + resultingMap[successorEntry.getColumn()] = solver.addBoundedContinuousVariable(variableNameBuffer.str(), 0, 1); ++numberOfVariablesCreated; } } @@ -329,10 +322,10 @@ namespace storm { * @param choiceInformation The information about the choices in the model. * @return A mapping from problematic choices to the index of the corresponding variables. */ - static std::pair, std::string, PairHash>, uint_fast64_t> createProblematicChoiceVariables(storm::solver::LpSolver& solver, storm::models::Mdp const& labeledMdp, StateInformation const& stateInformation, ChoiceInformation const& choiceInformation) { + static std::pair, storm::expressions::Variable, PairHash>, uint_fast64_t> createProblematicChoiceVariables(storm::solver::LpSolver& solver, storm::models::Mdp const& labeledMdp, StateInformation const& stateInformation, ChoiceInformation const& choiceInformation) { std::stringstream variableNameBuffer; uint_fast64_t numberOfVariablesCreated = 0; - std::unordered_map, std::string, PairHash> resultingMap; + std::unordered_map, storm::expressions::Variable, PairHash> resultingMap; for (auto state : stateInformation.problematicStates) { std::list const& relevantChoicesForState = choiceInformation.relevantChoicesForRelevantStates.at(state); @@ -342,8 +335,7 @@ namespace storm { variableNameBuffer.str(""); variableNameBuffer.clear(); variableNameBuffer << "t" << state << "to" << successorEntry.getColumn(); - resultingMap[std::make_pair(state, successorEntry.getColumn())] = variableNameBuffer.str(); - solver.addBinaryVariable(resultingMap[std::make_pair(state, successorEntry.getColumn())]); + resultingMap[std::make_pair(state, successorEntry.getColumn())] = solver.addBinaryVariable(variableNameBuffer.str()); ++numberOfVariablesCreated; } } @@ -367,43 +359,43 @@ namespace storm { VariableInformation result; // Create variables for involved labels. - std::pair, uint_fast64_t> labelVariableResult = createLabelVariables(solver, choiceInformation.allRelevantLabels); + std::pair, uint_fast64_t> labelVariableResult = createLabelVariables(solver, choiceInformation.allRelevantLabels); result.labelToVariableMap = std::move(labelVariableResult.first); result.numberOfVariables += labelVariableResult.second; LOG4CPLUS_DEBUG(logger, "Created variables for labels."); // Create scheduler variables for relevant states and their actions. - std::pair>, uint_fast64_t> schedulerVariableResult = createSchedulerVariables(solver, stateInformation, choiceInformation); + std::pair>, uint_fast64_t> schedulerVariableResult = createSchedulerVariables(solver, stateInformation, choiceInformation); result.stateToChoiceVariablesMap = std::move(schedulerVariableResult.first); result.numberOfVariables += schedulerVariableResult.second; LOG4CPLUS_DEBUG(logger, "Created variables for nondeterministic choices."); // Create scheduler variables for nondeterministically choosing an initial state. - std::pair, uint_fast64_t> initialChoiceVariableResult = createInitialChoiceVariables(solver, labeledMdp, stateInformation); + std::pair, uint_fast64_t> initialChoiceVariableResult = createInitialChoiceVariables(solver, labeledMdp, stateInformation); result.initialStateToChoiceVariableMap = std::move(initialChoiceVariableResult.first); result.numberOfVariables += initialChoiceVariableResult.second; LOG4CPLUS_DEBUG(logger, "Created variables for the nondeterministic choice of the initial state."); // Create variables for probabilities for all relevant states. - std::pair, uint_fast64_t> probabilityVariableResult = createProbabilityVariables(solver, stateInformation); + std::pair, uint_fast64_t> probabilityVariableResult = createProbabilityVariables(solver, stateInformation); result.stateToProbabilityVariableMap = std::move(probabilityVariableResult.first); result.numberOfVariables += probabilityVariableResult.second; LOG4CPLUS_DEBUG(logger, "Created variables for the reachability probabilities."); // Create a probability variable for a virtual initial state that nondeterministically chooses one of the system's real initial states as its target state. - std::pair virtualInitialStateVariableResult = createVirtualInitialStateVariable(solver); + std::pair virtualInitialStateVariableResult = createVirtualInitialStateVariable(solver); result.virtualInitialStateVariable = virtualInitialStateVariableResult.first; result.numberOfVariables += virtualInitialStateVariableResult.second; LOG4CPLUS_DEBUG(logger, "Created variables for the virtual initial state."); // Create variables for problematic states. - std::pair, uint_fast64_t> problematicStateVariableResult = createProblematicStateVariables(solver, labeledMdp, stateInformation, choiceInformation); + std::pair, uint_fast64_t> problematicStateVariableResult = createProblematicStateVariables(solver, labeledMdp, stateInformation, choiceInformation); result.problematicStateToVariableMap = std::move(problematicStateVariableResult.first); result.numberOfVariables += problematicStateVariableResult.second; LOG4CPLUS_DEBUG(logger, "Created variables for the problematic states."); // Create variables for problematic choices. - std::pair, std::string, PairHash>, uint_fast64_t> problematicTransitionVariableResult = createProblematicChoiceVariables(solver, labeledMdp, stateInformation, choiceInformation); + std::pair, storm::expressions::Variable, PairHash>, uint_fast64_t> problematicTransitionVariableResult = createProblematicChoiceVariables(solver, labeledMdp, stateInformation, choiceInformation); result.problematicTransitionToVariableMap = problematicTransitionVariableResult.first; result.numberOfVariables += problematicTransitionVariableResult.second; LOG4CPLUS_DEBUG(logger, "Created variables for the problematic choices."); @@ -430,9 +422,9 @@ namespace storm { static uint_fast64_t assertProbabilityGreaterThanThreshold(storm::solver::LpSolver& solver, storm::models::Mdp const& labeledMdp, VariableInformation const& variableInformation, double probabilityThreshold, bool strictBound) { storm::expressions::Expression constraint; if (strictBound) { - constraint = storm::expressions::Expression::createDoubleVariable(variableInformation.virtualInitialStateVariable) > storm::expressions::Expression::createDoubleLiteral(probabilityThreshold); + constraint = variableInformation.virtualInitialStateVariable > solver.getConstant(probabilityThreshold); } else { - constraint = storm::expressions::Expression::createDoubleVariable(variableInformation.virtualInitialStateVariable) >= storm::expressions::Expression::createDoubleLiteral(probabilityThreshold); + constraint = variableInformation.virtualInitialStateVariable >= solver.getConstant(probabilityThreshold); } solver.addConstraint("ProbGreaterThreshold", constraint); return 1; @@ -450,14 +442,14 @@ namespace storm { // Assert that the policy chooses at most one action in each state of the system. uint_fast64_t numberOfConstraintsCreated = 0; for (auto state : stateInformation.relevantStates) { - std::list const& choiceVariableIndices = variableInformation.stateToChoiceVariablesMap.at(state); - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleLiteral(0); + std::list const& choiceVariableIndices = variableInformation.stateToChoiceVariablesMap.at(state); + storm::expressions::Expression constraint = solver.getConstant(0); for (auto const& choiceVariable : choiceVariableIndices) { - constraint = constraint + storm::expressions::Expression::createIntegerVariable(choiceVariable); + constraint = constraint + choiceVariable; } - constraint = constraint <= storm::expressions::Expression::createDoubleLiteral(1); + constraint = constraint <= solver.getConstant(1); solver.addConstraint("ValidPolicy" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; @@ -465,11 +457,11 @@ namespace storm { // Now assert that the virtual initial state picks exactly one initial state from the system as its // successor state. - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleLiteral(0); + storm::expressions::Expression constraint = solver.getConstant(0); for (auto const& initialStateVariablePair : variableInformation.initialStateToChoiceVariableMap) { - constraint = constraint + storm::expressions::Expression::createIntegerVariable(initialStateVariablePair.second); + constraint = constraint + initialStateVariablePair.second; } - constraint = constraint == storm::expressions::Expression::createDoubleLiteral(1); + constraint = constraint == solver.getConstant(1); solver.addConstraint("VirtualInitialStateChoosesOneInitialState", constraint); ++numberOfConstraintsCreated; @@ -493,10 +485,10 @@ namespace storm { std::vector> const& choiceLabeling = labeledMdp.getChoiceLabeling(); for (auto state : stateInformation.relevantStates) { - std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(state).begin(); + std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(state).begin(); for (auto choice : choiceInformation.relevantChoicesForRelevantStates.at(state)) { for (auto label : choiceLabeling[choice]) { - storm::expressions::Expression constraint = storm::expressions::Expression::createIntegerVariable(variableInformation.labelToVariableMap.at(label)) - storm::expressions::Expression::createIntegerVariable(*choiceVariableIterator) >= storm::expressions::Expression::createDoubleLiteral(0); + storm::expressions::Expression constraint = variableInformation.labelToVariableMap.at(label) - *choiceVariableIterator >= solver.getConstant(0); solver.addConstraint("ChoicesImplyLabels" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; } @@ -519,11 +511,11 @@ namespace storm { static uint_fast64_t assertZeroProbabilityWithoutChoice(storm::solver::LpSolver& solver, StateInformation const& stateInformation, ChoiceInformation const& choiceInformation, VariableInformation const& variableInformation) { uint_fast64_t numberOfConstraintsCreated = 0; for (auto state : stateInformation.relevantStates) { - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleVariable(variableInformation.stateToProbabilityVariableMap.at(state)); + storm::expressions::Expression constraint = variableInformation.stateToProbabilityVariableMap.at(state); for (auto const& choiceVariable : variableInformation.stateToChoiceVariablesMap.at(state)) { - constraint = constraint - storm::expressions::Expression::createIntegerVariable(choiceVariable); + constraint = constraint - choiceVariable; } - constraint = constraint <= storm::expressions::Expression::createDoubleLiteral(0); + constraint = constraint <= solver.getConstant(0); solver.addConstraint("ProbabilityIsZeroIfNoAction" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; } @@ -544,20 +536,20 @@ namespace storm { static uint_fast64_t assertReachabilityProbabilities(storm::solver::LpSolver& solver, storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, StateInformation const& stateInformation, ChoiceInformation const& choiceInformation, VariableInformation const& variableInformation) { uint_fast64_t numberOfConstraintsCreated = 0; for (auto state : stateInformation.relevantStates) { - std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(state).begin(); + std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(state).begin(); for (auto choice : choiceInformation.relevantChoicesForRelevantStates.at(state)) { - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleVariable(variableInformation.stateToProbabilityVariableMap.at(state)); + storm::expressions::Expression constraint = variableInformation.stateToProbabilityVariableMap.at(state); double rightHandSide = 1; for (auto const& successorEntry : labeledMdp.getTransitionMatrix().getRow(choice)) { if (stateInformation.relevantStates.get(successorEntry.getColumn())) { - constraint = constraint - storm::expressions::Expression::createDoubleLiteral(successorEntry.getValue()) * storm::expressions::Expression::createDoubleVariable(variableInformation.stateToProbabilityVariableMap.at(successorEntry.getColumn())); + constraint = constraint - solver.getConstant(successorEntry.getValue()) * variableInformation.stateToProbabilityVariableMap.at(successorEntry.getColumn()); } else if (psiStates.get(successorEntry.getColumn())) { rightHandSide += successorEntry.getValue(); } } - constraint = constraint + storm::expressions::Expression::createIntegerVariable(*choiceVariableIterator) <= storm::expressions::Expression::createDoubleLiteral(rightHandSide); + constraint = constraint + *choiceVariableIterator <= solver.getConstant(rightHandSide); solver.addConstraint("ReachabilityProbabilities" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; @@ -568,7 +560,7 @@ namespace storm { // Make sure that the virtual initial state is being assigned the probability from the initial state // that it selected as a successor state. for (auto const& initialStateVariablePair : variableInformation.initialStateToChoiceVariableMap) { - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleVariable(variableInformation.virtualInitialStateVariable) - storm::expressions::Expression::createDoubleVariable(variableInformation.stateToProbabilityVariableMap.at(initialStateVariablePair.first)) + storm::expressions::Expression::createDoubleVariable(initialStateVariablePair.second) <= storm::expressions::Expression::createDoubleLiteral(1); + storm::expressions::Expression constraint = variableInformation.virtualInitialStateVariable - variableInformation.stateToProbabilityVariableMap.at(initialStateVariablePair.first) + initialStateVariablePair.second <= solver.getConstant(1); solver.addConstraint("VirtualInitialStateHasCorrectProbability" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; } @@ -591,7 +583,7 @@ namespace storm { for (auto stateListPair : choiceInformation.problematicChoicesForProblematicStates) { for (auto problematicChoice : stateListPair.second) { - std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(stateListPair.first).begin(); + std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(stateListPair.first).begin(); for (auto relevantChoice : choiceInformation.relevantChoicesForRelevantStates.at(stateListPair.first)) { if (relevantChoice == problematicChoice) { break; @@ -599,11 +591,11 @@ namespace storm { ++choiceVariableIterator; } - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleVariable(*choiceVariableIterator); + storm::expressions::Expression constraint = *choiceVariableIterator; for (auto const& successorEntry : labeledMdp.getTransitionMatrix().getRow(problematicChoice)) { - constraint = constraint - storm::expressions::Expression::createDoubleVariable(variableInformation.problematicTransitionToVariableMap.at(std::make_pair(stateListPair.first, successorEntry.getColumn()))); + constraint = constraint - variableInformation.problematicTransitionToVariableMap.at(std::make_pair(stateListPair.first, successorEntry.getColumn())); } - constraint = constraint <= storm::expressions::Expression::createDoubleLiteral(0); + constraint = constraint <= solver.getConstant(0); solver.addConstraint("UnproblematicStateReachable" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; @@ -613,10 +605,10 @@ namespace storm { for (auto state : stateInformation.problematicStates) { for (auto problematicChoice : choiceInformation.problematicChoicesForProblematicStates.at(state)) { for (auto const& successorEntry : labeledMdp.getTransitionMatrix().getRow(problematicChoice)) { - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleVariable(variableInformation.problematicStateToVariableMap.at(state)); - constraint = constraint - storm::expressions::Expression::createDoubleVariable(variableInformation.problematicStateToVariableMap.at(successorEntry.getColumn())); - constraint = constraint + storm::expressions::Expression::createDoubleVariable(variableInformation.problematicTransitionToVariableMap.at(std::make_pair(state, successorEntry.getColumn()))); - constraint = constraint < storm::expressions::Expression::createDoubleLiteral(1); + storm::expressions::Expression constraint = variableInformation.problematicStateToVariableMap.at(state); + constraint = constraint - variableInformation.problematicStateToVariableMap.at(successorEntry.getColumn()); + constraint = constraint + variableInformation.problematicTransitionToVariableMap.at(std::make_pair(state, successorEntry.getColumn())); + constraint = constraint < solver.getConstant(1); solver.addConstraint("UnproblematicStateReachable" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; @@ -640,7 +632,7 @@ namespace storm { uint_fast64_t numberOfConstraintsCreated = 0; for (auto label : choiceInformation.knownLabels) { - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleVariable(variableInformation.labelToVariableMap.at(label)) == storm::expressions::Expression::createDoubleLiteral(1); + storm::expressions::Expression constraint = variableInformation.labelToVariableMap.at(label) == solver.getConstant(0); solver.addConstraint("KnownLabels" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; } @@ -666,7 +658,7 @@ namespace storm { for (auto state : stateInformation.relevantStates) { // Assert that all states, that select an action, this action either has a non-zero probability to // go to a psi state directly, or in the successor states, at least one action is selected as well. - std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(state).begin(); + std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(state).begin(); for (auto choice : choiceInformation.relevantChoicesForRelevantStates.at(state)) { bool psiStateReachableInOneStep = false; for (auto const& successorEntry : labeledMdp.getTransitionMatrix().getRow(choice)) { @@ -676,17 +668,17 @@ namespace storm { } if (!psiStateReachableInOneStep) { - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleVariable(*choiceVariableIterator); + storm::expressions::Expression constraint = *choiceVariableIterator; for (auto const& successorEntry : labeledMdp.getTransitionMatrix().getRow(choice)) { if (state != successorEntry.getColumn() && stateInformation.relevantStates.get(successorEntry.getColumn())) { - std::list const& successorChoiceVariableIndices = variableInformation.stateToChoiceVariablesMap.at(successorEntry.getColumn()); + std::list const& successorChoiceVariableIndices = variableInformation.stateToChoiceVariablesMap.at(successorEntry.getColumn()); for (auto const& choiceVariable : successorChoiceVariableIndices) { - constraint = constraint - storm::expressions::Expression::createDoubleVariable(choiceVariable); + constraint = constraint - choiceVariable; } } } - constraint = constraint <= storm::expressions::Expression::createDoubleLiteral(1); + constraint = constraint <= solver.getConstant(1); solver.addConstraint("SchedulerCuts" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; @@ -697,10 +689,10 @@ namespace storm { // For all states assert that there is either a selected incoming transition in the subsystem or the // state is the chosen initial state if there is one selected action in the current state. - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleLiteral(0); + storm::expressions::Expression constraint = solver.getConstant(0); for (auto const& choiceVariable : variableInformation.stateToChoiceVariablesMap.at(state)) { - constraint = constraint + storm::expressions::Expression::createDoubleVariable(choiceVariable); + constraint = constraint + choiceVariable; } // Compute the set of predecessors. @@ -717,7 +709,7 @@ namespace storm { continue; } - std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(predecessor).begin(); + std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(predecessor).begin(); for (auto relevantChoice : choiceInformation.relevantChoicesForRelevantStates.at(predecessor)) { bool choiceTargetsCurrentState = false; @@ -731,7 +723,7 @@ namespace storm { // If it does, we can add the choice to the sum. if (choiceTargetsCurrentState) { - constraint = constraint - storm::expressions::Expression::createDoubleVariable(*choiceVariableIterator); + constraint = constraint - *choiceVariableIterator; } ++choiceVariableIterator; } @@ -740,27 +732,27 @@ namespace storm { // If the current state is an initial state and is selected as a successor state by the virtual // initial state, then this also justifies making a choice in the current state. if (labeledMdp.getLabeledStates("init").get(state)) { - constraint = constraint - storm::expressions::Expression::createDoubleVariable(variableInformation.initialStateToChoiceVariableMap.at(state)); + constraint = constraint - variableInformation.initialStateToChoiceVariableMap.at(state); } - constraint = constraint <= storm::expressions::Expression::createDoubleLiteral(0); + constraint = constraint <= solver.getConstant(0); solver.addConstraint("SchedulerCuts" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; } // Assert that at least one initial state selects at least one action. - storm::expressions::Expression constraint = storm::expressions::Expression::createDoubleLiteral(0); + storm::expressions::Expression constraint = solver.getConstant(0); for (auto initialState : labeledMdp.getLabeledStates("init")) { for (auto const& choiceVariable : variableInformation.stateToChoiceVariablesMap.at(initialState)) { - constraint = constraint + storm::expressions::Expression::createDoubleVariable(choiceVariable); + constraint = constraint + choiceVariable; } } - constraint = constraint >= storm::expressions::Expression::createDoubleLiteral(1); + constraint = constraint >= solver.getConstant(1); solver.addConstraint("SchedulerCuts" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; // Add constraints that ensure at least one choice is selected that targets a psi state. - constraint = storm::expressions::Expression::createDoubleLiteral(0); + constraint = solver.getConstant(0); std::unordered_set predecessors; for (auto psiState : psiStates) { // Compute the set of predecessors. @@ -777,7 +769,7 @@ namespace storm { continue; } - std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(predecessor).begin(); + std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesMap.at(predecessor).begin(); for (auto relevantChoice : choiceInformation.relevantChoicesForRelevantStates.at(predecessor)) { bool choiceTargetsPsiState = false; @@ -791,12 +783,12 @@ namespace storm { // If it does, we can add the choice to the sum. if (choiceTargetsPsiState) { - constraint = constraint + storm::expressions::Expression::createDoubleVariable(*choiceVariableIterator); + constraint = constraint + *choiceVariableIterator; } ++choiceVariableIterator; } } - constraint = constraint >= storm::expressions::Expression::createDoubleLiteral(1); + constraint = constraint >= solver.getConstant(1); solver.addConstraint("SchedulerCuts" + std::to_string(numberOfConstraintsCreated), constraint); ++numberOfConstraintsCreated; @@ -894,7 +886,7 @@ namespace storm { std::map result; for (auto state : stateInformation.relevantStates) { - std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesIndexMap.at(state).begin(); + std::list::const_iterator choiceVariableIterator = variableInformation.stateToChoiceVariablesIndexMap.at(state).begin(); for (auto choice : choiceInformation.relevantChoicesForRelevantStates.at(state)) { bool choiceTaken = solver.getBinaryValue(*choiceVariableIterator); ++choiceVariableIterator; diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index 3f8748e44..891d73ad2 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -1,24 +1,10 @@ -/* - * SMTMinimalCommandSetGenerator.h - * - * Created on: 01.10.2013 - * Author: Christian Dehnert - */ - #ifndef STORM_COUNTEREXAMPLES_SMTMINIMALCOMMANDSETGENERATOR_MDP_H_ #define STORM_COUNTEREXAMPLES_SMTMINIMALCOMMANDSETGENERATOR_MDP_H_ #include #include -// To detect whether the usage of Z3 is possible, this include is neccessary. -#include "storm-config.h" - -// If we have Z3 available, we have to include the C++ header. -#ifdef STORM_HAVE_Z3 -#include "z3++.h" -#include "src/adapters/Z3ExpressionAdapter.h" -#endif +#include "src/solver/Z3SmtSolver.h" #include "src/storage/prism/Program.h" #include "src/storage/expressions/Expression.h" @@ -55,20 +41,23 @@ namespace storm { }; struct VariableInformation { + // The manager responsible for the constraints we are building. + std::shared_ptr manager; + // The variables associated with the relevant labels. - std::vector labelVariables; + std::vector labelVariables; // A mapping from relevant labels to their indices in the variable vector. std::map labelToIndexMap; // A set of original auxiliary variables needed for the Fu-Malik procedure. - std::vector originalAuxiliaryVariables; + std::vector originalAuxiliaryVariables; // A set of auxiliary variables that may be modified by the MaxSAT procedure. - std::vector auxiliaryVariables; + std::vector auxiliaryVariables; // A vector of variables that can be used to constrain the number of variables that are set to true. - std::vector adderVariables; + std::vector adderVariables; // A flag whether or not there are variables reserved for encoding reachability of a target state. bool hasReachabilityVariables; @@ -80,13 +69,13 @@ namespace storm { // A vector of variables associated with each pair of relevant states (s, s') such that s' is // a successor of s. - std::vector statePairVariables; + std::vector statePairVariables; // A mapping from relevant states to the index with the corresponding order variable in the state order variable vector. std::map relevantStatesToOrderVariableIndexMap; // A vector of variables that holds all state order variables. - std::vector stateOrderVariables; + std::vector stateOrderVariables; }; /*! @@ -155,14 +144,15 @@ namespace storm { } /*! - * Creates all necessary base expressions for the relevant labels. + * Creates all necessary variables. * - * @param context The Z3 context in which to create the expressions. + * @param manager The manager in which to create the variables. * @param relevantCommands A set of relevant labels for which to create the expressions. * @return A mapping from relevant labels to their corresponding expressions. */ - static VariableInformation createVariables(z3::context& context, storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, RelevancyInformation const& relevancyInformation, bool createReachabilityVariables) { + static VariableInformation createVariables(std::shared_ptr const& manager, storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, RelevancyInformation const& relevancyInformation, bool createReachabilityVariables) { VariableInformation variableInformation; + variableInformation.manager = manager; // Create stringstream to build expression names. std::stringstream variableName; @@ -176,14 +166,14 @@ namespace storm { variableName.str(""); variableName << "c" << label; - variableInformation.labelVariables.push_back(context.bool_const(variableName.str().c_str())); + variableInformation.labelVariables.push_back(manager->declareBooleanVariable(variableName.str())); // Clear contents of the stream to construct new expression name. variableName.clear(); variableName.str(""); variableName << "h" << label; - variableInformation.originalAuxiliaryVariables.push_back(context.bool_const(variableName.str().c_str())); + variableInformation.originalAuxiliaryVariables.push_back(manager->declareBooleanVariable(variableName.str())); } // A mapping from each pair of adjacent relevant states to their index in the corresponding variable vector. @@ -191,7 +181,7 @@ namespace storm { // A vector of variables associated with each pair of relevant states (s, s') such that s' is // a successor of s. - std::vector statePairVariables; + std::vector statePairVariables; // Create variables needed for encoding reachability of a target state if requested. if (createReachabilityVariables) { @@ -208,7 +198,7 @@ namespace storm { variableName.str(""); variableName << "o" << state; - variableInformation.stateOrderVariables.push_back(context.real_const(variableName.str().c_str())); + variableInformation.stateOrderVariables.push_back(manager->declareRationalVariable(variableName.str())); for (auto const& relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) { for (auto const& entry : transitionMatrix.getRow(relevantChoice)) { @@ -230,7 +220,7 @@ namespace storm { variableName.str(""); variableName << "t" << state << "_" << entry.getColumn(); - variableInformation.statePairVariables.push_back(context.bool_const(variableName.str().c_str())); + variableInformation.statePairVariables.push_back(manager->declareBooleanVariable(variableName.str())); } } } @@ -244,7 +234,7 @@ namespace storm { variableName.str(""); variableName << "o" << psiState; - variableInformation.stateOrderVariables.push_back(context.real_const(variableName.str().c_str())); + variableInformation.stateOrderVariables.push_back(manager->declareRationalVariable(variableName.str())); } } @@ -277,7 +267,7 @@ namespace storm { * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. */ - static void assertExplicitCuts(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { + static void assertExplicitCuts(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, storm::solver::SmtSolver& solver) { // Walk through the MDP and // * identify labels enabled in initial states // * identify labels that can directly precede a given action @@ -364,7 +354,7 @@ namespace storm { LOG4CPLUS_DEBUG(logger, "Asserting initial combination is taken."); { - std::vector formulae; + std::vector formulae; // Start by asserting that we take at least one initial label. We may do so only if there is no initial // combination that is already known. Otherwise this condition would be too strong. @@ -376,7 +366,7 @@ namespace storm { initialCombinationKnown = true; break; } else { - z3::expr conj = context.bool_val(true); + storm::expressions::Expression conj = variableInformation.manager->boolean(true); for (auto label : tmpSet) { conj = conj && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); } @@ -384,13 +374,13 @@ namespace storm { } } if (!initialCombinationKnown) { - assertDisjunction(context, solver, formulae); + assertDisjunction(solver, formulae, *variableInformation.manager); } } LOG4CPLUS_DEBUG(logger, "Asserting target combination is taken."); { - std::vector formulae; + std::vector formulae; // Likewise, if no target combination is known, we may assert that there is at least one. bool targetCombinationKnown = false; @@ -401,7 +391,7 @@ namespace storm { targetCombinationKnown = true; break; } else { - z3::expr conj = context.bool_val(true); + storm::expressions::Expression conj = variableInformation.manager->boolean(true); for (auto label : tmpSet) { conj = conj && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); } @@ -409,7 +399,7 @@ namespace storm { } } if (!targetCombinationKnown) { - assertDisjunction(context, solver, formulae); + assertDisjunction(solver, formulae, *variableInformation.manager); } } @@ -427,7 +417,7 @@ namespace storm { LOG4CPLUS_DEBUG(logger, "Asserting taken labels are followed by another label if they are not a target label."); // Now assert that for each non-target label, we take a following label. for (auto const& labelSetFollowingSetsPair : followingLabels) { - std::vector formulae; + std::vector formulae; // Only build a constraint if the combination does not lead to a target state and // no successor set is already known. @@ -448,7 +438,7 @@ namespace storm { std::set_difference(followingSet.begin(), followingSet.end(), relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end(), std::inserter(tmpSet, tmpSet.end())); // Construct an expression that enables all unknown labels of the current following set. - z3::expr conj = context.bool_val(true); + storm::expressions::Expression conj = variableInformation.manager->boolean(true); for (auto label : tmpSet) { conj = conj && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); } @@ -460,13 +450,13 @@ namespace storm { // This is because it could be that the commands are taken to enable other synchronizations. Therefore, we need // to add an additional clause that says that the right-hand side of the implication is also true if all commands // of the current choice have enabled synchronization options. - z3::expr finalDisjunct = context.bool_val(false); + storm::expressions::Expression finalDisjunct = variableInformation.manager->boolean(false); for (auto label : labelSetFollowingSetsPair.first) { - z3::expr alternativeExpressionForLabel = context.bool_val(false); + storm::expressions::Expression alternativeExpressionForLabel = variableInformation.manager->boolean(false); std::set> const& synchsForCommand = synchronizingLabels.at(label); for (auto const& synchSet : synchsForCommand) { - z3::expr alternativeExpression = context.bool_val(true); + storm::expressions::Expression alternativeExpression = variableInformation.manager->boolean(true); // If the current synchSet is the same as left-hand side of the implication, we need to skip it. if (synchSet == labelSetFollowingSetsPair.first) continue; @@ -483,9 +473,9 @@ namespace storm { alternativeExpression = alternativeExpression && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); } - z3::expr disjunctionOverSuccessors = context.bool_val(false); + storm::expressions::Expression disjunctionOverSuccessors = variableInformation.manager->boolean(false); for (auto successorSet : followingLabels.at(synchSet)) { - z3::expr conjunctionOverLabels = context.bool_val(true); + storm::expressions::Expression conjunctionOverLabels = variableInformation.manager->boolean(true); for (auto label : successorSet) { if (relevancyInformation.knownLabels.find(label) == relevancyInformation.knownLabels.end()) { conjunctionOverLabels = conjunctionOverLabels && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); @@ -507,7 +497,7 @@ namespace storm { } if (formulae.size() > 0) { - assertDisjunction(context, solver, formulae); + assertDisjunction(solver, formulae, *variableInformation.manager); } } } @@ -516,7 +506,7 @@ namespace storm { // Finally, assert that if we take one of the synchronizing labels, we also take one of the combinations // the label appears in. for (auto const& labelSynchronizingSetsPair : synchronizingLabels) { - std::vector formulae; + std::vector formulae; if (relevancyInformation.knownLabels.find(labelSynchronizingSetsPair.first) == relevancyInformation.knownLabels.end()) { formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelSynchronizingSetsPair.first))); @@ -526,7 +516,7 @@ namespace storm { // known, which means we must not assert anything. bool allImplicantsKnownForOneSet = false; for (auto const& synchronizingSet : labelSynchronizingSetsPair.second) { - z3::expr currentCombination = context.bool_val(true); + storm::expressions::Expression currentCombination = variableInformation.manager->boolean(true); bool allImplicantsKnownForCurrentSet = true; for (auto label : synchronizingSet) { if (relevancyInformation.knownLabels.find(label) == relevancyInformation.knownLabels.end() && label != labelSynchronizingSetsPair.first) { @@ -543,7 +533,7 @@ namespace storm { } if (!allImplicantsKnownForOneSet) { - assertDisjunction(context, solver, formulae); + assertDisjunction(solver, formulae, *variableInformation.manager); } } } @@ -553,10 +543,9 @@ namespace storm { * suboptimal solutions. * * @param program The symbolic representation of the model in terms of a program. - * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. */ - static void assertSymbolicCuts(storm::prism::Program const& program, storm::models::Mdp const& labeledMdp, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { + static void assertSymbolicCuts(storm::prism::Program& program, storm::models::Mdp const& labeledMdp, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, storm::solver::SmtSolver& solver) { // A container storing the label sets that may precede a given label set. std::map, std::set>> precedingLabelSets; @@ -600,45 +589,42 @@ namespace storm { } } - // Create a context and register all variables of the program with their correct type. - z3::context localContext; - z3::solver localSolver(localContext); - std::map solverVariables; - for (auto const& booleanVariable : program.getGlobalBooleanVariables()) { - solverVariables.emplace(booleanVariable.getName(), localContext.bool_const(booleanVariable.getName().c_str())); - } - for (auto const& integerVariable : program.getGlobalIntegerVariables()) { - solverVariables.emplace(integerVariable.getName(), localContext.int_const(integerVariable.getName().c_str())); - } - - for (auto const& module : program.getModules()) { - for (auto const& booleanVariable : module.getBooleanVariables()) { - solverVariables.emplace(booleanVariable.getName(), localContext.bool_const(booleanVariable.getName().c_str())); - } - for (auto const& integerVariable : module.getIntegerVariables()) { - solverVariables.emplace(integerVariable.getName(), localContext.int_const(integerVariable.getName().c_str())); - } - } - - storm::adapters::Z3ExpressionAdapter expressionAdapter(localContext, false, solverVariables); + // Create a new solver over the same variables as the given program to use it for determining the symbolic + // cuts. + std::unique_ptr localSolver(new storm::solver::Z3SmtSolver(program.getManager())); + storm::expressions::ExpressionManager const& localManager = program.getManager(); +// +// // Create a context and register all variables of the program with their correct type. +// z3::context localContext; +// z3::solver localSolver(localContext); +// std::map solverVariables; +// for (auto const& booleanVariable : program.getGlobalBooleanVariables()) { +// solverVariables.emplace(booleanVariable.getName(), localContext.bool_const(booleanVariable.getName().c_str())); +// } +// for (auto const& integerVariable : program.getGlobalIntegerVariables()) { +// solverVariables.emplace(integerVariable.getName(), localContext.int_const(integerVariable.getName().c_str())); +// } +// +// for (auto const& module : program.getModules()) { +// for (auto const& booleanVariable : module.getBooleanVariables()) { +// solverVariables.emplace(booleanVariable.getName(), localContext.bool_const(booleanVariable.getName().c_str())); +// } +// for (auto const& integerVariable : module.getIntegerVariables()) { +// solverVariables.emplace(integerVariable.getName(), localContext.int_const(integerVariable.getName().c_str())); +// } +// } +// +// storm::adapters::Z3ExpressionAdapter expressionAdapter(localContext, false, solverVariables); // Then add the constraints for bounds of the integer variables.. for (auto const& integerVariable : program.getGlobalIntegerVariables()) { - z3::expr lowerBound = expressionAdapter.translateExpression(integerVariable.getLowerBoundExpression()); - lowerBound = solverVariables.at(integerVariable.getName()) >= lowerBound; - localSolver.add(lowerBound); - z3::expr upperBound = expressionAdapter.translateExpression(integerVariable.getUpperBoundExpression()); - upperBound = solverVariables.at(integerVariable.getName()) <= upperBound; - localSolver.add(upperBound); + localSolver->add(integerVariable.getExpressionVariable() >= integerVariable.getLowerBoundExpression()); + localSolver->add(integerVariable.getExpressionVariable() <= integerVariable.getUpperBoundExpression()); } for (auto const& module : program.getModules()) { for (auto const& integerVariable : module.getIntegerVariables()) { - z3::expr lowerBound = expressionAdapter.translateExpression(integerVariable.getLowerBoundExpression()); - lowerBound = solverVariables.at(integerVariable.getName()) >= lowerBound; - localSolver.add(lowerBound); - z3::expr upperBound = expressionAdapter.translateExpression(integerVariable.getUpperBoundExpression()); - upperBound = solverVariables.at(integerVariable.getName()) <= upperBound; - localSolver.add(upperBound); + localSolver->add(integerVariable.getExpressionVariable() >= integerVariable.getLowerBoundExpression()); + localSolver->add(integerVariable.getExpressionVariable() <= integerVariable.getUpperBoundExpression()); } } @@ -667,20 +653,20 @@ namespace storm { } // Save the state of the solver so we can easily backtrack. - localSolver.push(); + localSolver->push(); // Check if the command set is enabled in the initial state. for (auto const& command : currentCommandVector) { - localSolver.add(expressionAdapter.translateExpression(command.get().getGuardExpression())); + localSolver->add(command.get().getGuardExpression()); } - localSolver.add(expressionAdapter.translateExpression(initialStateExpression)); + localSolver->add(initialStateExpression); - z3::check_result checkResult = localSolver.check(); - localSolver.pop(); - localSolver.push(); + storm::solver::SmtSolver::CheckResult checkResult = localSolver->check(); + localSolver->pop(); + localSolver->push(); // If the solver reports unsat, then we know that the current selection is not enabled in the initial state. - if (checkResult == z3::unsat) { + if (checkResult == storm::solver::SmtSolver::CheckResult::Unsat) { LOG4CPLUS_DEBUG(logger, "Selection not enabled in initial state."); storm::expressions::Expression guardConjunction; if (currentCommandVector.size() == 1) { @@ -703,22 +689,22 @@ namespace storm { } LOG4CPLUS_DEBUG(logger, "About to assert disjunction of negated guards."); - z3::expr guardExpression = localContext.bool_val(false); + storm::expressions::Expression guardExpression = localManager.boolean(false); bool firstAssignment = true; for (auto const& command : currentCommandVector) { if (firstAssignment) { - guardExpression = !expressionAdapter.translateExpression(command.get().getGuardExpression()); + guardExpression = !command.get().getGuardExpression(); } else { - guardExpression = guardExpression | !expressionAdapter.translateExpression(command.get().getGuardExpression()); + guardExpression = guardExpression || !command.get().getGuardExpression(); } } - localSolver.add(guardExpression); + localSolver->add(guardExpression); LOG4CPLUS_DEBUG(logger, "Asserted disjunction of negated guards."); // Now check the possible preceding label sets for the essential ones. for (auto const& precedingLabelSet : labelSetAndPrecedingLabelSetsPair.second) { // Create a restore point so we can easily pop-off all weakest precondition expressions. - localSolver.push(); + localSolver->push(); // Find out the commands for the currently considered preceding label set. std::vector> currentPrecedingCommandVector; @@ -737,7 +723,7 @@ namespace storm { // Assert all the guards of the preceding command set. for (auto const& command : currentPrecedingCommandVector) { - localSolver.add(expressionAdapter.translateExpression(command.get().getGuardExpression())); + localSolver->add(command.get().getGuardExpression()); } std::vector::const_iterator> iteratorVector; @@ -746,19 +732,19 @@ namespace storm { } // Iterate over all possible combinations of updates of the preceding command set. - std::vector formulae; + std::vector formulae; bool done = false; while (!done) { - std::map currentUpdateCombinationMap; + std::map currentUpdateCombinationMap; for (auto const& updateIterator : iteratorVector) { for (auto const& assignment : updateIterator->getAssignments()) { - currentUpdateCombinationMap.emplace(assignment.getVariableName(), assignment.getExpression()); + currentUpdateCombinationMap.emplace(assignment.getVariable(), assignment.getExpression()); } } LOG4CPLUS_DEBUG(logger, "About to assert a weakest precondition."); storm::expressions::Expression wp = guardConjunction.substitute(currentUpdateCombinationMap); - formulae.push_back(expressionAdapter.translateExpression(wp)); + formulae.push_back(wp); LOG4CPLUS_DEBUG(logger, "Asserted weakest precondition."); // Now try to move iterators to the next position if possible. If we could properly move it, we can directly @@ -780,19 +766,19 @@ namespace storm { } // Now assert the disjunction of all weakest preconditions of all considered update combinations. - assertDisjunction(localContext, localSolver, formulae); + assertDisjunction(*localSolver, formulae, localManager); LOG4CPLUS_DEBUG(logger, "Asserted disjunction of all weakest preconditions."); - if (localSolver.check() == z3::sat) { + if (localSolver->check() == storm::solver::SmtSolver::CheckResult::Sat) { backwardImplications[labelSetAndPrecedingLabelSetsPair.first].insert(precedingLabelSet); } - localSolver.pop(); + localSolver->pop(); } // Popping the disjunction of negated guards from the solver stack. - localSolver.pop(); + localSolver->pop(); } } @@ -810,7 +796,7 @@ namespace storm { LOG4CPLUS_DEBUG(logger, "Asserting taken labels are preceded by another label if they are not an initial label."); // Now assert that for each non-target label, we take a following label. for (auto const& labelSetImplicationsPair : backwardImplications) { - std::vector formulae; + std::vector formulae; // Only build a constraint if the combination no predecessor set is already known. if (hasKnownPredecessor.find(labelSetImplicationsPair.first) == hasKnownPredecessor.end()) { @@ -830,7 +816,7 @@ namespace storm { std::set_difference(precedingSet.begin(), precedingSet.end(), relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end(), std::inserter(tmpSet, tmpSet.end())); // Construct an expression that enables all unknown labels of the current following set. - z3::expr conj = context.bool_val(true); + storm::expressions::Expression conj = variableInformation.manager->boolean(true); for (auto label : tmpSet) { conj = conj && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); } @@ -842,13 +828,13 @@ namespace storm { // This is because it could be that the commands are taken to enable other synchronizations. Therefore, we need // to add an additional clause that says that the right-hand side of the implication is also true if all commands // of the current choice have enabled synchronization options. - z3::expr finalDisjunct = context.bool_val(false); + storm::expressions::Expression finalDisjunct = variableInformation.manager->boolean(false); for (auto label : labelSetImplicationsPair.first) { - z3::expr alternativeExpressionForLabel = context.bool_val(false); + storm::expressions::Expression alternativeExpressionForLabel = variableInformation.manager->boolean(false); std::set> const& synchsForCommand = synchronizingLabels.at(label); for (auto const& synchSet : synchsForCommand) { - z3::expr alternativeExpression = context.bool_val(true); + storm::expressions::Expression alternativeExpression = variableInformation.manager->boolean(true); // If the current synchSet is the same as left-hand side of the implication, we need to skip it. if (synchSet == labelSetImplicationsPair.first) continue; @@ -865,11 +851,11 @@ namespace storm { alternativeExpression = alternativeExpression && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); } - z3::expr disjunctionOverPredecessors = context.bool_val(false); + storm::expressions::Expression disjunctionOverPredecessors = variableInformation.manager->boolean(false); auto precedingLabelSetsIterator = precedingLabelSets.find(synchSet); if (precedingLabelSetsIterator != precedingLabelSets.end()) { for (auto precedingSet : precedingLabelSetsIterator->second) { - z3::expr conjunctionOverLabels = context.bool_val(true); + storm::expressions::Expression conjunctionOverLabels = variableInformation.manager->boolean(true); for (auto label : precedingSet) { if (relevancyInformation.knownLabels.find(label) == relevancyInformation.knownLabels.end()) { conjunctionOverLabels = conjunctionOverLabels && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); @@ -892,7 +878,7 @@ namespace storm { } if (formulae.size() > 0) { - assertDisjunction(context, solver, formulae); + assertDisjunction(solver, formulae, *variableInformation.manager); } } } @@ -901,7 +887,7 @@ namespace storm { /*! * Asserts constraints necessary to encode the reachability of at least one target state from the initial states. */ - static void assertReachabilityCuts(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { + static void assertReachabilityCuts(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, storm::solver::SmtSolver& solver) { if (!variableInformation.hasReachabilityVariables) { throw storm::exceptions::InvalidStateException() << "Impossible to assert reachability cuts without the necessary variables."; @@ -915,7 +901,7 @@ namespace storm { // First, we add the formulas that encode // (1) if an incoming transition is chosen, an outgoing one is chosen as well (for non-initial states) // (2) an outgoing transition out of the initial states is taken. - z3::expr initialStateExpression = context.bool_val(false); + storm::expressions::Expression initialStateExpression = variableInformation.manager->boolean(false); for (auto relevantState : relevancyInformation.relevantStates) { if (!labeledMdp.getInitialStates().get(relevantState)) { // Assert the constraints (1). @@ -935,7 +921,7 @@ namespace storm { } } - z3::expr expression = context.bool_val(true); + storm::expressions::Expression expression = variableInformation.manager->boolean(true); for (auto predecessor : relevantPredecessors) { expression = expression && !variableInformation.statePairVariables.at(variableInformation.statePairToIndexMap.at(std::make_pair(predecessor, relevantState))); } @@ -978,9 +964,9 @@ namespace storm { } } } - z3::expr labelExpression = !variableInformation.statePairVariables.at(statePairIndexPair.second); + storm::expressions::Expression labelExpression = !variableInformation.statePairVariables.at(statePairIndexPair.second); for (auto choice : choicesForStatePair) { - z3::expr choiceExpression = context.bool_val(true); + storm::expressions::Expression choiceExpression = variableInformation.manager->boolean(true); for (auto element : choiceLabeling.at(choice)) { if (relevancyInformation.knownLabels.find(element) == relevancyInformation.knownLabels.end()) { choiceExpression = choiceExpression && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(element)); @@ -991,7 +977,7 @@ namespace storm { solver.add(labelExpression); // Assert constraint for (2). - z3::expr orderExpression = !variableInformation.statePairVariables.at(statePairIndexPair.second) || variableInformation.stateOrderVariables.at(variableInformation.relevantStatesToOrderVariableIndexMap.at(sourceState)) < variableInformation.stateOrderVariables.at(variableInformation.relevantStatesToOrderVariableIndexMap.at(targetState)); + storm::expressions::Expression orderExpression = !variableInformation.statePairVariables.at(statePairIndexPair.second) || variableInformation.stateOrderVariables.at(variableInformation.relevantStatesToOrderVariableIndexMap.at(sourceState)) < variableInformation.stateOrderVariables.at(variableInformation.relevantStatesToOrderVariableIndexMap.at(targetState)); solver.add(orderExpression); } } @@ -1000,12 +986,12 @@ namespace storm { * Asserts that the disjunction of the given formulae holds. If the content of the disjunction is empty, * this corresponds to asserting false. * - * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. * @param formulaVector A vector of expressions that shall form the disjunction. + * @param manager The expression manager to use. */ - static void assertDisjunction(z3::context& context, z3::solver& solver, std::vector const& formulaVector) { - z3::expr disjunction = context.bool_val(false); + static void assertDisjunction(storm::solver::SmtSolver& solver, std::vector const& formulaVector, storm::expressions::ExpressionManager const& manager) { + storm::expressions::Expression disjunction = manager.boolean(false); for (auto expr : formulaVector) { disjunction = disjunction || expr; } @@ -1037,9 +1023,9 @@ namespace storm { * @return A pair whose first component represents the carry bit and whose second component represents the * result bit. */ - static std::pair createFullAdder(z3::expr in1, z3::expr in2, z3::expr carryIn) { - z3::expr resultBit = (in1 && !in2 && !carryIn) || (!in1 && in2 && !carryIn) || (!in1 && !in2 && carryIn) || (in1 && in2 && carryIn); - z3::expr carryBit = in1 && in2 || in1 && carryIn || in2 && carryIn; + static std::pair createFullAdder(storm::expressions::Expression in1, storm::expressions::Expression in2, storm::expressions::Expression carryIn) { + storm::expressions::Expression resultBit = (in1 && !in2 && !carryIn) || (!in1 && in2 && !carryIn) || (!in1 && !in2 && carryIn) || (in1 && in2 && carryIn); + storm::expressions::Expression carryBit = in1 && in2 || in1 && carryIn || in2 && carryIn; return std::make_pair(carryBit, resultBit); } @@ -1048,12 +1034,12 @@ namespace storm { * Creates an adder for the two inputs of equal size. The resulting vector represents the different bits of * the sum (and is thus one bit longer than the two inputs). * - * @param context The Z3 context in which to build the expressions. + * @param variableInformation The variable information structure. * @param in1 The first input to the adder. * @param in2 The second input to the adder. * @return A vector representing the bits of the sum of the two inputs. */ - static std::vector createAdder(z3::context& context, std::vector const& in1, std::vector const& in2) { + static std::vector createAdder(VariableInformation const& variableInformation, std::vector const& in1, std::vector const& in2) { // Sanity check for sizes of input. if (in1.size() != in2.size() || in1.size() == 0) { LOG4CPLUS_ERROR(logger, "Illegal input to adder (" << in1.size() << ", " << in2.size() << ")."); @@ -1061,13 +1047,13 @@ namespace storm { } // Prepare result. - std::vector result; + std::vector result; result.reserve(in1.size() + 1); // Add all bits individually and pass on carry bit appropriately. - z3::expr carryBit = context.bool_val(false); + storm::expressions::Expression carryBit = variableInformation.manager->boolean(false); for (uint_fast64_t currentBit = 0; currentBit < in1.size(); ++currentBit) { - std::pair localResult = createFullAdder(in1[currentBit], in2[currentBit], carryBit); + std::pair localResult = createFullAdder(in1[currentBit], in2[currentBit], carryBit); result.push_back(localResult.second); carryBit = localResult.first; @@ -1082,22 +1068,22 @@ namespace storm { * consecutive numbers of the input. If the number if input numbers is odd, the last number is simply added * to the output. * - * @param context The Z3 context in which to build the expressions. + * @param variableInformation The variable information structure. * @param in A vector or binary encoded numbers. * @return A vector of numbers that each correspond to the sum of two consecutive elements of the input. */ - static std::vector> createAdderPairs(z3::context& context, std::vector> const& in) { - std::vector> result; + static std::vector> createAdderPairs(VariableInformation const& variableInformation, std::vector> const& in) { + std::vector> result; result.reserve(in.size() / 2 + in.size() % 2); for (uint_fast64_t index = 0; index < in.size() / 2; ++index) { - result.push_back(createAdder(context, in[2 * index], in[2 * index + 1])); + result.push_back(createAdder(variableInformation, in[2 * index], in[2 * index + 1])); } if (in.size() % 2 != 0) { result.push_back(in.back()); - result.back().push_back(context.bool_val(false)); + result.back().push_back(variableInformation.manager->boolean(false)); } return result; @@ -1106,25 +1092,25 @@ namespace storm { /*! * Creates a counter circuit that returns the number of literals out of the given vector that are set to true. * - * @param context The Z3 context in which to build the expressions. + * @param variableInformation The variable information structure. * @param literals The literals for which to create the adder circuit. * @return A bit vector representing the number of literals that are set to true. */ - static std::vector createCounterCircuit(z3::context& context, std::vector const& literals) { + static std::vector createCounterCircuit(VariableInformation const& variableInformation, std::vector const& literals) { LOG4CPLUS_DEBUG(logger, "Creating counter circuit for " << literals.size() << " literals."); // Create the auxiliary vector. - std::vector> aux; + std::vector> aux; for (uint_fast64_t index = 0; index < literals.size(); ++index) { aux.emplace_back(); aux.back().push_back(literals[index]); } while (aux.size() > 1) { - aux = createAdderPairs(context, aux); + aux = createAdderPairs(variableInformation, aux); } - return aux[0]; + return aux.front(); } /*! @@ -1144,38 +1130,39 @@ namespace storm { * may at most represent the number k. The constraint is associated with a relaxation variable, that is * returned by this function and may be used to deactivate the constraint. * - * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. - * @param input The variables that encode the value to restrict. + * @param variableInformation The struct that holds the variable information. * @param k The bound for the binary-encoded value. * @return The relaxation variable associated with the constraint. */ - static z3::expr assertLessOrEqualKRelaxed(z3::context& context, z3::solver& solver, std::vector const& input, uint64_t k) { + static storm::expressions::Variable assertLessOrEqualKRelaxed(storm::solver::SmtSolver& solver, VariableInformation const& variableInformation, uint64_t k) { LOG4CPLUS_DEBUG(logger, "Asserting solution has size less or equal " << k << "."); - - z3::expr result(context); + + std::vector const& input = variableInformation.adderVariables; + + storm::expressions::Expression result; if (bitIsSet(k, 0)) { - result = context.bool_val(true); + result = variableInformation.manager->boolean(true); } else { result = !input.at(0); } for (uint_fast64_t index = 1; index < input.size(); ++index) { - z3::expr i1(context); - z3::expr i2(context); + storm::expressions::Expression i1; + storm::expressions::Expression i2; if (bitIsSet(k, index)) { i1 = !input.at(index); i2 = result; } else { - i1 = context.bool_val(false); - i2 = context.bool_val(false); + i1 = variableInformation.manager->boolean(false); + i2 = variableInformation.manager->boolean(false); } result = i1 || i2 || (!input.at(index) && result); } std::stringstream variableName; variableName << "relaxed" << k; - z3::expr relaxingVariable = context.bool_const(variableName.str().c_str()); + storm::expressions::Variable relaxingVariable = variableInformation.manager->declareBooleanVariable(variableName.str()); result = result || relaxingVariable; solver.add(result); @@ -1293,11 +1280,10 @@ namespace storm { /*! * Determines the set of labels that was chosen by the given model. * - * @param context The Z3 context in which to build the expressions. - * @param model The Z3 model from which to extract the information. + * @param model The model from which to extract the information. * @param variableInformation A structure with information about the variables of the solver. */ - static boost::container::flat_set getUsedLabelSet(z3::context& context, z3::model const& model, VariableInformation const& variableInformation) { + static boost::container::flat_set getUsedLabelSet(storm::solver::SmtSolver::ModelReference const& model, VariableInformation const& variableInformation) { boost::container::flat_set result; for (auto const& labelIndexPair : variableInformation.labelToIndexMap) { z3::expr auxValue = model.eval(variableInformation.labelVariables.at(labelIndexPair.second)); @@ -1320,21 +1306,20 @@ namespace storm { * Asserts an adder structure in the given solver that counts the number of variables that are set to true * out of the given variables. * - * @param context The Z3 context in which to build the expressions. * @param solver The solver for which to add the adder. * @param variableInformation A structure with information about the variables of the solver. */ - static std::vector assertAdder(z3::context& context, z3::solver& solver, VariableInformation const& variableInformation) { + static std::vector assertAdder(storm::solver::SmtSolver& solver, VariableInformation const& variableInformation) { std::stringstream variableName; - std::vector result; + std::vector result; - std::vector adderVariables = createCounterCircuit(context, variableInformation.labelVariables); + std::vector adderVariables = createCounterCircuit(variableInformation, variableInformation.labelVariables); for (uint_fast64_t i = 0; i < adderVariables.size(); ++i) { variableName.str(""); variableName.clear(); variableName << "adder" << i; - result.push_back(context.bool_const(variableName.str().c_str())); - solver.add(implies(adderVariables[i], result.back())); + result.push_back(variableInformation.manager->declareBooleanVariable(variableName.str())); + solver.add(storm::expressions::implies(adderVariables[i], result.back())); } return result; @@ -1343,29 +1328,28 @@ namespace storm { /*! * Finds the smallest set of labels such that the constraint system of the solver is still satisfiable. * - * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. * @param variableInformation A structure with information about the variables of the solver. * @param currentBound The currently known lower bound for the number of labels that need to be enabled * in order to satisfy the constraint system. * @return The smallest set of labels such that the constraint system of the solver is satisfiable. */ - static boost::container::flat_set findSmallestCommandSet(z3::context& context, z3::solver& solver, VariableInformation& variableInformation, uint_fast64_t& currentBound) { + static boost::container::flat_set findSmallestCommandSet(storm::solver::SmtSolver& solver, VariableInformation& variableInformation, uint_fast64_t& currentBound) { // Check if we can find a solution with the current bound. - z3::expr assumption = !variableInformation.auxiliaryVariables.back(); + storm::expressions::Expression assumption = !variableInformation.auxiliaryVariables.back(); // As long as the constraints are unsatisfiable, we need to relax the last at-most-k constraint and // try with an increased bound. - while (solver.check(1, &assumption) == z3::unsat) { + while (solver.checkWithAssumptions({assumption}) == storm::solver::SmtSolver::CheckResult::) { LOG4CPLUS_DEBUG(logger, "Constraint system is unsatisfiable with at most " << currentBound << " taken commands; increasing bound."); solver.add(variableInformation.auxiliaryVariables.back()); - variableInformation.auxiliaryVariables.push_back(assertLessOrEqualKRelaxed(context, solver, variableInformation.adderVariables, ++currentBound)); + variableInformation.auxiliaryVariables.push_back(assertLessOrEqualKRelaxed(solver, variableInformation, ++currentBound)); assumption = !variableInformation.auxiliaryVariables.back(); } // At this point we know that the constraint system was satisfiable, so compute the induced label // set and return it. - return getUsedLabelSet(context, solver.get_model(), variableInformation); + return getUsedLabelSet(*solver.getModel(), variableInformation); } /*! @@ -1491,7 +1475,7 @@ namespace storm { } LOG4CPLUS_DEBUG(logger, "Asserting reachability implications."); - assertDisjunction(context, solver, formulae); + assertDisjunction(solver, formulae, *variableInformation.manager); } @@ -1601,7 +1585,7 @@ namespace storm { } LOG4CPLUS_DEBUG(logger, "Asserting reachability implications."); - assertDisjunction(context, solver, formulae); + assertDisjunction(solver, formulae, *variableInformation.manager); } #endif @@ -1640,7 +1624,7 @@ namespace storm { auto analysisClock = std::chrono::high_resolution_clock::now(); decltype(std::chrono::high_resolution_clock::now() - analysisClock) totalAnalysisTime(0); - std::map constantDefinitions = storm::utility::prism::parseConstantDefinitionString(program, constantDefinitionString); + std::map constantDefinitions = storm::utility::prism::parseConstantDefinitionString(program, constantDefinitionString); storm::prism::Program preparedProgram = program.defineUndefinedConstants(constantDefinitions); preparedProgram = preparedProgram.substituteConstants(); @@ -1666,37 +1650,35 @@ namespace storm { // (2) Identify all states and commands that are relevant, because only these need to be considered later. RelevancyInformation relevancyInformation = determineRelevantStatesAndLabels(labeledMdp, phiStates, psiStates); - // (3) Create context for solver. - z3::context context; + // (3) Create a solver. + std::shared_ptr manager(new storm::expressions::ExpressionManager()); + std::unique_ptr solver(new storm::solver::Z3SmtSolver(*manager)); // (4) Create the variables for the relevant commands. - VariableInformation variableInformation = createVariables(context, labeledMdp, psiStates, relevancyInformation, includeReachabilityEncoding); + VariableInformation variableInformation = createVariables(manager, labeledMdp, psiStates, relevancyInformation, includeReachabilityEncoding); LOG4CPLUS_DEBUG(logger, "Created variables."); - // (5) After all variables have been created, create a solver for that context. - z3::solver solver(context); - - // (6) Now assert an adder whose result variables can later be used to constrain the nummber of label + // (5) Now assert an adder whose result variables can later be used to constrain the nummber of label // variables that were set to true. Initially, we are looking for a solution that has no label enabled // and subsequently relax that. - variableInformation.adderVariables = assertAdder(context, solver, variableInformation); - variableInformation.auxiliaryVariables.push_back(assertLessOrEqualKRelaxed(context, solver, variableInformation.adderVariables, 0)); + variableInformation.adderVariables = assertAdder(*solver, variableInformation); + variableInformation.auxiliaryVariables.push_back(assertLessOrEqualKRelaxed(*solver, variableInformation, 0)); - // (7) Add constraints that cut off a lot of suboptimal solutions. + // (6) Add constraints that cut off a lot of suboptimal solutions. LOG4CPLUS_DEBUG(logger, "Asserting cuts."); - assertExplicitCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver); + assertExplicitCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, *solver); LOG4CPLUS_DEBUG(logger, "Asserted explicit cuts."); - assertSymbolicCuts(preparedProgram, labeledMdp, variableInformation, relevancyInformation, context, solver); + assertSymbolicCuts(preparedProgram, labeledMdp, variableInformation, relevancyInformation, *solver); LOG4CPLUS_DEBUG(logger, "Asserted symbolic cuts."); if (includeReachabilityEncoding) { - assertReachabilityCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver); + assertReachabilityCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, *solver); LOG4CPLUS_DEBUG(logger, "Asserted reachability cuts."); } // As we are done with the setup at this point, stop the clock for the setup time. totalSetupTime = std::chrono::high_resolution_clock::now() - setupTimeClock; - // (8) Find the smallest set of commands that satisfies all constraints. If the probability of + // (7) Find the smallest set of commands that satisfies all constraints. If the probability of // satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned. // Otherwise, the current solution has to be ruled out and the next smallest solution is retrieved from // the solver. @@ -1712,7 +1694,7 @@ namespace storm { do { LOG4CPLUS_DEBUG(logger, "Computing minimal command set."); solverClock = std::chrono::high_resolution_clock::now(); - commandSet = findSmallestCommandSet(context, solver, variableInformation, currentBound); + commandSet = findSmallestCommandSet(solver, variableInformation, currentBound); totalSolverTime += std::chrono::high_resolution_clock::now() - solverClock; LOG4CPLUS_DEBUG(logger, "Computed minimal command set of size " << (commandSet.size() + relevancyInformation.knownLabels.size()) << "."); diff --git a/src/parser/PrismParser.cpp b/src/parser/PrismParser.cpp index df247824f..74b7b43d3 100644 --- a/src/parser/PrismParser.cpp +++ b/src/parser/PrismParser.cpp @@ -345,7 +345,7 @@ namespace storm { } storm::prism::Assignment PrismParser::createAssignment(std::string const& variableName, storm::expressions::Expression assignedExpression) const { - return storm::prism::Assignment(variableName, assignedExpression, this->getFilename()); + return storm::prism::Assignment(manager->getVariable(variableName), assignedExpression, this->getFilename()); } storm::prism::Update PrismParser::createUpdate(storm::expressions::Expression likelihoodExpression, std::vector const& assignments, GlobalProgramInformation& globalProgramInformation) const { @@ -459,9 +459,9 @@ namespace storm { for (auto const& assignment : update.getAssignments()) { auto const& renamingPair = renaming.find(assignment.getVariableName()); if (renamingPair != renaming.end()) { - assignments.emplace_back(renamingPair->second, assignment.getExpression().substitute(expressionRenaming), this->getFilename(), get_line(qi::_1)); + assignments.emplace_back(manager->getVariable(renamingPair->second), assignment.getExpression().substitute(expressionRenaming), this->getFilename(), get_line(qi::_1)); } else { - assignments.emplace_back(assignment.getVariableName(), assignment.getExpression().substitute(expressionRenaming), this->getFilename(), get_line(qi::_1)); + assignments.emplace_back(assignment.getVariable(), assignment.getExpression().substitute(expressionRenaming), this->getFilename(), get_line(qi::_1)); } } updates.emplace_back(globalProgramInformation.currentUpdateIndex, update.getLikelihoodExpression().substitute(expressionRenaming), assignments, this->getFilename(), get_line(qi::_1)); diff --git a/src/storage/dd/CuddDdForwardIterator.cpp b/src/storage/dd/CuddDdForwardIterator.cpp index 6a8f0fa85..b37ba1a11 100644 --- a/src/storage/dd/CuddDdForwardIterator.cpp +++ b/src/storage/dd/CuddDdForwardIterator.cpp @@ -90,9 +90,9 @@ namespace storm { } else { storm::expressions::Variable const& metaVariable = std::get<1>(this->relevantDontCareDdVariables[index]); if ((this->cubeCounter & (1ull << index)) != 0) { - currentValuation.setBoundedIntegerValue(metaVariable, ((currentValuation.getBoundedIntegerValue(metaVariable) - ddMetaVariable.getLow()) | (1ull << std::get<2>(this->relevantDontCareDdVariables[index]))) + ddMetaVariable.getLow()); + currentValuation.setBitVectorValue(metaVariable, ((currentValuation.getBitVectorValue(metaVariable) - ddMetaVariable.getLow()) | (1ull << std::get<2>(this->relevantDontCareDdVariables[index]))) + ddMetaVariable.getLow()); } else { - currentValuation.setBoundedIntegerValue(metaVariable, ((currentValuation.getBoundedIntegerValue(metaVariable) - ddMetaVariable.getLow()) & ~(1ull << std::get<2>(this->relevantDontCareDdVariables[index]))) + ddMetaVariable.getLow()); + currentValuation.setBitVectorValue(metaVariable, ((currentValuation.getBitVectorValue(metaVariable) - ddMetaVariable.getLow()) & ~(1ull << std::get<2>(this->relevantDontCareDdVariables[index]))) + ddMetaVariable.getLow()); } } } @@ -134,7 +134,7 @@ namespace storm { } } if (this->enumerateDontCareMetaVariables || metaVariableAppearsInCube) { - currentValuation.setBoundedIntegerValue(metaVariable, intValue + ddMetaVariable.getLow()); + currentValuation.setBitVectorValue(metaVariable, intValue + ddMetaVariable.getLow()); } } diff --git a/src/storage/dd/CuddDdManager.cpp b/src/storage/dd/CuddDdManager.cpp index 884da639b..32a25ae3e 100644 --- a/src/storage/dd/CuddDdManager.cpp +++ b/src/storage/dd/CuddDdManager.cpp @@ -110,8 +110,8 @@ namespace storm { std::size_t numberOfBits = static_cast(std::ceil(std::log2(high - low + 1))); - storm::expressions::Variable unprimed = manager->declareBoundedIntegerVariable(name, numberOfBits); - storm::expressions::Variable primed = manager->declareBoundedIntegerVariable(name + "'", numberOfBits); + storm::expressions::Variable unprimed = manager->declareBitVectorVariable(name, numberOfBits); + storm::expressions::Variable primed = manager->declareBitVectorVariable(name + "'", numberOfBits); std::vector> variables; std::vector> variablesPrime; diff --git a/src/storage/expressions/BaseExpression.cpp b/src/storage/expressions/BaseExpression.cpp index 90033870d..9d44215d0 100644 --- a/src/storage/expressions/BaseExpression.cpp +++ b/src/storage/expressions/BaseExpression.cpp @@ -18,8 +18,8 @@ namespace storm { return this->getType().isIntegerType(); } - bool BaseExpression::hasBoundedIntegerType() const { - return this->getType().isBoundedIntegerType(); + bool BaseExpression::hasBitVectorType() const { + return this->getType().isBitVectorType(); } bool BaseExpression::hasNumericalType() const { diff --git a/src/storage/expressions/BaseExpression.h b/src/storage/expressions/BaseExpression.h index 5f9faf1f9..6a5a0d6b6 100644 --- a/src/storage/expressions/BaseExpression.h +++ b/src/storage/expressions/BaseExpression.h @@ -182,11 +182,11 @@ namespace storm { bool hasIntegerType() const; /*! - * Retrieves whether the expression has a bounded integer type. + * Retrieves whether the expression has a bitvector type. * - * @return True iff the expression has a bounded integer type. + * @return True iff the expression has a bitvector type. */ - bool hasBoundedIntegerType() const; + bool hasBitVectorType() const; /*! * Retrieves whether the expression has a boolean type. diff --git a/src/storage/expressions/Expression.cpp b/src/storage/expressions/Expression.cpp index 839dd7c24..c65ba9748 100644 --- a/src/storage/expressions/Expression.cpp +++ b/src/storage/expressions/Expression.cpp @@ -146,8 +146,8 @@ namespace storm { return this->getBaseExpression().hasIntegerType(); } - bool Expression::hasBoundedIntegerType() const { - return this->getBaseExpression().hasBoundedIntegerType(); + bool Expression::hasBitVectorType() const { + return this->getBaseExpression().hasBitVectorType(); } boost::any Expression::accept(ExpressionVisitor& visitor) const { diff --git a/src/storage/expressions/Expression.h b/src/storage/expressions/Expression.h index 22d0c2f87..3efd79f29 100644 --- a/src/storage/expressions/Expression.h +++ b/src/storage/expressions/Expression.h @@ -274,7 +274,7 @@ namespace storm { * * @return True iff the expression has a integral return type. */ - bool hasBoundedIntegerType() const; + bool hasBitVectorType() const; /*! * Accepts the given visitor. diff --git a/src/storage/expressions/ExpressionManager.cpp b/src/storage/expressions/ExpressionManager.cpp index b624af766..aa7ffc6f8 100644 --- a/src/storage/expressions/ExpressionManager.cpp +++ b/src/storage/expressions/ExpressionManager.cpp @@ -51,7 +51,7 @@ namespace storm { } } - ExpressionManager::ExpressionManager() : nameToIndexMapping(), indexToNameMapping(), indexToTypeMapping(), numberOfVariables(0), variableTypeToCountMapping(), auxiliaryVariableTypeToCountMapping(), numberOfAuxiliaryVariables(0), freshVariableCounter(0), booleanType(nullptr), integerType(nullptr), rationalType(nullptr) { + ExpressionManager::ExpressionManager() : nameToIndexMapping(), indexToNameMapping(), indexToTypeMapping(), numberOfVariables(0), numberOfBooleanVariables(0), numberOfIntegerVariables(0), numberOfBitVectorVariables(0), numberOfRationalVariables(0), numberOfAuxiliaryVariables(0), numberOfAuxiliaryBooleanVariables(0), numberOfAuxiliaryIntegerVariables(0), numberOfAuxiliaryBitVectorVariables(0), numberOfAuxiliaryRationalVariables(0), freshVariableCounter(0), types() { // Intentionally left empty. } @@ -72,35 +72,43 @@ namespace storm { } Type const& ExpressionManager::getBooleanType() const { - if (booleanType == nullptr) { - booleanType = std::unique_ptr(new Type(this->getSharedPointer(), std::shared_ptr(new BooleanType()))); + Type type(this->getSharedPointer(), std::shared_ptr(new BooleanType())); + auto typeIterator = types.find(type); + if (typeIterator == types.end()) { + auto iteratorBoolPair = types.insert(type); + return *iteratorBoolPair.first; } - return *booleanType; + return *typeIterator; } Type const& ExpressionManager::getIntegerType() const { - if (integerType == nullptr) { - integerType = std::unique_ptr(new Type(this->getSharedPointer(), std::shared_ptr(new IntegerType()))); + Type type(this->getSharedPointer(), std::shared_ptr(new IntegerType())); + auto typeIterator = types.find(type); + if (typeIterator == types.end()) { + auto iteratorBoolPair = types.insert(type); + return *iteratorBoolPair.first; } - return *integerType; + return *typeIterator; } - Type const& ExpressionManager::getBoundedIntegerType(std::size_t width) const { - auto boundedIntegerType = boundedIntegerTypes.find(width); - if (boundedIntegerType == boundedIntegerTypes.end()) { - Type newType = Type(this->getSharedPointer(), std::shared_ptr(new BoundedIntegerType(width))); - boundedIntegerTypes[width] = newType; - return boundedIntegerTypes[width]; - } else { - return boundedIntegerType->second; + Type const& ExpressionManager::getBitVectorType(std::size_t width) const { + Type type(this->getSharedPointer(), std::shared_ptr(new BitVectorType(width))); + auto typeIterator = types.find(type); + if (typeIterator == types.end()) { + auto iteratorBoolPair = types.insert(type); + return *iteratorBoolPair.first; } + return *typeIterator; } Type const& ExpressionManager::getRationalType() const { - if (rationalType == nullptr) { - rationalType = std::unique_ptr(new Type(this->getSharedPointer(), std::shared_ptr(new RationalType()))); + Type type(this->getSharedPointer(), std::shared_ptr(new RationalType())); + auto typeIterator = types.find(type); + if (typeIterator == types.end()) { + auto iteratorBoolPair = types.insert(type); + return *iteratorBoolPair.first; } - return *rationalType; + return *typeIterator; } bool ExpressionManager::isValidVariableName(std::string const& name) { @@ -112,39 +120,30 @@ namespace storm { return nameIndexPair != nameToIndexMapping.end(); } - Variable ExpressionManager::declareVariable(std::string const& name, storm::expressions::Type const& variableType) { + Variable ExpressionManager::declareVariable(std::string const& name, storm::expressions::Type const& variableType, bool auxiliary) { STORM_LOG_THROW(!variableExists(name), storm::exceptions::InvalidArgumentException, "Variable with name '" << name << "' already exists."); - return declareOrGetVariable(name, variableType); + return declareOrGetVariable(name, variableType, auxiliary); } - Variable ExpressionManager::declareBooleanVariable(std::string const& name) { - Variable var = this->declareVariable(name, this->getBooleanType()); + Variable ExpressionManager::declareBooleanVariable(std::string const& name, bool auxiliary) { + Variable var = this->declareVariable(name, this->getBooleanType(), auxiliary); return var; } - Variable ExpressionManager::declareIntegerVariable(std::string const& name) { - return this->declareVariable(name, this->getIntegerType()); + Variable ExpressionManager::declareIntegerVariable(std::string const& name, bool auxiliary) { + return this->declareVariable(name, this->getIntegerType(), auxiliary); } - Variable ExpressionManager::declareBoundedIntegerVariable(std::string const& name, std::size_t width) { - return this->declareVariable(name, this->getBoundedIntegerType(width)); + Variable ExpressionManager::declareBitVectorVariable(std::string const& name, std::size_t width, bool auxiliary) { + return this->declareVariable(name, this->getBitVectorType(width), auxiliary); } - Variable ExpressionManager::declareRationalVariable(std::string const& name) { - return this->declareVariable(name, this->getRationalType()); - } - - Variable ExpressionManager::declareAuxiliaryVariable(std::string const& name, storm::expressions::Type const& variableType) { - STORM_LOG_THROW(!variableExists(name), storm::exceptions::InvalidArgumentException, "Variable with name '" << name << "' already exists."); - return declareOrGetAuxiliaryVariable(name, variableType); + Variable ExpressionManager::declareRationalVariable(std::string const& name, bool auxiliary) { + return this->declareVariable(name, this->getRationalType(), auxiliary); } - Variable ExpressionManager::declareOrGetVariable(std::string const& name, storm::expressions::Type const& variableType) { - return declareOrGetVariable(name, variableType, false, true); - } - - Variable ExpressionManager::declareOrGetAuxiliaryVariable(std::string const& name, storm::expressions::Type const& variableType) { - return declareOrGetVariable(name, variableType, true, true); + Variable ExpressionManager::declareOrGetVariable(std::string const& name, storm::expressions::Type const& variableType, bool auxiliary) { + return declareOrGetVariable(name, variableType, auxiliary, true); } Variable ExpressionManager::declareOrGetVariable(std::string const& name, storm::expressions::Type const& variableType, bool auxiliary, bool checkName) { @@ -153,21 +152,31 @@ namespace storm { if (nameIndexPair != nameToIndexMapping.end()) { return Variable(this->getSharedPointer(), nameIndexPair->second); } else { - std::unordered_map::iterator typeCountPair; + uint_fast64_t offset = 0; if (auxiliary) { - typeCountPair = auxiliaryVariableTypeToCountMapping.find(variableType); - if (typeCountPair == auxiliaryVariableTypeToCountMapping.end()) { - typeCountPair = auxiliaryVariableTypeToCountMapping.insert(typeCountPair, std::make_pair(variableType, 0)); + if (variableType.isBooleanType()) { + offset = numberOfBooleanVariables++; + } else if (variableType.isIntegerType()) { + offset = numberOfIntegerVariables++ + numberOfBitVectorVariables; + } else if (variableType.isBitVectorType()) { + offset = numberOfBitVectorVariables++ + numberOfIntegerVariables; + } else { + offset = numberOfRationalVariables++; } } else { - typeCountPair = variableTypeToCountMapping.find(variableType); - if (typeCountPair == variableTypeToCountMapping.end()) { - typeCountPair = variableTypeToCountMapping.insert(typeCountPair, std::make_pair(variableType, 0)); + if (variableType.isBooleanType()) { + offset = numberOfBooleanVariables++; + } else if (variableType.isIntegerType()) { + offset = numberOfIntegerVariables++ + numberOfBitVectorVariables; + } else if (variableType.isBitVectorType()) { + offset = numberOfBitVectorVariables++ + numberOfIntegerVariables; + } else { + offset = numberOfRationalVariables++; } } // Compute the index of the new variable. - uint_fast64_t newIndex = typeCountPair->second++ | variableType.getMask() | (auxiliary ? auxiliaryMask : 0); + uint_fast64_t newIndex = offset | variableType.getMask() | (auxiliary ? auxiliaryMask : 0); // Properly insert the variable into the data structure. nameToIndexMapping[name] = newIndex; @@ -191,23 +200,22 @@ namespace storm { return nameToIndexMapping.find(name) != nameToIndexMapping.end(); } - Variable ExpressionManager::declareFreshVariable(storm::expressions::Type const& variableType) { - std::string newName = "__x" + std::to_string(freshVariableCounter++); - return declareOrGetVariable(newName, variableType, false, false); - } - - Variable ExpressionManager::declareFreshAuxiliaryVariable(storm::expressions::Type const& variableType) { + Variable ExpressionManager::declareFreshVariable(storm::expressions::Type const& variableType, bool auxiliary) { std::string newName = "__x" + std::to_string(freshVariableCounter++); - return declareOrGetVariable(newName, variableType, true, false); + return declareOrGetVariable(newName, variableType, auxiliary, false); } uint_fast64_t ExpressionManager::getNumberOfVariables(storm::expressions::Type const& variableType) const { - auto typeCountPair = variableTypeToCountMapping.find(variableType); - if (typeCountPair == variableTypeToCountMapping.end()) { - return 0; - } else { - return typeCountPair->second; + if (variableType.isBooleanType()) { + return numberOfBooleanVariables; + } else if (variableType.isIntegerType()) { + return numberOfIntegerVariables; + } else if (variableType.isBitVectorType()) { + return numberOfBitVectorVariables; + } else if (variableType.isRationalType()) { + return numberOfRationalVariables; } + return 0; } uint_fast64_t ExpressionManager::getNumberOfVariables() const { @@ -215,52 +223,19 @@ namespace storm { } uint_fast64_t ExpressionManager::getNumberOfBooleanVariables() const { - return getNumberOfVariables(getBooleanType()); + return numberOfBooleanVariables; } uint_fast64_t ExpressionManager::getNumberOfIntegerVariables() const { - return getNumberOfVariables(getIntegerType()); + return numberOfIntegerVariables; } - uint_fast64_t ExpressionManager::getNumberOfBoundedIntegerVariables() const { - // The bit width of the type is not of importance here, since bit vector types are considered the same when - // it comes to counting. - return getNumberOfVariables(getBoundedIntegerType(0)); + uint_fast64_t ExpressionManager::getNumberOfBitVectorVariables() const { + return numberOfBitVectorVariables; } uint_fast64_t ExpressionManager::getNumberOfRationalVariables() const { - return getNumberOfVariables(getRationalType()); - } - - uint_fast64_t ExpressionManager::getNumberOfAuxiliaryVariables(storm::expressions::Type const& variableType) const { - auto typeCountPair = auxiliaryVariableTypeToCountMapping.find(variableType); - if (typeCountPair == auxiliaryVariableTypeToCountMapping.end()) { - return 0; - } else { - return typeCountPair->second; - } - } - - uint_fast64_t ExpressionManager::getNumberOfAuxiliaryVariables() const { - return numberOfAuxiliaryVariables; - } - - uint_fast64_t ExpressionManager::getNumberOfAuxiliaryBooleanVariables() const { - return getNumberOfAuxiliaryVariables(getBooleanType()); - } - - uint_fast64_t ExpressionManager::getNumberOfAuxiliaryIntegerVariables() const { - return getNumberOfAuxiliaryVariables(getIntegerType()); - } - - uint_fast64_t ExpressionManager::getNumberOfAuxiliaryBoundedIntegerVariables() const { - // The bit width of the type is not of importance here, since bit vector types are considered the same when - // it comes to counting. - return getNumberOfAuxiliaryVariables(getBoundedIntegerType(0)); - } - - uint_fast64_t ExpressionManager::getNumberOfAuxiliaryRationalVariables() const { - return getNumberOfAuxiliaryVariables(getRationalType()); + return numberOfRationalVariables; } std::string const& ExpressionManager::getVariableName(uint_fast64_t index) const { @@ -294,10 +269,6 @@ namespace storm { std::shared_ptr ExpressionManager::getSharedPointer() const { return this->shared_from_this(); } - - bool ExpressionManager::ManagerTypeEquality::operator()(Type const& a, Type const& b) const { - return a.getMask() == b.getMask(); - } - + } // namespace expressions } // namespace storm \ No newline at end of file diff --git a/src/storage/expressions/ExpressionManager.h b/src/storage/expressions/ExpressionManager.h index 9661cf696..a2461ff02 100644 --- a/src/storage/expressions/ExpressionManager.h +++ b/src/storage/expressions/ExpressionManager.h @@ -117,19 +117,19 @@ namespace storm { Type const& getBooleanType() const; /*! - * Retrieves the integer type. + * Retrieves the unbounded integer type. * - * @return The integer type. + * @return The unbounded integer type. */ Type const& getIntegerType() const; - + /*! - * Retrieves the bounded integer type. + * Retrieves the bit vector type of the given width. * * @param width The bit width of the bounded type. * @return The bounded integer type. */ - Type const& getBoundedIntegerType(std::size_t width) const; + Type const& getBitVectorType(std::size_t width) const; /*! * Retrieves the rational type. @@ -144,74 +144,62 @@ namespace storm { * * @param name The name of the variable. * @param variableType The type of the variable. + * @param auxiliary A flag indicating whether the new variable should be tagged as an auxiliary variable. * @return The newly declared variable. */ - Variable declareVariable(std::string const& name, storm::expressions::Type const& variableType); + Variable declareVariable(std::string const& name, storm::expressions::Type const& variableType, bool auxiliary = false); /*! * Declares a new boolean variable with a name that must not yet exist and its corresponding type. Note that * the name must not start with two underscores since these variables are reserved for internal use only. * * @param name The name of the variable. + * @param auxiliary A flag indicating whether the new variable should be tagged as an auxiliary variable. * @return The newly declared variable. */ - Variable declareBooleanVariable(std::string const& name); + Variable declareBooleanVariable(std::string const& name, bool auxiliary = false); /*! * Declares a new integer variable with a name that must not yet exist and its corresponding type. Note that * the name must not start with two underscores since these variables are reserved for internal use only. * * @param name The name of the variable. + * @param auxiliary A flag indicating whether the new variable should be tagged as an auxiliary variable. * @return The newly declared variable. */ - Variable declareIntegerVariable(std::string const& name); - + Variable declareIntegerVariable(std::string const& name, bool auxiliary = false); + /*! - * Declares a new bounded integer variable with a name that must not yet exist and the bounded type of the + * Declares a new bit vector variable with a name that must not yet exist and the bounded type of the * given bit width. Note that the name must not start with two underscores since these variables are * reserved for internal use only. * * @param name The name of the variable. - * @param width The bit width of the bounded type. + * @param width The bit width of the bit vector type. + * @param auxiliary A flag indicating whether the new variable should be tagged as an auxiliary variable. * @return The newly declared variable. */ - Variable declareBoundedIntegerVariable(std::string const& name, std::size_t width); + Variable declareBitVectorVariable(std::string const& name, std::size_t width, bool auxiliary = false); /*! * Declares a new rational variable with a name that must not yet exist and its corresponding type. Note that * the name must not start with two underscores since these variables are reserved for internal use only. * * @param name The name of the variable. + * @param auxiliary A flag indicating whether the new variable should be tagged as an auxiliary variable. * @return The newly declared variable. */ - Variable declareRationalVariable(std::string const& name); - - /*! - * Declares an auxiliary variable with a name that must not yet exist and its corresponding type. - * - * @param name The name of the variable. - * @param variableType The type of the variable. - * @return The newly declared variable. - */ - Variable declareAuxiliaryVariable(std::string const& name, storm::expressions::Type const& variableType); + Variable declareRationalVariable(std::string const& name, bool auxiliary = false); /*! * Declares a variable with the given name if it does not yet exist. * * @param name The name of the variable to declare. * @param variableType The type of the variable to declare. + * @param auxiliary A flag indicating whether the new variable should be tagged as an auxiliary variable. * @return The variable. */ - Variable declareOrGetVariable(std::string const& name, storm::expressions::Type const& variableType); - - /*! - * Declares a variable with the given name if it does not yet exist. - * - * @param name The name of the variable to declare. - * @param variableType The type of the variable to declare. - * @return The variable. - */ - Variable declareOrGetAuxiliaryVariable(std::string const& name, storm::expressions::Type const& variableType); + Variable declareOrGetVariable(std::string const& name, storm::expressions::Type const& variableType, bool auxiliary = false); /*! * Retrieves the expression that represents the variable with the given name. @@ -240,17 +228,10 @@ namespace storm { * Declares a variable with the given type whose name is guaranteed to be unique and not yet in use. * * @param variableType The type of the variable to declare. + * @param auxiliary A flag indicating whether the new variable should be tagged as an auxiliary variable. * @return The variable. */ - Variable declareFreshVariable(storm::expressions::Type const& variableType); - - /*! - * Declares an auxiliary variable with the given type whose name is guaranteed to be unique and not yet in use. - * - * @param variableType The type of the variable to declare. - * @return The variable. - */ - Variable declareFreshAuxiliaryVariable(storm::expressions::Type const& variableType); + Variable declareFreshVariable(storm::expressions::Type const& variableType, bool auxiliary = false); /*! * Retrieves the number of variables. @@ -272,13 +253,13 @@ namespace storm { * @return The number of integer variables. */ uint_fast64_t getNumberOfIntegerVariables() const; - + /*! - * Retrieves the number of bounded integer variables. + * Retrieves the number of bit vector variables. * - * @return The number of bounded integer variables. + * @return The number of bit vector variables. */ - uint_fast64_t getNumberOfBoundedIntegerVariables() const; + uint_fast64_t getNumberOfBitVectorVariables() const; /*! * Retrieves the number of rational variables. @@ -287,41 +268,6 @@ namespace storm { */ uint_fast64_t getNumberOfRationalVariables() const; - /*! - * Retrieves the number of auxiliary variables. - * - * @return The number of auxiliary variables. - */ - uint_fast64_t getNumberOfAuxiliaryVariables() const; - - /*! - * Retrieves the number of auxiliary boolean variables. - * - * @return The number of auxiliary boolean variables. - */ - uint_fast64_t getNumberOfAuxiliaryBooleanVariables() const; - - /*! - * Retrieves the number of auxiliary integer variables. - * - * @return The number of auxiliary integer variables. - */ - uint_fast64_t getNumberOfAuxiliaryIntegerVariables() const; - - /*! - * Retrieves the number of auxiliary bounded integer variables. - * - * @return The number of auxiliary bounded integer variables. - */ - uint_fast64_t getNumberOfAuxiliaryBoundedIntegerVariables() const; - - /*! - * Retrieves the number of auxiliary rational variables. - * - * @return The number of auxiliary rational variables. - */ - uint_fast64_t getNumberOfAuxiliaryRationalVariables() const; - /*! * Retrieves the name of the variable with the given index. * @@ -375,12 +321,6 @@ namespace storm { std::shared_ptr getSharedPointer() const; private: - // A functor used for treating bit vector types of different bit widths equally when it comes to the variable - // count. - struct ManagerTypeEquality { - bool operator()(Type const& a, Type const& b) const; - }; - /*! * Checks whether the given variable name is valid. * @@ -438,28 +378,31 @@ namespace storm { uint_fast64_t numberOfVariables; // Store counts for variables. - std::unordered_map, ManagerTypeEquality> variableTypeToCountMapping; - - // Store counts for auxiliary variables. - std::unordered_map, ManagerTypeEquality> auxiliaryVariableTypeToCountMapping; - + uint_fast64_t numberOfBooleanVariables; + uint_fast64_t numberOfIntegerVariables; + uint_fast64_t numberOfBitVectorVariables; + uint_fast64_t numberOfRationalVariables; + // The number of declared auxiliary variables. uint_fast64_t numberOfAuxiliaryVariables; + // Store counts for auxiliary variables. + uint_fast64_t numberOfAuxiliaryBooleanVariables; + uint_fast64_t numberOfAuxiliaryIntegerVariables; + uint_fast64_t numberOfAuxiliaryBitVectorVariables; + uint_fast64_t numberOfAuxiliaryRationalVariables; + // A counter used to create fresh variables. uint_fast64_t freshVariableCounter; // The types managed by this manager. - mutable std::unique_ptr booleanType; - mutable std::unique_ptr integerType; - mutable std::unique_ptr rationalType; - mutable std::map boundedIntegerTypes; + mutable std::unordered_set types; // A mask that can be used to query whether a variable is an auxiliary variable. - static const uint64_t auxiliaryMask = (1ull << 60); + static const uint64_t auxiliaryMask = (1ull << 50); // A mask that can be used to project a variable index to its offset (with the group of equally typed variables). - static const uint64_t offsetMask = (1ull << 60) - 1; + static const uint64_t offsetMask = (1ull << 50) - 1; }; } } diff --git a/src/storage/expressions/SimpleValuation.cpp b/src/storage/expressions/SimpleValuation.cpp index bc9454815..b634b68c3 100644 --- a/src/storage/expressions/SimpleValuation.cpp +++ b/src/storage/expressions/SimpleValuation.cpp @@ -11,51 +11,24 @@ namespace storm { // Intentionally left empty. } - SimpleValuation::SimpleValuation(std::shared_ptr const& manager) : Valuation(manager), booleanValues(nullptr), integerValues(nullptr), boundedIntegerValues(nullptr), rationalValues(nullptr) { - if (this->getManager().getNumberOfBooleanVariables() > 0) { - booleanValues = std::unique_ptr>(new std::vector(this->getManager().getNumberOfBooleanVariables())); - } - if (this->getManager().getNumberOfIntegerVariables() > 0) { - integerValues = std::unique_ptr>(new std::vector(this->getManager().getNumberOfIntegerVariables())); - } - if (this->getManager().getNumberOfBoundedIntegerVariables() > 0) { - boundedIntegerValues = std::unique_ptr>(new std::vector(this->getManager().getNumberOfBoundedIntegerVariables())); - } - if (this->getManager().getNumberOfRationalVariables() > 0) { - rationalValues = std::unique_ptr>(new std::vector(this->getManager().getNumberOfRationalVariables())); - } + SimpleValuation::SimpleValuation(std::shared_ptr const& manager) : Valuation(manager), booleanValues(this->getManager().getNumberOfBooleanVariables()), integerValues(this->getManager().getNumberOfIntegerVariables() + this->getManager().getNumberOfBitVectorVariables()), rationalValues(this->getManager().getNumberOfRationalVariables()) { + // Intentionally left empty. } SimpleValuation::SimpleValuation(SimpleValuation const& other) : Valuation(other.getManager().getSharedPointer()) { - if (other.booleanValues != nullptr) { - booleanValues = std::unique_ptr>(new std::vector(*other.booleanValues)); - } - if (other.integerValues != nullptr) { - integerValues = std::unique_ptr>(new std::vector(*other.integerValues)); - } - if (other.boundedIntegerValues != nullptr) { - boundedIntegerValues = std::unique_ptr>(new std::vector(*other.boundedIntegerValues)); - } - if (other.rationalValues != nullptr) { - rationalValues = std::unique_ptr>(new std::vector(*other.rationalValues)); + if (this != &other) { + booleanValues = other.booleanValues; + integerValues = other.integerValues; + rationalValues = other.rationalValues; } } SimpleValuation& SimpleValuation::operator=(SimpleValuation const& other) { if (this != &other) { this->setManager(other.getManager().getSharedPointer()); - if (other.booleanValues != nullptr) { - booleanValues = std::unique_ptr>(new std::vector(*other.booleanValues)); - } - if (other.integerValues != nullptr) { - integerValues = std::unique_ptr>(new std::vector(*other.integerValues)); - } - if (other.boundedIntegerValues != nullptr) { - boundedIntegerValues = std::unique_ptr>(new std::vector(*other.boundedIntegerValues)); - } - if (other.booleanValues != nullptr) { - rationalValues = std::unique_ptr>(new std::vector(*other.rationalValues)); - } + booleanValues = other.booleanValues; + integerValues = other.integerValues; + rationalValues = other.rationalValues; } return *this; } @@ -83,40 +56,39 @@ namespace storm { } bool SimpleValuation::getBooleanValue(Variable const& booleanVariable) const { - return (*booleanValues)[booleanVariable.getOffset()]; + return booleanValues[booleanVariable.getOffset()]; } int_fast64_t SimpleValuation::getIntegerValue(Variable const& integerVariable) const { - return (*integerValues)[integerVariable.getOffset()]; + return integerValues[integerVariable.getOffset()]; } - int_fast64_t SimpleValuation::getBoundedIntegerValue(Variable const& integerVariable) const { - return (*boundedIntegerValues)[integerVariable.getOffset()]; + int_fast64_t SimpleValuation::getBitVectorValue(Variable const& bitVectorVariable) const { + return integerValues[bitVectorVariable.getOffset()]; } double SimpleValuation::getRationalValue(Variable const& rationalVariable) const { - return (*rationalValues)[rationalVariable.getOffset()]; + return rationalValues[rationalVariable.getOffset()]; } void SimpleValuation::setBooleanValue(Variable const& booleanVariable, bool value) { - (*booleanValues)[booleanVariable.getOffset()] = value; + booleanValues[booleanVariable.getOffset()] = value; } void SimpleValuation::setIntegerValue(Variable const& integerVariable, int_fast64_t value) { - (*integerValues)[integerVariable.getOffset()] = value; + integerValues[integerVariable.getOffset()] = value; } - void SimpleValuation::setBoundedIntegerValue(Variable const& integerVariable, int_fast64_t value) { - (*boundedIntegerValues)[integerVariable.getOffset()] = value; + void SimpleValuation::setBitVectorValue(Variable const& bitVectorVariable, int_fast64_t value) { + integerValues[bitVectorVariable.getOffset()] = value; } void SimpleValuation::setRationalValue(Variable const& rationalVariable, double value) { - (*rationalValues)[rationalVariable.getOffset()] = value; + rationalValues[rationalVariable.getOffset()] = value; } - + std::size_t SimpleValuationPointerHash::operator()(SimpleValuation* valuation) const { - size_t seed = 0; - boost::hash_combine(seed, valuation->booleanValues); + size_t seed = std::hash>()(valuation->booleanValues); boost::hash_combine(seed, valuation->integerValues); boost::hash_combine(seed, valuation->rationalValues); return seed; diff --git a/src/storage/expressions/SimpleValuation.h b/src/storage/expressions/SimpleValuation.h index d4418641b..a10f5a927 100644 --- a/src/storage/expressions/SimpleValuation.h +++ b/src/storage/expressions/SimpleValuation.h @@ -47,18 +47,17 @@ namespace storm { virtual bool getBooleanValue(Variable const& booleanVariable) const override; virtual void setBooleanValue(Variable const& booleanVariable, bool value) override; virtual int_fast64_t getIntegerValue(Variable const& integerVariable) const override; - virtual int_fast64_t getBoundedIntegerValue(Variable const& integerVariable) const override; + virtual int_fast64_t getBitVectorValue(Variable const& bitVectorVariable) const override; virtual void setIntegerValue(Variable const& integerVariable, int_fast64_t value) override; - virtual void setBoundedIntegerValue(Variable const& integerVariable, int_fast64_t value) override; + virtual void setBitVectorValue(Variable const& bitVectorVariable, int_fast64_t value) override; virtual double getRationalValue(Variable const& rationalVariable) const override; virtual void setRationalValue(Variable const& rationalVariable, double value) override; private: // Containers that store the values of the variables of the appropriate type. - std::unique_ptr> booleanValues; - std::unique_ptr> integerValues; - std::unique_ptr> boundedIntegerValues; - std::unique_ptr> rationalValues; + std::vector booleanValues; + std::vector integerValues; + std::vector rationalValues; }; /*! diff --git a/src/storage/expressions/Type.cpp b/src/storage/expressions/Type.cpp index e3357da72..5586e0398 100644 --- a/src/storage/expressions/Type.cpp +++ b/src/storage/expressions/Type.cpp @@ -1,6 +1,7 @@ #include "src/storage/expressions/Type.h" #include +#include #include "src/storage/expressions/ExpressionManager.h" #include "src/utility/macros.h" @@ -17,6 +18,10 @@ namespace storm { return this->getMask() == other.getMask(); } + bool BaseType::isErrorType() const { + return false; + } + bool BaseType::isBooleanType() const { return false; } @@ -32,15 +37,19 @@ namespace storm { bool IntegerType::isIntegerType() const { return true; } + + bool BitVectorType::isIntegerType() const { + return true; + } - bool BaseType::isBoundedIntegerType() const { + bool BaseType::isBitVectorType() const { return false; } - - bool BoundedIntegerType::isBoundedIntegerType() const { + + bool BitVectorType::isBitVectorType() const { return true; } - + bool BaseType::isRationalType() const { return false; } @@ -65,24 +74,24 @@ namespace storm { return "int"; } - BoundedIntegerType::BoundedIntegerType(std::size_t width) : width(width) { + BitVectorType::BitVectorType(std::size_t width) : width(width) { // Intentionally left empty. } - uint64_t BoundedIntegerType::getMask() const { - return BoundedIntegerType::mask; + uint64_t BitVectorType::getMask() const { + return BitVectorType::mask; } - std::string BoundedIntegerType::getStringRepresentation() const { - return "int[" + std::to_string(width) + "]"; + std::string BitVectorType::getStringRepresentation() const { + return "bv[" + std::to_string(width) + "]"; } - std::size_t BoundedIntegerType::getWidth() const { + std::size_t BitVectorType::getWidth() const { return width; } - bool BoundedIntegerType::operator==(BaseType const& other) const { - return this->getMask() == other.getMask() && this->width == static_cast(other).width; + bool BitVectorType::operator==(BaseType const& other) const { + return BaseType::operator==(other) && this->width == static_cast(other).getWidth(); } uint64_t RationalType::getMask() const { @@ -93,6 +102,16 @@ namespace storm { return "rational"; } + bool operator<(BaseType const& first, BaseType const& second) { + if (first.getMask() < second.getMask()) { + return true; + } + if (first.isBitVectorType() && second.isBitVectorType()) { + return static_cast(first).getWidth() < static_cast(second).getWidth(); + } + return false; + } + Type::Type() : manager(nullptr), innerType(nullptr) { // Intentionally left empty. } @@ -109,32 +128,28 @@ namespace storm { return this->innerType->getMask(); } - std::string Type::getStringRepresentation() const { - return this->innerType->getStringRepresentation(); - } - - bool Type::isNumericalType() const { - return this->isIntegerType() || this->isRationalType(); + bool Type::isBooleanType() const { + return this->innerType->isBooleanType(); } bool Type::isIntegerType() const { - return this->isUnboundedIntegerType() || this->isBoundedIntegerType(); + return this->innerType->isIntegerType(); } - bool Type::isBooleanType() const { - return this->innerType->isBooleanType(); + bool Type::isBitVectorType() const { + return this->innerType->isBitVectorType(); } - bool Type::isUnboundedIntegerType() const { - return this->innerType->isIntegerType(); + bool Type::isNumericalType() const { + return this->isIntegerType() || this->isRationalType(); } - bool Type::isBoundedIntegerType() const { - return this->innerType->isBoundedIntegerType(); + std::string Type::getStringRepresentation() const { + return this->innerType->getStringRepresentation(); } std::size_t Type::getWidth() const { - return dynamic_cast(*this->innerType).getWidth(); + return static_cast(*this->innerType).getWidth(); } bool Type::isRationalType() const { @@ -147,10 +162,7 @@ namespace storm { Type Type::plusMinusTimes(Type const& other) const { STORM_LOG_THROW(this->isNumericalType() && other.isNumericalType(), storm::exceptions::InvalidTypeException, "Operator requires numerical operands."); - if (this->isRationalType() || other.isRationalType()) { - return this->getManager().getRationalType(); - } - return getManager().getIntegerType(); + return std::max(*this, other); } Type Type::minus() const { @@ -160,18 +172,14 @@ namespace storm { Type Type::divide(Type const& other) const { STORM_LOG_THROW(this->isNumericalType() && other.isNumericalType(), storm::exceptions::InvalidTypeException, "Operator requires numerical operands."); - if (this->isRationalType() || other.isRationalType()) { - return this->getManager().getRationalType(); - } - return this->getManager().getIntegerType(); + STORM_LOG_THROW(!this->isBitVectorType() && !other.isBitVectorType(), storm::exceptions::InvalidTypeException, "Operator requires non-bitvector operands."); + return std::max(*this, other); } Type Type::power(Type const& other) const { STORM_LOG_THROW(this->isNumericalType() && other.isNumericalType(), storm::exceptions::InvalidTypeException, "Operator requires numerical operands."); - if (this->isRationalType() || other.isRationalType()) { - return getManager().getRationalType(); - } - return this->getManager().getIntegerType(); + STORM_LOG_THROW(!this->isBitVectorType() && !other.isBitVectorType(), storm::exceptions::InvalidTypeException, "Operator requires non-bitvector operands."); + return std::max(*this, other); } Type Type::logicalConnective(Type const& other) const { @@ -194,12 +202,8 @@ namespace storm { if (thenType == elseType) { return thenType; } else { - STORM_LOG_THROW(thenType.isNumericalType() == elseType.isNumericalType(), storm::exceptions::InvalidTypeException, "Operator 'ite' requires proper types."); - if (thenType.isRationalType() || elseType.isRationalType()) { - return this->getManager().getRationalType(); - } else { - return this->getManager().getIntegerType(); - } + STORM_LOG_THROW(thenType.isNumericalType() && elseType.isNumericalType(), storm::exceptions::InvalidTypeException, "Operator 'ite' requires proper types."); + return std::max(thenType, elseType); } return thenType; } @@ -211,10 +215,11 @@ namespace storm { Type Type::minimumMaximum(Type const& other) const { STORM_LOG_THROW(this->isNumericalType() && other.isNumericalType(), storm::exceptions::InvalidTypeException, "Operator requires numerical operands."); - if (this->isRationalType() || other.isRationalType()) { - return this->getManager().getRationalType(); - } - return this->getManager().getIntegerType(); + return std::max(*this, other); + } + + bool operator<(storm::expressions::Type const& type1, storm::expressions::Type const& type2) { + return *type1.innerType < *type2.innerType; } std::ostream& operator<<(std::ostream& stream, Type const& type) { diff --git a/src/storage/expressions/Type.h b/src/storage/expressions/Type.h index 22ebc94ca..bf677b2d0 100644 --- a/src/storage/expressions/Type.h +++ b/src/storage/expressions/Type.h @@ -38,9 +38,10 @@ namespace storm { */ virtual std::string getStringRepresentation() const = 0; + virtual bool isErrorType() const; virtual bool isBooleanType() const; virtual bool isIntegerType() const; - virtual bool isBoundedIntegerType() const; + virtual bool isBitVectorType() const; virtual bool isRationalType() const; }; @@ -51,7 +52,7 @@ namespace storm { virtual bool isBooleanType() const override; private: - static const uint64_t mask = (1ull << 61); + static const uint64_t mask = (1ull << 60); }; class IntegerType : public BaseType { @@ -64,14 +65,14 @@ namespace storm { static const uint64_t mask = (1ull << 62); }; - class BoundedIntegerType : public BaseType { + class BitVectorType : public BaseType { public: /*! - * Creates a new bounded integer type with the given bit width. + * Creates a new bounded bitvector type with the given bit width. * * @param width The bit width of the type. */ - BoundedIntegerType(std::size_t width); + BitVectorType(std::size_t width); /*! * Retrieves the bit width of the bounded type. @@ -80,16 +81,14 @@ namespace storm { */ std::size_t getWidth() const; - virtual uint64_t getMask() const override; - virtual bool operator==(BaseType const& other) const override; - + virtual uint64_t getMask() const override; virtual std::string getStringRepresentation() const override; - - virtual bool isBoundedIntegerType() const override; + virtual bool isIntegerType() const override; + virtual bool isBitVectorType() const override; private: - static const uint64_t mask = (1ull << 61) | (1ull << 62); + static const uint64_t mask = (1ull << 61); // The bit width of the type. std::size_t width; @@ -108,15 +107,19 @@ namespace storm { class ErrorType : public BaseType { public: virtual uint64_t getMask() const override; - virtual std::string getStringRepresentation() const override; + virtual bool isErrorType() const override; private: static const uint64_t mask = 0; }; + bool operator<(BaseType const& first, BaseType const& second); + class Type { public: + friend bool operator<(storm::expressions::Type const& type1, storm::expressions::Type const& type2); + Type(); /*! @@ -150,11 +153,11 @@ namespace storm { std::string getStringRepresentation() const; /*! - * Checks whether this type is a numerical type. + * Checks whether this type is a boolean type. * - * @return True iff the type is a numerical one. + * @return True iff the type is a boolean one. */ - bool isNumericalType() const; + bool isBooleanType() const; /*! * Checks whether this type is an integral type. @@ -162,41 +165,34 @@ namespace storm { * @return True iff the type is a integral one. */ bool isIntegerType() const; - + /*! - * Checks whether this type is a boolean type. + * Checks whether this type is a bitvector type. * - * @return True iff the type is a boolean one. + * @return True iff the type is a bitvector one. */ - bool isBooleanType() const; - + bool isBitVectorType() const; + /*! - * Checks whether this type is an unbounded integral type. + * Checks whether this type is a rational type. * - * @return True iff the type is a unbounded integral one. + * @return True iff the type is a rational one. */ - bool isUnboundedIntegerType() const; - + bool isRationalType() const; + /*! - * Checks whether this type is a bounded integral type. + * Checks whether this type is a numerical type. * - * @return True iff the type is a bounded integral one. + * @return True iff the type is a numerical one. */ - bool isBoundedIntegerType() const; + bool isNumericalType() const; /*! - * Retrieves the bit width of the type, provided that it is a bounded integral type. + * Retrieves the bit width of the type, provided that it is a bitvector type. * - * @return The bit width of the bounded integral type. + * @return The bit width of the bitvector type. */ std::size_t getWidth() const; - - /*! - * Checks whether this type is a rational type. - * - * @return True iff the type is a rational one. - */ - bool isRationalType() const; /*! * Retrieves the manager of the type. @@ -226,7 +222,8 @@ namespace storm { }; std::ostream& operator<<(std::ostream& stream, Type const& type); - + + bool operator<(storm::expressions::Type const& type1, storm::expressions::Type const& type2); } } @@ -238,14 +235,6 @@ namespace std { return std::hash()(type.getMask()); } }; - - // Provide a less operator, so we can put types in ordered collections. - template <> - struct less { - std::size_t operator()(storm::expressions::Type const& type1, storm::expressions::Type const& type2) const { - return type1.getMask() < type2.getMask(); - } - }; } #endif /* STORM_STORAGE_EXPRESSIONS_EXPRESSIONRETURNTYPE_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/Valuation.cpp b/src/storage/expressions/Valuation.cpp index 9696e7d8d..8b393c17c 100644 --- a/src/storage/expressions/Valuation.cpp +++ b/src/storage/expressions/Valuation.cpp @@ -7,6 +7,10 @@ namespace storm { // Intentionally left empty. } + Valuation::~Valuation() { + // Intentionally left empty. + } + ExpressionManager const& Valuation::getManager() const { return *manager; } diff --git a/src/storage/expressions/Valuation.h b/src/storage/expressions/Valuation.h index 8b420a3ec..0aeeb8eca 100644 --- a/src/storage/expressions/Valuation.h +++ b/src/storage/expressions/Valuation.h @@ -23,6 +23,11 @@ namespace storm { */ Valuation(std::shared_ptr const& manager); + /*! + * Declare virtual destructor, so we can properly delete instances later. + */ + virtual ~Valuation(); + /*! * Retrieves the value of the given boolean variable. * @@ -48,12 +53,12 @@ namespace storm { virtual int_fast64_t getIntegerValue(Variable const& integerVariable) const = 0; /*! - * Retrieves the value of the given bounded integer variable. + * Retrieves the value of the given bit vector variable. * - * @param integerVariable The bounded integer variable whose value to retrieve. - * @return The value of the bounded integer variable. + * @param bitVectorVariable The bit vector variable whose value to retrieve. + * @return The value of the bit vector variable. */ - virtual int_fast64_t getBoundedIntegerValue(Variable const& integerVariable) const = 0; + virtual int_fast64_t getBitVectorValue(Variable const& bitVectorVariable) const = 0; /*! * Sets the value of the given integer variable to the provided value. @@ -64,12 +69,12 @@ namespace storm { virtual void setIntegerValue(Variable const& integerVariable, int_fast64_t value) = 0; /*! - * Sets the value of the given bounded integer variable to the provided value. + * Sets the value of the given bit vector variable to the provided value. * - * @param integerVariable The variable whose value to set. + * @param bitVectorVariable The variable whose value to set. * @param value The new value of the variable. */ - virtual void setBoundedIntegerValue(Variable const& integerVariable, int_fast64_t value) = 0; + virtual void setBitVectorValue(Variable const& bitVectorVariable, int_fast64_t value) = 0; /*! * Retrieves the value of the given rational variable. diff --git a/src/storage/expressions/Variable.cpp b/src/storage/expressions/Variable.cpp index 1423aee11..a5f425215 100644 --- a/src/storage/expressions/Variable.cpp +++ b/src/storage/expressions/Variable.cpp @@ -47,8 +47,8 @@ namespace storm { return this->getType().isIntegerType(); } - bool Variable::hasBoundedIntegerType() const { - return this->getType().isBoundedIntegerType(); + bool Variable::hasBitVectorType() const { + return this->getType().isBitVectorType(); } bool Variable::hasRationalType() const { diff --git a/src/storage/expressions/Variable.h b/src/storage/expressions/Variable.h index d1bb31b7e..c3a18eab0 100644 --- a/src/storage/expressions/Variable.h +++ b/src/storage/expressions/Variable.h @@ -106,11 +106,11 @@ namespace storm { bool hasIntegerType() const; /*! - * Checks whether the variable is of integral type. + * Checks whether the variable is of a bit vector type. * - * @return True iff the variable if of integral type. + * @return True iff the variable is of a bit vector type. */ - bool hasBoundedIntegerType() const; + bool hasBitVectorType() const; /*! * Checks whether the variable is of rational type. diff --git a/src/storage/prism/Assignment.cpp b/src/storage/prism/Assignment.cpp index 4d1211959..c4973e231 100644 --- a/src/storage/prism/Assignment.cpp +++ b/src/storage/prism/Assignment.cpp @@ -2,12 +2,16 @@ namespace storm { namespace prism { - Assignment::Assignment(std::string const& variableName, storm::expressions::Expression const& expression, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variableName(variableName), expression(expression) { + Assignment::Assignment(storm::expressions::Variable const& variable, storm::expressions::Expression const& expression, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variable(variable), expression(expression) { // Intentionally left empty. } std::string const& Assignment::getVariableName() const { - return variableName; + return variable.getName(); + } + + storm::expressions::Variable const& Assignment::getVariable() const { + return variable; } storm::expressions::Expression const& Assignment::getExpression() const { @@ -15,7 +19,7 @@ namespace storm { } Assignment Assignment::substitute(std::map const& substitution) const { - return Assignment(this->getVariableName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); + return Assignment(this->getVariable(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } std::ostream& operator<<(std::ostream& stream, Assignment const& assignment) { diff --git a/src/storage/prism/Assignment.h b/src/storage/prism/Assignment.h index a22154650..cb6006ba7 100644 --- a/src/storage/prism/Assignment.h +++ b/src/storage/prism/Assignment.h @@ -15,12 +15,12 @@ namespace storm { /*! * Constructs an assignment using the given variable name and expression. * - * @param variableName The variable that this assignment targets. + * @param variable The variable that this assignment targets. * @param expression The expression to assign to the variable. * @param filename The filename in which the assignment is defined. * @param lineNumber The line number in which the assignment is defined. */ - Assignment(std::string const& variableName, storm::expressions::Expression const& expression, std::string const& filename = "", uint_fast64_t lineNumber = 0); + Assignment(storm::expressions::Variable const& variable, storm::expressions::Expression const& expression, std::string const& filename = "", uint_fast64_t lineNumber = 0); // Create default implementations of constructors/assignment. Assignment() = default; @@ -38,6 +38,13 @@ namespace storm { */ std::string const& getVariableName() const; + /*! + * Retrieves the variable that is written to by this assignment. + * + * @return The variable that is written to by this assignment. + */ + storm::expressions::Variable const& getVariable() const; + /*! * Retrieves the expression that is assigned to the variable. * @@ -56,8 +63,8 @@ namespace storm { friend std::ostream& operator<<(std::ostream& stream, Assignment const& assignment); private: - // The name of the variable that this assignment targets. - std::string variableName; + // The variable written in this assignment. + storm::expressions::Variable variable; // The expression that is assigned to the variable. storm::expressions::Expression expression;