Browse Source

StateGenerator

refactoring
Philipp Schröer 5 years ago
parent
commit
c1f2c83e1f
  1. 4
      src/storage/expressions.cpp
  2. 203
      src/storage/prism.cpp
  3. 3
      src/storage/valuation.cpp
  4. 75
      tests/storage/test_state_generation.py

4
src/storage/expressions.cpp

@ -39,6 +39,8 @@ void define_expressions(py::module& m) {
.def("has_numerical_type", &storm::expressions::Variable::hasNumericalType, "Check if the variable is of numerical type") .def("has_numerical_type", &storm::expressions::Variable::hasNumericalType, "Check if the variable is of numerical type")
.def("has_bitvector_type", &storm::expressions::Variable::hasBitVectorType, "Check if the variable is of bitvector type") .def("has_bitvector_type", &storm::expressions::Variable::hasBitVectorType, "Check if the variable is of bitvector type")
.def("get_expression", &storm::expressions::Variable::getExpression, "Get expression from variable") .def("get_expression", &storm::expressions::Variable::getExpression, "Get expression from variable")
.def("__eq__", &storm::expressions::Variable::operator==)
.def("__hash__", &storm::expressions::Variable::getIndex)
; ;
@ -130,4 +132,4 @@ void define_expressions(py::module& m) {
.def_property_readonly("is_rational", &storm::expressions::Type::isRationalType) .def_property_readonly("is_rational", &storm::expressions::Type::isRationalType)
.def("__str__", &storm::expressions::Type::getStringRepresentation); .def("__str__", &storm::expressions::Type::getStringRepresentation);
}
}

203
src/storage/prism.cpp

@ -1,12 +1,25 @@
#include "prism.h" #include "prism.h"
#include <storm/storage/prism/Program.h> #include <storm/storage/prism/Program.h>
#include <boost/variant.hpp>
#include "src/helpers.h" #include "src/helpers.h"
#include <storm/storage/expressions/ExpressionManager.h> #include <storm/storage/expressions/ExpressionManager.h>
#include <storm/storage/jani/Model.h> #include <storm/storage/jani/Model.h>
#include <storm/storage/jani/Property.h> #include <storm/storage/jani/Property.h>
#include <storm/generator/NextStateGenerator.h>
#include <storm/generator/Choice.h>
#include <storm/generator/PrismNextStateGenerator.h>
#include "storm/exceptions/NotSupportedException.h"
#include <storm/storage/expressions/SimpleValuation.h>
#include "storm/exceptions/InvalidTypeException.h"
#include "storm/exceptions/InvalidStateException.h"
#include "storm/exceptions/InvalidAccessException.h"
using namespace storm::prism; using namespace storm::prism;
template <typename StateType, typename ValueType>
void define_stateGeneration(py::module& m);
void define_prism(py::module& m) { void define_prism(py::module& m) {
py::class_<storm::prism::Program, std::shared_ptr<storm::prism::Program>> program(m, "PrismProgram", "A Prism Program"); py::class_<storm::prism::Program, std::shared_ptr<storm::prism::Program>> program(m, "PrismProgram", "A Prism Program");
program.def_property_readonly("constants", &Program::getConstants, "Get Program Constants") program.def_property_readonly("constants", &Program::getConstants, "Get Program Constants")
@ -38,22 +51,28 @@ void define_prism(py::module& m) {
.def_property_readonly("name", &Module::getName, "Name of the module") .def_property_readonly("name", &Module::getName, "Name of the module")
.def_property_readonly("integer_variables", &Module::getIntegerVariables, "All integer Variables of this module") .def_property_readonly("integer_variables", &Module::getIntegerVariables, "All integer Variables of this module")
.def_property_readonly("boolean_variables", &Module::getBooleanVariables, "All boolean Variables of this module") .def_property_readonly("boolean_variables", &Module::getBooleanVariables, "All boolean Variables of this module")
.def("__str__", &streamToString<Module>)
; ;
py::class_<Command> command(m, "PrismCommand", "A command in a Prism program"); py::class_<Command> command(m, "PrismCommand", "A command in a Prism program");
command.def_property_readonly("global_index", &Command::getGlobalIndex, "Get global index") command.def_property_readonly("global_index", &Command::getGlobalIndex, "Get global index")
.def_property_readonly("guard_expression", &Command::getGuardExpression, "Get guard expression") .def_property_readonly("guard_expression", &Command::getGuardExpression, "Get guard expression")
.def_property_readonly("updates", &Command::getUpdates, "Updates in the command");
.def_property_readonly("updates", &Command::getUpdates, "Updates in the command")
.def("__str__", &streamToString<Command>)
;
py::class_<Update> update(m, "PrismUpdate", "An update in a Prism command"); py::class_<Update> update(m, "PrismUpdate", "An update in a Prism command");
update.def_property_readonly("assignments", &Update::getAssignments, "Assignments in the update") update.def_property_readonly("assignments", &Update::getAssignments, "Assignments in the update")
.def_property_readonly("probability_expression", &Update::getLikelihoodExpression, "The probability expression for this update") .def_property_readonly("probability_expression", &Update::getLikelihoodExpression, "The probability expression for this update")
;//Added by Kevin
.def("__str__", &streamToString<Update>)
;
py::class_<Assignment> assignment(m, "PrismAssignment", "An assignment in prism"); py::class_<Assignment> assignment(m, "PrismAssignment", "An assignment in prism");
assignment.def_property_readonly("variable", &Assignment::getVariable, "Variable that is updated") assignment.def_property_readonly("variable", &Assignment::getVariable, "Variable that is updated")
.def_property_readonly("expression", &Assignment::getExpression, "Expression for the update");
.def_property_readonly("expression", &Assignment::getExpression, "Expression for the update")
.def("__str__", &streamToString<Assignment>)
;
// PrismType // PrismType
@ -85,9 +104,183 @@ void define_prism(py::module& m) {
py::class_<IntegerVariable, std::shared_ptr<IntegerVariable>> integer_variable(m, "Prism_Integer_Variable", variable, "A program integer variable in a Prism program"); py::class_<IntegerVariable, std::shared_ptr<IntegerVariable>> integer_variable(m, "Prism_Integer_Variable", variable, "A program integer variable in a Prism program");
integer_variable.def_property_readonly("lower_bound_expression", &IntegerVariable::getLowerBoundExpression, "The the lower bound expression of this integer variable") integer_variable.def_property_readonly("lower_bound_expression", &IntegerVariable::getLowerBoundExpression, "The the lower bound expression of this integer variable")
.def_property_readonly("upper_bound_expression", &IntegerVariable::getUpperBoundExpression, "The the upper bound expression of this integer variable") .def_property_readonly("upper_bound_expression", &IntegerVariable::getUpperBoundExpression, "The the upper bound expression of this integer variable")
.def("__str__", &streamToString<IntegerVariable>)
; ;
py::class_<BooleanVariable, std::shared_ptr<BooleanVariable>> boolean_variable(m, "Prism_Boolean_Variable", variable, "A program boolean variable in a Prism program"); py::class_<BooleanVariable, std::shared_ptr<BooleanVariable>> boolean_variable(m, "Prism_Boolean_Variable", variable, "A program boolean variable in a Prism program");
boolean_variable.def("__str__", &streamToString<BooleanVariable>);
define_stateGeneration<uint32_t, storm::RationalNumber>(m);
}
class ValuationMapping {
public:
std::map<storm::expressions::Variable, bool> booleanValues;
std::map<storm::expressions::Variable, int_fast64_t> integerValues;
std::map<storm::expressions::Variable, double> rationalValues;
ValuationMapping(storm::prism::Program const& program,
storm::expressions::SimpleValuation valuation) {
auto const& variables = program.getManager().getVariables();
for (auto const& variable : variables) {
if (variable.hasBooleanType()) {
booleanValues[variable] = valuation.getBooleanValue(variable);
} else if (variable.hasIntegerType()) {
integerValues[variable] = valuation.getIntegerValue(variable);
} else if (variable.hasRationalType()) {
rationalValues[variable] = valuation.getRationalValue(variable);
} else {
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException,
"Unexpected variable type.");
}
}
}
std::string toString() const {
std::vector<std::string> strs;
for (auto const& value : booleanValues) {
std::stringstream sstr;
sstr << value.first.getName() + "=";
sstr << value.second;
strs.push_back(sstr.str());
}
for (auto const& value : integerValues) {
std::stringstream sstr;
sstr << value.first.getName() + "=";
sstr << value.second;
strs.push_back(sstr.str());
}
for (auto const& value : rationalValues) {
std::stringstream sstr;
sstr << value.first.getName() + "=";
sstr << value.second;
strs.push_back(sstr.str());
}
return "[" + boost::join(strs, ",") + "]";
}
};
template <typename StateType, typename ValueType>
class StateGenerator {
typedef std::unordered_map<StateType, storm::generator::CompressedState> IdToStateMap;
typedef std::vector<std::pair<StateType, ValueType>> distribution_type;
typedef distribution_type choice_type; // currently we just collapse choices into distributions
typedef std::vector<choice_type> choice_list_type;
storm::prism::Program const& program;
storm::generator::PrismNextStateGenerator<ValueType, StateType> generator;
std::function<StateType (storm::generator::CompressedState const&)> stateToIdCallback;
// this needs to be below the generator attribute,
// because its initialization depends on the generator being initialized.
// #justcppthings
storm::storage::sparse::StateStorage<StateType> stateStorage;
bool hasComputedInitialStates = false;
IdToStateMap stateMap;
boost::optional<StateType> currentStateIndex;
public:
StateGenerator(storm::prism::Program const& program_) : program(program_), generator(program_), stateStorage(generator.getStateSize()) {
stateToIdCallback = [this] (storm::generator::CompressedState const& state) -> StateType {
StateType newIndex = stateStorage.getNumberOfStates();
std::pair<StateType, std::size_t> indexBucketPair = stateStorage.stateToId.findOrAddAndGetBucket(state, newIndex);
StateType index = indexBucketPair.first;
stateMap[index] = state;
return index;
};
}
StateType loadInitialState() {
if (!hasComputedInitialStates) {
stateStorage.initialStateIndices = generator.getInitialStates(stateToIdCallback);
hasComputedInitialStates = true;
}
STORM_LOG_THROW(stateStorage.initialStateIndices.size() == 1, storm::exceptions::NotSupportedException, "Currently only models with one initial state are supported.");
StateType initialStateIndex = stateStorage.initialStateIndices.front();
load(initialStateIndex);
return initialStateIndex;
}
void load(StateType stateIndex) {
if (currentStateIndex && *currentStateIndex == stateIndex) {
return;
}
auto search = stateMap.find(stateIndex);
if (search == stateMap.end()) {
STORM_LOG_THROW(false, storm::exceptions::InvalidAccessException,
"state id not found");
}
generator.load(search->second);
currentStateIndex = stateIndex;
}
ValuationMapping currentStateToValuation() {
if (!currentStateIndex) {
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException,
"Initial state not initialized");
}
auto valuation = generator.toValuation(stateMap[*currentStateIndex]);
return ValuationMapping(program, valuation);
}
bool satisfies(storm::expressions::Expression const& expression) {
return generator.satisfies(expression);
}
choice_list_type expand() {
if (!hasComputedInitialStates) {
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException,
"Initial state not initialized");
}
choice_list_type choices_result;
auto behavior = generator.expand(stateToIdCallback);
for (auto choice : behavior.getChoices()) {
choices_result.push_back(choice_type(choice.begin(), choice.end()));
}
return choices_result;
}
};
template <typename StateType, typename ValueType>
void define_stateGeneration(py::module& m) {
py::class_<ValuationMapping, std::shared_ptr<ValuationMapping>> valuation_mapping(m, "ValuationMapping", "A valuation mapping for a state consists of a mapping from variable to value for each of the three types.");
valuation_mapping
.def_readonly("boolean_values", &ValuationMapping::booleanValues)
.def_readonly("integer_values", &ValuationMapping::integerValues)
.def_readonly("rational_values", &ValuationMapping::rationalValues)
.def("__str__", &ValuationMapping::toString);
py::class_<StateGenerator<StateType, ValueType>, std::shared_ptr<StateGenerator<StateType, ValueType>>> state_generator(m, "StateGenerator", R"doc(
Interactively explore states using Storm's PrismNextStateGenerator.
:ivar program: A PRISM program.
)doc");
state_generator
.def(py::init<storm::prism::Program const&>())
.def("load_initial_state", &StateGenerator<StateType, ValueType>::loadInitialState, R"doc(
Loads the (unique) initial state.
Multiple initial states are not supported.
:rtype: the ID of the initial state.
)doc")
.def("load", &StateGenerator<StateType, ValueType>::load, R"doc(
:param state_id: The ID of the state to load.
)doc")
.def("current_state_to_valuation", &StateGenerator<StateType, ValueType>::currentStateToValuation, R"doc(
Return a valuation for the currently loaded state.
:rtype: stormpy.ValuationMapping
)doc")
.def("current_state_satisfies", &StateGenerator<StateType, ValueType>::satisfies, R"doc(
Check if the currently loaded state satisfies the given expression.
:param stormpy.Expression expression: The expression to check against.
:rtype: bool
)doc")
.def("expand", &StateGenerator<StateType, ValueType>::expand, R"doc(
Expand the currently loaded state and return its successors.
}
:rtype: [[(state_id, probability)]]
)doc");
}

3
src/storage/valuation.cpp

@ -18,4 +18,5 @@ void define_simplevaluation(py::module& m) {
py::class_<storm::expressions::SimpleValuation, std::shared_ptr<storm::expressions::SimpleValuation>> simplevaluation(m, "SimpleValuation", "Valuations for storm variables"); py::class_<storm::expressions::SimpleValuation, std::shared_ptr<storm::expressions::SimpleValuation>> simplevaluation(m, "SimpleValuation", "Valuations for storm variables");
simplevaluation.def("get_boolean_value", &storm::expressions::SimpleValuation::getBooleanValue); simplevaluation.def("get_boolean_value", &storm::expressions::SimpleValuation::getBooleanValue);
simplevaluation.def("get_integer_value", &storm::expressions::SimpleValuation::getIntegerValue); simplevaluation.def("get_integer_value", &storm::expressions::SimpleValuation::getIntegerValue);
}
simplevaluation.def("__str__", &storm::expressions::SimpleValuation::toString);
}

75
tests/storage/test_state_generation.py

@ -0,0 +1,75 @@
import stormpy
from stormpy.examples.files import prism_dtmc_die
class _DfsQueue:
def __init__(self):
self.queue = []
self.visited = set()
def push(self, state_id):
if state_id not in self.visited:
self.queue.append(state_id)
self.visited.add(state_id)
def pop(self):
if len(self.queue) > 0:
return self.queue.pop()
return None
def _dfs_explore(program, callback):
generator = stormpy.StateGenerator(program)
queue = _DfsQueue()
current_state_id = generator.load_initial_state()
queue.visited.add(current_state_id)
while True:
callback(current_state_id, generator)
successors = generator.expand()
assert len(successors) <= 1
for choice in successors:
for state_id, _prob in choice:
queue.push(state_id)
current_state_id = queue.pop()
if current_state_id is None:
break
generator.load(current_state_id)
def _load_program(filename):
program = stormpy.parse_prism_program(filename) # pylint: disable=no-member
program = program.substitute_constants()
expression_parser = stormpy.ExpressionParser(program.expression_manager)
expression_parser.set_identifier_mapping(
{var.name: var.get_expression()
for var in program.variables})
return program, expression_parser
def _find_variable(program, name):
for var in program.variables:
if var.name is name:
return var
return None
def test_knuth_yao_die():
program, expression_parser = _load_program(prism_dtmc_die)
s_variable = _find_variable(program, "s")
upper_bound_invariant = expression_parser.parse("d <= 6")
number_states = 0
def callback(_state_id, generator):
nonlocal number_states
number_states += 1
valuation = generator.current_state_to_valuation()
print(valuation)
assert valuation.integer_values[s_variable] <= 7
assert generator.current_state_satisfies(upper_bound_invariant)
_dfs_explore(program, callback)
assert number_states == 13
Loading…
Cancel
Save