Browse Source

Merge branch 'master' into almostsurepomdp

refactoring
Sebastian Junges 5 years ago
parent
commit
865c804093
  1. 1
      .gitignore
  2. 36
      lib/stormpy/simulator.py
  3. 3
      lib/stormpy/storage/__init__.py
  4. 1
      src/mod_storage.cpp
  5. 15
      src/storage/jani.cpp
  6. 1
      src/storage/jani.h
  7. 3
      src/storage/labeling.cpp
  8. 2
      src/storage/model.cpp
  9. 4
      src/utility/smtsolver.cpp
  10. 2
      tests/storage/test_model.py

1
.gitignore

@ -10,5 +10,6 @@ __pycache__/
_build/ _build/
.pytest_cache/ .pytest_cache/
.idea/ .idea/
cmake-build-debug/
.DS_Store .DS_Store

36
lib/stormpy/simulator.py

@ -18,6 +18,7 @@ class Simulator:
self._seed = seed self._seed = seed
self._observation_mode = SimulatorObservationMode.STATE_LEVEL self._observation_mode = SimulatorObservationMode.STATE_LEVEL
self._action_mode = SimulatorActionMode.INDEX_LEVEL self._action_mode = SimulatorActionMode.INDEX_LEVEL
self._full_observe = False
def available_actions(self): def available_actions(self):
""" """
@ -61,6 +62,15 @@ class Simulator:
raise RuntimeError("Observation mode must be a SimulatorObservationMode") raise RuntimeError("Observation mode must be a SimulatorObservationMode")
self._observation_mode = mode self._observation_mode = mode
def set_full_observability(self, value):
"""
Sets whether the full state space is observable.
Default inherited from the model, but this method overrides the setting.
:param value:
"""
self._full_observe = value
class SparseSimulator(Simulator): class SparseSimulator(Simulator):
""" """
@ -74,6 +84,7 @@ class SparseSimulator(Simulator):
if seed is not None: if seed is not None:
self._engine.set_seed(seed) self._engine.set_seed(seed)
self._state_valuations = None self._state_valuations = None
self.set_full_observability(self._model.model_type != stormpy.storage.ModelType.POMDP)
def available_actions(self): def available_actions(self):
return range(self.nr_available_actions()) return range(self.nr_available_actions())
@ -81,11 +92,30 @@ class SparseSimulator(Simulator):
def nr_available_actions(self): def nr_available_actions(self):
return self._model.get_nr_available_actions(self._engine.get_current_state()) return self._model.get_nr_available_actions(self._engine.get_current_state())
def _report_observation(self):
def _report_state(self):
if self._observation_mode == SimulatorObservationMode.STATE_LEVEL: if self._observation_mode == SimulatorObservationMode.STATE_LEVEL:
return self._engine.get_current_state() return self._engine.get_current_state()
elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL: elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL:
return self._state_valuations.get_state(self._engine.get_current_state()) return self._state_valuations.get_state(self._engine.get_current_state())
assert False, "The observation mode is unexpected"
def _report_observation(self):
"""
:return:
"""
#TODO this should be ensured earlier
assert self._model.model_type == stormpy.storage.ModelType.POMDP
if self._observation_mode == SimulatorObservationMode.STATE_LEVEL:
return self._model.get_observation(self._engine.get_current_state())
elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL:
raise NotImplementedError("Program level observations are not implemented in storm")
assert False, "The observation mode is unexpected"
def _report_result(self):
if self._full_observe:
return self._report_state()
else:
return self._report_observation()
def step(self, action=None): def step(self, action=None):
if action is None: if action is None:
@ -98,12 +128,12 @@ class SparseSimulator(Simulator):
raise RuntimeError(f"Only {self.nr_available_actions()} actions available") raise RuntimeError(f"Only {self.nr_available_actions()} actions available")
check = self._engine.step(action) check = self._engine.step(action)
assert check assert check
return self._report_observation()
return self._report_result()
def restart(self): def restart(self):
self._engine.reset_to_initial_state() self._engine.reset_to_initial_state()
return self._report_observation()
return self._report_result()
def is_done(self): def is_done(self):
return self._model.is_sink_state(self._engine.get_current_state()) return self._model.is_sink_state(self._engine.get_current_state())

3
lib/stormpy/storage/__init__.py

@ -1,3 +1,4 @@
import stormpy.utility
from . import storage from . import storage
from .storage import * from .storage import *

1
src/mod_storage.cpp

@ -33,6 +33,7 @@ PYBIND11_MODULE(storage, m) {
define_state(m); define_state(m);
define_prism(m); define_prism(m);
define_jani(m); define_jani(m);
define_jani_transformers(m);
define_labeling(m); define_labeling(m);
define_origins(m); define_origins(m);
define_expressions(m); define_expressions(m);

15
src/storage/jani.cpp

@ -4,6 +4,8 @@
#include <storm/storage/expressions/ExpressionManager.h> #include <storm/storage/expressions/ExpressionManager.h>
#include <storm/logic/RewardAccumulationEliminationVisitor.h> #include <storm/logic/RewardAccumulationEliminationVisitor.h>
#include <storm/storage/jani/traverser/InformationCollector.h> #include <storm/storage/jani/traverser/InformationCollector.h>
#include <storm/storage/jani/JaniLocationExpander.h>
#include <storm/storage/jani/JaniScopeChanger.h>
#include "src/helpers.h" #include "src/helpers.h"
using namespace storm::jani; using namespace storm::jani;
@ -34,6 +36,7 @@ void define_jani(py::module& m) {
.def("define_constants", &Model::defineUndefinedConstants, "define constants with a mapping from the corresponding expression variables to expressions", py::arg("map")) .def("define_constants", &Model::defineUndefinedConstants, "define constants with a mapping from the corresponding expression variables to expressions", py::arg("map"))
.def("substitute_constants", &Model::substituteConstants, "substitute constants") .def("substitute_constants", &Model::substituteConstants, "substitute constants")
.def("remove_constant", &Model::removeConstant, "remove a constant. Make sure the constant does not appear in the model.", "constant_name"_a) .def("remove_constant", &Model::removeConstant, "remove a constant. Make sure the constant does not appear in the model.", "constant_name"_a)
.def("get_automaton", [](Model const& model, std::string const& name) {return model.getAutomaton(name);}, "name"_a)
.def("get_automaton_index", &Model::getAutomatonIndex, "name"_a, "get index for automaton name") .def("get_automaton_index", &Model::getAutomatonIndex, "name"_a, "get index for automaton name")
.def("add_automaton", &Model::addAutomaton, "automaton"_a, "add an automaton (with a unique name)") .def("add_automaton", &Model::addAutomaton, "automaton"_a, "add an automaton (with a unique name)")
.def("set_standard_system_composition", &Model::setStandardSystemComposition, "sets the composition to the standard composition") .def("set_standard_system_composition", &Model::setStandardSystemComposition, "sets the composition to the standard composition")
@ -44,6 +47,7 @@ void define_jani(py::module& m) {
.def_static("decode_automaton_and_edge_index", &Model::decodeAutomatonAndEdgeIndices, "get edge and automaton from edge/automaton index") .def_static("decode_automaton_and_edge_index", &Model::decodeAutomatonAndEdgeIndices, "get edge and automaton from edge/automaton index")
.def("make_standard_compliant", &Model::makeStandardJaniCompliant, "make standard JANI compliant") .def("make_standard_compliant", &Model::makeStandardJaniCompliant, "make standard JANI compliant")
.def("has_standard_composition", &Model::hasStandardComposition, "is the composition the standard composition") .def("has_standard_composition", &Model::hasStandardComposition, "is the composition the standard composition")
.def("flatten_composition", &Model::flattenComposition, py::arg("smt_solver_factory")=std::make_shared<storm::utility::solver::SmtSolverFactory>())
.def("finalize", &Model::finalize,"finalizes the model. After this action, be careful changing the data structure.") .def("finalize", &Model::finalize,"finalizes the model. After this action, be careful changing the data structure.")
.def("to_dot", [](Model& model) {std::stringstream ss; model.writeDotToStream(ss); return ss.str(); }) .def("to_dot", [](Model& model) {std::stringstream ss; model.writeDotToStream(ss); return ss.str(); })
; ;
@ -171,3 +175,14 @@ void define_jani(py::module& m) {
m.def("collect_information", [](const Model& model) {return storm::jani::collectModelInformation(model);}); m.def("collect_information", [](const Model& model) {return storm::jani::collectModelInformation(model);});
} }
void define_jani_transformers(py::module& m) {
py::class_<JaniLocationExpander>(m, "JaniLocationExpander", "A transformer for Jani expanding variables into locations")
.def(py::init<Model const&>(), py::arg("model"))
.def("transform", &JaniLocationExpander::transform, py::arg("automaton_name"), py::arg("variable_name"))
.def("get_result", &JaniLocationExpander::getResult);
py::class_<JaniScopeChanger>(m, "JaniScopeChanger", "A transformer for Jani changing variables from local to global and vice versa")
.def(py::init<>())
.def("make_variables_local", [](JaniScopeChanger const& sc, Model const& model , std::vector<Property> const& props = {}) { Model newModel(model); sc.makeVariablesLocal(newModel, props); return newModel;}, py::arg("model"), py::arg("properties") = std::vector<Property>());
}

1
src/storage/jani.h

@ -3,3 +3,4 @@
#include "common.h" #include "common.h"
void define_jani(py::module& m); void define_jani(py::module& m);
void define_jani_transformers(py::module& m);

3
src/storage/labeling.cpp

@ -32,5 +32,6 @@ void define_labeling(py::module& m) {
; ;
py::class_<storm::models::sparse::ChoiceLabeling>(m, "ChoiceLabeling", "Labeling for choices", labeling);
py::class_<storm::models::sparse::ChoiceLabeling>(m, "ChoiceLabeling", "Labeling for choices", labeling).
def("get_labels_of_choice", &storm::models::sparse::ChoiceLabeling::getLabelsOfChoice, py::arg("choice"), "get labels of a choice");
} }

2
src/storage/model.cpp

@ -169,6 +169,7 @@ void define_sparse_model(py::module& m) {
// Models with double numbers // Models with double numbers
py::class_<SparseModel<double>, std::shared_ptr<SparseModel<double>>, ModelBase> model(m, "_SparseModel", "A probabilistic model where transitions are represented by doubles and saved in a sparse matrix"); py::class_<SparseModel<double>, std::shared_ptr<SparseModel<double>>, ModelBase> model(m, "_SparseModel", "A probabilistic model where transitions are represented by doubles and saved in a sparse matrix");
model.def_property_readonly("labeling", &getLabeling<double>, "Labels") model.def_property_readonly("labeling", &getLabeling<double>, "Labels")
.def("has_choice_labeling", [](SparseModel<double> const& model) {model.hasChoiceLabeling();}, "Does the model have an associated choice labelling?")
.def_property_readonly("choice_labeling", [](SparseModel<double> const& model) {return model.getChoiceLabeling();}, "get choice labelling") .def_property_readonly("choice_labeling", [](SparseModel<double> const& model) {return model.getChoiceLabeling();}, "get choice labelling")
.def("has_choice_origins", [](SparseModel<double> const& model) {return model.hasChoiceOrigins();}, "has choice origins?") .def("has_choice_origins", [](SparseModel<double> const& model) {return model.hasChoiceOrigins();}, "has choice origins?")
.def_property_readonly("choice_origins", [](SparseModel<double> const& model) {return model.getChoiceOrigins();}) .def_property_readonly("choice_origins", [](SparseModel<double> const& model) {return model.getChoiceOrigins();})
@ -195,6 +196,7 @@ void define_sparse_model(py::module& m) {
mdp.def(py::init<SparseMdp<double>>(), py::arg("other_model")) mdp.def(py::init<SparseMdp<double>>(), py::arg("other_model"))
.def_property_readonly("nondeterministic_choice_indices", [](SparseMdp<double> const& mdp) { return mdp.getNondeterministicChoiceIndices(); }) .def_property_readonly("nondeterministic_choice_indices", [](SparseMdp<double> const& mdp) { return mdp.getNondeterministicChoiceIndices(); })
.def("get_nr_available_actions", [](SparseMdp<double> const& mdp, uint64_t stateIndex) { return mdp.getNondeterministicChoiceIndices()[stateIndex+1] - mdp.getNondeterministicChoiceIndices()[stateIndex] ; }, py::arg("state")) .def("get_nr_available_actions", [](SparseMdp<double> const& mdp, uint64_t stateIndex) { return mdp.getNondeterministicChoiceIndices()[stateIndex+1] - mdp.getNondeterministicChoiceIndices()[stateIndex] ; }, py::arg("state"))
.def("get_choice_index", [](SparseMdp<double> const& mdp, uint64_t state, uint64_t actOff) { return mdp.getNondeterministicChoiceIndices()[state]+actOff; }, py::arg("state"), py::arg("action_offset"), "gets the choice index for the offset action from the given state.")
.def("apply_scheduler", [](SparseMdp<double> const& mdp, storm::storage::Scheduler<double> const& scheduler, bool dropUnreachableStates) { return mdp.applyScheduler(scheduler, dropUnreachableStates); } , "apply scheduler", "scheduler"_a, "drop_unreachable_states"_a = true) .def("apply_scheduler", [](SparseMdp<double> const& mdp, storm::storage::Scheduler<double> const& scheduler, bool dropUnreachableStates) { return mdp.applyScheduler(scheduler, dropUnreachableStates); } , "apply scheduler", "scheduler"_a, "drop_unreachable_states"_a = true)
.def("__str__", &getModelInfoPrinter) .def("__str__", &getModelInfoPrinter)
; ;

4
src/utility/smtsolver.cpp

@ -1,6 +1,8 @@
#include "smtsolver.h" #include "smtsolver.h"
#include <storm/solver/Z3SmtSolver.h> #include <storm/solver/Z3SmtSolver.h>
#include "storm/storage/expressions/ExpressionManager.h" #include "storm/storage/expressions/ExpressionManager.h"
#include <storm/solver/SmtSolver.h>
#include <storm/utility/solver.h>
void define_smt(py::module& m) { void define_smt(py::module& m) {
using SmtSolver = storm::solver::SmtSolver; using SmtSolver = storm::solver::SmtSolver;
@ -29,4 +31,6 @@ void define_smt(py::module& m) {
py::class_<Z3SmtSolver> z3solver(m, "Z3SmtSolver", "z3 API for storm smtsolver wrapper", smtsolver); py::class_<Z3SmtSolver> z3solver(m, "Z3SmtSolver", "z3 API for storm smtsolver wrapper", smtsolver);
z3solver.def(pybind11::init<storm::expressions::ExpressionManager&>()); z3solver.def(pybind11::init<storm::expressions::ExpressionManager&>());
py::class_<storm::utility::solver::SmtSolverFactory, std::shared_ptr<storm::utility::solver::SmtSolverFactory>> (m, "SmtSolverFactory", "Factory for creating SMT Solvers");
} }

2
tests/storage/test_model.py

@ -116,7 +116,7 @@ class TestSparseModel:
program = stormpy.parse_prism_program(get_example_path("pomdp", "maze_2.prism")) program = stormpy.parse_prism_program(get_example_path("pomdp", "maze_2.prism"))
formulas = stormpy.parse_properties_for_prism_program("P=? [F \"goal\"]", program) formulas = stormpy.parse_properties_for_prism_program("P=? [F \"goal\"]", program)
model = stormpy.build_model(program, formulas) model = stormpy.build_model(program, formulas)
assert model.nr_states == 16
assert model.nr_states == 15
assert model.nr_observations == 8 assert model.nr_observations == 8
def test_build_ma(self): def test_build_ma(self):

Loading…
Cancel
Save