Browse Source

changed shield export function

refactoring
Thomas Knoll 1 year ago
parent
commit
6adfa0cde1
  1. 2
      examples/shields/04_pre_shield_export.py
  2. 2
      examples/shields/05_post_shield_export.py
  3. 2
      examples/shields/06_optimal_shield_export.py
  4. 19
      examples/shields/09_pre_shield_decision_tree.py
  5. 4
      examples/shields/10_post_shield_decision_tree.py
  6. 8
      src/shields/abstract_shield.cpp
  7. 4
      src/shields/shield_handling.cpp

2
examples/shields/04_pre_shield_export.py

@ -36,7 +36,7 @@ def pre_schield():
shield = result.shield
stormpy.shields.export_shieldDouble(model, shield, "pre.shield")
stormpy.shields.export_shield(model, shield, "pre.shield")

2
examples/shields/05_post_shield_export.py

@ -36,7 +36,7 @@ def pre_schield():
shield = result.shield
stormpy.shields.export_shieldDouble(model, shield, "post.shield")
stormpy.shields.export_shield(model, shield, "post.shield")

2
examples/shields/06_optimal_shield_export.py

@ -33,7 +33,7 @@ def optimal_shield_export():
shield = result.shield
stormpy.shields.export_shieldDouble(model, shield, "optimal.shield")
stormpy.shields.export_shield(model, shield, "optimal.shield")
if __name__ == '__main__':

19
examples/shields/09_pre_shield_decision_tree.py

@ -8,6 +8,16 @@ import stormpy.shields
import stormpy.examples
import stormpy.examples.files
from sklearn.linear_model import LogisticRegression
from dtcontrol.benchmark_suite import BenchmarkSuite
from dtcontrol.decision_tree.decision_tree import DecisionTree
from dtcontrol.decision_tree.determinization.max_freq_determinizer import MaxFreqDeterminizer
from dtcontrol.decision_tree.impurity.entropy import Entropy
from dtcontrol.decision_tree.impurity.multi_label_entropy import MultiLabelEntropy
from dtcontrol.decision_tree.splitting.axis_aligned import AxisAlignedSplittingStrategy
from dtcontrol.decision_tree.splitting.linear_classifier import LinearClassifierSplittingStrategy
from stormpy.decision_tree import create_decision_tree
def export_shield_as_dot():
@ -30,7 +40,14 @@ def export_shield_as_dot():
shield = result.shield
filename = "preshield.storm.json"
stormpy.shields.export_shieldDouble(model, shield, filename)
stormpy.shields.export_shield(model, shield, filename)
if classifiers is None:
aa = AxisAlignedSplittingStrategy()
aa.priority = 1
classifiers = [DecisionTree([aa], Entropy(), name)]
output_folder = "pre_trees"
name = 'pre_my_output'

4
examples/shields/10_post_shield_decision_tree.py

@ -31,8 +31,8 @@ def export_shield_as_dot():
shield = result.shield
filename = "postshield.storm.json"
filename2 = "postshield.shield"
stormpy.shields.export_shieldDouble(model, shield, filename)
stormpy.shields.export_shieldDouble(model, shield, filename2)
stormpy.shields.export_shield(model, shield, filename)
stormpy.shields.export_shield(model, shield, filename2)
output_folder = "post_trees"
name = 'post_my_output'

8
src/shields/abstract_shield.cpp

@ -6,16 +6,20 @@
#include "storm/storage/BitVector.h"
#include "storm/storage/Distribution.h"
#include "storm/api/export.h"
template <typename ValueType, typename IndexType>
void define_abstract_shield(py::module& m, std::string vt_suffix) {
using AbstractShield = tempest::shields::AbstractShield<ValueType, IndexType>;
std::string shieldClassName = std::string("AbstractShield") + vt_suffix;
py::class_<AbstractShield, std::shared_ptr<AbstractShield>>(m, shieldClassName.c_str())
py::class_<AbstractShield, std::shared_ptr<AbstractShield>> shield(m, shieldClassName.c_str());
shield
.def("compute_row_group_size", &AbstractShield::computeRowGroupSizes)
.def("get_class_name", &AbstractShield::getClassName)
.def("get_optimization_direction", &AbstractShield::getOptimizationDirection)
;
;
}
template void define_abstract_shield<double, typename storm::storage::SparseMatrix<double>::index_type>(py::module& m, std::string vt_suffix);

4
src/shields/shield_handling.cpp

@ -5,10 +5,10 @@
template <typename ValueType, typename IndexType>
void define_shield_handling(py::module& m, std::string vt_suffix) {
std::string shieldHandlingname = std::string("export_shield") + vt_suffix;
std::string shieldHandlingname = std::string("export_shield");
m.def(shieldHandlingname.c_str(), &storm::api::exportShield<ValueType, IndexType>, py::arg("model"), py::arg("shield"), py::arg("filename"));
}
}
template void define_shield_handling<double, typename storm::storage::SparseMatrix<double>::index_type>(py::module& m, std::string vt_suffix);
template void define_shield_handling<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>(py::module& m, std::string vt_suffix);
Loading…
Cancel
Save