Browse Source

Length for states and actions

refactoring
Matthias Volk 8 years ago
parent
commit
9618b5ca31
  1. 2
      src/storage/matrix.cpp
  2. 4
      src/storage/state.cpp
  3. 26
      src/storage/state.h
  4. 5
      tests/storage/test_state.py

2
src/storage/matrix.cpp

@ -107,6 +107,7 @@ void define_sparse_matrix(py::module& m) {
.def("__iter__", [](storm::storage::SparseMatrix<double>::rows& rows) { .def("__iter__", [](storm::storage::SparseMatrix<double>::rows& rows) {
return py::make_iterator(rows.begin(), rows.end()); return py::make_iterator(rows.begin(), rows.end());
}, py::keep_alive<0, 1>()) }, py::keep_alive<0, 1>())
.def("__len__", &storm::storage::SparseMatrix<double>::rows::getNumberOfEntries)
.def("__str__", [](storm::storage::SparseMatrix<double>::const_rows& rows) { .def("__str__", [](storm::storage::SparseMatrix<double>::const_rows& rows) {
std::stringstream stream; std::stringstream stream;
for (auto transition : rows) { for (auto transition : rows) {
@ -120,6 +121,7 @@ void define_sparse_matrix(py::module& m) {
.def("__iter__", [](storm::storage::SparseMatrix<storm::RationalFunction>::rows& rows) { .def("__iter__", [](storm::storage::SparseMatrix<storm::RationalFunction>::rows& rows) {
return py::make_iterator(rows.begin(), rows.end()); return py::make_iterator(rows.begin(), rows.end());
}, py::keep_alive<0, 1>()) }, py::keep_alive<0, 1>())
.def("__len__", &storm::storage::SparseMatrix<storm::RationalFunction>::rows::getNumberOfEntries)
.def("__str__", [](storm::storage::SparseMatrix<storm::RationalFunction>::const_rows& rows) { .def("__str__", [](storm::storage::SparseMatrix<storm::RationalFunction>::const_rows& rows) {
std::stringstream stream; std::stringstream stream;
for (auto transition : rows) { for (auto transition : rows) {

4
src/storage/state.cpp

@ -5,9 +5,11 @@ void define_state(py::module& m) {
// SparseModelStates // SparseModelStates
py::class_<SparseModelStates<double>>(m, "SparseModelStates", "States in sparse model") py::class_<SparseModelStates<double>>(m, "SparseModelStates", "States in sparse model")
.def("__getitem__", &SparseModelStates<double>::getState) .def("__getitem__", &SparseModelStates<double>::getState)
.def("__len__", &SparseModelStates<double>::getSize)
; ;
py::class_<SparseModelStates<storm::RationalFunction>>(m, "SparseParametricModelStates", "States in sparse parametric model") py::class_<SparseModelStates<storm::RationalFunction>>(m, "SparseParametricModelStates", "States in sparse parametric model")
.def("__getitem__", &SparseModelStates<storm::RationalFunction>::getState) .def("__getitem__", &SparseModelStates<storm::RationalFunction>::getState)
.def("__len__", &SparseModelStates<storm::RationalFunction>::getSize)
; ;
// SparseModelState // SparseModelState
py::class_<SparseModelState<double>>(m, "SparseModelState", "State in sparse model") py::class_<SparseModelState<double>>(m, "SparseModelState", "State in sparse model")
@ -26,9 +28,11 @@ void define_state(py::module& m) {
// SparseModelActions // SparseModelActions
py::class_<SparseModelActions<double>>(m, "SparseModelActions", "Actions for state in sparse model") py::class_<SparseModelActions<double>>(m, "SparseModelActions", "Actions for state in sparse model")
.def("__getitem__", &SparseModelActions<double>::getAction) .def("__getitem__", &SparseModelActions<double>::getAction)
.def("__len__", &SparseModelActions<double>::getSize)
; ;
py::class_<SparseModelActions<storm::RationalFunction>>(m, "SparseParametricModelActions", "Actions for state in sparse parametric model") py::class_<SparseModelActions<storm::RationalFunction>>(m, "SparseParametricModelActions", "Actions for state in sparse parametric model")
.def("__getitem__", &SparseModelActions<storm::RationalFunction>::getAction) .def("__getitem__", &SparseModelActions<storm::RationalFunction>::getAction)
.def("__len__", &SparseModelActions<storm::RationalFunction>::getSize)
; ;
// SparseModelAction // SparseModelAction
py::class_<SparseModelAction<double>>(m, "SparseModelAction", "Action for state in sparse model") py::class_<SparseModelAction<double>>(m, "SparseModelAction", "Action for state in sparse model")

26
src/storage/state.h

@ -20,20 +20,20 @@ class SparseModelState {
} }
s_index getIndex() const { s_index getIndex() const {
return this->stateIndex;
return stateIndex;
} }
std::set<std::string> getLabels() {
std::set<std::string> getLabels() const {
return this->model.getStateLabeling().getLabelsOfState(this->stateIndex); return this->model.getStateLabeling().getLabelsOfState(this->stateIndex);
} }
SparseModelActions<ValueType> getActions() {
SparseModelActions<ValueType> getActions() const {
return SparseModelActions<ValueType>(this->model, stateIndex); return SparseModelActions<ValueType>(this->model, stateIndex);
} }
std::string toString() {
std::string toString() const {
std::stringstream stream; std::stringstream stream;
stream << this->getIndex();
stream << stateIndex;
return stream.str(); return stream.str();
} }
@ -51,7 +51,11 @@ class SparseModelStates {
length = model.getNumberOfStates(); length = model.getNumberOfStates();
} }
SparseModelState<ValueType> getState(s_index index) {
s_index getSize() const {
return length;
}
SparseModelState<ValueType> getState(s_index index) const {
if (index < length) { if (index < length) {
return SparseModelState<ValueType>(model, index); return SparseModelState<ValueType>(model, index);
} else { } else {
@ -81,9 +85,9 @@ class SparseModelAction {
return model.getTransitionMatrix().getRow(stateIndex, actionIndex); return model.getTransitionMatrix().getRow(stateIndex, actionIndex);
} }
std::string toString() {
std::string toString() const {
std::stringstream stream; std::stringstream stream;
stream << this->getIndex();
stream << actionIndex;
return stream.str(); return stream.str();
} }
@ -103,7 +107,11 @@ class SparseModelActions {
length = model.getTransitionMatrix().getRowGroupSize(stateIndex); length = model.getTransitionMatrix().getRowGroupSize(stateIndex);
} }
SparseModelAction<ValueType> getAction(size_t index) {
s_index getSize() const {
return length;
}
SparseModelAction<ValueType> getAction(size_t index) const {
if (index < length) { if (index < length) {
return SparseModelAction<ValueType>(model, stateIndex, index); return SparseModelAction<ValueType>(model, stateIndex, index);
} else { } else {

5
tests/storage/test_state.py

@ -6,6 +6,7 @@ class TestState:
model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab"))
i = 0 i = 0
states = model.states states = model.states
assert len(states) == 13
for state in states: for state in states:
assert state.id == i assert state.id == i
i += 1 i += 1
@ -18,6 +19,7 @@ class TestState:
def test_states_mdp(self): def test_states_mdp(self):
model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab")) model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab"))
i = 0 i = 0
assert len(model.states) == 169
for state in model.states: for state in model.states:
assert state.id == i assert state.id == i
i += 1 i += 1
@ -41,12 +43,14 @@ class TestState:
def test_actions_dtmc(self): def test_actions_dtmc(self):
model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab"))
for state in model.states: for state in model.states:
assert len(state.actions) == 1
for action in state.actions: for action in state.actions:
assert action.id == 0 assert action.id == 0
def test_actions_mdp(self): def test_actions_mdp(self):
model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab")) model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab"))
for state in model.states: for state in model.states:
assert len(state.actions) == 1 or len(state.actions) == 2
for action in state.actions: for action in state.actions:
assert action.id == 0 or action.id == 1 assert action.id == 0 or action.id == 1
@ -61,6 +65,7 @@ class TestState:
i = 0 i = 0
for state in model.states: for state in model.states:
for action in state.actions: for action in state.actions:
assert (state.id < 7 and len(action.transitions) == 3) or (state.id >= 7 and len(action.transitions) == 1)
for transition in action.transitions: for transition in action.transitions:
transition_orig = transitions_orig[i] transition_orig = transitions_orig[i]
i += 1 i += 1
Loading…
Cancel
Save