From 5b782b5ca7c149dfc1e4a6f7fa1c6722921b2d6a Mon Sep 17 00:00:00 2001 From: Tom Janson Date: Sun, 18 Dec 2016 01:23:30 +0100 Subject: [PATCH] add ShortestPathsGenerator matrix/vector ctor bindings --- src/utility/shortestPaths.cpp | 23 ++++++++++++++++++--- tests/utility/test_shortestpaths.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/utility/shortestPaths.cpp b/src/utility/shortestPaths.cpp index 4313969..9ff97b4 100644 --- a/src/utility/shortestPaths.cpp +++ b/src/utility/shortestPaths.cpp @@ -9,11 +9,18 @@ void define_ksp(py::module& m) { // long types shortened for readability - using Path = storm::utility::ksp::Path; - using state_t = storm::utility::ksp::state_t; - using ShortestPathsGenerator = storm::utility::ksp::ShortestPathsGenerator; + // + // this could be templated rather than hardcoding double, but the actual + // bindings must refer to instantiated versions anyway (i.e., overloaded + // for each template instantiation) -- and double is enough for me using Model = storm::models::sparse::Model; using BitVector = storm::storage::BitVector; + using Matrix = storm::storage::SparseMatrix; + using MatrixFormat = storm::utility::ksp::MatrixFormat; + using Path = storm::utility::ksp::Path; + using ShortestPathsGenerator = storm::utility::ksp::ShortestPathsGenerator; + using state_t = storm::utility::ksp::state_t; + using StateProbMap = std::unordered_map; py::class_(m, "Path") // overload constructor rather than dealing with boost::optional @@ -39,10 +46,20 @@ void define_ksp(py::module& m) { .def_readwrite("distance", &Path::distance) ; + py::enum_(m, "MatrixFormat") + .value("Straight", MatrixFormat::straight) + .value("I_Minus_P", MatrixFormat::iMinusP) + ; + py::class_(m, "ShortestPathsGenerator") .def(py::init(), "model"_a, "target_bitvector"_a) .def(py::init(), "model"_a, "target_state"_a) .def(py::init const&>(), "model"_a, "target_state_list"_a) .def(py::init(), "model"_a, "target_label"_a) + .def(py::init const&, BitVector const&, MatrixFormat>(), "transition_matrix"_a, "target_prob_vector"_a, "initial_states"_a, "matrix_format"_a) + .def(py::init(), "transition_matrix"_a, "target_prob_map"_a, "initial_states"_a, "matrix_format"_a) + + // ShortestPathsGenerator(Matrix const& transitionMatrix, std::vector const& targetProbVector, BitVector const& initialStates, MatrixFormat matrixFormat); + // ShortestPathsGenerator(Matrix const& maybeTransitionMatrix, StateProbMap const& targetProbMap, BitVector const& initialStates, MatrixFormat matrixFormat); ; } \ No newline at end of file diff --git a/tests/utility/test_shortestpaths.py b/tests/utility/test_shortestpaths.py index ea01047..70451fd 100644 --- a/tests/utility/test_shortestpaths.py +++ b/tests/utility/test_shortestpaths.py @@ -2,6 +2,7 @@ import stormpy import stormpy.logic from stormpy.storage import BitVector from stormpy.utility import ShortestPathsGenerator +from stormpy.utility import MatrixFormat from helpers.helper import get_example_path import pytest @@ -40,6 +41,31 @@ def label(model): return some_label +@pytest.fixture +def transition_matrix(model): + return model.transition_matrix + + +@pytest.fixture +def target_prob_map(model, state_list): + return {i: (1.0 if i in state_list else 0.0) for i in range(model.nr_states)} + + +@pytest.fixture +def target_prob_list(target_prob_map): + return [target_prob_map[i] for i in range(max(target_prob_map.keys()))] + + +@pytest.fixture +def initial_states(model): + return BitVector(model.nr_states, model.initial_states) + + +@pytest.fixture +def matrix_format(): + return MatrixFormat.Straight + + class TestShortestPaths: def test_spg_ctor_bitvector_target(self, model, state_bitvector): _ = ShortestPathsGenerator(model, state_bitvector) @@ -53,5 +79,10 @@ class TestShortestPaths: def test_spg_ctor_label_target(self, model, label): _ = ShortestPathsGenerator(model, label) + def test_spg_ctor_matrix_vector(self, transition_matrix, target_prob_list, initial_states, matrix_format): + _ = ShortestPathsGenerator(transition_matrix, target_prob_list, initial_states, matrix_format) + + def test_spg_ctor_matrix_map(self, transition_matrix, target_prob_map, initial_states, matrix_format): + _ = ShortestPathsGenerator(transition_matrix, target_prob_map, initial_states, matrix_format) # TODO: add tests that check actual functionality