Mavo
9 years ago
committed by
Matthias Volk
5 changed files with 165 additions and 0 deletions
-
1stormpy/lib/stormpy/storage/__init__.py
-
34stormpy/lib/stormpy/storage/action.py
-
34stormpy/lib/stormpy/storage/state.py
-
32stormpy/src/storage/matrix.cpp
-
64stormpy/tests/storage/test_model_iterators.py
@ -1,2 +1,3 @@ |
|||
from . import storage |
|||
from .storage import * |
|||
from . import state,action |
@ -0,0 +1,34 @@ |
|||
class Action: |
|||
""" Represents an action in the model """ |
|||
|
|||
def __init__(self, row_group_start, row_group_end, row, model): |
|||
""" Initialize |
|||
:param row_group_start: Start index of the row group in the matrix |
|||
:param row_group_end: End index of the row group in the matrix |
|||
:param row: Index of the corresponding row in the matrix |
|||
:param model: Corresponding model |
|||
""" |
|||
self.row_group_start = row_group_start |
|||
self.row_group_end = row_group_end |
|||
self.row = row - 1 |
|||
self.model = model |
|||
assert row >= -1 and row + row_group_start <= row_group_end |
|||
|
|||
def __iter__(self): |
|||
return self |
|||
|
|||
def __next__(self): |
|||
if self.row + self.row_group_start >= self.row_group_end - 1: |
|||
raise StopIteration |
|||
else: |
|||
self.row += 1 |
|||
return self |
|||
|
|||
def __str__(self): |
|||
return "{}".format(self.row) |
|||
|
|||
def transitions(self): |
|||
""" Get transitions associated with the action |
|||
:return List of tranistions |
|||
""" |
|||
return self.model.transition_matrix().get_row(self.row + self.row_group_start) |
@ -0,0 +1,34 @@ |
|||
import stormpy.storage |
|||
|
|||
class State: |
|||
""" Represents a state in the model """ |
|||
|
|||
def __init__(self, id, model): |
|||
""" Initialize |
|||
:param id: Id of the state |
|||
:param model: Corresponding model |
|||
""" |
|||
self.id = id - 1 |
|||
self.model = model |
|||
|
|||
def __iter__(self): |
|||
return self |
|||
|
|||
def __next__(self): |
|||
if self.id >= self.model.nr_states() - 1: |
|||
raise StopIteration |
|||
else: |
|||
self.id += 1 |
|||
return self |
|||
|
|||
def __str__(self): |
|||
return "{}".format(self.id) |
|||
|
|||
def actions(self): |
|||
""" Get actions associated with the state |
|||
:return List of actions |
|||
""" |
|||
row_group_indices = self.model.transition_matrix().row_group_indices() |
|||
start = row_group_indices[self.id] |
|||
end = row_group_indices[self.id+1] |
|||
return stormpy.action.Action(start, end, 0, self.model) |
@ -0,0 +1,64 @@ |
|||
import stormpy |
|||
|
|||
class TestModelIterators: |
|||
def test_states_dtmc(self): |
|||
model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab") |
|||
s = stormpy.state.State(0, model) |
|||
i = 0 |
|||
for state in stormpy.state.State(0, model): |
|||
assert state.id == i |
|||
i += 1 |
|||
assert i == model.nr_states() |
|||
|
|||
def test_states_mdp(self): |
|||
model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab") |
|||
s = stormpy.state.State(0, model) |
|||
i = 0 |
|||
for state in stormpy.state.State(0, model): |
|||
assert state.id == i |
|||
i += 1 |
|||
assert i == model.nr_states() |
|||
|
|||
def test_actions_dtmc(self): |
|||
model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab") |
|||
s = stormpy.state.State(0, model) |
|||
for state in stormpy.state.State(0, model): |
|||
for action in state.actions(): |
|||
assert action.row == 0 |
|||
|
|||
def test_actions_mdp(self): |
|||
model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab") |
|||
s = stormpy.state.State(0, model) |
|||
for state in stormpy.state.State(0, model): |
|||
for action in state.actions(): |
|||
assert action.row == 0 or action.row == 1 |
|||
|
|||
def test_transitions_dtmc(self): |
|||
transitions_orig = [(0, 0, 0), (0, 1, 0.5), (0, 2, 0.5), (1, 1, 0), (1, 3, 0.5), (1, 4, 0.5), |
|||
(2, 2, 0), (2, 5, 0.5), (2, 6, 0.5), (3, 1, 0.5), (3, 3, 0), (3, 7, 0.5), |
|||
(4, 4, 0), (4, 8, 0.5), (4, 9, 0.5), (5, 5, 0), (5, 10, 0.5), (5, 11, 0.5), |
|||
(6, 2, 0.5), (6, 6, 0), (6, 12, 0.5), (7, 7, 1), (8, 8, 1), |
|||
(9, 9, 1), (10, 10, 1), (11, 11, 1), (12, 12, 1) |
|||
] |
|||
model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab") |
|||
s = stormpy.state.State(0, model) |
|||
i = 0 |
|||
for state in stormpy.state.State(0, model): |
|||
for action in state.actions(): |
|||
for transition in action.transitions(): |
|||
transition_orig = transitions_orig[i] |
|||
i += 1 |
|||
assert state.id == transition_orig[0] |
|||
assert transition.column() == transition_orig[1] |
|||
assert transition.val() == transition_orig[2] |
|||
|
|||
def test_transitions_mdp(self): |
|||
model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab") |
|||
s = stormpy.state.State(0, model) |
|||
for state in stormpy.state.State(0, model): |
|||
i = 0 |
|||
for action in state.actions(): |
|||
i += 1 |
|||
for transition in action.transitions(): |
|||
assert transition.val() == 0.5 or transition.val() == 1 |
|||
assert i == 1 or i == 2 |
Write
Preview
Loading…
Cancel
Save
Reference in new issue