From 09b2bfcf67a9656b5ef777c2190954b91d94fcac Mon Sep 17 00:00:00 2001 From: Tom Janson Date: Sun, 18 Dec 2016 02:29:13 +0100 Subject: [PATCH] add ShortestPathsGenerator method binding (and fancy test fixture) --- src/utility/shortestPaths.cpp | 5 +- tests/utility/test_shortestpaths.py | 85 ++++++++++++++++++++++++----- 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/src/utility/shortestPaths.cpp b/src/utility/shortestPaths.cpp index 9ff97b4..f0817ca 100644 --- a/src/utility/shortestPaths.cpp +++ b/src/utility/shortestPaths.cpp @@ -59,7 +59,8 @@ void define_ksp(py::module& m) { .def(py::init const&, BitVector const&, MatrixFormat>(), "transition_matrix"_a, "target_prob_vector"_a, "initial_states"_a, "matrix_format"_a) .def(py::init(), "transition_matrix"_a, "target_prob_map"_a, "initial_states"_a, "matrix_format"_a) - // ShortestPathsGenerator(Matrix const& transitionMatrix, std::vector const& targetProbVector, BitVector const& initialStates, MatrixFormat matrixFormat); - // ShortestPathsGenerator(Matrix const& maybeTransitionMatrix, StateProbMap const& targetProbMap, BitVector const& initialStates, MatrixFormat matrixFormat); + .def("get_distance", &ShortestPathsGenerator::getDistance, "k"_a) + .def("get_states", &ShortestPathsGenerator::getStates, "k"_a) + .def("get_path_as_list", &ShortestPathsGenerator::getPathAsList, "k"_a) ; } \ No newline at end of file diff --git a/tests/utility/test_shortestpaths.py b/tests/utility/test_shortestpaths.py index 70451fd..45f4d96 100644 --- a/tests/utility/test_shortestpaths.py +++ b/tests/utility/test_shortestpaths.py @@ -6,13 +6,67 @@ from stormpy.utility import MatrixFormat from helpers.helper import get_example_path import pytest +import math + + +# this is admittedly slightly overengineered + +class ModelWithKnownShortestPaths: + """Knuth's die model with reference kSP methods""" + def __init__(self): + self.target_label = "one" + + program_path = get_example_path("dtmc", "die.pm") + raw_formula = "P=? [ F \"" + self.target_label + "\" ]" + program = stormpy.parse_prism_program(program_path) + formulas = stormpy.parse_formulas_for_prism_program(raw_formula, program) + + self.model = stormpy.build_model(program, formulas[0]) + + def probability(self, k): + return (1/2)**((2 * k) + 1) + + def state_set(self, k): + return BitVector(self.model.nr_states, [0, 1, 3, 7]) + + def path(self, k): + path = [0] + k * [1, 3] + [7] + return list(reversed(path)) # SPG returns traversal from back + + +@pytest.fixture(scope="module", params=[1, 2, 3, 3000, 42]) +def index(request): + return request.param + + +@pytest.fixture(scope="module") +def model_with_known_shortest_paths(): + return ModelWithKnownShortestPaths() + + +@pytest.fixture(scope="module") +def model(model_with_known_shortest_paths): + return model_with_known_shortest_paths.model + + +@pytest.fixture(scope="module") +def expected_distance(model_with_known_shortest_paths): + return model_with_known_shortest_paths.probability @pytest.fixture(scope="module") -def model(program_path=get_example_path("dtmc", "die.pm"), raw_formula="P=? [ F \"one\" ]"): - program = stormpy.parse_prism_program(program_path) - formulas = stormpy.parse_formulas_for_prism_program(raw_formula, program) - return stormpy.build_model(program, formulas[0]) +def expected_state_set(model_with_known_shortest_paths): + return model_with_known_shortest_paths.state_set + + +@pytest.fixture(scope="module") +def expected_path(model_with_known_shortest_paths): + return model_with_known_shortest_paths.path + + +@pytest.fixture(scope="module") +def target_label(model_with_known_shortest_paths): + return model_with_known_shortest_paths.target_label @pytest.fixture @@ -34,13 +88,6 @@ def state_bitvector(model, state_list): return BitVector(length=model.nr_states, set_entries=state_list) -@pytest.fixture -def label(model): - some_label = "one" - assert some_label in model.labels, "test model does not contain label '" + some_label + "'" - return some_label - - @pytest.fixture def transition_matrix(model): return model.transition_matrix @@ -76,8 +123,8 @@ class TestShortestPaths: def test_spg_ctor_state_list_target(self, model, state_list): _ = ShortestPathsGenerator(model, state_list) - def test_spg_ctor_label_target(self, model, label): - _ = ShortestPathsGenerator(model, label) + def test_spg_ctor_label_target(self, model, target_label): + _ = ShortestPathsGenerator(model, target_label) def test_spg_ctor_matrix_vector(self, transition_matrix, target_prob_list, initial_states, matrix_format): _ = ShortestPathsGenerator(transition_matrix, target_prob_list, initial_states, matrix_format) @@ -85,4 +132,14 @@ class TestShortestPaths: def test_spg_ctor_matrix_map(self, transition_matrix, target_prob_map, initial_states, matrix_format): _ = ShortestPathsGenerator(transition_matrix, target_prob_map, initial_states, matrix_format) - # TODO: add tests that check actual functionality + def test_spg_distance(self, model, target_label, index, expected_distance): + spg = ShortestPathsGenerator(model, target_label) + assert math.isclose(spg.get_distance(index), expected_distance(index)) + + def test_spg_state_set(self, model, target_label, index, expected_state_set): + spg = ShortestPathsGenerator(model, target_label) + assert spg.get_states(index) == expected_state_set(index) + + def test_spg_state_list(self, model, target_label, index, expected_path): + spg = ShortestPathsGenerator(model, target_label) + assert spg.get_path_as_list(index) == expected_path(index)