diff --git a/examples/shields/04_pre_shield_export.py b/examples/shields/04_pre_shield_export.py index be5e4dd..08bef6b 100644 --- a/examples/shields/04_pre_shield_export.py +++ b/examples/shields/04_pre_shield_export.py @@ -36,7 +36,7 @@ def pre_schield(): shield = result.shield - stormpy.shields.export_shieldDouble(model, shield) + stormpy.shields.export_shieldDouble(model, shield, "pre.shield") diff --git a/examples/shields/05_post_shield_export.py b/examples/shields/05_post_shield_export.py index b6265b0..47567e0 100644 --- a/examples/shields/05_post_shield_export.py +++ b/examples/shields/05_post_shield_export.py @@ -36,7 +36,7 @@ def pre_schield(): shield = result.shield - stormpy.shields.export_shieldDouble(model, shield) + stormpy.shields.export_shieldDouble(model, shield, "post.shield") diff --git a/examples/shields/06_optimal_shield_export.py b/examples/shields/06_optimal_shield_export.py index 7a3a870..7e789f7 100644 --- a/examples/shields/06_optimal_shield_export.py +++ b/examples/shields/06_optimal_shield_export.py @@ -33,7 +33,7 @@ def optimal_shield_export(): shield = result.shield - stormpy.shields.export_shieldDouble(model, shield) + stormpy.shields.export_shieldDouble(model, shield, "optimal.shield") if __name__ == '__main__': diff --git a/examples/shields/09_dtcontrol.py b/examples/shields/09_dtcontrol.py new file mode 100644 index 0000000..b27a9fe --- /dev/null +++ b/examples/shields/09_dtcontrol.py @@ -0,0 +1,40 @@ +import stormpy +import stormpy.core +import stormpy.simulator + + +import stormpy.shields + +import stormpy.examples +import stormpy.examples.files + +from stormpy.dtcontrol import export_decision_tree + +def export_shield_as_dot(): + path = stormpy.examples.files.prism_mdp_lava_simple + formula_str = " Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" + + program = stormpy.parse_prism_program(path) + formulas = stormpy.parse_properties_for_prism_program(formula_str, program) + + options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) + options.set_build_state_valuations(True) + options.set_build_choice_labels(True) + options.set_build_all_labels() + options.set_build_with_choice_origins(True) + model = stormpy.build_sparse_model_with_options(program, options) + + result = stormpy.model_checking(model, formulas[0], extract_scheduler=True) #, shielding_expression=shield_specification) + + assert result.has_shield + + shield = result.shield + stormpy.shields.export_shieldDouble(model, shield, "preshield.storm.json") + + + export_decision_tree(result.shield) + + + +if __name__ == '__main__': + export_shield_as_dot() \ No newline at end of file diff --git a/lib/stormpy/dtcontrol.py b/lib/stormpy/dtcontrol.py new file mode 100644 index 0000000..1f28158 --- /dev/null +++ b/lib/stormpy/dtcontrol.py @@ -0,0 +1,17 @@ +import json +import logging +from os.path import splitext, exists +from sklearn.linear_model import LogisticRegression +import numpy as np + +from dtcontrol.dataset.dataset_loader import DatasetLoader +from dtcontrol.decision_tree.decision_tree import DecisionTree +from dtcontrol.decision_tree.impurity.entropy import Entropy + + +from dtcontrol.decision_tree.splitting.axis_aligned import AxisAlignedSplittingStrategy +from dtcontrol.decision_tree.splitting.linear_classifier import LinearClassifierSplittingStrategy + + +def export_decision_tree(filename): + pass diff --git a/setup.py b/setup.py index 5262efc..e8fdecf 100755 --- a/setup.py +++ b/setup.py @@ -275,7 +275,7 @@ setup( cmdclass={'build_ext': CMakeBuild}, zip_safe=False, - install_requires=['pycarl>=2.0.4'], + install_requires=['pycarl>=2.0.4', 'dtcontrol'], setup_requires=['pytest-runner'], tests_require=['pytest', 'nbval'], extras_require={ diff --git a/src/shields/shield_handling.cpp b/src/shields/shield_handling.cpp index a894fa0..205806f 100644 --- a/src/shields/shield_handling.cpp +++ b/src/shields/shield_handling.cpp @@ -7,7 +7,7 @@ template void define_shield_handling(py::module& m, std::string vt_suffix) { std::string shieldHandlingname = std::string("export_shield") + vt_suffix; - m.def(shieldHandlingname.c_str(), &storm::api::exportShield, py::arg("model"), py::arg("shield")); + m.def(shieldHandlingname.c_str(), &storm::api::exportShield, py::arg("model"), py::arg("shield"), py::arg("filename")); } template void define_shield_handling::index_type>(py::module& m, std::string vt_suffix);