diff --git a/src/storm-pars/analysis/AssumptionChecker.cpp b/src/storm-pars/analysis/AssumptionChecker.cpp index b8cfd9d19..b454cf451 100644 --- a/src/storm-pars/analysis/AssumptionChecker.cpp +++ b/src/storm-pars/analysis/AssumptionChecker.cpp @@ -146,27 +146,32 @@ namespace storm { } if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) { - result = validateAssumptionFunction(lattice, state1succ1, state1succ2, state2succ1, state2succ2); + auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); + if (comp != storm::analysis::Lattice::UNKNOWN) { + result = validateAssumptionFunction(lattice, state1succ1, state1succ2, state2succ1, + state2succ2); if (!result) { - result = validateAssumptionSMTSolver(lattice, state1succ1, state1succ2, state2succ1, state2succ2); + result = validateAssumptionSMTSolver(lattice, state1succ1, state1succ2, state2succ1, + state2succ2); } - validatedAssumptions.insert(assumption); - if (result) { - validAssumptions.insert(assumption); + validatedAssumptions.insert(assumption); + if (result) { + validAssumptions.insert(assumption); + } } } } else { bool subset = true; - auto row1 = matrix.getRow(std::stoi(assumption->getFirstOperand()->asVariableExpression().getVariableName())); - auto row2 = matrix.getRow(std::stoi(assumption->getSecondOperand()->asVariableExpression().getVariableName())); if (row1.getNumberOfEntries() > row2.getNumberOfEntries()) { std::swap(row1, row2); } + storm::storage::BitVector stateNumbers(matrix.getColumnCount()); for (auto itr1 = row1.begin(); subset && itr1 != row1.end(); ++itr1) { bool found = false; + stateNumbers.set(itr1->getColumn()); for (auto itr2 = row2.begin(); !found && itr2 != row2.end(); ++itr2) { found = itr1->getColumn() == itr2->getColumn(); } @@ -174,10 +179,21 @@ namespace storm { } if (subset) { - result = validateAssumptionSMTSolver(lattice, assumption); - validatedAssumptions.insert(assumption); - if (result) { - validAssumptions.insert(assumption); + // Check if they all are in the lattice + bool allInLattice = true; + for (auto i = stateNumbers.getNextSetIndex(0); allInLattice && i < stateNumbers.size(); i = stateNumbers.getNextSetIndex(i+1)) { + for (auto j = stateNumbers.getNextSetIndex(i+1); allInLattice && j < stateNumbers.size(); j = stateNumbers.getNextSetIndex(j+1)) { + auto comp = lattice->compare(i,j); + allInLattice &= comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW || comp == storm::analysis::Lattice::SAME; + } + } + + if (allInLattice) { + result = validateAssumptionSMTSolver(lattice, assumption); + validatedAssumptions.insert(assumption); + if (result) { + validAssumptions.insert(assumption); + } } } } @@ -195,15 +211,18 @@ namespace storm { && state1succ2->getColumn() == state2succ2->getColumn()) || (state1succ1->getColumn() == state2succ2->getColumn() && state1succ2->getColumn() == state2succ1->getColumn())); + bool result = true; ValueType prob; auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); + assert (comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW); if (comp == storm::analysis::Lattice::ABOVE) { prob = state1succ1->getValue() - state2succ1->getValue(); } else if (comp == storm::analysis::Lattice::BELOW) { prob = state1succ2->getValue() - state2succ2->getValue(); } + auto vars = prob.gatherVariables(); // TODO: Type @@ -239,112 +258,136 @@ namespace storm { storm::solver::Z3SmtSolver s(*manager); storm::solver::SmtSolver::CheckResult smtResult = storm::solver::SmtSolver::CheckResult::Unknown; + storm::expressions::Variable succ1 = manager->declareRationalVariable(std::to_string(state1succ1->getColumn())); storm::expressions::Variable succ2 = manager->declareRationalVariable(std::to_string(state1succ2->getColumn())); auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); - if (comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW) { - if (comp == storm::analysis::Lattice::BELOW) { - std::swap(succ1, succ2); - } - storm::expressions::Expression exprGiven = succ1 >= succ2; - - auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression(manager); - storm::expressions::Expression exprToCheck = - (valueTypeToExpression.toExpression(state1succ1->getValue())*succ1 - + valueTypeToExpression.toExpression(state1succ2->getValue())*succ2 - >= valueTypeToExpression.toExpression(state2succ1->getValue())*succ1 - + valueTypeToExpression.toExpression(state2succ2->getValue())*succ2); - - storm::expressions::Expression exprBounds = manager->boolean(true); - auto variables = manager->getVariables(); - for (auto var : variables) { - exprBounds = exprBounds && var > 0 && var < 1; - } - s.add(exprGiven); - s.add(exprToCheck); - s.add(exprBounds); - smtResult = s.check(); + storm::expressions::Expression exprGiven; + if (comp == storm::analysis::Lattice::ABOVE) { + exprGiven = succ1 >= succ2; + } else if (comp == storm::analysis::Lattice::BELOW) { + exprGiven = succ1 <= succ2; + } else { + assert (comp != storm::analysis::Lattice::UNKNOWN); + exprGiven = succ1 = succ2; + } + + auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression(manager); + storm::expressions::Expression exprToCheck = + (valueTypeToExpression.toExpression(state1succ1->getValue())*succ1 + + valueTypeToExpression.toExpression(state1succ2->getValue())*succ2 + >= valueTypeToExpression.toExpression(state2succ1->getValue())*succ1 + + valueTypeToExpression.toExpression(state2succ2->getValue())*succ2); + + storm::expressions::Expression exprBounds = manager->boolean(true); + auto variables = manager->getVariables(); + for (auto var : variables) { + if (var != succ1 && var != succ2) { + // ensure graph-preserving + exprBounds = exprBounds && manager->rational(0) <= var && manager->rational(1) >= var; + } else { + exprBounds = exprBounds && var >= manager->rational(0) && var <= manager->rational(1); + } } + + s.add(exprGiven); + s.add(exprToCheck); + s.add(exprBounds); + smtResult = s.check(); + return smtResult == storm::solver::SmtSolver::CheckResult::Sat; } template bool AssumptionChecker::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, std::shared_ptr assumption) { + assert (!validated(assumption)); + std::shared_ptr smtSolverFactory = std::make_shared(); std::shared_ptr manager(new storm::expressions::ExpressionManager()); + bool result = true; auto row1 = matrix.getRow(std::stoi(assumption->getFirstOperand()->asVariableExpression().getVariableName())); auto row2 = matrix.getRow(std::stoi(assumption->getSecondOperand()->asVariableExpression().getVariableName())); - if (result) { - storm::solver::Z3SmtSolver s(*manager); - if (row1.getNumberOfEntries() >= row2.getNumberOfEntries()) { - for (auto itr1 = row1.begin(); itr1 != row1.end(); ++itr1) { - manager->declareRationalVariable(std::to_string(itr1->getColumn())); - } - } else { - for (auto itr1 = row2.begin(); itr1 != row2.end(); ++itr1) { - manager->declareRationalVariable(std::to_string(itr1->getColumn())); - } + storm::solver::Z3SmtSolver s(*manager); + std::set stateVariables; + if (row1.getNumberOfEntries() >= row2.getNumberOfEntries()) { + for (auto itr = row1.begin(); itr != row1.end(); ++itr) { + stateVariables.insert(manager->declareRationalVariable(std::to_string(itr->getColumn()))); } + } else { + for (auto itr = row2.begin(); itr != row2.end(); ++itr) { + stateVariables.insert(manager->declareRationalVariable(std::to_string(itr->getColumn()))); + } + } - storm::expressions::Expression exprGiven = manager->boolean(true); + storm::expressions::Expression exprGiven = manager->boolean(true); - for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { - for (auto itr2 = (itr1 + 1); result && itr2 != row1.end(); ++itr2) { + for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { + for (auto itr2 = row1.begin(); result && itr2 != row1.end(); ++itr2) { + if (itr1->getColumn() != itr2->getColumn()) { auto comp = lattice->compare(itr1->getColumn(), itr2->getColumn()); + if (comp == storm::analysis::Lattice::ABOVE) { - exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) >= manager->getVariable(std::to_string(itr2->getColumn()))); + exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) >= + manager->getVariable(std::to_string(itr2->getColumn()))); } else if (comp == storm::analysis::Lattice::BELOW) { - exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) <= manager->getVariable(std::to_string(itr2->getColumn()))); - } else if (comp == storm::analysis::Lattice::SAME) { - exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) = manager->getVariable(std::to_string(itr2->getColumn()))); + exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) <= + manager->getVariable(std::to_string(itr2->getColumn()))); } else { - result = false; + assert (comp != storm::analysis::Lattice::UNKNOWN); + exprGiven = exprGiven && + (manager->getVariable(std::to_string(itr1->getColumn())) = manager->getVariable( + std::to_string(itr2->getColumn()))); } } } + } - auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression(manager); - storm::expressions::Expression expr1 = manager->integer(0); - for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { - expr1 = expr1 + (valueTypeToExpression.toExpression(itr1->getValue()) * manager->getVariable(std::to_string(itr1->getColumn()))); - } + auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression(manager); + storm::expressions::Expression expr1 = manager->rational(0); + for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { + expr1 = expr1 + (valueTypeToExpression.toExpression(itr1->getValue()) * manager->getVariable(std::to_string(itr1->getColumn()))); + } - storm::expressions::Expression expr2 = manager->integer(0); - for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) { - expr2 = expr2 + (valueTypeToExpression.toExpression(itr2->getValue()) * manager->getVariable(std::to_string(itr2->getColumn()))); - } - storm::expressions::Expression exprToCheck = expr1 >= expr2; + storm::expressions::Expression expr2 = manager->rational(0); + for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) { + expr2 = expr2 + (valueTypeToExpression.toExpression(itr2->getValue()) * manager->getVariable(std::to_string(itr2->getColumn()))); + } + storm::expressions::Expression exprToCheck = expr1 >= expr2; - storm::expressions::Expression exprProb1 = manager->integer(0); - for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { - exprProb1 = exprProb1 + (valueTypeToExpression.toExpression(itr1->getValue())); - } + storm::expressions::Expression exprProb1 = manager->rational(0); + for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { + exprProb1 = exprProb1 + (valueTypeToExpression.toExpression(itr1->getValue())); + } - storm::expressions::Expression exprProb2 = manager->integer(0); - for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) { - exprProb2 = exprProb2 + (valueTypeToExpression.toExpression(itr2->getValue())); - } + storm::expressions::Expression exprProb2 = manager->rational(0); + for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) { + exprProb2 = exprProb2 + (valueTypeToExpression.toExpression(itr2->getValue())); + } - storm::expressions::Expression exprBounds = exprProb1 >= manager->rational(0) - && exprProb1 <= manager->rational(1) - && exprProb2 >= manager->rational(0) - && exprProb2 <= manager->rational(1); + storm::expressions::Expression exprBounds = exprProb1 >= manager->rational(0) + && exprProb1 <= manager->rational(1) + && exprProb2 >= manager->rational(0) + && exprProb2 <= manager->rational(1); - auto variables = manager->getVariables(); - for (auto var : variables) { - exprBounds = exprBounds && var > 0 && var < 1; + auto variables = manager->getVariables(); + for (auto var : variables) { + if (find(stateVariables.begin(), stateVariables.end(), var) != stateVariables.end()) { + // ensure graph-preserving + exprBounds = exprBounds && manager->rational(0) <= var && manager->rational(1) >= var; + } else { + exprBounds = exprBounds && var >= manager->rational(0) && var <= manager->rational(1); } - - s.add(exprGiven); - s.add(exprBounds); - assert(s.check() == storm::solver::SmtSolver::CheckResult::Sat); - s.add(exprToCheck); - auto smtRes = s.check(); - result = result && smtRes == storm::solver::SmtSolver::CheckResult::Sat; } + + s.add(exprGiven); + s.add(exprBounds); + assert(s.check() == storm::solver::SmtSolver::CheckResult::Sat); + s.add(exprToCheck); + auto smtRes = s.check(); + result = result && smtRes == storm::solver::SmtSolver::CheckResult::Sat; return result; }