You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
3.2 KiB

8 years ago
  1. #include "shortestPaths.h"
  2. #include "storm/utility/shortestPaths.h"
  3. #include "src/helpers.h"
  4. // only forward declaring Model leads to pybind compilation error
  5. // this may be avoidable. but including certainly works.
  6. #include "storm/models/sparse/Model.h"
  7. #include "storm/models/sparse/StandardRewardModel.h"
  8. void define_ksp(py::module& m) {
  9. // long types shortened for readability
  10. //
  11. // this could be templated rather than hardcoding double, but the actual
  12. // bindings must refer to instantiated versions anyway (i.e., overloaded
  13. // for each template instantiation) -- and double is enough for me
  14. using BitVector = storm::storage::BitVector;
  15. using MatrixFormat = storm::utility::ksp::MatrixFormat;
  16. using Path = storm::utility::ksp::Path<double>;
  17. using ShortestPathsGenerator = storm::utility::ksp::ShortestPathsGenerator<double>;
  18. using state_t = storm::utility::ksp::state_t;
  19. using Matrix = ShortestPathsGenerator::Matrix;
  20. using Model = ShortestPathsGenerator::Model;
  21. using StateProbMap = ShortestPathsGenerator::StateProbMap;
  22. py::class_<Path>(m, "Path")
  23. // overload constructor rather than dealing with boost::optional
  24. .def("__init__", [](Path &instance, state_t preNode, unsigned long preK, double distance) {
  25. new (&instance) Path { boost::optional<state_t>(preNode), preK, distance };
  26. }, "predecessorNode"_a, "predecessorK"_a, "distance"_a)
  27. .def("__init__", [](Path &instance, unsigned long preK, double distance) {
  28. new (&instance) Path { boost::none, preK, distance };
  29. }, "predecessorK"_a, "distance"_a)
  30. .def(py::self == py::self, "Compares predecessor node and index, ignoring distance")
  31. //.def("__str__", &streamToString<Path>)
  32. .def_readwrite("predecessorNode", &Path::predecessorNode) // TODO (un-)wrap boost::optional so it's usable
  33. .def_readwrite("predecessorK", &Path::predecessorK)
  34. .def_readwrite("distance", &Path::distance)
  35. ;
  36. py::enum_<MatrixFormat>(m, "MatrixFormat")
  37. .value("Straight", MatrixFormat::straight)
  38. .value("I_Minus_P", MatrixFormat::iMinusP)
  39. ;
  40. py::class_<ShortestPathsGenerator>(m, "ShortestPathsGenerator")
  41. .def(py::init<Model const&, BitVector>(), "model"_a, "target_bitvector"_a)
  42. .def(py::init<Model const&, state_t>(), "model"_a, "target_state"_a)
  43. .def(py::init<Model const&, std::vector<state_t> const&>(), "model"_a, "target_state_list"_a)
  44. .def(py::init<Model const&, std::string>(), "model"_a, "target_label"_a)
  45. .def(py::init<Matrix const&, std::vector<double> const&, BitVector const&, MatrixFormat>(), "transition_matrix"_a, "target_prob_vector"_a, "initial_states"_a, "matrix_format"_a)
  46. .def(py::init<Matrix const&, StateProbMap const&, BitVector const&, MatrixFormat>(), "transition_matrix"_a, "target_prob_map"_a, "initial_states"_a, "matrix_format"_a)
  47. .def("get_distance", &ShortestPathsGenerator::getDistance, "k"_a)
  48. .def("get_states", &ShortestPathsGenerator::getStates, "k"_a)
  49. .def("get_path_as_list", &ShortestPathsGenerator::getPathAsList, "k"_a)
  50. ;
  51. }