Browse Source

Use matrix instead of model for iterators

refactoring
Matthias Volk 8 years ago
parent
commit
c017949d0c
  1. 8
      lib/stormpy/storage/action.py
  2. 14
      lib/stormpy/storage/state.py
  3. 2
      src/storage/matrix.cpp
  4. 7
      src/storage/model.cpp
  5. 12
      tests/storage/test_matrix.py
  6. 26
      tests/storage/test_model_iterators.py

8
lib/stormpy/storage/action.py

@ -1,17 +1,17 @@
class Action:
""" Represents an action in the model """
def __init__(self, row_group_start, row_group_end, row, model):
def __init__(self, row_group_start, row_group_end, row, matrix):
""" Initialize
:param row_group_start: Start index of the row group in the matrix
:param row_group_end: End index of the row group in the matrix
:param row: Index of the corresponding row in the matrix
:param model: Corresponding model
:param matrix: Corresponding matrix
"""
self.row_group_start = row_group_start
self.row_group_end = row_group_end
self.row = row - 1
self.model = model
self.matrix = matrix
assert row >= -1 and row + row_group_start <= row_group_end
def __iter__(self):
@ -33,4 +33,4 @@ class Action:
"""
row = self.row_group_start + self.row
#return self.model.transition_matrix().get_row(self.row_group_start + self.row)
return self.model.transition_matrix.row_iter(row, row)
return self.matrix.row_iter(row, row)

14
lib/stormpy/storage/state.py

@ -1,21 +1,21 @@
from . import action
class State:
""" Represents a state in the model """
""" Represents a state in the matrix """
def __init__(self, id, model):
def __init__(self, id, matrix):
""" Initialize
:param id: Id of the state
:param model: Corresponding model
:param matrix: Corresponding matrix
"""
self.id = id - 1
self.model = model
self.matrix = matrix
def __iter__(self):
return self
def __next__(self):
if self.id >= self.model.nr_states - 1:
if self.id >= self.matrix.nr_row_groups - 1:
raise StopIteration
else:
self.id += 1
@ -28,7 +28,7 @@ class State:
""" Get actions associated with the state
:return List of actions
"""
row_group_indices = self.model.transition_matrix._row_group_indices
row_group_indices = self.matrix._row_group_indices
start = row_group_indices[self.id]
end = row_group_indices[self.id+1]
return action.Action(start, end, 0, self.model)
return action.Action(start, end, 0, self.matrix)

2
src/storage/matrix.cpp

@ -44,6 +44,7 @@ void define_sparse_matrix(py::module& m) {
return stream.str();
})
.def_property_readonly("nr_rows", &storm::storage::SparseMatrix<double>::getRowCount, "Number of rows")
.def_property_readonly("nr_row_groups", &storm::storage::SparseMatrix<double>::getRowGroupCount, "Number of row groups")
.def_property_readonly("nr_columns", &storm::storage::SparseMatrix<double>::getColumnCount, "Number of columns")
.def_property_readonly("nr_entries", &storm::storage::SparseMatrix<double>::getEntryCount, "Number of non-zero entries")
.def_property_readonly("_row_group_indices", &storm::storage::SparseMatrix<double>::getRowGroupIndices, "Starting rows of row groups")
@ -77,6 +78,7 @@ void define_sparse_matrix(py::module& m) {
return stream.str();
})
.def_property_readonly("nr_rows", &storm::storage::SparseMatrix<storm::RationalFunction>::getRowCount, "Number of rows")
.def_property_readonly("nr_row_groups", &storm::storage::SparseMatrix<storm::RationalFunction>::getRowGroupCount, "Number of row groups")
.def_property_readonly("nr_columns", &storm::storage::SparseMatrix<storm::RationalFunction>::getColumnCount, "Number of columns")
.def_property_readonly("nr_entries", &storm::storage::SparseMatrix<storm::RationalFunction>::getEntryCount, "Number of non-zero entries")
.def_property_readonly("_row_group_indices", &storm::storage::SparseMatrix<storm::RationalFunction>::getRowGroupIndices, "Starting rows of row groups")

7
src/storage/model.cpp

@ -22,6 +22,11 @@ storm::storage::SparseMatrix<ValueType>& getTransitionMatrix(storm::models::spar
return model.getTransitionMatrix();
}
template<typename ValueType>
storm::storage::SparseMatrix<ValueType> getBackwardTransitionMatrix(storm::models::sparse::Model<ValueType>& model) {
return model.getBackwardTransitions();
}
std::set<storm::RationalFunctionVariable> probabilityVariables(storm::models::sparse::Model<storm::RationalFunction> const& model) {
return storm::models::sparse::getProbabilityParameters(model);
}
@ -72,6 +77,7 @@ void define_model(py::module& m) {
.def("labels_state", &storm::models::sparse::Model<double>::getLabelsOfState, py::arg("state"), "Get labels of state")
.def_property_readonly("initial_states", &getInitialStates<double>, "Initial states")
.def_property_readonly("transition_matrix", &getTransitionMatrix<double>, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Transition matrix")
.def_property_readonly("backward_transition_matrix", &getBackwardTransitionMatrix<double>, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Backward transition matrix")
;
py::class_<storm::models::sparse::Dtmc<double>, std::shared_ptr<storm::models::sparse::Dtmc<double>>>(m, "SparseDtmc", "DTMC in sparse representation", model)
;
@ -89,6 +95,7 @@ void define_model(py::module& m) {
.def("labels_state", &storm::models::sparse::Model<storm::RationalFunction>::getLabelsOfState, py::arg("state"), "Get labels of state")
.def_property_readonly("initial_states", &getInitialStates<storm::RationalFunction>, "Initial states")
.def_property_readonly("transition_matrix", &getTransitionMatrix<storm::RationalFunction>, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Transition matrix")
.def_property_readonly("backward_transition_matrix", &getBackwardTransitionMatrix<storm::RationalFunction>, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Backward transition matrix")
;
py::class_<storm::models::sparse::Dtmc<storm::RationalFunction>, std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>>>(m, "SparseParametricDtmc", "pDTMC in sparse representation", modelRatFunc)
;

12
tests/storage/test_matrix.py

@ -14,6 +14,16 @@ class TestMatrix:
for e in matrix:
assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6)
def test_backward_matrix(self):
model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab"))
matrix = model.backward_transition_matrix
assert type(matrix) is stormpy.storage.SparseMatrix
assert matrix.nr_rows == model.nr_states
assert matrix.nr_columns == model.nr_states
assert matrix.nr_entries == 20 #model.nr_transitions
for e in matrix:
assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6)
def test_change_sparse_matrix(self):
model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab"))
matrix = model.transition_matrix
@ -58,7 +68,7 @@ class TestMatrix:
assert math.isclose(resValue, 0.06923076923076932)
# Change probabilities again
for state in stormpy.state.State(0, model):
for state in stormpy.state.State(0, model.transition_matrix):
for action in state.actions():
for transition in action.transitions():
if transition.value() == 0.3:

26
tests/storage/test_model_iterators.py

@ -4,33 +4,33 @@ from helpers.helper import get_example_path
class TestModelIterators:
def test_states_dtmc(self):
model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab"))
s = stormpy.state.State(0, model)
s = stormpy.state.State(0, model.transition_matrix)
i = 0
for state in stormpy.state.State(0, model):
for state in stormpy.state.State(0, model.transition_matrix):
assert state.id == i
i += 1
assert i == model.nr_states
def test_states_mdp(self):
model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab"))
s = stormpy.state.State(0, model)
s = stormpy.state.State(0, model.transition_matrix)
i = 0
for state in stormpy.state.State(0, model):
for state in stormpy.state.State(0, model.transition_matrix):
assert state.id == i
i += 1
assert i == model.nr_states
def test_actions_dtmc(self):
model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab"))
s = stormpy.state.State(0, model)
for state in stormpy.state.State(0, model):
s = stormpy.state.State(0, model.transition_matrix)
for state in stormpy.state.State(0, model.transition_matrix):
for action in state.actions():
assert action.row == 0
def test_actions_mdp(self):
model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab"))
s = stormpy.state.State(0, model)
for state in stormpy.state.State(0, model):
s = stormpy.state.State(0, model.transition_matrix)
for state in stormpy.state.State(0, model.transition_matrix):
for action in state.actions():
assert action.row == 0 or action.row == 1
@ -42,9 +42,9 @@ class TestModelIterators:
(9, 9, 1), (10, 10, 1), (11, 11, 1), (12, 12, 1)
]
model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab"))
s = stormpy.state.State(0, model)
s = stormpy.state.State(0, model.transition_matrix)
i = 0
for state in stormpy.state.State(0, model):
for state in stormpy.state.State(0, model.transition_matrix):
for action in state.actions():
for transition in action.transitions():
transition_orig = transitions_orig[i]
@ -55,8 +55,8 @@ class TestModelIterators:
def test_transitions_mdp(self):
model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab"))
s = stormpy.state.State(0, model)
for state in stormpy.state.State(0, model):
s = stormpy.state.State(0, model.transition_matrix)
for state in stormpy.state.State(0, model.transition_matrix):
i = 0
for action in state.actions():
i += 1
@ -73,7 +73,7 @@ class TestModelIterators:
]
model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab"))
i = 0
for state in stormpy.state.State(0, model):
for state in stormpy.state.State(0, model.transition_matrix):
for transition in model.transition_matrix.row_iter(state.id, state.id):
transition_orig = transitions_orig[i]
i += 1

Loading…
Cancel
Save