diff --git a/src/adapters/SymbolicExpressionAdapter.h b/src/adapters/SymbolicExpressionAdapter.h new file mode 100644 index 000000000..2e07bae83 --- /dev/null +++ b/src/adapters/SymbolicExpressionAdapter.h @@ -0,0 +1,231 @@ +/* + * SymbolicExpressionAdapter.h + * + * Created on: 27.01.2013 + * Author: Christian Dehnert + */ + +#ifndef STORM_ADAPTERS_SYMBOLICEXPRESSIONADAPTER_H_ +#define STORM_ADAPTERS_SYMBOLICEXPRESSIONADAPTER_H_ + +#include "src/ir/expressions/ExpressionVisitor.h" +#include "src/exceptions/ExpressionEvaluationException.h" + +#include "cuddObj.hh" + +#include +#include + +namespace storm { + +namespace adapters { + +class SymbolicExpressionAdapter : public storm::ir::expressions::ExpressionVisitor { +public: + SymbolicExpressionAdapter(std::unordered_map>& variableToDecisionDiagramVariableMap) : stack(), variableToDecisionDiagramVariableMap(variableToDecisionDiagramVariableMap) { + + } + + ADD* translateExpression(std::shared_ptr expression) { + expression->accept(this); + return stack.top(); + } + + virtual void visit(storm::ir::expressions::BaseExpression* expression) { + std::cout << expression->toString() << std::endl; + throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression " + << " of abstract superclass type."; + } + + virtual void visit(storm::ir::expressions::BinaryBooleanFunctionExpression* expression) { + expression->getLeft()->accept(this); + expression->getRight()->accept(this); + + ADD* rightResult = stack.top(); + stack.pop(); + ADD* leftResult = stack.top(); + stack.pop(); + + switch(expression->getFunctionType()) { + case storm::ir::expressions::BinaryBooleanFunctionExpression::AND: + stack.push(new ADD(leftResult->Times(*rightResult))); + break; + case storm::ir::expressions::BinaryBooleanFunctionExpression::OR: + stack.push(new ADD(leftResult->Plus(*rightResult))); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown boolean binary operator: '" << expression->getFunctionType() << "'."; + } + + // delete leftResult; + // delete rightResult; + } + + virtual void visit(storm::ir::expressions::BinaryNumericalFunctionExpression* expression) { + expression->getLeft()->accept(this); + expression->getRight()->accept(this); + + ADD* rightResult = stack.top(); + stack.pop(); + ADD* leftResult = stack.top(); + stack.pop(); + + switch(expression->getFunctionType()) { + case storm::ir::expressions::BinaryNumericalFunctionExpression::PLUS: + stack.push(new ADD(leftResult->Plus(*rightResult))); + break; + case storm::ir::expressions::BinaryNumericalFunctionExpression::MINUS: + stack.push(new ADD(leftResult->Minus(*rightResult))); + break; + case storm::ir::expressions::BinaryNumericalFunctionExpression::TIMES: + stack.push(new ADD(leftResult->Times(*rightResult))); + break; + case storm::ir::expressions::BinaryNumericalFunctionExpression::DIVIDE: + stack.push(new ADD(leftResult->Divide(*rightResult))); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown boolean binary operator: '" << expression->getFunctionType() << "'."; + } + } + + virtual void visit(storm::ir::expressions::BinaryRelationExpression* expression) { + expression->getLeft()->accept(this); + expression->getRight()->accept(this); + + ADD* rightResult = stack.top(); + stack.pop(); + ADD* leftResult = stack.top(); + stack.pop(); + + switch(expression->getRelationType()) { + case storm::ir::expressions::BinaryRelationExpression::EQUAL: + stack.push(new ADD(leftResult->Equals(*rightResult))); + break; + case storm::ir::expressions::BinaryRelationExpression::NOT_EQUAL: + stack.push(new ADD(leftResult->NotEquals(*rightResult))); + break; + case storm::ir::expressions::BinaryRelationExpression::LESS: + stack.push(new ADD(leftResult->LessThan(*rightResult))); + break; + case storm::ir::expressions::BinaryRelationExpression::LESS_OR_EQUAL: + stack.push(new ADD(leftResult->LessThanOrEqual(*rightResult))); + break; + case storm::ir::expressions::BinaryRelationExpression::GREATER: + stack.push(new ADD(leftResult->GreaterThan(*rightResult))); + break; + case storm::ir::expressions::BinaryRelationExpression::GREATER_OR_EQUAL: + stack.push(new ADD(leftResult->GreaterThanOrEqual(*rightResult))); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown boolean binary operator: '" << expression->getRelationType() << "'."; + } + } + + virtual void visit(storm::ir::expressions::BooleanConstantExpression* expression) { + if (!expression->isDefined()) { + throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Boolean constant '" << expression->getConstantName() << "' is undefined."; + } + + storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); + stack.push(new ADD(*cuddUtility->getConstant(expression->getValue() ? 1 : 0))); + } + + virtual void visit(storm::ir::expressions::BooleanLiteral* expression) { + storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); + stack.push(new ADD(*cuddUtility->getConstant(expression->getValueAsBool(nullptr) ? 1 : 0))); + } + + virtual void visit(storm::ir::expressions::DoubleConstantExpression* expression) { + if (expression->isDefined()) { + throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Double constant '" << expression->getConstantName() << "' is undefined."; + } + + storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); + stack.push(new ADD(*cuddUtility->getConstant(expression->getValue()))); + } + + virtual void visit(storm::ir::expressions::DoubleLiteral* expression) { + storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); + stack.push(new ADD(*cuddUtility->getConstant(expression->getValueAsDouble(nullptr)))); + } + + virtual void visit(storm::ir::expressions::IntegerConstantExpression* expression) { + if (!expression->isDefined()) { + throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Integer constant '" << expression->getConstantName() << "' is undefined."; + } + + storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); + stack.push(new ADD(*cuddUtility->getConstant(expression->getValue()))); + } + + virtual void visit(storm::ir::expressions::IntegerLiteral* expression) { + storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); + stack.push(new ADD(*cuddUtility->getConstant(expression->getValueAsInt(nullptr)))); + } + + virtual void visit(storm::ir::expressions::UnaryBooleanFunctionExpression* expression) { + expression->getChild()->accept(this); + + ADD* childResult = stack.top(); + stack.pop(); + + switch (expression->getFunctionType()) { + case storm::ir::expressions::UnaryBooleanFunctionExpression::NOT: + stack.push(new ADD(~(*childResult))); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown boolean unary operator: '" << expression->getFunctionType() << "'."; + } + } + + virtual void visit(storm::ir::expressions::UnaryNumericalFunctionExpression* expression) { + expression->getChild()->accept(this); + + ADD* childResult = stack.top(); + stack.pop(); + + storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); + ADD* result = cuddUtility->getConstant(0); + switch(expression->getFunctionType()) { + case storm::ir::expressions::UnaryNumericalFunctionExpression::MINUS: + stack.push(new ADD(result->Minus(*childResult))); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown numerical unary operator: '" << expression->getFunctionType() << "'."; + } + + } + + virtual void visit(storm::ir::expressions::VariableExpression* expression) { + storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); + + std::vector const& variables = variableToDecisionDiagramVariableMap[expression->getVariableName()]; + + ADD* result = cuddUtility->getConstant(0); + if (expression->getType() == storm::ir::expressions::BaseExpression::bool_) { + cuddUtility->setValueAtIndex(result, 1, variables, 1); + } else { + int64_t low = expression->getLowerBound()->getValueAsInt(nullptr); + int64_t high = expression->getUpperBound()->getValueAsInt(nullptr); + + for (uint_fast64_t i = low; i <= high; ++i) { + cuddUtility->setValueAtIndex(result, i - low, variables, i); + } + } + + stack.push(result); + } + +private: + std::stack stack; + std::unordered_map>& variableToDecisionDiagramVariableMap; +}; + +} // namespace adapters + +} // namespace storm + +#endif /* STORM_ADAPTERS_SYMBOLICEXPRESSIONADAPTER_H_ */ diff --git a/src/adapters/SymbolicModelAdapter.h b/src/adapters/SymbolicModelAdapter.h index 1cdbe7c04..7e8674dae 100644 --- a/src/adapters/SymbolicModelAdapter.h +++ b/src/adapters/SymbolicModelAdapter.h @@ -11,10 +11,11 @@ #include "src/exceptions/WrongFileFormatException.h" #include "src/utility/CuddUtility.h" -#include "src/ir/expressions/ExpressionVisitor.h" +#include "SymbolicExpressionAdapter.h" #include "cuddObj.hh" #include +#include namespace storm { @@ -23,24 +24,99 @@ namespace adapters { class SymbolicModelAdapter { public: - SymbolicModelAdapter() : cuddUtility(storm::utility::cuddUtilityInstance()) { + SymbolicModelAdapter() : cuddUtility(storm::utility::cuddUtilityInstance()), allDecisionDiagramVariables(), + allRowDecisionDiagramVariables(), allColumnDecisionDiagramVariables(), booleanRowDecisionDiagramVariables(), + integerRowDecisionDiagramVariables(), booleanColumnDecisionDiagramVariables(), integerColumnDecisionDiagramVariables(), + variableToRowDecisionDiagramVariableMap(), variableToColumnDecisionDiagramVariableMap(), + variableToIdentityDecisionDiagramMap(), + rowExpressionAdapter(variableToRowDecisionDiagramVariableMap), columnExpressionAdapter(variableToColumnDecisionDiagramVariableMap) { } void toMTBDD(storm::ir::Program const& program) { LOG4CPLUS_INFO(logger, "Creating MTBDD representation for probabilistic program."); createDecisionDiagramVariables(program); + createIdentityDecisionDiagrams(program); + ADD* systemAdd = cuddUtility->getZero(); for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) { storm::ir::Module const& module = program.getModule(i); + ADD* moduleAdd = cuddUtility->getZero(); for (uint_fast64_t j = 0; j < module.getNumberOfCommands(); ++j) { storm::ir::Command const& command = module.getCommand(j); + ADD* commandAdd = cuddUtility->getZero(); + ADD* guard = rowExpressionAdapter.translateExpression(command.getGuard()); + if (*guard != *cuddUtility->getZero()) { + for (uint_fast64_t i = 0; i < command.getNumberOfUpdates(); ++i) { + ADD* updateAdd = cuddUtility->getOne(); + + storm::ir::Update const& update = command.getUpdate(i); + + std::map booleanAssignments = update.getBooleanAssignments(); + for (auto assignmentPair : booleanAssignments) { + ADD* updateExpr = rowExpressionAdapter.translateExpression(assignmentPair.second.getExpression()); + + ADD* temporary = cuddUtility->getZero(); + cuddUtility->setValueAtIndex(temporary, 0, variableToColumnDecisionDiagramVariableMap[assignmentPair.first], 0); + cuddUtility->setValueAtIndex(temporary, 1, variableToColumnDecisionDiagramVariableMap[assignmentPair.first], 1); + + ADD* result = new ADD(*updateExpr * *guard); + result = new ADD(result->Equals(*temporary)); + + *updateAdd = *updateAdd * *result; + } + + std::map integerAssignments = update.getIntegerAssignments(); + for (auto assignmentPair : integerAssignments) { + ADD* updateExpr = rowExpressionAdapter.translateExpression(assignmentPair.second.getExpression()); + + ADD* temporary = cuddUtility->getZero(); + + uint_fast64_t variableIndex = module.getIntegerVariableIndex(assignmentPair.first); + storm::ir::IntegerVariable integerVariable = module.getIntegerVariable(variableIndex); + int_fast64_t low = integerVariable.getLowerBound()->getValueAsInt(nullptr); + int_fast64_t high = integerVariable.getUpperBound()->getValueAsInt(nullptr); + + for (uint_fast64_t i = low; i <= high; ++i) { + cuddUtility->setValueAtIndex(temporary, i - low, variableToColumnDecisionDiagramVariableMap[assignmentPair.first], i); + } + + ADD* result = new ADD(*updateExpr * *guard); + result = new ADD(result->Equals(*temporary)); + *result *= *guard; + + *updateAdd = *updateAdd * *result; + } + for (uint_fast64_t i = 0; i < module.getNumberOfBooleanVariables(); ++i) { + storm::ir::BooleanVariable const& booleanVariable = module.getBooleanVariable(i); + + if (update.getBooleanAssignments().find(booleanVariable.getName()) == update.getBooleanAssignments().end()) { + *updateAdd = *updateAdd * *variableToIdentityDecisionDiagramMap[booleanVariable.getName()]; + } + } + for (uint_fast64_t i = 0; i < module.getNumberOfIntegerVariables(); ++i) { + storm::ir::IntegerVariable const& integerVariable = module.getIntegerVariable(i); + + if (update.getIntegerAssignments().find(integerVariable.getName()) == update.getIntegerAssignments().end()) { + *updateAdd = *updateAdd * *variableToIdentityDecisionDiagramMap[integerVariable.getName()]; + } + } + + *commandAdd += *updateAdd * *cuddUtility->getConstant(update.getLikelihoodExpression()->getValueAsDouble(nullptr)); + } + *moduleAdd += *commandAdd; + } else { + LOG4CPLUS_WARN(logger, "Guard " << command.getGuard()->toString() << " is unsatisfiable."); + } } + *systemAdd += *moduleAdd; } + performReachability(program, systemAdd); + LOG4CPLUS_INFO(logger, "Done creating MTBDD representation for probabilistic program."); } @@ -57,6 +133,108 @@ private: std::unordered_map> variableToRowDecisionDiagramVariableMap; std::unordered_map> variableToColumnDecisionDiagramVariableMap; + std::unordered_map variableToIdentityDecisionDiagramMap; + + SymbolicExpressionAdapter rowExpressionAdapter; + SymbolicExpressionAdapter columnExpressionAdapter; + + ADD* getInitialStateDecisionDiagram(storm::ir::Program const& program) { + ADD* initialStates = cuddUtility->getOne(); + for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) { + storm::ir::Module const& module = program.getModule(i); + + for (uint_fast64_t j = 0; j < module.getNumberOfBooleanVariables(); ++j) { + storm::ir::BooleanVariable const& booleanVariable = module.getBooleanVariable(j); + bool initialValue = booleanVariable.getInitialValue()->getValueAsBool(nullptr); + *initialStates *= *cuddUtility->getConstantEncoding(1, variableToRowDecisionDiagramVariableMap[booleanVariable.getName()]); + } + for (uint_fast64_t j = 0; j < module.getNumberOfIntegerVariables(); ++j) { + storm::ir::IntegerVariable const& integerVariable = module.getIntegerVariable(j); + int_fast64_t initialValue = integerVariable.getInitialValue()->getValueAsInt(nullptr); + int_fast64_t low = integerVariable.getLowerBound()->getValueAsInt(nullptr); + *initialStates *= *cuddUtility->getConstantEncoding(initialValue - low, variableToRowDecisionDiagramVariableMap[integerVariable.getName()]); + } + } + + cuddUtility->dumpDotToFile(initialStates, "initstates.add"); + return initialStates; + } + + void performReachability(storm::ir::Program const& program, ADD* systemAdd) { + cuddUtility->dumpDotToFile(systemAdd, "reachtransold.add"); + ADD* reachableStates = getInitialStateDecisionDiagram(program); + ADD* newReachableStates = reachableStates; + + ADD* rowCube = cuddUtility->getOne(); + for (auto variablePtr : allRowDecisionDiagramVariables) { + *rowCube *= *variablePtr; + } + + bool changed; + int iter = 0; + do { + changed = false; + std::cout << "iter " << ++iter << std::endl; + + *newReachableStates = *reachableStates * *systemAdd; + newReachableStates->ExistAbstract(*rowCube); + + cuddUtility->dumpDotToFile(newReachableStates, "reach1.add"); + + cuddUtility->permuteVariables(newReachableStates, allColumnDecisionDiagramVariables, allRowDecisionDiagramVariables, allDecisionDiagramVariables.size()); + + cuddUtility->dumpDotToFile(newReachableStates, "reach2.add"); + cuddUtility->dumpDotToFile(reachableStates, "reachplus.add"); + *newReachableStates += *reachableStates; + + cuddUtility->dumpDotToFile(newReachableStates, "reach3.add"); + cuddUtility->dumpDotToFile(reachableStates, "reach4.add"); + + if (*newReachableStates != *reachableStates) { + std::cout << "changed ..." << std::endl; + changed = true; + } + *reachableStates = *newReachableStates; + } while (changed); + + *systemAdd *= *reachableStates; + cuddUtility->dumpDotToFile(systemAdd, "reachtrans.add"); + } + + void createIdentityDecisionDiagrams(storm::ir::Program const& program) { + for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) { + storm::ir::Module const& module = program.getModule(i); + + for (uint_fast64_t j = 0; j < module.getNumberOfBooleanVariables(); ++j) { + storm::ir::BooleanVariable const& booleanVariable = module.getBooleanVariable(j); + ADD* identity = cuddUtility->getZero(); + cuddUtility->setValueAtIndices(identity, 0, 0, + variableToRowDecisionDiagramVariableMap[booleanVariable.getName()], + variableToColumnDecisionDiagramVariableMap[booleanVariable.getName()], 1); + cuddUtility->setValueAtIndices(identity, 1, 1, + variableToRowDecisionDiagramVariableMap[booleanVariable.getName()], + variableToColumnDecisionDiagramVariableMap[booleanVariable.getName()], 1); + variableToIdentityDecisionDiagramMap[booleanVariable.getName()] = identity; + } + + for (uint_fast64_t j = 0; j < module.getNumberOfIntegerVariables(); ++j) { + storm::ir::IntegerVariable const& integerVariable = module.getIntegerVariable(j); + + ADD* identity = cuddUtility->getZero(); + + int_fast64_t low = integerVariable.getLowerBound()->getValueAsInt(nullptr); + int_fast64_t high = integerVariable.getUpperBound()->getValueAsInt(nullptr); + + for (uint_fast64_t i = low; i <= high; ++i) { + cuddUtility->setValueAtIndices(identity, i - low, i - low, + variableToRowDecisionDiagramVariableMap[integerVariable.getName()], + variableToColumnDecisionDiagramVariableMap[integerVariable.getName()], 1); + } + variableToIdentityDecisionDiagramMap[integerVariable.getName()] = identity; + } + } + } + void createDecisionDiagramVariables(storm::ir::Program const& program) { for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) { storm::ir::Module const& module = program.getModule(i); diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index f76a53b88..d800256c8 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -7,6 +7,8 @@ #include "Module.h" +#include "src/exceptions/InvalidArgumentException.h" + #include namespace storm { @@ -56,6 +58,26 @@ uint_fast64_t Module::getNumberOfCommands() const { return this->commands.size(); } +// Return the index of the variable if it exists and throw exception otherwise. +uint_fast64_t Module::getBooleanVariableIndex(std::string variableName) const { + auto it = booleanVariablesToIndexMap.find(variableName); + if (it != booleanVariablesToIndexMap.end()) { + return it->second; + } + throw storm::exceptions::InvalidArgumentException() << "Cannot retrieve index of unknown " + << "boolean variable " << variableName << "."; +} + +// Return the index of the variable if it exists and throw exception otherwise. +uint_fast64_t Module::getIntegerVariableIndex(std::string variableName) const { + auto it = integerVariablesToIndexMap.find(variableName); + if (it != integerVariablesToIndexMap.end()) { + return it->second; + } + throw storm::exceptions::InvalidArgumentException() << "Cannot retrieve index of unknown " + << "variable " << variableName << "."; +} + // Return the requested command. storm::ir::Command const& Module::getCommand(uint_fast64_t index) const { return this->commands[index]; diff --git a/src/ir/Module.h b/src/ir/Module.h index 0ed2ed010..2c362da41 100644 --- a/src/ir/Module.h +++ b/src/ir/Module.h @@ -73,6 +73,20 @@ public: */ uint_fast64_t getNumberOfCommands() const; + /*! + * Retrieves the index of the boolean variable with the given name. + * @param variableName the name of the variable whose index to retrieve. + * @returns the index of the boolean variable with the given name. + */ + uint_fast64_t getBooleanVariableIndex(std::string variableName) const; + + /*! + * Retrieves the index of the integer variable with the given name. + * @param variableName the name of the variable whose index to retrieve. + * @returns the index of the integer variable with the given name. + */ + uint_fast64_t getIntegerVariableIndex(std::string variableName) const; + /*! * Retrieves a reference to the command with the given index. * @returns a reference to the command with the given index. diff --git a/src/ir/expressions/BaseExpression.h b/src/ir/expressions/BaseExpression.h index 8324f372a..fae7f270a 100644 --- a/src/ir/expressions/BaseExpression.h +++ b/src/ir/expressions/BaseExpression.h @@ -11,6 +11,7 @@ #include "src/exceptions/ExpressionEvaluationException.h" #include "src/exceptions/NotImplementedException.h" +#include "ExpressionVisitor.h" #include "src/utility/CuddUtility.h" #include @@ -66,7 +67,9 @@ public: << this->getTypeName() << " because evaluation implementation is missing."; } - virtual ADD* toAdd() const = 0; + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); + } virtual std::string toString() const = 0; diff --git a/src/ir/expressions/BinaryBooleanFunctionExpression.h b/src/ir/expressions/BinaryBooleanFunctionExpression.h index 9c54362c2..0cf866f74 100644 --- a/src/ir/expressions/BinaryBooleanFunctionExpression.h +++ b/src/ir/expressions/BinaryBooleanFunctionExpression.h @@ -8,7 +8,7 @@ #ifndef STORM_IR_EXPRESSIONS_BINARYBOOLEANFUNCTIONEXPRESSION_H_ #define STORM_IR_EXPRESSIONS_BINARYBOOLEANFUNCTIONEXPRESSION_H_ -#include "src/ir/expressions/BaseExpression.h" +#include "src/ir/expressions/BinaryExpression.h" #include "src/utility/CuddUtility.h" @@ -21,11 +21,11 @@ namespace ir { namespace expressions { -class BinaryBooleanFunctionExpression : public BaseExpression { +class BinaryBooleanFunctionExpression : public BinaryExpression { public: enum FunctionType {AND, OR}; - BinaryBooleanFunctionExpression(std::shared_ptr left, std::shared_ptr right, FunctionType functionType) : BaseExpression(bool_), left(left), right(right), functionType(functionType) { + BinaryBooleanFunctionExpression(std::shared_ptr left, std::shared_ptr right, FunctionType functionType) : BinaryExpression(bool_, left, right), functionType(functionType) { } @@ -34,8 +34,8 @@ public: } virtual bool getValueAsBool(std::pair, std::vector> const* variableValues) const { - bool resultLeft = left->getValueAsBool(variableValues); - bool resultRight = right->getValueAsBool(variableValues); + bool resultLeft = this->getLeft()->getValueAsBool(variableValues); + bool resultRight = this->getRight()->getValueAsBool(variableValues); switch(functionType) { case AND: return resultLeft & resultRight; break; case OR: return resultLeft | resultRight; break; @@ -44,33 +44,27 @@ public: } } - virtual ADD* toAdd() const { - ADD* leftAdd = left->toAdd(); - ADD* rightAdd = right->toAdd(); + FunctionType getFunctionType() const { + return functionType; + } - switch(functionType) { - case AND: return new ADD(leftAdd->Times(*rightAdd)); break; - case OR: return new ADD(leftAdd->Plus(*rightAdd)); break; - default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << functionType << "'."; - } + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { std::stringstream result; - result << left->toString(); + result << this->getLeft()->toString(); switch (functionType) { case AND: result << " & "; break; case OR: result << " | "; break; } - result << right->toString(); + result << this->getRight()->toString(); return result.str(); } private: - std::shared_ptr left; - std::shared_ptr right; FunctionType functionType; }; diff --git a/src/ir/expressions/BinaryExpression.h b/src/ir/expressions/BinaryExpression.h new file mode 100644 index 000000000..03b933c48 --- /dev/null +++ b/src/ir/expressions/BinaryExpression.h @@ -0,0 +1,44 @@ +/* + * BinaryExpression.h + * + * Created on: 27.01.2013 + * Author: Christian Dehnert + */ + +#ifndef STORM_IR_EXPRESSIONS_BINARYEXPRESSION_H_ +#define STORM_IR_EXPRESSIONS_BINARYEXPRESSION_H_ + +#include "BaseExpression.h" + +namespace storm { + +namespace ir { + +namespace expressions { + +class BinaryExpression : public BaseExpression { +public: + BinaryExpression(ReturnType type, std::shared_ptr left, std::shared_ptr right) : BaseExpression(type), left(left), right(right) { + + } + + std::shared_ptr const& getLeft() const { + return left; + } + + std::shared_ptr const& getRight() const { + return right; + } + +private: + std::shared_ptr left; + std::shared_ptr right; +}; + +} // namespace expressions + +} // namespace ir + +} // namespace storm + +#endif /* STORM_IR_EXPRESSIONS_BINARYEXPRESSION_H_ */ diff --git a/src/ir/expressions/BinaryNumericalFunctionExpression.h b/src/ir/expressions/BinaryNumericalFunctionExpression.h index b2dd12e9c..539cb3114 100644 --- a/src/ir/expressions/BinaryNumericalFunctionExpression.h +++ b/src/ir/expressions/BinaryNumericalFunctionExpression.h @@ -18,11 +18,11 @@ namespace ir { namespace expressions { -class BinaryNumericalFunctionExpression : public BaseExpression { +class BinaryNumericalFunctionExpression : public BinaryExpression { public: enum FunctionType {PLUS, MINUS, TIMES, DIVIDE}; - BinaryNumericalFunctionExpression(ReturnType type, std::shared_ptr left, std::shared_ptr right, FunctionType functionType) : BaseExpression(type), left(left), right(right), functionType(functionType) { + BinaryNumericalFunctionExpression(ReturnType type, std::shared_ptr left, std::shared_ptr right, FunctionType functionType) : BinaryExpression(type, left, right), functionType(functionType) { } @@ -30,13 +30,17 @@ public: } + FunctionType getFunctionType() const { + return functionType; + } + virtual int_fast64_t getValueAsInt(std::pair, std::vector> const* variableValues) const { if (this->getType() != int_) { BaseExpression::getValueAsInt(variableValues); } - int_fast64_t resultLeft = left->getValueAsInt(variableValues); - int_fast64_t resultRight = right->getValueAsInt(variableValues); + int_fast64_t resultLeft = this->getLeft()->getValueAsInt(variableValues); + int_fast64_t resultRight = this->getRight()->getValueAsInt(variableValues); switch(functionType) { case PLUS: return resultLeft + resultRight; break; case MINUS: return resultLeft - resultRight; break; @@ -52,8 +56,8 @@ public: BaseExpression::getValueAsDouble(variableValues); } - double resultLeft = left->getValueAsDouble(variableValues); - double resultRight = right->getValueAsDouble(variableValues); + double resultLeft = this->getLeft()->getValueAsDouble(variableValues); + double resultRight = this->getRight()->getValueAsDouble(variableValues); switch(functionType) { case PLUS: return resultLeft + resultRight; break; case MINUS: return resultLeft - resultRight; break; @@ -64,35 +68,23 @@ public: } } - virtual ADD* toAdd() const { - ADD* leftAdd = left->toAdd(); - ADD* rightAdd = right->toAdd(); - - switch(functionType) { - case PLUS: return new ADD(leftAdd->Plus(*rightAdd)); break; - case MINUS: return new ADD(leftAdd->Minus(*rightAdd)); break; - case TIMES: return new ADD(leftAdd->Times(*rightAdd)); break; - case DIVIDE: return new ADD(leftAdd->Divide(*rightAdd)); break; - default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << functionType << "'."; - } + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { - std::string result = left->toString(); + std::string result = this->getLeft()->toString(); switch (functionType) { case PLUS: result += " + "; break; case MINUS: result += " - "; break; case TIMES: result += " * "; break; case DIVIDE: result += " / "; break; } - result += right->toString(); + result += this->getRight()->toString(); return result; } private: - std::shared_ptr left; - std::shared_ptr right; FunctionType functionType; }; diff --git a/src/ir/expressions/BinaryRelationExpression.h b/src/ir/expressions/BinaryRelationExpression.h index 02907dfc5..d71072b13 100644 --- a/src/ir/expressions/BinaryRelationExpression.h +++ b/src/ir/expressions/BinaryRelationExpression.h @@ -8,7 +8,7 @@ #ifndef BINARYRELATIONEXPRESSION_H_ #define BINARYRELATIONEXPRESSION_H_ -#include "src/ir/expressions/BaseExpression.h" +#include "src/ir/expressions/BinaryExpression.h" #include "src/utility/CuddUtility.h" @@ -18,11 +18,11 @@ namespace ir { namespace expressions { -class BinaryRelationExpression : public BaseExpression { +class BinaryRelationExpression : public BinaryExpression { public: enum RelationType {EQUAL, NOT_EQUAL, LESS, LESS_OR_EQUAL, GREATER, GREATER_OR_EQUAL}; - BinaryRelationExpression(std::shared_ptr left, std::shared_ptr right, RelationType relationType) : BaseExpression(bool_), left(left), right(right), relationType(relationType) { + BinaryRelationExpression(std::shared_ptr left, std::shared_ptr right, RelationType relationType) : BinaryExpression(bool_, left, right), relationType(relationType) { } @@ -31,8 +31,8 @@ public: } virtual bool getValueAsBool(std::pair, std::vector> const* variableValues) const { - int_fast64_t resultLeft = left->getValueAsInt(variableValues); - int_fast64_t resultRight = right->getValueAsInt(variableValues); + int_fast64_t resultLeft = this->getLeft()->getValueAsInt(variableValues); + int_fast64_t resultRight = this->getRight()->getValueAsInt(variableValues); switch(relationType) { case EQUAL: return resultLeft == resultRight; break; case NOT_EQUAL: return resultLeft != resultRight; break; @@ -45,24 +45,16 @@ public: } } - virtual ADD* toAdd() const { - ADD* leftAdd = left->toAdd(); - ADD* rightAdd = right->toAdd(); + RelationType getRelationType() const { + return relationType; + } - switch(relationType) { - case EQUAL: return new ADD(leftAdd->Equals(*rightAdd)); break; - case NOT_EQUAL: return new ADD(leftAdd->NotEquals(*rightAdd)); break; - case LESS: return new ADD(leftAdd->LessThan(*rightAdd)); break; - case LESS_OR_EQUAL: return new ADD(leftAdd->LessThanOrEqual(*rightAdd)); break; - case GREATER: return new ADD(leftAdd->GreaterThan(*rightAdd)); break; - case GREATER_OR_EQUAL: return new ADD(leftAdd->GreaterThanOrEqual(*rightAdd)); break; - default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << relationType << "'."; - } + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { - std::string result = left->toString(); + std::string result = this->getLeft()->toString(); switch (relationType) { case EQUAL: result += " = "; break; case NOT_EQUAL: result += " != "; break; @@ -71,14 +63,12 @@ public: case GREATER: result += " > "; break; case GREATER_OR_EQUAL: result += " >= "; break; } - result += right->toString(); + result += this->getRight()->toString(); return result; } private: - std::shared_ptr left; - std::shared_ptr right; RelationType relationType; }; diff --git a/src/ir/expressions/BooleanConstantExpression.h b/src/ir/expressions/BooleanConstantExpression.h index b16cd5ad7..9181a5449 100644 --- a/src/ir/expressions/BooleanConstantExpression.h +++ b/src/ir/expressions/BooleanConstantExpression.h @@ -40,14 +40,8 @@ public: } } - virtual ADD* toAdd() const { - if (!defined) { - throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Boolean constant '" << this->getConstantName() << "' is undefined."; - } - - storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); - return new ADD(*cuddUtility->getConstant(value ? 1 : 0)); + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { diff --git a/src/ir/expressions/BooleanLiteral.h b/src/ir/expressions/BooleanLiteral.h index 6b92e4b7b..04f40285f 100644 --- a/src/ir/expressions/BooleanLiteral.h +++ b/src/ir/expressions/BooleanLiteral.h @@ -32,9 +32,8 @@ public: return value; } - virtual ADD* toAdd() const { - storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); - return new ADD(*cuddUtility->getConstant(value ? 1 : 0)); + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { diff --git a/src/ir/expressions/DoubleConstantExpression.h b/src/ir/expressions/DoubleConstantExpression.h index 9e3471f19..a027afa17 100644 --- a/src/ir/expressions/DoubleConstantExpression.h +++ b/src/ir/expressions/DoubleConstantExpression.h @@ -35,14 +35,8 @@ public: } } - virtual ADD* toAdd() const { - if (!defined) { - throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Double constant '" << this->getConstantName() << "' is undefined."; - } - - storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); - return new ADD(*cuddUtility->getConstant(value)); + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { diff --git a/src/ir/expressions/DoubleLiteral.h b/src/ir/expressions/DoubleLiteral.h index 1f447f7fb..72c464100 100644 --- a/src/ir/expressions/DoubleLiteral.h +++ b/src/ir/expressions/DoubleLiteral.h @@ -34,9 +34,8 @@ public: return value; } - virtual ADD* toAdd() const { - storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); - return new ADD(*cuddUtility->getConstant(value)); + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { diff --git a/src/ir/expressions/ExpressionVisitor.h b/src/ir/expressions/ExpressionVisitor.h index 7fc05b79b..6a8d9d813 100644 --- a/src/ir/expressions/ExpressionVisitor.h +++ b/src/ir/expressions/ExpressionVisitor.h @@ -8,30 +8,41 @@ #ifndef STORM_IR_EXPRESSIONS_EXPRESSIONVISITOR_H_ #define STORM_IR_EXPRESSIONS_EXPRESSIONVISITOR_H_ -#include "Expressions.h" - namespace storm { namespace ir { namespace expressions { +class BaseExpression; +class BinaryBooleanFunctionExpression; +class BinaryNumericalFunctionExpression; +class BinaryRelationExpression; +class BooleanConstantExpression; +class BooleanLiteral; +class DoubleConstantExpression; +class DoubleLiteral; +class IntegerConstantExpression; +class IntegerLiteral; +class UnaryBooleanFunctionExpression; +class UnaryNumericalFunctionExpression; +class VariableExpression; + class ExpressionVisitor { public: - virtual void visit(BaseExpression expression) = 0; - virtual void visit(BinaryBooleanFunctionExpression expression) = 0; - virtual void visit(BinaryNumericalFunctionExpression expression) = 0; - virtual void visit(BinaryRelationExpression expression) = 0; - virtual void visit(BooleanConstantExpression expression) = 0; - virtual void visit(BooleanLiteral expression) = 0; - virtual void visit(ConstantExpression expression) = 0; - virtual void visit(DoubleConstantExpression expression) = 0; - virtual void visit(DoubleLiteral expression) = 0; - virtual void visit(IntegerConstantExpression expression) = 0; - virtual void visit(IntegerLiteral expression) = 0; - virtual void visit(UnaryBooleanFunctionExpression expression) = 0; - virtual void visit(UnaryNumericalFunctionExpression expression) = 0; - virtual void visit(VariableExpression expression) = 0; + virtual void visit(BaseExpression* expression) = 0; + virtual void visit(BinaryBooleanFunctionExpression* expression) = 0; + virtual void visit(BinaryNumericalFunctionExpression* expression) = 0; + virtual void visit(BinaryRelationExpression* expression) = 0; + virtual void visit(BooleanConstantExpression* expression) = 0; + virtual void visit(BooleanLiteral* expression) = 0; + virtual void visit(DoubleConstantExpression* expression) = 0; + virtual void visit(DoubleLiteral* expression) = 0; + virtual void visit(IntegerConstantExpression* expression) = 0; + virtual void visit(IntegerLiteral* expression) = 0; + virtual void visit(UnaryBooleanFunctionExpression* expression) = 0; + virtual void visit(UnaryNumericalFunctionExpression* expression) = 0; + virtual void visit(VariableExpression* expression) = 0; }; } // namespace expressions diff --git a/src/ir/expressions/IntegerConstantExpression.h b/src/ir/expressions/IntegerConstantExpression.h index 615aa2adf..8c58260f6 100644 --- a/src/ir/expressions/IntegerConstantExpression.h +++ b/src/ir/expressions/IntegerConstantExpression.h @@ -35,14 +35,8 @@ public: } } - virtual ADD* toAdd() const { - if (!defined) { - throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Integer constant '" << this->getConstantName() << "' is undefined."; - } - - storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); - return new ADD(*cuddUtility->getConstant(value)); + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { diff --git a/src/ir/expressions/IntegerLiteral.h b/src/ir/expressions/IntegerLiteral.h index 6399813c3..a5130bd5e 100644 --- a/src/ir/expressions/IntegerLiteral.h +++ b/src/ir/expressions/IntegerLiteral.h @@ -32,9 +32,8 @@ public: return value; } - virtual ADD* toAdd() const { - storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); - return new ADD(*cuddUtility->getConstant(value)); + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { diff --git a/src/ir/expressions/UnaryBooleanFunctionExpression.h b/src/ir/expressions/UnaryBooleanFunctionExpression.h index 2322a323b..26b448548 100644 --- a/src/ir/expressions/UnaryBooleanFunctionExpression.h +++ b/src/ir/expressions/UnaryBooleanFunctionExpression.h @@ -8,7 +8,7 @@ #ifndef UNARYBOOLEANFUNCTIONEXPRESSION_H_ #define UNARYBOOLEANFUNCTIONEXPRESSION_H_ -#include "src/ir/expressions/BaseExpression.h" +#include "src/ir/expressions/UnaryExpression.h" namespace storm { @@ -16,11 +16,11 @@ namespace ir { namespace expressions { -class UnaryBooleanFunctionExpression : public BaseExpression { +class UnaryBooleanFunctionExpression : public UnaryExpression { public: enum FunctionType {NOT}; - UnaryBooleanFunctionExpression(std::shared_ptr child, FunctionType functionType) : BaseExpression(bool_), child(child), functionType(functionType) { + UnaryBooleanFunctionExpression(std::shared_ptr child, FunctionType functionType) : UnaryExpression(bool_, child), functionType(functionType) { } @@ -28,8 +28,12 @@ public: } + FunctionType getFunctionType() const { + return functionType; + } + virtual bool getValueAsBool(std::pair, std::vector> const* variableValues) const { - bool resultChild = child->getValueAsBool(variableValues); + bool resultChild = this->getChild()->getValueAsBool(variableValues); switch(functionType) { case NOT: return !resultChild; break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " @@ -37,9 +41,8 @@ public: } } - virtual ADD* toAdd() const { - ADD* childResult = child->toAdd(); - return new ADD(~(*childResult)); + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { @@ -47,13 +50,12 @@ public: switch (functionType) { case NOT: result += "!"; break; } - result += child->toString(); + result += this->getChild()->toString(); return result; } private: - std::shared_ptr child; FunctionType functionType; }; diff --git a/src/ir/expressions/UnaryExpression.h b/src/ir/expressions/UnaryExpression.h new file mode 100644 index 000000000..177b5593e --- /dev/null +++ b/src/ir/expressions/UnaryExpression.h @@ -0,0 +1,39 @@ +/* + * UnaryExpression.h + * + * Created on: 27.01.2013 + * Author: Christian Dehnert + */ + +#ifndef STORM_IR_EXPRESSIONS_UNARYEXPRESSION_H_ +#define STORM_IR_EXPRESSIONS_UNARYEXPRESSION_H_ + +#include "BaseExpression.h" + +namespace storm { + +namespace ir { + +namespace expressions { + +class UnaryExpression : public BaseExpression { +public: + UnaryExpression(ReturnType type, std::shared_ptr child) : BaseExpression(type), child(child) { + + } + + std::shared_ptr const& getChild() const { + return child; + } + +private: + std::shared_ptr child; +}; + +} // namespace expressions + +} // namespace ir + +} // namespace storm + +#endif /* STORM_IR_EXPRESSIONS_UNARYEXPRESSION_H_ */ diff --git a/src/ir/expressions/UnaryNumericalFunctionExpression.h b/src/ir/expressions/UnaryNumericalFunctionExpression.h index 3ab8fe355..2a1f02125 100644 --- a/src/ir/expressions/UnaryNumericalFunctionExpression.h +++ b/src/ir/expressions/UnaryNumericalFunctionExpression.h @@ -8,7 +8,7 @@ #ifndef UNARYFUNCTIONEXPRESSION_H_ #define UNARYFUNCTIONEXPRESSION_H_ -#include "src/ir/expressions/BaseExpression.h" +#include "src/ir/expressions/UnaryExpression.h" namespace storm { @@ -16,11 +16,11 @@ namespace ir { namespace expressions { -class UnaryNumericalFunctionExpression : public BaseExpression { +class UnaryNumericalFunctionExpression : public UnaryExpression { public: enum FunctionType {MINUS}; - UnaryNumericalFunctionExpression(ReturnType type, std::shared_ptr child, FunctionType functionType) : BaseExpression(type), child(child), functionType(functionType) { + UnaryNumericalFunctionExpression(ReturnType type, std::shared_ptr child, FunctionType functionType) : UnaryExpression(type, child), functionType(functionType) { } @@ -33,7 +33,7 @@ public: BaseExpression::getValueAsInt(variableValues); } - int_fast64_t resultChild = child->getValueAsInt(variableValues); + int_fast64_t resultChild = this->getChild()->getValueAsInt(variableValues); switch(functionType) { case MINUS: return -resultChild; break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " @@ -46,7 +46,7 @@ public: BaseExpression::getValueAsDouble(variableValues); } - double resultChild = child->getValueAsDouble(variableValues); + double resultChild = this->getChild()->getValueAsDouble(variableValues); switch(functionType) { case MINUS: return -resultChild; break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " @@ -54,11 +54,12 @@ public: } } - virtual ADD* toAdd() const { - ADD* childResult = child->toAdd(); - storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); - ADD* result = cuddUtility->getConstant(0); - return new ADD(result->Minus(*childResult)); + FunctionType getFunctionType() const { + return functionType; + } + + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); } virtual std::string toString() const { @@ -66,13 +67,12 @@ public: switch (functionType) { case MINUS: result += "-"; break; } - result += child->toString(); + result += this->getChild()->toString(); return result; } private: - std::shared_ptr child; FunctionType functionType; }; diff --git a/src/ir/expressions/VariableExpression.h b/src/ir/expressions/VariableExpression.h index 87374f080..1b9dfd46b 100644 --- a/src/ir/expressions/VariableExpression.h +++ b/src/ir/expressions/VariableExpression.h @@ -33,6 +33,10 @@ public: } + virtual void accept(ExpressionVisitor* visitor) { + visitor->visit(this); + } + virtual std::string toString() const { return variableName; } @@ -72,20 +76,16 @@ public: << " variable '" << variableName << "' of type double."; } - virtual ADD* toAdd() const { - storm::utility::CuddUtility* cuddUtility = storm::utility::cuddUtilityInstance(); - - return nullptr; + std::string const& getVariableName() const { + return variableName; + } - if (this->getType() == bool_) { - ADD* result = cuddUtility->getConstant(0); - //cuddUtility->addValueForEncodingOfConstant(result, 1, ) - } else { - int64_t low = lowerBound->getValueAsInt(nullptr); - int64_t high = upperBound->getValueAsInt(nullptr); - } + std::shared_ptr const& getLowerBound() const { + return lowerBound; + } - return new ADD(); + std::shared_ptr const& getUpperBound() const { + return upperBound; } private: diff --git a/src/utility/CuddUtility.cpp b/src/utility/CuddUtility.cpp index c9ab7aba6..ec9726b18 100644 --- a/src/utility/CuddUtility.cpp +++ b/src/utility/CuddUtility.cpp @@ -32,7 +32,15 @@ ADD* CuddUtility::getAddVariable(int index) const { return new ADD(manager.addVar(index)); } -ADD* CuddUtility::getConstantEncoding(uint_fast64_t constant, std::vector& variables) const { +ADD* CuddUtility::getOne() const { + return new ADD(manager.addOne()); +} + +ADD* CuddUtility::getZero() const { + return new ADD(manager.addZero()); +} + +ADD* CuddUtility::getConstantEncoding(uint_fast64_t constant, std::vector const& variables) const { if ((constant >> variables.size()) != 0) { LOG4CPLUS_ERROR(logger, "Cannot create encoding for constant " << constant << " with " << variables.size() << " variables."); @@ -62,18 +70,18 @@ ADD* CuddUtility::getConstantEncoding(uint_fast64_t constant, std::vector& return result; } -ADD* CuddUtility::addValueForEncodingOfConstant(ADD* add, uint_fast64_t constant, std::vector& variables, double value) const { - if ((constant >> variables.size()) != 0) { - LOG4CPLUS_ERROR(logger, "Cannot create encoding for constant " << constant << " with " +void CuddUtility::setValueAtIndex(ADD* add, uint_fast64_t index, std::vector const& variables, double value) const { + if ((index >> variables.size()) != 0) { + LOG4CPLUS_ERROR(logger, "Cannot create encoding for index " << index << " with " << variables.size() << " variables."); throw storm::exceptions::InvalidArgumentException() << "Cannot create encoding" - << " for constant " << constant << " with " << variables.size() + << " for index " << index << " with " << variables.size() << " variables."; } // Determine whether the new ADD will be rooted by the first variable or its complement. ADD initialNode; - if ((constant & (1 << (variables.size() - 1))) != 0) { + if ((index & (1 << (variables.size() - 1))) != 0) { initialNode = *variables[0]; } else { initialNode = ~(*variables[0]); @@ -82,21 +90,83 @@ ADD* CuddUtility::addValueForEncodingOfConstant(ADD* add, uint_fast64_t constant // Add (i.e. multiply) the other variables as well according to whether their bit is set or not. for (uint_fast64_t i = 1; i < variables.size(); ++i) { - if ((constant & (1 << (variables.size() - i - 1))) != 0) { + if ((index & (1 << (variables.size() - i - 1))) != 0) { *encoding *= *variables[i]; } else { *encoding *= ~(*variables[i]); } } - ADD* result = new ADD(add->Ite(manager.constant(value), *add)); - return result; + *add = encoding->Ite(manager.constant(value), *add); +} + +void CuddUtility::setValueAtIndices(ADD* add, uint_fast64_t rowIndex, uint_fast64_t columnIndex, std::vector const& rowVariables, std::vector const& columnVariables, double value) const { + if ((rowIndex >> rowVariables.size()) != 0) { + LOG4CPLUS_ERROR(logger, "Cannot create encoding for index " << rowIndex << " with " + << rowVariables.size() << " variables."); + throw storm::exceptions::InvalidArgumentException() << "Cannot create encoding" + << " for index " << rowIndex << " with " << rowVariables.size() + << " variables."; + } + if ((columnIndex >> columnVariables.size()) != 0) { + LOG4CPLUS_ERROR(logger, "Cannot create encoding for index " << columnIndex << " with " + << columnVariables.size() << " variables."); + throw storm::exceptions::InvalidArgumentException() << "Cannot create encoding" + << " for index " << columnIndex << " with " << columnVariables.size() + << " variables."; + } + if (rowVariables.size() != columnVariables.size()) { + LOG4CPLUS_ERROR(logger, "Number of variables for indices encodings does not match."); + throw storm::exceptions::InvalidArgumentException() + << "Number of variables for indices encodings does not match."; + } + + ADD initialNode; + if ((rowIndex & (1 << (rowVariables.size() - 1))) != 0) { + initialNode = *rowVariables[0]; + } else { + initialNode = ~(*rowVariables[0]); + } + ADD* encoding = new ADD(initialNode); + if ((columnIndex & (1 << (rowVariables.size() - 1))) != 0) { + *encoding *= *columnVariables[0]; + } else { + *encoding *= ~(*columnVariables[0]); + } + + for (uint_fast64_t i = 1; i < rowVariables.size(); ++i) { + if ((rowIndex & (1 << (rowVariables.size() - i - 1))) != 0) { + *encoding *= *rowVariables[i]; + } else { + *encoding *= ~(*rowVariables[i]); + } + if ((columnIndex & (1 << (columnVariables.size() - i - 1))) != 0) { + *encoding *= *columnVariables[i]; + } else { + *encoding *= ~(*columnVariables[i]); + } + } + + *add = encoding->Ite(manager.constant(value), *add); } + ADD* CuddUtility::getConstant(double value) const { return new ADD(manager.constant(value)); } +void CuddUtility::permuteVariables(ADD* add, std::vector fromVariables, std::vector toVariables, uint_fast64_t totalNumberOfVariables) const { + std::vector permutation; + permutation.resize(totalNumberOfVariables); + for (uint_fast64_t i = 0; i < totalNumberOfVariables; ++i) { + permutation[i] = i; + } + for (uint_fast64_t i = 0; i < fromVariables.size(); ++i) { + permutation[fromVariables[i]->NodeReadIndex()] = toVariables[i]->NodeReadIndex(); + } + add->Permute(&permutation[0]); +} + void CuddUtility::dumpDotToFile(ADD* add, std::string filename) const { std::vector nodes; nodes.push_back(*add); @@ -112,7 +182,7 @@ Cudd const& CuddUtility::getManager() const { } CuddUtility* cuddUtilityInstance() { - if (CuddUtility::instance != nullptr) { + if (CuddUtility::instance == nullptr) { CuddUtility::instance = new CuddUtility(); } return CuddUtility::instance; diff --git a/src/utility/CuddUtility.h b/src/utility/CuddUtility.h index ae2b56bb6..7eb679c82 100644 --- a/src/utility/CuddUtility.h +++ b/src/utility/CuddUtility.h @@ -25,12 +25,18 @@ public: ADD* getNewAddVariable(); ADD* getAddVariable(int index) const; - ADD* getConstantEncoding(uint_fast64_t constant, std::vector& variables) const; + ADD* getOne() const; + ADD* getZero() const; - ADD* addValueForEncodingOfConstant(ADD* add, uint_fast64_t constant, std::vector& variables, double value) const; + ADD* getConstantEncoding(uint_fast64_t constant, std::vector const& variables) const; + + void setValueAtIndex(ADD* add, uint_fast64_t index, std::vector const& variables, double value) const; + void setValueAtIndices(ADD* add, uint_fast64_t rowIndex, uint_fast64_t columnIndex, std::vector const& rowVariables, std::vector const& columnVariables, double value) const; ADD* getConstant(double value) const; + void permuteVariables(ADD* add, std::vector fromVariables, std::vector toVariables, uint_fast64_t totalNumberOfVariables) const; + void dumpDotToFile(ADD* add, std::string filename) const; Cudd const& getManager() const; @@ -38,7 +44,7 @@ public: friend CuddUtility* cuddUtilityInstance(); private: - CuddUtility() : manager(0, 0), allDecisionDiagramVariables() { + CuddUtility() : manager(), allDecisionDiagramVariables() { }