From 29f927530286df78aa8a93fe021c2232077fbde8 Mon Sep 17 00:00:00 2001 From: Jip Spel Date: Mon, 24 Sep 2018 14:46:56 +0200 Subject: [PATCH] Refactor validation methods --- src/storm-pars/analysis/AssumptionChecker.cpp | 163 +++++++++--------- src/storm-pars/analysis/AssumptionChecker.h | 13 +- 2 files changed, 88 insertions(+), 88 deletions(-) diff --git a/src/storm-pars/analysis/AssumptionChecker.cpp b/src/storm-pars/analysis/AssumptionChecker.cpp index fdac5ed62..b3326d519 100644 --- a/src/storm-pars/analysis/AssumptionChecker.cpp +++ b/src/storm-pars/analysis/AssumptionChecker.cpp @@ -136,10 +136,24 @@ namespace storm { // Only implemented for two successors if (row1.getNumberOfEntries() == 2 && row2.getNumberOfEntries() == 2) { - result = validateAssumptionFunction(lattice, row1, row2); - if (!result) { - result = validateAssumptionSMTSolver(lattice, row1, row2); + auto state1succ1 = row1.begin(); + auto state1succ2 = (++row1.begin()); + auto state2succ1 = row2.begin(); + auto state2succ2 = (++row2.begin()); + + if (state1succ1->getColumn() == state2succ2->getColumn() + && state2succ1->getColumn() == state1succ2->getColumn()) { + std::swap(state1succ1, state1succ2); } + + if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) { + result = validateAssumptionFunction(lattice, state1succ1, state1succ2, state2succ1, state2succ2); + if (!result) { + result = validateAssumptionSMTSolver(lattice, state1succ1, state1succ2, state2succ1, state2succ2); + } + } + } else { + STORM_LOG_DEBUG("Validation only implemented for two successor states"); } } @@ -152,97 +166,76 @@ namespace storm { } template - bool AssumptionChecker::validateAssumptionFunction(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix::rows row1, typename storm::storage::SparseMatrix::rows row2) { - bool result = false; - auto state1succ1 = row1.begin(); - auto state1succ2 = (++row1.begin()); - auto state2succ1 = row2.begin(); - auto state2succ2 = (++row2.begin()); - - if (state1succ1->getColumn() == state2succ2->getColumn() - && state2succ1->getColumn() == state1succ2->getColumn()) { - // swap them - auto temp = state1succ2; - state1succ2 = state1succ1; - state1succ1 = temp; + bool AssumptionChecker::validateAssumptionFunction(storm::analysis::Lattice* lattice, + typename storm::storage::SparseMatrix::iterator state1succ1, + typename storm::storage::SparseMatrix::iterator state1succ2, + typename storm::storage::SparseMatrix::iterator state2succ1, + typename storm::storage::SparseMatrix::iterator state2succ2) { + + ValueType prob; + auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); + if (comp == storm::analysis::Lattice::ABOVE) { + prob = state1succ1->getValue() - state2succ1->getValue(); + } else if (comp == storm::analysis::Lattice::BELOW) { + prob = state1succ2->getValue() - state2succ2->getValue(); } - - if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) { - ValueType prob; - auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); - if (comp == storm::analysis::Lattice::ABOVE) { - prob = state1succ1->getValue() - state2succ1->getValue(); - } else if (comp == storm::analysis::Lattice::BELOW) { - prob = state1succ2->getValue() - state2succ2->getValue(); + auto vars = prob.gatherVariables(); + + // TODO: Type + std::map substitutions; + for (auto var:vars) { + auto derivative = prob.derivative(var); + assert(derivative.isConstant()); + if (derivative.constantPart() >= 0) { + substitutions[var] = 0; + } else if (derivative.constantPart() <= 0) { + substitutions[var] = 1; } - auto vars = prob.gatherVariables(); - // TODO: Type - std::map substitutions; - for (auto var:vars) { - auto derivative = prob.derivative(var); - assert(derivative.isConstant()); - if (derivative.constantPart() >= 0) { - substitutions[var] = 0; - } else if (derivative.constantPart() <= 0) { - substitutions[var] = 1; - } - } - result = prob.evaluate(substitutions) >= 0; } - return result; + return prob.evaluate(substitutions) >= 0; } template - bool AssumptionChecker::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix::rows row1, typename storm::storage::SparseMatrix::rows row2) { - bool result = false; - auto state1succ1 = row1.begin(); - auto state1succ2 = (++row1.begin()); - auto state2succ1 = row2.begin(); - auto state2succ2 = (++row2.begin()); - - if (state1succ1->getColumn() == state2succ2->getColumn() - && state2succ1->getColumn() == state1succ2->getColumn()) { - std::swap(state1succ1, state1succ2); - } - - if (state1succ1->getColumn() == state2succ1->getColumn() && state1succ2->getColumn() == state2succ2->getColumn()) { - std::shared_ptr smtSolverFactory = std::make_shared(); - std::shared_ptr manager( - new storm::expressions::ExpressionManager()); - - storm::solver::Z3SmtSolver s(*manager); - storm::solver::SmtSolver::CheckResult smtResult = storm::solver::SmtSolver::CheckResult::Unknown; - storm::expressions::Variable succ1 = manager->declareRationalVariable(std::to_string(state1succ1->getColumn())); - storm::expressions::Variable succ2 = manager->declareRationalVariable(std::to_string(state1succ2->getColumn())); - auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); - if (comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW) { - if (comp == storm::analysis::Lattice::BELOW) { - std::swap(succ1, succ2); - } - storm::expressions::Expression exprGiven = succ1 >= succ2; - - auto valueTypeToExpression = storm::expressions::ValueTypeToExpression(manager); - storm::expressions::Expression exprToCheck = - (valueTypeToExpression.toExpression(state1succ1->getValue())*succ1 - + valueTypeToExpression.toExpression(state2succ1->getValue())*succ2 - >= valueTypeToExpression.toExpression(state1succ2->getValue())*succ1 - + valueTypeToExpression.toExpression(state1succ1->getValue())*succ2); - - storm::expressions::Expression exprBounds = manager->boolean(true); - auto variables = manager->getVariables(); - for (auto var : variables) { - exprBounds = exprBounds && var >= 0 && var <= 1; - } - - s.add(exprGiven); - s.add(exprToCheck); - s.add(exprBounds); - smtResult = s.check(); + bool AssumptionChecker::validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, + typename storm::storage::SparseMatrix::iterator state1succ1, + typename storm::storage::SparseMatrix::iterator state1succ2, + typename storm::storage::SparseMatrix::iterator state2succ1, + typename storm::storage::SparseMatrix::iterator state2succ2) { + + std::shared_ptr smtSolverFactory = std::make_shared(); + std::shared_ptr manager( + new storm::expressions::ExpressionManager()); + + storm::solver::Z3SmtSolver s(*manager); + storm::solver::SmtSolver::CheckResult smtResult = storm::solver::SmtSolver::CheckResult::Unknown; + storm::expressions::Variable succ1 = manager->declareRationalVariable(std::to_string(state1succ1->getColumn())); + storm::expressions::Variable succ2 = manager->declareRationalVariable(std::to_string(state1succ2->getColumn())); + auto comp = lattice->compare(state1succ1->getColumn(), state1succ2->getColumn()); + if (comp == storm::analysis::Lattice::ABOVE || comp == storm::analysis::Lattice::BELOW) { + if (comp == storm::analysis::Lattice::BELOW) { + std::swap(succ1, succ2); + } + storm::expressions::Expression exprGiven = succ1 >= succ2; + + auto valueTypeToExpression = storm::expressions::ValueTypeToExpression(manager); + storm::expressions::Expression exprToCheck = + (valueTypeToExpression.toExpression(state1succ1->getValue())*succ1 + + valueTypeToExpression.toExpression(state2succ1->getValue())*succ2 + >= valueTypeToExpression.toExpression(state1succ2->getValue())*succ1 + + valueTypeToExpression.toExpression(state1succ1->getValue())*succ2); + + storm::expressions::Expression exprBounds = manager->boolean(true); + auto variables = manager->getVariables(); + for (auto var : variables) { + exprBounds = exprBounds && var >= 0 && var <= 1; } - result = smtResult == storm::solver::SmtSolver::CheckResult::Sat; + s.add(exprGiven); + s.add(exprToCheck); + s.add(exprBounds); + smtResult = s.check(); } - return result; + return smtResult == storm::solver::SmtSolver::CheckResult::Sat; } template diff --git a/src/storm-pars/analysis/AssumptionChecker.h b/src/storm-pars/analysis/AssumptionChecker.h index 098e6433d..cdb6faf2d 100644 --- a/src/storm-pars/analysis/AssumptionChecker.h +++ b/src/storm-pars/analysis/AssumptionChecker.h @@ -69,11 +69,18 @@ namespace storm { std::set> validatedAssumptions; - bool validateAssumptionFunction(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix::rows row1, typename storm::storage::SparseMatrix::rows row2); + bool validateAssumptionFunction(storm::analysis::Lattice* lattice, + typename storm::storage::SparseMatrix::iterator state1succ1, + typename storm::storage::SparseMatrix::iterator state1succ2, + typename storm::storage::SparseMatrix::iterator state2succ1, + typename storm::storage::SparseMatrix::iterator state2succ2); - bool validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix::rows row1, typename storm::storage::SparseMatrix::rows row2); + bool validateAssumptionSMTSolver(storm::analysis::Lattice* lattice, + typename storm::storage::SparseMatrix::iterator state1succ1, + typename storm::storage::SparseMatrix::iterator state1succ2, + typename storm::storage::SparseMatrix::iterator state2succ1, + typename storm::storage::SparseMatrix::iterator state2succ2); }; } } - #endif //STORM_ASSUMPTIONCHECKER_H