Browse Source

Fix SMT validation of assumptions

tempestpy_adaptions
Jip Spel 6 years ago
parent
commit
fbdce446b3
  1. 97
      src/storm-pars/analysis/AssumptionChecker.cpp

97
src/storm-pars/analysis/AssumptionChecker.cpp

@ -146,10 +146,14 @@ namespace storm {
} }
if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) { if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) {
result = validateAssumptionFunction(lattice, state1succ1, state1succ2, state2succ1, state2succ2);
auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn());
if (comp != storm::analysis::Lattice::UNKNOWN) {
result = validateAssumptionFunction(lattice, state1succ1, state1succ2, state2succ1,
state2succ2);
if (!result) { if (!result) {
result = validateAssumptionSMTSolver(lattice, state1succ1, state1succ2, state2succ1, state2succ2);
result = validateAssumptionSMTSolver(lattice, state1succ1, state1succ2, state2succ1,
state2succ2);
} }
validatedAssumptions.insert(assumption); validatedAssumptions.insert(assumption);
@ -157,16 +161,17 @@ namespace storm {
validAssumptions.insert(assumption); validAssumptions.insert(assumption);
} }
} }
}
} else { } else {
bool subset = true; bool subset = true;
auto row1 = matrix.getRow(std::stoi(assumption->getFirstOperand()->asVariableExpression().getVariableName()));
auto row2 = matrix.getRow(std::stoi(assumption->getSecondOperand()->asVariableExpression().getVariableName()));
if (row1.getNumberOfEntries() > row2.getNumberOfEntries()) { if (row1.getNumberOfEntries() > row2.getNumberOfEntries()) {
std::swap(row1, row2); std::swap(row1, row2);
} }
storm::storage::BitVector stateNumbers(matrix.getColumnCount());
for (auto itr1 = row1.begin(); subset && itr1 != row1.end(); ++itr1) { for (auto itr1 = row1.begin(); subset && itr1 != row1.end(); ++itr1) {
bool found = false; bool found = false;
stateNumbers.set(itr1->getColumn());
for (auto itr2 = row2.begin(); !found && itr2 != row2.end(); ++itr2) { for (auto itr2 = row2.begin(); !found && itr2 != row2.end(); ++itr2) {
found = itr1->getColumn() == itr2->getColumn(); found = itr1->getColumn() == itr2->getColumn();
} }
@ -174,6 +179,16 @@ namespace storm {
} }
if (subset) { if (subset) {
// Check if they all are in the lattice
bool allInLattice = true;
for (auto i = stateNumbers.getNextSetIndex(0); allInLattice && i < stateNumbers.size(); i = stateNumbers.getNextSetIndex(i+1)) {
for (auto j = stateNumbers.getNextSetIndex(i+1); allInLattice && j < stateNumbers.size(); j = stateNumbers.getNextSetIndex(j+1)) {
auto comp = lattice->compare(i,j);
allInLattice &= comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW || comp == storm::analysis::Lattice::SAME;
}
}
if (allInLattice) {
result = validateAssumptionSMTSolver(lattice, assumption); result = validateAssumptionSMTSolver(lattice, assumption);
validatedAssumptions.insert(assumption); validatedAssumptions.insert(assumption);
if (result) { if (result) {
@ -182,6 +197,7 @@ namespace storm {
} }
} }
} }
}
return result; return result;
} }
@ -195,15 +211,18 @@ namespace storm {
&& state1succ2->getColumn() == state2succ2->getColumn()) && state1succ2->getColumn() == state2succ2->getColumn())
|| (state1succ1->getColumn() == state2succ2->getColumn() || (state1succ1->getColumn() == state2succ2->getColumn()
&& state1succ2->getColumn() == state2succ1->getColumn())); && state1succ2->getColumn() == state2succ1->getColumn()));
bool result = true; bool result = true;
ValueType prob; ValueType prob;
auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn());
assert (comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW);
if (comp == storm::analysis::Lattice::ABOVE) { if (comp == storm::analysis::Lattice::ABOVE) {
prob = state1succ1->getValue() - state2succ1->getValue(); prob = state1succ1->getValue() - state2succ1->getValue();
} else if (comp == storm::analysis::Lattice::BELOW) { } else if (comp == storm::analysis::Lattice::BELOW) {
prob = state1succ2->getValue() - state2succ2->getValue(); prob = state1succ2->getValue() - state2succ2->getValue();
} }
auto vars = prob.gatherVariables(); auto vars = prob.gatherVariables();
// TODO: Type // TODO: Type
@ -239,14 +258,20 @@ namespace storm {
storm::solver::Z3SmtSolver s(*manager); storm::solver::Z3SmtSolver s(*manager);
storm::solver::SmtSolver::CheckResult smtResult = storm::solver::SmtSolver::CheckResult::Unknown; storm::solver::SmtSolver::CheckResult smtResult = storm::solver::SmtSolver::CheckResult::Unknown;
storm::expressions::Variable succ1 = manager->declareRationalVariable(std::to_string(state1succ1->getColumn())); storm::expressions::Variable succ1 = manager->declareRationalVariable(std::to_string(state1succ1->getColumn()));
storm::expressions::Variable succ2 = manager->declareRationalVariable(std::to_string(state1succ2->getColumn())); storm::expressions::Variable succ2 = manager->declareRationalVariable(std::to_string(state1succ2->getColumn()));
auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn());
if (comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW) {
if (comp == storm::analysis::Lattice::BELOW) {
std::swap(succ1, succ2);
storm::expressions::Expression exprGiven;
if (comp == storm::analysis::Lattice::ABOVE) {
exprGiven = succ1 >= succ2;
} else if (comp == storm::analysis::Lattice::BELOW) {
exprGiven = succ1 <= succ2;
} else {
assert (comp != storm::analysis::Lattice::UNKNOWN);
exprGiven = succ1 = succ2;
} }
storm::expressions::Expression exprGiven = succ1 >= succ2;
auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression<ValueType>(manager); auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression<ValueType>(manager);
storm::expressions::Expression exprToCheck = storm::expressions::Expression exprToCheck =
@ -258,72 +283,86 @@ namespace storm {
storm::expressions::Expression exprBounds = manager->boolean(true); storm::expressions::Expression exprBounds = manager->boolean(true);
auto variables = manager->getVariables(); auto variables = manager->getVariables();
for (auto var : variables) { for (auto var : variables) {
exprBounds = exprBounds && var > 0 && var < 1;
if (var != succ1 && var != succ2) {
// ensure graph-preserving
exprBounds = exprBounds && manager->rational(0) <= var && manager->rational(1) >= var;
} else {
exprBounds = exprBounds && var >= manager->rational(0) && var <= manager->rational(1);
}
} }
s.add(exprGiven); s.add(exprGiven);
s.add(exprToCheck); s.add(exprToCheck);
s.add(exprBounds); s.add(exprBounds);
smtResult = s.check(); smtResult = s.check();
}
return smtResult == storm::solver::SmtSolver::CheckResult::Sat; return smtResult == storm::solver::SmtSolver::CheckResult::Sat;
} }
template <typename ValueType> template <typename ValueType>
bool AssumptionChecker<ValueType>::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, std::shared_ptr<storm::expressions::BinaryRelationExpression> assumption) { bool AssumptionChecker<ValueType>::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, std::shared_ptr<storm::expressions::BinaryRelationExpression> assumption) {
assert (!validated(assumption));
std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>(); std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
std::shared_ptr<storm::expressions::ExpressionManager> manager(new storm::expressions::ExpressionManager()); std::shared_ptr<storm::expressions::ExpressionManager> manager(new storm::expressions::ExpressionManager());
bool result = true; bool result = true;
auto row1 = matrix.getRow(std::stoi(assumption->getFirstOperand()->asVariableExpression().getVariableName())); auto row1 = matrix.getRow(std::stoi(assumption->getFirstOperand()->asVariableExpression().getVariableName()));
auto row2 = matrix.getRow(std::stoi(assumption->getSecondOperand()->asVariableExpression().getVariableName())); auto row2 = matrix.getRow(std::stoi(assumption->getSecondOperand()->asVariableExpression().getVariableName()));
if (result) {
storm::solver::Z3SmtSolver s(*manager); storm::solver::Z3SmtSolver s(*manager);
std::set<storm::expressions::Variable> stateVariables;
if (row1.getNumberOfEntries() >= row2.getNumberOfEntries()) { if (row1.getNumberOfEntries() >= row2.getNumberOfEntries()) {
for (auto itr1 = row1.begin(); itr1 != row1.end(); ++itr1) {
manager->declareRationalVariable(std::to_string(itr1->getColumn()));
for (auto itr = row1.begin(); itr != row1.end(); ++itr) {
stateVariables.insert(manager->declareRationalVariable(std::to_string(itr->getColumn())));
} }
} else { } else {
for (auto itr1 = row2.begin(); itr1 != row2.end(); ++itr1) {
manager->declareRationalVariable(std::to_string(itr1->getColumn()));
for (auto itr = row2.begin(); itr != row2.end(); ++itr) {
stateVariables.insert(manager->declareRationalVariable(std::to_string(itr->getColumn())));
} }
} }
storm::expressions::Expression exprGiven = manager->boolean(true); storm::expressions::Expression exprGiven = manager->boolean(true);
for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) {
for (auto itr2 = (itr1 + 1); result && itr2 != row1.end(); ++itr2) {
for (auto itr2 = row1.begin(); result && itr2 != row1.end(); ++itr2) {
if (itr1->getColumn() != itr2->getColumn()) {
auto comp = lattice->compare(itr1->getColumn(), itr2->getColumn()); auto comp = lattice->compare(itr1->getColumn(), itr2->getColumn());
if (comp == storm::analysis::Lattice::ABOVE) { if (comp == storm::analysis::Lattice::ABOVE) {
exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) >= manager->getVariable(std::to_string(itr2->getColumn())));
exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) >=
manager->getVariable(std::to_string(itr2->getColumn())));
} else if (comp == storm::analysis::Lattice::BELOW) { } else if (comp == storm::analysis::Lattice::BELOW) {
exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) <= manager->getVariable(std::to_string(itr2->getColumn())));
} else if (comp == storm::analysis::Lattice::SAME) {
exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) = manager->getVariable(std::to_string(itr2->getColumn())));
exprGiven = exprGiven && (manager->getVariable(std::to_string(itr1->getColumn())) <=
manager->getVariable(std::to_string(itr2->getColumn())));
} else { } else {
result = false;
assert (comp != storm::analysis::Lattice::UNKNOWN);
exprGiven = exprGiven &&
(manager->getVariable(std::to_string(itr1->getColumn())) = manager->getVariable(
std::to_string(itr2->getColumn())));
}
} }
} }
} }
auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression<ValueType>(manager); auto valueTypeToExpression = storm::expressions::RationalFunctionToExpression<ValueType>(manager);
storm::expressions::Expression expr1 = manager->integer(0);
storm::expressions::Expression expr1 = manager->rational(0);
for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) {
expr1 = expr1 + (valueTypeToExpression.toExpression(itr1->getValue()) * manager->getVariable(std::to_string(itr1->getColumn()))); expr1 = expr1 + (valueTypeToExpression.toExpression(itr1->getValue()) * manager->getVariable(std::to_string(itr1->getColumn())));
} }
storm::expressions::Expression expr2 = manager->integer(0);
storm::expressions::Expression expr2 = manager->rational(0);
for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) { for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) {
expr2 = expr2 + (valueTypeToExpression.toExpression(itr2->getValue()) * manager->getVariable(std::to_string(itr2->getColumn()))); expr2 = expr2 + (valueTypeToExpression.toExpression(itr2->getValue()) * manager->getVariable(std::to_string(itr2->getColumn())));
} }
storm::expressions::Expression exprToCheck = expr1 >= expr2; storm::expressions::Expression exprToCheck = expr1 >= expr2;
storm::expressions::Expression exprProb1 = manager->integer(0);
storm::expressions::Expression exprProb1 = manager->rational(0);
for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) { for (auto itr1 = row1.begin(); result && itr1 != row1.end(); ++itr1) {
exprProb1 = exprProb1 + (valueTypeToExpression.toExpression(itr1->getValue())); exprProb1 = exprProb1 + (valueTypeToExpression.toExpression(itr1->getValue()));
} }
storm::expressions::Expression exprProb2 = manager->integer(0);
storm::expressions::Expression exprProb2 = manager->rational(0);
for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) { for (auto itr2 = row2.begin(); result && itr2 != row2.end(); ++itr2) {
exprProb2 = exprProb2 + (valueTypeToExpression.toExpression(itr2->getValue())); exprProb2 = exprProb2 + (valueTypeToExpression.toExpression(itr2->getValue()));
} }
@ -335,7 +374,12 @@ namespace storm {
auto variables = manager->getVariables(); auto variables = manager->getVariables();
for (auto var : variables) { for (auto var : variables) {
exprBounds = exprBounds && var > 0 && var < 1;
if (find(stateVariables.begin(), stateVariables.end(), var) != stateVariables.end()) {
// ensure graph-preserving
exprBounds = exprBounds && manager->rational(0) <= var && manager->rational(1) >= var;
} else {
exprBounds = exprBounds && var >= manager->rational(0) && var <= manager->rational(1);
}
} }
s.add(exprGiven); s.add(exprGiven);
@ -344,7 +388,6 @@ namespace storm {
s.add(exprToCheck); s.add(exprToCheck);
auto smtRes = s.check(); auto smtRes = s.check();
result = result && smtRes == storm::solver::SmtSolver::CheckResult::Sat; result = result && smtRes == storm::solver::SmtSolver::CheckResult::Sat;
}
return result; return result;
} }

Loading…
Cancel
Save