diff --git a/examples/shields/04_pre_shield_export.py b/examples/shields/04_pre_shield_export.py index 08bef6b..83b9fa7 100644 --- a/examples/shields/04_pre_shield_export.py +++ b/examples/shields/04_pre_shield_export.py @@ -36,7 +36,7 @@ def pre_schield(): shield = result.shield - stormpy.shields.export_shieldDouble(model, shield, "pre.shield") + stormpy.shields.export_shield(model, shield, "pre.shield") diff --git a/examples/shields/05_post_shield_export.py b/examples/shields/05_post_shield_export.py index 47567e0..d678dff 100644 --- a/examples/shields/05_post_shield_export.py +++ b/examples/shields/05_post_shield_export.py @@ -36,7 +36,7 @@ def pre_schield(): shield = result.shield - stormpy.shields.export_shieldDouble(model, shield, "post.shield") + stormpy.shields.export_shield(model, shield, "post.shield") diff --git a/examples/shields/06_optimal_shield_export.py b/examples/shields/06_optimal_shield_export.py index 7e789f7..ddf1453 100644 --- a/examples/shields/06_optimal_shield_export.py +++ b/examples/shields/06_optimal_shield_export.py @@ -33,7 +33,7 @@ def optimal_shield_export(): shield = result.shield - stormpy.shields.export_shieldDouble(model, shield, "optimal.shield") + stormpy.shields.export_shield(model, shield, "optimal.shield") if __name__ == '__main__': diff --git a/examples/shields/09_pre_shield_decision_tree.py b/examples/shields/09_pre_shield_decision_tree.py index 3e8a9a3..7c67809 100644 --- a/examples/shields/09_pre_shield_decision_tree.py +++ b/examples/shields/09_pre_shield_decision_tree.py @@ -8,6 +8,16 @@ import stormpy.shields import stormpy.examples import stormpy.examples.files +from sklearn.linear_model import LogisticRegression +from dtcontrol.benchmark_suite import BenchmarkSuite +from dtcontrol.decision_tree.decision_tree import DecisionTree +from dtcontrol.decision_tree.determinization.max_freq_determinizer import MaxFreqDeterminizer +from dtcontrol.decision_tree.impurity.entropy import Entropy +from dtcontrol.decision_tree.impurity.multi_label_entropy import MultiLabelEntropy +from dtcontrol.decision_tree.splitting.axis_aligned import AxisAlignedSplittingStrategy +from dtcontrol.decision_tree.splitting.linear_classifier import LinearClassifierSplittingStrategy + + from stormpy.decision_tree import create_decision_tree def export_shield_as_dot(): @@ -30,7 +40,14 @@ def export_shield_as_dot(): shield = result.shield filename = "preshield.storm.json" - stormpy.shields.export_shieldDouble(model, shield, filename) + stormpy.shields.export_shield(model, shield, filename) + + if classifiers is None: + aa = AxisAlignedSplittingStrategy() + aa.priority = 1 + + classifiers = [DecisionTree([aa], Entropy(), name)] + output_folder = "pre_trees" name = 'pre_my_output' diff --git a/examples/shields/10_post_shield_decision_tree.py b/examples/shields/10_post_shield_decision_tree.py index 0bf3100..ebaf5be 100644 --- a/examples/shields/10_post_shield_decision_tree.py +++ b/examples/shields/10_post_shield_decision_tree.py @@ -31,8 +31,8 @@ def export_shield_as_dot(): shield = result.shield filename = "postshield.storm.json" filename2 = "postshield.shield" - stormpy.shields.export_shieldDouble(model, shield, filename) - stormpy.shields.export_shieldDouble(model, shield, filename2) + stormpy.shields.export_shield(model, shield, filename) + stormpy.shields.export_shield(model, shield, filename2) output_folder = "post_trees" name = 'post_my_output' diff --git a/src/shields/abstract_shield.cpp b/src/shields/abstract_shield.cpp index 4f70e72..ab8aace 100644 --- a/src/shields/abstract_shield.cpp +++ b/src/shields/abstract_shield.cpp @@ -6,16 +6,20 @@ #include "storm/storage/BitVector.h" #include "storm/storage/Distribution.h" +#include "storm/api/export.h" + + template void define_abstract_shield(py::module& m, std::string vt_suffix) { using AbstractShield = tempest::shields::AbstractShield; std::string shieldClassName = std::string("AbstractShield") + vt_suffix; - py::class_>(m, shieldClassName.c_str()) + py::class_> shield(m, shieldClassName.c_str()); + shield .def("compute_row_group_size", &AbstractShield::computeRowGroupSizes) .def("get_class_name", &AbstractShield::getClassName) .def("get_optimization_direction", &AbstractShield::getOptimizationDirection) - ; + ; } template void define_abstract_shield::index_type>(py::module& m, std::string vt_suffix); diff --git a/src/shields/shield_handling.cpp b/src/shields/shield_handling.cpp index 205806f..eab19a8 100644 --- a/src/shields/shield_handling.cpp +++ b/src/shields/shield_handling.cpp @@ -5,10 +5,10 @@ template void define_shield_handling(py::module& m, std::string vt_suffix) { - std::string shieldHandlingname = std::string("export_shield") + vt_suffix; + std::string shieldHandlingname = std::string("export_shield"); m.def(shieldHandlingname.c_str(), &storm::api::exportShield, py::arg("model"), py::arg("shield"), py::arg("filename")); -} + } template void define_shield_handling::index_type>(py::module& m, std::string vt_suffix); template void define_shield_handling::index_type>(py::module& m, std::string vt_suffix);