Browse Source

add ShortestPathsGenerator matrix/vector ctor bindings

refactoring
Tom Janson 8 years ago
parent
commit
5b782b5ca7
  1. 23
      src/utility/shortestPaths.cpp
  2. 31
      tests/utility/test_shortestpaths.py

23
src/utility/shortestPaths.cpp

@ -9,11 +9,18 @@
void define_ksp(py::module& m) { void define_ksp(py::module& m) {
// long types shortened for readability // long types shortened for readability
using Path = storm::utility::ksp::Path<double>;
using state_t = storm::utility::ksp::state_t;
using ShortestPathsGenerator = storm::utility::ksp::ShortestPathsGenerator<double>;
//
// this could be templated rather than hardcoding double, but the actual
// bindings must refer to instantiated versions anyway (i.e., overloaded
// for each template instantiation) -- and double is enough for me
using Model = storm::models::sparse::Model<double>; using Model = storm::models::sparse::Model<double>;
using BitVector = storm::storage::BitVector; using BitVector = storm::storage::BitVector;
using Matrix = storm::storage::SparseMatrix<double>;
using MatrixFormat = storm::utility::ksp::MatrixFormat;
using Path = storm::utility::ksp::Path<double>;
using ShortestPathsGenerator = storm::utility::ksp::ShortestPathsGenerator<double>;
using state_t = storm::utility::ksp::state_t;
using StateProbMap = std::unordered_map<state_t, double>;
py::class_<Path>(m, "Path") py::class_<Path>(m, "Path")
// overload constructor rather than dealing with boost::optional // overload constructor rather than dealing with boost::optional
@ -39,10 +46,20 @@ void define_ksp(py::module& m) {
.def_readwrite("distance", &Path::distance) .def_readwrite("distance", &Path::distance)
; ;
py::enum_<MatrixFormat>(m, "MatrixFormat")
.value("Straight", MatrixFormat::straight)
.value("I_Minus_P", MatrixFormat::iMinusP)
;
py::class_<ShortestPathsGenerator>(m, "ShortestPathsGenerator") py::class_<ShortestPathsGenerator>(m, "ShortestPathsGenerator")
.def(py::init<Model const&, BitVector>(), "model"_a, "target_bitvector"_a) .def(py::init<Model const&, BitVector>(), "model"_a, "target_bitvector"_a)
.def(py::init<Model const&, state_t>(), "model"_a, "target_state"_a) .def(py::init<Model const&, state_t>(), "model"_a, "target_state"_a)
.def(py::init<Model const&, std::vector<state_t> const&>(), "model"_a, "target_state_list"_a) .def(py::init<Model const&, std::vector<state_t> const&>(), "model"_a, "target_state_list"_a)
.def(py::init<Model const&, std::string>(), "model"_a, "target_label"_a) .def(py::init<Model const&, std::string>(), "model"_a, "target_label"_a)
.def(py::init<Matrix const&, std::vector<double> const&, BitVector const&, MatrixFormat>(), "transition_matrix"_a, "target_prob_vector"_a, "initial_states"_a, "matrix_format"_a)
.def(py::init<Matrix const&, StateProbMap const&, BitVector const&, MatrixFormat>(), "transition_matrix"_a, "target_prob_map"_a, "initial_states"_a, "matrix_format"_a)
// ShortestPathsGenerator(Matrix const& transitionMatrix, std::vector<T> const& targetProbVector, BitVector const& initialStates, MatrixFormat matrixFormat);
// ShortestPathsGenerator(Matrix const& maybeTransitionMatrix, StateProbMap const& targetProbMap, BitVector const& initialStates, MatrixFormat matrixFormat);
; ;
} }

31
tests/utility/test_shortestpaths.py

@ -2,6 +2,7 @@ import stormpy
import stormpy.logic import stormpy.logic
from stormpy.storage import BitVector from stormpy.storage import BitVector
from stormpy.utility import ShortestPathsGenerator from stormpy.utility import ShortestPathsGenerator
from stormpy.utility import MatrixFormat
from helpers.helper import get_example_path from helpers.helper import get_example_path
import pytest import pytest
@ -40,6 +41,31 @@ def label(model):
return some_label return some_label
@pytest.fixture
def transition_matrix(model):
return model.transition_matrix
@pytest.fixture
def target_prob_map(model, state_list):
return {i: (1.0 if i in state_list else 0.0) for i in range(model.nr_states)}
@pytest.fixture
def target_prob_list(target_prob_map):
return [target_prob_map[i] for i in range(max(target_prob_map.keys()))]
@pytest.fixture
def initial_states(model):
return BitVector(model.nr_states, model.initial_states)
@pytest.fixture
def matrix_format():
return MatrixFormat.Straight
class TestShortestPaths: class TestShortestPaths:
def test_spg_ctor_bitvector_target(self, model, state_bitvector): def test_spg_ctor_bitvector_target(self, model, state_bitvector):
_ = ShortestPathsGenerator(model, state_bitvector) _ = ShortestPathsGenerator(model, state_bitvector)
@ -53,5 +79,10 @@ class TestShortestPaths:
def test_spg_ctor_label_target(self, model, label): def test_spg_ctor_label_target(self, model, label):
_ = ShortestPathsGenerator(model, label) _ = ShortestPathsGenerator(model, label)
def test_spg_ctor_matrix_vector(self, transition_matrix, target_prob_list, initial_states, matrix_format):
_ = ShortestPathsGenerator(transition_matrix, target_prob_list, initial_states, matrix_format)
def test_spg_ctor_matrix_map(self, transition_matrix, target_prob_map, initial_states, matrix_format):
_ = ShortestPathsGenerator(transition_matrix, target_prob_map, initial_states, matrix_format)
# TODO: add tests that check actual functionality # TODO: add tests that check actual functionality
Loading…
Cancel
Save