diff --git a/src/storm-pars/analysis/AssumptionChecker.cpp b/src/storm-pars/analysis/AssumptionChecker.cpp index 400020aff..f63ac68ac 100644 --- a/src/storm-pars/analysis/AssumptionChecker.cpp +++ b/src/storm-pars/analysis/AssumptionChecker.cpp @@ -134,7 +134,6 @@ namespace storm { auto row1 = matrix.getRow(std::stoi(assumption->getFirstOperand()->asVariableExpression().getVariableName())); auto row2 = matrix.getRow(std::stoi(assumption->getSecondOperand()->asVariableExpression().getVariableName())); - // Only implemented for two successors if (row1.getNumberOfEntries() == 2 && row2.getNumberOfEntries() == 2) { auto state1succ1 = row1.begin(); auto state1succ2 = (++row1.begin()); @@ -153,7 +152,7 @@ namespace storm { } } } else { - STORM_LOG_DEBUG("Validation only implemented for two successor states"); + result = validateAssumptionSMTSolver(lattice, assumption); } } @@ -171,6 +170,10 @@ namespace storm { typename storm::storage::SparseMatrix::iterator state1succ2, typename storm::storage::SparseMatrix::iterator state2succ1, typename storm::storage::SparseMatrix::iterator state2succ2) { + assert((state1succ1->getColumn() == state2succ1->getColumn() + && state1succ2->getColumn() == state2succ2->getColumn()) + || (state1succ1->getColumn() == state2succ2->getColumn() + && state1succ2->getColumn() == state2succ1->getColumn())); bool result = true; ValueType prob; @@ -201,11 +204,14 @@ namespace storm { template bool AssumptionChecker::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, - typename storm::storage::SparseMatrix::iterator state1succ1, - typename storm::storage::SparseMatrix::iterator state1succ2, - typename storm::storage::SparseMatrix::iterator state2succ1, - typename storm::storage::SparseMatrix::iterator state2succ2) { - + typename storm::storage::SparseMatrix::iterator state1succ1, + typename storm::storage::SparseMatrix::iterator state1succ2, + typename storm::storage::SparseMatrix::iterator state2succ1, + typename storm::storage::SparseMatrix::iterator state2succ2) { + assert((state1succ1->getColumn() == state2succ1->getColumn() + && state1succ2->getColumn() == state2succ2->getColumn()) + || (state1succ1->getColumn() == state2succ2->getColumn() + && state1succ2->getColumn() == state2succ1->getColumn())); std::shared_ptr smtSolverFactory = std::make_shared(); std::shared_ptr manager( new storm::expressions::ExpressionManager()); @@ -224,9 +230,9 @@ namespace storm { auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression(manager); storm::expressions::Expression exprToCheck = (valueTypeToExpression.toExpression(state1succ1->getValue())*succ1 - + valueTypeToExpression.toExpression(state2succ1->getValue())*succ2 - >= valueTypeToExpression.toExpression(state1succ2->getValue())*succ1 - + valueTypeToExpression.toExpression(state1succ1->getValue())*succ2); + + valueTypeToExpression.toExpression(state1succ2->getValue())*succ2 + >= valueTypeToExpression.toExpression(state2succ1->getValue())*succ1 + + valueTypeToExpression.toExpression(state2succ2->getValue())*succ2); storm::expressions::Expression exprBounds = manager->boolean(true); auto variables = manager->getVariables(); @@ -242,6 +248,109 @@ namespace storm { return smtResult == storm::solver::SmtSolver::CheckResult::Sat; } + template + bool AssumptionChecker::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, std::shared_ptr assumption) { + std::shared_ptr smtSolverFactory = std::make_shared(); + std::shared_ptr manager(new storm::expressions::ExpressionManager()); + bool result = true; + auto row1 = matrix.getRow(std::stoi(assumption->getFirstOperand()->asVariableExpression().getVariableName())); + auto row2 = matrix.getRow(std::stoi(assumption->getSecondOperand()->asVariableExpression().getVariableName())); + + if (row1.getNumberOfEntries() <= row2.getNumberOfEntries()) { + for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { + bool found = false; + for (auto itr2 = row2.begin(); !found && itr2 != row2.end(); ++itr2) { + found = itr1->getColumn() == itr2->getColumn(); + } + + if (!found) { + result = false; + } + } + } else { + for (auto itr1 = row2.begin(); result && itr1 != row2.end(); ++itr1) { + bool found = false; + for (auto itr2 = row1.begin(); !found && itr2 != row1.end(); ++itr2) { + found = itr1->getColumn() == itr2->getColumn(); + } + + if (!found) { + result = false; + } + } + } + + if (result) { + storm::solver::Z3SmtSolver s(*manager); + if (row1.getNumberOfEntries() >= row2.getNumberOfEntries()) { + for (auto itr1 = row1.begin(); itr1 != row1.end(); ++itr1) { + manager->declareRationalVariable(std::to_string(itr1->getColumn())); + } + } else { + for (auto itr1 = row2.begin(); itr1 != row2.end(); ++itr1) { + manager->declareRationalVariable(std::to_string(itr1->getColumn())); + } + } + + storm::expressions::Expression exprGiven = manager->boolean(true); + + for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { + for (auto itr2 = (itr1 + 1); result && itr2 != row1.end(); ++itr2) { + auto comp = lattice->compare(itr1->getColumn(), itr2->getColumn()); + if (comp == storm::analysis::Lattice::ABOVE) { + exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) >= manager->getVariable(std::to_string(itr2->getColumn()))); + } else if (comp == storm::analysis::Lattice::BELOW) { + exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) <= manager->getVariable(std::to_string(itr2->getColumn()))); + } else if (comp == storm::analysis::Lattice::SAME) { + exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) = manager->getVariable(std::to_string(itr2->getColumn()))); + } else { + result = false; + } + } + } + + auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression(manager); + storm::expressions::Expression expr1 = manager->integer(0); + for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { + expr1 = expr1 + (valueTypeToExpression.toExpression(itr1->getValue()) * manager->getVariable(std::to_string(itr1->getColumn()))); + } + + storm::expressions::Expression expr2 = manager->integer(0); + for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) { + expr2 = expr2 + (valueTypeToExpression.toExpression(itr2->getValue()) * manager->getVariable(std::to_string(itr2->getColumn()))); + } + storm::expressions::Expression exprToCheck = expr1 >= expr2; + + storm::expressions::Expression exprProb1 = manager->integer(0); + for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { + exprProb1 = exprProb1 + (valueTypeToExpression.toExpression(itr1->getValue())); + } + + storm::expressions::Expression exprProb2 = manager->integer(0); + for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) { + exprProb2 = exprProb2 + (valueTypeToExpression.toExpression(itr2->getValue())); + } + + storm::expressions::Expression exprBounds = exprProb1 >= manager->rational(0) + && exprProb1 <= manager->rational(1) + && exprProb2 >= manager->rational(0) + && exprProb2 <= manager->rational(1); + + auto variables = manager->getVariables(); + for (auto var : variables) { + exprBounds = exprBounds && var >= 0 && var <= 1; + } + + s.add(exprGiven); + s.add(exprBounds); + assert(s.check() == storm::solver::SmtSolver::CheckResult::Sat); + s.add(exprToCheck); + auto smtRes = s.check(); + result = result && smtRes == storm::solver::SmtSolver::CheckResult::Sat; + } + return result; + } + template bool AssumptionChecker::validated(std::shared_ptr assumption) { return find(validatedAssumptions.begin(), validatedAssumptions.end(), assumption) != validatedAssumptions.end(); diff --git a/src/storm-pars/analysis/AssumptionChecker.h b/src/storm-pars/analysis/AssumptionChecker.h index cdb6faf2d..6a11f7e6f 100644 --- a/src/storm-pars/analysis/AssumptionChecker.h +++ b/src/storm-pars/analysis/AssumptionChecker.h @@ -80,6 +80,9 @@ namespace storm { typename storm::storage::SparseMatrix::iterator state1succ2, typename storm::storage::SparseMatrix::iterator state2succ1, typename storm::storage::SparseMatrix::iterator state2succ2); + + bool validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, + std::shared_ptr assumption); }; } }