From 8de8570d11640783e845782d5351d2daa832f574 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Thu, 19 Apr 2018 15:03:34 +0200 Subject: [PATCH] - more expression handling - smt wrap --- src/mod_utility.cpp | 3 ++ src/storage/expressions.cpp | 25 +++++++++++++-- src/utility/smtsolver.cpp | 32 +++++++++++++++++++ src/utility/smtsolver.h | 5 +++ tests/utility/test_smtsolver.py | 55 +++++++++++++++++++++++++++++++++ 5 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 src/utility/smtsolver.cpp create mode 100644 src/utility/smtsolver.h create mode 100644 tests/utility/test_smtsolver.py diff --git a/src/mod_utility.cpp b/src/mod_utility.cpp index 1b6870c..8d1a504 100644 --- a/src/mod_utility.cpp +++ b/src/mod_utility.cpp @@ -1,9 +1,12 @@ #include "common.h" #include "utility/shortestPaths.h" +#include "utility/smtsolver.h" + PYBIND11_MODULE(utility, m) { m.doc() = "Utilities for Storm"; define_ksp(m); + define_smt(m); } diff --git a/src/storage/expressions.cpp b/src/storage/expressions.cpp index 3e564ab..00799ec 100644 --- a/src/storage/expressions.cpp +++ b/src/storage/expressions.cpp @@ -7,6 +7,9 @@ //Define python bindings void define_expressions(py::module& m) { + using Expression = storm::expressions::Expression; + + // ExpressionManager py::class_>(m, "ExpressionManager", "Manages variables for expressions") @@ -16,7 +19,11 @@ void define_expressions(py::module& m) { .def("create_rational", [](storm::expressions::ExpressionManager const& manager, storm::RationalNumber number) { return manager.rational(number); }, py::arg("rational"), "Create expression from rational number") - ; + .def("create_boolean_variable", &storm::expressions::ExpressionManager::declareBooleanVariable, "create Boolean variable", py::arg("name"), py::arg("auxiliary") = false) + .def("create_integer_variable", &storm::expressions::ExpressionManager::declareIntegerVariable, "create Integer variable", py::arg("name"), py::arg("auxiliary") = false) + .def("create_rational_variable", &storm::expressions::ExpressionManager::declareRationalVariable, "create Rational variable", py::arg("name"), py::arg("auxiliary") = false) + + ; // Variable py::class_>(m, "Variable", "Represents a variable") @@ -39,7 +46,21 @@ void define_expressions(py::module& m) { .def("has_rational_type", &storm::expressions::Expression::hasRationalType, "Check if the expression is a rational") .def_property_readonly("type", &storm::expressions::Expression::getType, "Get the Type") .def("__str__", &storm::expressions::Expression::toString, "To string") - ; + + .def_static("plus", [](Expression const& lhs, Expression const& rhs) {return lhs + rhs;}) + .def_static("minus", [](Expression const& lhs, Expression const& rhs) {return lhs - rhs;}) + .def_static("multiply", [](Expression const& lhs, Expression const& rhs) {return lhs * rhs;}) + .def_static("and", [](Expression const& lhs, Expression const& rhs) {return lhs && rhs;}) + .def_static("or", [](Expression const& lhs, Expression const& rhs) {return lhs || rhs;}) + .def_static("geq", [](Expression const& lhs, Expression const& rhs) {return lhs >= rhs;}) + .def_static("eq", [](Expression const& lhs, Expression const& rhs) {return lhs == rhs;}) + .def_static("neq", [](Expression const& lhs, Expression const& rhs) {return lhs != rhs;}) + .def_static("greater", [](Expression const& lhs, Expression const& rhs) {return lhs > rhs;}) + .def_static("less", [](Expression const& lhs, Expression const& rhs) {return lhs < rhs;}) + .def_static("leq", [](Expression const& lhs, Expression const& rhs) {return lhs <= rhs;}) + .def_static("implies", [](Expression const& lhs, Expression const& rhs) {return storm::expressions::implies(lhs, rhs);}) + .def_static("iff", [](Expression const& lhs, Expression const& rhs) {return storm::expressions::iff(lhs, rhs);}) + ; py::class_(m, "ExpressionParser", "Parser for storm-expressions") .def(py::init(), "Expression Manager to use", py::arg("expression_manager")) diff --git a/src/utility/smtsolver.cpp b/src/utility/smtsolver.cpp new file mode 100644 index 0000000..4e8e258 --- /dev/null +++ b/src/utility/smtsolver.cpp @@ -0,0 +1,32 @@ +#include "smtsolver.h" +#include +#include "storm/storage/expressions/ExpressionManager.h" + +void define_smt(py::module& m) { + using SmtSolver = storm::solver::SmtSolver; + using Z3SmtSolver = storm::solver::Z3SmtSolver; + using ModelReference = storm::solver::SmtSolver::ModelReference; + + py::enum_(m, "SmtCheckResult", "Result type") + .value("Sat", SmtSolver::CheckResult::Sat) + .value("Unsat", SmtSolver::CheckResult::Unsat) + .value("Unknown", SmtSolver::CheckResult::Unknown) + ; + + py::class_> modelref(m, "ModelReference", "Lightweight Wrapper around results"); + modelref.def("get_boolean_value", &ModelReference::getBooleanValue, "get a value for a boolean variable", py::arg("variable")) + .def("get_integer_value", &ModelReference::getIntegerValue, "get a value for an integer variable", py::arg("variable")) + .def("get_rational_value", &ModelReference::getRationalValue, "get a value (as double) for an rational variable", py::arg("variable")); + + + py::class_ smtsolver(m, "SmtSolver", "Generic Storm SmtSolver Wrapper"); + smtsolver.def("push", &SmtSolver::push, "push") + .def("pop", [](SmtSolver& solver, uint64_t n) {solver.pop(n);}, "pop", py::arg("levels")) + .def("reset", &SmtSolver::reset, "reset") + .def("add", [](SmtSolver& solver, storm::expressions::Expression const& expr) {solver.add(expr);}, "addconstraint") + .def("check", &SmtSolver::check, "check") + .def_property_readonly("model", &SmtSolver::getModel, "get the model"); + + py::class_ z3solver(m, "Z3SmtSolver", "z3 API for storm smtsolver wrapper", smtsolver); + z3solver.def(pybind11::init()); +} \ No newline at end of file diff --git a/src/utility/smtsolver.h b/src/utility/smtsolver.h new file mode 100644 index 0000000..d23a479 --- /dev/null +++ b/src/utility/smtsolver.h @@ -0,0 +1,5 @@ +#pragma once + +#include "src/common.h" + +void define_smt(py::module& m); \ No newline at end of file diff --git a/tests/utility/test_smtsolver.py b/tests/utility/test_smtsolver.py new file mode 100644 index 0000000..7655f3a --- /dev/null +++ b/tests/utility/test_smtsolver.py @@ -0,0 +1,55 @@ +import stormpy +import stormpy.utility + +import pytest + +class TestSmtSolver(): + def test_smtsolver_init(self): + manager = stormpy.ExpressionManager() + solver = stormpy.utility.Z3SmtSolver(manager) + + def test_smtsolver_trivial(self): + manager = stormpy.ExpressionManager() + solver = stormpy.utility.Z3SmtSolver(manager) + solver.add(manager.create_boolean(True)) + assert solver.check() != stormpy.utility.SmtCheckResult.Unsat + assert solver.check() == stormpy.utility.SmtCheckResult.Sat + solver.add(manager.create_boolean(False)) + assert solver.check() == stormpy.utility.SmtCheckResult.Unsat + assert solver.check() != stormpy.utility.SmtCheckResult.Sat + + def test_smtsolver_arithmetic_unsat(self): + manager = stormpy.ExpressionManager() + x = manager.create_integer_variable("x") + xe = x.get_expression() + c1 = stormpy.Expression.geq(xe, manager.create_integer(1)) + c2 = stormpy.Expression.less(xe, manager.create_integer(0)) + solver = stormpy.utility.Z3SmtSolver(manager) + solver.add(c1) + solver.add(c2) + assert solver.check() == stormpy.utility.SmtCheckResult.Unsat + + def test_smtsolver_arithmetic_unsat(self): + manager = stormpy.ExpressionManager() + x = manager.create_integer_variable("x") + xe = x.get_expression() + c1 = stormpy.Expression.geq(xe, manager.create_integer(1)) + c2 = stormpy.Expression.less(xe, manager.create_integer(0)) + solver = stormpy.utility.Z3SmtSolver(manager) + solver.add(c1) + solver.add(c2) + assert solver.check() == stormpy.utility.SmtCheckResult.Unsat + + def test_smtsolver_arithmetic_unsat(self): + manager = stormpy.ExpressionManager() + x = manager.create_integer_variable("x") + xe = x.get_expression() + c1 = stormpy.Expression.geq(xe, manager.create_integer(1)) + c2 = stormpy.Expression.less(xe, manager.create_integer(2)) + solver = stormpy.utility.Z3SmtSolver(manager) + solver.add(c1) + solver.add(c2) + assert solver.check() == stormpy.utility.SmtCheckResult.Sat + assert solver.model.get_integer_value(x) == 1 + +