diff --git a/src/storm-pars/analysis/AssumptionChecker.cpp b/src/storm-pars/analysis/AssumptionChecker.cpp index 359dc1a86..0e18d0a7e 100644 --- a/src/storm-pars/analysis/AssumptionChecker.cpp +++ b/src/storm-pars/analysis/AssumptionChecker.cpp @@ -133,48 +133,54 @@ namespace storm { // Only implemented for two successors if (row1.getNumberOfEntries() == 2 && row2.getNumberOfEntries() == 2) { - auto succ1State1 = row1.begin(); - auto succ2State1 = (++row1.begin()); - auto succ1State2 = row2.begin(); - auto succ2State2 = (++row2.begin()); - - if (succ1State1->getColumn() == succ2State2->getColumn() - && succ1State2->getColumn() == succ2State1->getColumn()) { - // swap them - auto temp = succ2State1; - succ2State1 = succ1State1; - succ1State1 = temp; - } + result = validateAssumptionOnFunction(lattice, row1, row2); + } + } + if (result) { + validatedAssumptions.insert(assumption); + } else { + STORM_LOG_DEBUG("Could not validate: " << *assumption << std::endl); + } + return result; + } - if (succ1State1->getColumn() == succ1State2->getColumn() && succ2State1->getColumn() == succ2State2->getColumn()) { - ValueType prob; - auto comp = lattice->compare(succ1State1->getColumn(), succ2State1->getColumn()); - if (comp == storm::analysis::Lattice::ABOVE) { - prob = succ1State1->getValue() - succ1State2->getValue(); - } else if (comp == storm::analysis::Lattice::BELOW) { - prob = succ2State1->getValue() - succ2State2->getValue(); - } - auto vars = prob.gatherVariables(); - // TODO: Type - std::map substitutions; - for (auto var:vars) { - auto derivative = prob.derivative(var); - assert(derivative.isConstant()); - if (derivative.constantPart() >= 0) { - substitutions[var] = 0; - } else if (derivative.constantPart() <= 0) { - substitutions[var] = 1; - } - } - result = prob.evaluate(substitutions) >= 0; - } + template + bool AssumptionChecker::validateAssumptionOnFunction(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix::rows row1, typename storm::storage::SparseMatrix::rows row2) { + bool result = false; + auto succ1State1 = row1.begin(); + auto succ2State1 = (++row1.begin()); + auto succ1State2 = row2.begin(); + auto succ2State2 = (++row2.begin()); + + if (succ1State1->getColumn() == succ2State2->getColumn() + && succ1State2->getColumn() == succ2State1->getColumn()) { + // swap them + auto temp = succ2State1; + succ2State1 = succ1State1; + succ1State1 = temp; + } - if (result) { - validatedAssumptions.insert(assumption); - } else { - STORM_LOG_DEBUG("Could not validate: " << *assumption << std::endl); + if (succ1State1->getColumn() == succ1State2->getColumn() && succ2State1->getColumn() == succ2State2->getColumn()) { + ValueType prob; + auto comp = lattice->compare(succ1State1->getColumn(), succ2State1->getColumn()); + if (comp == storm::analysis::Lattice::ABOVE) { + prob = succ1State1->getValue() - succ1State2->getValue(); + } else if (comp == storm::analysis::Lattice::BELOW) { + prob = succ2State1->getValue() - succ2State2->getValue(); + } + auto vars = prob.gatherVariables(); + // TODO: Type + std::map substitutions; + for (auto var:vars) { + auto derivative = prob.derivative(var); + assert(derivative.isConstant()); + if (derivative.constantPart() >= 0) { + substitutions[var] = 0; + } else if (derivative.constantPart() <= 0) { + substitutions[var] = 1; } } + result = prob.evaluate(substitutions) >= 0; } return result; } diff --git a/src/storm-pars/analysis/AssumptionChecker.h b/src/storm-pars/analysis/AssumptionChecker.h index deb75cd87..bec43dd11 100644 --- a/src/storm-pars/analysis/AssumptionChecker.h +++ b/src/storm-pars/analysis/AssumptionChecker.h @@ -69,6 +69,7 @@ namespace storm { std::set> validatedAssumptions; + bool validateAssumptionOnFunction(storm::analysis::Lattice* lattice, typename storm::storage::SparseMatrix::rows row1, typename storm::storage::SparseMatrix::rows row2); }; } }