Browse Source

Refactor validation methods

tempestpy_adaptions
Jip Spel 6 years ago
parent
commit
29f9275302
  1. 67
      src/storm-pars/analysis/AssumptionChecker.cpp
  2. 13
      src/storm-pars/analysis/AssumptionChecker.h

67
src/storm-pars/analysis/AssumptionChecker.cpp

@ -136,10 +136,24 @@ namespace storm {
// Only implemented for two successors // Only implemented for two successors
if (row1.getNumberOfEntries() == 2 && row2.getNumberOfEntries() == 2) { if (row1.getNumberOfEntries() == 2 && row2.getNumberOfEntries() == 2) {
result = validateAssumptionFunction(lattice, row1, row2);
auto state1succ1 = row1.begin();
auto state1succ2 = (++row1.begin());
auto state2succ1 = row2.begin();
auto state2succ2 = (++row2.begin());
if (state1succ1->getColumn() == state2succ2->getColumn()
&& state2succ1->getColumn() == state1succ2->getColumn()) {
std::swap(state1succ1, state1succ2);
}
if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) {
result = validateAssumptionFunction(lattice, state1succ1, state1succ2, state2succ1, state2succ2);
if (!result) { if (!result) {
result = validateAssumptionSMTSolver(lattice, row1, row2);
result = validateAssumptionSMTSolver(lattice, state1succ1, state1succ2, state2succ1, state2succ2);
}
} }
} else {
STORM_LOG_DEBUG("Validation only implemented for two successor states");
} }
} }
@ -152,22 +166,12 @@ namespace storm {
} }
template <typename ValueType> template <typename ValueType>
bool AssumptionChecker<ValueType>::validateAssumptionFunction(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix<ValueType>::rows row1, typename storm::storage::SparseMatrix<ValueType>::rows row2) {
bool result = false;
auto state1succ1 = row1.begin();
auto state1succ2 = (++row1.begin());
auto state2succ1 = row2.begin();
auto state2succ2 = (++row2.begin());
if (state1succ1->getColumn() == state2succ2->getColumn()
&& state2succ1->getColumn() == state1succ2->getColumn()) {
// swap them
auto temp = state1succ2;
state1succ2 = state1succ1;
state1succ1 = temp;
}
bool AssumptionChecker<ValueType>::validateAssumptionFunction(storm::analysis::Lattice* lattice,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ2,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ2) {
if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) {
ValueType prob; ValueType prob;
auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn());
if (comp == storm::analysis::Lattice::ABOVE) { if (comp == storm::analysis::Lattice::ABOVE) {
@ -176,8 +180,9 @@ namespace storm {
prob = state1succ2->getValue() - state2succ2->getValue(); prob = state1succ2->getValue() - state2succ2->getValue();
} }
auto vars = prob.gatherVariables(); auto vars = prob.gatherVariables();
// TODO: Type // TODO: Type
std::map<storm::RationalFunctionVariable, storm::RationalFunctionCoefficient> substitutions;
std::map<storm::RationalFunctionVariable, typename ValueType::CoeffType> substitutions;
for (auto var:vars) { for (auto var:vars) {
auto derivative = prob.derivative(var); auto derivative = prob.derivative(var);
assert(derivative.isConstant()); assert(derivative.isConstant());
@ -187,25 +192,16 @@ namespace storm {
substitutions[var] = 1; substitutions[var] = 1;
} }
} }
result = prob.evaluate(substitutions) >= 0;
}
return result;
return prob.evaluate(substitutions) >= 0;
} }
template <typename ValueType> template <typename ValueType>
bool AssumptionChecker<ValueType>::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix<ValueType>::rows row1, typename storm::storage::SparseMatrix<ValueType>::rows row2) {
bool result = false;
auto state1succ1 = row1.begin();
auto state1succ2 = (++row1.begin());
auto state2succ1 = row2.begin();
auto state2succ2 = (++row2.begin());
bool AssumptionChecker<ValueType>::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ2,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ2) {
if (state1succ1->getColumn() == state2succ2->getColumn()
&& state2succ1->getColumn() == state1succ2->getColumn()) {
std::swap(state1succ1, state1succ2);
}
if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) {
std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>(); std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
std::shared_ptr<storm::expressions::ExpressionManager> manager( std::shared_ptr<storm::expressions::ExpressionManager> manager(
new storm::expressions::ExpressionManager()); new storm::expressions::ExpressionManager());
@ -239,10 +235,7 @@ namespace storm {
s.add(exprBounds); s.add(exprBounds);
smtResult = s.check(); smtResult = s.check();
} }
result = smtResult == storm::solver::SmtSolver::CheckResult::Sat;
}
return result;
return smtResult == storm::solver::SmtSolver::CheckResult::Sat;
} }
template <typename ValueType> template <typename ValueType>

13
src/storm-pars/analysis/AssumptionChecker.h

@ -69,11 +69,18 @@ namespace storm {
std::set<std::shared_ptr<storm::expressions::BinaryRelationExpression>> validatedAssumptions; std::set<std::shared_ptr<storm::expressions::BinaryRelationExpression>> validatedAssumptions;
bool validateAssumptionFunction(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix<ValueType>::rows row1, typename storm::storage::SparseMatrix<ValueType>::rows row2);
bool validateAssumptionFunction(storm::analysis::Lattice* lattice,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ2,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ2);
bool validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix<ValueType>::rows row1, typename storm::storage::SparseMatrix<ValueType>::rows row2);
bool validateAssumptionSMTSolver(storm::analysis::Lattice* lattice,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state1succ2,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ1,
typename storm::storage::SparseMatrix<ValueType>::iterator state2succ2);
}; };
} }
} }
#endif //STORM_ASSUMPTIONCHECKER_H #endif //STORM_ASSUMPTIONCHECKER_H
Loading…
Cancel
Save