You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

63 lines
3.2 KiB

#include "shortestPaths.h"
#include "storm/utility/shortestPaths.h"
#include "src/helpers.h"
// only forward declaring Model leads to pybind compilation error
// this may be avoidable. but including certainly works.
#include "storm/models/sparse/Model.h"
#include "storm/models/sparse/StandardRewardModel.h"
void define_ksp(py::module& m) {
// long types shortened for readability
//
// this could be templated rather than hardcoding double, but the actual
// bindings must refer to instantiated versions anyway (i.e., overloaded
// for each template instantiation) -- and double is enough for me
using BitVector = storm::storage::BitVector;
using MatrixFormat = storm::utility::ksp::MatrixFormat;
using Path = storm::utility::ksp::Path<double>;
using ShortestPathsGenerator = storm::utility::ksp::ShortestPathsGenerator<double>;
using state_t = storm::utility::ksp::state_t;
using Matrix = ShortestPathsGenerator::Matrix;
using Model = ShortestPathsGenerator::Model;
using StateProbMap = ShortestPathsGenerator::StateProbMap;
py::class_<Path>(m, "Path")
// overload constructor rather than dealing with boost::optional
.def("__init__", [](Path &instance, state_t preNode, unsigned long preK, double distance) {
new (&instance) Path { boost::optional<state_t>(preNode), preK, distance };
}, "predecessorNode"_a, "predecessorK"_a, "distance"_a)
.def("__init__", [](Path &instance, unsigned long preK, double distance) {
new (&instance) Path { boost::none, preK, distance };
}, "predecessorK"_a, "distance"_a)
.def(py::self == py::self, "Compares predecessor node and index, ignoring distance")
//.def("__str__", &streamToString<Path>)
.def_readwrite("predecessorNode", &Path::predecessorNode) // TODO (un-)wrap boost::optional so it's usable
.def_readwrite("predecessorK", &Path::predecessorK)
.def_readwrite("distance", &Path::distance)
;
py::enum_<MatrixFormat>(m, "MatrixFormat")
.value("Straight", MatrixFormat::straight)
.value("I_Minus_P", MatrixFormat::iMinusP)
;
py::class_<ShortestPathsGenerator>(m, "ShortestPathsGenerator")
.def(py::init<Model const&, BitVector>(), "model"_a, "target_bitvector"_a)
.def(py::init<Model const&, state_t>(), "model"_a, "target_state"_a)
.def(py::init<Model const&, std::vector<state_t> const&>(), "model"_a, "target_state_list"_a)
.def(py::init<Model const&, std::string>(), "model"_a, "target_label"_a)
.def(py::init<Matrix const&, std::vector<double> const&, BitVector const&, MatrixFormat>(), "transition_matrix"_a, "target_prob_vector"_a, "initial_states"_a, "matrix_format"_a)
.def(py::init<Matrix const&, StateProbMap const&, BitVector const&, MatrixFormat>(), "transition_matrix"_a, "target_prob_map"_a, "initial_states"_a, "matrix_format"_a)
.def("get_distance", &ShortestPathsGenerator::getDistance, "k"_a)
.def("get_states", &ShortestPathsGenerator::getStates, "k"_a)
.def("get_path_as_list", &ShortestPathsGenerator::getPathAsList, "k"_a)
;
}