Browse Source

Refactor validation methods

tempestpy_adaptions
Jip Spel 6 years ago
parent
commit
29f9275302
  1. 163
      src/storm-pars/analysis/AssumptionChecker.cpp
  2. 13
      src/storm-pars/analysis/AssumptionChecker.h

163
src/storm-pars/analysis/AssumptionChecker.cpp

@ -136,10 +136,24 @@ namespace storm {
// Only implemented for two successors
if (row1.getNumberOfEntries() == 2 && row2.getNumberOfEntries() == 2) {
result = validateAssumptionFunction(lattice, row1, row2);
if (!result) {
result = validateAssumptionSMTSolver(lattice, row1, row2);
auto state1succ1 = row1.begin();
auto state1succ2 = (++row1.begin());
auto state2succ1 = row2.begin();
auto state2succ2 = (++row2.begin());
if (state1succ1->getColumn() == state2succ2->getColumn()
&& state2succ1->getColumn() == state1succ2->getColumn()) {
std::swap(state1succ1, state1succ2);
}
if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) {
result = validateAssumptionFunction(lattice, state1succ1, state1succ2, state2succ1, state2succ2);
if (!result) {
result = validateAssumptionSMTSolver(lattice, state1succ1, state1succ2, state2succ1, state2succ2);
}
}
} else {
STORM_LOG_DEBUG("Validation only implemented for two successor states");
}
}
@ -152,97 +166,76 @@ namespace storm {
}
template <typename ValueType>
bool AssumptionChecker<ValueType>::validateAssumptionFunction(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix<ValueType>::rows row1, typename storm::storage::SparseMatrix<ValueType>::rows row2) {
bool result = false;
auto state1succ1 = row1.begin();
auto state1succ2 = (++row1.begin());
auto state2succ1 = row2.begin();
auto state2succ2 = (++row2.begin());
if (state1succ1->getColumn() == state2succ2->getColumn()
&& state2succ1->getColumn() == state1succ2->getColumn()) {
// swap them
auto temp = state1succ2;
state1succ2 = state1succ1;
state1succ1 = temp;
bool AssumptionChecker<ValueType>::validateAssumptionFunction(storm::analysis::Lattice* lattice,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ2,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ2) {
ValueType prob;
auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn());
if (comp == storm::analysis::Lattice::ABOVE) {
prob = state1succ1->getValue() - state2succ1->getValue();
} else if (comp == storm::analysis::Lattice::BELOW) {
prob = state1succ2->getValue() - state2succ2->getValue();
}
if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) {
ValueType prob;
auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn());
if (comp == storm::analysis::Lattice::ABOVE) {
prob = state1succ1->getValue() - state2succ1->getValue();
} else if (comp == storm::analysis::Lattice::BELOW) {
prob = state1succ2->getValue() - state2succ2->getValue();
auto vars = prob.gatherVariables();
// TODO: Type
std::map<storm::RationalFunctionVariable, typename ValueType::CoeffType> substitutions;
for (auto var:vars) {
auto derivative = prob.derivative(var);
assert(derivative.isConstant());
if (derivative.constantPart() >= 0) {
substitutions[var] = 0;
} else if (derivative.constantPart() <= 0) {
substitutions[var] = 1;
}
auto vars = prob.gatherVariables();
// TODO: Type
std::map<storm::RationalFunctionVariable, storm::RationalFunctionCoefficient> substitutions;
for (auto var:vars) {
auto derivative = prob.derivative(var);
assert(derivative.isConstant());
if (derivative.constantPart() >= 0) {
substitutions[var] = 0;
} else if (derivative.constantPart() <= 0) {
substitutions[var] = 1;
}
}
result = prob.evaluate(substitutions) >= 0;
}
return result;
return prob.evaluate(substitutions) >= 0;
}
template <typename ValueType>
bool AssumptionChecker<ValueType>::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix<ValueType>::rows row1, typename storm::storage::SparseMatrix<ValueType>::rows row2) {
bool result = false;
auto state1succ1 = row1.begin();
auto state1succ2 = (++row1.begin());
auto state2succ1 = row2.begin();
auto state2succ2 = (++row2.begin());
if (state1succ1->getColumn() == state2succ2->getColumn()
&& state2succ1->getColumn() == state1succ2->getColumn()) {
std::swap(state1succ1, state1succ2);
}
if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) {
std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
std::shared_ptr<storm::expressions::ExpressionManager> manager(
new storm::expressions::ExpressionManager());
storm::solver::Z3SmtSolver s(*manager);
storm::solver::SmtSolver::CheckResult smtResult = storm::solver::SmtSolver::CheckResult::Unknown;
storm::expressions::Variable succ1 = manager->declareRationalVariable(std::to_string(state1succ1->getColumn()));
storm::expressions::Variable succ2 = manager->declareRationalVariable(std::to_string(state1succ2->getColumn()));
auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn());
if (comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW) {
if (comp == storm::analysis::Lattice::BELOW) {
std::swap(succ1, succ2);
}
storm::expressions::Expression exprGiven = succ1 >= succ2;
auto valueTypeToExpression = storm::expressions::ValueTypeToExpression<ValueType>(manager);
storm::expressions::Expression exprToCheck =
(valueTypeToExpression.toExpression(state1succ1->getValue())*succ1
+ valueTypeToExpression.toExpression(state2succ1->getValue())*succ2
>= valueTypeToExpression.toExpression(state1succ2->getValue())*succ1
+ valueTypeToExpression.toExpression(state1succ1->getValue())*succ2);
storm::expressions::Expression exprBounds = manager->boolean(true);
auto variables = manager->getVariables();
for (auto var : variables) {
exprBounds = exprBounds && var >= 0 && var <= 1;
}
s.add(exprGiven);
s.add(exprToCheck);
s.add(exprBounds);
smtResult = s.check();
bool AssumptionChecker<ValueType>::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ2,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ2) {
std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
std::shared_ptr<storm::expressions::ExpressionManager> manager(
new storm::expressions::ExpressionManager());
storm::solver::Z3SmtSolver s(*manager);
storm::solver::SmtSolver::CheckResult smtResult = storm::solver::SmtSolver::CheckResult::Unknown;
storm::expressions::Variable succ1 = manager->declareRationalVariable(std::to_string(state1succ1->getColumn()));
storm::expressions::Variable succ2 = manager->declareRationalVariable(std::to_string(state1succ2->getColumn()));
auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn());
if (comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW) {
if (comp == storm::analysis::Lattice::BELOW) {
std::swap(succ1, succ2);
}
storm::expressions::Expression exprGiven = succ1 >= succ2;
auto valueTypeToExpression = storm::expressions::ValueTypeToExpression<ValueType>(manager);
storm::expressions::Expression exprToCheck =
(valueTypeToExpression.toExpression(state1succ1->getValue())*succ1
+ valueTypeToExpression.toExpression(state2succ1->getValue())*succ2
>= valueTypeToExpression.toExpression(state1succ2->getValue())*succ1
+ valueTypeToExpression.toExpression(state1succ1->getValue())*succ2);
storm::expressions::Expression exprBounds = manager->boolean(true);
auto variables = manager->getVariables();
for (auto var : variables) {
exprBounds = exprBounds && var >= 0 && var <= 1;
}
result = smtResult == storm::solver::SmtSolver::CheckResult::Sat;
s.add(exprGiven);
s.add(exprToCheck);
s.add(exprBounds);
smtResult = s.check();
}
return result;
return smtResult == storm::solver::SmtSolver::CheckResult::Sat;
}
template <typename ValueType>

13
src/storm-pars/analysis/AssumptionChecker.h

@ -69,11 +69,18 @@ namespace storm {
std::set<std::shared_ptr<storm::expressions::BinaryRelationExpression>> validatedAssumptions;
bool validateAssumptionFunction(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix<ValueType>::rows row1, typename storm::storage::SparseMatrix<ValueType>::rows row2);
bool validateAssumptionFunction(storm::analysis::Lattice* lattice,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ2,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ2);
bool validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix<ValueType>::rows row1, typename storm::storage::SparseMatrix<ValueType>::rows row2);
bool validateAssumptionSMTSolver(storm::analysis::Lattice* lattice,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ2,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ2);
};
}
}
#endif //STORM_ASSUMPTIONCHECKER_H
Loading…
Cancel
Save